From 1d9ad82a3714af40e7cdb0ddb66f2710cd128a9f Mon Sep 17 00:00:00 2001
From: Rasmus Oersoe <rahn@outlook.dk>
Date: Tue, 5 Apr 2022 22:11:21 +0200
Subject: [PATCH 1/6] initial stuff

---
 src/graphnet/data/sqlite_dataset.py     | 54 ++++++++++++++++---
 src/graphnet/models/detector/icecube.py | 71 +++++++++++++++++++++++++
 src/graphnet/models/gnn/dynedge.py      |  3 ++
 src/graphnet/models/graph_builders.py   |  5 +-
 src/graphnet/models/training/utils.py   |  4 ++
 5 files changed, 128 insertions(+), 9 deletions(-)

diff --git a/src/graphnet/data/sqlite_dataset.py b/src/graphnet/data/sqlite_dataset.py
index 3d55ceb6b..915f3a05a 100644
--- a/src/graphnet/data/sqlite_dataset.py
+++ b/src/graphnet/data/sqlite_dataset.py
@@ -19,6 +19,7 @@ class SQLiteDataset(torch.utils.data.Dataset):
         truth_table: str = 'truth',
         selection: Optional[List[int]] = None,
         dtype: torch.dtype = torch.float32,
+        node_representation = 'pulse',
     ):
 
         # Check(s)
@@ -44,6 +45,7 @@ class SQLiteDataset(torch.utils.data.Dataset):
         self._index_column = index_column
         self._truth_table = truth_table
         self._dtype = dtype
+        self._node_representation =  node_representation
 
         self._features_string = ', '.join(self._features)
         self._truth_string = ', '.join(self._truth)
@@ -118,6 +120,33 @@ class SQLiteDataset(torch.utils.data.Dataset):
         except:
             return -1
 
