diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index dfb2cedacf97bdb090d8bb202155fc22e49d9093..c29867155ce3c662d594a07ef6d5e93486937e8e 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -71,7 +71,7 @@ class GraphDefinition(Model): def forward( # type: ignore self, - node_features: np.array, + node_features: np.ndarray, node_feature_names: List[str], truth_dicts: Optional[List[Dict[str, Any]]] = None, custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None, diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index afe0dcfce2be1e012984eb8e77be24d69d0a1b9a..b98ad375a359ae1427c7b05a04c4294453b13ecf 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -4,9 +4,7 @@ from typing import List from abc import abstractmethod import torch -from torch_geometric.nn import knn_graph, radius_graph from torch_geometric.data import Data -import numpy as np from graphnet.utilities.decorators import final from graphnet.utilities.config import save_model_config