Skip to content
Snippets Groups Projects

The implementation of the RNN_DynEdge model

Closed Jorge Prado requested to merge github/fork/Aske-Rosted/RNN_GraphNet into main
7 files
+ 615
1
Compare changes
  • Side-by-side
  • Inline
Files
7
+ 242
1
@@ -17,6 +17,8 @@ from typing import (
@@ -17,6 +17,8 @@ from typing import (
import numpy as np
import numpy as np
import torch
import torch
from torch_geometric.data import Data
from torch_geometric.data import Data
 
from torch.utils.data import TensorDataset
 
from graphnet.constants import GRAPHNET_ROOT_DIR
from graphnet.constants import GRAPHNET_ROOT_DIR
from graphnet.data.utilities.string_selection_resolver import (
from graphnet.data.utilities.string_selection_resolver import (
@@ -150,6 +152,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -150,6 +152,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
loss_weight_column: Optional[str] = None,
loss_weight_column: Optional[str] = None,
loss_weight_default_value: Optional[float] = None,
loss_weight_default_value: Optional[float] = None,
seed: Optional[int] = None,
seed: Optional[int] = None,
 
timeseries: Optional[bool] = None,
    • Author Owner

      Created by: RasmusOrsoe

      It seems like timeseries is a boolean that toggles between our "default" graph definition and your problem/model specific graph definition. I think this is a good clue that we should introduce a more modular way of choosing graph representations that doesn't make the Dataset inflated with arguments. I think the GraphDefinition mentioned in #462 (closed) is what we need now. I think we should introduce one argument to Dataset called graph_definition: GraphDefinition. This module should be in charge of

      1. Standardizing/scaling of node features via PreprocesingModule (Modular)
      2. Creation of Data object (Problem specific - new implementation for new problem)
      3. All changes/attachments/adjustments made to graph object in Dataset (functionality in base class)
      4. Assignment of edges fromEdgeDefinition (Modular)

      Here is some very specific pseudo-code of how GraphDefinition could look like:

      from typing import Tuple, Any, List, Optional, Union, Dict, Callable
      from abc import ABC, abstractmethod
      import torch
      from torch_geometric.data import Data
      import numpy as np
      
      class GraphDefinition(Model):
          """
          An Abstract class to create graph definitions from.
          """
          @save_model_config
          def __init__(self,
                       preprocessing_module: Optional[PreprocesingModule],
                       edge_definition: Optional[EdgeDefinition]):
              
              # Member Variables
              self._preprocessing = preprocessing_module
              self._edge_definiton = edge_definition
      
              # Base class constructor
              super().__init__(name=__name__, class_name=self.__class__.__name__)
      
          @abstractmethod
          def _create_graph(self, x: np.array) -> Data:
              """ Problem/model specific graph definition. 
                  Should not standardize/scale data.
                  May assign edges.
      
              Args:
                  x: node features for a single event
      
              Returns:
                  Data object (a single graph)
              """
      
          def __call__(self, 
                       node_features: List[Tuple[float, ...]],
                       node_feature_names: List[str],
                       truth_dicts: List[Dict[str, Any]],
                       custom_label_functions: List[Dict[str, Callable]],
                       loss_weight_column: Optional[str] = None,
                       loss_weight: Optional[float] = None,
                       loss_weight_default_value: Optional[float] = None,
                       data_path: Optional[str] = None,
                       ) -> Data:
              
              # Standardize / Scale  node features
              node_features = self._preprocessing(node_features, node_feature_names)
      
              # Create graph 
              graph = self._create_graph(node_features)
      
              # Attach number of pulses as static attribute.
              graph.n_pulses = torch.tensor(len(node_features), dtype=torch.int32)
      
              # Assign edges
              graph = self._edge_definiton(graph)
      
              # Attach data path - useful for Ensemble datasets.
              if data_path is not None:
                  graph['dataset_path'] = data_path
      
              # Attach loss weights if they exist
              graph = self._add_loss_weights(graph = graph,
                                             loss_weight = loss_weight,
                                             loss_weight_column=loss_weight_column,
                                             loss_weight_default_value = loss_weight_default_value)
              
              # Attach default truth labels and node truths
              graph = self._add_truth(graph = graph,
                                      truth_dicts = truth_dicts)
              
              # Attach custom truth labels
              graph = self._add_custom_labels(graph = graph,
                                              custom_label_functions=  custom_label_functions)
              
              # Attach node features as seperate fields. MAY NOT CONTAIN 'x'
              graph = self._add_features(graph = graph,
                                         node_feature_names = node_feature_names)
          
              return graph
          
          def _add_loss_weights(self,
                                loss_weight: np.array,
                                graph: Data,
                                loss_weight_column: str,
                                loss_weight_default_value: Union[int,float],
                                ) -> Data:
              """Attempt to store a weight in the graph that can be used to weight 
                  the loss during training. I.e. 
                  `graph[loss_weight_column] = loss_weight`
      
              Args:
                  loss_weight: The non-negative weight to be stored.
                  graph: Data object representing the event.
                  loss_weight_column: The name under which the weight is stored in
                                       the graph.
                  loss_weight_default_value: The default value used if 
                                              none was retrieved.
      
              Returns:
                  A graph with loss weight added, if available.
              """
              # Add loss weight to graph.
              if loss_weight is not None and loss_weight_column is not None:
                  # No loss weight was retrieved, i.e., it is missing for the current
                  # event.
                  if loss_weight < 0:
                      if loss_weight_default_value is None:
                          raise ValueError(
                              "At least one event is missing an entry in "
                              f"{loss_weight_column} "
                              "but loss_weight_default_value is None."
                          )
                      graph[loss_weight_column] = torch.tensor(
                          self._loss_weight_default_value, dtype=self._dtype
                      ).reshape(-1, 1)
                  else:
                      graph[loss_weight_column] = torch.tensor(
                          loss_weight, dtype=self._dtype
                      ).reshape(-1, 1)
              return graph
          
          def _add_truth(self, graph: Data, 
                                truth_dicts: List[Dict[str, Any]]
                                ) -> Data:
              """Add truth labels from ´truth_dicts´ to ´graph´. I.e.:
                  ´graph[key] = truth_dict[key]´
      
      
              Args:
                  graph: graph where the label will be stored
                  truth_dicts: dictionary containing the labels
      
              Returns:
                  graph with labels
              """
              # Write attributes, either target labels, truth info or original
              # features.
              for truth_dict in truth_dicts:
                  for key, value in truth_dict.items():
                      try:
                          graph[key] = torch.tensor(value)
                      except TypeError:
                          # Cannot convert `value` to Tensor due to its data type,
                          # e.g. `str`.
                          self.debug(
                              (
                                  f"Could not assign `{key}` with type "
                                  f"'{type(value).__name__}' as attribute to graph."
                              )
                          )
              return graph
      
          def _add_features_individually(self,
                            graph: Data,
                            node_feature_names: List[str],
                            ) -> Data:
              # Additionally add original features as (static) attributes
              graph.features = node_feature_names
              for index, feature in enumerate(node_feature_names):
                  if feature not in ["x"]: # reserved for node features.
                      graph[feature] = graph.x[:, index].detach()
                  else:
                      self.warnonce(f"""Cannot assign graph['x']. This field is 
                                    reserved for node features. Please rename your 
                                    input feature.""")
              return graph
          
          def _add_custom_labels(self, graph: Data,
                                 custom_label_functions: Dict[str, Callable]
                                 ) -> Data:
              # Add custom labels to the graph
              for key, fn in custom_label_functions.items():
                  graph[key] = fn(graph)
              return graph

      Whats missing here is PreprocesingModule (Essentially a refactor of Detector that only does scaling/standardization). Ideally, this class should be able to deal with graphs that are created on a subset of the features that the module preprocesses. E.g. for IceCube86, which has features ["dom_x","dom_y", "dom_z", "dom_time", "charge","rde","pmt_area"] it should be possible to pass graphs with features = ["dom_x", "dom_y", "dom_z", "dom_time"] as this is a legit use case (right now it just throws a tantrum).

      EdgeDefinition is our GraphBuilder class, which I think we should just rename and keep as-is.

      Our "default" graph would then be defined as:

      class DefaultGraph(GraphDefinition):
          def _create_graph(self, node_features: np.array, dtype: torch.dtype) -> Data:
              return Data(x= torch.tensor(node_features, dtype=dtype), edge_index=None)

      Where it can then be instantiated like

      graph_definition = DefaultGraph(edge_definition = KNNGraph,
                                                           preprocessing_module = IceCube86)

      This way, all the code that alters our graph in Dataset has been moved into one self-consistent class. In Dataset.create_graph the code is now simpler:

      def _create_graph(
              self,
              features: List[Tuple[float, ...]],
              truth: Tuple[Any, ...],
              node_truth: Optional[List[Tuple[Any, ...]]] = None,
              loss_weight: Optional[float] = None,
          ) -> Data:
              """Create Pytorch Data (i.e. graph) object.
      
              Args:
                  features: List of tuples, containing event features.
                  truth: List of tuples, containing truth information.
                  node_truth: List of tuples, containing node-level truth.
                  loss_weight: A weight associated with the event for weighing the
                      loss.
      
              Returns:
                  Graph object.
              """
              # Convert nested list to simple dict
              truth_dict = {
                  key: truth[index] for index, key in enumerate(self._truth)
              }
      
              # Define custom labels
              labels_dict = self._get_labels(truth_dict)
      
              # Convert nested list to simple dict
              if node_truth is not None:
                  node_truth_array = np.asarray(node_truth)
                  assert self._node_truth is not None
                  node_truth_dict = {
                      key: node_truth_array[:, index]
                      for index, key in enumerate(self._node_truth)
                  }
              
              # Create list of truth dicts with labels
              truth_dicts = [labels_dict, truth_dict]
              if node_truth is not None:
                  truth_dicts.append(node_truth_dict)
      
              # Catch cases with no reconstructed pulses
              if len(features):
                  node_features= np.asarray(features)[:, 1:]
              else:
                  node_features = np.array([]).reshape((0, len(self._features) - 1))
      
              # Construct graph data object
              graph = self._graph_definition(node_features = node_features,
                                              node_feature_names =  self._features[1:], # first entry is index column
                                              truth_dicts = truth_dicts,
                                              custom_label_functions = self._label_fns,
                                              loss_weight_column = self._loss_weight_column,
                                              loss_weight = loss_weight,
                                              loss_weight_default_value = self._loss_weight_default_value,
                                              data_path = self._path)
              return graph

      Similar changes will have to be made to make_graph in GraphNeTI3Module.

      When we then construct Models, we would then pass GraphDefinition instead of Detector, and we could add a self.warnonce if the model is being evaluated on graphs that it was not trained on. Ideally, the graph definition that a model was trained on should be available via the model (e.g. my_model.graph_definition( .. ) ) once its saved and loaded into a new session, like asogaard drew in the diagrams in #462 (closed), such that it's available when the model is deployed.

Please register or sign in to reply
):
):
"""Construct Dataset.
"""Construct Dataset.
@@ -195,6 +198,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -195,6 +198,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
subset of events when resolving a string-based selection (e.g.,
subset of events when resolving a string-based selection (e.g.,
`"10000 random events ~ event_no % 5 > 0"` or `"20% random
`"10000 random events ~ event_no % 5 > 0"` or `"20% random
events ~ event_no % 5 > 0"`).
events ~ event_no % 5 > 0"`).
 
timeseries: Whether the dataset is a timeseries dataset. Defaults to None (Not a timeseries dataset).
"""
"""
# Base class constructor
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)
super().__init__(name=__name__, class_name=self.__class__.__name__)
@@ -218,6 +222,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -218,6 +222,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
self._index_column = index_column
self._index_column = index_column
self._truth_table = truth_table
self._truth_table = truth_table
self._loss_weight_default_value = loss_weight_default_value
self._loss_weight_default_value = loss_weight_default_value
 
self._timeseries = timeseries
if node_truth is not None:
if node_truth is not None:
assert isinstance(node_truth_table, str)
assert isinstance(node_truth_table, str)
@@ -375,7 +380,14 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -375,7 +380,14 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
features, truth, node_truth, loss_weight = self._query(
features, truth, node_truth, loss_weight = self._query(
sequential_index
sequential_index
)
)
graph = self._create_graph(features, truth, node_truth, loss_weight)
if self._timeseries is not False:
 
graph = self._create_timeseries(
 
features, truth, node_truth, loss_weight
 
)
 
else:
 
graph = self._create_graph(
 
features, truth, node_truth, loss_weight
 
)
return graph
return graph
# Internal method(s)
# Internal method(s)
@@ -616,6 +628,223 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -616,6 +628,223 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
graph["dataset_path"] = self._path
graph["dataset_path"] = self._path
return graph
return graph
 
def _create_timeseries(
 
self,
 
features: List[Tuple[float, ...]],
 
truth: Tuple[Any, ...],
 
node_truth: Optional[List[Tuple[Any, ...]]] = None,
 
loss_weight: Optional[float] = None,
 
) -> Data:
 
"""Create Pytorch Data (i.e. timeseries) object.
 
 
No preprocessing is performed at this stage, just as no node adjancency
 
is imposed. This means that the `edge_attr` and `edge_weight`
 
attributes are not set.
 
 
Args:
 
features: List of list of tuples, containing event features.
 
truth: List of tuples, containing truth information.
 
node_truth: List of tuples, containing node-level truth.
 
loss_weight: A weight associated with the event for weighing the
 
loss.
 
 
Returns:
 
time series object.
 
"""
 
# Convert nested list to simple dict
 
truth_dict = {
 
key: truth[index] for index, key in enumerate(self._truth)
 
}
 
 
# Define custom labels
 
labels_dict = self._get_labels(truth_dict)
 
 
# Convert nested list to simple dict
 
if node_truth is not None:
 
node_truth_array = np.asarray(node_truth)
 
assert self._node_truth is not None
 
node_truth_dict = {
 
key: node_truth_array[:, index]
 
for index, key in enumerate(self._node_truth)
 
}
 
 
# Catch cases with no reconstructed pulses
 
if len(features):
 
data = np.asarray(features)[:, 1:]
 
else:
 
data = np.array([]).reshape((0, len(self._features) - 1))
 
 
# Construct graph data object
 
x = torch.tensor(data, dtype=self._dtype) # pylint: disable=C0103
 
n_pulses = torch.tensor(len(x), dtype=torch.long)
 
if "dom_number" not in features:
 
number_index = -1
 
unique_index = [
 
self._features.index(key) - 1
 
for key in ["dom_x", "dom_y", "dom_z"]
 
]
 
dom_number = torch.unique(
 
x[:, unique_index], return_inverse=True, dim=0, sorted=False
 
)[-1]
 
x = torch.cat([x, dom_number.reshape(-1, 1)], dim=1)
 
else:
 
number_index = self._features.index("dom_number") - 1
 
 
dom_activation_sort = x[:, number_index].sort()[-1]
 
x = x[dom_activation_sort]
 
bin_count = torch.bincount(
 
x[:, number_index].type(torch.int64)
 
).cumsum(0)
 
# removing dom_number from features should maybe be optional.
 
if number_index != -1:
 
x = x[:, torch.arange(x.shape[-1]) != number_index]
 
else:
 
x = x[:, :-1]
 
n_activations = bin_count - torch.cat(
 
[torch.tensor([0]), bin_count[:-1]]
 
)
 
activations_sort = n_activations.sort(descending=True)[-1]
 
x = np.array(
 
[np.array(_) for _ in torch.tensor_split(x, bin_count[:-1])],
 
dtype=object,
 
)[activations_sort]
 
x = [
 
torch.tensor(
 
dom_list[
 
np.argsort(
 
dom_list[
 
:,
 
self._features.index("dom_time"),
 
]
 
)[::-1]
 
]
 
)
 
for dom_list in x
 
]
 
x, xyztt = Dataset.DOM_time_series_to_pack_sequence(
 
x, device="cpu", features=self._features[1:]
 
)
 
graph = TimeSeries(x=xyztt, edge_index=None)
 
graph.time_series = [
 
x.data,
 
x.batch_sizes,
 
x.sorted_indices,
 
x.unsorted_indices,
 
]
 
# print(x.batch_sizes)
 
graph.n_pulses = n_pulses
 
graph.features = self._features[1:]
 
# graph.n_activations = n_activations[activations_sort]
 
 
# Add loss weight to graph.
 
if loss_weight is not None and self._loss_weight_column is not None:
 
# No loss weight was retrieved, i.e., it is missing for the current
 
# event.
 
if loss_weight < 0:
 
if self._loss_weight_default_value is None:
 
raise ValueError(
 
"At least one event is missing an entry in "
 
f"{self._loss_weight_column} "
 
"but loss_weight_default_value is None."
 
)
 
graph[self._loss_weight_column] = torch.tensor(
 
self._loss_weight_default_value, dtype=self._dtype
 
).reshape(-1, 1)
 
else:
 
graph[self._loss_weight_column] = torch.tensor(
 
loss_weight, dtype=self._dtype
 
).reshape(-1, 1)
 
 
# Write attributes, either target labels, truth info or original
 
# features.
 
add_these_to_graph = [labels_dict, truth_dict]
 
if node_truth is not None:
 
add_these_to_graph.append(node_truth_dict)
 
for write_dict in add_these_to_graph:
 
for key, value in write_dict.items():
 
try:
 
graph[key] = torch.tensor(value)
 
except TypeError:
 
# Cannot convert `value` to Tensor due to its data type,
 
# e.g. `str`.
 
self.debug(
 
(
 
f"Could not assign `{key}` with type "
 
f"'{type(value).__name__}' as attribute to graph."
 
)
 
)
 
 
# Additionally add original features as (static) attributes
 
for index, feature in enumerate(graph.features[:-2]):
 
if feature not in ["x"]:
 
graph[feature] = graph.x[:, index]
 
 
# Add custom labels to the graph
 
for key, fn in self._label_fns.items():
 
graph[key] = fn(graph)
 
 
# Add Dataset Path. Useful if multiple datasets are concatenated.
 
graph["dataset_path"] = self._path
 
return graph
 
 
def DOM_time_series_to_pack_sequence(
 
tensor: torch.Tensor, device: torch.device, features: List[str]
 
) -> Tuple[torch.Tensor, torch.Tensor]:
 
"""Packs DOM time series data into a pack sequence.
 
 
Also returns a tensor of x/y/z-coordinates and time of first/mean activation for each DOM.
 
 
Args:
 
tensor (torch.Tensor): Input tensor.
 
device (torch.device): Device to use.
 
features (list): List of features in the tensor.
 
 
Returns:
 
tensor (torch.Tensor): Packed sequence of DOM time series data.
 
xyztt (torch.Tensor): Tensor of x/y/z-coordinates as well as time of mean & first activations for each DOM.
 
"""
 
xyztc = torch.stack(
 
[
 
torch.cat(
 
[
 
v[
 
0,
 
[
 
features.index("dom_x"),
 
features.index("dom_y"),
 
features.index("dom_z"),
 
],
 
],
 
torch.as_tensor(
 
[
 
v[:, features.index("dom_time")].mean(),
 
v[:, features.index("charge")].sum(),
 
],
 
device=device,
 
),
 
]
 
)
 
for v in tensor
 
]
 
)
 
 
tensor = torch.nn.utils.rnn.pack_sequence(tensor, enforce_sorted=True)
 
return tensor, xyztc
 
 
@torch.jit.ignore
 
def sorted_jit_ignore(tensor: torch.Tensor) -> torch.Tensor:
 
"""Sort a tensor based on the length of the elements.
 
 
Args:
 
tensor (torch.Tensor): Input tensor.
 
 
Returns:
 
torch.Tensor: Sorted tensor.
 
torch.Tensor: Indices of sorted tensor.
 
"""
 
tensor = sorted(tensor, key=len, reverse=True)
 
return tensor
 
def _get_labels(self, truth_dict: Dict[str, Any]) -> Dict[str, Any]:
def _get_labels(self, truth_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Return dictionary of labels, to be added as graph attributes."""
"""Return dictionary of labels, to be added as graph attributes."""
if "pid" in truth_dict.keys():
if "pid" in truth_dict.keys():
@@ -664,6 +893,18 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -664,6 +893,18 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
return -1
return -1
 
class TimeSeries(Data):
 
"""Dataset class for time series data."""
 
 
def __cat_dim__(
 
self: Data, key: str, value: Any, *args: Any, **kwargs: Dict[str, Any]
 
) -> Any:
 
"""Specify change in concatenation dimension for specific key."""
 
if key == "time_series":
 
return 0
 
return super().__cat_dim__(key, value, *args, **kwargs)
 
 
class EnsembleDataset(torch.utils.data.ConcatDataset):
class EnsembleDataset(torch.utils.data.ConcatDataset):
"""Construct a single dataset from a collection of datasets."""
"""Construct a single dataset from a collection of datasets."""
Loading