+    def _get_unique_positions(self, tensor):
+        return torch.unique(tensor, return_counts = True, return_inverse=True, dim=0)
+
+    def _make_dom_wise_representation(self,data):
+        unique_doms, inverse_idx, n_pulses_pr_dom = self._get_unique_positions(data.x[:,[0,1,2,5,6]])
+        unique_inverse_indices = torch.unique(inverse_idx)
+        count = 0
+        pulse_statistics = torch.zeros(size = (len(unique_doms), 8))
+        #'dom_x','dom_y','dom_z','dom_time','charge','rde','pmt_area'
+        for unique_inverse_idx in unique_inverse_indices:
+            time   = data.x[inverse_idx == unique_inverse_idx,3]
+            charge = data.x[inverse_idx == unique_inverse_idx,4]
+            pulse_statistics[count,0] = torch.min(time)
+            pulse_statistics[count,1] = torch.mean(time)
+            pulse_statistics[count,2] = torch.max(time)
+            pulse_statistics[count,3] = torch.std(time)
+            pulse_statistics[count,4] = torch.min(charge)
+            pulse_statistics[count,5] = torch.mean(charge)
+            pulse_statistics[count,6] = torch.max(charge)
+            pulse_statistics[count,7] = torch.std(charge)
+            count +=1
+        #print(unique_doms.shape)
+        #print(n_pulses_pr_dom.shape)
+        #print(pulse_statistics.shape)      
+        data.x = torch.cat((unique_doms, n_pulses_pr_dom.unsqueeze(1), pulse_statistics), dim = 1)
+        return data
+
     def _create_graph(self, features, truth):
         """Create Pytorch Data (i.e.graph) object.
 
@@ -161,12 +190,25 @@ class SQLiteDataset(torch.utils.data.Dataset):
             data = np.array([]).reshape((0, len(self._features) - 1))
 
         # Construct graph data object
-        x = torch.tensor(data, dtype=self._dtype)
-        n_pulses = torch.tensor(len(x), dtype=torch.int32)
-        graph = Data(
-            x=x,
-            edge_index= None
-        )
+        if self._node_representation.lower() == 'pulse':
+            x = torch.tensor(data, dtype=self._dtype)
+            n_pulses = torch.tensor(len(x), dtype=torch.int32)
+            graph = Data(
+                x=x,
+                edge_index= None
+            )
+        elif self._node_representation.lower() == 'dom':
+            x = torch.tensor(data, dtype=self._dtype)
+            n_pulses = torch.tensor(len(x), dtype=torch.int32)
+            graph = Data(
+                x=x,
+                edge_index= None
+            )
+            graph = self._make_dom_wise_representation(graph)
+        else:
+            print('WARNING: node representation %s not recognized!'%self._node_representation)
+            
+        
         graph.n_pulses = n_pulses
         graph.features = self._features[1:]
 
diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py
index 4297be6e9..23e1ca10d 100644
--- a/src/graphnet/models/detector/icecube.py
+++ b/src/graphnet/models/detector/icecube.py
@@ -40,6 +40,77 @@ class IceCube86(Detector):
 
         return data
 
+class IceCube86_v2(Detector):
+    """`Detector` class for IceCube-86 with nodes as doms."""
+
+    # Implementing abstract class attribute
+    features = FEATURES.ICECUBE86
+
+    @property
+    def nb_inputs(self) -> int:
+        return len(self.features) + 7
+
+    @property
+    def nb_outputs(self):
+        return self.nb_inputs + 7
+    def _forward(self, data: Data) -> Data:
+        """Ingests data, builds graph (connectivity/adjacency), and preprocesses features.
+
+        Args:
+            data (Data): Input graph data.
+
+        Returns:
+            Data: Connected and preprocessed graph data.
+        """
+
+        # Check(s)
+        self._validate_features(data)
+
+        #pulse_statistics[count,0] = torch.min(time)
+        #pulse_statistics[count,1] = torch.mean(time)
+        #pulse_statistics[count,2] = torch.max(time)
+        #pulse_statistics[count,3] = torch.std(time)
+        #pulse_statistics[count,4] = torch.min(charge)
+        #pulse_statistics[count,5] = torch.mean(charge)
+        #pulse_statistics[count,6] = torch.max(charge)
+        #pulse_statistics[count,7] = torch.std(charge)
+        #unique_doms, n_pulses_pr_dom.unsqueeze(1), pulse_statistics
+
+        # Preprocessing
+        data.x[:,0] /= 100.  # dom_x
+        data.x[:,1] /= 100.  # dom_y
+        data.x[:,2] += 350.  # dom_z
+        data.x[:,2] /= 100.
+        data.x[:,3] -= 1.25  # rde
+        data.x[:,3] /= 0.25
+        data.x[:,4] /= 0.05  # pmt_area
+
+        data.x[:,5] /= 5  # n_pulses
+        data.x[:,5] -= 2
+
+        data.x[:,6] /= 1.05e+04  # dom_time
+        data.x[:,6] -= 1.
+        data.x[:,6] *= 20.
+
+        data.x[:,7] /= 1.05e+04  # dom_time
+        data.x[:,7] -= 1.
+        data.x[:,7] *= 20
+
+        data.x[:,8] /= 1.05e+04  # dom_time
+        data.x[:,8] -= 1.
+        data.x[:,8] *= 20
+
+        data.x[:,9] /= 1.05e+04  # dom_time
+        data.x[:,9] -= 1.
+        data.x[:,9] *= 20
+
+        data.x[:,10] /= 1.  # charge
+        data.x[:,11] /= 1.  # charge
+        data.x[:,12] /= 1.  # charge
+        data.x[:,13] /= 1.  # charge
+        print('detector out')
+        print(data.x.shape)
+        return data
 
 class IceCubeDeepCore(IceCube86):
     """`Detector` class for IceCube-DeepCore."""
diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py
index 388e48278..db5729fdb 100644
--- a/src/graphnet/models/gnn/dynedge.py
+++ b/src/graphnet/models/gnn/dynedge.py
@@ -106,6 +106,9 @@ class DynEdge(GNN):
         # Convenience variables
         x, edge_index, batch = data.x, data.edge_index, data.batch
 
+        print('modelout')
+        print(x.shape)
+
         # Calculate homophily (scalar variables)
         h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch)
 
diff --git a/src/graphnet/models/graph_builders.py b/src/graphnet/models/graph_builders.py
index 4c0f59567..604eb6f9f 100644
--- a/src/graphnet/models/graph_builders.py
+++ b/src/graphnet/models/graph_builders.py
@@ -3,8 +3,8 @@ from typing import List
 
 import torch
 from torch_geometric.nn import knn_graph,radius_graph
-from torch_geometric.data import Data
-
+from torch_geometric.data import Data, Batch
+from graphnet.components.pool import group_identical
 from graphnet.models.utils import calculate_distance_matrix
 
 
@@ -45,7 +45,6 @@ class KNNGraphBuilder(GraphBuilder):  # pylint: disable=too-few-public-methods
 
         return data
 
-
 class RadialGraphBuilder(GraphBuilder):  
     """Builds graph adjacency according to a sphere of chosen radius centred at each DOM hit"""
     def __init__ (
diff --git a/src/graphnet/models/training/utils.py b/src/graphnet/models/training/utils.py
index 6867cc6bb..bc2f1bef2 100644
--- a/src/graphnet/models/training/utils.py
+++ b/src/graphnet/models/training/utils.py
@@ -26,6 +26,7 @@ def make_dataloader(
     selection: List[int] = None,
     num_workers: int = 10,
     persistent_workers: bool = True,
+    node_representation: str = 'pulse',
 ) -> DataLoader:
 
     # Check(s)
@@ -38,6 +39,7 @@ def make_dataloader(
         features,
         truth,
         selection=selection,
+        node_representation = node_representation
     )
 
     def collate_fn(graphs):
@@ -70,6 +72,7 @@ def make_train_validation_dataloader(
     test_size: float = 0.33,
     num_workers: int = 10,
     persistent_workers: bool = True,
+    node_representation: str = 'pulse'
 ) -> Tuple[DataLoader]:
 
     # Reproducibility
@@ -98,6 +101,7 @@ def make_train_validation_dataloader(
         batch_size=batch_size,
         num_workers=num_workers,
         persistent_workers=persistent_workers,
+        node_representation = node_representation,
     )
 
     training_dataloader = make_dataloader(
-- 
GitLab


From b7420b8163c8eb4f30a4dd0b3784582fe4dcb974 Mon Sep 17 00:00:00 2001
From: Rasmus Oersoe <rahn@outlook.dk>
Date: Fri, 8 Apr 2022 10:35:42 +0200
Subject: [PATCH 2/6] dimensional fixes in detector module

---
 src/graphnet/data/sqlite_dataset.py     | 9 +++------
 src/graphnet/models/detector/icecube.py | 4 +---
 src/graphnet/models/gnn/dynedge.py      | 4 ----
 3 files changed, 4 insertions(+), 13 deletions(-)

diff --git a/src/graphnet/data/sqlite_dataset.py b/src/graphnet/data/sqlite_dataset.py
index 915f3a05a..ca8116327 100644
--- a/src/graphnet/data/sqlite_dataset.py
+++ b/src/graphnet/data/sqlite_dataset.py
@@ -135,15 +135,12 @@ class SQLiteDataset(torch.utils.data.Dataset):
             pulse_statistics[count,0] = torch.min(time)
             pulse_statistics[count,1] = torch.mean(time)
             pulse_statistics[count,2] = torch.max(time)
-            pulse_statistics[count,3] = torch.std(time)
+            pulse_statistics[count,3] = torch.std(time, unbiased=False)
             pulse_statistics[count,4] = torch.min(charge)
             pulse_statistics[count,5] = torch.mean(charge)
             pulse_statistics[count,6] = torch.max(charge)
-            pulse_statistics[count,7] = torch.std(charge)
-            count +=1
-        #print(unique_doms.shape)
-        #print(n_pulses_pr_dom.shape)
-        #print(pulse_statistics.shape)      
+            pulse_statistics[count,7] = torch.std(charge, unbiased=False)
+            count +=1   
         data.x = torch.cat((unique_doms, n_pulses_pr_dom.unsqueeze(1), pulse_statistics), dim = 1)
         return data
 
diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py
index 23e1ca10d..22a4b0387 100644
--- a/src/graphnet/models/detector/icecube.py
+++ b/src/graphnet/models/detector/icecube.py
@@ -52,7 +52,7 @@ class IceCube86_v2(Detector):
 
     @property
     def nb_outputs(self):
-        return self.nb_inputs + 7
+        return self.nb_inputs
     def _forward(self, data: Data) -> Data:
         """Ingests data, builds graph (connectivity/adjacency), and preprocesses features.
 
