From b318e195f6b0b75ce373ce5e35e7d75043265b67 Mon Sep 17 00:00:00 2001
From: Rasmus Oersoe <rahn@outlook.dk>
Date: Sat, 3 Feb 2024 19:10:52 +0100
Subject: [PATCH 1/2] bugfix + minkowski mypy

---
 src/graphnet/models/graphs/edges/minkowski.py | 11 +++---
 src/graphnet/models/standard_model.py         | 39 +++++++++++++------
 2 files changed, 34 insertions(+), 16 deletions(-)

diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py
index 5d1134ec5..2526de1cb 100644
--- a/src/graphnet/models/graphs/edges/minkowski.py
+++ b/src/graphnet/models/graphs/edges/minkowski.py
@@ -69,12 +69,13 @@ class MinkowskiKNNEdges(EdgeDefinition):
         row = []
         col = []
         for batch in range(x.shape[0]):
+            x_masked = x[batch][mask[batch]]
             distance_mat = compute_minkowski_distance_mat(
-                x_masked := x[batch][mask[batch]],
-                x_masked,
-                self.c,
-                self.space_coords,
-                self.time_coord,
+                x=x_masked,
+                y=x_masked,
+                c=self.c,
+                space_coords=self.space_coords,
+                time_coord=self.time_coord,
             )
             num_points = x_masked.shape[0]
             num_edges = min(self.nb_nearest_neighbours, num_points)
diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py
index 663664996..098f63d98 100644
--- a/src/graphnet/models/standard_model.py
+++ b/src/graphnet/models/standard_model.py
@@ -410,6 +410,9 @@ class StandardModel(Model):
             f"number of output columns ({predictions.shape[1]}) don't match."
         )
 
+        # Check if predictions are on event- or pulse-level
+        pulse_level_predictions = len(predictions) > len(dataloader.dataset)
+
         # Get additional attributes
         attributes: Dict[str, List[np.ndarray]] = OrderedDict(
             [(attr, []) for attr in additional_attributes]
@@ -423,25 +426,39 @@ class StandardModel(Model):
                 # Check if node level predictions
                 # If true, additional attributes are repeated
                 # to make dimensions fit
-                if len(predictions) != len(dataloader.dataset):
+                if pulse_level_predictions:
                     if len(attribute) < np.sum(
                         batch.n_pulses.detach().cpu().numpy()
                     ):
                         attribute = np.repeat(
                             attribute, batch.n_pulses.detach().cpu().numpy()
                         )
-                        try:
-                            assert len(attribute) == len(batch.x)
-                        except AssertionError:
-                            self.warning_once(
-                                "Could not automatically adjust length"
-                                f"of additional attribute {attr} to match length of"
-                                f"predictions. Make sure {attr} is a graph-level or"
-                                "node-level attribute. Attribute skipped."
-                            )
-                            pass
                 attributes[attr].extend(attribute)
 
+        # Confirm that attributes match length of predictions
+        skip_attributes = []
+        for attr in attributes.keys():
+            try:
+                assert len(attributes[attr]) == len(predictions)
+            except AssertionError:
+                self.warning_once(
+                    "Could not automatically adjust length"
+                    f" of additional attribute '{attr}' to match length of"
+                    f" predictions.This error can be caused by heavy"
+                    " disagreement between number of examples in the"
+                    " dataset vs. actual events in the dataloader, e.g. "
+                    " heavy filtering of events in `collate_fn` passed to"
+                    " `dataloader`. This can also be caused by requesting"
+                    " pulse-level attributes for `Task`s that produce"
+                    " event-level predictions. Attribute skipped."
+                )
+                skip_attributes.append(attr)
+
+        # Remove bad attributes
+        for attr in skip_attributes:
+            attributes.pop(attr)
+            additional_attributes.remove(attr)
+
         data = np.concatenate(
             [predictions]
             + [
-- 
GitLab


From 36a2384ff06afabec9fc143ef3d548672b398761 Mon Sep 17 00:00:00 2001
From: Rasmus Oersoe <rahn@outlook.dk>
Date: Sat, 3 Feb 2024 19:52:56 +0100
Subject: [PATCH 2/2] check requirement links

---
 requirements/torch_cpu.txt   | 2 +-
 requirements/torch_gpu.txt   | 2 +-
 requirements/torch_macos.txt | 2 +-
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/requirements/torch_cpu.txt b/requirements/torch_cpu.txt
index 59e273288..babb4fb8e 100644
--- a/requirements/torch_cpu.txt
+++ b/requirements/torch_cpu.txt
@@ -1,2 +1,2 @@
 --find-links https://download.pytorch.org/whl/cpu
---find-links https://data.pyg.org/whl/torch-2.1.0+cpu.html
\ No newline at end of file
+--find-links https://data.pyg.org/whl/torch-2.2.0+cpu.html
\ No newline at end of file
diff --git a/requirements/torch_gpu.txt b/requirements/torch_gpu.txt
index 1f1abba3f..ddcb85038 100644
--- a/requirements/torch_gpu.txt
+++ b/requirements/torch_gpu.txt
@@ -1,4 +1,4 @@
 # Contains packages requirements for GPU installation
 --find-links https://download.pytorch.org/whl/torch_stable.html
 torch==2.1.0+cu118
---find-links https://data.pyg.org/whl/torch-2.1.0+cu118.html
+--find-links https://data.pyg.org/whl/torch-2.2.0+cu118.html
diff --git a/requirements/torch_macos.txt b/requirements/torch_macos.txt
index 3e9d75df4..2b5009a8e 100644
--- a/requirements/torch_macos.txt
+++ b/requirements/torch_macos.txt
@@ -1,2 +1,2 @@
 --find-links https://download.pytorch.org/whl/torch_stable.html
---find-links https://data.pyg.org/whl/torch-2.1.0+cpu.html
\ No newline at end of file
+--find-links https://data.pyg.org/whl/torch-2.2.0+cpu.html
\ No newline at end of file
-- 
GitLab