diff --git a/src/graphnet/data/sqlite_dataset.py b/src/graphnet/data/sqlite_dataset.py index 3d55ceb6bfedf7ecc09ef77ca4e882bd9474a9a0..d1c23ddd2eb602fcb9d59324e8b0153468eff93c 100644 --- a/src/graphnet/data/sqlite_dataset.py +++ b/src/graphnet/data/sqlite_dataset.py @@ -19,6 +19,7 @@ class SQLiteDataset(torch.utils.data.Dataset): truth_table: str = 'truth', selection: Optional[List[int]] = None, dtype: torch.dtype = torch.float32, + node_representation = 'pulse', ): # Check(s) @@ -37,6 +38,8 @@ class SQLiteDataset(torch.utils.data.Dataset): assert isinstance(features, (list, tuple)) assert isinstance(truth, (list, tuple)) + assert node_representation.lower() in ['pulse', 'dom'] + self._database = database self._pulsemaps = pulsemaps self._features = [index_column] + features @@ -44,6 +47,7 @@ class SQLiteDataset(torch.utils.data.Dataset): self._index_column = index_column self._truth_table = truth_table self._dtype = dtype + self._node_representation = node_representation self._features_string = ', '.join(self._features) self._truth_string = ', '.join(self._truth) @@ -118,6 +122,36 @@ class SQLiteDataset(torch.utils.data.Dataset): except: return -1 + def _get_unique_positions(self, tensor): + return torch.unique(tensor, return_counts = True, return_inverse=True, dim=0) + + def _make_dom_wise_representation(self,data): + unique_doms, inverse_idx, n_pulses_pr_dom = self._get_unique_positions(data.x[:,[self._features.index('dom_x') -1, + self._features.index('dom_y') -1, + self._features.index('dom_z') -1, + self._features.index('rde') -1, + self._features.index('pmt_area') -1]]) + unique_inverse_indices = torch.unique(inverse_idx) + count = 0 + pulse_statistics = torch.zeros(size = (len(unique_doms), 8)) + #'dom_x','dom_y','dom_z','dom_time','charge','rde','pmt_area' + time_idx = self._features.index('dom_time')-1 + charge_idx = self._features.index('charge')-1 + for unique_inverse_idx in unique_inverse_indices: + time = data.x[inverse_idx == unique_inverse_idx,time_idx] + charge = data.x[inverse_idx == unique_inverse_idx,charge_idx] + pulse_statistics[count,0] = torch.min(time) + pulse_statistics[count,1] = torch.mean(time) + pulse_statistics[count,2] = torch.max(time) + pulse_statistics[count,3] = torch.std(time, unbiased=False) + pulse_statistics[count,4] = torch.min(charge) + pulse_statistics[count,5] = torch.mean(charge) + pulse_statistics[count,6] = torch.max(charge) + pulse_statistics[count,7] = torch.std(charge, unbiased=False) + count +=1 + data.x = torch.cat((unique_doms, n_pulses_pr_dom.unsqueeze(1), pulse_statistics), dim = 1) + return data + def _create_graph(self, features, truth): """Create Pytorch Data (i.e.graph) object. @@ -161,12 +195,18 @@ class SQLiteDataset(torch.utils.data.Dataset): data = np.array([]).reshape((0, len(self._features) - 1)) # Construct graph data object + x = torch.tensor(data, dtype=self._dtype) n_pulses = torch.tensor(len(x), dtype=torch.int32) graph = Data( x=x, edge_index= None ) + if self._node_representation.lower() == 'dom': + graph = self._make_dom_wise_representation(graph) + + + graph.n_pulses = n_pulses graph.features = self._features[1:] diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py index 4297be6e9ea671a123810cad1577476fb18a42d0..22fa5fcf8b2f6dd9fc77979a4a05a8759536f5fc 100644 --- a/src/graphnet/models/detector/icecube.py +++ b/src/graphnet/models/detector/icecube.py @@ -40,6 +40,66 @@ class IceCube86(Detector): return data +class IceCube86_DOM(Detector): + """`Detector` class for IceCube-86 with nodes as doms.""" + + # Implementing abstract class attribute + features = FEATURES.ICECUBE86 + + @property + def nb_inputs(self) -> int: + return len(self.features) + 7 + + @property + def nb_outputs(self): + return self.nb_inputs + + def _forward(self, data: Data) -> Data: + """Ingests data, builds graph (connectivity/adjacency), and preprocesses features. + + Args: + data (Data): Input graph data. + + Returns: + Data: Connected and preprocessed graph data. + """ + + # Check(s) + self._validate_features(data) + + # Preprocessing + data.x[:,0] /= 100. # dom_x + data.x[:,1] /= 100. # dom_y + data.x[:,2] += 350. # dom_z + data.x[:,2] /= 100. + data.x[:,3] -= 1.25 # rde + data.x[:,3] /= 0.25 + data.x[:,4] /= 0.05 # pmt_area + + data.x[:,5] /= 5 # n_pulses + data.x[:,5] -= 2 + + data.x[:,6] /= 1.05e+04 # dom_time + data.x[:,6] -= 1. + data.x[:,6] *= 20. + + data.x[:,7] /= 1.05e+04 # dom_time + data.x[:,7] -= 1. + data.x[:,7] *= 20 + + data.x[:,8] /= 1.05e+04 # dom_time + data.x[:,8] -= 1. + data.x[:,8] *= 20 + + data.x[:,9] /= 1.05e+04 # dom_time + data.x[:,9] -= 1. + data.x[:,9] *= 20 + + data.x[:,10] /= 1. # charge + data.x[:,11] /= 1. # charge + data.x[:,12] /= 1. # charge + data.x[:,13] /= 1. # charge + return data class IceCubeDeepCore(IceCube86): """`Detector` class for IceCube-DeepCore.""" diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index 388e482781d3e9535185789121c10c66a53e519d..c138e34e3f75921de57c62a9f0aa62474fef6701 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -31,7 +31,6 @@ class DynEdge(GNN): # Architecture configuration c = layer_size_scale l1, l2, l3, l4, l5,l6 = nb_inputs, c*16*2, c*32*2, c*42*2, c*32*2, c*16*2 - # Base class constructor super().__init__(nb_inputs, l6) diff --git a/src/graphnet/models/graph_builders.py b/src/graphnet/models/graph_builders.py index 4c0f59567c65c0f9b17bbc7a7c18552f48cdcd2e..604eb6f9f911874d0c0d4530e44728990f0bbb7d 100644 --- a/src/graphnet/models/graph_builders.py +++ b/src/graphnet/models/graph_builders.py @@ -3,8 +3,8 @@ from typing import List import torch from torch_geometric.nn import knn_graph,radius_graph -from torch_geometric.data import Data - +from torch_geometric.data import Data, Batch +from graphnet.components.pool import group_identical from graphnet.models.utils import calculate_distance_matrix @@ -45,7 +45,6 @@ class KNNGraphBuilder(GraphBuilder): # pylint: disable=too-few-public-methods return data - class RadialGraphBuilder(GraphBuilder): """Builds graph adjacency according to a sphere of chosen radius centred at each DOM hit""" def __init__ ( diff --git a/src/graphnet/models/training/utils.py b/src/graphnet/models/training/utils.py index 6867cc6bbd23f0787ccf099dbc45a5da1e6414f9..bc2f1bef2828a17cc3bbf631666e3cbdc86f8b58 100644 --- a/src/graphnet/models/training/utils.py +++ b/src/graphnet/models/training/utils.py @@ -26,6 +26,7 @@ def make_dataloader( selection: List[int] = None, num_workers: int = 10, persistent_workers: bool = True, + node_representation: str = 'pulse', ) -> DataLoader: # Check(s) @@ -38,6 +39,7 @@ def make_dataloader( features, truth, selection=selection, + node_representation = node_representation ) def collate_fn(graphs): @@ -70,6 +72,7 @@ def make_train_validation_dataloader( test_size: float = 0.33, num_workers: int = 10, persistent_workers: bool = True, + node_representation: str = 'pulse' ) -> Tuple[DataLoader]: # Reproducibility @@ -98,6 +101,7 @@ def make_train_validation_dataloader( batch_size=batch_size, num_workers=num_workers, persistent_workers=persistent_workers, + node_representation = node_representation, ) training_dataloader = make_dataloader(