From a3bbc42ba55d9a78d9f4798a40f50c91d8b70767 Mon Sep 17 00:00:00 2001
From: Rasmus Oersoe <rahn@outlook.dk>
Date: Sat, 4 Feb 2023 13:05:33 +0100
Subject: [PATCH 1/6] added add_inactive_sensor and geometry table arg

---
 src/graphnet/data/dataset.py | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/src/graphnet/data/dataset.py b/src/graphnet/data/dataset.py
index d07585223..02089651b 100644
--- a/src/graphnet/data/dataset.py
+++ b/src/graphnet/data/dataset.py
@@ -140,6 +140,8 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
         loss_weight_column: Optional[str] = None,
         loss_weight_default_value: Optional[float] = None,
         seed: Optional[int] = None,
+        add_inactive_sensors: bool = False,
+        geometry_table: str = "geometry_table",
     ):
         """Construct Dataset.
 
@@ -185,6 +187,8 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
                 subset of events when resolving a string-based selection (e.g.,
                 `"10000 random events ~ event_no % 5 > 0"` or `"20% random
                 events ~ event_no % 5 > 0"`).
+            add_inactive_sensors: If True, sensors that measured nothing during the event will be appended to the graph.
+            geometry_table: The sqlite table in which the detector geometry is stored.
         """
         # Check(s)
         if isinstance(pulsemaps, str):
@@ -205,6 +209,9 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
         self._index_column = index_column
         self._truth_table = truth_table
         self._loss_weight_default_value = loss_weight_default_value
+        self._add_inactive_sensors = add_inactive_sensors
+        if self._add_inactive_sensors:
+            self._set_geometry_table(geometry_table)
 
         if node_truth is not None:
             assert isinstance(node_truth_table, str)
-- 
GitLab


From a02a70792ccbb678047b70d972c5ddc0bf1f03f2 Mon Sep 17 00:00:00 2001
From: Rasmus Oersoe <rahn@outlook.dk>
Date: Sat, 4 Feb 2023 13:18:57 +0100
Subject: [PATCH 2/6] added sensor position column names as arg

---
 src/graphnet/data/dataset.py | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/src/graphnet/data/dataset.py b/src/graphnet/data/dataset.py
index 02089651b..c8bf8adac 100644
--- a/src/graphnet/data/dataset.py
+++ b/src/graphnet/data/dataset.py
@@ -142,6 +142,9 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
         seed: Optional[int] = None,
         add_inactive_sensors: bool = False,
         geometry_table: str = "geometry_table",
+        sensor_x_position: str = "dom_x",
+        sensor_y_position: str = "dom_y",
+        sensor_z_position: str = "dom_z",
     ):
         """Construct Dataset.
 
@@ -189,6 +192,9 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
                 events ~ event_no % 5 > 0"`).
             add_inactive_sensors: If True, sensors that measured nothing during the event will be appended to the graph.
             geometry_table: The sqlite table in which the detector geometry is stored.
+            sensor_x_position: column name of x-coordinate of sensors. e.g. "dom_x",
+            sensor_y_position: column name of y-coordinate of sensors. e.g. "dom_y",
+            sensor_z_position: column name of z-coordinate of sensors. e.g. "dom_z",
         """
         # Check(s)
         if isinstance(pulsemaps, str):
@@ -210,6 +216,11 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
         self._truth_table = truth_table
         self._loss_weight_default_value = loss_weight_default_value
         self._add_inactive_sensors = add_inactive_sensors
+        self._sensor_position = {
+            "sensor_x_position_column": sensor_x_position,
+            "sensor_y_position_column": sensor_y_position,
+            "sensor_z_position_column": sensor_z_position,
+        }
         if self._add_inactive_sensors:
             self._set_geometry_table(geometry_table)
 
@@ -314,6 +325,10 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
     ) -> Optional[int]:
         """Return a the event index corresponding to a `sequential_index`."""
 
