The implementation of the RNN_DynEdge model
Compare changes
Files
7+ 242
− 1
@@ -17,6 +17,8 @@ from typing import (
@@ -17,6 +17,8 @@ from typing import (
@@ -150,6 +152,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -150,6 +152,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -195,6 +198,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -195,6 +198,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -218,6 +222,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -218,6 +222,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -375,7 +380,14 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -375,7 +380,14 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -616,6 +628,223 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -616,6 +628,223 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -664,6 +893,18 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -664,6 +893,18 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
Created by: RasmusOrsoe
It seems like
timeseries
is a boolean that toggles between our "default" graph definition and your problem/model specific graph definition. I think this is a good clue that we should introduce a more modular way of choosing graph representations that doesn't make theDataset
inflated with arguments. I think theGraphDefinition
mentioned in #462 (closed) is what we need now. I think we should introduce one argument toDataset
calledgraph_definition: GraphDefinition
. This module should be in charge ofPreprocesingModule
(Modular)Data
object (Problem specific - new implementation for new problem)Dataset
(functionality in base class)EdgeDefinition
(Modular)Here is some very specific pseudo-code of how
GraphDefinition
could look like:Whats missing here is
PreprocesingModule
(Essentially a refactor ofDetector
that only does scaling/standardization). Ideally, this class should be able to deal with graphs that are created on a subset of the features that the module preprocesses. E.g. for IceCube86, which has features["dom_x","dom_y", "dom_z", "dom_time", "charge","rde","pmt_area"]
it should be possible to pass graphs withfeatures = ["dom_x", "dom_y", "dom_z", "dom_time"]
as this is a legit use case (right now it just throws a tantrum).EdgeDefinition
is ourGraphBuilder
class, which I think we should just rename and keep as-is.Our "default" graph would then be defined as:
Where it can then be instantiated like
This way, all the code that alters our
graph
inDataset
has been moved into one self-consistent class. InDataset.create_graph
the code is now simpler:Similar changes will have to be made to make_graph in
GraphNeTI3Module
.When we then construct
Models
, we would then passGraphDefinition
instead ofDetector
, and we could add aself.warnonce
if the model is being evaluated on graphs that it was not trained on. Ideally, the graph definition that a model was trained on should be available via the model (e.g.my_model.graph_definition( .. )
) once its saved and loaded into a new session, like asogaard drew in the diagrams in #462 (closed), such that it's available when the model is deployed.