diff --git a/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml b/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml index 7c4a6c0177bbc4940b9dc38d377ff1470d7fa891..345087431df7708338b3d768ddc5851d9bbf4f25 100644 --- a/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml +++ b/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml @@ -1,4 +1,17 @@ path: /groups/icecube/asogaard/data/example/dev_lvl7_robustness_muon_neutrino_0000.db +graph_definition: + arguments: + columns: [0, 1, 2] + detector: + arguments: {} + class_name: IceCube86 + dtype: null + nb_nearest_neighbours: 8 + node_definition: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [dom_x, dom_y, dom_z, dom_time, charge, rde, pmt_area] + class_name: KNNGraph pulsemaps: - SRTTWOfflinePulsesDC features: diff --git a/configs/datasets/test_data_sqlite.yml b/configs/datasets/test_data_sqlite.yml index 9ea481d74fc00f0df8ea8ab296925e6d835c41c2..349a8593ba520e83900b772c8dec8796cc5a47ee 100644 --- a/configs/datasets/test_data_sqlite.yml +++ b/configs/datasets/test_data_sqlite.yml @@ -1,26 +1,34 @@ -path: $GRAPHNET/data/tests/sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db -pulsemaps: - - SRTInIcePulses -features: - - dom_x - - dom_y - - dom_z - - dom_time - - charge - - rde - - pmt_area -truth: - - energy - - position_x - - position_y - - position_z - - azimuth - - zenith - - pid - - elasticity - - sim_type - - interaction_type +features: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] +graph_definition: + arguments: + columns: [0, 1, 2] + detector: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph index_column: event_no -truth_table: truth -seed: 21 -selection: null \ No newline at end of file +loss_weight_column: null +loss_weight_default_value: null +loss_weight_table: null +node_truth: null +node_truth_table: null +path: /home/iwsatlas1/oersoe/github/graphnet/data/examples/sqlite/prometheus/prometheus-events.db +pulsemaps: total +seed: null +selection: null +string_selection: null +truth: [injection_energy, injection_type, injection_interaction_type, injection_zenith, + injection_azimuth, injection_bjorkenx, injection_bjorkeny, injection_position_x, + injection_position_y, injection_position_z, injection_column_depth, primary_lepton_1_type, + primary_hadron_1_type, primary_lepton_1_position_x, primary_lepton_1_position_y, + primary_lepton_1_position_z, primary_hadron_1_position_x, primary_hadron_1_position_y, + primary_hadron_1_position_z, primary_lepton_1_direction_theta, primary_lepton_1_direction_phi, + primary_hadron_1_direction_theta, primary_hadron_1_direction_phi, primary_lepton_1_energy, + primary_hadron_1_energy, total_energy] +truth_table: mc_truth diff --git a/configs/datasets/training_classification_example_data_sqlite.yml b/configs/datasets/training_classification_example_data_sqlite.yml index b56266de2f2b709b8734044d2be1f48dca5a8446..5d12c3bbfbd5ea044e7d86f2b805a64cfc201f49 100644 --- a/configs/datasets/training_classification_example_data_sqlite.yml +++ b/configs/datasets/training_classification_example_data_sqlite.yml @@ -1,4 +1,17 @@ path: $GRAPHNET/data/examples/sqlite/prometheus/prometheus-events.db +graph_definition: + arguments: + columns: [0, 1, 2] + detector: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph pulsemaps: - total features: diff --git a/configs/datasets/training_example_data_parquet.yml b/configs/datasets/training_example_data_parquet.yml index 8df8870c75412d69118394058d941e65ad1817db..11c8d7fb08b0ea5eca54ce1bfde0a0d0698beeac 100644 --- a/configs/datasets/training_example_data_parquet.yml +++ b/configs/datasets/training_example_data_parquet.yml @@ -1,4 +1,17 @@ path: $GRAPHNET/data/examples/parquet/prometheus/prometheus-events.parquet +graph_definition: + arguments: + columns: [0, 1, 2] + detector: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph pulsemaps: - total features: diff --git a/configs/datasets/training_example_data_sqlite.yml b/configs/datasets/training_example_data_sqlite.yml index b61074d99738a296d04ff3cb229aedbe884ae223..0de880a77176cfcc049937b03fbf42fb12afb773 100644 --- a/configs/datasets/training_example_data_sqlite.yml +++ b/configs/datasets/training_example_data_sqlite.yml @@ -1,4 +1,17 @@ path: $GRAPHNET/data/examples/sqlite/prometheus/prometheus-events.db +graph_definition: + arguments: + columns: [0, 1, 2] + detector: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph pulsemaps: - total features: diff --git a/configs/models/dynedge_PID_classification_example.yml b/configs/models/dynedge_PID_classification_example.yml index a43ea3856fdd930b47fdd42850b83e507b1fdf9c..4b2fd0246bcd2f3952cb9956987f42db78340a1c 100644 --- a/configs/models/dynedge_PID_classification_example.yml +++ b/configs/models/dynedge_PID_classification_example.yml @@ -1,14 +1,4 @@ arguments: - coarsening: null - detector: - ModelConfig: - arguments: - graph_builder: - ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: Prometheus gnn: ModelConfig: arguments: @@ -21,11 +11,29 @@ arguments: post_processing_layer_sizes: null readout_layer_sizes: null class_name: DynEdge + graph_definition: + ModelConfig: + arguments: + columns: [0, 1, 2] + detector: + ModelConfig: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + ModelConfig: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' - optimizer_kwargs: {eps: 1e-03, lr: 1e-05} - scheduler_class: '!class torch.optim.lr_scheduler ReduceLROnPlateau' - scheduler_config: {frequency: 1, monitor: val_loss} - scheduler_kwargs: {patience: 1} + optimizer_kwargs: {eps: 0.001, lr: 0.001} + scheduler_class: '!class graphnet.training.callbacks PiecewiseLinearLR' + scheduler_config: {interval: step} + scheduler_kwargs: + factors: [0.01, 1, 0.01] + milestones: [0, 20.0, 80] tasks: - ModelConfig: arguments: diff --git a/configs/models/dynedge_energy_example.yml b/configs/models/dynedge_energy_example.yml deleted file mode 100644 index 02d647f0c70eaa9d4bdce7184129f56f31809e77..0000000000000000000000000000000000000000 --- a/configs/models/dynedge_energy_example.yml +++ /dev/null @@ -1,44 +0,0 @@ -arguments: - coarsening: null - detector: - ModelConfig: - arguments: - graph_builder: - ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: IceCubeDeepCore - gnn: - ModelConfig: - arguments: - add_global_variables_after_pooling: false - dynedge_layer_sizes: null - features_subset: null - global_pooling_schemes: [min, max, mean, sum] - nb_inputs: 7 - nb_neighbours: 8 - post_processing_layer_sizes: null - readout_layer_sizes: null - class_name: DynEdge - optimizer_class: '!class torch.optim.adam Adam' - optimizer_kwargs: {eps: 0.001, lr: 1e-05} - scheduler_class: '!class torch.optim.lr_scheduler ReduceLROnPlateau' - scheduler_config: {frequency: 1, monitor: val_loss} - scheduler_kwargs: {patience: 5} - tasks: - - ModelConfig: - arguments: - hidden_size: 128 - loss_function: - ModelConfig: - arguments: {} - class_name: LogCoshLoss - loss_weight: null - target_labels: energy - transform_inference: null - transform_prediction_and_target: '!lambda x: torch.log10(x)' - transform_support: null - transform_target: null - class_name: EnergyReconstruction -class_name: StandardModel diff --git a/configs/models/dynedge_position_custom_scaling_example.yml b/configs/models/dynedge_position_custom_scaling_example.yml index e986c1529c0b3b46e68c07c11dda017fc8c68cc5..195695a8d6d1b1a2808f1381141885d72119324f 100644 --- a/configs/models/dynedge_position_custom_scaling_example.yml +++ b/configs/models/dynedge_position_custom_scaling_example.yml @@ -1,14 +1,24 @@ arguments: - coarsening: null - detector: + graph_definition: ModelConfig: arguments: - graph_builder: + detector: ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: IceCubeDeepCore + arguments: {} + class_name: Prometheus + dtype: null + edge_definition: + ModelConfig: + arguments: + columns: [0, 1, 2] + nb_nearest_neighbours: 8 + class_name: KNNEdges + node_definition: + ModelConfig: + arguments: {} + class_name: NodesAsPulses + node_feature_names: null + class_name: KNNGraph gnn: ModelConfig: arguments: diff --git a/configs/models/example_direction_reconstruction_model.yml b/configs/models/example_direction_reconstruction_model.yml index c04974b437bd563c6d42af1ff88bdb4c37578956..cb1c4d841a28b7727be20ef2a07eb3edb61f14f0 100644 --- a/configs/models/example_direction_reconstruction_model.yml +++ b/configs/models/example_direction_reconstruction_model.yml @@ -1,14 +1,20 @@ arguments: - coarsening: null - detector: + graph_definition: ModelConfig: arguments: - graph_builder: + columns: [0, 1, 2] + detector: ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: Prometheus + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + ModelConfig: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph gnn: ModelConfig: arguments: diff --git a/configs/models/example_energy_reconstruction_model.yml b/configs/models/example_energy_reconstruction_model.yml index 7ef5a9265b7abf318aa8199dbefbaeeed383a522..827c84748b393a57a6dee27472b79eef5af2eaa9 100644 --- a/configs/models/example_energy_reconstruction_model.yml +++ b/configs/models/example_energy_reconstruction_model.yml @@ -1,14 +1,4 @@ arguments: - coarsening: null - detector: - ModelConfig: - arguments: - graph_builder: - ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: Prometheus gnn: ModelConfig: arguments: @@ -21,11 +11,29 @@ arguments: post_processing_layer_sizes: null readout_layer_sizes: null class_name: DynEdge + graph_definition: + ModelConfig: + arguments: + columns: [0, 1, 2] + detector: + ModelConfig: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + ModelConfig: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' - optimizer_kwargs: {eps: 0.001, lr: 1e-05} - scheduler_class: '!class torch.optim.lr_scheduler ReduceLROnPlateau' - scheduler_config: {frequency: 1, monitor: val_loss} - scheduler_kwargs: {patience: 5} + optimizer_kwargs: {eps: 0.001, lr: 0.001} + scheduler_class: '!class graphnet.training.callbacks PiecewiseLinearLR' + scheduler_config: {interval: step} + scheduler_kwargs: + factors: [0.01, 1, 0.01] + milestones: [0, 20.0, 80] tasks: - ModelConfig: arguments: @@ -35,8 +43,9 @@ arguments: arguments: {} class_name: LogCoshLoss loss_weight: null + prediction_labels: null target_labels: total_energy - transform_inference: null + transform_inference: '!lambda x: torch.pow(10,x)' transform_prediction_and_target: '!lambda x: torch.log10(x)' transform_support: null transform_target: null diff --git a/configs/models/example_vertex_position_reconstruction_model.yml b/configs/models/example_vertex_position_reconstruction_model.yml index 8b9c8709c5abd3221ca2077b525894de340c2448..0522a1f2ddc1b914da05b31cd83f4ffdbcb1d7dc 100644 --- a/configs/models/example_vertex_position_reconstruction_model.yml +++ b/configs/models/example_vertex_position_reconstruction_model.yml @@ -1,14 +1,4 @@ arguments: - coarsening: null - detector: - ModelConfig: - arguments: - graph_builder: - ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: Prometheus gnn: ModelConfig: arguments: @@ -21,6 +11,22 @@ arguments: post_processing_layer_sizes: null readout_layer_sizes: null class_name: DynEdge + graph_definition: + ModelConfig: + arguments: + columns: [0, 1, 2] + detector: + ModelConfig: + arguments: {} + class_name: Prometheus + dtype: null + nb_nearest_neighbours: 8 + node_definition: + ModelConfig: + arguments: {} + class_name: NodesAsPulses + node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: {eps: 0.001, lr: 0.001} scheduler_class: '!class graphnet.training.callbacks PiecewiseLinearLR' diff --git a/examples/02_data/01_read_dataset.py b/examples/02_data/01_read_dataset.py index a9a1f4d01c4aac73081e33da36bf29f6585f1443..302529050856b0a14db32c569f216a9b2dd3b90a 100644 --- a/examples/02_data/01_read_dataset.py +++ b/examples/02_data/01_read_dataset.py @@ -13,8 +13,8 @@ from tqdm import tqdm from graphnet.constants import TEST_PARQUET_DATA, TEST_SQLITE_DATA from graphnet.data.constants import FEATURES, TRUTH from graphnet.data.dataset import Dataset -from graphnet.data.sqlite.sqlite_dataset import SQLiteDataset -from graphnet.data.parquet.parquet_dataset import ParquetDataset +from graphnet.data.dataset import SQLiteDataset +from graphnet.data.dataset import ParquetDataset from graphnet.utilities.argparse import ArgumentParser from graphnet.utilities.logging import Logger diff --git a/examples/04_training/01_train_model.py b/examples/04_training/01_train_model.py index 2d82bf2726035ad9fc1ae9d4f37a5377d84830c2..a013326fe5d5ef97484865f6c21ee32399088ee3 100644 --- a/examples/04_training/01_train_model.py +++ b/examples/04_training/01_train_model.py @@ -72,6 +72,7 @@ def main( # Construct dataloaders dataset_config = DatasetConfig.load(dataset_config_path) + print(dataset_config_path) dataloaders = DataLoader.from_dataset_config( dataset_config, **config.dataloader, diff --git a/examples/04_training/02_train_model_without_configs.py b/examples/04_training/02_train_model_without_configs.py index 27e112ce5a609942d184630cf8ad9a1f9d43e452..6d9c5746e87a7a4f9175821db8312e50d3aee5a2 100644 --- a/examples/04_training/02_train_model_without_configs.py +++ b/examples/04_training/02_train_model_without_configs.py @@ -13,7 +13,8 @@ from graphnet.data.constants import FEATURES, TRUTH from graphnet.models import StandardModel from graphnet.models.detector.prometheus import Prometheus from graphnet.models.gnn import DynEdge -from graphnet.models.graph_builders import KNNGraphBuilder +from graphnet.models.graphs import KNNGraph +from graphnet.models.graphs.nodes import NodesAsPulses from graphnet.models.task.reconstruction import EnergyReconstruction from graphnet.training.callbacks import ProgressBar, PiecewiseLinearLR from graphnet.training.loss_functions import LogCoshLoss @@ -77,36 +78,44 @@ def main( # Log configuration to W&B wandb_logger.experiment.config.update(config) + # Define graph representation + graph_definition = KNNGraph( + detector=Prometheus(), + node_definition=NodesAsPulses(), + nb_nearest_neighbours=8, + node_feature_names=features, + ) + ( training_dataloader, validation_dataloader, ) = make_train_validation_dataloader( - config["path"], - None, - config["pulsemap"], - features, - truth, + db=config["path"], + graph_definition=graph_definition, + pulsemaps=config["pulsemap"], + features=features, + truth=truth, batch_size=config["batch_size"], num_workers=config["num_workers"], truth_table=truth_table, + selection=None, ) # Building model - detector = Prometheus( - graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8), - ) + gnn = DynEdge( - nb_inputs=detector.nb_outputs, + nb_inputs=graph_definition.nb_outputs, global_pooling_schemes=["min", "max", "mean", "sum"], ) task = EnergyReconstruction( hidden_size=gnn.nb_outputs, target_labels=config["target"], loss_function=LogCoshLoss(), - transform_prediction_and_target=torch.log10, + transform_prediction_and_target=lambda x: torch.log10(x), + transform_inference=lambda x: torch.pow(10, x), ) model = StandardModel( - detector=detector, + graph_definition=graph_definition, gnn=gnn, tasks=[task], optimizer_class=Adam, diff --git a/examples/04_training/03_train_classification_model.py b/examples/04_training/03_train_classification_model.py index b403e4850b96ac8dc07837e9bef15477d24fcd6f..0c537bac03781b304955f42ab95ea9557388060d 100644 --- a/examples/04_training/03_train_classification_model.py +++ b/examples/04_training/03_train_classification_model.py @@ -6,7 +6,7 @@ from typing import List, Optional, Dict, Any from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities import rank_zero_only -from graphnet.data.dataset import EnsembleDataset +from graphnet.data.dataset.dataset import EnsembleDataset from graphnet.constants import ( EXAMPLE_OUTPUT_DIR, DATASETS_CONFIG_DIR, diff --git a/src/graphnet/data/__init__.py b/src/graphnet/data/__init__.py index 64cef139d821998062792d95f37e91540e6a9b7c..1eca4f6cd76e0b45ac5c1db5aaed7becd44c4b5e 100644 --- a/src/graphnet/data/__init__.py +++ b/src/graphnet/data/__init__.py @@ -3,14 +3,3 @@ `graphnet.data` enables converting domain-specific data to industry-standard, intermediate file formats and reading this data. """ - -# Configuration -from graphnet.utilities.imports import has_torch_package - -if has_torch_package(): - import torch.multiprocessing - from .dataset import EnsembleDataset - - torch.multiprocessing.set_sharing_strategy("file_system") - -del has_torch_package diff --git a/src/graphnet/data/dataset/__init__.py b/src/graphnet/data/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccdd964221a9f509c51d6904e2d7231684dfb5c --- /dev/null +++ b/src/graphnet/data/dataset/__init__.py @@ -0,0 +1,14 @@ +"""Dataset classes for training in GraphNeT.""" +# Configuration +from graphnet.utilities.imports import has_torch_package + +if has_torch_package(): + import torch.multiprocessing + from .dataset import EnsembleDataset, Dataset, ColumnMissingException + from .parquet.parquet_dataset import ParquetDataset + from .sqlite.sqlite_dataset import SQLiteDataset + from .sqlite.sqlite_dataset_perturbed import SQLiteDatasetPerturbed + + torch.multiprocessing.set_sharing_strategy("file_system") + +del has_torch_package diff --git a/src/graphnet/data/dataset.py b/src/graphnet/data/dataset/dataset.py similarity index 87% rename from src/graphnet/data/dataset.py rename to src/graphnet/data/dataset/dataset.py index 7300da815154c5550cd4c9add067bd4872e4da47..730f33469fc04a6cd8cdf3c4f89f7f87b6c6e1b8 100644 --- a/src/graphnet/data/dataset.py +++ b/src/graphnet/data/dataset/dataset.py @@ -12,6 +12,7 @@ from typing import ( Tuple, Union, Iterable, + Type, ) import numpy as np @@ -29,12 +30,57 @@ from graphnet.utilities.config import ( save_dataset_config, ) from graphnet.utilities.logging import Logger +from graphnet.models.graphs import GraphDefinition + +from graphnet.utilities.config.parsing import ( + get_all_grapnet_classes, +) class ColumnMissingException(Exception): """Exception to indicate a missing column in a dataset.""" +def load_module(class_name: str) -> Type: + """Load graphnet module from string name. + + Args: + class_name: name of class + + Returns: + graphnet module. + """ + # Get a lookup for all classes in `graphnet` + import graphnet.data + import graphnet.models + import graphnet.training + + namespace_classes = get_all_grapnet_classes( + graphnet.data, graphnet.models, graphnet.training + ) + return namespace_classes[class_name] + + +def parse_graph_definition(cfg: dict) -> GraphDefinition: + """Construct GraphDefinition from DatasetConfig.""" + assert cfg["graph_definition"] is not None + + args = cfg["graph_definition"]["arguments"] + classes = {} + for arg in args.keys(): + if isinstance(args[arg], dict): + if "class_name" in args[arg].keys(): + classes[arg] = load_module(args[arg]["class_name"])( + **args[arg]["arguments"] + ) + new_cfg = deepcopy(args) + new_cfg.update(classes) + graph_definition = load_module(cfg["graph_definition"]["class_name"])( + **new_cfg + ) + return graph_definition + + class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): """Base Dataset class for reading from any intermediate file format.""" @@ -55,9 +101,13 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): assert isinstance(source, DatasetConfig), ( f"Argument `source` of type ({type(source)}) is not a " - "`DatasetConfig" + "`DatasetConfig`" ) + assert ( + "graph_definition" in source.dict().keys() + ), "`DatasetConfig` incompatible with current GraphNeT version." + # Parse set of `selection``. if isinstance(source.selection, dict): return cls._construct_datasets_from_dict(source) @@ -68,7 +118,10 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): ): return cls._construct_dataset_from_list_of_strings(source) - return source._dataset_class(**source.dict()) + cfg = source.dict() + if cfg["graph_definition"] is not None: + cfg["graph_definition"] = parse_graph_definition(cfg) + return source._dataset_class(**cfg) @classmethod def concatenate( @@ -135,6 +188,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): def __init__( self, path: Union[str, List[str]], + graph_definition: GraphDefinition, pulsemaps: Union[str, List[str]], features: List[str], truth: List[str], @@ -195,6 +249,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): subset of events when resolving a string-based selection (e.g., `"10000 random events ~ event_no % 5 > 0"` or `"20% random events ~ event_no % 5 > 0"`). + graph_definition: Method that defines the graph representation. """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -218,6 +273,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): self._index_column = index_column self._truth_table = truth_table self._loss_weight_default_value = loss_weight_default_value + self._graph_definition = graph_definition if node_truth is not None: assert isinstance(node_truth_table, str) @@ -521,10 +577,6 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): ) -> Data: """Create Pytorch Data (i.e. graph) object. - No preprocessing is performed at this stage, just as no node adjancency - is imposed. This means that the `edge_attr` and `edge_weight` - attributes are not set. - Args: features: List of tuples, containing event features. truth: List of tuples, containing truth information. @@ -552,71 +604,33 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): for index, key in enumerate(self._node_truth) } + # Create list of truth dicts with labels + truth_dicts = [labels_dict, truth_dict] + if node_truth is not None: + truth_dicts.append(node_truth_dict) + # Catch cases with no reconstructed pulses if len(features): - data = np.asarray(features)[:, 1:] + node_features = np.asarray(features)[ + :, 1: + ] # first entry is index column else: - data = np.array([]).reshape((0, len(self._features) - 1)) + node_features = np.array([]).reshape((0, len(self._features) - 1)) # Construct graph data object - x = torch.tensor(data, dtype=self._dtype) # pylint: disable=C0103 - n_pulses = torch.tensor(len(x), dtype=torch.int32) - graph = Data(x=x, edge_index=None) - graph.n_pulses = n_pulses - graph.features = self._features[1:] - - # Add loss weight to graph. - if loss_weight is not None and self._loss_weight_column is not None: - # No loss weight was retrieved, i.e., it is missing for the current - # event. - if loss_weight < 0: - if self._loss_weight_default_value is None: - raise ValueError( - "At least one event is missing an entry in " - f"{self._loss_weight_column} " - "but loss_weight_default_value is None." - ) - graph[self._loss_weight_column] = torch.tensor( - self._loss_weight_default_value, dtype=self._dtype - ).reshape(-1, 1) - else: - graph[self._loss_weight_column] = torch.tensor( - loss_weight, dtype=self._dtype - ).reshape(-1, 1) - - # Write attributes, either target labels, truth info or original - # features. - add_these_to_graph = [labels_dict, truth_dict] - if node_truth is not None: - add_these_to_graph.append(node_truth_dict) - for write_dict in add_these_to_graph: - for key, value in write_dict.items(): - try: - graph[key] = torch.tensor(value) - except (TypeError, RuntimeError) as error: - if isinstance(error, TypeError) or (value is None): - # Cannot convert `value` to Tensor due to its data type, - # e.g. `str`. - self.warning( - ( - f"Could not assign `{key}` with type " - f"'{type(value).__name__ if value is not None else 'NoneType'}' as attribute to graph of " - f"event with {self._index_column} == {labels_dict[self._index_column]}" - ) - ) - else: - raise error - # Additionally add original features as (static) attributes - for index, feature in enumerate(graph.features): - if feature not in ["x"]: - graph[feature] = graph.x[:, index].detach() - - # Add custom labels to the graph - for key, fn in self._label_fns.items(): - graph[key] = fn(graph) - - # Add Dataset Path. Useful if multiple datasets are concatenated. - graph["dataset_path"] = self._path + assert self._graph_definition is not None + graph = self._graph_definition( + node_features=node_features, + node_feature_names=self._features[ + 1: + ], # first entry is index column + truth_dicts=truth_dicts, + custom_label_functions=self._label_fns, + loss_weight_column=self._loss_weight_column, + loss_weight=loss_weight, + loss_weight_default_value=self._loss_weight_default_value, + data_path=self._path, + ) return graph def _get_labels(self, truth_dict: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/graphnet/data/dataset/parquet/__init__.py b/src/graphnet/data/dataset/parquet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..edfc62b4e83e0e79d008a116d8bdabb2ad9993e3 --- /dev/null +++ b/src/graphnet/data/dataset/parquet/__init__.py @@ -0,0 +1,11 @@ +"""Datasets using parquet backend.""" +# Configuration +from graphnet.utilities.imports import has_torch_package + +if has_torch_package(): + import torch.multiprocessing + from .parquet_dataset import ParquetDataset + + torch.multiprocessing.set_sharing_strategy("file_system") + +del has_torch_package diff --git a/src/graphnet/data/parquet/parquet_dataset.py b/src/graphnet/data/dataset/parquet/parquet_dataset.py similarity index 98% rename from src/graphnet/data/parquet/parquet_dataset.py rename to src/graphnet/data/dataset/parquet/parquet_dataset.py index 7839bd983e08b0b00dca362eac7d95b73252d323..bb63e18001520003b8e85590ab70df02692cddab 100644 --- a/src/graphnet/data/parquet/parquet_dataset.py +++ b/src/graphnet/data/dataset/parquet/parquet_dataset.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast import numpy as np import awkward as ak -from graphnet.data.dataset import Dataset, ColumnMissingException +from graphnet.data.dataset.dataset import Dataset, ColumnMissingException class ParquetDataset(Dataset): diff --git a/src/graphnet/data/dataset/sqlite/__init__.py b/src/graphnet/data/dataset/sqlite/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..74b164e6ebba031edc39529d18cf1ec57cec1d3d --- /dev/null +++ b/src/graphnet/data/dataset/sqlite/__init__.py @@ -0,0 +1,6 @@ +"""Datasets using SQLite backend.""" +from graphnet.utilities.imports import has_torch_package + +if has_torch_package(): + from .sqlite_dataset import SQLiteDataset + from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed diff --git a/src/graphnet/data/sqlite/sqlite_dataset.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset.py similarity index 98% rename from src/graphnet/data/sqlite/sqlite_dataset.py rename to src/graphnet/data/dataset/sqlite/sqlite_dataset.py index e61623c46d60c16abe474af3543fc379d59b75ed..a0b06ff6676201a4994a990f24fde2187e4bb464 100644 --- a/src/graphnet/data/sqlite/sqlite_dataset.py +++ b/src/graphnet/data/dataset/sqlite/sqlite_dataset.py @@ -4,7 +4,7 @@ from typing import Any, List, Optional, Tuple, Union import pandas as pd import sqlite3 -from graphnet.data.dataset import Dataset, ColumnMissingException +from graphnet.data.dataset.dataset import Dataset, ColumnMissingException class SQLiteDataset(Dataset): diff --git a/src/graphnet/data/sqlite/sqlite_dataset_perturbed.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py similarity index 98% rename from src/graphnet/data/sqlite/sqlite_dataset_perturbed.py rename to src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py index 6fb977ccd7a35b0aff42a782422893cb4afaa75c..755d96b82df1b111e319bf12a86a44a4668a7338 100644 --- a/src/graphnet/data/sqlite/sqlite_dataset_perturbed.py +++ b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py @@ -6,7 +6,7 @@ import numpy as np import torch from torch_geometric.data import Data -from graphnet.data.sqlite.sqlite_dataset import SQLiteDataset +from .sqlite_dataset import SQLiteDataset class SQLiteDatasetPerturbed(SQLiteDataset): diff --git a/src/graphnet/data/parquet/__init__.py b/src/graphnet/data/parquet/__init__.py index fc0b2f7a02070608ac6c23a0c56863d9ee21917c..616d89c166fa46f0d0d2c59e9e554561434ed5a7 100644 --- a/src/graphnet/data/parquet/__init__.py +++ b/src/graphnet/data/parquet/__init__.py @@ -1,10 +1,2 @@ """Parquet-specific implementation of data classes.""" - -from graphnet.utilities.imports import has_torch_package - from .parquet_dataconverter import ParquetDataConverter - -if has_torch_package(): - from .parquet_dataset import ParquetDataset - -del has_torch_package diff --git a/src/graphnet/data/sqlite/__init__.py b/src/graphnet/data/sqlite/__init__.py index f632f2be85992b8c28d29c8fc15b7ccbecd45e2c..e4ac554a789cee81ebca6b49a28be05d4777d1a0 100644 --- a/src/graphnet/data/sqlite/__init__.py +++ b/src/graphnet/data/sqlite/__init__.py @@ -1,11 +1,4 @@ """SQLite-specific implementation of data classes.""" - -from graphnet.utilities.imports import has_torch_package - from .sqlite_dataconverter import SQLiteDataConverter from .sqlite_utilities import create_table_and_save_to_sql - -if has_torch_package(): - from .sqlite_dataset import SQLiteDataset - from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed from .sqlite_utilities import run_sql_code, save_to_sql diff --git a/src/graphnet/models/detector/__init__.py b/src/graphnet/models/detector/__init__.py index 17fba8fd2ae8f572504ae3d11d0b2309452e30d3..060b7ca038953cad94d0f9eebe5c5f5f66968b82 100644 --- a/src/graphnet/models/detector/__init__.py +++ b/src/graphnet/models/detector/__init__.py @@ -1,3 +1,4 @@ """Detector-specific modules, for data ingestion and standardisation.""" from .icecube import IceCube86, IceCubeDeepCore +from .detector import Detector diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py index 4ad3cce482d6d8accd23bad60c226e1ec4a35776..25660c6d7abefa5dca3421b2ecbcbcb58949959c 100644 --- a/src/graphnet/models/detector/detector.py +++ b/src/graphnet/models/detector/detector.py @@ -1,115 +1,49 @@ """Base detector-specific `Model` class(es).""" from abc import abstractmethod -from typing import List +from typing import Dict, Callable, List -import torch from torch_geometric.data import Data -from torch_geometric.data.batch import Batch +import torch -from graphnet.models.graph_builders import GraphBuilder from graphnet.models import Model -from graphnet.utilities.config import save_model_config from graphnet.utilities.decorators import final +from graphnet.utilities.config import save_model_config class Detector(Model): """Base class for all detector-specific read-ins in graphnet.""" - @property - @abstractmethod - def features(self) -> List[str]: - """List of features used/assumed by inheriting `Detector` objects.""" - @save_model_config - def __init__( - self, graph_builder: GraphBuilder, scalers: List[dict] = None - ): + def __init__(self) -> None: """Construct `Detector`.""" # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) - # Member variables - self._graph_builder = graph_builder - self._scalers = scalers - if self._scalers: - self.info( - ( - "Will use scalers rather than standard preprocessing " - f"in {self.__class__.__name__}." - ) - ) + @property + @abstractmethod + def feature_map(self) -> Dict[str, Callable]: + """List of features used/assumed by inheriting `Detector` objects.""" @final - def forward(self, data: Data) -> Data: + def forward( + self, node_features: torch.tensor, node_feature_names: List[str] + ) -> Data: """Pre-process graph `Data` features and build graph adjacency.""" - # Check(s) - assert data.x.size()[1] == self.nb_inputs, ( - "Got graph data with incompatible size, ", - f"{data.x.size()} vs. {self.nb_inputs} expected", - ) - - # Graph-bulding - # @NOTE: `.clone` is necessary to avoid modifying original tensor in-place. - data = self._graph_builder(data).clone() - - if self._scalers: - # # Scaling individual features - # x_numpy = data.x.detach().cpu().numpy() - # for key, scaler in self._scalers.items(): - # ix = self.features.index(key) - # data.x[:,ix] = torch.tensor(scaler.transform(x_numpy[:,ix])).type_as(data.x) - - # Scaling groups of features | @TEMP, probably - x_numpy = data.x.detach().cpu().numpy() - - data.x[:, :3] = torch.tensor( - self._scalers["xyz"].transform(x_numpy[:, :3]) # type: ignore[call-overload] - ).type_as(data.x) + return self._standardize(node_features, node_feature_names) - data.x[:, 3:] = torch.tensor( - self._scalers["features"].transform(x_numpy[:, 3:]) # type: ignore[call-overload] - ).type_as(data.x) - - else: - # Implementation-specific forward pass (e.g. preprocessing) - data = self._forward(data) - - return data - - @abstractmethod - def _forward(self, data: Data) -> Data: - """Syntax like `.forward`, for implentation in inheriting classes.""" - - @property - def nb_inputs(self) -> int: - """Return number of input features.""" - return len(self.features) - - @property - def nb_outputs(self) -> int: - """Return number of output features. - - This the default, but may be overridden by specific inheriting classes. - """ - return self.nb_inputs - - def _validate_features(self, data: Data) -> None: - if isinstance(data, Batch): - # `data.features` is "transposed" and each list element contains only duplicate entries. - - if ( - len(data.features[0]) == data.num_graphs - and len(set(data.features[0])) == 1 - ): - data_features = [features[0] for features in data.features] - - # `data.features` is not "transposed" and each list element - # contains the original features. - else: - data_features = data.features[0] - else: - data_features = data.features - assert ( - data_features == self.features - ), f"Features on Data and Detector differ: {data_features} vs. {self.features}" + @final + def _standardize( + self, node_features: torch.tensor, node_feature_names: List[str] + ) -> Data: + for idx, feature in enumerate(node_feature_names): + try: + node_features[:, idx] = self.feature_map()[feature]( # type: ignore + node_features[:, idx] + ) + except KeyError as e: + self.warning( + f"""No Standardization function found for '{feature}'""" + ) + raise e + return node_features diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py index e6addb8e37f178bbae3994fdd8ff5cba759d36c9..c2028755ad5207316c320c427772c655775c7fd9 100644 --- a/src/graphnet/models/detector/icecube.py +++ b/src/graphnet/models/detector/icecube.py @@ -1,5 +1,6 @@ """IceCube-specific `Detector` class(es).""" +from typing import Dict, Callable import torch from torch_geometric.data import Data @@ -15,32 +16,33 @@ from graphnet.models.detector.detector import Detector class IceCube86(Detector): """`Detector` class for IceCube-86.""" - # Implementing abstract class attribute - features = FEATURES.ICECUBE86 - - def _forward(self, data: Data) -> Data: - """Ingest data, build graph, and preprocess features. - - Args: - data: Input graph data. - - Returns: - Connected and preprocessed graph data. - """ - # Check(s) - self._validate_features(data) - - # Preprocessing - data.x[:, 0] /= 500.0 # dom_x - data.x[:, 1] /= 500.0 # dom_y - data.x[:, 2] /= 500.0 # dom_z - data.x[:, 3] = (data.x[:, 3] - 1.0e04) / 3.0e4 # dom_time - data.x[:, 4] = torch.log10(data.x[:, 4]) / 3.0 # charge - data.x[:, 5] -= 1.25 # rde - data.x[:, 5] /= 0.25 - data.x[:, 6] /= 0.05 # pmt_area - - return data + def feature_map(self) -> Dict[str, Callable]: + """Map standardization functions to each dimension.""" + feature_map = { + "dom_x": self._dom_xyz, + "dom_y": self._dom_xyz, + "dom_z": self._dom_xyz, + "dom_time": self._dom_time, + "charge": self._charge, + "rde": self._rde, + "pmt_area": self._pmt_area, + } + return feature_map + + def _dom_xyz(self, x: torch.tensor) -> torch.tensor: + return x / 500.0 + + def _dom_time(self, x: torch.tensor) -> torch.tensor: + return x - 1.0e04 / 3.0e4 + + def _charge(self, x: torch.tensor) -> torch.tensor: + return torch.log10(x) + + def _rde(self, x: torch.tensor) -> torch.tensor: + return (x - 1.25) / 0.25 + + def _pmt_area(self, x: torch.tensor) -> torch.tensor: + return x / 0.05 class IceCubeKaggle(Detector): diff --git a/src/graphnet/models/detector/prometheus.py b/src/graphnet/models/detector/prometheus.py index 2a35886ac014e3ef82bcf98dd38d7a62e504f860..f21f9c41363546393c61783df78994f34e297320 100644 --- a/src/graphnet/models/detector/prometheus.py +++ b/src/graphnet/models/detector/prometheus.py @@ -1,6 +1,7 @@ """Prometheus-specific `Detector` class(es).""" -from torch_geometric.data import Data +from typing import Dict, Callable +import torch from graphnet.models.detector.detector import Detector @@ -8,27 +9,21 @@ from graphnet.models.detector.detector import Detector class Prometheus(Detector): """`Detector` class for Prometheus prototype.""" - features = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z", "t"] + def feature_map(self) -> Dict[str, Callable]: + """Map standardization functions to each dimension.""" + feature_map = { + "sensor_pos_x": self._sensor_pos_xy, + "sensor_pos_y": self._sensor_pos_xy, + "sensor_pos_z": self._sensor_pos_z, + "t": self._t, + } + return feature_map - def _forward(self, data: Data) -> Data: - """Ingest data, build graph, and preprocess features. + def _sensor_pos_xy(self, x: torch.tensor) -> torch.tensor: + return x / 100 - Args: - data: Input graph data. + def _sensor_pos_z(self, x: torch.tensor) -> torch.tensor: + return (x + 350) / 100 - Returns: - Connected and preprocessed graph data. - """ - # Check(s) - self._validate_features(data) - - # Preprocessing - data.x[:, 0] /= 100.0 # dom_x - data.x[:, 1] /= 100.0 # dom_y - data.x[:, 2] += 350.0 # dom_z - data.x[:, 2] /= 100.0 - data.x[:, 3] /= 1.05e04 # dom_time - data.x[:, 3] -= 1.0 - data.x[:, 3] *= 20.0 - - return data + def _t(self, x: torch.tensor) -> torch.tensor: + return ((x / 1.05e04) - 1.0) * 20.0 diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5066307d0d9f5525f8b6b93895546b9003aae3 --- /dev/null +++ b/src/graphnet/models/graphs/__init__.py @@ -0,0 +1,10 @@ +"""Modules for constructing graphs. + +´GraphDefinition´ defines the nodes and their features, and contains general +graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes +and their features. +""" + + +from .graph_definition import GraphDefinition +from .graphs import KNNGraph diff --git a/src/graphnet/models/graphs/edges/__init__.py b/src/graphnet/models/graphs/edges/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7da8baa7c2adb6556fb08525b6f1a6a2ddc1d181 --- /dev/null +++ b/src/graphnet/models/graphs/edges/__init__.py @@ -0,0 +1,7 @@ +"""Modules for constructing graphs. + +´GraphDefinition´ defines the nodes and their features, and contains general +graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes +and their features. +""" +from .edges import EdgeDefinition, KNNEdges, RadialEdges, EuclideanEdges diff --git a/src/graphnet/models/graphs/edges/edges.py b/src/graphnet/models/graphs/edges/edges.py new file mode 100644 index 0000000000000000000000000000000000000000..28507058bb5a4353025bd68cd7e6fc7d93207582 --- /dev/null +++ b/src/graphnet/models/graphs/edges/edges.py @@ -0,0 +1,188 @@ +"""Class(es) for building/connecting graphs.""" + +from typing import List +from abc import abstractmethod, ABC + +import torch +from torch_geometric.nn import knn_graph, radius_graph +from torch_geometric.data import Data + +from graphnet.utilities.config import save_model_config +from graphnet.models.utils import calculate_distance_matrix +from graphnet.models import Model + + +class EdgeDefinition(Model): # pylint: disable=too-few-public-methods + """Base class for graph building.""" + + def forward(self, graph: Data) -> Data: + """Construct edges based on problem specific implementation of. + + ´_construct_edges´ + + Args: + graph: a graph without edges + + Returns: + graph: a graph with edges + """ + if graph.edge_index is not None: + self.warnonce( + "GraphBuilder received graph with pre-existing " + "structure. Will overwrite." + ) + return self._construct_edges(graph) + + @abstractmethod + def _construct_edges(self, graph: Data) -> Data: + """Construct edges and assign them to graph. I.e. ´graph.edge_index = edge_index´. + + Args: + graph: graph without edges + + Returns: + graph: graph with edges assigned. + """ + + +class KNNEdges(EdgeDefinition): # pylint: disable=too-few-public-methods + """Builds edges from the k-nearest neighbours.""" + + @save_model_config + def __init__( + self, + nb_nearest_neighbours: int, + columns: List[int] = [0, 1, 2], + ): + """K-NN Edge definition. + + Will connect nodes together with their ´nb_nearest_neighbours´ + nearest neighbours in the feature space given by ´columns´. + + Args: + nb_nearest_neighbours: number of neighbours. + columns: Node features to use for distance calculation. + Defaults to [0,1,2]. + """ + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Member variable(s) + self._nb_nearest_neighbours = nb_nearest_neighbours + self._columns = columns + + def _construct_edges(self, graph: Data) -> Data: + """Define K-NN edges.""" + graph.edge_index = knn_graph( + graph.x[:, self._columns], + self._nb_nearest_neighbours, + graph.batch, + ).to(self.device) + + return graph + + +class RadialEdges(EdgeDefinition): + """Builds graph from a sphere of chosen radius centred at each node.""" + + @save_model_config + def __init__( + self, + radius: float, + columns: List[int] = [0, 1, 2], + ): + """Radial edges. + + Connects each node to other nodes that are within a sphere of + radius ´r´ centered at the node. The feature space of ´r´ is defined + by ´columns´ + + Args: + radius: radius of sphere + columns: columns of the node feature matrix used. + Defaults to [0,1,2]. + """ + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Member variable(s) + self._radius = radius + self._columns = columns + + def _construct_edges(self, graph: Data) -> Data: + """Define radial edges.""" + graph.edge_index = radius_graph( + graph.x[:, self._columns], + self._radius, + graph.batch, + ).to(self.device) + + return graph + + +class EuclideanEdges(EdgeDefinition): # pylint: disable=too-few-public-methods + """Builds edges according to Euclidean distance between nodes. + + See https://arxiv.org/pdf/1809.06166.pdf. + """ + + @save_model_config + def __init__( + self, + sigma: float, + threshold: float = 0.0, + columns: List[int] = None, + ): + """Construct `EuclideanEdges`.""" + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Check(s) + if columns is None: + columns = [0, 1, 2] + + # Member variable(s) + self._sigma = sigma + self._threshold = threshold + self._columns = columns + + def _construct_edges(self, graph: Data) -> Data: + """Forward pass.""" + # Constructs the adjacency matrix from the raw, DOM-level data and + # returns this matrix + if graph.edge_index is not None: + self.info( + "WARNING: GraphBuilder received graph with pre-existing " + "structure. Will overwrite." + ) + + xyz_coords = graph.x[:, self._columns] + + # Construct block-diagonal matrix indicating whether pulses belong to + # the same event in the batch + batch_mask = graph.batch.unsqueeze(dim=0) == graph.batch.unsqueeze( + dim=1 + ) + + distance_matrix = calculate_distance_matrix(xyz_coords) + affinity_matrix = torch.exp( + -0.5 * distance_matrix**2 / self._sigma**2 + ) + + # Use softmax to normalise all adjacencies to one for each node + exp_row_sums = torch.exp(affinity_matrix).sum(axis=1) + weighted_adj_matrix = torch.exp( + affinity_matrix + ) / exp_row_sums.unsqueeze(dim=1) + + # Only include edges with weights that exceed the chosen threshold (and + # are part of the same event) + sources, targets = torch.where( + (weighted_adj_matrix > self._threshold) & (batch_mask) + ) + edge_weights = weighted_adj_matrix[sources, targets] + + graph.edge_index = torch.stack((sources, targets)) + graph.edge_weight = edge_weights + + return graph diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf7b32542efb37abe08d229b8dd16e5fff7825d --- /dev/null +++ b/src/graphnet/models/graphs/graph_definition.py @@ -0,0 +1,270 @@ +"""Modules for defining graphs. + +These are self-contained graph definitions that hold all the graph-altering +code in graphnet. These modules define what the GNNs sees as input and can be +passed to dataloaders during training and deployment. +""" + + +from typing import Any, List, Optional, Dict, Callable +import torch +from torch_geometric.data import Data +import numpy as np + +from graphnet.utilities.config import save_model_config + +from graphnet.models.detector import Detector +from .edges import EdgeDefinition +from .nodes import NodeDefinition +from graphnet.models import Model + + +class GraphDefinition(Model): + """An Abstract class to create graph definitions from.""" + + @save_model_config + def __init__( + self, + detector: Detector, + node_definition: NodeDefinition, + edge_definition: Optional[EdgeDefinition] = None, + node_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = None, + ): + """Construct ´GraphDefinition´. The ´detector´ holds. + + ´Detector´-specific code. E.g. scaling/standardization and geometry + tables. + + ´node_definition´ defines the nodes in the graph. + + ´edge_definition´ defines the connectivity of the nodes in the graph. + + Args: + detector: The corresponding ´Detector´ representing the data. + node_definition: Definition of nodes. + edge_definition: Definition of edges. Defaults to None. + node_feature_names: Names of node feature columns. Defaults to None + dtype: data type used for node features. e.g. ´torch.float´ + """ + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Member Variables + self._detector = detector + self._edge_definiton = edge_definition + self._node_definition = node_definition + if node_feature_names is None: + # Assume all features in Detector is used. + node_feature_names = list(self._detector.feature_map().keys()) # type: ignore + print(node_feature_names) + self._node_feature_names = node_feature_names + if dtype is None: + dtype = torch.float + self._dtype = dtype + + # Set Input / Output dimensions + self._node_definition.set_number_of_inputs( + node_feature_names=node_feature_names + ) + self.nb_inputs = len(self._node_feature_names) + self.nb_outputs = self._node_definition.nb_outputs + + def forward( # type: ignore + self, + node_features: np.array, + node_feature_names: List[str], + truth_dicts: Optional[List[Dict[str, Any]]] = None, + custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None, + loss_weight_column: Optional[str] = None, + loss_weight: Optional[float] = None, + loss_weight_default_value: Optional[float] = None, + data_path: Optional[str] = None, + ) -> Data: + """Construct graph as ´Data´ object. + + Args: + node_features: node features for graph. Shape ´[num_nodes, d]´ + node_feature_names: name of each column. Shape ´[,d]´. + truth_dicts: Dictionary containing truth labels. + custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels. + loss_weight_column: Name of column that holds loss weight. Defaults to None. + loss_weight: Loss weight associated with event. Defaults to None. + loss_weight_default_value: default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None. + data_path: Path to dataset data files. Defaults to None. + + Returns: + graph + """ + # Checks + self._validate_input( + node_features=node_features, node_feature_names=node_feature_names + ) + + # Transform to pytorch tensor + node_features = torch.tensor(node_features, dtype=self._dtype) + + # Standardize / Scale node features + node_features = self._detector(node_features, node_feature_names) + + # Create graph + graph = self._node_definition(node_features) + + # Attach number of pulses as static attribute. + graph.n_pulses = torch.tensor(len(node_features), dtype=torch.int32) + + # Assign edges + if self._edge_definiton is not None: + graph = self._edge_definiton(graph) + else: + self.warnonce( + "No EdgeDefinition provided. Graphs will not have edges defined!" + ) + + # Attach data path - useful for Ensemble datasets. + if data_path is not None: + graph["dataset_path"] = data_path + + # Attach loss weights if they exist + graph = self._add_loss_weights( + graph=graph, + loss_weight=loss_weight, + loss_weight_column=loss_weight_column, + loss_weight_default_value=loss_weight_default_value, + ) + + # Attach default truth labels and node truths + if truth_dicts is not None: + graph = self._add_truth(graph=graph, truth_dicts=truth_dicts) + + # Attach custom truth labels + if custom_label_functions is not None: + graph = self._add_custom_labels( + graph=graph, custom_label_functions=custom_label_functions + ) + + # Attach node features as seperate fields. MAY NOT CONTAIN 'x' + graph = self._add_features_individually( + graph=graph, node_feature_names=node_feature_names + ) + + # Add GraphDefinition Stamp + graph["graph_definition"] = self.__class__.__name__ + return graph + + def _validate_input( + self, node_features: np.array, node_feature_names: List[str] + ) -> None: + + # node feature matrix dimension check + assert node_features.shape[1] == len(node_feature_names) + + # check that provided features for input is the same that the ´Graph´ + # was instantiated with. + assert len(node_feature_names) == len( + self._node_feature_names + ), f"""Input features ({node_feature_names}) is not what {self.__class__.__name__} was instatiated with ({self._node_feature_names})""" + for idx in range(len(node_feature_names)): + assert ( + node_feature_names[idx] == self._node_feature_names[idx] + ), """ Order of node features are not the same.""" + + def _add_loss_weights( + self, + graph: Data, + loss_weight_column: Optional[str] = None, + loss_weight: Optional[float] = None, + loss_weight_default_value: Optional[float] = None, + ) -> Data: + """Attempt to store a loss weight in the graph for use during training. + + I.e. `graph[loss_weight_column] = loss_weight` + + Args: + loss_weight: The non-negative weight to be stored. + graph: Data object representing the event. + loss_weight_column: The name under which the weight is stored in + the graph. + loss_weight_default_value: The default value used if + none was retrieved. + + Returns: + A graph with loss weight added, if available. + """ + # Add loss weight to graph. + if loss_weight is not None and loss_weight_column is not None: + # No loss weight was retrieved, i.e., it is missing for the current + # event. + if loss_weight < 0: + if loss_weight_default_value is None: + raise ValueError( + "At least one event is missing an entry in " + f"{loss_weight_column} " + "but loss_weight_default_value is None." + ) + graph[loss_weight_column] = torch.tensor( + self._loss_weight_default_value, dtype=self._dtype + ).reshape(-1, 1) + else: + graph[loss_weight_column] = torch.tensor( + loss_weight, dtype=self._dtype + ).reshape(-1, 1) + return graph + + def _add_truth( + self, graph: Data, truth_dicts: List[Dict[str, Any]] + ) -> Data: + """Add truth labels from ´truth_dicts´ to ´graph´. + + I.e. ´graph[key] = truth_dict[key]´ + + + Args: + graph: graph where the label will be stored + truth_dicts: dictionary containing the labels + + Returns: + graph with labels + """ + # Write attributes, either target labels, truth info or original + # features. + for truth_dict in truth_dicts: + for key, value in truth_dict.items(): + try: + graph[key] = torch.tensor(value) + except TypeError: + # Cannot convert `value` to Tensor due to its data type, + # e.g. `str`. + self.debug( + ( + f"Could not assign `{key}` with type " + f"'{type(value).__name__}' as attribute to graph." + ) + ) + return graph + + def _add_features_individually( + self, + graph: Data, + node_feature_names: List[str], + ) -> Data: + # Additionally add original features as (static) attributes + graph.features = node_feature_names + for index, feature in enumerate(node_feature_names): + if feature not in ["x"]: # reserved for node features. + graph[feature] = graph.x[:, index].detach() + else: + self.warnonce( + """Cannot assign graph['x']. This field is reserved for node features. Please rename your input feature.""" + ) + return graph + + def _add_custom_labels( + self, + graph: Data, + custom_label_functions: Dict[str, Callable[..., Any]], + ) -> Data: + # Add custom labels to the graph + for key, fn in custom_label_functions.items(): + graph[key] = fn(graph) + return graph diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py new file mode 100644 index 0000000000000000000000000000000000000000..a9c3ff983f655f12d7ed5d676d49bb2ef71dc9a3 --- /dev/null +++ b/src/graphnet/models/graphs/graphs.py @@ -0,0 +1,47 @@ +"""A module containing different graph representations in GraphNeT.""" + +from typing import List, Optional +import torch + +from graphnet.utilities.config import save_model_config +from .graph_definition import GraphDefinition +from graphnet.models.detector import Detector +from graphnet.models.graphs.edges import EdgeDefinition, KNNEdges +from graphnet.models.graphs.nodes import NodeDefinition + + +class KNNGraph(GraphDefinition): + """A Graph representation where Edges are drawn to nearest neighbours.""" + + @save_model_config + def __init__( + self, + detector: Detector, + node_definition: NodeDefinition, + node_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = None, + nb_nearest_neighbours: int = 8, + columns: List[int] = [0, 1, 2], + ) -> None: + """Construct k-nn graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + node_feature_names: Name of node features. + dtype: data type for node features. + nb_nearest_neighbours: Number of edges for each node. Defaults to 8. + columns: node feature columns used for distance calculation + . Defaults to [0, 1, 2]. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition, + edge_definition=KNNEdges( + nb_nearest_neighbours=nb_nearest_neighbours, + columns=columns, + ), + dtype=dtype, + node_feature_names=node_feature_names, + ) diff --git a/src/graphnet/models/graphs/nodes/__init__.py b/src/graphnet/models/graphs/nodes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05194b61acfdf76e803606fb29e62b540619edac --- /dev/null +++ b/src/graphnet/models/graphs/nodes/__init__.py @@ -0,0 +1,8 @@ +"""Modules for constructing graphs. + +´GraphDefinition´ defines the nodes and their features, and contains general +graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes +and their features. +""" + +from .nodes import NodeDefinition, NodesAsPulses diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..afe0dcfce2be1e012984eb8e77be24d69d0a1b9a --- /dev/null +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -0,0 +1,74 @@ +"""Class(es) for building/connecting graphs.""" + +from typing import List +from abc import abstractmethod + +import torch +from torch_geometric.nn import knn_graph, radius_graph +from torch_geometric.data import Data +import numpy as np + +from graphnet.utilities.decorators import final +from graphnet.utilities.config import save_model_config +from graphnet.models import Model + + +class NodeDefinition(Model): # pylint: disable=too-few-public-methods + """Base class for graph building.""" + + @save_model_config + def __init__(self) -> None: + """Construct `Detector`.""" + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + @final + def forward(self, x: torch.tensor) -> Data: + """Construct nodes from raw node features. + + Args: + x: standardized node features with shape ´[num_pulses, d]´, + where ´d´ is the number of node features. + + Returns: + graph: a graph without edges + """ + graph = self._construct_nodes(x) + return graph + + @property + def nb_outputs(self) -> int: + """Return number of output features. + + This the default, but may be overridden by specific inheriting classes. + """ + return self.nb_inputs + + @final + def set_number_of_inputs(self, node_feature_names: List[str]) -> None: + """Return number of inputs expected by node definition. + + Args: + node_feature_names: name of each node feature column. + """ + assert isinstance(node_feature_names, list) + self.nb_inputs = len(node_feature_names) + + @abstractmethod + def _construct_nodes(self, x: torch.tensor) -> Data: + """Construct nodes from raw node features ´x´. + + Args: + x: standardized node features with shape ´[num_pulses, d]´, + where ´d´ is the number of node features. + + Returns: + graph: graph without edges. + """ + + +class NodesAsPulses(NodeDefinition): + """Represent each measured pulse of Cherenkov Radiation as a node.""" + + def _construct_nodes(self, x: torch.tensor) -> Data: + return Data(x=x) diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index 41b70bb269f49ffede804dc7b542be27676787bb..01fa0a574bf41ff05c7bfef51761cfeb16b042ed 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -12,7 +12,7 @@ import pandas as pd from graphnet.models.coarsening import Coarsening from graphnet.utilities.config import save_model_config -from graphnet.models.detector.detector import Detector +from graphnet.models.graphs import GraphDefinition from graphnet.models.gnn.gnn import GNN from graphnet.models.model import Model from graphnet.models.task import Task @@ -29,7 +29,7 @@ class StandardModel(Model): def __init__( self, *, - detector: Detector, + graph_definition: GraphDefinition, gnn: GNN, tasks: Union[Task, List[Task]], coarsening: Optional[Coarsening] = None, @@ -48,12 +48,12 @@ class StandardModel(Model): tasks = [tasks] assert isinstance(tasks, (list, tuple)) assert all(isinstance(task, Task) for task in tasks) - assert isinstance(detector, Detector) + assert isinstance(graph_definition, GraphDefinition) assert isinstance(gnn, GNN) assert coarsening is None or isinstance(coarsening, Coarsening) # Member variable(s) - self._detector = detector + self._graph_definition = graph_definition self._gnn = gnn self._tasks = ModuleList(tasks) self._coarsening = coarsening @@ -101,7 +101,7 @@ class StandardModel(Model): """Forward pass, chaining model components.""" if self._coarsening: data = self._coarsening(data) - data = self._detector(data) + assert isinstance(data, Data) x = self._gnn(data) preds = [task(x) for task in self._tasks] return preds diff --git a/src/graphnet/models/utils.py b/src/graphnet/models/utils.py index 942e76caa81898858263019a37046430685fe824..e1ef7956c150da86992fa23e8f44c4a38bca3633 100644 --- a/src/graphnet/models/utils.py +++ b/src/graphnet/models/utils.py @@ -1,6 +1,6 @@ """Utility functions for `graphnet.models`.""" -from typing import List, Tuple +from typing import List, Tuple, Union from torch_geometric.nn import knn_graph from torch_geometric.data import Batch import torch diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index 52b7634e85817953c601df65363dd5280c08567b..2578ff9a65875e6a4937d95273feff46922f4170 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -12,10 +12,11 @@ from torch.utils.data import DataLoader from torch_geometric.data import Batch, Data from graphnet.data.dataset import Dataset -from graphnet.data.sqlite import SQLiteDataset -from graphnet.data.parquet import ParquetDataset +from graphnet.data.dataset import SQLiteDataset +from graphnet.data.dataset import ParquetDataset from graphnet.models import Model from graphnet.utilities.logging import Logger +from graphnet.models.graphs import GraphDefinition def collate_fn(graphs: List[Data]) -> Batch: @@ -31,6 +32,7 @@ def collate_fn(graphs: List[Data]) -> Batch: def make_dataloader( db: str, pulsemaps: Union[str, List[str]], + graph_definition: Optional[GraphDefinition], features: List[str], truth: List[str], *, @@ -66,6 +68,7 @@ def make_dataloader( loss_weight_table=loss_weight_table, loss_weight_column=loss_weight_column, index_column=index_column, + graph_definition=graph_definition, ) # adds custom labels to dataset @@ -89,6 +92,7 @@ def make_dataloader( # @TODO: Remove in favour of DataLoader{,.from_dataset_config} def make_train_validation_dataloader( db: str, + graph_definition: Optional[GraphDefinition], selection: Optional[List[int]], pulsemaps: Union[str, List[str]], features: List[str], @@ -122,19 +126,21 @@ def make_train_validation_dataloader( dataset: Dataset if db.endswith(".db"): dataset = SQLiteDataset( - db, - pulsemaps, - features, - truth, + path=db, + graph_definition=graph_definition, + pulsemaps=pulsemaps, + features=features, + truth=truth, truth_table=truth_table, index_column=index_column, ) elif db.endswith(".parquet"): dataset = ParquetDataset( - db, - pulsemaps, - features, - truth, + path=db, + graph_definition=graph_definition, + pulsemaps=pulsemaps, + features=features, + truth=truth, truth_table=truth_table, index_column=index_column, ) @@ -179,6 +185,7 @@ def make_train_validation_dataloader( loss_weight_table=loss_weight_table, index_column=index_column, labels=labels, + graph_definition=graph_definition, ) training_dataloader = make_dataloader( diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py index fc5beb44b01d429f173ce3e9b8d28f12d3ceaef4..bb9ed26785d498a9882d5f84077d87bf88ebc800 100644 --- a/src/graphnet/utilities/config/dataset_config.py +++ b/src/graphnet/utilities/config/dataset_config.py @@ -2,6 +2,7 @@ from functools import wraps from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -14,6 +15,12 @@ from graphnet.utilities.config.base_config import ( BaseConfig, get_all_argument_values, ) +from graphnet.utilities.config.parsing import traverse_and_apply +from .model_config import ModelConfig + +if TYPE_CHECKING: + from graphnet.models import Model + BACKEND_LOOKUP = { "db": "sqlite", @@ -45,8 +52,8 @@ class DatasetConfig(BaseConfig): loss_weight_table: Optional[str] = None loss_weight_column: Optional[str] = None loss_weight_default_value: Optional[float] = None - seed: Optional[int] = None + graph_definition: Any = None def __init__(self, **data: Any) -> None: """Construct `DataConfig`. @@ -139,8 +146,8 @@ class DatasetConfig(BaseConfig): @property def _dataset_class(self) -> type: """Return the `Dataset` class implementation for this configuration.""" - from graphnet.data.sqlite import SQLiteDataset - from graphnet.data.parquet import ParquetDataset + from graphnet.data.dataset.sqlite import SQLiteDataset + from graphnet.data.dataset.parquet import ParquetDataset dataset_class = { "sqlite": SQLiteDataset, @@ -153,6 +160,17 @@ class DatasetConfig(BaseConfig): def save_dataset_config(init_fn: Callable) -> Callable: """Save the arguments to `__init__` functions as member `DatasetConfig`.""" + def _replace_model_instance_with_config( + obj: Union["Model", Any] + ) -> Union[ModelConfig, Any]: + """Replace `Model` instances in `obj` with their `ModelConfig`.""" + from graphnet.models import Model + + if isinstance(obj, Model): + return obj.config + else: + return obj + @wraps(init_fn) def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: """Set `DatasetConfig` after calling `init_fn`.""" @@ -162,6 +180,9 @@ def save_dataset_config(init_fn: Callable) -> Callable: # Get all argument values, including defaults cfg = get_all_argument_values(init_fn, *args, **kwargs) + # Handle nested `Model`s, etc. + cfg = traverse_and_apply(cfg, _replace_model_instance_with_config) + # Add `DatasetConfig` as member variables self._config = DatasetConfig(**cfg) diff --git a/src/graphnet/utilities/config/model_config.py b/src/graphnet/utilities/config/model_config.py index d4733ca6863a27979bcc55e2a5812bd143409ba1..21e18c104fa321fec027fdc91137d1ebdf0635e3 100644 --- a/src/graphnet/utilities/config/model_config.py +++ b/src/graphnet/utilities/config/model_config.py @@ -133,7 +133,7 @@ class ModelConfig(BaseConfig): fn_kwargs={"trust": trust}, ) - # Construct model based on arguments + # Construct model based on return namespace_classes[self.class_name](**arguments) @classmethod diff --git a/tests/utilities/test_dataset_config.py b/tests/utilities/test_dataset_config.py index a09d0e5ada778ffe4132a9fb02132134dc9c83c1..88f3b3f1cb3bc6f24ab56c740d275b0192478fbe 100644 --- a/tests/utilities/test_dataset_config.py +++ b/tests/utilities/test_dataset_config.py @@ -13,8 +13,8 @@ import graphnet import graphnet.constants from graphnet.data.constants import FEATURES, TRUTH from graphnet.data.dataset import Dataset -from graphnet.data.parquet import ParquetDataset -from graphnet.data.sqlite import SQLiteDataset +from graphnet.data.dataset import ParquetDataset +from graphnet.data.dataset import SQLiteDataset from graphnet.utilities.config import DatasetConfig CONFIG_PATHS = {