+    @abstractmethod
+    def _setup_geometry_table(self) -> None:
+        """Must assign self._geometry_table."""
+
     @abstractmethod
     def query_table(
         self,
-- 
GitLab


From 3ce4420f74973b3e2cd0b1666dc0bd153f6a8a78 Mon Sep 17 00:00:00 2001
From: Rasmus Oersoe <rahn@outlook.dk>
Date: Sat, 4 Feb 2023 13:33:15 +0100
Subject: [PATCH 3/6] implemented _setup_geometry_table for sqlite dataset

---
 src/graphnet/data/dataset.py               |  2 +-
 src/graphnet/data/sqlite/sqlite_dataset.py | 18 ++++++++++++++++++
 2 files changed, 19 insertions(+), 1 deletion(-)

diff --git a/src/graphnet/data/dataset.py b/src/graphnet/data/dataset.py
index c8bf8adac..ef880d1c1 100644
--- a/src/graphnet/data/dataset.py
+++ b/src/graphnet/data/dataset.py
@@ -326,7 +326,7 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
         """Return a the event index corresponding to a `sequential_index`."""
 
     @abstractmethod
-    def _setup_geometry_table(self) -> None:
+    def _setup_geometry_table(self, geometry_table: str) -> None:
         """Must assign self._geometry_table."""
 
     @abstractmethod
diff --git a/src/graphnet/data/sqlite/sqlite_dataset.py b/src/graphnet/data/sqlite/sqlite_dataset.py
index e61623c46..1aac00747 100644
--- a/src/graphnet/data/sqlite/sqlite_dataset.py
+++ b/src/graphnet/data/sqlite/sqlite_dataset.py
@@ -147,3 +147,21 @@ class SQLiteDataset(Dataset):
                 self._all_connections_established = False
                 self._conn = None
         return self
+
+    def _setup_geometry_table(self, geometry_table: str) -> None:
+        if self._table_exists(geometry_table):
+            self._geoemtry_table = geometry_table
+        else:
+            assert (
+                1 == 2
+            ), f"Geometry table named {geometry_table} is not in the database {self._path}"
+
+    def _table_exists(self, geometry_table: str) -> bool:
+        assert isinstance(self._path, str)
+        with sqlite3.connect(self._path) as conn:
+            query = 'SELECT name FROM sqlite_master WHERE type == "name" '
+            all_tables = conn.execute(query).fetchall()
+        if geometry_table in all_tables:
+            return True
+        else:
+            return False
-- 
GitLab


From d41e362512b763c3a7ea1962fd2e86fa91043513 Mon Sep 17 00:00:00 2001
From: Rasmus Oersoe <rahn@outlook.dk>
Date: Sat, 4 Feb 2023 13:57:53 +0100
Subject: [PATCH 4/6] added conditional table query

---
 src/graphnet/data/sqlite/sqlite_dataset.py | 40 +++++++++++++++++++---
 1 file changed, 35 insertions(+), 5 deletions(-)

diff --git a/src/graphnet/data/sqlite/sqlite_dataset.py b/src/graphnet/data/sqlite/sqlite_dataset.py
index 1aac00747..1d02aafbb 100644
--- a/src/graphnet/data/sqlite/sqlite_dataset.py
+++ b/src/graphnet/data/sqlite/sqlite_dataset.py
@@ -3,6 +3,7 @@
 from typing import Any, List, Optional, Tuple, Union
 import pandas as pd
 import sqlite3
+import numpy as np
 
 from graphnet.data.dataset import Dataset, ColumnMissingException
 
@@ -49,6 +50,7 @@ class SQLiteDataset(Dataset):
         """Query table at a specific index, optionally with some selection."""
         # Check(s)
         if isinstance(columns, list):
+            n_features = len(columns)
             columns = ", ".join(columns)
 
         if not selection:  # I.e., `None` or `""`
@@ -67,11 +69,39 @@ class SQLiteDataset(Dataset):
                 combined_selections = (
                     f"{self._index_column} = {index} and {selection}"
                 )
-
-            result = self._conn.execute(
-                f"SELECT {columns} FROM {table} WHERE "
-                f"{combined_selections}"
-            ).fetchall()
+            if (
+                self._add_inactive_sensors
+                and self._sensor_position["x"] in columns
+                and n_features > 1
+            ):  # if this is a pulsemap query
+                active_query = f"select (CAST({self._sensor_position['x']} AS str) || '_' || CAST({self._sensor_position['y']} AS str) || '_' || CAST({self._sensor_position['z']} AS str)) as UID, {columns} from {table} where {self._index_column} = {index} and {selection}"
+                active_result = self._conn.execute(active_query).fetchall()
+                if len(columns.split(", ")) > 1:
+                    columns = ", ".join(
+                        columns.split(", ")[1:]
+                    )  # remove event_no because not in geometry table
+                query = f"select {columns} from {self._geometry_table} where UID not in {str(tuple(np.array(active_result)[:,0]))}"
+                inactive_result = self._conn.execute(query).fetchall()
+                active_result = np.asarray(active_result)[
+                    :, 2:
+                ].tolist()  # drops UID column & event_no
+
+                result = []
+                result.extend(active_result)
+                result.extend(inactive_result)
+                result = (
+                    np.concatenate(
+                        [np.repeat(index, len(result)).reshape(-1, 1), result],
+                        axis=1,
+                    )
+                    .astype("float64")
+                    .tolist()  # adds event_no again.
+                )
+            else:
+                result = self._conn.execute(
+                    f"SELECT {columns} FROM {table} WHERE "
+                    f"{combined_selections}"
+                ).fetchall()
         except sqlite3.OperationalError as e:
             if "no such column" in str(e):
                 raise ColumnMissingException(str(e))
-- 
GitLab


From 6cd8ac98f0523ba7eac5103e59ca62b38190cabc Mon Sep 17 00:00:00 2001
From: Rasmus Oersoe <rahn@outlook.dk>
Date: Sat, 4 Feb 2023 17:10:52 +0100
Subject: [PATCH 5/6] functionality for creating geometry_table

---
 src/graphnet/data/sqlite/sqlite_dataset.py   | 18 +++---
 src/graphnet/data/sqlite/sqlite_utilities.py | 60 +++++++++++++++++++-
 2 files changed, 65 insertions(+), 13 deletions(-)

diff --git a/src/graphnet/data/sqlite/sqlite_dataset.py b/src/graphnet/data/sqlite/sqlite_dataset.py
index 1d02aafbb..12e350c19 100644
--- a/src/graphnet/data/sqlite/sqlite_dataset.py
+++ b/src/graphnet/data/sqlite/sqlite_dataset.py
@@ -7,6 +7,8 @@ import numpy as np
 
 from graphnet.data.dataset import Dataset, ColumnMissingException
 
+from graphnet.data.sqlite.sqlite_utilities import database_table_exists
+
 
 class SQLiteDataset(Dataset):
     """Pytorch dataset for reading data from SQLite databases."""
@@ -179,19 +181,13 @@ class SQLiteDataset(Dataset):
         return self
 
     def _setup_geometry_table(self, geometry_table: str) -> None:
-        if self._table_exists(geometry_table):
+        """Assign the geometry table to self if it exists in the database."""
+        assert isinstance(self._path, str)
+        if database_table_exists(
+            database_path=self._path, table_name=geometry_table
+        ):
             self._geoemtry_table = geometry_table
         else:
             assert (
                 1 == 2
             ), f"Geometry table named {geometry_table} is not in the database {self._path}"
-
-    def _table_exists(self, geometry_table: str) -> bool:
-        assert isinstance(self._path, str)
-        with sqlite3.connect(self._path) as conn:
-            query = 'SELECT name FROM sqlite_master WHERE type == "name" '
-            all_tables = conn.execute(query).fetchall()
-        if geometry_table in all_tables:
-            return True
-        else:
-            return False
diff --git a/src/graphnet/data/sqlite/sqlite_utilities.py b/src/graphnet/data/sqlite/sqlite_utilities.py
index 23bae802d..ee2bffb89 100644
--- a/src/graphnet/data/sqlite/sqlite_utilities.py
+++ b/src/graphnet/data/sqlite/sqlite_utilities.py
@@ -1,13 +1,65 @@
 """SQLite-specific utility functions for use in `graphnet.data`."""
 
 import os.path
-from typing import List
+from typing import List, Optional, Dict
 
 import pandas as pd
 import sqlalchemy
 import sqlite3
 
 
+def add_geometry_table_to_database(
+    database: str,
+    pulsemap: str,
+    features_to_pad: List[str],
+    padding_value: int = 0,
+    additional_features: List[str] = ["rde", "pmt_area"],
+    sensor_x: str = "dom_x",
+    sensor_y: str = "dom_y",
+    sensor_z: str = "dom_z",
+    gcd_file: Optional[str] = None,
+    table_name: str = "geometry_table",
+) -> None:
+    """Add geometry table to database.
+
+    Args:
+        database: path to sqlite database
+        pulsemap: name of the pulsemap table
+        features_to_pad: list of column names that will be added to the dataframe after sqlite query. Will be padded.
+        padding_value: Value used for padding. Defaults to 0.
+        additional_features: additional features in pulsemap table that you want to include. Defaults to ["rde", "pmt_area"].
+        sensor_x: x-coordinate of sensor positions. Defaults to "dom_x".
+        sensor_y: y-coordinate of sensor positions. Defaults to "dom_y".
+        sensor_z: z-coordinate of sensor positions. Defaults to "dom_z".
+        gcd_file: Path to gcd file. Defaults to None.
+        table_name:  Name of the geometry table. . Defaults to "geometry_table".
+    """
+    if gcd_file is not None:
+        assert (
+            1 == 2
+        ), "Creation of geometry table from gcd file is not yet supported. Please make a pull request."
+    else:
+        additional_features_str = ", ".join(additional_features)
+        with sqlite3.connect(database) as con:
+            query = f"select distinct (CAST({sensor_x} AS str) || '_' || CAST({sensor_y} AS str) || '_' || CAST({sensor_z} AS str)) as UID, {sensor_x}, {sensor_y}, {sensor_z}, {additional_features_str} from {pulsemap}"
+            table = pd.read_sql(query, con)
+
+        for feature_to_pad in features_to_pad:
+            table[feature_to_pad] = padding_value
+
+    create_table(
+        table_name=table_name,
+        columns=table.columns,
+        database_path=database,
+        index_column="UID",
+        primary_key_type="STR",
+        integer_primary_key=True,
+    )
+
+    save_to_sql(df=table, table_name=table_name, database_path=database)
+    return
+
+
 def database_exists(database_path: str) -> bool:
     """Check whether database exists at `database_path`."""
     assert database_path.endswith(
@@ -79,6 +131,7 @@ def create_table(
     *,
     index_column: str = "event_no",
     default_type: str = "NOT NULL",
+    primary_key_type: str = "INTEGER",
     integer_primary_key: bool = True,
 ) -> None:
     """Create a table.
@@ -89,6 +142,7 @@ def create_table(
         database_path: Path to the database.
         index_column: Name of the index column.
         default_type: The type used for all non-index columns.
+        primary_key_type: the data type for the primary key. Defaults to INTEGER.
         integer_primary_key: Whether or not to create the `index_column` with
             the `INTEGER PRIMARY KEY` type. Such a column is required to have
             unique, integer values for each row. This is appropriate when the
@@ -103,7 +157,7 @@ def create_table(
         type_ = default_type
         if column == index_column:
             if integer_primary_key:
-                type_ = "INTEGER PRIMARY KEY NOT NULL"
+                type_ = f"{primary_key_type} PRIMARY KEY NOT NULL"
             else:
                 type_ = "NOT NULL"
 
@@ -134,6 +188,7 @@ def create_table_and_save_to_sql(
     index_column: str = "event_no",
     default_type: str = "NOT NULL",
     integer_primary_key: bool = True,
+    primary_key_type: str = "INTEGER",
 ) -> None:
     """Create table if it doesn't exist and save dataframe to it."""
     if not database_table_exists(database_path, table_name):
@@ -144,5 +199,6 @@ def create_table_and_save_to_sql(
             index_column=index_column,
             default_type=default_type,
             integer_primary_key=integer_primary_key,
+            primary_key_type=primary_key_type,
         )
     save_to_sql(df, table_name=table_name, database_path=database_path)
-- 
GitLab


From 8feef0ac789bf2e07eee2cca92be642abb67c3cc Mon Sep 17 00:00:00 2001
From: Rasmus Oersoe <rahn@outlook.dk>
Date: Sat, 4 Feb 2023 19:36:34 +0100
Subject: [PATCH 6/6] implemented get_inactive_sensors for sqlite dataset

---
 src/graphnet/data/dataset.py               | 40 ++++++++++++++++--
 src/graphnet/data/sqlite/sqlite_dataset.py | 47 ++++++++--------------
 2 files changed, 53 insertions(+), 34 deletions(-)

diff --git a/src/graphnet/data/dataset.py b/src/graphnet/data/dataset.py
index ef880d1c1..57f03875c 100644
--- a/src/graphnet/data/dataset.py
+++ b/src/graphnet/data/dataset.py
@@ -140,7 +140,7 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
         loss_weight_column: Optional[str] = None,
         loss_weight_default_value: Optional[float] = None,
         seed: Optional[int] = None,
-        add_inactive_sensors: bool = False,
+        include_inactive_sensors: bool = False,
         geometry_table: str = "geometry_table",
         sensor_x_position: str = "dom_x",
         sensor_y_position: str = "dom_y",
@@ -190,7 +190,7 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
                 subset of events when resolving a string-based selection (e.g.,
                 `"10000 random events ~ event_no % 5 > 0"` or `"20% random
                 events ~ event_no % 5 > 0"`).
-            add_inactive_sensors: If True, sensors that measured nothing during the event will be appended to the graph.
+            include_inactive_sensors: If True, sensors that measured nothing during the event will be appended to the graph.
             geometry_table: The sqlite table in which the detector geometry is stored.
             sensor_x_position: column name of x-coordinate of sensors. e.g. "dom_x",
             sensor_y_position: column name of y-coordinate of sensors. e.g. "dom_y",
@@ -215,13 +215,13 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
         self._index_column = index_column
         self._truth_table = truth_table
         self._loss_weight_default_value = loss_weight_default_value
-        self._add_inactive_sensors = add_inactive_sensors
+        self._include_inactive_sensors = include_inactive_sensors
         self._sensor_position = {
             "sensor_x_position_column": sensor_x_position,
             "sensor_y_position_column": sensor_y_position,
             "sensor_z_position_column": sensor_z_position,
         }
-        if self._add_inactive_sensors:
+        if self._include_inactive_sensors:
             self._set_geometry_table(geometry_table)
 
         if node_truth is not None:
@@ -329,6 +329,15 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
     def _setup_geometry_table(self, geometry_table: str) -> None:
         """Must assign self._geometry_table."""
 
+    @abstractmethod
+    def _get_inactive_sensors(
+        self,
+        features: List[Tuple[float, ...]],
+        columns: List[str],
+        sequential_index: int,
+    ) -> List[Tuple[float, ...]]:
+        """Add sensors that are inactive."""
+
     @abstractmethod
     def query_table(
         self,
@@ -502,6 +511,29 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC):
         else:
             node_truth = None
 
+        if self._include_inactive_sensors:
+            inactive_sensors = self._get_inactive_sensors(
+                features=features,
+                columns=self._features,
+                sequential_index=sequential_index,
+            )
+            result = []
+            result.extend(features)
+            result.extend(inactive_sensors)
+            result = (
+                np.concatenate(
+                    [
+                        np.repeat(sequential_index, len(result)).reshape(
+                            -1, 1
+                        ),
+                        result,
+                    ],
+                    axis=1,
+                )
+                .astype(self._dtype)
+                .tolist()  # adds event_no again.
+            )
+
         loss_weight: Optional[float] = None  # Default
         if self._loss_weight_column is not None:
             assert self._loss_weight_table is not None
diff --git a/src/graphnet/data/sqlite/sqlite_dataset.py b/src/graphnet/data/sqlite/sqlite_dataset.py
index 12e350c19..6ed79d599 100644
--- a/src/graphnet/data/sqlite/sqlite_dataset.py
+++ b/src/graphnet/data/sqlite/sqlite_dataset.py
@@ -52,7 +52,6 @@ class SQLiteDataset(Dataset):
         """Query table at a specific index, optionally with some selection."""
         # Check(s)
         if isinstance(columns, list):
-            n_features = len(columns)
             columns = ", ".join(columns)
 
         if not selection:  # I.e., `None` or `""`
@@ -71,35 +70,6 @@ class SQLiteDataset(Dataset):
                 combined_selections = (
                     f"{self._index_column} = {index} and {selection}"
                 )
-            if (
-                self._add_inactive_sensors
-                and self._sensor_position["x"] in columns
-                and n_features > 1
-            ):  # if this is a pulsemap query
-                active_query = f"select (CAST({self._sensor_position['x']} AS str) || '_' || CAST({self._sensor_position['y']} AS str) || '_' || CAST({self._sensor_position['z']} AS str)) as UID, {columns} from {table} where {self._index_column} = {index} and {selection}"
-                active_result = self._conn.execute(active_query).fetchall()
-                if len(columns.split(", ")) > 1:
-                    columns = ", ".join(
-                        columns.split(", ")[1:]
-                    )  # remove event_no because not in geometry table
-                query = f"select {columns} from {self._geometry_table} where UID not in {str(tuple(np.array(active_result)[:,0]))}"
-                inactive_result = self._conn.execute(query).fetchall()
-                active_result = np.asarray(active_result)[
-                    :, 2:
-                ].tolist()  # drops UID column & event_no
-
-                result = []
-                result.extend(active_result)
-                result.extend(inactive_result)
-                result = (
-                    np.concatenate(
-                        [np.repeat(index, len(result)).reshape(-1, 1), result],
-                        axis=1,
-                    )
-                    .astype("float64")
-                    .tolist()  # adds event_no again.
-                )
-            else:
                 result = self._conn.execute(
                     f"SELECT {columns} FROM {table} WHERE "
                     f"{combined_selections}"
@@ -111,6 +81,23 @@ class SQLiteDataset(Dataset):
                 raise e
         return result
 
+    def _get_inactive_sensors(
+        self,
+        features: List[Tuple[float, ...]],
+        columns: List[str],
+        sequential_index: int,
+    ) -> List[Tuple[float, ...]]:
+        assert self._conn
+        index = self._get_event_index(sequential_index)
+        active_query = f"select (CAST({self._sensor_position['x']} AS str) || '_' || CAST({self._sensor_position['y']} AS str) || '_' || CAST({self._sensor_position['z']} AS str)) as UID from {self._pulsemaps[0]} where {self._index_column} = {index} and {self._selection}"
+        active_result = self._conn.execute(active_query).fetchall()
+        columns_str = ", ".join(
+            columns
+        )  # remove event_no because not in geometry table
+        query = f"select {columns_str} from {self._geometry_table} where UID not in {str(tuple(np.array(active_result)))}"
+        inactive_result = self._conn.execute(query).fetchall()
+        return inactive_result
+
     def _get_all_indices(self) -> List[int]:
         self._establish_connection(0)
         indices = pd.read_sql_query(
-- 
GitLab