From 89445fe86fd45675acb20ab676b7530475d0fcbb Mon Sep 17 00:00:00 2001 From: Aske-Rosted <askerosted@gmail.com> Date: Fri, 19 May 2023 17:21:08 +0900 Subject: [PATCH 1/3] The implementation of the RNN_DynEdge model --- .../models/components/DOM_handling.py | 147 ++++++++++++++++++ src/graphnet/models/gnn/__init__.py | 1 + src/graphnet/models/gnn/node_rnn.py | 81 ++++++++++ src/graphnet/models/gnn/rnn_dynedge.py | 90 +++++++++++ 4 files changed, 319 insertions(+) create mode 100644 src/graphnet/models/components/DOM_handling.py create mode 100644 src/graphnet/models/gnn/node_rnn.py create mode 100644 src/graphnet/models/gnn/rnn_dynedge.py diff --git a/src/graphnet/models/components/DOM_handling.py b/src/graphnet/models/components/DOM_handling.py new file mode 100644 index 000000000..0bfc95b4d --- /dev/null +++ b/src/graphnet/models/components/DOM_handling.py @@ -0,0 +1,147 @@ +"""Functions for handling event level DOM data.""" +from typing import Tuple, Optional +from torch_geometric.nn.pool import knn_graph + +import torch + + +@torch.jit.script +def append_dom_id( + tensor: torch.Tensor, + batch: torch.Tensor, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Assign a unique ID to each DOM. + + The ID is assigned based on the position of the DOM in the input tensor. requires x,y,z as the first three columns of the input tensor. + + Args: + tensor (torch.Tensor): Input tensor. + batch (torch.Tensor, optional): Batch tensor. Defaults to None. + device (torch.device): Device to use. + + Returns: + torch.Tensor: Tensor of DOM IDs. + """ + inverse_matrix = torch.zeros( + (tensor.shape[0], 3), device=device, dtype=torch.int64 + ) + for i in range(3): + _, inverse = torch.unique(tensor[:, i], return_inverse=True) + inverse_matrix[:, i] = inverse + + if batch is None: + batch = torch.zeros(tensor.shape[0], device=device, dtype=torch.int64) + + inverse_matrix = torch.hstack([batch.unsqueeze(1), inverse_matrix]) + for i in range(3): + inverse_matrix[:, 1] = inverse_matrix[:, 0] + ( + (torch.max(inverse_matrix[:, 0]) + 1) * (inverse_matrix[:, 1] + 1) + ) + + _, inverse_matrix[:, 1] = torch.unique( + inverse_matrix[:, 1], return_inverse=True + ) + + inverse_matrix = inverse_matrix[:, -(3 - i) :] + + inverse_matrix = inverse_matrix.flatten() + tensor = torch.hstack([tensor, inverse_matrix.unsqueeze(1)]) + return tensor + + +torch.jit.script + + +def DOM_to_time_series( + tensor: torch.Tensor, batch: torch.tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Create time series for each activated DOM from dom activation data. + + Returns a time series view of the DOM activations as well as an updated batch tensor. REQURES the DOM ID to be the last column of the input tensor. + Args: + tensor (torch.Tensor): Input tensor. + batch (torch.Tensor): Batch tensor. + + Returns: + tensor (torch.Tensor): Times series DOM data. + sort_batch (torch.Tensor): Batch tensor. + """ + dom_activation_sort = tensor[:, -1].sort()[-1] + tensor, batch = tensor[dom_activation_sort], batch[dom_activation_sort] + bin_count = torch.bincount(tensor[:, -1].type(torch.int64)).cumsum(0) + batch = batch[bin_count - 1] + tensor = tensor[:, :-1] + tensor = torch.tensor_split(tensor, bin_count.cpu()[:-1]) + lengths_index = ( + torch.as_tensor([v.size(0) for v in tensor]).sort()[-1].flip(0) + ) + batch = batch[lengths_index] + + return tensor, batch + + +torch.jit.script + + +def DOM_time_series_to_pack_sequence( + tensor: torch.Tensor, device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor]: + """Packs DOM time series data into a pack sequence. + + Also returns a tensor of x/y/z-coordinates and time of first/mean activation for each DOM. + + Args: + tensor (torch.Tensor): Input tensor. + device (torch.device): Device to use. + + Returns: + tensor (torch.Tensor): Packed sequence of DOM time series data. + xyztt (torch.Tensor): Tensor of x/y/z-coordinates as well as time of mean & first activations for each DOM. + """ + tensor = sorted_jit_ignore(tensor) + xyztt = torch.stack( + [ + torch.cat( + [ + v[0, :3], + torch.as_tensor( + [v[:, 4].mean(), v[:, 4].min()], device=device + ), + ] + ) + for v in tensor + ] + ) + + tensor = torch.nn.utils.rnn.pack_sequence(tensor, enforce_sorted=True) + return tensor, xyztt + + +@torch.jit.ignore +def sorted_jit_ignore(tensor: torch.Tensor) -> torch.Tensor: + """Sort a tensor based on the length of the elements. + + Args: + tensor (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Sorted tensor. + torch.Tensor: Indices of sorted tensor. + """ + tensor = sorted(tensor, key=len, reverse=True) + return tensor + + +@torch.jit.ignore +def knn_graph_ignore( + x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None +) -> torch.Tensor: + """Create a kNN graph based on the input data. + + Args: + x: Input data. + k: Number of neighbours. + batch: Batch index. + """ + return knn_graph(x=x, k=k, batch=batch) diff --git a/src/graphnet/models/gnn/__init__.py b/src/graphnet/models/gnn/__init__.py index 4da7f5b21..b5092e9fe 100644 --- a/src/graphnet/models/gnn/__init__.py +++ b/src/graphnet/models/gnn/__init__.py @@ -3,3 +3,4 @@ from .convnet import ConvNet from .dynedge import DynEdge from .dynedge_jinst import DynEdgeJINST +from .node_rnn import Node_RNN diff --git a/src/graphnet/models/gnn/node_rnn.py b/src/graphnet/models/gnn/node_rnn.py new file mode 100644 index 000000000..1c5e1e969 --- /dev/null +++ b/src/graphnet/models/gnn/node_rnn.py @@ -0,0 +1,81 @@ +"""Implementation of the NodeTimeRNN model. + +(cannot be used as a standalone model) +""" +import torch + +from graphnet.models.gnn.gnn import GNN +from graphnet.models.components.DOM_handling import ( + append_dom_id, + DOM_time_series_to_pack_sequence, + DOM_to_time_series, + knn_graph_ignore, +) +from graphnet.utilities.config import save_model_config +from torch_geometric.data import Data + + +class Node_RNN(GNN): + """Implementation of the RNN model architecture. + + The model takes as input the typical DOM data format and transforms it into + a time series of DOM activations pr. DOM. before applying a RNN layer and + outputting the an RNN output for each DOM. + """ + + @save_model_config + def __init__( + self, + nb_inputs: int, + hidden_size: int, + num_layers: int, + nb_neighbours: int, + ) -> None: + """Construct `NodeTimeRNN`. + + Args: + nb_inputs: Number of features in the input data. + hidden_size: Number of features for the RNN output and hidden layers. + num_layers: Number of layers in the RNN. + nb_neighbours: Number of neighbours to use when reconstructing the graph representation. + """ + self._nb_neighbours = nb_neighbours + self._num_layers = num_layers + self._hidden_size = hidden_size + self._nb_neighbours = nb_neighbours + self._nb_inputs = nb_inputs + super().__init__(nb_inputs, hidden_size + 5) + + self._rnn = torch.nn.RNN( + num_layers=self._num_layers, + input_size=self._nb_inputs, + hidden_size=self._hidden_size, + ) + + def forward(self, data: Data) -> torch.Tensor: + """Apply learnable forward pass to the GNN.""" + x, batch = data.x, data.batch + + x = append_dom_id(x, batch, device=self.device) # add dom id to data + x, batch = DOM_to_time_series( + x, batch + ) # create time series for each dom + x, xyztt = DOM_time_series_to_pack_sequence( + x, device=self.device + ) # pack time series into pack sequence + x = self._rnn(x)[-1][0] # apply rnn layer + + x = torch.hstack( + [xyztt, x] + ) # reintroduce x/y/z-coordinates and time of first/mean activation for each DOM + + batch, sort_index = batch.sort() + data.x = x[sort_index] + data.batch = batch + edge_index = knn_graph_ignore( + x=data.x, + k=self._nb_neighbours, + batch=data.batch, + ).to(self.device) + data.edge_index = edge_index + return data diff --git a/src/graphnet/models/gnn/rnn_dynedge.py b/src/graphnet/models/gnn/rnn_dynedge.py new file mode 100644 index 000000000..e3e58b891 --- /dev/null +++ b/src/graphnet/models/gnn/rnn_dynedge.py @@ -0,0 +1,90 @@ +"""RNN_DynEdge model implementation.""" +from typing import List, Optional, Tuple, Union + +import torch +from graphnet.models.gnn.gnn import GNN +from graphnet.models.gnn.dynedge import DynEdge +from graphnet.models.gnn.node_rnn import Node_RNN + +from graphnet.utilities.config import save_model_config +from torch_geometric.data import Data + + +class RNN_DynEdge(GNN): + """The RNN_DynEdge model class.""" + + @save_model_config + def __init__( + self, + nb_inputs: int, + *, + nb_neighbours: int = 8, + RNN_layers: int = 1, + RNN_hidden_size: int = 64, + features_subset: Optional[Union[List[int], slice]] = None, + dynedge_layer_sizes: Optional[List[Tuple[int, ...]]] = None, + post_processing_layer_sizes: Optional[List[int]] = None, + readout_layer_sizes: Optional[List[int]] = None, + global_pooling_schemes: Optional[Union[str, List[str]]] = None, + add_global_variables_after_pooling: bool = False, + ): + """Initialize the RNN_DynEdge model. + + Args: + nb_inputs (int): Number of input features. + nb_neighbours (int, optional): Number of neighbours to consider. + Defaults to 8. + RNN_layers (int, optional): Number of RNN layers. + Defaults to 1. + RNN_hidden_size (int, optional): Size of the hidden state of the RNN. Also determines the size of the output of the RNN. + Defaults to 64. + features_subset (Optional[Union[List[int], slice]], optional): Subset of features to use. + dynedge_layer_sizes (Optional[List[Tuple[int, ...]]], optional): List of tuples of integers representing the sizes of the hidden layers of the DynEdge model. + post_processing_layer_sizes (Optional[List[int]], optional): List of integers representing the sizes of the hidden layers of the post-processing model. + readout_layer_sizes (Optional[List[int]], optional): List of integers representing the sizes of the hidden layers of the readout model. + global_pooling_schemes (Optional[Union[str, List[str]]], optional): Pooling schemes to use. Defaults to None. + add_global_variables_after_pooling (bool, optional): Whether to add global variables after pooling. Defaults to False. + """ + self._nb_neighbours = nb_neighbours + self._nb_inputs = nb_inputs + self._RNN_layers = RNN_layers + self._RNN_hidden_size = RNN_hidden_size + + self._feautres_subset = features_subset + self._dynedge_layer_sizes = dynedge_layer_sizes + self._post_processing_layer_sizes = post_processing_layer_sizes + self._global_pooling_schemes = global_pooling_schemes + self._add_global_variables_after_pooling = ( + add_global_variables_after_pooling + ) + if readout_layer_sizes is None: + readout_layer_sizes = [ + 128, + ] + self._readout_layer_sizes = readout_layer_sizes + + super().__init__(nb_inputs, self._readout_layer_sizes[-1]) + + self._rnn = Node_RNN( + num_layers=self._RNN_layers, + nb_inputs=self._nb_inputs, + hidden_size=self._RNN_hidden_size, + nb_neighbours=self._nb_neighbours, + ) + + self._dynedge = DynEdge( + nb_inputs=self._RNN_hidden_size + 5, + nb_neighbours=self._nb_neighbours, + dynedge_layer_sizes=self._dynedge_layer_sizes, + post_processing_layer_sizes=self._post_processing_layer_sizes, + readout_layer_sizes=self._readout_layer_sizes, + global_pooling_schemes=self._global_pooling_schemes, + add_global_variables_after_pooling=self._add_global_variables_after_pooling, + ) + + def forward(self, data: Data) -> torch.Tensor: + """Apply learnable forward pass of the RNN and DynEdge models.""" + data = self._rnn(data) + readout = self._dynedge(data) + + return readout -- GitLab From b88a0c005353cf8633d439672f511ab443283377 Mon Sep 17 00:00:00 2001 From: Aske-Rosted <askerosted@gmail.com> Date: Mon, 12 Jun 2023 16:39:03 +0900 Subject: [PATCH 2/3] Change generation of Timeseries to dataset class --- src/graphnet/data/dataset.py | 243 +++++++++++++++++- src/graphnet/models/detector/icecube.py | 66 +++++ .../utilities/config/dataset_config.py | 1 + 3 files changed, 309 insertions(+), 1 deletion(-) diff --git a/src/graphnet/data/dataset.py b/src/graphnet/data/dataset.py index a0110988b..22a5c185c 100644 --- a/src/graphnet/data/dataset.py +++ b/src/graphnet/data/dataset.py @@ -17,6 +17,8 @@ from typing import ( import numpy as np import torch from torch_geometric.data import Data +from torch.utils.data import TensorDataset + from graphnet.constants import GRAPHNET_ROOT_DIR from graphnet.data.utilities.string_selection_resolver import ( @@ -150,6 +152,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): loss_weight_column: Optional[str] = None, loss_weight_default_value: Optional[float] = None, seed: Optional[int] = None, + timeseries: Optional[bool] = None, ): """Construct Dataset. @@ -195,6 +198,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): subset of events when resolving a string-based selection (e.g., `"10000 random events ~ event_no % 5 > 0"` or `"20% random events ~ event_no % 5 > 0"`). + timeseries: Whether the dataset is a timeseries dataset. Defaults to None (Not a timeseries dataset). """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -218,6 +222,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): self._index_column = index_column self._truth_table = truth_table self._loss_weight_default_value = loss_weight_default_value + self._timeseries = timeseries if node_truth is not None: assert isinstance(node_truth_table, str) @@ -375,7 +380,14 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): features, truth, node_truth, loss_weight = self._query( sequential_index ) - graph = self._create_graph(features, truth, node_truth, loss_weight) + if self._timeseries is not False: + graph = self._create_timeseries( + features, truth, node_truth, loss_weight + ) + else: + graph = self._create_graph( + features, truth, node_truth, loss_weight + ) return graph # Internal method(s) @@ -616,6 +628,223 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): graph["dataset_path"] = self._path return graph + def _create_timeseries( + self, + features: List[Tuple[float, ...]], + truth: Tuple[Any, ...], + node_truth: Optional[List[Tuple[Any, ...]]] = None, + loss_weight: Optional[float] = None, + ) -> Data: + """Create Pytorch Data (i.e. timeseries) object. + + No preprocessing is performed at this stage, just as no node adjancency + is imposed. This means that the `edge_attr` and `edge_weight` + attributes are not set. + + Args: + features: List of list of tuples, containing event features. + truth: List of tuples, containing truth information. + node_truth: List of tuples, containing node-level truth. + loss_weight: A weight associated with the event for weighing the + loss. + + Returns: + time series object. + """ + # Convert nested list to simple dict + truth_dict = { + key: truth[index] for index, key in enumerate(self._truth) + } + + # Define custom labels + labels_dict = self._get_labels(truth_dict) + + # Convert nested list to simple dict + if node_truth is not None: + node_truth_array = np.asarray(node_truth) + assert self._node_truth is not None + node_truth_dict = { + key: node_truth_array[:, index] + for index, key in enumerate(self._node_truth) + } + + # Catch cases with no reconstructed pulses + if len(features): + data = np.asarray(features)[:, 1:] + else: + data = np.array([]).reshape((0, len(self._features) - 1)) + + # Construct graph data object + x = torch.tensor(data, dtype=self._dtype) # pylint: disable=C0103 + n_pulses = torch.tensor(len(x), dtype=torch.long) + if "dom_number" not in features: + number_index = -1 + unique_index = [ + self._features.index(key) - 1 + for key in ["dom_x", "dom_y", "dom_z"] + ] + dom_number = torch.unique( + x[:, unique_index], return_inverse=True, dim=0, sorted=False + )[-1] + x = torch.cat([x, dom_number.reshape(-1, 1)], dim=1) + else: + number_index = self._features.index("dom_number") - 1 + + dom_activation_sort = x[:, number_index].sort()[-1] + x = x[dom_activation_sort] + bin_count = torch.bincount( + x[:, number_index].type(torch.int64) + ).cumsum(0) + # removing dom_number from features should maybe be optional. + if number_index != -1: + x = x[:, torch.arange(x.shape[-1]) != number_index] + else: + x = x[:, :-1] + n_activations = bin_count - torch.cat( + [torch.tensor([0]), bin_count[:-1]] + ) + activations_sort = n_activations.sort(descending=True)[-1] + x = np.array( + [np.array(_) for _ in torch.tensor_split(x, bin_count[:-1])], + dtype=object, + )[activations_sort] + x = [ + torch.tensor( + dom_list[ + np.argsort( + dom_list[ + :, + self._features.index("dom_time"), + ] + )[::-1] + ] + ) + for dom_list in x + ] + x, xyztt = Dataset.DOM_time_series_to_pack_sequence( + x, device="cpu", features=self._features[1:] + ) + graph = TimeSeries(x=xyztt, edge_index=None) + graph.time_series = [ + x.data, + x.batch_sizes, + x.sorted_indices, + x.unsorted_indices, + ] + # print(x.batch_sizes) + graph.n_pulses = n_pulses + graph.features = self._features[1:] + # graph.n_activations = n_activations[activations_sort] + + # Add loss weight to graph. + if loss_weight is not None and self._loss_weight_column is not None: + # No loss weight was retrieved, i.e., it is missing for the current + # event. + if loss_weight < 0: + if self._loss_weight_default_value is None: + raise ValueError( + "At least one event is missing an entry in " + f"{self._loss_weight_column} " + "but loss_weight_default_value is None." + ) + graph[self._loss_weight_column] = torch.tensor( + self._loss_weight_default_value, dtype=self._dtype + ).reshape(-1, 1) + else: + graph[self._loss_weight_column] = torch.tensor( + loss_weight, dtype=self._dtype + ).reshape(-1, 1) + + # Write attributes, either target labels, truth info or original + # features. + add_these_to_graph = [labels_dict, truth_dict] + if node_truth is not None: + add_these_to_graph.append(node_truth_dict) + for write_dict in add_these_to_graph: + for key, value in write_dict.items(): + try: + graph[key] = torch.tensor(value) + except TypeError: + # Cannot convert `value` to Tensor due to its data type, + # e.g. `str`. + self.debug( + ( + f"Could not assign `{key}` with type " + f"'{type(value).__name__}' as attribute to graph." + ) + ) + + # Additionally add original features as (static) attributes + for index, feature in enumerate(graph.features[:-2]): + if feature not in ["x"]: + graph[feature] = graph.x[:, index] + + # Add custom labels to the graph + for key, fn in self._label_fns.items(): + graph[key] = fn(graph) + + # Add Dataset Path. Useful if multiple datasets are concatenated. + graph["dataset_path"] = self._path + return graph + + def DOM_time_series_to_pack_sequence( + tensor: torch.Tensor, device: torch.device, features: List[str] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Packs DOM time series data into a pack sequence. + + Also returns a tensor of x/y/z-coordinates and time of first/mean activation for each DOM. + + Args: + tensor (torch.Tensor): Input tensor. + device (torch.device): Device to use. + features (list): List of features in the tensor. + + Returns: + tensor (torch.Tensor): Packed sequence of DOM time series data. + xyztt (torch.Tensor): Tensor of x/y/z-coordinates as well as time of mean & first activations for each DOM. + """ + xyztc = torch.stack( + [ + torch.cat( + [ + v[ + 0, + [ + features.index("dom_x"), + features.index("dom_y"), + features.index("dom_z"), + ], + ], + torch.as_tensor( + [ + v[:, features.index("dom_time")].mean(), + v[:, features.index("charge")].sum(), + ], + device=device, + ), + ] + ) + for v in tensor + ] + ) + + tensor = torch.nn.utils.rnn.pack_sequence(tensor, enforce_sorted=True) + return tensor, xyztc + + @torch.jit.ignore + def sorted_jit_ignore(tensor: torch.Tensor) -> torch.Tensor: + """Sort a tensor based on the length of the elements. + + Args: + tensor (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Sorted tensor. + torch.Tensor: Indices of sorted tensor. + """ + tensor = sorted(tensor, key=len, reverse=True) + return tensor + def _get_labels(self, truth_dict: Dict[str, Any]) -> Dict[str, Any]: """Return dictionary of labels, to be added as graph attributes.""" if "pid" in truth_dict.keys(): @@ -664,6 +893,18 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): return -1 +class TimeSeries(Data): + """Dataset class for time series data.""" + + def __cat_dim__( + self: Data, key: str, value: Any, *args: Any, **kwargs: Dict[str, Any] + ) -> Any: + """Specify change in concatenation dimension for specific key.""" + if key == "time_series": + return 0 + return super().__cat_dim__(key, value, *args, **kwargs) + + class EnsembleDataset(torch.utils.data.ConcatDataset): """Construct a single dataset from a collection of datasets.""" diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py index e6addb8e3..4bd162042 100644 --- a/src/graphnet/models/detector/icecube.py +++ b/src/graphnet/models/detector/icecube.py @@ -71,6 +71,72 @@ class IceCubeKaggle(Detector): return data +class IceCube86RNN(Detector): + """`Detector` class for IceCube-86 with a time-series data input.""" + + features = FEATURES.ICECUBE86[:5] + + def _forward(self, data: Data) -> Data: + """Ingest data, build graph, and preprocess features. + + Args: + data: Input graph data. + + Returns: + Connected and preprocessed time-series graph data. + """ + self._validate_features(data) + + data.x[:, 0] /= 500.0 # dom_x + data.x[:, 1] /= 500.0 # dom_y + data.x[:, 2] /= 500.0 # dom_z + data.x[:, 3] = (data.x[:, 3] - 1.0e04) / 3.0e4 # dom_time + data.x[:, 4] = torch.log10(data.x[:, 4]) / 3.0 # charge + + data.time_series[0] /= torch.tensor( + [500.0, 500.0, 500.0, 1.0, 1.0], device=data.x.device + ) + data.time_series[0][:, 3] = ( + data.time_series[0][:, 3] - 1.0e04 + ) / 3.0e4 + data.time_series[0][:, 4] = ( + torch.log10(data.time_series[0][:, 4]) / 3.0 + ) + + data.time_series[0] = torch.tensor_split( + data.time_series[0], (data.n_pulses.cumsum(0)[:-1]).cpu() + ) + + data.time_series[1] = torch.tensor_split( + data.time_series[1], + ( + torch.argwhere( + torch.gt(data.time_series[1][1:], data.time_series[1][:-1]) + ).flatten() + + 1 + ).cpu(), + ) + + time_series = [] + for sequence, batch_sizes, sorted_indices, unsorted_indices in zip( + data.time_series[0], + data.time_series[1], + data.time_series[2], + data.time_series[3], + ): + time_series.append( + torch.nn.utils.rnn.PackedSequence( + data=sequence, + batch_sizes=batch_sizes.cpu(), + sorted_indices=sorted_indices, + unsorted_indices=unsorted_indices, + ) + ) + + data.time_series = time_series + return data + + class IceCubeDeepCore(IceCube86): """`Detector` class for IceCube-DeepCore.""" diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py index fc5beb44b..782ff7c6f 100644 --- a/src/graphnet/utilities/config/dataset_config.py +++ b/src/graphnet/utilities/config/dataset_config.py @@ -47,6 +47,7 @@ class DatasetConfig(BaseConfig): loss_weight_default_value: Optional[float] = None seed: Optional[int] = None + timeseries: Optional[bool] = None def __init__(self, **data: Any) -> None: """Construct `DataConfig`. -- GitLab From 5a03cb10f3d203d60aee5cba8c9b1c507dd6546f Mon Sep 17 00:00:00 2001 From: Aske-Rosted <askerosted@gmail.com> Date: Mon, 12 Jun 2023 16:41:52 +0900 Subject: [PATCH 3/3] Update RNN to new dataset --- src/graphnet/models/gnn/node_rnn.py | 31 ++++++++------------------ src/graphnet/models/gnn/rnn_dynedge.py | 2 +- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/src/graphnet/models/gnn/node_rnn.py b/src/graphnet/models/gnn/node_rnn.py index 1c5e1e969..490c1f024 100644 --- a/src/graphnet/models/gnn/node_rnn.py +++ b/src/graphnet/models/gnn/node_rnn.py @@ -54,28 +54,15 @@ class Node_RNN(GNN): def forward(self, data: Data) -> torch.Tensor: """Apply learnable forward pass to the GNN.""" - x, batch = data.x, data.batch + rnn_out = [] + for b_uniq in data.batch.unique(): + rnn_out.append( + self._rnn(data.time_series[b_uniq])[-1][0] + ) # apply rnn layer + # x = self._rnn(x)[-1][0] # apply rnn layer + rnn_out = torch.cat(rnn_out) - x = append_dom_id(x, batch, device=self.device) # add dom id to data - x, batch = DOM_to_time_series( - x, batch - ) # create time series for each dom - x, xyztt = DOM_time_series_to_pack_sequence( - x, device=self.device - ) # pack time series into pack sequence - x = self._rnn(x)[-1][0] # apply rnn layer - - x = torch.hstack( - [xyztt, x] + data.x = torch.hstack( + [data.x, rnn_out] ) # reintroduce x/y/z-coordinates and time of first/mean activation for each DOM - - batch, sort_index = batch.sort() - data.x = x[sort_index] - data.batch = batch - edge_index = knn_graph_ignore( - x=data.x, - k=self._nb_neighbours, - batch=data.batch, - ).to(self.device) - data.edge_index = edge_index return data diff --git a/src/graphnet/models/gnn/rnn_dynedge.py b/src/graphnet/models/gnn/rnn_dynedge.py index e3e58b891..0deec7113 100644 --- a/src/graphnet/models/gnn/rnn_dynedge.py +++ b/src/graphnet/models/gnn/rnn_dynedge.py @@ -59,7 +59,7 @@ class RNN_DynEdge(GNN): ) if readout_layer_sizes is None: readout_layer_sizes = [ - 128, + 512, ] self._readout_layer_sizes = readout_layer_sizes -- GitLab