Skip to content
Snippets Groups Projects

The implementation of the RNN_DynEdge model

Closed Jorge Prado requested to merge github/fork/Aske-Rosted/RNN_GraphNet into main
7 files
+ 615
1
Compare changes
  • Side-by-side
  • Inline
Files
7
+ 242
1
@@ -17,6 +17,8 @@ from typing import (
import numpy as np
import torch
from torch_geometric.data import Data
from torch.utils.data import TensorDataset
from graphnet.constants import GRAPHNET_ROOT_DIR
from graphnet.data.utilities.string_selection_resolver import (
@@ -150,6 +152,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
loss_weight_column: Optional[str] = None,
loss_weight_default_value: Optional[float] = None,
seed: Optional[int] = None,
timeseries: Optional[bool] = None,
):
"""Construct Dataset.
@@ -195,6 +198,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
subset of events when resolving a string-based selection (e.g.,
`"10000 random events ~ event_no % 5 > 0"` or `"20% random
events ~ event_no % 5 > 0"`).
timeseries: Whether the dataset is a timeseries dataset. Defaults to None (Not a timeseries dataset).
"""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)
@@ -218,6 +222,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
self._index_column = index_column
self._truth_table = truth_table
self._loss_weight_default_value = loss_weight_default_value
self._timeseries = timeseries
if node_truth is not None:
assert isinstance(node_truth_table, str)
@@ -375,7 +380,14 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
features, truth, node_truth, loss_weight = self._query(
sequential_index
)
graph = self._create_graph(features, truth, node_truth, loss_weight)
if self._timeseries is not False:
graph = self._create_timeseries(
features, truth, node_truth, loss_weight
)
else:
graph = self._create_graph(
features, truth, node_truth, loss_weight
)
return graph
# Internal method(s)
@@ -616,6 +628,223 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
graph["dataset_path"] = self._path
return graph
def _create_timeseries(
self,
features: List[Tuple[float, ...]],
truth: Tuple[Any, ...],
node_truth: Optional[List[Tuple[Any, ...]]] = None,
loss_weight: Optional[float] = None,
) -> Data:
"""Create Pytorch Data (i.e. timeseries) object.
No preprocessing is performed at this stage, just as no node adjancency
is imposed. This means that the `edge_attr` and `edge_weight`
attributes are not set.
Args:
features: List of list of tuples, containing event features.
truth: List of tuples, containing truth information.
node_truth: List of tuples, containing node-level truth.
loss_weight: A weight associated with the event for weighing the
loss.
Returns:
time series object.
"""
# Convert nested list to simple dict
truth_dict = {
key: truth[index] for index, key in enumerate(self._truth)
}
# Define custom labels
labels_dict = self._get_labels(truth_dict)
# Convert nested list to simple dict
if node_truth is not None:
node_truth_array = np.asarray(node_truth)
assert self._node_truth is not None
node_truth_dict = {
key: node_truth_array[:, index]
for index, key in enumerate(self._node_truth)
}
# Catch cases with no reconstructed pulses
if len(features):
data = np.asarray(features)[:, 1:]
else:
data = np.array([]).reshape((0, len(self._features) - 1))
# Construct graph data object
x = torch.tensor(data, dtype=self._dtype) # pylint: disable=C0103
n_pulses = torch.tensor(len(x), dtype=torch.long)
if "dom_number" not in features:
number_index = -1
unique_index = [
self._features.index(key) - 1
for key in ["dom_x", "dom_y", "dom_z"]
]
dom_number = torch.unique(
x[:, unique_index], return_inverse=True, dim=0, sorted=False
)[-1]
x = torch.cat([x, dom_number.reshape(-1, 1)], dim=1)
else:
number_index = self._features.index("dom_number") - 1
dom_activation_sort = x[:, number_index].sort()[-1]
x = x[dom_activation_sort]
bin_count = torch.bincount(
x[:, number_index].type(torch.int64)
).cumsum(0)
# removing dom_number from features should maybe be optional.
if number_index != -1:
x = x[:, torch.arange(x.shape[-1]) != number_index]
else:
x = x[:, :-1]
n_activations = bin_count - torch.cat(
[torch.tensor([0]), bin_count[:-1]]
)
activations_sort = n_activations.sort(descending=True)[-1]
x = np.array(
[np.array(_) for _ in torch.tensor_split(x, bin_count[:-1])],
dtype=object,
)[activations_sort]
x = [
torch.tensor(
dom_list[
np.argsort(
dom_list[
:,
self._features.index("dom_time"),
]
)[::-1]
]
)
for dom_list in x
]
x, xyztt = Dataset.DOM_time_series_to_pack_sequence(
x, device="cpu", features=self._features[1:]
)
graph = TimeSeries(x=xyztt, edge_index=None)
graph.time_series = [
x.data,
x.batch_sizes,
x.sorted_indices,
x.unsorted_indices,
]
# print(x.batch_sizes)
graph.n_pulses = n_pulses
graph.features = self._features[1:]
# graph.n_activations = n_activations[activations_sort]
# Add loss weight to graph.
if loss_weight is not None and self._loss_weight_column is not None:
# No loss weight was retrieved, i.e., it is missing for the current
# event.
if loss_weight < 0:
if self._loss_weight_default_value is None:
raise ValueError(
"At least one event is missing an entry in "
f"{self._loss_weight_column} "
"but loss_weight_default_value is None."
)
graph[self._loss_weight_column] = torch.tensor(
self._loss_weight_default_value, dtype=self._dtype
).reshape(-1, 1)
else:
graph[self._loss_weight_column] = torch.tensor(
loss_weight, dtype=self._dtype
).reshape(-1, 1)
# Write attributes, either target labels, truth info or original
# features.
add_these_to_graph = [labels_dict, truth_dict]
if node_truth is not None:
add_these_to_graph.append(node_truth_dict)
for write_dict in add_these_to_graph:
for key, value in write_dict.items():
try:
graph[key] = torch.tensor(value)
except TypeError:
# Cannot convert `value` to Tensor due to its data type,
# e.g. `str`.
self.debug(
(
f"Could not assign `{key}` with type "
f"'{type(value).__name__}' as attribute to graph."
)
)
# Additionally add original features as (static) attributes
for index, feature in enumerate(graph.features[:-2]):
if feature not in ["x"]:
graph[feature] = graph.x[:, index]
# Add custom labels to the graph
for key, fn in self._label_fns.items():
graph[key] = fn(graph)
# Add Dataset Path. Useful if multiple datasets are concatenated.
graph["dataset_path"] = self._path
return graph
def DOM_time_series_to_pack_sequence(
tensor: torch.Tensor, device: torch.device, features: List[str]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Packs DOM time series data into a pack sequence.
Also returns a tensor of x/y/z-coordinates and time of first/mean activation for each DOM.
Args:
tensor (torch.Tensor): Input tensor.
device (torch.device): Device to use.
features (list): List of features in the tensor.
Returns:
tensor (torch.Tensor): Packed sequence of DOM time series data.
xyztt (torch.Tensor): Tensor of x/y/z-coordinates as well as time of mean & first activations for each DOM.
"""
xyztc = torch.stack(
[
torch.cat(
[
v[
0,
[
features.index("dom_x"),
features.index("dom_y"),
features.index("dom_z"),
],
],
torch.as_tensor(
[
v[:, features.index("dom_time")].mean(),
v[:, features.index("charge")].sum(),
],
device=device,
),
]
)
for v in tensor
]
)
tensor = torch.nn.utils.rnn.pack_sequence(tensor, enforce_sorted=True)
return tensor, xyztc
@torch.jit.ignore
def sorted_jit_ignore(tensor: torch.Tensor) -> torch.Tensor:
"""Sort a tensor based on the length of the elements.
Args:
tensor (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Sorted tensor.
torch.Tensor: Indices of sorted tensor.
"""
tensor = sorted(tensor, key=len, reverse=True)
return tensor
def _get_labels(self, truth_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Return dictionary of labels, to be added as graph attributes."""
if "pid" in truth_dict.keys():
@@ -664,6 +893,18 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
return -1
class TimeSeries(Data):
"""Dataset class for time series data."""
def __cat_dim__(
self: Data, key: str, value: Any, *args: Any, **kwargs: Dict[str, Any]
) -> Any:
"""Specify change in concatenation dimension for specific key."""
if key == "time_series":
return 0
return super().__cat_dim__(key, value, *args, **kwargs)
class EnsembleDataset(torch.utils.data.ConcatDataset):
"""Construct a single dataset from a collection of datasets."""
Loading