Skip to content
Snippets Groups Projects

The implementation of the RNN_DynEdge model

Closed Jorge Prado requested to merge github/fork/Aske-Rosted/RNN_GraphNet into main

Created by: Aske-Rosted

An implementation of an model that first puts the DOM data through an RNN network before feeding it to the DynEdge network, along with some functions working on DOM level data in order to facilitate the RNN.

This is a way of tackling the issue which is also addressed in #474 (closed).

Due to limited hardware on my end I have not had the ability to properly performance check this new model.

Merge request reports

Approval is optional

Closed by Jorge PradoJorge Prado 1 year ago (Oct 13, 2023 5:35am UTC)

Merge details

  • The changes were not merged into main.

Activity

Filter activity
  • Approvals
  • Assignees & reviewers
  • Comments (from bots)
  • Comments (from users)
  • Commits & branches
  • Edits
  • Labels
  • Lock status
  • Mentions
  • Merge request status
  • Tracking
150 152 loss_weight_column: Optional[str] = None,
151 153 loss_weight_default_value: Optional[float] = None,
152 154 seed: Optional[int] = None,
155 timeseries: Optional[bool] = None,
  • Author Owner

    Created by: RasmusOrsoe

    It seems like timeseries is a boolean that toggles between our "default" graph definition and your problem/model specific graph definition. I think this is a good clue that we should introduce a more modular way of choosing graph representations that doesn't make the Dataset inflated with arguments. I think the GraphDefinition mentioned in #462 (closed) is what we need now. I think we should introduce one argument to Dataset called graph_definition: GraphDefinition. This module should be in charge of

    1. Standardizing/scaling of node features via PreprocesingModule (Modular)
    2. Creation of Data object (Problem specific - new implementation for new problem)
    3. All changes/attachments/adjustments made to graph object in Dataset (functionality in base class)
    4. Assignment of edges fromEdgeDefinition (Modular)

    Here is some very specific pseudo-code of how GraphDefinition could look like:

    from typing import Tuple, Any, List, Optional, Union, Dict, Callable
    from abc import ABC, abstractmethod
    import torch
    from torch_geometric.data import Data
    import numpy as np
    
    class GraphDefinition(Model):
        """
        An Abstract class to create graph definitions from.
        """
        @save_model_config
        def __init__(self,
                     preprocessing_module: Optional[PreprocesingModule],
                     edge_definition: Optional[EdgeDefinition]):
            
            # Member Variables
            self._preprocessing = preprocessing_module
            self._edge_definiton = edge_definition
    
            # Base class constructor
            super().__init__(name=__name__, class_name=self.__class__.__name__)
    
        @abstractmethod
        def _create_graph(self, x: np.array) -> Data:
            """ Problem/model specific graph definition. 
                Should not standardize/scale data.
                May assign edges.
    
            Args:
                x: node features for a single event
    
            Returns:
                Data object (a single graph)
            """
    
        def __call__(self, 
                     node_features: List[Tuple[float, ...]],
                     node_feature_names: List[str],
                     truth_dicts: List[Dict[str, Any]],
                     custom_label_functions: List[Dict[str, Callable]],
                     loss_weight_column: Optional[str] = None,
                     loss_weight: Optional[float] = None,
                     loss_weight_default_value: Optional[float] = None,
                     data_path: Optional[str] = None,
                     ) -> Data:
            
            # Standardize / Scale  node features
            node_features = self._preprocessing(node_features, node_feature_names)
    
            # Create graph 
            graph = self._create_graph(node_features)
    
            # Attach number of pulses as static attribute.
            graph.n_pulses = torch.tensor(len(node_features), dtype=torch.int32)
    
            # Assign edges
            graph = self._edge_definiton(graph)
    
            # Attach data path - useful for Ensemble datasets.
            if data_path is not None:
                graph['dataset_path'] = data_path
    
            # Attach loss weights if they exist
            graph = self._add_loss_weights(graph = graph,
                                           loss_weight = loss_weight,
                                           loss_weight_column=loss_weight_column,
                                           loss_weight_default_value = loss_weight_default_value)
            
            # Attach default truth labels and node truths
            graph = self._add_truth(graph = graph,
                                    truth_dicts = truth_dicts)
            
            # Attach custom truth labels
            graph = self._add_custom_labels(graph = graph,
                                            custom_label_functions=  custom_label_functions)
            
            # Attach node features as seperate fields. MAY NOT CONTAIN 'x'
            graph = self._add_features(graph = graph,
                                       node_feature_names = node_feature_names)
        
            return graph
        
        def _add_loss_weights(self,
                              loss_weight: np.array,
                              graph: Data,
                              loss_weight_column: str,
                              loss_weight_default_value: Union[int,float],
                              ) -> Data:
            """Attempt to store a weight in the graph that can be used to weight 
                the loss during training. I.e. 
                `graph[loss_weight_column] = loss_weight`
    
            Args:
                loss_weight: The non-negative weight to be stored.
                graph: Data object representing the event.
                loss_weight_column: The name under which the weight is stored in
                                     the graph.
                loss_weight_default_value: The default value used if 
                                            none was retrieved.
    
            Returns:
                A graph with loss weight added, if available.
            """
            # Add loss weight to graph.
            if loss_weight is not None and loss_weight_column is not None:
                # No loss weight was retrieved, i.e., it is missing for the current
                # event.
                if loss_weight < 0:
                    if loss_weight_default_value is None:
                        raise ValueError(
                            "At least one event is missing an entry in "
                            f"{loss_weight_column} "
                            "but loss_weight_default_value is None."
                        )
                    graph[loss_weight_column] = torch.tensor(
                        self._loss_weight_default_value, dtype=self._dtype
                    ).reshape(-1, 1)
                else:
                    graph[loss_weight_column] = torch.tensor(
                        loss_weight, dtype=self._dtype
                    ).reshape(-1, 1)
            return graph
        
        def _add_truth(self, graph: Data, 
                              truth_dicts: List[Dict[str, Any]]
                              ) -> Data:
            """Add truth labels from ´truth_dicts´ to ´graph´. I.e.:
                ´graph[key] = truth_dict[key]´
    
    
            Args:
                graph: graph where the label will be stored
                truth_dicts: dictionary containing the labels
    
            Returns:
                graph with labels
            """
            # Write attributes, either target labels, truth info or original
            # features.
            for truth_dict in truth_dicts:
                for key, value in truth_dict.items():
                    try:
                        graph[key] = torch.tensor(value)
                    except TypeError:
                        # Cannot convert `value` to Tensor due to its data type,
                        # e.g. `str`.
                        self.debug(
                            (
                                f"Could not assign `{key}` with type "
                                f"'{type(value).__name__}' as attribute to graph."
                            )
                        )
            return graph
    
        def _add_features_individually(self,
                          graph: Data,
                          node_feature_names: List[str],
                          ) -> Data:
            # Additionally add original features as (static) attributes
            graph.features = node_feature_names
            for index, feature in enumerate(node_feature_names):
                if feature not in ["x"]: # reserved for node features.
                    graph[feature] = graph.x[:, index].detach()
                else:
                    self.warnonce(f"""Cannot assign graph['x']. This field is 
                                  reserved for node features. Please rename your 
                                  input feature.""")
            return graph
        
        def _add_custom_labels(self, graph: Data,
                               custom_label_functions: Dict[str, Callable]
                               ) -> Data:
            # Add custom labels to the graph
            for key, fn in custom_label_functions.items():
                graph[key] = fn(graph)
            return graph

    Whats missing here is PreprocesingModule (Essentially a refactor of Detector that only does scaling/standardization). Ideally, this class should be able to deal with graphs that are created on a subset of the features that the module preprocesses. E.g. for IceCube86, which has features ["dom_x","dom_y", "dom_z", "dom_time", "charge","rde","pmt_area"] it should be possible to pass graphs with features = ["dom_x", "dom_y", "dom_z", "dom_time"] as this is a legit use case (right now it just throws a tantrum).

    EdgeDefinition is our GraphBuilder class, which I think we should just rename and keep as-is.

    Our "default" graph would then be defined as:

    class DefaultGraph(GraphDefinition):
        def _create_graph(self, node_features: np.array, dtype: torch.dtype) -> Data:
            return Data(x= torch.tensor(node_features, dtype=dtype), edge_index=None)

    Where it can then be instantiated like

    graph_definition = DefaultGraph(edge_definition = KNNGraph,
                                                         preprocessing_module = IceCube86)

    This way, all the code that alters our graph in Dataset has been moved into one self-consistent class. In Dataset.create_graph the code is now simpler:

    def _create_graph(
            self,
            features: List[Tuple[float, ...]],
            truth: Tuple[Any, ...],
            node_truth: Optional[List[Tuple[Any, ...]]] = None,
            loss_weight: Optional[float] = None,
        ) -> Data:
            """Create Pytorch Data (i.e. graph) object.
    
            Args:
                features: List of tuples, containing event features.
                truth: List of tuples, containing truth information.
                node_truth: List of tuples, containing node-level truth.
                loss_weight: A weight associated with the event for weighing the
                    loss.
    
            Returns:
                Graph object.
            """
            # Convert nested list to simple dict
            truth_dict = {
                key: truth[index] for index, key in enumerate(self._truth)
            }
    
            # Define custom labels
            labels_dict = self._get_labels(truth_dict)
    
            # Convert nested list to simple dict
            if node_truth is not None:
                node_truth_array = np.asarray(node_truth)
                assert self._node_truth is not None
                node_truth_dict = {
                    key: node_truth_array[:, index]
                    for index, key in enumerate(self._node_truth)
                }
            
            # Create list of truth dicts with labels
            truth_dicts = [labels_dict, truth_dict]
            if node_truth is not None:
                truth_dicts.append(node_truth_dict)
    
            # Catch cases with no reconstructed pulses
            if len(features):
                node_features= np.asarray(features)[:, 1:]
            else:
                node_features = np.array([]).reshape((0, len(self._features) - 1))
    
            # Construct graph data object
            graph = self._graph_definition(node_features = node_features,
                                            node_feature_names =  self._features[1:], # first entry is index column
                                            truth_dicts = truth_dicts,
                                            custom_label_functions = self._label_fns,
                                            loss_weight_column = self._loss_weight_column,
                                            loss_weight = loss_weight,
                                            loss_weight_default_value = self._loss_weight_default_value,
                                            data_path = self._path)
            return graph

    Similar changes will have to be made to make_graph in GraphNeTI3Module.

    When we then construct Models, we would then pass GraphDefinition instead of Detector, and we could add a self.warnonce if the model is being evaluated on graphs that it was not trained on. Ideally, the graph definition that a model was trained on should be available via the model (e.g. my_model.graph_definition( .. ) ) once its saved and loaded into a new session, like asogaard drew in the diagrams in #462 (closed), such that it's available when the model is deployed.

  • Author Owner

    Created by: RasmusOrsoe

    Review: Commented

    Hey @Aske-Rosted!

    Thank you for sharing this code. I have made a rather large comment on the first part of this PR that relates to your changes in Dataset and the creation of graphs, their scaling and so on.

    Could you take a look at the pseudo code and let me know if this suggested refactor fits your usecase?

  • Jorge Prado mentioned in merge request !557 (closed)

    mentioned in merge request !557 (closed)

  • Jorge Prado mentioned in merge request !558 (merged)

    mentioned in merge request !558 (merged)

  • closed

  • Please register or sign in to reply
    Loading