From 13606161e8dab671cf0d001b1254333c8067c533 Mon Sep 17 00:00:00 2001 From: Morten Holm <Volumunox@gmail.com> Date: Sun, 5 Feb 2023 15:02:58 +0100 Subject: [PATCH 1/4] GraphNorm draft --- src/graphnet/models/gnn/dynedge.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index 4e9e07b65..975581848 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -3,6 +3,7 @@ from typing import List, Optional, Sequence, Tuple, Union import torch from torch import Tensor, LongTensor +import torch_geometric from torch_geometric.data import Data from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum @@ -34,6 +35,7 @@ class DynEdge(GNN): readout_layer_sizes: Optional[List[int]] = None, global_pooling_schemes: Optional[Union[str, List[str]]] = None, add_global_variables_after_pooling: bool = False, + use_graph_normalization: bool = False, ): """Construct `DynEdge`. @@ -67,6 +69,7 @@ class DynEdge(GNN): after global pooling. The alternative is to added (distribute) them to the individual nodes before any convolutional operations. + use_graph_normalization: Whether to use graph normalization on nodes """ # Latent feature subset for computing nearest neighbours in DynEdge. if features_subset is None: @@ -151,6 +154,8 @@ class DynEdge(GNN): add_global_variables_after_pooling ) + self._use_graph_normalization = use_graph_normalization + # Base class constructor super().__init__(nb_inputs, self._readout_layer_sizes[-1]) @@ -276,6 +281,11 @@ class DynEdge(GNN): # Convenience variables x, edge_index, batch = data.x, data.edge_index, data.batch + if self._use_graph_normalization: + x = torch_geometric.nn.norm.GraphNorm(x.size(-1), eps=1e-5)( + x, batch + ) + global_variables = self._calculate_global_variables( x, edge_index, -- GitLab From 11a1333c6b4dd618712eee3e9a1b9a6fa65ca78d Mon Sep 17 00:00:00 2001 From: Morten Holm <Volumunox@gmail.com> Date: Tue, 7 Feb 2023 09:25:59 +0100 Subject: [PATCH 2/4] fixed device allocation --- src/graphnet/models/gnn/dynedge.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index 975581848..4b20158cb 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -159,6 +159,11 @@ class DynEdge(GNN): # Base class constructor super().__init__(nb_inputs, self._readout_layer_sizes[-1]) + # GraphNorm + self._graphnorm = torch_geometric.nn.norm.GraphNorm( + nb_inputs, eps=1e-5 + ) + # Remaining member variables() self._activation = torch.nn.LeakyReLU() self._nb_inputs = nb_inputs @@ -282,9 +287,7 @@ class DynEdge(GNN): x, edge_index, batch = data.x, data.edge_index, data.batch if self._use_graph_normalization: - x = torch_geometric.nn.norm.GraphNorm(x.size(-1), eps=1e-5)( - x, batch - ) + x = self._graphnorm(x, batch) global_variables = self._calculate_global_variables( x, -- GitLab From 67d2ef1cf3b4383447d4ec8cce51c4a1e7e0984a Mon Sep 17 00:00:00 2001 From: Morten Holm <Volumunox@gmail.com> Date: Tue, 7 Feb 2023 10:14:46 +0100 Subject: [PATCH 3/4] corrected model spelling of GraphNorm --- configs/models/dynedge_energy_example.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/models/dynedge_energy_example.yml b/configs/models/dynedge_energy_example.yml index 02d647f0c..3a4099434 100644 --- a/configs/models/dynedge_energy_example.yml +++ b/configs/models/dynedge_energy_example.yml @@ -20,6 +20,7 @@ arguments: nb_neighbours: 8 post_processing_layer_sizes: null readout_layer_sizes: null + use_graph_normalization: True class_name: DynEdge optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: {eps: 0.001, lr: 1e-05} -- GitLab From 064fc115869f2e62a1fd2702418e7839491e1d0b Mon Sep 17 00:00:00 2001 From: Morten Holm <33733987+MortenHolmRep@users.noreply.github.com> Date: Wed, 8 Feb 2023 18:29:23 +0100 Subject: [PATCH 4/4] set graph normalisation to the default value Default value is False --- configs/models/dynedge_energy_example.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/models/dynedge_energy_example.yml b/configs/models/dynedge_energy_example.yml index 3a4099434..d007af51f 100644 --- a/configs/models/dynedge_energy_example.yml +++ b/configs/models/dynedge_energy_example.yml @@ -20,7 +20,7 @@ arguments: nb_neighbours: 8 post_processing_layer_sizes: null readout_layer_sizes: null - use_graph_normalization: True + use_graph_normalization: null class_name: DynEdge optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: {eps: 0.001, lr: 1e-05} -- GitLab