Graph Definition
Compare changes
Created by: RasmusOrsoe
This PR addresses the ongoing discussion in #462 (closed) and #521 by changing Model
such that it now consists of the modules
Model = [GraphDefinition, GNN, Task]
Where GraphDefinition
is a single, problem/model dependent class that contains all the code responsible for data representations.
TLDR: Model, Dataset and GraphNeTI3Module now depends on GraphDefinition, which allows us to easily represent data as sequences, images, or whatever your heart desires. This change is breaking; older config files and pickled models are not compatible with these changes, but state_dicts are.
Conceptually, GraphDefinition
contains all the code that alters the raw data from Dataset
before it's passed to GNN
. It's a single, swapable module that can be passed to Dataset
and deployment modules. GraphDefinition
consists of multiple submodules, and the data flow is GraphDefinition = [Detector, NodeDefinition, EdgeDefinition]
and can be seen here. The modules are defined as
NodeDefinition : A generic class that defines what a node represents. Problem-specific versions can be implemented by overwriting the abstract method
def _construct_nodes(self, x: torch.tensor) -> Data:
"""Construct nodes from raw node features ´x´.
Args:
x: standardized node features with shape ´[num_pulses, d]´,
where ´d´ is the number of node features.
Returns:
graph: graph without edges.
"""
_construct_nodes
is the playground we've been missing for a while; it gives us the freedom to fully define exactly how we want the data to be structured for our Models. Here, one can use nodes to represent DOMs (by using Coarsening
or some other method), create images for CNNs, define sequences or other forms of data representations. Our standard of representing pulses as nodes is just
class NodesAsPulses(NodeDefinition):
"""Represent each measured pulse of Cherenkov Radiation as a node."""
def _construct_nodes(self, x: torch.tensor) -> Data:
return Data(x=x)
EdgeDefinition:
A generic class that defines how edges are drawn between nodes in the graph. This is essentially a refactor of our GraphBuilder
. One can create problem-specific implementations by overwriting the abstract method
def _construct_edges(self, graph: Data) -> Data:
"""Construct edges and assign them to graph. I.e. ´graph.edge_index = edge_index´.
Args:
graph: graph without edges
Returns:
graph: graph with edges assigned.
"""
Detector: Virtually unchanged from it's known form. In charge of standardizing data. I cleaned the class up a little bit. In the future, it will hold detector-specific geometry tables as mentioned in #462 (closed).
Our usual k-nn graph with nodes representing pulses can then be created like so:
from graphnet.models.graphs import GraphDefinition
from graphnet.models.graphs.nodes import NodesAsPulses
from graphnet.models.graphs.edges import KNNEdges
from graphnet.models.detector.prometheus import Prometheus
graph_definition = GraphDefinition(node_definiton = NodesAsPulses(nb_nearest_neighbours=8),
edge_definiton = KNNEdges(),
detector = Prometheus(),
)
Alternatively, you can also just import this graph definition directly, as it's included in the PR:
from graphnet.models.graphs import KNNGraph
graph_definition = KNNGraph(
detector=Prometheus(),
node_definition=NodesAsPulses(),
nb_nearest_neighbours=8,
node_feature_names=features,
)
Other things to note:
Dataset
is now simpler, as graph-altering code has been moved to GraphDefinition