The implementation of the RNN_DynEdge model
Compare changes
+ 242
− 1
@@ -17,6 +17,8 @@ from typing import (
@@ -17,6 +17,8 @@ from typing import (
@@ -150,6 +152,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -150,6 +152,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -195,6 +198,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -195,6 +198,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -218,6 +222,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -218,6 +222,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -375,7 +380,14 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -375,7 +380,14 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -616,6 +628,223 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -616,6 +628,223 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -664,6 +893,18 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
@@ -664,6 +893,18 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
Created by: RasmusOrsoe
It seems like
timeseriesis a boolean that toggles between our "default" graph definition and your problem/model specific graph definition. I think this is a good clue that we should introduce a more modular way of choosing graph representations that doesn't make theDatasetinflated with arguments. I think theGraphDefinitionmentioned in #462 (closed) is what we need now. I think we should introduce one argument toDatasetcalledgraph_definition: GraphDefinition. This module should be in charge ofPreprocesingModule(Modular)Dataobject (Problem specific - new implementation for new problem)Dataset(functionality in base class)EdgeDefinition(Modular)Here is some very specific pseudo-code of how
GraphDefinitioncould look like:Whats missing here is
PreprocesingModule(Essentially a refactor ofDetectorthat only does scaling/standardization). Ideally, this class should be able to deal with graphs that are created on a subset of the features that the module preprocesses. E.g. for IceCube86, which has features["dom_x","dom_y", "dom_z", "dom_time", "charge","rde","pmt_area"]it should be possible to pass graphs withfeatures = ["dom_x", "dom_y", "dom_z", "dom_time"]as this is a legit use case (right now it just throws a tantrum).EdgeDefinitionis ourGraphBuilderclass, which I think we should just rename and keep as-is.Our "default" graph would then be defined as:
Where it can then be instantiated like
This way, all the code that alters our
graphinDatasethas been moved into one self-consistent class. InDataset.create_graphthe code is now simpler:Similar changes will have to be made to make_graph in
GraphNeTI3Module.When we then construct
Models, we would then passGraphDefinitioninstead ofDetector, and we could add aself.warnonceif the model is being evaluated on graphs that it was not trained on. Ideally, the graph definition that a model was trained on should be available via the model (e.g.my_model.graph_definition( .. )) once its saved and loaded into a new session, like asogaard drew in the diagrams in #462 (closed), such that it's available when the model is deployed.