@@ -108,8 +108,6 @@ class IceCube86_v2(Detector):
         data.x[:,11] /= 1.  # charge
         data.x[:,12] /= 1.  # charge
         data.x[:,13] /= 1.  # charge
-        print('detector out')
-        print(data.x.shape)
         return data
 
 class IceCubeDeepCore(IceCube86):
diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py
index db5729fdb..c138e34e3 100644
--- a/src/graphnet/models/gnn/dynedge.py
+++ b/src/graphnet/models/gnn/dynedge.py
@@ -31,7 +31,6 @@ class DynEdge(GNN):
         # Architecture configuration
         c = layer_size_scale
         l1, l2, l3, l4, l5,l6 = nb_inputs, c*16*2, c*32*2, c*42*2, c*32*2, c*16*2
-
         # Base class constructor
         super().__init__(nb_inputs, l6)
 
@@ -106,9 +105,6 @@ class DynEdge(GNN):
         # Convenience variables
         x, edge_index, batch = data.x, data.edge_index, data.batch
 
-        print('modelout')
-        print(x.shape)
-
         # Calculate homophily (scalar variables)
         h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch)
 
-- 
GitLab


From 9906b6498da22300167308638fecba023fce2822 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Rasmus=20F=2E=20=C3=98rs=C3=B8e?=
 <48880272+RasmusOrsoe@users.noreply.github.com>
