From ae4badaa8d602215694d4d1ea3f9e90052c27d5c Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Mon, 26 Jun 2023 17:02:01 +0200 Subject: [PATCH 01/15] added GraphDefinition, EdgeDefinition --- src/graphnet/models/detector/__init__.py | 2 +- src/graphnet/models/graphs/__init__.py | 10 + src/graphnet/models/graphs/edges.py | 188 ++++++++++++++++++ src/graphnet/models/graphs/graph.py | 230 +++++++++++++++++++++++ 4 files changed, 429 insertions(+), 1 deletion(-) create mode 100644 src/graphnet/models/graphs/__init__.py create mode 100644 src/graphnet/models/graphs/edges.py create mode 100644 src/graphnet/models/graphs/graph.py diff --git a/src/graphnet/models/detector/__init__.py b/src/graphnet/models/detector/__init__.py index 17fba8fd2..2d36144da 100644 --- a/src/graphnet/models/detector/__init__.py +++ b/src/graphnet/models/detector/__init__.py @@ -1,3 +1,3 @@ """Detector-specific modules, for data ingestion and standardisation.""" -from .icecube import IceCube86, IceCubeDeepCore +from .icecube import IceCube86, IceCubeDeepCore, Detector diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py new file mode 100644 index 000000000..f2bd5cd78 --- /dev/null +++ b/src/graphnet/models/graphs/__init__.py @@ -0,0 +1,10 @@ +"""Modules for constructing graphs. + +´GraphDefinition´ defines the nodes and their features, and contains general +graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes +and their features. +""" + + +from .graph import GraphDefinition +from .edges import EdgeDefinition, KNNEdges, RadialEdges, EuclideanEdges diff --git a/src/graphnet/models/graphs/edges.py b/src/graphnet/models/graphs/edges.py new file mode 100644 index 000000000..cd2455f7d --- /dev/null +++ b/src/graphnet/models/graphs/edges.py @@ -0,0 +1,188 @@ +"""Class(es) for building/connecting graphs.""" + +from typing import List +from ABC import abstractmethod + +import torch +from torch_geometric.nn import knn_graph, radius_graph +from torch_geometric.data import Data + +from graphnet.utilities.config import save_model_config +from graphnet.models.utils import calculate_distance_matrix +from graphnet.models import Model + + +class EdgeDefinition(Model): # pylint: disable=too-few-public-methods + """Base class for graph building.""" + + def forward(self, graph: Data) -> Data: + """Construct edges based on problem specific implementation of. + + ´_construct_edges´ + + Args: + graph: a graph without edges + + Returns: + graph: a graph with edges + """ + if graph.edge_index is not None: + self.warnonce( + "GraphBuilder received graph with pre-existing " + "structure. Will overwrite." + ) + return self._construct_edges(graph) + + @abstractmethod + def _construct_edges(self, graph: Data) -> Data: + """Construct edges and assign them to graph. I.e. ´graph.edge_index = edge_index´. + + Args: + graph: graph without edges + + Returns: + graph: graph with edges assigned. + """ + + +class KNNEdges(EdgeDefinition): # pylint: disable=too-few-public-methods + """Builds edges from the k-nearest neighbours.""" + + @save_model_config + def __init__( + self, + nb_nearest_neighbours: int, + columns: List[int] = [0, 1, 2], + ): + """K-NN Edge definition. + + Will connect nodes together with their ´nb_nearest_neighbours´ + nearest neighbours in the feature space given by ´columns´. + + Args: + nb_nearest_neighbours: number of neighbours. + columns: Node features to use for distance calculation. + Defaults to [0,1,2]. + """ + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Member variable(s) + self._nb_nearest_neighbours = nb_nearest_neighbours + self._columns = columns + + def _construct_edges(self, graph: Data) -> Data: + """Define K-NN edges.""" + graph.edge_index = knn_graph( + graph.x[:, self._columns], + self._nb_nearest_neighbours, + graph.batch, + ).to(self.device) + + return graph + + +class RadialEdges(EdgeDefinition): + """Builds graph from a sphere of chosen radius centred at each node.""" + + @save_model_config + def __init__( + self, + radius: float, + columns: List[int] = [0, 1, 2], + ): + """Radial edges. + + Connects each node to other nodes that are within a sphere of + radius ´r´ centered at the node. The feature space of ´r´ is defined + by ´columns´ + + Args: + radius: radius of sphere + columns: columns of the node feature matrix used. + Defaults to [0,1,2]. + """ + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Member variable(s) + self._radius = radius + self._columns = columns + + def _construct_edges(self, graph: Data) -> Data: + """Define radial edges.""" + graph.edge_index = radius_graph( + graph.x[:, self._columns], + self._radius, + graph.batch, + ).to(self.device) + + return graph + + +class EuclideanEdges(EdgeDefinition): # pylint: disable=too-few-public-methods + """Builds edges according to Euclidean distance between nodes. + + See https://arxiv.org/pdf/1809.06166.pdf. + """ + + @save_model_config + def __init__( + self, + sigma: float, + threshold: float = 0.0, + columns: List[int] = None, + ): + """Construct `EuclideanEdges`.""" + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Check(s) + if columns is None: + columns = [0, 1, 2] + + # Member variable(s) + self._sigma = sigma + self._threshold = threshold + self._columns = columns + + def _construct_edges(self, graph: Data) -> Data: + """Forward pass.""" + # Constructs the adjacency matrix from the raw, DOM-level data and + # returns this matrix + if graph.edge_index is not None: + self.info( + "WARNING: GraphBuilder received graph with pre-existing " + "structure. Will overwrite." + ) + + xyz_coords = graph.x[:, self._columns] + + # Construct block-diagonal matrix indicating whether pulses belong to + # the same event in the batch + batch_mask = graph.batch.unsqueeze(dim=0) == graph.batch.unsqueeze( + dim=1 + ) + + distance_matrix = calculate_distance_matrix(xyz_coords) + affinity_matrix = torch.exp( + -0.5 * distance_matrix**2 / self._sigma**2 + ) + + # Use softmax to normalise all adjacencies to one for each node + exp_row_sums = torch.exp(affinity_matrix).sum(axis=1) + weighted_adj_matrix = torch.exp( + affinity_matrix + ) / exp_row_sums.unsqueeze(dim=1) + + # Only include edges with weights that exceed the chosen threshold (and + # are part of the same event) + sources, targets = torch.where( + (weighted_adj_matrix > self._threshold) & (batch_mask) + ) + edge_weights = weighted_adj_matrix[sources, targets] + + graph.edge_index = torch.stack((sources, targets)) + graph.edge_weight = edge_weights + + return graph diff --git a/src/graphnet/models/graphs/graph.py b/src/graphnet/models/graphs/graph.py new file mode 100644 index 000000000..70ddc75e0 --- /dev/null +++ b/src/graphnet/models/graphs/graph.py @@ -0,0 +1,230 @@ +"""Modules for defining graphs. + +These are self-contained graph definitions that hold all the graph-altering +code in graphnet. These modules define what the GNNs sees as input and can be +passed to dataloaders during training and deployment. +""" + + +from typing import Tuple, Any, List, Optional, Union, Dict, Callable +from abc import abstractmethod +import torch +from torch_geometric.data import Data +import numpy as np + +from graphnet.utilities.config import save_model_config +from graphnet.models.detector import Detector +from graphnet.models import Model +from graphnet.models.graphs import EdgeDefinition, KNNEdges + + +class GraphDefinition(Model): + """An Abstract class to create graph definitions from.""" + + @save_model_config + def __init__( + self, + detector: Detector, + edge_definition: EdgeDefinition, + ): + """Construct ´GraphDefinition´. The ´detector´ holds. + + ´Detector´-specific code. E.g. scaling/standardization and geometry + tables. + + ´edge_definition´ defines the connectivity of the graph. + + Args: + detector: The corresponding ´Detector´ representing the data. + Defaults to None. + edge_definition: Your choice in edges. Defaults to None. + """ + # Member Variables + self._detector = detector + self._edge_definiton = edge_definition + + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + @abstractmethod + def _create_graph(self, x: np.array) -> Data: + """Problem/model specific graph definition. + + Should not standardize/scale data. May assign edges. + + Args: + x: node features for a single event + + Returns: + Data object (a single graph) + """ + + def __call__( + self, + node_features: List[Tuple[float, ...]], + node_feature_names: List[str], + truth_dicts: List[Dict[str, Any]], + custom_label_functions: Dict[str, Callable[..., Any]], + loss_weight_column: Optional[str] = None, + loss_weight: Optional[float] = None, + loss_weight_default_value: Optional[float] = None, + data_path: Optional[str] = None, + ) -> Data: + """Construct graph as ´Data´ object. + + Args: + node_features: node features for graph. Shape ´[num_nodes, d]´ + node_feature_names: name of each column. Shape ´[,d]´. + truth_dicts: Dictionary containing truth labels. + custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels. + loss_weight_column: Name of column that holds loss weight. Defaults to None. + loss_weight: Loss weight associated with event. Defaults to None. + loss_weight_default_value: default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None. + data_path: Path to dataset data files. Defaults to None. + + Returns: + graph + """ + # Standardize / Scale node features + node_features = self._detector(node_features, node_feature_names) + + # Create graph + graph = self._create_graph(node_features) + + # Attach number of pulses as static attribute. + graph.n_pulses = torch.tensor(len(node_features), dtype=torch.int32) + + # Assign edges + if self._edge_definiton is not None: + graph = self._edge_definiton(graph) + else: + self.warnonce( + "No EdgeDefinition provided. Graphs will not have edges defined!" + ) + + # Attach data path - useful for Ensemble datasets. + if data_path is not None: + graph["dataset_path"] = data_path + + # Attach loss weights if they exist + graph = self._add_loss_weights( + graph=graph, + loss_weight=loss_weight, + loss_weight_column=loss_weight_column, + loss_weight_default_value=loss_weight_default_value, + ) + + # Attach default truth labels and node truths + graph = self._add_truth(graph=graph, truth_dicts=truth_dicts) + + # Attach custom truth labels + graph = self._add_custom_labels( + graph=graph, custom_label_functions=custom_label_functions + ) + + # Attach node features as seperate fields. MAY NOT CONTAIN 'x' + graph = self._add_features_individually( + graph=graph, node_feature_names=node_feature_names + ) + + return graph + + def _add_loss_weights( + self, + graph: Data, + loss_weight_column: Optional[str] = None, + loss_weight: Optional[float] = None, + loss_weight_default_value: Optional[float] = None, + ) -> Data: + """Attempt to store a loss weight in the graph for use during training. + + I.e. `graph[loss_weight_column] = loss_weight` + + Args: + loss_weight: The non-negative weight to be stored. + graph: Data object representing the event. + loss_weight_column: The name under which the weight is stored in + the graph. + loss_weight_default_value: The default value used if + none was retrieved. + + Returns: + A graph with loss weight added, if available. + """ + # Add loss weight to graph. + if loss_weight is not None and loss_weight_column is not None: + # No loss weight was retrieved, i.e., it is missing for the current + # event. + if loss_weight < 0: + if loss_weight_default_value is None: + raise ValueError( + "At least one event is missing an entry in " + f"{loss_weight_column} " + "but loss_weight_default_value is None." + ) + graph[loss_weight_column] = torch.tensor( + self._loss_weight_default_value, dtype=self._dtype + ).reshape(-1, 1) + else: + graph[loss_weight_column] = torch.tensor( + loss_weight, dtype=self._dtype + ).reshape(-1, 1) + return graph + + def _add_truth( + self, graph: Data, truth_dicts: List[Dict[str, Any]] + ) -> Data: + """Add truth labels from ´truth_dicts´ to ´graph´. + + I.e. ´graph[key] = truth_dict[key]´ + + + Args: + graph: graph where the label will be stored + truth_dicts: dictionary containing the labels + + Returns: + graph with labels + """ + # Write attributes, either target labels, truth info or original + # features. + for truth_dict in truth_dicts: + for key, value in truth_dict.items(): + try: + graph[key] = torch.tensor(value) + except TypeError: + # Cannot convert `value` to Tensor due to its data type, + # e.g. `str`. + self.debug( + ( + f"Could not assign `{key}` with type " + f"'{type(value).__name__}' as attribute to graph." + ) + ) + return graph + + def _add_features_individually( + self, + graph: Data, + node_feature_names: List[str], + ) -> Data: + # Additionally add original features as (static) attributes + graph.features = node_feature_names + for index, feature in enumerate(node_feature_names): + if feature not in ["x"]: # reserved for node features. + graph[feature] = graph.x[:, index].detach() + else: + self.warnonce( + """Cannot assign graph['x']. This field is reserved for node features. Please rename your input feature.""" + ) + return graph + + def _add_custom_labels( + self, + graph: Data, + custom_label_functions: Dict[str, Callable[..., Any]], + ) -> Data: + # Add custom labels to the graph + for key, fn in custom_label_functions.items(): + graph[key] = fn(graph) + return graph -- GitLab From 08cd6b43146daebc11e9c15da5b557dd620a767b Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Mon, 26 Jun 2023 21:53:36 +0200 Subject: [PATCH 02/15] refactor of Detector --- src/graphnet/models/detector/detector.py | 111 ++++------------------- src/graphnet/models/detector/icecube.py | 46 +++++----- 2 files changed, 42 insertions(+), 115 deletions(-) diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py index 4ad3cce48..1c4bd61b4 100644 --- a/src/graphnet/models/detector/detector.py +++ b/src/graphnet/models/detector/detector.py @@ -1,115 +1,38 @@ """Base detector-specific `Model` class(es).""" from abc import abstractmethod -from typing import List +from typing import Dict, Callable -import torch from torch_geometric.data import Data -from torch_geometric.data.batch import Batch -from graphnet.models.graph_builders import GraphBuilder from graphnet.models import Model -from graphnet.utilities.config import save_model_config from graphnet.utilities.decorators import final class Detector(Model): """Base class for all detector-specific read-ins in graphnet.""" - @property - @abstractmethod - def features(self) -> List[str]: - """List of features used/assumed by inheriting `Detector` objects.""" - - @save_model_config - def __init__( - self, graph_builder: GraphBuilder, scalers: List[dict] = None - ): + def __init__(self, feature_map: Dict[str, Callable]): """Construct `Detector`.""" # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) - - # Member variables - self._graph_builder = graph_builder - self._scalers = scalers - if self._scalers: - self.info( - ( - "Will use scalers rather than standard preprocessing " - f"in {self.__class__.__name__}." - ) - ) + self.feature_map = feature_map @final - def forward(self, data: Data) -> Data: + def forward(self, graph: Data) -> Data: """Pre-process graph `Data` features and build graph adjacency.""" # Check(s) - assert data.x.size()[1] == self.nb_inputs, ( - "Got graph data with incompatible size, ", - f"{data.x.size()} vs. {self.nb_inputs} expected", - ) - - # Graph-bulding - # @NOTE: `.clone` is necessary to avoid modifying original tensor in-place. - data = self._graph_builder(data).clone() - - if self._scalers: - # # Scaling individual features - # x_numpy = data.x.detach().cpu().numpy() - # for key, scaler in self._scalers.items(): - # ix = self.features.index(key) - # data.x[:,ix] = torch.tensor(scaler.transform(x_numpy[:,ix])).type_as(data.x) - - # Scaling groups of features | @TEMP, probably - x_numpy = data.x.detach().cpu().numpy() - - data.x[:, :3] = torch.tensor( - self._scalers["xyz"].transform(x_numpy[:, :3]) # type: ignore[call-overload] - ).type_as(data.x) - - data.x[:, 3:] = torch.tensor( - self._scalers["features"].transform(x_numpy[:, 3:]) # type: ignore[call-overload] - ).type_as(data.x) - - else: - # Implementation-specific forward pass (e.g. preprocessing) - data = self._forward(data) + assert isinstance(graph, Data) + return self._standardize(graph) - return data - - @abstractmethod - def _forward(self, data: Data) -> Data: - """Syntax like `.forward`, for implentation in inheriting classes.""" - - @property - def nb_inputs(self) -> int: - """Return number of input features.""" - return len(self.features) - - @property - def nb_outputs(self) -> int: - """Return number of output features. - - This the default, but may be overridden by specific inheriting classes. - """ - return self.nb_inputs - - def _validate_features(self, data: Data) -> None: - if isinstance(data, Batch): - # `data.features` is "transposed" and each list element contains only duplicate entries. - - if ( - len(data.features[0]) == data.num_graphs - and len(set(data.features[0])) == 1 - ): - data_features = [features[0] for features in data.features] - - # `data.features` is not "transposed" and each list element - # contains the original features. - else: - data_features = data.features[0] - else: - data_features = data.features - assert ( - data_features == self.features - ), f"Features on Data and Detector differ: {data_features} vs. {self.features}" + @final + def _standardize(self, graph: Data) -> Data: + for feature, idx in graph.features: + try: + graph.x[:, idx] = self.feature_map[feature](graph.x[:, idx]) + except KeyError as e: + self.warning( + f"""No Standardization function found for '{feature}'""" + ) + raise e + return graph diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py index e6addb8e3..231f0b111 100644 --- a/src/graphnet/models/detector/icecube.py +++ b/src/graphnet/models/detector/icecube.py @@ -15,32 +15,36 @@ from graphnet.models.detector.detector import Detector class IceCube86(Detector): """`Detector` class for IceCube-86.""" - # Implementing abstract class attribute - features = FEATURES.ICECUBE86 + def __init__(self) -> None: + """Construct `Detector`.""" + feature_map = { + "dom_x": self._dom_xyz, + "dom_y": self._dom_xyz, + "dom_z": self._dom_xyz, + "dom_time": self._dom_time, + "charge": self._charge, + "rde": self._rde, + "pmt_area": self._pmt_area, + } + # Base class constructor + super().__init__( + feature_map=feature_map, + ) - def _forward(self, data: Data) -> Data: - """Ingest data, build graph, and preprocess features. + def _dom_xyz(self, x: torch.tensor) -> torch.tensor: + return x / 500.0 - Args: - data: Input graph data. + def _dom_time(self, x: torch.tensor) -> torch.tensor: + return x - 1.0e04 / 3.0e4 - Returns: - Connected and preprocessed graph data. - """ - # Check(s) - self._validate_features(data) + def _charge(self, x: torch.tensor) -> torch.tensor: + return torch.log10(x) - # Preprocessing - data.x[:, 0] /= 500.0 # dom_x - data.x[:, 1] /= 500.0 # dom_y - data.x[:, 2] /= 500.0 # dom_z - data.x[:, 3] = (data.x[:, 3] - 1.0e04) / 3.0e4 # dom_time - data.x[:, 4] = torch.log10(data.x[:, 4]) / 3.0 # charge - data.x[:, 5] -= 1.25 # rde - data.x[:, 5] /= 0.25 - data.x[:, 6] /= 0.05 # pmt_area + def _rde(self, x: torch.tensor) -> torch.tensor: + return (x - 1.25) / 0.25 - return data + def _pmt_area(self, x: torch.tensor) -> torch.tensor: + return x / 0.05 class IceCubeKaggle(Detector): -- GitLab From 238324c6e89b15d0d16a50cdb89595ffca266e48 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Tue, 27 Jun 2023 10:51:36 +0200 Subject: [PATCH 03/15] add KNNGraph --- src/graphnet/models/graphs/graph.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/graphnet/models/graphs/graph.py b/src/graphnet/models/graphs/graph.py index 70ddc75e0..5aea8f424 100644 --- a/src/graphnet/models/graphs/graph.py +++ b/src/graphnet/models/graphs/graph.py @@ -228,3 +228,24 @@ class GraphDefinition(Model): for key, fn in custom_label_functions.items(): graph[key] = fn(graph) return graph + + +class KNNGraph(GraphDefinition): + """A graph with K-NN Edges.""" + + def __init__( + self, columns: List[int] = [0, 1, 2], nb_nearest_neighbours: int = 8 + ): + """Construct ´KNNGraph´. + + Args: + columns: Node feature dimensions used for K-NN computation. + Defaults to [0, 1, 2]. + nb_nearest_neighbours: Number of neighbours for each node. + Defaults to 8. + """ + super().__init__( + edge_definition=KNNEdges( + nb_nearest_neighbours=nb_nearest_neighbours, columns=columns + ) + ) -- GitLab From ea52b8d23fd5bb525d347748c4621fa7e4db1da9 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Wed, 28 Jun 2023 13:11:30 +0200 Subject: [PATCH 04/15] polish --- src/graphnet/models/detector/__init__.py | 3 +- src/graphnet/models/detector/detector.py | 33 +++-- src/graphnet/models/detector/icecube.py | 11 +- src/graphnet/models/graphs/__init__.py | 1 - src/graphnet/models/graphs/edges/__init__.py | 7 + .../models/graphs/{ => edges}/edges.py | 2 +- src/graphnet/models/graphs/graph.py | 124 ++++++++++-------- src/graphnet/models/graphs/nodes/__init__.py | 8 ++ src/graphnet/models/graphs/nodes/nodes.py | 74 +++++++++++ 9 files changed, 190 insertions(+), 73 deletions(-) create mode 100644 src/graphnet/models/graphs/edges/__init__.py rename src/graphnet/models/graphs/{ => edges}/edges.py (99%) create mode 100644 src/graphnet/models/graphs/nodes/__init__.py create mode 100644 src/graphnet/models/graphs/nodes/nodes.py diff --git a/src/graphnet/models/detector/__init__.py b/src/graphnet/models/detector/__init__.py index 2d36144da..87c85fb98 100644 --- a/src/graphnet/models/detector/__init__.py +++ b/src/graphnet/models/detector/__init__.py @@ -1,3 +1,4 @@ """Detector-specific modules, for data ingestion and standardisation.""" -from .icecube import IceCube86, IceCubeDeepCore, Detector +from .icecube import * # IceCube86, IceCubeDeepCore +from .detector import * # Detector diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py index 1c4bd61b4..25660c6d7 100644 --- a/src/graphnet/models/detector/detector.py +++ b/src/graphnet/models/detector/detector.py @@ -1,38 +1,49 @@ """Base detector-specific `Model` class(es).""" from abc import abstractmethod -from typing import Dict, Callable +from typing import Dict, Callable, List from torch_geometric.data import Data +import torch from graphnet.models import Model from graphnet.utilities.decorators import final +from graphnet.utilities.config import save_model_config class Detector(Model): """Base class for all detector-specific read-ins in graphnet.""" - def __init__(self, feature_map: Dict[str, Callable]): + @save_model_config + def __init__(self) -> None: """Construct `Detector`.""" # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) - self.feature_map = feature_map + + @property + @abstractmethod + def feature_map(self) -> Dict[str, Callable]: + """List of features used/assumed by inheriting `Detector` objects.""" @final - def forward(self, graph: Data) -> Data: + def forward( + self, node_features: torch.tensor, node_feature_names: List[str] + ) -> Data: """Pre-process graph `Data` features and build graph adjacency.""" - # Check(s) - assert isinstance(graph, Data) - return self._standardize(graph) + return self._standardize(node_features, node_feature_names) @final - def _standardize(self, graph: Data) -> Data: - for feature, idx in graph.features: + def _standardize( + self, node_features: torch.tensor, node_feature_names: List[str] + ) -> Data: + for idx, feature in enumerate(node_feature_names): try: - graph.x[:, idx] = self.feature_map[feature](graph.x[:, idx]) + node_features[:, idx] = self.feature_map()[feature]( # type: ignore + node_features[:, idx] + ) except KeyError as e: self.warning( f"""No Standardization function found for '{feature}'""" ) raise e - return graph + return node_features diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py index 231f0b111..67b168feb 100644 --- a/src/graphnet/models/detector/icecube.py +++ b/src/graphnet/models/detector/icecube.py @@ -1,5 +1,6 @@ """IceCube-specific `Detector` class(es).""" +from typing import Dict, Callable import torch from torch_geometric.data import Data @@ -10,13 +11,14 @@ from graphnet.models.components.pool import ( ) from graphnet.data.constants import FEATURES from graphnet.models.detector.detector import Detector +from graphnet.utilities.config import save_model_config class IceCube86(Detector): """`Detector` class for IceCube-86.""" - def __init__(self) -> None: - """Construct `Detector`.""" + def feature_map(self) -> Dict[str, Callable]: + """Map standardization functions to each dimension.""" feature_map = { "dom_x": self._dom_xyz, "dom_y": self._dom_xyz, @@ -26,10 +28,7 @@ class IceCube86(Detector): "rde": self._rde, "pmt_area": self._pmt_area, } - # Base class constructor - super().__init__( - feature_map=feature_map, - ) + return feature_map def _dom_xyz(self, x: torch.tensor) -> torch.tensor: return x / 500.0 diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index f2bd5cd78..97949c139 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -7,4 +7,3 @@ and their features. from .graph import GraphDefinition -from .edges import EdgeDefinition, KNNEdges, RadialEdges, EuclideanEdges diff --git a/src/graphnet/models/graphs/edges/__init__.py b/src/graphnet/models/graphs/edges/__init__.py new file mode 100644 index 000000000..7da8baa7c --- /dev/null +++ b/src/graphnet/models/graphs/edges/__init__.py @@ -0,0 +1,7 @@ +"""Modules for constructing graphs. + +´GraphDefinition´ defines the nodes and their features, and contains general +graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes +and their features. +""" +from .edges import EdgeDefinition, KNNEdges, RadialEdges, EuclideanEdges diff --git a/src/graphnet/models/graphs/edges.py b/src/graphnet/models/graphs/edges/edges.py similarity index 99% rename from src/graphnet/models/graphs/edges.py rename to src/graphnet/models/graphs/edges/edges.py index cd2455f7d..28507058b 100644 --- a/src/graphnet/models/graphs/edges.py +++ b/src/graphnet/models/graphs/edges/edges.py @@ -1,7 +1,7 @@ """Class(es) for building/connecting graphs.""" from typing import List -from ABC import abstractmethod +from abc import abstractmethod, ABC import torch from torch_geometric.nn import knn_graph, radius_graph diff --git a/src/graphnet/models/graphs/graph.py b/src/graphnet/models/graphs/graph.py index 5aea8f424..4c929315c 100644 --- a/src/graphnet/models/graphs/graph.py +++ b/src/graphnet/models/graphs/graph.py @@ -7,15 +7,20 @@ passed to dataloaders during training and deployment. from typing import Tuple, Any, List, Optional, Union, Dict, Callable -from abc import abstractmethod +from abc import abstractmethod, ABC import torch from torch_geometric.data import Data import numpy as np -from graphnet.utilities.config import save_model_config +from graphnet.utilities.config import Configurable +from graphnet.utilities.config import ( + save_model_config, +) # .graph_config import save_graph_config, GraphConfig +from graphnet.utilities.logging import Logger from graphnet.models.detector import Detector +from .edges import EdgeDefinition, KNNEdges +from .nodes import NodeDefinition from graphnet.models import Model -from graphnet.models.graphs import EdgeDefinition, KNNEdges class GraphDefinition(Model): @@ -25,46 +30,53 @@ class GraphDefinition(Model): def __init__( self, detector: Detector, - edge_definition: EdgeDefinition, + node_definition: NodeDefinition, + edge_definition: Optional[EdgeDefinition] = None, + node_feature_names: Optional[List[str]] = None, + dtype: torch.dtype = None, # torch.float, ): """Construct ´GraphDefinition´. The ´detector´ holds. ´Detector´-specific code. E.g. scaling/standardization and geometry tables. - ´edge_definition´ defines the connectivity of the graph. + ´node_definition´ defines the nodes in the graph. + + ´edge_definition´ defines the connectivity of the nodes in the graph. Args: detector: The corresponding ´Detector´ representing the data. - Defaults to None. - edge_definition: Your choice in edges. Defaults to None. + node_definition: Definition of nodes. + edge_definition: Definition of edges. Defaults to None. + node_feature_names: Names of node feature columns. Defaults to None + dtype: data type used for node features. e.g. ´torch.float´ """ - # Member Variables - self._detector = detector - self._edge_definiton = edge_definition - # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) - @abstractmethod - def _create_graph(self, x: np.array) -> Data: - """Problem/model specific graph definition. - - Should not standardize/scale data. May assign edges. - - Args: - x: node features for a single event - - Returns: - Data object (a single graph) - """ + # Member Variables + self._detector = detector + self._edge_definiton = edge_definition + self._node_definition = node_definition + if node_feature_names is None: + # Assume all features in Detector is used. + node_feature_names = list(self._detector.feature_map().keys()) # type: ignore + self._node_feature_names = node_feature_names + self._dtype = dtype + + # Set Input / Output dimensions + self._node_definition.set_number_of_inputs( + node_feature_names=node_feature_names + ) + self.nb_inputs = len(self._node_feature_names) + self.nb_outputs = self._node_definition.nb_outputs - def __call__( + def forward( # type: ignore self, - node_features: List[Tuple[float, ...]], + node_features: np.array, node_feature_names: List[str], - truth_dicts: List[Dict[str, Any]], - custom_label_functions: Dict[str, Callable[..., Any]], + truth_dicts: Optional[List[Dict[str, Any]]] = None, + custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None, loss_weight_column: Optional[str] = None, loss_weight: Optional[float] = None, loss_weight_default_value: Optional[float] = None, @@ -85,11 +97,19 @@ class GraphDefinition(Model): Returns: graph """ + # Checks + self._validate_input( + node_features=node_features, node_feature_names=node_feature_names + ) + + # Transform to pytorch tensor + node_features = torch.tensor(node_features, dtype=self._dtype) + # Standardize / Scale node features node_features = self._detector(node_features, node_feature_names) # Create graph - graph = self._create_graph(node_features) + graph = self._node_definition(node_features) # Attach number of pulses as static attribute. graph.n_pulses = torch.tensor(len(node_features), dtype=torch.int32) @@ -115,12 +135,14 @@ class GraphDefinition(Model): ) # Attach default truth labels and node truths - graph = self._add_truth(graph=graph, truth_dicts=truth_dicts) + if truth_dicts is not None: + graph = self._add_truth(graph=graph, truth_dicts=truth_dicts) # Attach custom truth labels - graph = self._add_custom_labels( - graph=graph, custom_label_functions=custom_label_functions - ) + if custom_label_functions is not None: + graph = self._add_custom_labels( + graph=graph, custom_label_functions=custom_label_functions + ) # Attach node features as seperate fields. MAY NOT CONTAIN 'x' graph = self._add_features_individually( @@ -129,6 +151,23 @@ class GraphDefinition(Model): return graph + def _validate_input( + self, node_features: np.array, node_feature_names: List[str] + ) -> None: + + # node feature matrix dimension check + assert node_features.shape[1] == len(node_feature_names) + + # check that provided features for input is the same that the ´Graph´ + # was instantiated with. + assert len(node_feature_names) == len( + self._node_feature_names + ), f"""Input features ({node_feature_names}) is not what {self.__class__.__name__} was instatiated with ({self._node_feature_names})""" + for idx in range(len(node_feature_names)): + assert ( + node_feature_names[idx] == self._node_feature_names[idx] + ), """ Order of node features are not the same.""" + def _add_loss_weights( self, graph: Data, @@ -228,24 +267,3 @@ class GraphDefinition(Model): for key, fn in custom_label_functions.items(): graph[key] = fn(graph) return graph - - -class KNNGraph(GraphDefinition): - """A graph with K-NN Edges.""" - - def __init__( - self, columns: List[int] = [0, 1, 2], nb_nearest_neighbours: int = 8 - ): - """Construct ´KNNGraph´. - - Args: - columns: Node feature dimensions used for K-NN computation. - Defaults to [0, 1, 2]. - nb_nearest_neighbours: Number of neighbours for each node. - Defaults to 8. - """ - super().__init__( - edge_definition=KNNEdges( - nb_nearest_neighbours=nb_nearest_neighbours, columns=columns - ) - ) diff --git a/src/graphnet/models/graphs/nodes/__init__.py b/src/graphnet/models/graphs/nodes/__init__.py new file mode 100644 index 000000000..05194b61a --- /dev/null +++ b/src/graphnet/models/graphs/nodes/__init__.py @@ -0,0 +1,8 @@ +"""Modules for constructing graphs. + +´GraphDefinition´ defines the nodes and their features, and contains general +graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes +and their features. +""" + +from .nodes import NodeDefinition, NodesAsPulses diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py new file mode 100644 index 000000000..afe0dcfce --- /dev/null +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -0,0 +1,74 @@ +"""Class(es) for building/connecting graphs.""" + +from typing import List +from abc import abstractmethod + +import torch +from torch_geometric.nn import knn_graph, radius_graph +from torch_geometric.data import Data +import numpy as np + +from graphnet.utilities.decorators import final +from graphnet.utilities.config import save_model_config +from graphnet.models import Model + + +class NodeDefinition(Model): # pylint: disable=too-few-public-methods + """Base class for graph building.""" + + @save_model_config + def __init__(self) -> None: + """Construct `Detector`.""" + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + @final + def forward(self, x: torch.tensor) -> Data: + """Construct nodes from raw node features. + + Args: + x: standardized node features with shape ´[num_pulses, d]´, + where ´d´ is the number of node features. + + Returns: + graph: a graph without edges + """ + graph = self._construct_nodes(x) + return graph + + @property + def nb_outputs(self) -> int: + """Return number of output features. + + This the default, but may be overridden by specific inheriting classes. + """ + return self.nb_inputs + + @final + def set_number_of_inputs(self, node_feature_names: List[str]) -> None: + """Return number of inputs expected by node definition. + + Args: + node_feature_names: name of each node feature column. + """ + assert isinstance(node_feature_names, list) + self.nb_inputs = len(node_feature_names) + + @abstractmethod + def _construct_nodes(self, x: torch.tensor) -> Data: + """Construct nodes from raw node features ´x´. + + Args: + x: standardized node features with shape ´[num_pulses, d]´, + where ´d´ is the number of node features. + + Returns: + graph: graph without edges. + """ + + +class NodesAsPulses(NodeDefinition): + """Represent each measured pulse of Cherenkov Radiation as a node.""" + + def _construct_nodes(self, x: torch.tensor) -> Data: + return Data(x=x) -- GitLab From f7e93b29f609dfa02fc03c0ef136724a656355fb Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Wed, 28 Jun 2023 13:13:40 +0200 Subject: [PATCH 05/15] polish --- src/graphnet/models/graphs/graph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/graph.py b/src/graphnet/models/graphs/graph.py index 4c929315c..a608342b4 100644 --- a/src/graphnet/models/graphs/graph.py +++ b/src/graphnet/models/graphs/graph.py @@ -33,7 +33,7 @@ class GraphDefinition(Model): node_definition: NodeDefinition, edge_definition: Optional[EdgeDefinition] = None, node_feature_names: Optional[List[str]] = None, - dtype: torch.dtype = None, # torch.float, + dtype: Optional[torch.dtype] = None, ): """Construct ´GraphDefinition´. The ´detector´ holds. @@ -62,6 +62,8 @@ class GraphDefinition(Model): # Assume all features in Detector is used. node_feature_names = list(self._detector.feature_map().keys()) # type: ignore self._node_feature_names = node_feature_names + if dtype is None: + dtype = torch.float self._dtype = dtype # Set Input / Output dimensions -- GitLab From 8a1cacec159efda2c058b79217f9f2dcca6c9ab9 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Wed, 28 Jun 2023 13:23:36 +0200 Subject: [PATCH 06/15] replace Detector with GraphDefinition for StandardModel --- src/graphnet/models/standard_model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index 41b70bb26..a56ee953d 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -12,7 +12,7 @@ import pandas as pd from graphnet.models.coarsening import Coarsening from graphnet.utilities.config import save_model_config -from graphnet.models.detector.detector import Detector +from graphnet.models.graphs import GraphDefinition from graphnet.models.gnn.gnn import GNN from graphnet.models.model import Model from graphnet.models.task import Task @@ -29,7 +29,7 @@ class StandardModel(Model): def __init__( self, *, - detector: Detector, + graph_definition: GraphDefinition, gnn: GNN, tasks: Union[Task, List[Task]], coarsening: Optional[Coarsening] = None, @@ -48,12 +48,12 @@ class StandardModel(Model): tasks = [tasks] assert isinstance(tasks, (list, tuple)) assert all(isinstance(task, Task) for task in tasks) - assert isinstance(detector, Detector) + assert isinstance(graph_definition, GraphDefinition) assert isinstance(gnn, GNN) assert coarsening is None or isinstance(coarsening, Coarsening) # Member variable(s) - self._detector = detector + self._graph_definition = graph_definition self._gnn = gnn self._tasks = ModuleList(tasks) self._coarsening = coarsening @@ -101,7 +101,11 @@ class StandardModel(Model): """Forward pass, chaining model components.""" if self._coarsening: data = self._coarsening(data) - data = self._detector(data) + assert isinstance(data, Data) + if data.graph_definition != self._graph_definition.__class__.__name__: + self.warn( + f"Model expects {self._graph_definition.__class__.__name__} but is given {data.graph_definition}" + ) x = self._gnn(data) preds = [task(x) for task in self._tasks] return preds -- GitLab From ccf9d3261007a73142e1f0e4f21613ba4d792cac Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Wed, 28 Jun 2023 13:57:13 +0200 Subject: [PATCH 07/15] Simplify Dataset, polish --- src/graphnet/data/dataset.py | 93 ++++++------------- src/graphnet/models/detector/__init__.py | 4 +- src/graphnet/models/graphs/__init__.py | 2 +- .../graphs/{graph.py => graph_definition.py} | 12 +-- 4 files changed, 36 insertions(+), 75 deletions(-) rename src/graphnet/models/graphs/{graph.py => graph_definition.py} (96%) diff --git a/src/graphnet/data/dataset.py b/src/graphnet/data/dataset.py index 7300da815..1352c8e20 100644 --- a/src/graphnet/data/dataset.py +++ b/src/graphnet/data/dataset.py @@ -29,6 +29,7 @@ from graphnet.utilities.config import ( save_dataset_config, ) from graphnet.utilities.logging import Logger +from graphnet.models.graphs import GraphDefinition class ColumnMissingException(Exception): @@ -150,6 +151,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): loss_weight_column: Optional[str] = None, loss_weight_default_value: Optional[float] = None, seed: Optional[int] = None, + graph_definition: GraphDefinition = None, ): """Construct Dataset. @@ -195,6 +197,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): subset of events when resolving a string-based selection (e.g., `"10000 random events ~ event_no % 5 > 0"` or `"20% random events ~ event_no % 5 > 0"`). + graph_definition: Method that defines the graph representation. """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -218,6 +221,10 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): self._index_column = index_column self._truth_table = truth_table self._loss_weight_default_value = loss_weight_default_value + # self.info( + # f"No GraphDefinition recieved. Defaulting to KNNGraph(nb_neighbours = 8)" + # ) + self._graph_definition = graph_definition if node_truth is not None: assert isinstance(node_truth_table, str) @@ -521,10 +528,6 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): ) -> Data: """Create Pytorch Data (i.e. graph) object. - No preprocessing is performed at this stage, just as no node adjancency - is imposed. This means that the `edge_attr` and `edge_weight` - attributes are not set. - Args: features: List of tuples, containing event features. truth: List of tuples, containing truth information. @@ -552,71 +555,33 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): for index, key in enumerate(self._node_truth) } + # Create list of truth dicts with labels + truth_dicts = [labels_dict, truth_dict] + if node_truth is not None: + truth_dicts.append(node_truth_dict) + # Catch cases with no reconstructed pulses if len(features): - data = np.asarray(features)[:, 1:] + node_features = np.asarray(features)[ + :, 1: + ] # first entry is index column else: - data = np.array([]).reshape((0, len(self._features) - 1)) + node_features = np.array([]).reshape((0, len(self._features) - 1)) # Construct graph data object - x = torch.tensor(data, dtype=self._dtype) # pylint: disable=C0103 - n_pulses = torch.tensor(len(x), dtype=torch.int32) - graph = Data(x=x, edge_index=None) - graph.n_pulses = n_pulses - graph.features = self._features[1:] - - # Add loss weight to graph. - if loss_weight is not None and self._loss_weight_column is not None: - # No loss weight was retrieved, i.e., it is missing for the current - # event. - if loss_weight < 0: - if self._loss_weight_default_value is None: - raise ValueError( - "At least one event is missing an entry in " - f"{self._loss_weight_column} " - "but loss_weight_default_value is None." - ) - graph[self._loss_weight_column] = torch.tensor( - self._loss_weight_default_value, dtype=self._dtype - ).reshape(-1, 1) - else: - graph[self._loss_weight_column] = torch.tensor( - loss_weight, dtype=self._dtype - ).reshape(-1, 1) - - # Write attributes, either target labels, truth info or original - # features. - add_these_to_graph = [labels_dict, truth_dict] - if node_truth is not None: - add_these_to_graph.append(node_truth_dict) - for write_dict in add_these_to_graph: - for key, value in write_dict.items(): - try: - graph[key] = torch.tensor(value) - except (TypeError, RuntimeError) as error: - if isinstance(error, TypeError) or (value is None): - # Cannot convert `value` to Tensor due to its data type, - # e.g. `str`. - self.warning( - ( - f"Could not assign `{key}` with type " - f"'{type(value).__name__ if value is not None else 'NoneType'}' as attribute to graph of " - f"event with {self._index_column} == {labels_dict[self._index_column]}" - ) - ) - else: - raise error - # Additionally add original features as (static) attributes - for index, feature in enumerate(graph.features): - if feature not in ["x"]: - graph[feature] = graph.x[:, index].detach() - - # Add custom labels to the graph - for key, fn in self._label_fns.items(): - graph[key] = fn(graph) - - # Add Dataset Path. Useful if multiple datasets are concatenated. - graph["dataset_path"] = self._path + assert self._graph_definition is not None + graph = self._graph_definition( + node_features=node_features, + node_feature_names=self._features[ + 1: + ], # first entry is index column + truth_dicts=truth_dicts, + custom_label_functions=self._label_fns, + loss_weight_column=self._loss_weight_column, + loss_weight=loss_weight, + loss_weight_default_value=self._loss_weight_default_value, + data_path=self._path, + ) return graph def _get_labels(self, truth_dict: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/graphnet/models/detector/__init__.py b/src/graphnet/models/detector/__init__.py index 87c85fb98..060b7ca03 100644 --- a/src/graphnet/models/detector/__init__.py +++ b/src/graphnet/models/detector/__init__.py @@ -1,4 +1,4 @@ """Detector-specific modules, for data ingestion and standardisation.""" -from .icecube import * # IceCube86, IceCubeDeepCore -from .detector import * # Detector +from .icecube import IceCube86, IceCubeDeepCore +from .detector import Detector diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index 97949c139..13ab6cbba 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -6,4 +6,4 @@ and their features. """ -from .graph import GraphDefinition +from .graph_definition import GraphDefinition diff --git a/src/graphnet/models/graphs/graph.py b/src/graphnet/models/graphs/graph_definition.py similarity index 96% rename from src/graphnet/models/graphs/graph.py rename to src/graphnet/models/graphs/graph_definition.py index a608342b4..ec035ad48 100644 --- a/src/graphnet/models/graphs/graph.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -6,19 +6,15 @@ passed to dataloaders during training and deployment. """ -from typing import Tuple, Any, List, Optional, Union, Dict, Callable -from abc import abstractmethod, ABC +from typing import Any, List, Optional, Dict, Callable import torch from torch_geometric.data import Data import numpy as np -from graphnet.utilities.config import Configurable -from graphnet.utilities.config import ( - save_model_config, -) # .graph_config import save_graph_config, GraphConfig -from graphnet.utilities.logging import Logger +from graphnet.utilities.config import save_model_config + from graphnet.models.detector import Detector -from .edges import EdgeDefinition, KNNEdges +from .edges import EdgeDefinition from .nodes import NodeDefinition from graphnet.models import Model -- GitLab From b4b10235561874877cf35b90d0eaf476be493526 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Wed, 28 Jun 2023 14:25:13 +0200 Subject: [PATCH 08/15] Restructure --- src/graphnet/data/__init__.py | 11 ----------- src/graphnet/data/dataset/__init__.py | 14 ++++++++++++++ src/graphnet/data/{ => dataset}/dataset.py | 0 .../data/{ => dataset}/parquet/parquet_dataset.py | 2 +- .../data/{ => dataset}/sqlite/sqlite_dataset.py | 2 +- .../sqlite/sqlite_dataset_perturbed.py | 2 +- src/graphnet/data/parquet/__init__.py | 8 -------- src/graphnet/data/sqlite/__init__.py | 7 ------- src/graphnet/training/utils.py | 4 ++-- 9 files changed, 19 insertions(+), 31 deletions(-) create mode 100644 src/graphnet/data/dataset/__init__.py rename src/graphnet/data/{ => dataset}/dataset.py (100%) rename src/graphnet/data/{ => dataset}/parquet/parquet_dataset.py (98%) rename src/graphnet/data/{ => dataset}/sqlite/sqlite_dataset.py (98%) rename src/graphnet/data/{ => dataset}/sqlite/sqlite_dataset_perturbed.py (98%) diff --git a/src/graphnet/data/__init__.py b/src/graphnet/data/__init__.py index 64cef139d..1eca4f6cd 100644 --- a/src/graphnet/data/__init__.py +++ b/src/graphnet/data/__init__.py @@ -3,14 +3,3 @@ `graphnet.data` enables converting domain-specific data to industry-standard, intermediate file formats and reading this data. """ - -# Configuration -from graphnet.utilities.imports import has_torch_package - -if has_torch_package(): - import torch.multiprocessing - from .dataset import EnsembleDataset - - torch.multiprocessing.set_sharing_strategy("file_system") - -del has_torch_package diff --git a/src/graphnet/data/dataset/__init__.py b/src/graphnet/data/dataset/__init__.py new file mode 100644 index 000000000..3ccdd9642 --- /dev/null +++ b/src/graphnet/data/dataset/__init__.py @@ -0,0 +1,14 @@ +"""Dataset classes for training in GraphNeT.""" +# Configuration +from graphnet.utilities.imports import has_torch_package + +if has_torch_package(): + import torch.multiprocessing + from .dataset import EnsembleDataset, Dataset, ColumnMissingException + from .parquet.parquet_dataset import ParquetDataset + from .sqlite.sqlite_dataset import SQLiteDataset + from .sqlite.sqlite_dataset_perturbed import SQLiteDatasetPerturbed + + torch.multiprocessing.set_sharing_strategy("file_system") + +del has_torch_package diff --git a/src/graphnet/data/dataset.py b/src/graphnet/data/dataset/dataset.py similarity index 100% rename from src/graphnet/data/dataset.py rename to src/graphnet/data/dataset/dataset.py diff --git a/src/graphnet/data/parquet/parquet_dataset.py b/src/graphnet/data/dataset/parquet/parquet_dataset.py similarity index 98% rename from src/graphnet/data/parquet/parquet_dataset.py rename to src/graphnet/data/dataset/parquet/parquet_dataset.py index 7839bd983..bb63e1800 100644 --- a/src/graphnet/data/parquet/parquet_dataset.py +++ b/src/graphnet/data/dataset/parquet/parquet_dataset.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast import numpy as np import awkward as ak -from graphnet.data.dataset import Dataset, ColumnMissingException +from graphnet.data.dataset.dataset import Dataset, ColumnMissingException class ParquetDataset(Dataset): diff --git a/src/graphnet/data/sqlite/sqlite_dataset.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset.py similarity index 98% rename from src/graphnet/data/sqlite/sqlite_dataset.py rename to src/graphnet/data/dataset/sqlite/sqlite_dataset.py index e61623c46..a0b06ff66 100644 --- a/src/graphnet/data/sqlite/sqlite_dataset.py +++ b/src/graphnet/data/dataset/sqlite/sqlite_dataset.py @@ -4,7 +4,7 @@ from typing import Any, List, Optional, Tuple, Union import pandas as pd import sqlite3 -from graphnet.data.dataset import Dataset, ColumnMissingException +from graphnet.data.dataset.dataset import Dataset, ColumnMissingException class SQLiteDataset(Dataset): diff --git a/src/graphnet/data/sqlite/sqlite_dataset_perturbed.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py similarity index 98% rename from src/graphnet/data/sqlite/sqlite_dataset_perturbed.py rename to src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py index 6fb977ccd..755d96b82 100644 --- a/src/graphnet/data/sqlite/sqlite_dataset_perturbed.py +++ b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py @@ -6,7 +6,7 @@ import numpy as np import torch from torch_geometric.data import Data -from graphnet.data.sqlite.sqlite_dataset import SQLiteDataset +from .sqlite_dataset import SQLiteDataset class SQLiteDatasetPerturbed(SQLiteDataset): diff --git a/src/graphnet/data/parquet/__init__.py b/src/graphnet/data/parquet/__init__.py index fc0b2f7a0..616d89c16 100644 --- a/src/graphnet/data/parquet/__init__.py +++ b/src/graphnet/data/parquet/__init__.py @@ -1,10 +1,2 @@ """Parquet-specific implementation of data classes.""" - -from graphnet.utilities.imports import has_torch_package - from .parquet_dataconverter import ParquetDataConverter - -if has_torch_package(): - from .parquet_dataset import ParquetDataset - -del has_torch_package diff --git a/src/graphnet/data/sqlite/__init__.py b/src/graphnet/data/sqlite/__init__.py index f632f2be8..e4ac554a7 100644 --- a/src/graphnet/data/sqlite/__init__.py +++ b/src/graphnet/data/sqlite/__init__.py @@ -1,11 +1,4 @@ """SQLite-specific implementation of data classes.""" - -from graphnet.utilities.imports import has_torch_package - from .sqlite_dataconverter import SQLiteDataConverter from .sqlite_utilities import create_table_and_save_to_sql - -if has_torch_package(): - from .sqlite_dataset import SQLiteDataset - from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed from .sqlite_utilities import run_sql_code, save_to_sql diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index 52b7634e8..2d99c39f1 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -12,8 +12,8 @@ from torch.utils.data import DataLoader from torch_geometric.data import Batch, Data from graphnet.data.dataset import Dataset -from graphnet.data.sqlite import SQLiteDataset -from graphnet.data.parquet import ParquetDataset +from graphnet.data.dataset.sqlite import SQLiteDataset +from graphnet.data.dataset.parquet import ParquetDataset from graphnet.models import Model from graphnet.utilities.logging import Logger -- GitLab From c18e5e18b443edbba27a5090f9afbff6acb323c5 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Wed, 28 Jun 2023 14:26:09 +0200 Subject: [PATCH 09/15] simplify imports --- tests/utilities/test_dataset_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_dataset_config.py b/tests/utilities/test_dataset_config.py index a09d0e5ad..88f3b3f1c 100644 --- a/tests/utilities/test_dataset_config.py +++ b/tests/utilities/test_dataset_config.py @@ -13,8 +13,8 @@ import graphnet import graphnet.constants from graphnet.data.constants import FEATURES, TRUTH from graphnet.data.dataset import Dataset -from graphnet.data.parquet import ParquetDataset -from graphnet.data.sqlite import SQLiteDataset +from graphnet.data.dataset import ParquetDataset +from graphnet.data.dataset import SQLiteDataset from graphnet.utilities.config import DatasetConfig CONFIG_PATHS = { -- GitLab From 99d9b5180062a2bbd54a6fca860d0726c930fd45 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Wed, 28 Jun 2023 14:27:52 +0200 Subject: [PATCH 10/15] simplify imports --- examples/02_data/01_read_dataset.py | 4 ++-- examples/04_training/03_train_classification_model.py | 2 +- src/graphnet/training/utils.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/02_data/01_read_dataset.py b/examples/02_data/01_read_dataset.py index a9a1f4d01..302529050 100644 --- a/examples/02_data/01_read_dataset.py +++ b/examples/02_data/01_read_dataset.py @@ -13,8 +13,8 @@ from tqdm import tqdm from graphnet.constants import TEST_PARQUET_DATA, TEST_SQLITE_DATA from graphnet.data.constants import FEATURES, TRUTH from graphnet.data.dataset import Dataset -from graphnet.data.sqlite.sqlite_dataset import SQLiteDataset -from graphnet.data.parquet.parquet_dataset import ParquetDataset +from graphnet.data.dataset import SQLiteDataset +from graphnet.data.dataset import ParquetDataset from graphnet.utilities.argparse import ArgumentParser from graphnet.utilities.logging import Logger diff --git a/examples/04_training/03_train_classification_model.py b/examples/04_training/03_train_classification_model.py index b403e4850..0c537bac0 100644 --- a/examples/04_training/03_train_classification_model.py +++ b/examples/04_training/03_train_classification_model.py @@ -6,7 +6,7 @@ from typing import List, Optional, Dict, Any from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities import rank_zero_only -from graphnet.data.dataset import EnsembleDataset +from graphnet.data.dataset.dataset import EnsembleDataset from graphnet.constants import ( EXAMPLE_OUTPUT_DIR, DATASETS_CONFIG_DIR, diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index 2d99c39f1..75f6855b2 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -12,8 +12,8 @@ from torch.utils.data import DataLoader from torch_geometric.data import Batch, Data from graphnet.data.dataset import Dataset -from graphnet.data.dataset.sqlite import SQLiteDataset -from graphnet.data.dataset.parquet import ParquetDataset +from graphnet.data.dataset import SQLiteDataset +from graphnet.data.dataset import ParquetDataset from graphnet.models import Model from graphnet.utilities.logging import Logger -- GitLab From 9b9434dc2e860819761e513a9669f96b2c370d86 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Wed, 28 Jun 2023 14:33:15 +0200 Subject: [PATCH 11/15] remove redundant import --- src/graphnet/models/detector/icecube.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py index 67b168feb..c2028755a 100644 --- a/src/graphnet/models/detector/icecube.py +++ b/src/graphnet/models/detector/icecube.py @@ -11,7 +11,6 @@ from graphnet.models.components.pool import ( ) from graphnet.data.constants import FEATURES from graphnet.models.detector.detector import Detector -from graphnet.utilities.config import save_model_config class IceCube86(Detector): -- GitLab From 8984a01a077e759329d727b285a1b1e32ca8dc77 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Wed, 28 Jun 2023 14:38:04 +0200 Subject: [PATCH 12/15] refactor of promethus detector class --- src/graphnet/models/detector/prometheus.py | 37 ++++++++++------------ 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/graphnet/models/detector/prometheus.py b/src/graphnet/models/detector/prometheus.py index 2a35886ac..aef337830 100644 --- a/src/graphnet/models/detector/prometheus.py +++ b/src/graphnet/models/detector/prometheus.py @@ -1,6 +1,7 @@ """Prometheus-specific `Detector` class(es).""" -from torch_geometric.data import Data +from typing import Dict, Callable +import torch from graphnet.models.detector.detector import Detector @@ -10,25 +11,21 @@ class Prometheus(Detector): features = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z", "t"] - def _forward(self, data: Data) -> Data: - """Ingest data, build graph, and preprocess features. + def feature_map(self) -> Dict[str, Callable]: + """Map standardization functions to each dimension.""" + feature_map = { + "sensor_pos_x": self._sensor_pos_xy, + "sensor_pos_y": self._sensor_pos_xy, + "sensor_pos_z": self._sensor_pos_z, + "t": self._t, + } + return feature_map - Args: - data: Input graph data. + def _sensor_pos_xy(self, x: torch.tensor) -> torch.tensor: + return x / 100 - Returns: - Connected and preprocessed graph data. - """ - # Check(s) - self._validate_features(data) + def _sensor_pos_z(self, x: torch.tensor) -> torch.tensor: + return (x + 350) / 100 - # Preprocessing - data.x[:, 0] /= 100.0 # dom_x - data.x[:, 1] /= 100.0 # dom_y - data.x[:, 2] += 350.0 # dom_z - data.x[:, 2] /= 100.0 - data.x[:, 3] /= 1.05e04 # dom_time - data.x[:, 3] -= 1.0 - data.x[:, 3] *= 20.0 - - return data + def _t(self, x: torch.tensor) -> torch.tensor: + return ((x / 1.05e04) - 1.0) * 20.0 -- GitLab From d5c8ade4a46d67f9999fc329de891ff019f7df76 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Wed, 28 Jun 2023 14:38:49 +0200 Subject: [PATCH 13/15] refactor of promethus detector class --- src/graphnet/models/detector/prometheus.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/graphnet/models/detector/prometheus.py b/src/graphnet/models/detector/prometheus.py index aef337830..f21f9c413 100644 --- a/src/graphnet/models/detector/prometheus.py +++ b/src/graphnet/models/detector/prometheus.py @@ -9,8 +9,6 @@ from graphnet.models.detector.detector import Detector class Prometheus(Detector): """`Detector` class for Prometheus prototype.""" - features = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z", "t"] - def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" feature_map = { -- GitLab From 6d1501904e0c7c813c4320181488b552921ac6b4 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Wed, 28 Jun 2023 15:38:30 +0200 Subject: [PATCH 14/15] refactor training example without configs --- .../02_train_model_without_configs.py | 19 +++++++--- src/graphnet/data/dataset/parquet/__init__.py | 11 ++++++ src/graphnet/data/dataset/sqlite/__init__.py | 6 +++ src/graphnet/models/graphs/__init__.py | 1 + .../models/graphs/graph_definition.py | 3 ++ src/graphnet/models/graphs/graphs.py | 38 +++++++++++++++++++ src/graphnet/models/standard_model.py | 4 -- src/graphnet/training/utils.py | 5 +++ 8 files changed, 77 insertions(+), 10 deletions(-) create mode 100644 src/graphnet/data/dataset/parquet/__init__.py create mode 100644 src/graphnet/data/dataset/sqlite/__init__.py create mode 100644 src/graphnet/models/graphs/graphs.py diff --git a/examples/04_training/02_train_model_without_configs.py b/examples/04_training/02_train_model_without_configs.py index 27e112ce5..4b336d830 100644 --- a/examples/04_training/02_train_model_without_configs.py +++ b/examples/04_training/02_train_model_without_configs.py @@ -13,7 +13,8 @@ from graphnet.data.constants import FEATURES, TRUTH from graphnet.models import StandardModel from graphnet.models.detector.prometheus import Prometheus from graphnet.models.gnn import DynEdge -from graphnet.models.graph_builders import KNNGraphBuilder +from graphnet.models.graphs import KNNGraph +from graphnet.models.graphs.nodes import NodesAsPulses from graphnet.models.task.reconstruction import EnergyReconstruction from graphnet.training.callbacks import ProgressBar, PiecewiseLinearLR from graphnet.training.loss_functions import LogCoshLoss @@ -77,11 +78,19 @@ def main( # Log configuration to W&B wandb_logger.experiment.config.update(config) + # Define graph representation + graph_definition = KNNGraph( + detector=Prometheus(), + node_definition=NodesAsPulses(), + nb_nearest_neighbours=8, + ) + ( training_dataloader, validation_dataloader, ) = make_train_validation_dataloader( config["path"], + graph_definition, None, config["pulsemap"], features, @@ -92,11 +101,9 @@ def main( ) # Building model - detector = Prometheus( - graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8), - ) + gnn = DynEdge( - nb_inputs=detector.nb_outputs, + nb_inputs=graph_definition.nb_outputs, global_pooling_schemes=["min", "max", "mean", "sum"], ) task = EnergyReconstruction( @@ -106,7 +113,7 @@ def main( transform_prediction_and_target=torch.log10, ) model = StandardModel( - detector=detector, + graph_definition=graph_definition, gnn=gnn, tasks=[task], optimizer_class=Adam, diff --git a/src/graphnet/data/dataset/parquet/__init__.py b/src/graphnet/data/dataset/parquet/__init__.py new file mode 100644 index 000000000..edfc62b4e --- /dev/null +++ b/src/graphnet/data/dataset/parquet/__init__.py @@ -0,0 +1,11 @@ +"""Datasets using parquet backend.""" +# Configuration +from graphnet.utilities.imports import has_torch_package + +if has_torch_package(): + import torch.multiprocessing + from .parquet_dataset import ParquetDataset + + torch.multiprocessing.set_sharing_strategy("file_system") + +del has_torch_package diff --git a/src/graphnet/data/dataset/sqlite/__init__.py b/src/graphnet/data/dataset/sqlite/__init__.py new file mode 100644 index 000000000..74b164e6e --- /dev/null +++ b/src/graphnet/data/dataset/sqlite/__init__.py @@ -0,0 +1,6 @@ +"""Datasets using SQLite backend.""" +from graphnet.utilities.imports import has_torch_package + +if has_torch_package(): + from .sqlite_dataset import SQLiteDataset + from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index 13ab6cbba..ea5066307 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -7,3 +7,4 @@ and their features. from .graph_definition import GraphDefinition +from .graphs import KNNGraph diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index ec035ad48..0cf7b3254 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -57,6 +57,7 @@ class GraphDefinition(Model): if node_feature_names is None: # Assume all features in Detector is used. node_feature_names = list(self._detector.feature_map().keys()) # type: ignore + print(node_feature_names) self._node_feature_names = node_feature_names if dtype is None: dtype = torch.float @@ -147,6 +148,8 @@ class GraphDefinition(Model): graph=graph, node_feature_names=node_feature_names ) + # Add GraphDefinition Stamp + graph["graph_definition"] = self.__class__.__name__ return graph def _validate_input( diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py new file mode 100644 index 000000000..068869b81 --- /dev/null +++ b/src/graphnet/models/graphs/graphs.py @@ -0,0 +1,38 @@ +"""A module containing different graph representations in GraphNeT.""" + +from typing import List + +from .graph_definition import GraphDefinition +from graphnet.models.detector import Detector +from graphnet.models.graphs.edges import KNNEdges +from graphnet.models.graphs.nodes import NodeDefinition + + +class KNNGraph(GraphDefinition): + """A Graph representation where Edges are drawn to nearest neighbours.""" + + def __init__( + self, + detector: Detector, + node_definition: NodeDefinition, + nb_nearest_neighbours: int = 8, + columns: List[int] = [0, 1, 2], + ): + """Construct k-nn graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + nb_nearest_neighbours: Number of edges for each node. Defaults to 8. + columns: node feature columns used for distance calculation + . Defaults to [0, 1, 2]. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition, + edge_definition=KNNEdges( + nb_nearest_neighbours=nb_nearest_neighbours, + columns=columns, + ), + ) diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index a56ee953d..01fa0a574 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -102,10 +102,6 @@ class StandardModel(Model): if self._coarsening: data = self._coarsening(data) assert isinstance(data, Data) - if data.graph_definition != self._graph_definition.__class__.__name__: - self.warn( - f"Model expects {self._graph_definition.__class__.__name__} but is given {data.graph_definition}" - ) x = self._gnn(data) preds = [task(x) for task in self._tasks] return preds diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index 75f6855b2..413a04b76 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -16,6 +16,7 @@ from graphnet.data.dataset import SQLiteDataset from graphnet.data.dataset import ParquetDataset from graphnet.models import Model from graphnet.utilities.logging import Logger +from graphnet.models.graphs import GraphDefinition def collate_fn(graphs: List[Data]) -> Batch: @@ -31,6 +32,7 @@ def collate_fn(graphs: List[Data]) -> Batch: def make_dataloader( db: str, pulsemaps: Union[str, List[str]], + graph_definition: Optional[GraphDefinition], features: List[str], truth: List[str], *, @@ -66,6 +68,7 @@ def make_dataloader( loss_weight_table=loss_weight_table, loss_weight_column=loss_weight_column, index_column=index_column, + graph_definition=graph_definition, ) # adds custom labels to dataset @@ -89,6 +92,7 @@ def make_dataloader( # @TODO: Remove in favour of DataLoader{,.from_dataset_config} def make_train_validation_dataloader( db: str, + graph_definition: Optional[GraphDefinition], selection: Optional[List[int]], pulsemaps: Union[str, List[str]], features: List[str], @@ -179,6 +183,7 @@ def make_train_validation_dataloader( loss_weight_table=loss_weight_table, index_column=index_column, labels=labels, + graph_definition=graph_definition, ) training_dataloader = make_dataloader( -- GitLab From fa30515717bd40d27270eb299bbc122f5ebde94b Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe <rahn@outlook.dk> Date: Tue, 18 Jul 2023 13:59:15 +0200 Subject: [PATCH 15/15] polish --- ...dev_lvl7_robustness_muon_neutrino_0000.yml | 13 ++++ configs/datasets/test_data_sqlite.yml | 58 ++++++++++-------- ...ing_classification_example_data_sqlite.yml | 13 ++++ .../training_example_data_parquet.yml | 13 ++++ .../datasets/training_example_data_sqlite.yml | 13 ++++ .../dynedge_PID_classification_example.yml | 36 ++++++----- configs/models/dynedge_energy_example.yml | 44 ------------- ...ynedge_position_custom_scaling_example.yml | 24 +++++--- ...example_direction_reconstruction_model.yml | 20 +++--- .../example_energy_reconstruction_model.yml | 39 +++++++----- ...e_vertex_position_reconstruction_model.yml | 26 +++++--- examples/04_training/01_train_model.py | 1 + .../02_train_model_without_configs.py | 16 ++--- src/graphnet/data/dataset/dataset.py | 61 +++++++++++++++++-- src/graphnet/models/graphs/graphs.py | 15 ++++- src/graphnet/models/utils.py | 2 +- src/graphnet/training/utils.py | 18 +++--- .../utilities/config/dataset_config.py | 27 +++++++- src/graphnet/utilities/config/model_config.py | 2 +- 19 files changed, 290 insertions(+), 151 deletions(-) delete mode 100644 configs/models/dynedge_energy_example.yml diff --git a/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml b/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml index 7c4a6c017..345087431 100644 --- a/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml +++ b/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml @@ -1,4 +1,17 @@ path: /groups/icecube/asogaard/data/example/dev_lvl7_robustness_muon_neutrino_0000.db +graph_definition: + arguments: + columns: [0, 1, 2] + detector: + arguments: {} + class_name: IceCube86 + dtype: null + nb_nearest_neighbours: 8 + node_definition: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [dom_x, dom_y, dom_z, dom_time, charge, rde, pmt_area] + class_name: KNNGraph pulsemaps: - SRTTWOfflinePulsesDC features: diff --git a/configs/datasets/test_data_sqlite.yml b/configs/datasets/test_data_sqlite.yml index 9ea481d74..349a8593b 100644 --- a/configs/datasets/test_data_sqlite.yml +++ b/configs/datasets/test_data_sqlite.yml @@ -1,26 +1,34 @@ -path: $GRAPHNET/data/tests/sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db -pulsemaps: - - SRTInIcePulses -features: - - dom_x - - dom_y - - dom_z - - dom_time - - charge - - rde - - pmt_area -truth: - - energy - - position_x - - position_y - - position_z - - azimuth - - zenith - - pid - - elasticity - - sim_type - - interaction_type +features: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] +graph_definition: + arguments: + columns: [0, 1, 2] + detector: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph index_column: event_no -truth_table: truth -seed: 21 -selection: null \ No newline at end of file +loss_weight_column: null +loss_weight_default_value: null +loss_weight_table: null +node_truth: null +node_truth_table: null +path: /home/iwsatlas1/oersoe/github/graphnet/data/examples/sqlite/prometheus/prometheus-events.db +pulsemaps: total +seed: null +selection: null +string_selection: null +truth: [injection_energy, injection_type, injection_interaction_type, injection_zenith, + injection_azimuth, injection_bjorkenx, injection_bjorkeny, injection_position_x, + injection_position_y, injection_position_z, injection_column_depth, primary_lepton_1_type, + primary_hadron_1_type, primary_lepton_1_position_x, primary_lepton_1_position_y, + primary_lepton_1_position_z, primary_hadron_1_position_x, primary_hadron_1_position_y, + primary_hadron_1_position_z, primary_lepton_1_direction_theta, primary_lepton_1_direction_phi, + primary_hadron_1_direction_theta, primary_hadron_1_direction_phi, primary_lepton_1_energy, + primary_hadron_1_energy, total_energy] +truth_table: mc_truth diff --git a/configs/datasets/training_classification_example_data_sqlite.yml b/configs/datasets/training_classification_example_data_sqlite.yml index b56266de2..5d12c3bbf 100644 --- a/configs/datasets/training_classification_example_data_sqlite.yml +++ b/configs/datasets/training_classification_example_data_sqlite.yml @@ -1,4 +1,17 @@ path: $GRAPHNET/data/examples/sqlite/prometheus/prometheus-events.db +graph_definition: + arguments: + columns: [0, 1, 2] + detector: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph pulsemaps: - total features: diff --git a/configs/datasets/training_example_data_parquet.yml b/configs/datasets/training_example_data_parquet.yml index 8df8870c7..11c8d7fb0 100644 --- a/configs/datasets/training_example_data_parquet.yml +++ b/configs/datasets/training_example_data_parquet.yml @@ -1,4 +1,17 @@ path: $GRAPHNET/data/examples/parquet/prometheus/prometheus-events.parquet +graph_definition: + arguments: + columns: [0, 1, 2] + detector: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph pulsemaps: - total features: diff --git a/configs/datasets/training_example_data_sqlite.yml b/configs/datasets/training_example_data_sqlite.yml index b61074d99..0de880a77 100644 --- a/configs/datasets/training_example_data_sqlite.yml +++ b/configs/datasets/training_example_data_sqlite.yml @@ -1,4 +1,17 @@ path: $GRAPHNET/data/examples/sqlite/prometheus/prometheus-events.db +graph_definition: + arguments: + columns: [0, 1, 2] + detector: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph pulsemaps: - total features: diff --git a/configs/models/dynedge_PID_classification_example.yml b/configs/models/dynedge_PID_classification_example.yml index a43ea3856..4b2fd0246 100644 --- a/configs/models/dynedge_PID_classification_example.yml +++ b/configs/models/dynedge_PID_classification_example.yml @@ -1,14 +1,4 @@ arguments: - coarsening: null - detector: - ModelConfig: - arguments: - graph_builder: - ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: Prometheus gnn: ModelConfig: arguments: @@ -21,11 +11,29 @@ arguments: post_processing_layer_sizes: null readout_layer_sizes: null class_name: DynEdge + graph_definition: + ModelConfig: + arguments: + columns: [0, 1, 2] + detector: + ModelConfig: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + ModelConfig: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' - optimizer_kwargs: {eps: 1e-03, lr: 1e-05} - scheduler_class: '!class torch.optim.lr_scheduler ReduceLROnPlateau' - scheduler_config: {frequency: 1, monitor: val_loss} - scheduler_kwargs: {patience: 1} + optimizer_kwargs: {eps: 0.001, lr: 0.001} + scheduler_class: '!class graphnet.training.callbacks PiecewiseLinearLR' + scheduler_config: {interval: step} + scheduler_kwargs: + factors: [0.01, 1, 0.01] + milestones: [0, 20.0, 80] tasks: - ModelConfig: arguments: diff --git a/configs/models/dynedge_energy_example.yml b/configs/models/dynedge_energy_example.yml deleted file mode 100644 index 02d647f0c..000000000 --- a/configs/models/dynedge_energy_example.yml +++ /dev/null @@ -1,44 +0,0 @@ -arguments: - coarsening: null - detector: - ModelConfig: - arguments: - graph_builder: - ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: IceCubeDeepCore - gnn: - ModelConfig: - arguments: - add_global_variables_after_pooling: false - dynedge_layer_sizes: null - features_subset: null - global_pooling_schemes: [min, max, mean, sum] - nb_inputs: 7 - nb_neighbours: 8 - post_processing_layer_sizes: null - readout_layer_sizes: null - class_name: DynEdge - optimizer_class: '!class torch.optim.adam Adam' - optimizer_kwargs: {eps: 0.001, lr: 1e-05} - scheduler_class: '!class torch.optim.lr_scheduler ReduceLROnPlateau' - scheduler_config: {frequency: 1, monitor: val_loss} - scheduler_kwargs: {patience: 5} - tasks: - - ModelConfig: - arguments: - hidden_size: 128 - loss_function: - ModelConfig: - arguments: {} - class_name: LogCoshLoss - loss_weight: null - target_labels: energy - transform_inference: null - transform_prediction_and_target: '!lambda x: torch.log10(x)' - transform_support: null - transform_target: null - class_name: EnergyReconstruction -class_name: StandardModel diff --git a/configs/models/dynedge_position_custom_scaling_example.yml b/configs/models/dynedge_position_custom_scaling_example.yml index e986c1529..195695a8d 100644 --- a/configs/models/dynedge_position_custom_scaling_example.yml +++ b/configs/models/dynedge_position_custom_scaling_example.yml @@ -1,14 +1,24 @@ arguments: - coarsening: null - detector: + graph_definition: ModelConfig: arguments: - graph_builder: + detector: ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: IceCubeDeepCore + arguments: {} + class_name: Prometheus + dtype: null + edge_definition: + ModelConfig: + arguments: + columns: [0, 1, 2] + nb_nearest_neighbours: 8 + class_name: KNNEdges + node_definition: + ModelConfig: + arguments: {} + class_name: NodesAsPulses + node_feature_names: null + class_name: KNNGraph gnn: ModelConfig: arguments: diff --git a/configs/models/example_direction_reconstruction_model.yml b/configs/models/example_direction_reconstruction_model.yml index c04974b43..cb1c4d841 100644 --- a/configs/models/example_direction_reconstruction_model.yml +++ b/configs/models/example_direction_reconstruction_model.yml @@ -1,14 +1,20 @@ arguments: - coarsening: null - detector: + graph_definition: ModelConfig: arguments: - graph_builder: + columns: [0, 1, 2] + detector: ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: Prometheus + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + ModelConfig: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph gnn: ModelConfig: arguments: diff --git a/configs/models/example_energy_reconstruction_model.yml b/configs/models/example_energy_reconstruction_model.yml index 7ef5a9265..827c84748 100644 --- a/configs/models/example_energy_reconstruction_model.yml +++ b/configs/models/example_energy_reconstruction_model.yml @@ -1,14 +1,4 @@ arguments: - coarsening: null - detector: - ModelConfig: - arguments: - graph_builder: - ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: Prometheus gnn: ModelConfig: arguments: @@ -21,11 +11,29 @@ arguments: post_processing_layer_sizes: null readout_layer_sizes: null class_name: DynEdge + graph_definition: + ModelConfig: + arguments: + columns: [0, 1, 2] + detector: + ModelConfig: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + ModelConfig: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' - optimizer_kwargs: {eps: 0.001, lr: 1e-05} - scheduler_class: '!class torch.optim.lr_scheduler ReduceLROnPlateau' - scheduler_config: {frequency: 1, monitor: val_loss} - scheduler_kwargs: {patience: 5} + optimizer_kwargs: {eps: 0.001, lr: 0.001} + scheduler_class: '!class graphnet.training.callbacks PiecewiseLinearLR' + scheduler_config: {interval: step} + scheduler_kwargs: + factors: [0.01, 1, 0.01] + milestones: [0, 20.0, 80] tasks: - ModelConfig: arguments: @@ -35,8 +43,9 @@ arguments: arguments: {} class_name: LogCoshLoss loss_weight: null + prediction_labels: null target_labels: total_energy - transform_inference: null + transform_inference: '!lambda x: torch.pow(10,x)' transform_prediction_and_target: '!lambda x: torch.log10(x)' transform_support: null transform_target: null diff --git a/configs/models/example_vertex_position_reconstruction_model.yml b/configs/models/example_vertex_position_reconstruction_model.yml index 8b9c8709c..0522a1f2d 100644 --- a/configs/models/example_vertex_position_reconstruction_model.yml +++ b/configs/models/example_vertex_position_reconstruction_model.yml @@ -1,14 +1,4 @@ arguments: - coarsening: null - detector: - ModelConfig: - arguments: - graph_builder: - ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: Prometheus gnn: ModelConfig: arguments: @@ -21,6 +11,22 @@ arguments: post_processing_layer_sizes: null readout_layer_sizes: null class_name: DynEdge + graph_definition: + ModelConfig: + arguments: + columns: [0, 1, 2] + detector: + ModelConfig: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + ModelConfig: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: {eps: 0.001, lr: 0.001} scheduler_class: '!class graphnet.training.callbacks PiecewiseLinearLR' diff --git a/examples/04_training/01_train_model.py b/examples/04_training/01_train_model.py index 2d82bf272..a013326fe 100644 --- a/examples/04_training/01_train_model.py +++ b/examples/04_training/01_train_model.py @@ -72,6 +72,7 @@ def main( # Construct dataloaders dataset_config = DatasetConfig.load(dataset_config_path) + print(dataset_config_path) dataloaders = DataLoader.from_dataset_config( dataset_config, **config.dataloader, diff --git a/examples/04_training/02_train_model_without_configs.py b/examples/04_training/02_train_model_without_configs.py index 4b336d830..6d9c5746e 100644 --- a/examples/04_training/02_train_model_without_configs.py +++ b/examples/04_training/02_train_model_without_configs.py @@ -83,21 +83,22 @@ def main( detector=Prometheus(), node_definition=NodesAsPulses(), nb_nearest_neighbours=8, + node_feature_names=features, ) ( training_dataloader, validation_dataloader, ) = make_train_validation_dataloader( - config["path"], - graph_definition, - None, - config["pulsemap"], - features, - truth, + db=config["path"], + graph_definition=graph_definition, + pulsemaps=config["pulsemap"], + features=features, + truth=truth, batch_size=config["batch_size"], num_workers=config["num_workers"], truth_table=truth_table, + selection=None, ) # Building model @@ -110,7 +111,8 @@ def main( hidden_size=gnn.nb_outputs, target_labels=config["target"], loss_function=LogCoshLoss(), - transform_prediction_and_target=torch.log10, + transform_prediction_and_target=lambda x: torch.log10(x), + transform_inference=lambda x: torch.pow(10, x), ) model = StandardModel( graph_definition=graph_definition, diff --git a/src/graphnet/data/dataset/dataset.py b/src/graphnet/data/dataset/dataset.py index 1352c8e20..730f33469 100644 --- a/src/graphnet/data/dataset/dataset.py +++ b/src/graphnet/data/dataset/dataset.py @@ -12,6 +12,7 @@ from typing import ( Tuple, Union, Iterable, + Type, ) import numpy as np @@ -31,11 +32,55 @@ from graphnet.utilities.config import ( from graphnet.utilities.logging import Logger from graphnet.models.graphs import GraphDefinition +from graphnet.utilities.config.parsing import ( + get_all_grapnet_classes, +) + class ColumnMissingException(Exception): """Exception to indicate a missing column in a dataset.""" +def load_module(class_name: str) -> Type: + """Load graphnet module from string name. + + Args: + class_name: name of class + + Returns: + graphnet module. + """ + # Get a lookup for all classes in `graphnet` + import graphnet.data + import graphnet.models + import graphnet.training + + namespace_classes = get_all_grapnet_classes( + graphnet.data, graphnet.models, graphnet.training + ) + return namespace_classes[class_name] + + +def parse_graph_definition(cfg: dict) -> GraphDefinition: + """Construct GraphDefinition from DatasetConfig.""" + assert cfg["graph_definition"] is not None + + args = cfg["graph_definition"]["arguments"] + classes = {} + for arg in args.keys(): + if isinstance(args[arg], dict): + if "class_name" in args[arg].keys(): + classes[arg] = load_module(args[arg]["class_name"])( + **args[arg]["arguments"] + ) + new_cfg = deepcopy(args) + new_cfg.update(classes) + graph_definition = load_module(cfg["graph_definition"]["class_name"])( + **new_cfg + ) + return graph_definition + + class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): """Base Dataset class for reading from any intermediate file format.""" @@ -56,9 +101,13 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): assert isinstance(source, DatasetConfig), ( f"Argument `source` of type ({type(source)}) is not a " - "`DatasetConfig" + "`DatasetConfig`" ) + assert ( + "graph_definition" in source.dict().keys() + ), "`DatasetConfig` incompatible with current GraphNeT version." + # Parse set of `selection``. if isinstance(source.selection, dict): return cls._construct_datasets_from_dict(source) @@ -69,7 +118,10 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): ): return cls._construct_dataset_from_list_of_strings(source) - return source._dataset_class(**source.dict()) + cfg = source.dict() + if cfg["graph_definition"] is not None: + cfg["graph_definition"] = parse_graph_definition(cfg) + return source._dataset_class(**cfg) @classmethod def concatenate( @@ -136,6 +188,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): def __init__( self, path: Union[str, List[str]], + graph_definition: GraphDefinition, pulsemaps: Union[str, List[str]], features: List[str], truth: List[str], @@ -151,7 +204,6 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): loss_weight_column: Optional[str] = None, loss_weight_default_value: Optional[float] = None, seed: Optional[int] = None, - graph_definition: GraphDefinition = None, ): """Construct Dataset. @@ -221,9 +273,6 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): self._index_column = index_column self._truth_table = truth_table self._loss_weight_default_value = loss_weight_default_value - # self.info( - # f"No GraphDefinition recieved. Defaulting to KNNGraph(nb_neighbours = 8)" - # ) self._graph_definition = graph_definition if node_truth is not None: diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index 068869b81..a9c3ff983 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -1,28 +1,35 @@ """A module containing different graph representations in GraphNeT.""" -from typing import List +from typing import List, Optional +import torch +from graphnet.utilities.config import save_model_config from .graph_definition import GraphDefinition from graphnet.models.detector import Detector -from graphnet.models.graphs.edges import KNNEdges +from graphnet.models.graphs.edges import EdgeDefinition, KNNEdges from graphnet.models.graphs.nodes import NodeDefinition class KNNGraph(GraphDefinition): """A Graph representation where Edges are drawn to nearest neighbours.""" + @save_model_config def __init__( self, detector: Detector, node_definition: NodeDefinition, + node_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = None, nb_nearest_neighbours: int = 8, columns: List[int] = [0, 1, 2], - ): + ) -> None: """Construct k-nn graph representation. Args: detector: Detector that represents your data. node_definition: Definition of nodes in the graph. + node_feature_names: Name of node features. + dtype: data type for node features. nb_nearest_neighbours: Number of edges for each node. Defaults to 8. columns: node feature columns used for distance calculation . Defaults to [0, 1, 2]. @@ -35,4 +42,6 @@ class KNNGraph(GraphDefinition): nb_nearest_neighbours=nb_nearest_neighbours, columns=columns, ), + dtype=dtype, + node_feature_names=node_feature_names, ) diff --git a/src/graphnet/models/utils.py b/src/graphnet/models/utils.py index 942e76caa..e1ef7956c 100644 --- a/src/graphnet/models/utils.py +++ b/src/graphnet/models/utils.py @@ -1,6 +1,6 @@ """Utility functions for `graphnet.models`.""" -from typing import List, Tuple +from typing import List, Tuple, Union from torch_geometric.nn import knn_graph from torch_geometric.data import Batch import torch diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index 413a04b76..2578ff9a6 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -126,19 +126,21 @@ def make_train_validation_dataloader( dataset: Dataset if db.endswith(".db"): dataset = SQLiteDataset( - db, - pulsemaps, - features, - truth, + path=db, + graph_definition=graph_definition, + pulsemaps=pulsemaps, + features=features, + truth=truth, truth_table=truth_table, index_column=index_column, ) elif db.endswith(".parquet"): dataset = ParquetDataset( - db, - pulsemaps, - features, - truth, + path=db, + graph_definition=graph_definition, + pulsemaps=pulsemaps, + features=features, + truth=truth, truth_table=truth_table, index_column=index_column, ) diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py index fc5beb44b..bb9ed2678 100644 --- a/src/graphnet/utilities/config/dataset_config.py +++ b/src/graphnet/utilities/config/dataset_config.py @@ -2,6 +2,7 @@ from functools import wraps from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -14,6 +15,12 @@ from graphnet.utilities.config.base_config import ( BaseConfig, get_all_argument_values, ) +from graphnet.utilities.config.parsing import traverse_and_apply +from .model_config import ModelConfig + +if TYPE_CHECKING: + from graphnet.models import Model + BACKEND_LOOKUP = { "db": "sqlite", @@ -45,8 +52,8 @@ class DatasetConfig(BaseConfig): loss_weight_table: Optional[str] = None loss_weight_column: Optional[str] = None loss_weight_default_value: Optional[float] = None - seed: Optional[int] = None + graph_definition: Any = None def __init__(self, **data: Any) -> None: """Construct `DataConfig`. @@ -139,8 +146,8 @@ class DatasetConfig(BaseConfig): @property def _dataset_class(self) -> type: """Return the `Dataset` class implementation for this configuration.""" - from graphnet.data.sqlite import SQLiteDataset - from graphnet.data.parquet import ParquetDataset + from graphnet.data.dataset.sqlite import SQLiteDataset + from graphnet.data.dataset.parquet import ParquetDataset dataset_class = { "sqlite": SQLiteDataset, @@ -153,6 +160,17 @@ class DatasetConfig(BaseConfig): def save_dataset_config(init_fn: Callable) -> Callable: """Save the arguments to `__init__` functions as member `DatasetConfig`.""" + def _replace_model_instance_with_config( + obj: Union["Model", Any] + ) -> Union[ModelConfig, Any]: + """Replace `Model` instances in `obj` with their `ModelConfig`.""" + from graphnet.models import Model + + if isinstance(obj, Model): + return obj.config + else: + return obj + @wraps(init_fn) def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: """Set `DatasetConfig` after calling `init_fn`.""" @@ -162,6 +180,9 @@ def save_dataset_config(init_fn: Callable) -> Callable: # Get all argument values, including defaults cfg = get_all_argument_values(init_fn, *args, **kwargs) + # Handle nested `Model`s, etc. + cfg = traverse_and_apply(cfg, _replace_model_instance_with_config) + # Add `DatasetConfig` as member variables self._config = DatasetConfig(**cfg) diff --git a/src/graphnet/utilities/config/model_config.py b/src/graphnet/utilities/config/model_config.py index d4733ca68..21e18c104 100644 --- a/src/graphnet/utilities/config/model_config.py +++ b/src/graphnet/utilities/config/model_config.py @@ -133,7 +133,7 @@ class ModelConfig(BaseConfig): fn_kwargs={"trust": trust}, ) - # Construct model based on arguments + # Construct model based on return namespace_classes[self.class_name](**arguments) @classmethod -- GitLab