diff --git a/src/graphnet/data/sqlite_dataset.py b/src/graphnet/data/sqlite_dataset.py
index 3d55ceb6bfedf7ecc09ef77ca4e882bd9474a9a0..d1c23ddd2eb602fcb9d59324e8b0153468eff93c 100644
--- a/src/graphnet/data/sqlite_dataset.py
+++ b/src/graphnet/data/sqlite_dataset.py
@@ -19,6 +19,7 @@ class SQLiteDataset(torch.utils.data.Dataset):
         truth_table: str = 'truth',
         selection: Optional[List[int]] = None,
         dtype: torch.dtype = torch.float32,
+        node_representation = 'pulse',
     ):
 
         # Check(s)
@@ -37,6 +38,8 @@ class SQLiteDataset(torch.utils.data.Dataset):
         assert isinstance(features, (list, tuple))
         assert isinstance(truth, (list, tuple))
 
+        assert node_representation.lower() in ['pulse', 'dom']
+
         self._database = database
         self._pulsemaps = pulsemaps
         self._features = [index_column] + features
@@ -44,6 +47,7 @@ class SQLiteDataset(torch.utils.data.Dataset):
         self._index_column = index_column
         self._truth_table = truth_table
         self._dtype = dtype
+        self._node_representation =  node_representation
 
         self._features_string = ', '.join(self._features)
         self._truth_string = ', '.join(self._truth)
@@ -118,6 +122,36 @@ class SQLiteDataset(torch.utils.data.Dataset):
         except:
             return -1
 
