diff --git a/configs/models/dynedge_energy_example.yml b/configs/models/dynedge_energy_example.yml index 02d647f0c70eaa9d4bdce7184129f56f31809e77..d007af51f5eb8e83235c6b670b867c273e8e7027 100644 --- a/configs/models/dynedge_energy_example.yml +++ b/configs/models/dynedge_energy_example.yml @@ -20,6 +20,7 @@ arguments: nb_neighbours: 8 post_processing_layer_sizes: null readout_layer_sizes: null + use_graph_normalization: null class_name: DynEdge optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: {eps: 0.001, lr: 1e-05} diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index 4e9e07b6520dac80a6bd7649fe60b7273d76e790..4b20158cba3429fc09dd26950b28fad376e66ebc 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -3,6 +3,7 @@ from typing import List, Optional, Sequence, Tuple, Union import torch from torch import Tensor, LongTensor +import torch_geometric from torch_geometric.data import Data from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum @@ -34,6 +35,7 @@ class DynEdge(GNN): readout_layer_sizes: Optional[List[int]] = None, global_pooling_schemes: Optional[Union[str, List[str]]] = None, add_global_variables_after_pooling: bool = False, + use_graph_normalization: bool = False, ): """Construct `DynEdge`. @@ -67,6 +69,7 @@ class DynEdge(GNN): after global pooling. The alternative is to added (distribute) them to the individual nodes before any convolutional operations. + use_graph_normalization: Whether to use graph normalization on nodes """ # Latent feature subset for computing nearest neighbours in DynEdge. if features_subset is None: @@ -151,9 +154,16 @@ class DynEdge(GNN): add_global_variables_after_pooling ) + self._use_graph_normalization = use_graph_normalization + # Base class constructor super().__init__(nb_inputs, self._readout_layer_sizes[-1]) + # GraphNorm + self._graphnorm = torch_geometric.nn.norm.GraphNorm( + nb_inputs, eps=1e-5 + ) + # Remaining member variables() self._activation = torch.nn.LeakyReLU() self._nb_inputs = nb_inputs @@ -276,6 +286,9 @@ class DynEdge(GNN): # Convenience variables x, edge_index, batch = data.x, data.edge_index, data.batch + if self._use_graph_normalization: + x = self._graphnorm(x, batch) + global_variables = self._calculate_global_variables( x, edge_index,