diff --git a/src/graphnet/data/dataset.py b/src/graphnet/data/dataset.py index d0758522316b840e0fead4484b933645b159ca9d..57f03875c32afcdd6efe6ad15f55bffaa8f8431a 100644 --- a/src/graphnet/data/dataset.py +++ b/src/graphnet/data/dataset.py @@ -140,6 +140,11 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC): loss_weight_column: Optional[str] = None, loss_weight_default_value: Optional[float] = None, seed: Optional[int] = None, + include_inactive_sensors: bool = False, + geometry_table: str = "geometry_table", + sensor_x_position: str = "dom_x", + sensor_y_position: str = "dom_y", + sensor_z_position: str = "dom_z", ): """Construct Dataset. @@ -185,6 +190,11 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC): subset of events when resolving a string-based selection (e.g., `"10000 random events ~ event_no % 5 > 0"` or `"20% random events ~ event_no % 5 > 0"`). + include_inactive_sensors: If True, sensors that measured nothing during the event will be appended to the graph. + geometry_table: The sqlite table in which the detector geometry is stored. + sensor_x_position: column name of x-coordinate of sensors. e.g. "dom_x", + sensor_y_position: column name of y-coordinate of sensors. e.g. "dom_y", + sensor_z_position: column name of z-coordinate of sensors. e.g. "dom_z", """ # Check(s) if isinstance(pulsemaps, str): @@ -205,6 +215,14 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC): self._index_column = index_column self._truth_table = truth_table self._loss_weight_default_value = loss_weight_default_value + self._include_inactive_sensors = include_inactive_sensors + self._sensor_position = { + "sensor_x_position_column": sensor_x_position, + "sensor_y_position_column": sensor_y_position, + "sensor_z_position_column": sensor_z_position, + } + if self._include_inactive_sensors: + self._set_geometry_table(geometry_table) if node_truth is not None: assert isinstance(node_truth_table, str) @@ -307,6 +325,19 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC): ) -> Optional[int]: """Return a the event index corresponding to a `sequential_index`.""" + @abstractmethod + def _setup_geometry_table(self, geometry_table: str) -> None: + """Must assign self._geometry_table.""" + + @abstractmethod + def _get_inactive_sensors( + self, + features: List[Tuple[float, ...]], + columns: List[str], + sequential_index: int, + ) -> List[Tuple[float, ...]]: + """Add sensors that are inactive.""" + @abstractmethod def query_table( self, @@ -480,6 +511,29 @@ class Dataset(torch.utils.data.Dataset, Configurable, LoggerMixin, ABC): else: node_truth = None + if self._include_inactive_sensors: + inactive_sensors = self._get_inactive_sensors( + features=features, + columns=self._features, + sequential_index=sequential_index, + ) + result = [] + result.extend(features) + result.extend(inactive_sensors) + result = ( + np.concatenate( + [ + np.repeat(sequential_index, len(result)).reshape( + -1, 1 + ), + result, + ], + axis=1, + ) + .astype(self._dtype) + .tolist() # adds event_no again. + ) + loss_weight: Optional[float] = None # Default if self._loss_weight_column is not None: assert self._loss_weight_table is not None diff --git a/src/graphnet/data/sqlite/sqlite_dataset.py b/src/graphnet/data/sqlite/sqlite_dataset.py index e61623c46d60c16abe474af3543fc379d59b75ed..6ed79d5992b0e507df3925e51d55c8c392cef151 100644 --- a/src/graphnet/data/sqlite/sqlite_dataset.py +++ b/src/graphnet/data/sqlite/sqlite_dataset.py @@ -3,9 +3,12 @@ from typing import Any, List, Optional, Tuple, Union import pandas as pd import sqlite3 +import numpy as np from graphnet.data.dataset import Dataset, ColumnMissingException +from graphnet.data.sqlite.sqlite_utilities import database_table_exists + class SQLiteDataset(Dataset): """Pytorch dataset for reading data from SQLite databases.""" @@ -67,11 +70,10 @@ class SQLiteDataset(Dataset): combined_selections = ( f"{self._index_column} = {index} and {selection}" ) - - result = self._conn.execute( - f"SELECT {columns} FROM {table} WHERE " - f"{combined_selections}" - ).fetchall() + result = self._conn.execute( + f"SELECT {columns} FROM {table} WHERE " + f"{combined_selections}" + ).fetchall() except sqlite3.OperationalError as e: if "no such column" in str(e): raise ColumnMissingException(str(e)) @@ -79,6 +81,23 @@ class SQLiteDataset(Dataset): raise e return result + def _get_inactive_sensors( + self, + features: List[Tuple[float, ...]], + columns: List[str], + sequential_index: int, + ) -> List[Tuple[float, ...]]: + assert self._conn + index = self._get_event_index(sequential_index) + active_query = f"select (CAST({self._sensor_position['x']} AS str) || '_' || CAST({self._sensor_position['y']} AS str) || '_' || CAST({self._sensor_position['z']} AS str)) as UID from {self._pulsemaps[0]} where {self._index_column} = {index} and {self._selection}" + active_result = self._conn.execute(active_query).fetchall() + columns_str = ", ".join( + columns + ) # remove event_no because not in geometry table + query = f"select {columns_str} from {self._geometry_table} where UID not in {str(tuple(np.array(active_result)))}" + inactive_result = self._conn.execute(query).fetchall() + return inactive_result + def _get_all_indices(self) -> List[int]: self._establish_connection(0) indices = pd.read_sql_query( @@ -147,3 +166,15 @@ class SQLiteDataset(Dataset): self._all_connections_established = False self._conn = None return self + + def _setup_geometry_table(self, geometry_table: str) -> None: + """Assign the geometry table to self if it exists in the database.""" + assert isinstance(self._path, str) + if database_table_exists( + database_path=self._path, table_name=geometry_table + ): + self._geoemtry_table = geometry_table + else: + assert ( + 1 == 2 + ), f"Geometry table named {geometry_table} is not in the database {self._path}" diff --git a/src/graphnet/data/sqlite/sqlite_utilities.py b/src/graphnet/data/sqlite/sqlite_utilities.py index 23bae802d9ebfbbd105088dca4ffe1cc910aff51..ee2bffb89d72ea819532f28c67dbb5ac7febce1b 100644 --- a/src/graphnet/data/sqlite/sqlite_utilities.py +++ b/src/graphnet/data/sqlite/sqlite_utilities.py @@ -1,13 +1,65 @@ """SQLite-specific utility functions for use in `graphnet.data`.""" import os.path -from typing import List +from typing import List, Optional, Dict import pandas as pd import sqlalchemy import sqlite3 +def add_geometry_table_to_database( + database: str, + pulsemap: str, + features_to_pad: List[str], + padding_value: int = 0, + additional_features: List[str] = ["rde", "pmt_area"], + sensor_x: str = "dom_x", + sensor_y: str = "dom_y", + sensor_z: str = "dom_z", + gcd_file: Optional[str] = None, + table_name: str = "geometry_table", +) -> None: + """Add geometry table to database. + + Args: + database: path to sqlite database + pulsemap: name of the pulsemap table + features_to_pad: list of column names that will be added to the dataframe after sqlite query. Will be padded. + padding_value: Value used for padding. Defaults to 0. + additional_features: additional features in pulsemap table that you want to include. Defaults to ["rde", "pmt_area"]. + sensor_x: x-coordinate of sensor positions. Defaults to "dom_x". + sensor_y: y-coordinate of sensor positions. Defaults to "dom_y". + sensor_z: z-coordinate of sensor positions. Defaults to "dom_z". + gcd_file: Path to gcd file. Defaults to None. + table_name: Name of the geometry table. . Defaults to "geometry_table". + """ + if gcd_file is not None: + assert ( + 1 == 2 + ), "Creation of geometry table from gcd file is not yet supported. Please make a pull request." + else: + additional_features_str = ", ".join(additional_features) + with sqlite3.connect(database) as con: + query = f"select distinct (CAST({sensor_x} AS str) || '_' || CAST({sensor_y} AS str) || '_' || CAST({sensor_z} AS str)) as UID, {sensor_x}, {sensor_y}, {sensor_z}, {additional_features_str} from {pulsemap}" + table = pd.read_sql(query, con) + + for feature_to_pad in features_to_pad: + table[feature_to_pad] = padding_value + + create_table( + table_name=table_name, + columns=table.columns, + database_path=database, + index_column="UID", + primary_key_type="STR", + integer_primary_key=True, + ) + + save_to_sql(df=table, table_name=table_name, database_path=database) + return + + def database_exists(database_path: str) -> bool: """Check whether database exists at `database_path`.""" assert database_path.endswith( @@ -79,6 +131,7 @@ def create_table( *, index_column: str = "event_no", default_type: str = "NOT NULL", + primary_key_type: str = "INTEGER", integer_primary_key: bool = True, ) -> None: """Create a table. @@ -89,6 +142,7 @@ def create_table( database_path: Path to the database. index_column: Name of the index column. default_type: The type used for all non-index columns. + primary_key_type: the data type for the primary key. Defaults to INTEGER. integer_primary_key: Whether or not to create the `index_column` with the `INTEGER PRIMARY KEY` type. Such a column is required to have unique, integer values for each row. This is appropriate when the @@ -103,7 +157,7 @@ def create_table( type_ = default_type if column == index_column: if integer_primary_key: - type_ = "INTEGER PRIMARY KEY NOT NULL" + type_ = f"{primary_key_type} PRIMARY KEY NOT NULL" else: type_ = "NOT NULL" @@ -134,6 +188,7 @@ def create_table_and_save_to_sql( index_column: str = "event_no", default_type: str = "NOT NULL", integer_primary_key: bool = True, + primary_key_type: str = "INTEGER", ) -> None: """Create table if it doesn't exist and save dataframe to it.""" if not database_table_exists(database_path, table_name): @@ -144,5 +199,6 @@ def create_table_and_save_to_sql( index_column=index_column, default_type=default_type, integer_primary_key=integer_primary_key, + primary_key_type=primary_key_type, ) save_to_sql(df, table_name=table_name, database_path=database_path)