+    def _get_unique_positions(self, tensor):
+        return torch.unique(tensor, return_counts = True, return_inverse=True, dim=0)
+
+    def _make_dom_wise_representation(self,data):
+        unique_doms, inverse_idx, n_pulses_pr_dom = self._get_unique_positions(data.x[:,[self._features.index('dom_x')   -1,
+                                                                                        self._features.index('dom_y')    -1,
+                                                                                        self._features.index('dom_z')    -1,
+                                                                                        self._features.index('rde')      -1,
+                                                                                        self._features.index('pmt_area') -1]])
+        unique_inverse_indices = torch.unique(inverse_idx)
+        count = 0
+        pulse_statistics = torch.zeros(size = (len(unique_doms), 8))
+        #'dom_x','dom_y','dom_z','dom_time','charge','rde','pmt_area'
+        time_idx = self._features.index('dom_time')-1
+        charge_idx = self._features.index('charge')-1
+        for unique_inverse_idx in unique_inverse_indices:
+            time   = data.x[inverse_idx == unique_inverse_idx,time_idx]
+            charge = data.x[inverse_idx == unique_inverse_idx,charge_idx]
+            pulse_statistics[count,0] = torch.min(time)
+            pulse_statistics[count,1] = torch.mean(time)
+            pulse_statistics[count,2] = torch.max(time)
+            pulse_statistics[count,3] = torch.std(time, unbiased=False)
+            pulse_statistics[count,4] = torch.min(charge)
+            pulse_statistics[count,5] = torch.mean(charge)
+            pulse_statistics[count,6] = torch.max(charge)
+            pulse_statistics[count,7] = torch.std(charge, unbiased=False)
+            count +=1   
+        data.x = torch.cat((unique_doms, n_pulses_pr_dom.unsqueeze(1), pulse_statistics), dim = 1)
+        return data
+
     def _create_graph(self, features, truth):
         """Create Pytorch Data (i.e.graph) object.
 
@@ -161,12 +195,18 @@ class SQLiteDataset(torch.utils.data.Dataset):
             data = np.array([]).reshape((0, len(self._features) - 1))
 
         # Construct graph data object
+       
         x = torch.tensor(data, dtype=self._dtype)
         n_pulses = torch.tensor(len(x), dtype=torch.int32)
         graph = Data(
             x=x,
             edge_index= None
         )
+        if self._node_representation.lower() == 'dom':
+            graph = self._make_dom_wise_representation(graph)
+        
+            
+        
         graph.n_pulses = n_pulses
         graph.features = self._features[1:]
 
diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py
index 4297be6e9ea671a123810cad1577476fb18a42d0..22fa5fcf8b2f6dd9fc77979a4a05a8759536f5fc 100644
--- a/src/graphnet/models/detector/icecube.py
+++ b/src/graphnet/models/detector/icecube.py
@@ -40,6 +40,66 @@ class IceCube86(Detector):
 
         return data
 
+class IceCube86_DOM(Detector):
+    """`Detector` class for IceCube-86 with nodes as doms."""
+
+    # Implementing abstract class attribute
+    features = FEATURES.ICECUBE86
+
+    @property
+    def nb_inputs(self) -> int:
+        return len(self.features) + 7
+
+    @property
+    def nb_outputs(self):
+        return self.nb_inputs
+        
+    def _forward(self, data: Data) -> Data:
+        """Ingests data, builds graph (connectivity/adjacency), and preprocesses features.
+
+        Args:
+            data (Data): Input graph data.
+
+        Returns:
+            Data: Connected and preprocessed graph data.
+        """
+
+        # Check(s)
+        self._validate_features(data)
+
+        # Preprocessing
+        data.x[:,0] /= 100.  # dom_x
+        data.x[:,1] /= 100.  # dom_y
+        data.x[:,2] += 350.  # dom_z
+        data.x[:,2] /= 100.
+        data.x[:,3] -= 1.25  # rde
+        data.x[:,3] /= 0.25
+        data.x[:,4] /= 0.05  # pmt_area
+
+        data.x[:,5] /= 5  # n_pulses
+        data.x[:,5] -= 2
+
+        data.x[:,6] /= 1.05e+04  # dom_time
+        data.x[:,6] -= 1.
+        data.x[:,6] *= 20.
+
+        data.x[:,7] /= 1.05e+04  # dom_time
+        data.x[:,7] -= 1.
+        data.x[:,7] *= 20
+
+        data.x[:,8] /= 1.05e+04  # dom_time
+        data.x[:,8] -= 1.
+        data.x[:,8] *= 20
+
+        data.x[:,9] /= 1.05e+04  # dom_time
+        data.x[:,9] -= 1.
+        data.x[:,9] *= 20
+
+        data.x[:,10] /= 1.  # charge
+        data.x[:,11] /= 1.  # charge
+        data.x[:,12] /= 1.  # charge
+        data.x[:,13] /= 1.  # charge
+        return data
 
 class IceCubeDeepCore(IceCube86):
     """`Detector` class for IceCube-DeepCore."""
diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py
index 388e482781d3e9535185789121c10c66a53e519d..c138e34e3f75921de57c62a9f0aa62474fef6701 100644
--- a/src/graphnet/models/gnn/dynedge.py
+++ b/src/graphnet/models/gnn/dynedge.py
@@ -31,7 +31,6 @@ class DynEdge(GNN):
         # Architecture configuration
         c = layer_size_scale
         l1, l2, l3, l4, l5,l6 = nb_inputs, c*16*2, c*32*2, c*42*2, c*32*2, c*16*2
-
         # Base class constructor
         super().__init__(nb_inputs, l6)
 
diff --git a/src/graphnet/models/graph_builders.py b/src/graphnet/models/graph_builders.py
index 4c0f59567c65c0f9b17bbc7a7c18552f48cdcd2e..604eb6f9f911874d0c0d4530e44728990f0bbb7d 100644
--- a/src/graphnet/models/graph_builders.py
+++ b/src/graphnet/models/graph_builders.py
@@ -3,8 +3,8 @@ from typing import List
 
 import torch
 from torch_geometric.nn import knn_graph,radius_graph
-from torch_geometric.data import Data
-
+from torch_geometric.data import Data, Batch
+from graphnet.components.pool import group_identical
 from graphnet.models.utils import calculate_distance_matrix
 
 
@@ -45,7 +45,6 @@ class KNNGraphBuilder(GraphBuilder):  # pylint: disable=too-few-public-methods
 
         return data
 
-
 class RadialGraphBuilder(GraphBuilder):  
     """Builds graph adjacency according to a sphere of chosen radius centred at each DOM hit"""
     def __init__ (
diff --git a/src/graphnet/models/training/utils.py b/src/graphnet/models/training/utils.py
index 6867cc6bbd23f0787ccf099dbc45a5da1e6414f9..bc2f1bef2828a17cc3bbf631666e3cbdc86f8b58 100644
--- a/src/graphnet/models/training/utils.py
+++ b/src/graphnet/models/training/utils.py
@@ -26,6 +26,7 @@ def make_dataloader(
     selection: List[int] = None,
     num_workers: int = 10,
     persistent_workers: bool = True,
+    node_representation: str = 'pulse',
 ) -> DataLoader:
 
     # Check(s)
@@ -38,6 +39,7 @@ def make_dataloader(
         features,
         truth,
         selection=selection,
+        node_representation = node_representation
     )
 
     def collate_fn(graphs):
@@ -70,6 +72,7 @@ def make_train_validation_dataloader(
     test_size: float = 0.33,
     num_workers: int = 10,
     persistent_workers: bool = True,
+    node_representation: str = 'pulse'
 ) -> Tuple[DataLoader]:
 
     # Reproducibility
@@ -98,6 +101,7 @@ def make_train_validation_dataloader(
         batch_size=batch_size,
         num_workers=num_workers,
         persistent_workers=persistent_workers,
+        node_representation = node_representation,
     )
 
     training_dataloader = make_dataloader(