diff --git a/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml b/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml
index 7c4a6c0177bbc4940b9dc38d377ff1470d7fa891..345087431df7708338b3d768ddc5851d9bbf4f25 100644
--- a/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml
+++ b/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml
@@ -1,4 +1,17 @@
 path: /groups/icecube/asogaard/data/example/dev_lvl7_robustness_muon_neutrino_0000.db
+graph_definition:
+  arguments:
+    columns: [0, 1, 2]
+    detector:
+      arguments: {}
+      class_name: IceCube86
+    dtype: null
+    nb_nearest_neighbours: 8
+    node_definition:
+      arguments: {}
+      class_name: NodesAsPulses
+    node_feature_names: [dom_x, dom_y, dom_z, dom_time, charge, rde, pmt_area]
+  class_name: KNNGraph
 pulsemaps:
   - SRTTWOfflinePulsesDC
 features:
diff --git a/configs/datasets/test_data_sqlite.yml b/configs/datasets/test_data_sqlite.yml
index 9ea481d74fc00f0df8ea8ab296925e6d835c41c2..349a8593ba520e83900b772c8dec8796cc5a47ee 100644
--- a/configs/datasets/test_data_sqlite.yml
+++ b/configs/datasets/test_data_sqlite.yml
@@ -1,26 +1,34 @@
-path: $GRAPHNET/data/tests/sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db
-pulsemaps:
-  - SRTInIcePulses
-features:
-  - dom_x
-  - dom_y
-  - dom_z
-  - dom_time
-  - charge
-  - rde
-  - pmt_area
-truth:
-  - energy
-  - position_x
-  - position_y
-  - position_z
-  - azimuth
-  - zenith
-  - pid
-  - elasticity
-  - sim_type
-  - interaction_type
+features: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+graph_definition:
+  arguments:
+    columns: [0, 1, 2]
+    detector:
+      arguments: {}
+      class_name: Prometheus
+    dtype: null
+    nb_nearest_neighbours: 8
+    node_definition:
+      arguments: {}
+      class_name: NodesAsPulses
+    node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+  class_name: KNNGraph
 index_column: event_no
-truth_table: truth
-seed: 21
-selection: null
\ No newline at end of file
+loss_weight_column: null
+loss_weight_default_value: null
+loss_weight_table: null
+node_truth: null
+node_truth_table: null
+path: /home/iwsatlas1/oersoe/github/graphnet/data/examples/sqlite/prometheus/prometheus-events.db
+pulsemaps: total
+seed: null
+selection: null
+string_selection: null
+truth: [injection_energy, injection_type, injection_interaction_type, injection_zenith,
+  injection_azimuth, injection_bjorkenx, injection_bjorkeny, injection_position_x,
+  injection_position_y, injection_position_z, injection_column_depth, primary_lepton_1_type,
+  primary_hadron_1_type, primary_lepton_1_position_x, primary_lepton_1_position_y,
+  primary_lepton_1_position_z, primary_hadron_1_position_x, primary_hadron_1_position_y,
+  primary_hadron_1_position_z, primary_lepton_1_direction_theta, primary_lepton_1_direction_phi,
+  primary_hadron_1_direction_theta, primary_hadron_1_direction_phi, primary_lepton_1_energy,
+  primary_hadron_1_energy, total_energy]
+truth_table: mc_truth
diff --git a/configs/datasets/training_classification_example_data_sqlite.yml b/configs/datasets/training_classification_example_data_sqlite.yml
index b56266de2f2b709b8734044d2be1f48dca5a8446..5d12c3bbfbd5ea044e7d86f2b805a64cfc201f49 100644
--- a/configs/datasets/training_classification_example_data_sqlite.yml
+++ b/configs/datasets/training_classification_example_data_sqlite.yml
@@ -1,4 +1,17 @@
 path: $GRAPHNET/data/examples/sqlite/prometheus/prometheus-events.db
+graph_definition:
+  arguments:
+    columns: [0, 1, 2]
+    detector:
+      arguments: {}
+      class_name: Prometheus
+    dtype: null
+    nb_nearest_neighbours: 8
+    node_definition:
+      arguments: {}
+      class_name: NodesAsPulses
+    node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+  class_name: KNNGraph
 pulsemaps:
   - total
 features:
diff --git a/configs/datasets/training_example_data_parquet.yml b/configs/datasets/training_example_data_parquet.yml
index 8df8870c75412d69118394058d941e65ad1817db..11c8d7fb08b0ea5eca54ce1bfde0a0d0698beeac 100644
--- a/configs/datasets/training_example_data_parquet.yml
+++ b/configs/datasets/training_example_data_parquet.yml
@@ -1,4 +1,17 @@
 path: $GRAPHNET/data/examples/parquet/prometheus/prometheus-events.parquet
+graph_definition:
+  arguments:
+    columns: [0, 1, 2]
+    detector:
+      arguments: {}
+      class_name: Prometheus
+    dtype: null
+    nb_nearest_neighbours: 8
+    node_definition:
+      arguments: {}
+      class_name: NodesAsPulses
+    node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+  class_name: KNNGraph
 pulsemaps:
   - total
 features:
diff --git a/configs/datasets/training_example_data_sqlite.yml b/configs/datasets/training_example_data_sqlite.yml
index b61074d99738a296d04ff3cb229aedbe884ae223..0de880a77176cfcc049937b03fbf42fb12afb773 100644
--- a/configs/datasets/training_example_data_sqlite.yml
+++ b/configs/datasets/training_example_data_sqlite.yml
@@ -1,4 +1,17 @@
 path: $GRAPHNET/data/examples/sqlite/prometheus/prometheus-events.db
+graph_definition:
+  arguments:
+    columns: [0, 1, 2]
+    detector:
+      arguments: {}
+      class_name: Prometheus
+    dtype: null
+    nb_nearest_neighbours: 8
+    node_definition:
+      arguments: {}
+      class_name: NodesAsPulses
+    node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+  class_name: KNNGraph
 pulsemaps:
   - total
 features:
diff --git a/configs/models/dynedge_PID_classification_example.yml b/configs/models/dynedge_PID_classification_example.yml
index a43ea3856fdd930b47fdd42850b83e507b1fdf9c..4b2fd0246bcd2f3952cb9956987f42db78340a1c 100644
--- a/configs/models/dynedge_PID_classification_example.yml
+++ b/configs/models/dynedge_PID_classification_example.yml
@@ -1,14 +1,4 @@
 arguments:
-  coarsening: null
-  detector:
-    ModelConfig:
-      arguments:
-        graph_builder:
-          ModelConfig:
-            arguments: {columns: null, nb_nearest_neighbours: 8}
-            class_name: KNNGraphBuilder
-        scalers: null
-      class_name: Prometheus
   gnn:
     ModelConfig:
       arguments:
@@ -21,11 +11,29 @@ arguments:
         post_processing_layer_sizes: null
         readout_layer_sizes: null
       class_name: DynEdge
+  graph_definition:
+    ModelConfig:
+      arguments:
+        columns: [0, 1, 2]
+        detector:
+          ModelConfig:
+            arguments: {}
+            class_name: Prometheus
+        dtype: null
+        nb_nearest_neighbours: 8
+        node_definition:
+          ModelConfig:
+            arguments: {}
+            class_name: NodesAsPulses
+        node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+      class_name: KNNGraph
   optimizer_class: '!class torch.optim.adam Adam'
-  optimizer_kwargs: {eps: 1e-03, lr: 1e-05}
-  scheduler_class: '!class torch.optim.lr_scheduler ReduceLROnPlateau'
-  scheduler_config: {frequency: 1, monitor: val_loss}
-  scheduler_kwargs: {patience: 1}
+  optimizer_kwargs: {eps: 0.001, lr: 0.001}
+  scheduler_class: '!class graphnet.training.callbacks PiecewiseLinearLR'
+  scheduler_config: {interval: step}
+  scheduler_kwargs:
+    factors: [0.01, 1, 0.01]
+    milestones: [0, 20.0, 80]
   tasks:
   - ModelConfig:
       arguments:
diff --git a/configs/models/dynedge_energy_example.yml b/configs/models/dynedge_energy_example.yml
deleted file mode 100644
index 02d647f0c70eaa9d4bdce7184129f56f31809e77..0000000000000000000000000000000000000000
--- a/configs/models/dynedge_energy_example.yml
+++ /dev/null
@@ -1,44 +0,0 @@
-arguments:
-  coarsening: null
-  detector:
-    ModelConfig:
-      arguments:
-        graph_builder:
-          ModelConfig:
-            arguments: {columns: null, nb_nearest_neighbours: 8}
-            class_name: KNNGraphBuilder
-        scalers: null
-      class_name: IceCubeDeepCore
-  gnn:
-    ModelConfig:
-      arguments:
-        add_global_variables_after_pooling: false
-        dynedge_layer_sizes: null
-        features_subset: null
-        global_pooling_schemes: [min, max, mean, sum]
-        nb_inputs: 7
-        nb_neighbours: 8
-        post_processing_layer_sizes: null
-        readout_layer_sizes: null
-      class_name: DynEdge
-  optimizer_class: '!class torch.optim.adam Adam'
-  optimizer_kwargs: {eps: 0.001, lr: 1e-05}
-  scheduler_class: '!class torch.optim.lr_scheduler ReduceLROnPlateau'
-  scheduler_config: {frequency: 1, monitor: val_loss}
-  scheduler_kwargs: {patience: 5}
-  tasks:
-  - ModelConfig:
-      arguments:
-        hidden_size: 128
-        loss_function:
-          ModelConfig:
-            arguments: {}
-            class_name: LogCoshLoss
-        loss_weight: null
-        target_labels: energy
-        transform_inference: null
-        transform_prediction_and_target: '!lambda x: torch.log10(x)'
-        transform_support: null
-        transform_target: null
-      class_name: EnergyReconstruction
-class_name: StandardModel
diff --git a/configs/models/dynedge_position_custom_scaling_example.yml b/configs/models/dynedge_position_custom_scaling_example.yml
index e986c1529c0b3b46e68c07c11dda017fc8c68cc5..195695a8d6d1b1a2808f1381141885d72119324f 100644
--- a/configs/models/dynedge_position_custom_scaling_example.yml
+++ b/configs/models/dynedge_position_custom_scaling_example.yml
@@ -1,14 +1,24 @@
 arguments:
-  coarsening: null
-  detector:
+  graph_definition:
     ModelConfig:
       arguments:
-        graph_builder:
+        detector:
           ModelConfig:
-            arguments: {columns: null, nb_nearest_neighbours: 8}
-            class_name: KNNGraphBuilder
-        scalers: null
-      class_name: IceCubeDeepCore
+            arguments: {}
+            class_name: Prometheus
+        dtype: null
+        edge_definition:
+          ModelConfig:
+            arguments:
+              columns: [0, 1, 2]
+              nb_nearest_neighbours: 8
+            class_name: KNNEdges
+        node_definition:
+          ModelConfig:
+            arguments: {}
+            class_name: NodesAsPulses
+        node_feature_names: null
+      class_name: KNNGraph
   gnn:
     ModelConfig:
       arguments:
diff --git a/configs/models/example_direction_reconstruction_model.yml b/configs/models/example_direction_reconstruction_model.yml
index c04974b437bd563c6d42af1ff88bdb4c37578956..cb1c4d841a28b7727be20ef2a07eb3edb61f14f0 100644
--- a/configs/models/example_direction_reconstruction_model.yml
+++ b/configs/models/example_direction_reconstruction_model.yml
@@ -1,14 +1,20 @@
 arguments:
-  coarsening: null
-  detector:
+  graph_definition:
     ModelConfig:
       arguments:
-        graph_builder:
+        columns: [0, 1, 2]
+        detector:
           ModelConfig:
-            arguments: {columns: null, nb_nearest_neighbours: 8}
-            class_name: KNNGraphBuilder
-        scalers: null
-      class_name: Prometheus
+            arguments: {}
+            class_name: Prometheus
+        dtype: null
+        nb_nearest_neighbours: 8
+        node_definition:
+          ModelConfig:
+            arguments: {}
+            class_name: NodesAsPulses
+        node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+      class_name: KNNGraph
   gnn:
     ModelConfig:
       arguments:
diff --git a/configs/models/example_energy_reconstruction_model.yml b/configs/models/example_energy_reconstruction_model.yml
index 7ef5a9265b7abf318aa8199dbefbaeeed383a522..827c84748b393a57a6dee27472b79eef5af2eaa9 100644
--- a/configs/models/example_energy_reconstruction_model.yml
+++ b/configs/models/example_energy_reconstruction_model.yml
@@ -1,14 +1,4 @@
 arguments:
-  coarsening: null
-  detector:
-    ModelConfig:
-      arguments:
-        graph_builder:
-          ModelConfig:
-            arguments: {columns: null, nb_nearest_neighbours: 8}
-            class_name: KNNGraphBuilder
-        scalers: null
-      class_name: Prometheus
   gnn:
     ModelConfig:
       arguments:
@@ -21,11 +11,29 @@ arguments:
         post_processing_layer_sizes: null
         readout_layer_sizes: null
       class_name: DynEdge
+  graph_definition:
+    ModelConfig:
+      arguments:
+        columns: [0, 1, 2]
+        detector:
+          ModelConfig:
+            arguments: {}
+            class_name: Prometheus
+        dtype: null
+        nb_nearest_neighbours: 8
+        node_definition:
+          ModelConfig:
+            arguments: {}
+            class_name: NodesAsPulses
+        node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+      class_name: KNNGraph
   optimizer_class: '!class torch.optim.adam Adam'
-  optimizer_kwargs: {eps: 0.001, lr: 1e-05}
-  scheduler_class: '!class torch.optim.lr_scheduler ReduceLROnPlateau'
-  scheduler_config: {frequency: 1, monitor: val_loss}
-  scheduler_kwargs: {patience: 5}
+  optimizer_kwargs: {eps: 0.001, lr: 0.001}
+  scheduler_class: '!class graphnet.training.callbacks PiecewiseLinearLR'
+  scheduler_config: {interval: step}
+  scheduler_kwargs:
+    factors: [0.01, 1, 0.01]
+    milestones: [0, 20.0, 80]
   tasks:
   - ModelConfig:
       arguments:
@@ -35,8 +43,9 @@ arguments:
             arguments: {}
             class_name: LogCoshLoss
         loss_weight: null
+        prediction_labels: null
         target_labels: total_energy
-        transform_inference: null
+        transform_inference: '!lambda x: torch.pow(10,x)'
         transform_prediction_and_target: '!lambda x: torch.log10(x)'
         transform_support: null
         transform_target: null
diff --git a/configs/models/example_vertex_position_reconstruction_model.yml b/configs/models/example_vertex_position_reconstruction_model.yml
index 8b9c8709c5abd3221ca2077b525894de340c2448..0522a1f2ddc1b914da05b31cd83f4ffdbcb1d7dc 100644
--- a/configs/models/example_vertex_position_reconstruction_model.yml
+++ b/configs/models/example_vertex_position_reconstruction_model.yml
@@ -1,14 +1,4 @@
 arguments:
-  coarsening: null
-  detector:
-    ModelConfig:
-      arguments:
-        graph_builder:
-          ModelConfig:
-            arguments: {columns: null, nb_nearest_neighbours: 8}
-            class_name: KNNGraphBuilder
-        scalers: null
-      class_name: Prometheus
   gnn:
     ModelConfig:
       arguments:
@@ -21,6 +11,22 @@ arguments:
         post_processing_layer_sizes: null
         readout_layer_sizes: null
       class_name: DynEdge
+  graph_definition:
+    ModelConfig:
+      arguments:
+        columns: [0, 1, 2]
+        detector:
+          ModelConfig:
+            arguments: {}
+            class_name: Prometheus
+        dtype: null
+        nb_nearest_neighbours: 8
+        node_definition:
+          ModelConfig:
+            arguments: {}
+            class_name: NodesAsPulses
+        node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t]
+      class_name: KNNGraph
   optimizer_class: '!class torch.optim.adam Adam'
   optimizer_kwargs: {eps: 0.001, lr: 0.001}
   scheduler_class: '!class graphnet.training.callbacks PiecewiseLinearLR'
diff --git a/examples/02_data/01_read_dataset.py b/examples/02_data/01_read_dataset.py
index a9a1f4d01c4aac73081e33da36bf29f6585f1443..302529050856b0a14db32c569f216a9b2dd3b90a 100644
--- a/examples/02_data/01_read_dataset.py
+++ b/examples/02_data/01_read_dataset.py
@@ -13,8 +13,8 @@ from tqdm import tqdm
 from graphnet.constants import TEST_PARQUET_DATA, TEST_SQLITE_DATA
 from graphnet.data.constants import FEATURES, TRUTH
 from graphnet.data.dataset import Dataset
-from graphnet.data.sqlite.sqlite_dataset import SQLiteDataset
-from graphnet.data.parquet.parquet_dataset import ParquetDataset
+from graphnet.data.dataset import SQLiteDataset
+from graphnet.data.dataset import ParquetDataset
 from graphnet.utilities.argparse import ArgumentParser
 from graphnet.utilities.logging import Logger
 