Date: Fri, 22 Apr 2022 16:30:50 +0200
Subject: [PATCH 3/6] Update src/graphnet/models/detector/icecube.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Andreas Søgaard <andreas.sogaard@gmail.com>
---
 src/graphnet/models/detector/icecube.py | 10 ----------
 1 file changed, 10 deletions(-)

diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py
index 22a4b0387..25a080386 100644
--- a/src/graphnet/models/detector/icecube.py
+++ b/src/graphnet/models/detector/icecube.py
@@ -66,16 +66,6 @@ class IceCube86_v2(Detector):
         # Check(s)
         self._validate_features(data)
 
-        #pulse_statistics[count,0] = torch.min(time)
-        #pulse_statistics[count,1] = torch.mean(time)
-        #pulse_statistics[count,2] = torch.max(time)
-        #pulse_statistics[count,3] = torch.std(time)
-        #pulse_statistics[count,4] = torch.min(charge)
-        #pulse_statistics[count,5] = torch.mean(charge)
-        #pulse_statistics[count,6] = torch.max(charge)
-        #pulse_statistics[count,7] = torch.std(charge)
-        #unique_doms, n_pulses_pr_dom.unsqueeze(1), pulse_statistics
-
         # Preprocessing
         data.x[:,0] /= 100.  # dom_x
         data.x[:,1] /= 100.  # dom_y
-- 
GitLab


From 3422f09862b4d81998a9f94880dfbff20acc9c2b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Rasmus=20F=2E=20=C3=98rs=C3=B8e?=
 <48880272+RasmusOrsoe@users.noreply.github.com>
Date: Fri, 22 Apr 2022 16:31:57 +0200
Subject: [PATCH 4/6] Update src/graphnet/models/detector/icecube.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Andreas Søgaard <andreas.sogaard@gmail.com>
---
 src/graphnet/models/detector/icecube.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py
index 25a080386..29629c3f1 100644
--- a/src/graphnet/models/detector/icecube.py
+++ b/src/graphnet/models/detector/icecube.py
@@ -53,6 +53,7 @@ class IceCube86_v2(Detector):
     @property
     def nb_outputs(self):
         return self.nb_inputs
