diff --git a/configs/models/dynedge_energy_example.yml b/configs/models/dynedge_energy_example.yml
index 02d647f0c70eaa9d4bdce7184129f56f31809e77..d007af51f5eb8e83235c6b670b867c273e8e7027 100644
--- a/configs/models/dynedge_energy_example.yml
+++ b/configs/models/dynedge_energy_example.yml
@@ -20,6 +20,7 @@ arguments:
         nb_neighbours: 8
         post_processing_layer_sizes: null
         readout_layer_sizes: null
+        use_graph_normalization: null
       class_name: DynEdge
   optimizer_class: '!class torch.optim.adam Adam'
   optimizer_kwargs: {eps: 0.001, lr: 1e-05}
diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py
index 4e9e07b6520dac80a6bd7649fe60b7273d76e790..4b20158cba3429fc09dd26950b28fad376e66ebc 100644
--- a/src/graphnet/models/gnn/dynedge.py
+++ b/src/graphnet/models/gnn/dynedge.py
@@ -3,6 +3,7 @@ from typing import List, Optional, Sequence, Tuple, Union
 
 import torch
 from torch import Tensor, LongTensor
+import torch_geometric
 from torch_geometric.data import Data
 from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum
 
@@ -34,6 +35,7 @@ class DynEdge(GNN):
         readout_layer_sizes: Optional[List[int]] = None,
         global_pooling_schemes: Optional[Union[str, List[str]]] = None,
         add_global_variables_after_pooling: bool = False,
+        use_graph_normalization: bool = False,
     ):
         """Construct `DynEdge`.
 
@@ -67,6 +69,7 @@ class DynEdge(GNN):
                 after global pooling. The alternative is to  added (distribute)
                 them to the individual nodes before any convolutional
                 operations.
+            use_graph_normalization: Whether to use graph normalization on nodes
         """
         # Latent feature subset for computing nearest neighbours in DynEdge.
         if features_subset is None:
@@ -151,9 +154,16 @@ class DynEdge(GNN):
             add_global_variables_after_pooling
         )
 
+        self._use_graph_normalization = use_graph_normalization
+
         # Base class constructor
         super().__init__(nb_inputs, self._readout_layer_sizes[-1])
 
+        # GraphNorm
+        self._graphnorm = torch_geometric.nn.norm.GraphNorm(
+            nb_inputs, eps=1e-5
+        )
+
         # Remaining member variables()
         self._activation = torch.nn.LeakyReLU()
         self._nb_inputs = nb_inputs
@@ -276,6 +286,9 @@ class DynEdge(GNN):
         # Convenience variables
         x, edge_index, batch = data.x, data.edge_index, data.batch
 
+        if self._use_graph_normalization:
+            x = self._graphnorm(x, batch)
+
         global_variables = self._calculate_global_variables(
             x,
             edge_index,