diff --git a/examples/04_training/01_train_model.py b/examples/04_training/01_train_model.py
index 2d82bf2726035ad9fc1ae9d4f37a5377d84830c2..a013326fe5d5ef97484865f6c21ee32399088ee3 100644
--- a/examples/04_training/01_train_model.py
+++ b/examples/04_training/01_train_model.py
@@ -72,6 +72,7 @@ def main(
 
     # Construct dataloaders
     dataset_config = DatasetConfig.load(dataset_config_path)
+    print(dataset_config_path)
     dataloaders = DataLoader.from_dataset_config(
         dataset_config,
         **config.dataloader,
diff --git a/examples/04_training/02_train_model_without_configs.py b/examples/04_training/02_train_model_without_configs.py
index 27e112ce5a609942d184630cf8ad9a1f9d43e452..6d9c5746e87a7a4f9175821db8312e50d3aee5a2 100644
--- a/examples/04_training/02_train_model_without_configs.py
+++ b/examples/04_training/02_train_model_without_configs.py
@@ -13,7 +13,8 @@ from graphnet.data.constants import FEATURES, TRUTH
 from graphnet.models import StandardModel
 from graphnet.models.detector.prometheus import Prometheus
 from graphnet.models.gnn import DynEdge
-from graphnet.models.graph_builders import KNNGraphBuilder
+from graphnet.models.graphs import KNNGraph
+from graphnet.models.graphs.nodes import NodesAsPulses
 from graphnet.models.task.reconstruction import EnergyReconstruction
 from graphnet.training.callbacks import ProgressBar, PiecewiseLinearLR
 from graphnet.training.loss_functions import LogCoshLoss
@@ -77,36 +78,44 @@ def main(
         # Log configuration to W&B
         wandb_logger.experiment.config.update(config)
 
+    # Define graph representation
+    graph_definition = KNNGraph(
+        detector=Prometheus(),
+        node_definition=NodesAsPulses(),
+        nb_nearest_neighbours=8,
+        node_feature_names=features,
+    )
+
     (
         training_dataloader,
         validation_dataloader,
     ) = make_train_validation_dataloader(
-        config["path"],
-        None,
-        config["pulsemap"],
-        features,
-        truth,
+        db=config["path"],
+        graph_definition=graph_definition,
+        pulsemaps=config["pulsemap"],
+        features=features,
+        truth=truth,
         batch_size=config["batch_size"],
         num_workers=config["num_workers"],
         truth_table=truth_table,
+        selection=None,
     )
 
     # Building model
-    detector = Prometheus(
-        graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8),
-    )
+
     gnn = DynEdge(
-        nb_inputs=detector.nb_outputs,
+        nb_inputs=graph_definition.nb_outputs,
         global_pooling_schemes=["min", "max", "mean", "sum"],
     )
     task = EnergyReconstruction(
         hidden_size=gnn.nb_outputs,
         target_labels=config["target"],
         loss_function=LogCoshLoss(),
-        transform_prediction_and_target=torch.log10,
+        transform_prediction_and_target=lambda x: torch.log10(x),
+        transform_inference=lambda x: torch.pow(10, x),
     )
     model = StandardModel(
-        detector=detector,
+        graph_definition=graph_definition,
         gnn=gnn,
         tasks=[task],
         optimizer_class=Adam,
diff --git a/examples/04_training/03_train_classification_model.py b/examples/04_training/03_train_classification_model.py
index b403e4850b96ac8dc07837e9bef15477d24fcd6f..0c537bac03781b304955f42ab95ea9557388060d 100644
--- a/examples/04_training/03_train_classification_model.py
+++ b/examples/04_training/03_train_classification_model.py
@@ -6,7 +6,7 @@ from typing import List, Optional, Dict, Any
 from pytorch_lightning.callbacks import EarlyStopping
 from pytorch_lightning.loggers import WandbLogger
 from pytorch_lightning.utilities import rank_zero_only
-from graphnet.data.dataset import EnsembleDataset
+from graphnet.data.dataset.dataset import EnsembleDataset
 from graphnet.constants import (
     EXAMPLE_OUTPUT_DIR,
     DATASETS_CONFIG_DIR,
diff --git a/src/graphnet/data/__init__.py b/src/graphnet/data/__init__.py
index 64cef139d821998062792d95f37e91540e6a9b7c..1eca4f6cd76e0b45ac5c1db5aaed7becd44c4b5e 100644
--- a/src/graphnet/data/__init__.py
+++ b/src/graphnet/data/__init__.py
@@ -3,14 +3,3 @@
 `graphnet.data` enables converting domain-specific data to industry-standard,
 intermediate file formats  and reading this data.
 """
-
-# Configuration
-from graphnet.utilities.imports import has_torch_package
-
-if has_torch_package():
-    import torch.multiprocessing
-    from .dataset import EnsembleDataset
-
-    torch.multiprocessing.set_sharing_strategy("file_system")
-
-del has_torch_package
diff --git a/src/graphnet/data/dataset/__init__.py b/src/graphnet/data/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ccdd964221a9f509c51d6904e2d7231684dfb5c
--- /dev/null
+++ b/src/graphnet/data/dataset/__init__.py
@@ -0,0 +1,14 @@
+"""Dataset classes for training in GraphNeT."""
+# Configuration
+from graphnet.utilities.imports import has_torch_package
+
+if has_torch_package():
+    import torch.multiprocessing
+    from .dataset import EnsembleDataset, Dataset, ColumnMissingException
+    from .parquet.parquet_dataset import ParquetDataset
+    from .sqlite.sqlite_dataset import SQLiteDataset
+    from .sqlite.sqlite_dataset_perturbed import SQLiteDatasetPerturbed
+
+    torch.multiprocessing.set_sharing_strategy("file_system")
+
+del has_torch_package
diff --git a/src/graphnet/data/dataset.py b/src/graphnet/data/dataset/dataset.py
similarity index 87%
rename from src/graphnet/data/dataset.py
rename to src/graphnet/data/dataset/dataset.py
index 7300da815154c5550cd4c9add067bd4872e4da47..730f33469fc04a6cd8cdf3c4f89f7f87b6c6e1b8 100644
--- a/src/graphnet/data/dataset.py
+++ b/src/graphnet/data/dataset/dataset.py
@@ -12,6 +12,7 @@ from typing import (
     Tuple,
     Union,
     Iterable,
+    Type,
 )
 
 import numpy as np
@@ -29,12 +30,57 @@ from graphnet.utilities.config import (
     save_dataset_config,
 )
 from graphnet.utilities.logging import Logger
+from graphnet.models.graphs import GraphDefinition
+
+from graphnet.utilities.config.parsing import (
+    get_all_grapnet_classes,
+)
 
 
 class ColumnMissingException(Exception):
     """Exception to indicate a missing column in a dataset."""
 
 
+def load_module(class_name: str) -> Type:
+    """Load graphnet module from string name.
+
+    Args:
+        class_name: name of class
+
+    Returns:
+        graphnet module.
+    """
+    # Get a lookup for all classes in `graphnet`
+    import graphnet.data
+    import graphnet.models
+    import graphnet.training
+
+    namespace_classes = get_all_grapnet_classes(
+        graphnet.data, graphnet.models, graphnet.training
+    )
+    return namespace_classes[class_name]
+
+
+def parse_graph_definition(cfg: dict) -> GraphDefinition:
+    """Construct GraphDefinition from DatasetConfig."""
+    assert cfg["graph_definition"] is not None
+
+    args = cfg["graph_definition"]["arguments"]
+    classes = {}
+    for arg in args.keys():
+        if isinstance(args[arg], dict):
+            if "class_name" in args[arg].keys():
+                classes[arg] = load_module(args[arg]["class_name"])(
+                    **args[arg]["arguments"]
+                )
+    new_cfg = deepcopy(args)
+    new_cfg.update(classes)
+    graph_definition = load_module(cfg["graph_definition"]["class_name"])(
+        **new_cfg
+    )
+    return graph_definition
+
+
 class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
     """Base Dataset class for reading from any intermediate file format."""
 
@@ -55,9 +101,13 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
 
         assert isinstance(source, DatasetConfig), (
             f"Argument `source` of type ({type(source)}) is not a "
-            "`DatasetConfig"
+            "`DatasetConfig`"
         )
 
+        assert (
+            "graph_definition" in source.dict().keys()
+        ), "`DatasetConfig` incompatible with current GraphNeT version."
+
         # Parse set of `selection``.
         if isinstance(source.selection, dict):
             return cls._construct_datasets_from_dict(source)
@@ -68,7 +118,10 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
         ):
             return cls._construct_dataset_from_list_of_strings(source)
 
-        return source._dataset_class(**source.dict())
+        cfg = source.dict()
+        if cfg["graph_definition"] is not None:
+            cfg["graph_definition"] = parse_graph_definition(cfg)
+        return source._dataset_class(**cfg)
 
     @classmethod
     def concatenate(
@@ -135,6 +188,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
     def __init__(
         self,
         path: Union[str, List[str]],
+        graph_definition: GraphDefinition,
         pulsemaps: Union[str, List[str]],
         features: List[str],
         truth: List[str],
@@ -195,6 +249,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
                 subset of events when resolving a string-based selection (e.g.,
                 `"10000 random events ~ event_no % 5 > 0"` or `"20% random
                 events ~ event_no % 5 > 0"`).
+            graph_definition: Method that defines the graph representation.
         """
         # Base class constructor
         super().__init__(name=__name__, class_name=self.__class__.__name__)
@@ -218,6 +273,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
         self._index_column = index_column
         self._truth_table = truth_table
         self._loss_weight_default_value = loss_weight_default_value
+        self._graph_definition = graph_definition
 
         if node_truth is not None:
             assert isinstance(node_truth_table, str)
@@ -521,10 +577,6 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
     ) -> Data:
         """Create Pytorch Data (i.e. graph) object.
 
-        No preprocessing is performed at this stage, just as no node adjancency
-        is imposed. This means that the `edge_attr` and `edge_weight`
-        attributes are not set.
-
         Args:
             features: List of tuples, containing event features.
             truth: List of tuples, containing truth information.
@@ -552,71 +604,33 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
                 for index, key in enumerate(self._node_truth)
             }
 
+        # Create list of truth dicts with labels
+        truth_dicts = [labels_dict, truth_dict]
+        if node_truth is not None:
+            truth_dicts.append(node_truth_dict)
+
         # Catch cases with no reconstructed pulses
         if len(features):
-            data = np.asarray(features)[:, 1:]
+            node_features = np.asarray(features)[
+                :, 1:
+            ]  # first entry is index column
         else:
-            data = np.array([]).reshape((0, len(self._features) - 1))
+            node_features = np.array([]).reshape((0, len(self._features) - 1))
 
         # Construct graph data object
-        x = torch.tensor(data, dtype=self._dtype)  # pylint: disable=C0103
-        n_pulses = torch.tensor(len(x), dtype=torch.int32)
-        graph = Data(x=x, edge_index=None)
-        graph.n_pulses = n_pulses
-        graph.features = self._features[1:]
-
-        # Add loss weight to graph.
-        if loss_weight is not None and self._loss_weight_column is not None:
-            # No loss weight was retrieved, i.e., it is missing for the current
-            # event.
-            if loss_weight < 0:
-                if self._loss_weight_default_value is None:
-                    raise ValueError(
-                        "At least one event is missing an entry in "
-                        f"{self._loss_weight_column} "
-                        "but loss_weight_default_value is None."
-                    )
-                graph[self._loss_weight_column] = torch.tensor(
-                    self._loss_weight_default_value, dtype=self._dtype
-                ).reshape(-1, 1)
-            else:
-                graph[self._loss_weight_column] = torch.tensor(
-                    loss_weight, dtype=self._dtype
-                ).reshape(-1, 1)
-
-        # Write attributes, either target labels, truth info or original
-        # features.
-        add_these_to_graph = [labels_dict, truth_dict]
-        if node_truth is not None:
-            add_these_to_graph.append(node_truth_dict)
-        for write_dict in add_these_to_graph:
-            for key, value in write_dict.items():
-                try:
-                    graph[key] = torch.tensor(value)
-                except (TypeError, RuntimeError) as error:
-                    if isinstance(error, TypeError) or (value is None):
-                        # Cannot convert `value` to Tensor due to its data type,
-                        # e.g. `str`.
-                        self.warning(
-                            (
-                                f"Could not assign `{key}` with type "
-                                f"'{type(value).__name__ if value is not None else 'NoneType'}' as attribute to graph of "
-                                f"event with {self._index_column} == {labels_dict[self._index_column]}"
-                            )
-                        )
-                    else:
-                        raise error
-        # Additionally add original features as (static) attributes
-        for index, feature in enumerate(graph.features):
-            if feature not in ["x"]:
-                graph[feature] = graph.x[:, index].detach()
-
-        # Add custom labels to the graph
-        for key, fn in self._label_fns.items():
-            graph[key] = fn(graph)
-
-        # Add Dataset Path. Useful if multiple datasets are concatenated.
-        graph["dataset_path"] = self._path
+        assert self._graph_definition is not None
+        graph = self._graph_definition(
+            node_features=node_features,
+            node_feature_names=self._features[
+                1:
+            ],  # first entry is index column
+            truth_dicts=truth_dicts,
+            custom_label_functions=self._label_fns,
+            loss_weight_column=self._loss_weight_column,
+            loss_weight=loss_weight,
+            loss_weight_default_value=self._loss_weight_default_value,
+            data_path=self._path,
+        )
         return graph
 
     def _get_labels(self, truth_dict: Dict[str, Any]) -> Dict[str, Any]:
diff --git a/src/graphnet/data/dataset/parquet/__init__.py b/src/graphnet/data/dataset/parquet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..edfc62b4e83e0e79d008a116d8bdabb2ad9993e3
--- /dev/null
+++ b/src/graphnet/data/dataset/parquet/__init__.py
@@ -0,0 +1,11 @@
+"""Datasets using parquet backend."""
+# Configuration
+from graphnet.utilities.imports import has_torch_package
+
+if has_torch_package():
+    import torch.multiprocessing
+    from .parquet_dataset import ParquetDataset
+
+    torch.multiprocessing.set_sharing_strategy("file_system")
+
+del has_torch_package
diff --git a/src/graphnet/data/parquet/parquet_dataset.py b/src/graphnet/data/dataset/parquet/parquet_dataset.py
similarity index 98%
rename from src/graphnet/data/parquet/parquet_dataset.py
rename to src/graphnet/data/dataset/parquet/parquet_dataset.py
index 7839bd983e08b0b00dca362eac7d95b73252d323..bb63e18001520003b8e85590ab70df02692cddab 100644
--- a/src/graphnet/data/parquet/parquet_dataset.py
+++ b/src/graphnet/data/dataset/parquet/parquet_dataset.py
@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast
 import numpy as np
 import awkward as ak
 
-from graphnet.data.dataset import Dataset, ColumnMissingException
+from graphnet.data.dataset.dataset import Dataset, ColumnMissingException
 
 
 class ParquetDataset(Dataset):
diff --git a/src/graphnet/data/dataset/sqlite/__init__.py b/src/graphnet/data/dataset/sqlite/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..74b164e6ebba031edc39529d18cf1ec57cec1d3d
--- /dev/null
+++ b/src/graphnet/data/dataset/sqlite/__init__.py
@@ -0,0 +1,6 @@
+"""Datasets using SQLite backend."""
+from graphnet.utilities.imports import has_torch_package
+
+if has_torch_package():
+    from .sqlite_dataset import SQLiteDataset
+    from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed
diff --git a/src/graphnet/data/sqlite/sqlite_dataset.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset.py
similarity index 98%
rename from src/graphnet/data/sqlite/sqlite_dataset.py
rename to src/graphnet/data/dataset/sqlite/sqlite_dataset.py
index e61623c46d60c16abe474af3543fc379d59b75ed..a0b06ff6676201a4994a990f24fde2187e4bb464 100644
--- a/src/graphnet/data/sqlite/sqlite_dataset.py
+++ b/src/graphnet/data/dataset/sqlite/sqlite_dataset.py
@@ -4,7 +4,7 @@ from typing import Any, List, Optional, Tuple, Union
 import pandas as pd
 import sqlite3
 
-from graphnet.data.dataset import Dataset, ColumnMissingException
+from graphnet.data.dataset.dataset import Dataset, ColumnMissingException
 
 
 class SQLiteDataset(Dataset):
diff --git a/src/graphnet/data/sqlite/sqlite_dataset_perturbed.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py
similarity index 98%
rename from src/graphnet/data/sqlite/sqlite_dataset_perturbed.py
rename to src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py
index 6fb977ccd7a35b0aff42a782422893cb4afaa75c..755d96b82df1b111e319bf12a86a44a4668a7338 100644
--- a/src/graphnet/data/sqlite/sqlite_dataset_perturbed.py
+++ b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py
@@ -6,7 +6,7 @@ import numpy as np
 import torch
 from torch_geometric.data import Data
 
-from graphnet.data.sqlite.sqlite_dataset import SQLiteDataset
+from .sqlite_dataset import SQLiteDataset
 
 
 class SQLiteDatasetPerturbed(SQLiteDataset):
diff --git a/src/graphnet/data/parquet/__init__.py b/src/graphnet/data/parquet/__init__.py
index fc0b2f7a02070608ac6c23a0c56863d9ee21917c..616d89c166fa46f0d0d2c59e9e554561434ed5a7 100644
--- a/src/graphnet/data/parquet/__init__.py
+++ b/src/graphnet/data/parquet/__init__.py
@@ -1,10 +1,2 @@
 """Parquet-specific implementation of data classes."""
-
-from graphnet.utilities.imports import has_torch_package
-
 from .parquet_dataconverter import ParquetDataConverter
-
-if has_torch_package():
-    from .parquet_dataset import ParquetDataset
-
-del has_torch_package
diff --git a/src/graphnet/data/sqlite/__init__.py b/src/graphnet/data/sqlite/__init__.py
index f632f2be85992b8c28d29c8fc15b7ccbecd45e2c..e4ac554a789cee81ebca6b49a28be05d4777d1a0 100644
--- a/src/graphnet/data/sqlite/__init__.py
+++ b/src/graphnet/data/sqlite/__init__.py
@@ -1,11 +1,4 @@
 """SQLite-specific implementation of data classes."""
-
-from graphnet.utilities.imports import has_torch_package
-
 from .sqlite_dataconverter import SQLiteDataConverter
 from .sqlite_utilities import create_table_and_save_to_sql
-
-if has_torch_package():
-    from .sqlite_dataset import SQLiteDataset
-    from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed
 from .sqlite_utilities import run_sql_code, save_to_sql
diff --git a/src/graphnet/models/detector/__init__.py b/src/graphnet/models/detector/__init__.py
index 17fba8fd2ae8f572504ae3d11d0b2309452e30d3..060b7ca038953cad94d0f9eebe5c5f5f66968b82 100644
--- a/src/graphnet/models/detector/__init__.py
+++ b/src/graphnet/models/detector/__init__.py
@@ -1,3 +1,4 @@
 """Detector-specific modules, for data ingestion and standardisation."""
 
 from .icecube import IceCube86, IceCubeDeepCore
+from .detector import Detector
diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py
index 4ad3cce482d6d8accd23bad60c226e1ec4a35776..25660c6d7abefa5dca3421b2ecbcbcb58949959c 100644
--- a/src/graphnet/models/detector/detector.py
+++ b/src/graphnet/models/detector/detector.py
@@ -1,115 +1,49 @@
 """Base detector-specific `Model` class(es)."""
 
 from abc import abstractmethod
-from typing import List
+from typing import Dict, Callable, List
 
-import torch
 from torch_geometric.data import Data
-from torch_geometric.data.batch import Batch
+import torch
 
-from graphnet.models.graph_builders import GraphBuilder
 from graphnet.models import Model
-from graphnet.utilities.config import save_model_config
 from graphnet.utilities.decorators import final
+from graphnet.utilities.config import save_model_config
 
 
 class Detector(Model):
     """Base class for all detector-specific read-ins in graphnet."""
 
-    @property
-    @abstractmethod
-    def features(self) -> List[str]:
-        """List of features used/assumed by inheriting `Detector` objects."""
-
     @save_model_config
-    def __init__(
-        self, graph_builder: GraphBuilder, scalers: List[dict] = None
-    ):
+    def __init__(self) -> None:
         """Construct `Detector`."""
         # Base class constructor
         super().__init__(name=__name__, class_name=self.__class__.__name__)
 
-        # Member variables
-        self._graph_builder = graph_builder
-        self._scalers = scalers
-        if self._scalers:
-            self.info(
-                (
-                    "Will use scalers rather than standard preprocessing "
-                    f"in {self.__class__.__name__}."
-                )
-            )
+    @property
+    @abstractmethod
+    def feature_map(self) -> Dict[str, Callable]:
+        """List of features used/assumed by inheriting `Detector` objects."""
 
     @final
-    def forward(self, data: Data) -> Data:
+    def forward(
+        self, node_features: torch.tensor, node_feature_names: List[str]
+    ) -> Data:
         """Pre-process graph `Data` features and build graph adjacency."""
-        # Check(s)
-        assert data.x.size()[1] == self.nb_inputs, (
-            "Got graph data with incompatible size, ",
-            f"{data.x.size()} vs. {self.nb_inputs} expected",
-        )
-
-        # Graph-bulding
-        # @NOTE: `.clone` is necessary to avoid modifying original tensor in-place.
-        data = self._graph_builder(data).clone()
-
-        if self._scalers:
-            # # Scaling individual features
-            # x_numpy = data.x.detach().cpu().numpy()
-            # for key, scaler in self._scalers.items():
-            #     ix = self.features.index(key)
-            #     data.x[:,ix] = torch.tensor(scaler.transform(x_numpy[:,ix])).type_as(data.x)
-
-            # Scaling groups of features | @TEMP, probably
-            x_numpy = data.x.detach().cpu().numpy()
-
-            data.x[:, :3] = torch.tensor(
-                self._scalers["xyz"].transform(x_numpy[:, :3])  # type: ignore[call-overload]
-            ).type_as(data.x)
+        return self._standardize(node_features, node_feature_names)
 
-            data.x[:, 3:] = torch.tensor(
-                self._scalers["features"].transform(x_numpy[:, 3:])  # type: ignore[call-overload]
-            ).type_as(data.x)
-
-        else:
-            # Implementation-specific forward pass (e.g. preprocessing)
-            data = self._forward(data)
-
-        return data
-
-    @abstractmethod
-    def _forward(self, data: Data) -> Data:
-        """Syntax like `.forward`, for implentation in inheriting classes."""
-
-    @property
-    def nb_inputs(self) -> int:
-        """Return number of input features."""
-        return len(self.features)
-
-    @property
-    def nb_outputs(self) -> int:
-        """Return number of output features.
-
-        This the default, but may be overridden by specific inheriting classes.
-        """
-        return self.nb_inputs
-
-    def _validate_features(self, data: Data) -> None:
-        if isinstance(data, Batch):
-            # `data.features` is "transposed" and each list element contains only duplicate entries.
-
-            if (
-                len(data.features[0]) == data.num_graphs
-                and len(set(data.features[0])) == 1
-            ):
-                data_features = [features[0] for features in data.features]
-
-            # `data.features` is not "transposed" and each list element
-            # contains the original features.
-            else:
-                data_features = data.features[0]
-        else:
-            data_features = data.features
-        assert (
-            data_features == self.features
-        ), f"Features on Data and Detector differ: {data_features} vs. {self.features}"
+    @final
+    def _standardize(
+        self, node_features: torch.tensor, node_feature_names: List[str]
+    ) -> Data:
+        for idx, feature in enumerate(node_feature_names):
+            try:
+                node_features[:, idx] = self.feature_map()[feature](  # type: ignore
+                    node_features[:, idx]
+                )
+            except KeyError as e:
+                self.warning(
+                    f"""No Standardization function found for '{feature}'"""
+                )
+                raise e
+        return node_features
diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py
index e6addb8e37f178bbae3994fdd8ff5cba759d36c9..c2028755ad5207316c320c427772c655775c7fd9 100644
--- a/src/graphnet/models/detector/icecube.py
+++ b/src/graphnet/models/detector/icecube.py
@@ -1,5 +1,6 @@
 """IceCube-specific `Detector` class(es)."""
 
+from typing import Dict, Callable
 import torch
 from torch_geometric.data import Data
 
@@ -15,32 +16,33 @@ from graphnet.models.detector.detector import Detector
 class IceCube86(Detector):
     """`Detector` class for IceCube-86."""
 
-    # Implementing abstract class attribute
-    features = FEATURES.ICECUBE86
-
-    def _forward(self, data: Data) -> Data:
-        """Ingest data, build graph, and preprocess features.
-
-        Args:
-            data: Input graph data.
-
-        Returns:
-            Connected and preprocessed graph data.
-        """
-        # Check(s)
-        self._validate_features(data)
-
-        # Preprocessing
-        data.x[:, 0] /= 500.0  # dom_x
-        data.x[:, 1] /= 500.0  # dom_y
-        data.x[:, 2] /= 500.0  # dom_z
-        data.x[:, 3] = (data.x[:, 3] - 1.0e04) / 3.0e4  # dom_time
-        data.x[:, 4] = torch.log10(data.x[:, 4]) / 3.0  # charge
-        data.x[:, 5] -= 1.25  # rde
-        data.x[:, 5] /= 0.25
-        data.x[:, 6] /= 0.05  # pmt_area
-
-        return data
+    def feature_map(self) -> Dict[str, Callable]:
+        """Map standardization functions to each dimension."""
+        feature_map = {
+            "dom_x": self._dom_xyz,
+            "dom_y": self._dom_xyz,
+            "dom_z": self._dom_xyz,
+            "dom_time": self._dom_time,
+            "charge": self._charge,
+            "rde": self._rde,
+            "pmt_area": self._pmt_area,
+        }
+        return feature_map
+
+    def _dom_xyz(self, x: torch.tensor) -> torch.tensor:
+        return x / 500.0
+
+    def _dom_time(self, x: torch.tensor) -> torch.tensor:
+        return x - 1.0e04 / 3.0e4
+
+    def _charge(self, x: torch.tensor) -> torch.tensor:
+        return torch.log10(x)
+
+    def _rde(self, x: torch.tensor) -> torch.tensor:
+        return (x - 1.25) / 0.25
+
+    def _pmt_area(self, x: torch.tensor) -> torch.tensor:
+        return x / 0.05
 
 
 class IceCubeKaggle(Detector):
diff --git a/src/graphnet/models/detector/prometheus.py b/src/graphnet/models/detector/prometheus.py
index 2a35886ac014e3ef82bcf98dd38d7a62e504f860..f21f9c41363546393c61783df78994f34e297320 100644
--- a/src/graphnet/models/detector/prometheus.py
+++ b/src/graphnet/models/detector/prometheus.py
@@ -1,6 +1,7 @@
 """Prometheus-specific `Detector` class(es)."""
 
-from torch_geometric.data import Data
+from typing import Dict, Callable
+import torch
 
 from graphnet.models.detector.detector import Detector
 
@@ -8,27 +9,21 @@ from graphnet.models.detector.detector import Detector
 class Prometheus(Detector):
     """`Detector` class for Prometheus prototype."""
 
-    features = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z", "t"]
+    def feature_map(self) -> Dict[str, Callable]:
+        """Map standardization functions to each dimension."""
+        feature_map = {
+            "sensor_pos_x": self._sensor_pos_xy,
+            "sensor_pos_y": self._sensor_pos_xy,
+            "sensor_pos_z": self._sensor_pos_z,
+            "t": self._t,
+        }
+        return feature_map
 
-    def _forward(self, data: Data) -> Data:
-        """Ingest data, build graph, and preprocess features.
+    def _sensor_pos_xy(self, x: torch.tensor) -> torch.tensor:
+        return x / 100
 
-        Args:
-            data: Input graph data.
+    def _sensor_pos_z(self, x: torch.tensor) -> torch.tensor:
+        return (x + 350) / 100
 
-        Returns:
-            Connected and preprocessed graph data.
-        """
-        # Check(s)
-        self._validate_features(data)
-
-        # Preprocessing
-        data.x[:, 0] /= 100.0  # dom_x
-        data.x[:, 1] /= 100.0  # dom_y
-        data.x[:, 2] += 350.0  # dom_z
-        data.x[:, 2] /= 100.0
-        data.x[:, 3] /= 1.05e04  # dom_time
-        data.x[:, 3] -= 1.0
-        data.x[:, 3] *= 20.0
-
-        return data
+    def _t(self, x: torch.tensor) -> torch.tensor:
+        return ((x / 1.05e04) - 1.0) * 20.0
diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea5066307d0d9f5525f8b6b93895546b9003aae3
--- /dev/null
+++ b/src/graphnet/models/graphs/__init__.py
@@ -0,0 +1,10 @@
+"""Modules for constructing graphs.
+
+´GraphDefinition´ defines the nodes and their features,  and contains general
+graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes
+and their features.
+"""
+
+
+from .graph_definition import GraphDefinition
+from .graphs import KNNGraph
diff --git a/src/graphnet/models/graphs/edges/__init__.py b/src/graphnet/models/graphs/edges/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7da8baa7c2adb6556fb08525b6f1a6a2ddc1d181
--- /dev/null
+++ b/src/graphnet/models/graphs/edges/__init__.py
@@ -0,0 +1,7 @@
+"""Modules for constructing graphs.
+
+´GraphDefinition´ defines the nodes and their features,  and contains general
+graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes
+and their features.
+"""
+from .edges import EdgeDefinition, KNNEdges, RadialEdges, EuclideanEdges
diff --git a/src/graphnet/models/graphs/edges/edges.py b/src/graphnet/models/graphs/edges/edges.py
new file mode 100644
index 0000000000000000000000000000000000000000..28507058bb5a4353025bd68cd7e6fc7d93207582
--- /dev/null
+++ b/src/graphnet/models/graphs/edges/edges.py
@@ -0,0 +1,188 @@
+"""Class(es) for building/connecting graphs."""
+
+from typing import List
+from abc import abstractmethod, ABC
+
+import torch
+from torch_geometric.nn import knn_graph, radius_graph
+from torch_geometric.data import Data
+
+from graphnet.utilities.config import save_model_config
+from graphnet.models.utils import calculate_distance_matrix
+from graphnet.models import Model
+
+
+class EdgeDefinition(Model):  # pylint: disable=too-few-public-methods
+    """Base class for graph building."""
+
+    def forward(self, graph: Data) -> Data:
+        """Construct edges based on problem specific implementation of.
+
+        ´_construct_edges´
+
+        Args:
+            graph: a graph without edges
+
+        Returns:
+            graph: a graph with edges
+        """
+        if graph.edge_index is not None:
+            self.warnonce(
+                "GraphBuilder received graph with pre-existing "
+                "structure. Will overwrite."
+            )
+        return self._construct_edges(graph)
+
+    @abstractmethod
+    def _construct_edges(self, graph: Data) -> Data:
+        """Construct edges and assign them to graph. I.e. ´graph.edge_index = edge_index´.
+
+        Args:
+            graph: graph without edges
+
+        Returns:
+            graph: graph with edges assigned.
+        """
+
+
+class KNNEdges(EdgeDefinition):  # pylint: disable=too-few-public-methods
+    """Builds edges from the k-nearest neighbours."""
+
+    @save_model_config
+    def __init__(
+        self,
+        nb_nearest_neighbours: int,
+        columns: List[int] = [0, 1, 2],
+    ):
+        """K-NN Edge definition.
+
+        Will connect nodes together with their ´nb_nearest_neighbours´
+        nearest neighbours in the feature space given by ´columns´.
+
+        Args:
+            nb_nearest_neighbours: number of neighbours.
+            columns: Node features to use for distance calculation.
+            Defaults to [0,1,2].
+        """
+        # Base class constructor
+        super().__init__(name=__name__, class_name=self.__class__.__name__)
+
+        # Member variable(s)
+        self._nb_nearest_neighbours = nb_nearest_neighbours
+        self._columns = columns
+
+    def _construct_edges(self, graph: Data) -> Data:
+        """Define K-NN edges."""
+        graph.edge_index = knn_graph(
+            graph.x[:, self._columns],
+            self._nb_nearest_neighbours,
+            graph.batch,
+        ).to(self.device)
+
+        return graph
+
+
+class RadialEdges(EdgeDefinition):
+    """Builds graph from a sphere of chosen radius centred at each node."""
+
+    @save_model_config
+    def __init__(
+        self,
+        radius: float,
+        columns: List[int] = [0, 1, 2],
+    ):
+        """Radial edges.
+
+        Connects each node to other nodes that are within a sphere of
+        radius ´r´ centered at the node. The feature space of ´r´ is defined
+        by ´columns´
+
+        Args:
+            radius: radius of sphere
+            columns: columns of the node feature matrix used.
+            Defaults to [0,1,2].
+        """
+        # Base class constructor
+        super().__init__(name=__name__, class_name=self.__class__.__name__)
+
+        # Member variable(s)
+        self._radius = radius
+        self._columns = columns
+
+    def _construct_edges(self, graph: Data) -> Data:
+        """Define radial edges."""
+        graph.edge_index = radius_graph(
+            graph.x[:, self._columns],
+            self._radius,
+            graph.batch,
+        ).to(self.device)
+
+        return graph
+
+
+class EuclideanEdges(EdgeDefinition):  # pylint: disable=too-few-public-methods
+    """Builds edges according to Euclidean distance between nodes.
+
+    See https://arxiv.org/pdf/1809.06166.pdf.
+    """
+
+    @save_model_config
+    def __init__(
+        self,
+        sigma: float,
+        threshold: float = 0.0,
+        columns: List[int] = None,
+    ):
+        """Construct `EuclideanEdges`."""
+        # Base class constructor
+        super().__init__(name=__name__, class_name=self.__class__.__name__)
+
+        # Check(s)
+        if columns is None:
+            columns = [0, 1, 2]
+
+        # Member variable(s)
+        self._sigma = sigma
+        self._threshold = threshold
+        self._columns = columns
+
+    def _construct_edges(self, graph: Data) -> Data:
+        """Forward pass."""
+        # Constructs the adjacency matrix from the raw, DOM-level data and
+        # returns this matrix
+        if graph.edge_index is not None:
+            self.info(
+                "WARNING: GraphBuilder received graph with pre-existing "
+                "structure. Will overwrite."
+            )
+
+        xyz_coords = graph.x[:, self._columns]
+
+        # Construct block-diagonal matrix indicating whether pulses belong to
+        # the same event in the batch
+        batch_mask = graph.batch.unsqueeze(dim=0) == graph.batch.unsqueeze(
+            dim=1
+        )
+
+        distance_matrix = calculate_distance_matrix(xyz_coords)
+        affinity_matrix = torch.exp(
+            -0.5 * distance_matrix**2 / self._sigma**2
+        )
+
+        # Use softmax to normalise all adjacencies to one for each node
+        exp_row_sums = torch.exp(affinity_matrix).sum(axis=1)
+        weighted_adj_matrix = torch.exp(
+            affinity_matrix
+        ) / exp_row_sums.unsqueeze(dim=1)
+
+        # Only include edges with weights that exceed the chosen threshold (and
+        # are part of the same event)
+        sources, targets = torch.where(
+            (weighted_adj_matrix > self._threshold) & (batch_mask)
+        )
+        edge_weights = weighted_adj_matrix[sources, targets]
+
+        graph.edge_index = torch.stack((sources, targets))
+        graph.edge_weight = edge_weights
+
+        return graph
diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cf7b32542efb37abe08d229b8dd16e5fff7825d
--- /dev/null
+++ b/src/graphnet/models/graphs/graph_definition.py
@@ -0,0 +1,270 @@
+"""Modules for defining graphs.
+
+These are self-contained graph definitions that hold all the graph-altering
+code in graphnet. These modules define what the GNNs sees as input and can be
+passed to dataloaders during training and deployment.
+"""
+
+
+from typing import Any, List, Optional, Dict, Callable
+import torch
+from torch_geometric.data import Data
+import numpy as np
+
+from graphnet.utilities.config import save_model_config
+
+from graphnet.models.detector import Detector
+from .edges import EdgeDefinition
+from .nodes import NodeDefinition
+from graphnet.models import Model
+
+
+class GraphDefinition(Model):
+    """An Abstract class to create graph definitions from."""
+
+    @save_model_config
+    def __init__(
+        self,
+        detector: Detector,
+        node_definition: NodeDefinition,
+        edge_definition: Optional[EdgeDefinition] = None,
+        node_feature_names: Optional[List[str]] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        """Construct ´GraphDefinition´. The ´detector´ holds.
+
+        ´Detector´-specific code. E.g. scaling/standardization and geometry
+        tables.
+
+        ´node_definition´ defines the nodes in the graph.
+
+        ´edge_definition´ defines the connectivity of the nodes in the graph.
+
+        Args:
+            detector: The corresponding ´Detector´ representing the data.
+            node_definition: Definition of nodes.
+            edge_definition: Definition of edges. Defaults to None.
+            node_feature_names: Names of node feature columns. Defaults to None
+            dtype: data type used for node features. e.g. ´torch.float´
+        """
+        # Base class constructor
+        super().__init__(name=__name__, class_name=self.__class__.__name__)
+
+        # Member Variables
+        self._detector = detector
+        self._edge_definiton = edge_definition
+        self._node_definition = node_definition
+        if node_feature_names is None:
+            # Assume all features in Detector is used.
+            node_feature_names = list(self._detector.feature_map().keys())  # type: ignore
+        print(node_feature_names)
+        self._node_feature_names = node_feature_names
+        if dtype is None:
+            dtype = torch.float
+        self._dtype = dtype
+
+        # Set Input / Output dimensions
+        self._node_definition.set_number_of_inputs(
+            node_feature_names=node_feature_names
+        )
+        self.nb_inputs = len(self._node_feature_names)
+        self.nb_outputs = self._node_definition.nb_outputs
+
+    def forward(  # type: ignore
+        self,
+        node_features: np.array,
+        node_feature_names: List[str],
+        truth_dicts: Optional[List[Dict[str, Any]]] = None,
+        custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None,
+        loss_weight_column: Optional[str] = None,
+        loss_weight: Optional[float] = None,
+        loss_weight_default_value: Optional[float] = None,
+        data_path: Optional[str] = None,
+    ) -> Data:
+        """Construct graph as ´Data´ object.
+
+        Args:
+            node_features: node features for graph. Shape ´[num_nodes, d]´
+            node_feature_names: name of each column. Shape ´[,d]´.
+            truth_dicts: Dictionary containing truth labels.
+            custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels.
+            loss_weight_column: Name of column that holds loss weight. Defaults to None.
+            loss_weight: Loss weight associated with event. Defaults to None.
+            loss_weight_default_value: default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None.
+            data_path: Path to dataset data files. Defaults to None.
+
+        Returns:
+            graph
+        """
+        # Checks
+        self._validate_input(
+            node_features=node_features, node_feature_names=node_feature_names
+        )
+
+        # Transform to pytorch tensor
+        node_features = torch.tensor(node_features, dtype=self._dtype)
+
+        # Standardize / Scale  node features
+        node_features = self._detector(node_features, node_feature_names)
+
+        # Create graph
+        graph = self._node_definition(node_features)
+
+        # Attach number of pulses as static attribute.
+        graph.n_pulses = torch.tensor(len(node_features), dtype=torch.int32)
+
+        # Assign edges
+        if self._edge_definiton is not None:
+            graph = self._edge_definiton(graph)
+        else:
+            self.warnonce(
+                "No EdgeDefinition provided. Graphs will not have edges defined!"
+            )
+
+        # Attach data path - useful for Ensemble datasets.
+        if data_path is not None:
+            graph["dataset_path"] = data_path
+
+        # Attach loss weights if they exist
+        graph = self._add_loss_weights(
+            graph=graph,
+            loss_weight=loss_weight,
+            loss_weight_column=loss_weight_column,
+            loss_weight_default_value=loss_weight_default_value,
+        )
+
+        # Attach default truth labels and node truths
+        if truth_dicts is not None:
+            graph = self._add_truth(graph=graph, truth_dicts=truth_dicts)
+
+        # Attach custom truth labels
+        if custom_label_functions is not None:
+            graph = self._add_custom_labels(
+                graph=graph, custom_label_functions=custom_label_functions
+            )
+
+        # Attach node features as seperate fields. MAY NOT CONTAIN 'x'
+        graph = self._add_features_individually(
+            graph=graph, node_feature_names=node_feature_names
+        )
+
+        # Add GraphDefinition Stamp
+        graph["graph_definition"] = self.__class__.__name__
+        return graph
+
+    def _validate_input(
+        self, node_features: np.array, node_feature_names: List[str]
+    ) -> None:
+
+        # node feature matrix dimension check
+        assert node_features.shape[1] == len(node_feature_names)
+
+        # check that provided features for input is the same that the ´Graph´
+        # was instantiated with.
+        assert len(node_feature_names) == len(
+            self._node_feature_names
+        ), f"""Input features ({node_feature_names}) is not what {self.__class__.__name__} was instatiated with ({self._node_feature_names})"""
+        for idx in range(len(node_feature_names)):
+            assert (
+                node_feature_names[idx] == self._node_feature_names[idx]
+            ), """ Order of node features are not the same."""
+
+    def _add_loss_weights(
+        self,
+        graph: Data,
+        loss_weight_column: Optional[str] = None,
+        loss_weight: Optional[float] = None,
+        loss_weight_default_value: Optional[float] = None,
+    ) -> Data:
+        """Attempt to store a loss weight in the graph for use during training.
+
+        I.e. `graph[loss_weight_column] = loss_weight`
+
+        Args:
+            loss_weight: The non-negative weight to be stored.
+            graph: Data object representing the event.
+            loss_weight_column: The name under which the weight is stored in
+                                 the graph.
+            loss_weight_default_value: The default value used if
+                                        none was retrieved.
+
+        Returns:
+            A graph with loss weight added, if available.
+        """
+        # Add loss weight to graph.
+        if loss_weight is not None and loss_weight_column is not None:
+            # No loss weight was retrieved, i.e., it is missing for the current
+            # event.
+            if loss_weight < 0:
+                if loss_weight_default_value is None:
+                    raise ValueError(
+                        "At least one event is missing an entry in "
+                        f"{loss_weight_column} "
+                        "but loss_weight_default_value is None."
+                    )
+                graph[loss_weight_column] = torch.tensor(
+                    self._loss_weight_default_value, dtype=self._dtype
+                ).reshape(-1, 1)
+            else:
+                graph[loss_weight_column] = torch.tensor(
+                    loss_weight, dtype=self._dtype
+                ).reshape(-1, 1)
+        return graph
+
+    def _add_truth(
+        self, graph: Data, truth_dicts: List[Dict[str, Any]]
+    ) -> Data:
+        """Add truth labels from ´truth_dicts´ to ´graph´.
+
+        I.e. ´graph[key] = truth_dict[key]´
+
+
+        Args:
+            graph: graph where the label will be stored
+            truth_dicts: dictionary containing the labels
+
+        Returns:
+            graph with labels
+        """
+        # Write attributes, either target labels, truth info or original
+        # features.
+        for truth_dict in truth_dicts:
+            for key, value in truth_dict.items():
+                try:
+                    graph[key] = torch.tensor(value)
+                except TypeError:
+                    # Cannot convert `value` to Tensor due to its data type,
+                    # e.g. `str`.
+                    self.debug(
+                        (
+                            f"Could not assign `{key}` with type "
+                            f"'{type(value).__name__}' as attribute to graph."
+                        )
+                    )
+        return graph
+
+    def _add_features_individually(
+        self,
+        graph: Data,
+        node_feature_names: List[str],
+    ) -> Data:
+        # Additionally add original features as (static) attributes
+        graph.features = node_feature_names
+        for index, feature in enumerate(node_feature_names):
+            if feature not in ["x"]:  # reserved for node features.
+                graph[feature] = graph.x[:, index].detach()
+            else:
+                self.warnonce(
+                    """Cannot assign graph['x']. This field is reserved for node features. Please rename your input feature."""
+                )
+        return graph
+
+    def _add_custom_labels(
+        self,
+        graph: Data,
+        custom_label_functions: Dict[str, Callable[..., Any]],
+    ) -> Data:
+        # Add custom labels to the graph
+        for key, fn in custom_label_functions.items():
+            graph[key] = fn(graph)
+        return graph
diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9c3ff983f655f12d7ed5d676d49bb2ef71dc9a3
--- /dev/null
+++ b/src/graphnet/models/graphs/graphs.py
@@ -0,0 +1,47 @@
+"""A module containing different graph representations in GraphNeT."""
+
+from typing import List, Optional
+import torch
+
+from graphnet.utilities.config import save_model_config
+from .graph_definition import GraphDefinition
+from graphnet.models.detector import Detector
+from graphnet.models.graphs.edges import EdgeDefinition, KNNEdges
+from graphnet.models.graphs.nodes import NodeDefinition
+
+
+class KNNGraph(GraphDefinition):
+    """A Graph representation where Edges are drawn to nearest neighbours."""
+
+    @save_model_config
+    def __init__(
+        self,
+        detector: Detector,
+        node_definition: NodeDefinition,
+        node_feature_names: Optional[List[str]] = None,
+        dtype: Optional[torch.dtype] = None,
+        nb_nearest_neighbours: int = 8,
+        columns: List[int] = [0, 1, 2],
+    ) -> None:
+        """Construct k-nn graph representation.
+
+        Args:
+            detector: Detector that represents your data.
+            node_definition: Definition of nodes in the graph.
+            node_feature_names: Name of node features.
+            dtype: data type for node features.
+            nb_nearest_neighbours: Number of edges for each node. Defaults to 8.
+            columns: node feature columns used for distance calculation
+            . Defaults to [0, 1, 2].
+        """
+        # Base class constructor
+        super().__init__(
+            detector=detector,
+            node_definition=node_definition,
+            edge_definition=KNNEdges(
+                nb_nearest_neighbours=nb_nearest_neighbours,
+                columns=columns,
+            ),
+            dtype=dtype,
+            node_feature_names=node_feature_names,
+        )
diff --git a/src/graphnet/models/graphs/nodes/__init__.py b/src/graphnet/models/graphs/nodes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..05194b61acfdf76e803606fb29e62b540619edac
--- /dev/null
+++ b/src/graphnet/models/graphs/nodes/__init__.py
@@ -0,0 +1,8 @@
+"""Modules for constructing graphs.
+
+´GraphDefinition´ defines the nodes and their features,  and contains general
+graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes
+and their features.
+"""
+
+from .nodes import NodeDefinition, NodesAsPulses
diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..afe0dcfce2be1e012984eb8e77be24d69d0a1b9a
--- /dev/null
+++ b/src/graphnet/models/graphs/nodes/nodes.py
@@ -0,0 +1,74 @@
+"""Class(es) for building/connecting graphs."""
+
+from typing import List
+from abc import abstractmethod
+
+import torch
+from torch_geometric.nn import knn_graph, radius_graph
+from torch_geometric.data import Data
+import numpy as np
+
+from graphnet.utilities.decorators import final
+from graphnet.utilities.config import save_model_config
+from graphnet.models import Model
+
+
+class NodeDefinition(Model):  # pylint: disable=too-few-public-methods
+    """Base class for graph building."""
+
+    @save_model_config
+    def __init__(self) -> None:
+        """Construct `Detector`."""
+        # Base class constructor
+        super().__init__(name=__name__, class_name=self.__class__.__name__)
+
+    @final
+    def forward(self, x: torch.tensor) -> Data:
+        """Construct nodes from raw node features.
+
+        Args:
+            x: standardized node features with shape ´[num_pulses, d]´,
+            where ´d´ is the number of node features.
+
+        Returns:
+            graph: a graph without edges
+        """
+        graph = self._construct_nodes(x)
+        return graph
+
+    @property
+    def nb_outputs(self) -> int:
+        """Return number of output features.
+
+        This the default, but may be overridden by specific inheriting classes.
+        """
+        return self.nb_inputs
+
+    @final
+    def set_number_of_inputs(self, node_feature_names: List[str]) -> None:
+        """Return number of inputs expected by node definition.
+
+        Args:
+            node_feature_names: name of each node feature column.
+        """
+        assert isinstance(node_feature_names, list)
+        self.nb_inputs = len(node_feature_names)
+
+    @abstractmethod
+    def _construct_nodes(self, x: torch.tensor) -> Data:
+        """Construct nodes from raw node features ´x´.
+
+        Args:
+            x: standardized node features with shape ´[num_pulses, d]´,
+            where ´d´ is the number of node features.
+
+        Returns:
+            graph: graph without edges.
+        """
+
+
+class NodesAsPulses(NodeDefinition):
+    """Represent each measured pulse of Cherenkov Radiation as a node."""
+
+    def _construct_nodes(self, x: torch.tensor) -> Data:
+        return Data(x=x)
diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py
index 41b70bb269f49ffede804dc7b542be27676787bb..01fa0a574bf41ff05c7bfef51761cfeb16b042ed 100644
--- a/src/graphnet/models/standard_model.py
+++ b/src/graphnet/models/standard_model.py
@@ -12,7 +12,7 @@ import pandas as pd
 
 from graphnet.models.coarsening import Coarsening
 from graphnet.utilities.config import save_model_config
-from graphnet.models.detector.detector import Detector
+from graphnet.models.graphs import GraphDefinition
 from graphnet.models.gnn.gnn import GNN
 from graphnet.models.model import Model
 from graphnet.models.task import Task
@@ -29,7 +29,7 @@ class StandardModel(Model):
     def __init__(
         self,
         *,
-        detector: Detector,
+        graph_definition: GraphDefinition,
         gnn: GNN,
         tasks: Union[Task, List[Task]],
         coarsening: Optional[Coarsening] = None,
@@ -48,12 +48,12 @@ class StandardModel(Model):
             tasks = [tasks]
         assert isinstance(tasks, (list, tuple))
         assert all(isinstance(task, Task) for task in tasks)
-        assert isinstance(detector, Detector)
+        assert isinstance(graph_definition, GraphDefinition)
         assert isinstance(gnn, GNN)
         assert coarsening is None or isinstance(coarsening, Coarsening)
 
         # Member variable(s)
-        self._detector = detector
+        self._graph_definition = graph_definition
         self._gnn = gnn
         self._tasks = ModuleList(tasks)
         self._coarsening = coarsening
@@ -101,7 +101,7 @@ class StandardModel(Model):
         """Forward pass, chaining model components."""
         if self._coarsening:
             data = self._coarsening(data)
-        data = self._detector(data)
+        assert isinstance(data, Data)
         x = self._gnn(data)
         preds = [task(x) for task in self._tasks]
         return preds
diff --git a/src/graphnet/models/utils.py b/src/graphnet/models/utils.py
index 942e76caa81898858263019a37046430685fe824..e1ef7956c150da86992fa23e8f44c4a38bca3633 100644
--- a/src/graphnet/models/utils.py
+++ b/src/graphnet/models/utils.py
@@ -1,6 +1,6 @@
 """Utility functions for `graphnet.models`."""
 
-from typing import List, Tuple
+from typing import List, Tuple, Union
 from torch_geometric.nn import knn_graph
 from torch_geometric.data import Batch
 import torch
diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py
index 52b7634e85817953c601df65363dd5280c08567b..2578ff9a65875e6a4937d95273feff46922f4170 100644
--- a/src/graphnet/training/utils.py
+++ b/src/graphnet/training/utils.py
@@ -12,10 +12,11 @@ from torch.utils.data import DataLoader
 from torch_geometric.data import Batch, Data
 
 from graphnet.data.dataset import Dataset
-from graphnet.data.sqlite import SQLiteDataset
-from graphnet.data.parquet import ParquetDataset
+from graphnet.data.dataset import SQLiteDataset
+from graphnet.data.dataset import ParquetDataset
 from graphnet.models import Model
 from graphnet.utilities.logging import Logger
+from graphnet.models.graphs import GraphDefinition
 
 
 def collate_fn(graphs: List[Data]) -> Batch:
@@ -31,6 +32,7 @@ def collate_fn(graphs: List[Data]) -> Batch:
 def make_dataloader(
     db: str,
     pulsemaps: Union[str, List[str]],
+    graph_definition: Optional[GraphDefinition],
     features: List[str],
     truth: List[str],
     *,
@@ -66,6 +68,7 @@ def make_dataloader(
         loss_weight_table=loss_weight_table,
         loss_weight_column=loss_weight_column,
         index_column=index_column,
+        graph_definition=graph_definition,
     )
 
     # adds custom labels to dataset
@@ -89,6 +92,7 @@ def make_dataloader(
 # @TODO: Remove in favour of DataLoader{,.from_dataset_config}
 def make_train_validation_dataloader(
     db: str,
+    graph_definition: Optional[GraphDefinition],
     selection: Optional[List[int]],
     pulsemaps: Union[str, List[str]],
     features: List[str],
@@ -122,19 +126,21 @@ def make_train_validation_dataloader(
         dataset: Dataset
         if db.endswith(".db"):
             dataset = SQLiteDataset(
-                db,
-                pulsemaps,
-                features,
-                truth,
+                path=db,
+                graph_definition=graph_definition,
+                pulsemaps=pulsemaps,
+                features=features,
+                truth=truth,
                 truth_table=truth_table,
                 index_column=index_column,
             )
         elif db.endswith(".parquet"):
             dataset = ParquetDataset(
-                db,
-                pulsemaps,
-                features,
-                truth,
+                path=db,
+                graph_definition=graph_definition,
+                pulsemaps=pulsemaps,
+                features=features,
+                truth=truth,
                 truth_table=truth_table,
                 index_column=index_column,
             )
@@ -179,6 +185,7 @@ def make_train_validation_dataloader(
         loss_weight_table=loss_weight_table,
         index_column=index_column,
         labels=labels,
+        graph_definition=graph_definition,
     )
 
     training_dataloader = make_dataloader(
diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py
index fc5beb44b01d429f173ce3e9b8d28f12d3ceaef4..bb9ed26785d498a9882d5f84077d87bf88ebc800 100644
--- a/src/graphnet/utilities/config/dataset_config.py
+++ b/src/graphnet/utilities/config/dataset_config.py
@@ -2,6 +2,7 @@
 
 from functools import wraps
 from typing import (
+    TYPE_CHECKING,
     Any,
     Callable,
     Dict,
@@ -14,6 +15,12 @@ from graphnet.utilities.config.base_config import (
     BaseConfig,
     get_all_argument_values,
 )
+from graphnet.utilities.config.parsing import traverse_and_apply
+from .model_config import ModelConfig
+
+if TYPE_CHECKING:
+    from graphnet.models import Model
+
 
 BACKEND_LOOKUP = {
     "db": "sqlite",
@@ -45,8 +52,8 @@ class DatasetConfig(BaseConfig):
     loss_weight_table: Optional[str] = None
     loss_weight_column: Optional[str] = None
     loss_weight_default_value: Optional[float] = None
-
     seed: Optional[int] = None
+    graph_definition: Any = None
 
     def __init__(self, **data: Any) -> None:
         """Construct `DataConfig`.
@@ -139,8 +146,8 @@ class DatasetConfig(BaseConfig):
     @property
     def _dataset_class(self) -> type:
         """Return the `Dataset` class implementation for this configuration."""
-        from graphnet.data.sqlite import SQLiteDataset
-        from graphnet.data.parquet import ParquetDataset
+        from graphnet.data.dataset.sqlite import SQLiteDataset
+        from graphnet.data.dataset.parquet import ParquetDataset
 
         dataset_class = {
             "sqlite": SQLiteDataset,
@@ -153,6 +160,17 @@ class DatasetConfig(BaseConfig):
 def save_dataset_config(init_fn: Callable) -> Callable:
     """Save the arguments to `__init__` functions as member `DatasetConfig`."""
 
+    def _replace_model_instance_with_config(
+        obj: Union["Model", Any]
+    ) -> Union[ModelConfig, Any]:
+        """Replace `Model` instances in `obj` with their `ModelConfig`."""
+        from graphnet.models import Model
+
+        if isinstance(obj, Model):
+            return obj.config
+        else:
+            return obj
+
     @wraps(init_fn)
     def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
         """Set `DatasetConfig` after calling `init_fn`."""
@@ -162,6 +180,9 @@ def save_dataset_config(init_fn: Callable) -> Callable:
         # Get all argument values, including defaults
         cfg = get_all_argument_values(init_fn, *args, **kwargs)
 
+        # Handle nested `Model`s, etc.
+        cfg = traverse_and_apply(cfg, _replace_model_instance_with_config)
+
         # Add `DatasetConfig` as member variables
         self._config = DatasetConfig(**cfg)
 
diff --git a/src/graphnet/utilities/config/model_config.py b/src/graphnet/utilities/config/model_config.py
index d4733ca6863a27979bcc55e2a5812bd143409ba1..21e18c104fa321fec027fdc91137d1ebdf0635e3 100644
--- a/src/graphnet/utilities/config/model_config.py
+++ b/src/graphnet/utilities/config/model_config.py
@@ -133,7 +133,7 @@ class ModelConfig(BaseConfig):
             fn_kwargs={"trust": trust},
         )
 
-        # Construct model based on arguments
+        # Construct model based on
         return namespace_classes[self.class_name](**arguments)
 
     @classmethod
diff --git a/tests/utilities/test_dataset_config.py b/tests/utilities/test_dataset_config.py
index a09d0e5ada778ffe4132a9fb02132134dc9c83c1..88f3b3f1cb3bc6f24ab56c740d275b0192478fbe 100644
--- a/tests/utilities/test_dataset_config.py
+++ b/tests/utilities/test_dataset_config.py
@@ -13,8 +13,8 @@ import graphnet
 import graphnet.constants
 from graphnet.data.constants import FEATURES, TRUTH
 from graphnet.data.dataset import Dataset
-from graphnet.data.parquet import ParquetDataset
-from graphnet.data.sqlite import SQLiteDataset
+from graphnet.data.dataset import ParquetDataset
+from graphnet.data.dataset import SQLiteDataset
 from graphnet.utilities.config import DatasetConfig
 
 CONFIG_PATHS = {