+        
     def _forward(self, data: Data) -> Data:
         """Ingests data, builds graph (connectivity/adjacency), and preprocesses features.
 
-- 
GitLab


From 5644ee4e204e93d8f5b85eb377b5079f7f6376ea Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Rasmus=20F=2E=20=C3=98rs=C3=B8e?=
 <48880272+RasmusOrsoe@users.noreply.github.com>
Date: Fri, 22 Apr 2022 16:37:39 +0200
Subject: [PATCH 5/6] Update src/graphnet/models/detector/icecube.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Andreas Søgaard <andreas.sogaard@gmail.com>
---
 src/graphnet/models/detector/icecube.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py
index 29629c3f1..22fa5fcf8 100644
--- a/src/graphnet/models/detector/icecube.py
+++ b/src/graphnet/models/detector/icecube.py
@@ -40,7 +40,7 @@ class IceCube86(Detector):
 
         return data
 
-class IceCube86_v2(Detector):
+class IceCube86_DOM(Detector):
     """`Detector` class for IceCube-86 with nodes as doms."""
 
     # Implementing abstract class attribute
-- 
GitLab


From 7a0ce52f562db31d813904eae0f56474c6f5eee0 Mon Sep 17 00:00:00 2001
From: Rasmus Oersoe <rahn@outlook.dk>
Date: Fri, 22 Apr 2022 17:01:47 +0200
Subject: [PATCH 6/6] andreas suggestions

---
 src/graphnet/data/sqlite_dataset.py | 39 +++++++++++++++--------------
 1 file changed, 20 insertions(+), 19 deletions(-)

diff --git a/src/graphnet/data/sqlite_dataset.py b/src/graphnet/data/sqlite_dataset.py
index ca8116327..d1c23ddd2 100644
--- a/src/graphnet/data/sqlite_dataset.py
+++ b/src/graphnet/data/sqlite_dataset.py
@@ -38,6 +38,8 @@ class SQLiteDataset(torch.utils.data.Dataset):
         assert isinstance(features, (list, tuple))
         assert isinstance(truth, (list, tuple))
 
+        assert node_representation.lower() in ['pulse', 'dom']
+
         self._database = database
         self._pulsemaps = pulsemaps
         self._features = [index_column] + features
@@ -124,14 +126,20 @@ class SQLiteDataset(torch.utils.data.Dataset):
         return torch.unique(tensor, return_counts = True, return_inverse=True, dim=0)
 
     def _make_dom_wise_representation(self,data):
-        unique_doms, inverse_idx, n_pulses_pr_dom = self._get_unique_positions(data.x[:,[0,1,2,5,6]])
+        unique_doms, inverse_idx, n_pulses_pr_dom = self._get_unique_positions(data.x[:,[self._features.index('dom_x')   -1,
+                                                                                        self._features.index('dom_y')    -1,
+                                                                                        self._features.index('dom_z')    -1,
+                                                                                        self._features.index('rde')      -1,
+                                                                                        self._features.index('pmt_area') -1]])
         unique_inverse_indices = torch.unique(inverse_idx)
         count = 0
         pulse_statistics = torch.zeros(size = (len(unique_doms), 8))
         #'dom_x','dom_y','dom_z','dom_time','charge','rde','pmt_area'
+        time_idx = self._features.index('dom_time')-1
+        charge_idx = self._features.index('charge')-1
         for unique_inverse_idx in unique_inverse_indices:
-            time   = data.x[inverse_idx == unique_inverse_idx,3]
-            charge = data.x[inverse_idx == unique_inverse_idx,4]
+            time   = data.x[inverse_idx == unique_inverse_idx,time_idx]
+            charge = data.x[inverse_idx == unique_inverse_idx,charge_idx]
             pulse_statistics[count,0] = torch.min(time)
             pulse_statistics[count,1] = torch.mean(time)
             pulse_statistics[count,2] = torch.max(time)
@@ -187,23 +195,16 @@ class SQLiteDataset(torch.utils.data.Dataset):
             data = np.array([]).reshape((0, len(self._features) - 1))
 
         # Construct graph data object
-        if self._node_representation.lower() == 'pulse':
-            x = torch.tensor(data, dtype=self._dtype)
-            n_pulses = torch.tensor(len(x), dtype=torch.int32)
-            graph = Data(
-                x=x,
-                edge_index= None
-            )
-        elif self._node_representation.lower() == 'dom':
-            x = torch.tensor(data, dtype=self._dtype)
-            n_pulses = torch.tensor(len(x), dtype=torch.int32)
-            graph = Data(
-                x=x,
-                edge_index= None
-            )
+       
+        x = torch.tensor(data, dtype=self._dtype)
+        n_pulses = torch.tensor(len(x), dtype=torch.int32)
+        graph = Data(
+            x=x,
+            edge_index= None
+        )
+        if self._node_representation.lower() == 'dom':
             graph = self._make_dom_wise_representation(graph)
-        else:
-            print('WARNING: node representation %s not recognized!'%self._node_representation)
+        
             
         
         graph.n_pulses = n_pulses
-- 
GitLab