Skip to content
Snippets Groups Projects

Graph Definition

Closed Jorge Prado requested to merge github/fork/RasmusOrsoe/graph_definition into main

Created by: RasmusOrsoe

This PR addresses the ongoing discussion in #462 (closed) and #521 by changing Model such that it now consists of the modules

Model = [GraphDefinition, GNN, Task]

Where GraphDefinition is a single, problem/model dependent class that contains all the code responsible for data representations.

  • Add GraphDefinition
  • Implement our default k-nn graph as importable graph definition
  • Redefine Model to depend on GraphDefinition
  • Refactor Detector
  • Update example scripts
  • Update config files
  • Consider default values
  • Delete redundant modules; GraphBuilder etc.
  • Update getting_started.md
  • Refactor Deployment modules

TLDR: Model, Dataset and GraphNeTI3Module now depends on GraphDefinition, which allows us to easily represent data as sequences, images, or whatever your heart desires. This change is breaking; older config files and pickled models are not compatible with these changes, but state_dicts are.

Conceptually, GraphDefinition contains all the code that alters the raw data from Dataset before it's passed to GNN. It's a single, swapable module that can be passed to Dataset and deployment modules. GraphDefinition consists of multiple submodules, and the data flow is GraphDefinition = [Detector, NodeDefinition, EdgeDefinition] and can be seen here. The modules are defined as

NodeDefinition : A generic class that defines what a node represents. Problem-specific versions can be implemented by overwriting the abstract method

def _construct_nodes(self, x: torch.tensor) -> Data:
        """Construct nodes from raw node features ´x´.

        Args:
            x: standardized node features with shape ´[num_pulses, d]´,
            where ´d´ is the number of node features.

        Returns:
            graph: graph without edges.
        """

_construct_nodes is the playground we've been missing for a while; it gives us the freedom to fully define exactly how we want the data to be structured for our Models. Here, one can use nodes to represent DOMs (by using Coarsening or some other method), create images for CNNs, define sequences or other forms of data representations. Our standard of representing pulses as nodes is just

class NodesAsPulses(NodeDefinition):
    """Represent each measured pulse of Cherenkov Radiation as a node."""

    def _construct_nodes(self, x: torch.tensor) -> Data:
        return Data(x=x)

EdgeDefinition: A generic class that defines how edges are drawn between nodes in the graph. This is essentially a refactor of our GraphBuilder. One can create problem-specific implementations by overwriting the abstract method

def _construct_edges(self, graph: Data) -> Data:
        """Construct edges and assign them to graph. I.e. ´graph.edge_index = edge_index´.

        Args:
            graph: graph without edges

        Returns:
            graph: graph with edges assigned.
        """

Detector: Virtually unchanged from it's known form. In charge of standardizing data. I cleaned the class up a little bit. In the future, it will hold detector-specific geometry tables as mentioned in #462 (closed).

Our usual k-nn graph with nodes representing pulses can then be created like so:

from graphnet.models.graphs import GraphDefinition
from graphnet.models.graphs.nodes import NodesAsPulses
from graphnet.models.graphs.edges import KNNEdges
from graphnet.models.detector.prometheus import Prometheus

graph_definition = GraphDefinition(node_definiton = NodesAsPulses(nb_nearest_neighbours=8),
                                   edge_definiton = KNNEdges(),
                                   detector = Prometheus(),
                                     )
                               

Alternatively, you can also just import this graph definition directly, as it's included in the PR:

from graphnet.models.graphs import KNNGraph

graph_definition = KNNGraph(
        detector=Prometheus(),
        node_definition=NodesAsPulses(),
        nb_nearest_neighbours=8,
        node_feature_names=features,
    )

Other things to note:

  1. Dataset is now simpler, as graph-altering code has been moved to GraphDefinition
  2. Changes are compatible with configuration files.

Merge request reports

Approval is optional

Closed by Jorge PradoJorge Prado 1 year ago (Jul 18, 2023 1:39pm UTC)

Merge details

  • The changes were not merged into main.

Activity

Filter activity
  • Approvals
  • Assignees & reviewers
  • Comments (from bots)
  • Comments (from users)
  • Commits & branches
  • Edits
  • Labels
  • Lock status
  • Mentions
  • Merge request status
  • Tracking
Please register or sign in to reply
Loading