diff --git a/Makefile b/Makefile index 792309de70c434184adb2b69d6dc3e9c84fb995d..d06e4c8db783e2aa6f54fda791ee5babdcb5780f 100644 --- a/Makefile +++ b/Makefile @@ -15,6 +15,9 @@ clean: test: py.test --junitxml=./reports/junit.xml -o junit_suite_name=$(PKGNAME) tests +retest: + py.test --junitxml=./reports/junit.xml -o junit_suite_name=$(PKGNAME) tests --last-failed + test-cov: py.test tests --cov $(ALLNAMES) --cov-report term-missing --cov-report xml:reports/coverage.xml --cov-report html:reports/coverage tests @@ -38,4 +41,4 @@ yapf: yapf -i -r $(PKGNAME) yapf -i setup.py -.PHONY: all clean build install install-dev test test-nocov flake8 pep8 dependencies docstyle +.PHONY: all clean build install install-dev test retest test-nocov flake8 pep8 dependencies docstyle diff --git a/configs/bundle_ORCA4_corsika_sibyll_2-3c.toml b/configs/bundle_ORCA4_corsika_sibyll_2-3c.toml index aface2a4ddf4303c0ad929874e0bbc9aff28496b..045bc9b43ebd2deb5f4ea24d9b5ed476d036a0a6 100644 --- a/configs/bundle_ORCA4_corsika_sibyll_2-3c.toml +++ b/configs/bundle_ORCA4_corsika_sibyll_2-3c.toml @@ -1,6 +1,5 @@ # for atmospheric muon bundles, ORCA4, corsika sibyll 2.3c, graph network mode = "graph" -max_n_hits = 2000 time_window = [-250, 1000] extractor = "bundle_mc" # center of mupage detector: diff --git a/configs/bundle_ORCA4_data_v5-40.toml b/configs/bundle_ORCA4_data_v5-40.toml index 16ad7390cf8e4d503be8a704519d53dcaded769a..936a5335d4ab9e4e78d584ba96029f912abc514c 100644 --- a/configs/bundle_ORCA4_data_v5-40.toml +++ b/configs/bundle_ORCA4_data_v5-40.toml @@ -1,6 +1,5 @@ # for atmospheric muon bundles, data ORCA4, graph network mode = "graph" -max_n_hits = 2000 time_window = [-250, 1000] extractor = "bundle_data" diff --git a/configs/bundle_ORCA4_mupage_v5-40.toml b/configs/bundle_ORCA4_mupage_v5-40.toml index 63d30d10c22cbdc5e37056b774390cbae94c4874..c1dd917f9e25b8968651b1077c2c6c1e4ddcbfc1 100644 --- a/configs/bundle_ORCA4_mupage_v5-40.toml +++ b/configs/bundle_ORCA4_mupage_v5-40.toml @@ -1,6 +1,5 @@ # for atmospheric muon bundles, mupage ORCA4, graph network mode = "graph" -max_n_hits = 2000 time_window = [-250, 1000] extractor = "bundle_mc" diff --git a/docs/orcasong.rst b/docs/orcasong.rst index 0dfd92f74ada095f192cb101d992dc62ded53280..be00f22f7565994c81466b378eac851c2bb9252b 100644 --- a/docs/orcasong.rst +++ b/docs/orcasong.rst @@ -105,16 +105,12 @@ like this: from orcasong.core import FileGraph -The FileGraph produces a list of nodes, each representing a hit. -The length of this list has to be fixed, i.e. be the same for each event. -Since the number of hits varies from event to event, some events will have to get -padded, while others might get hits removed. The parameter ``max_n_hits`` -of FileGraph determines this fixed length: - -.. code-block:: python - - fg = FileGraph(max_n_hits=2000) + fg = FileGraph() +The FileGraph produces a list of nodes, each representing a hit. +Since the number of hits varies from event to event, the hits of all events are saved +in a long list (2d array), and a seperate datasets is saved that can be used +to identify which hits belong to which events. General usage ------------- diff --git a/examples/orcasong_example.toml b/examples/orcasong_example.toml index a869d554edee27abe93ce0300c45b13b23fd640b..926c6d8440264319f05c03afede4a6f733eedf1a 100644 --- a/examples/orcasong_example.toml +++ b/examples/orcasong_example.toml @@ -3,8 +3,8 @@ # the mode to run orcasong in; either 'graph' or 'image' mode="graph" -# arguments for FileGraph or FileBinner (see orcasong.core) -max_n_hits = 2000 +# arguments for FileGraph or FileBinner can be put here +# (see orcasong.core for a list of parameters) time_window = [-100, 5000] # can also give the arguments of orcasong.core.BaseProcessor, # which are shared between modes diff --git a/orcasong/core.py b/orcasong/core.py index 57de103db300fc6a98e787092e4f0c3366e0de3d..8c06b8f1fd883797e945fead7ea517083d4110e7 100644 --- a/orcasong/core.py +++ b/orcasong/core.py @@ -355,35 +355,44 @@ class FileGraph(BaseProcessor): Turn km3 events to graph data. The resulting file will have a dataset "x" of shape - (?, max_n_hits, len(hit_infos) + 1). + (total n_hits, len(hit_infos)). The column names of the last axis (i.e. hit_infos) are saved as attributes of the dataset (f["x"].attrs). - The last column will always be called 'is_valid', and its 0 if - the entry is padded, and 1 otherwise. Parameters ---------- - max_n_hits : int - Maximum number of hits that gets saved per event. If an event has - more, some will get cut randomly! - time_window : tuple, optional - Two ints (start, end). Hits outside of this time window will be cut - away (based on 'Hits/time'). Default: Keep all hits. hit_infos : tuple, optional Which entries in the '/Hits' Table will be kept. E.g. pos_x, time, ... Default: Keep all entries. + time_window : tuple, optional + Two ints (start, end). Hits outside of this time window will be cut + away (based on 'Hits/time'). Default: Keep all hits. only_triggered_hits : bool - If true, use only triggered hits. Otherwise, use all hits. + If true, use only triggered hits. Otherwise, use all hits (default). + max_n_hits : int + Maximum number of hits that gets saved per event. If an event has + more, some will get cut randomly! Default: Keep all hits. + fixed_length : bool + If False (default), save hits of events with variable length as + 2d arrays using km3pipe's indices. + If True, pad hits of each event with 0s to a fixed length, + so that they can be stored as 3d arrays like images. + max_n_hits needs to be given in that case, and a column will be + added called 'is_valid', which is 0 if the entry is padded, + and 1 otherwise. + This is inefficient and will cut off hits, so it should not be used. kwargs Options of the BaseProcessor. """ - def __init__(self, max_n_hits, + def __init__(self, max_n_hits=None, time_window=None, hit_infos=None, only_triggered_hits=False, + fixed_length=False, **kwargs): self.max_n_hits = max_n_hits + self.fixed_length = fixed_length self.time_window = time_window self.hit_infos = hit_infos self.only_triggered_hits = only_triggered_hits @@ -392,6 +401,7 @@ class FileGraph(BaseProcessor): def get_cmpts_main(self): return [((modules.PointMaker, { "max_n_hits": self.max_n_hits, + "fixed_length": self.fixed_length, "time_window": self.time_window, "hit_infos": self.hit_infos, "dset_n_hits": "EventInfo", @@ -402,3 +412,4 @@ class FileGraph(BaseProcessor): super().finish_file(f, summary) for i, hit_info in enumerate(summary["PointMaker"]["hit_infos"]): f["x"].attrs.create(f"hit_info_{i}", hit_info) + f["x"].attrs.create("indexed", not self.fixed_length) diff --git a/orcasong/modules.py b/orcasong/modules.py index 30d8f0c07b173f49e9a058560a5ad83b49b1aba9..96f8fe241a8dc5011175cfd7929ace452e8c8f66 100644 --- a/orcasong/modules.py +++ b/orcasong/modules.py @@ -276,38 +276,48 @@ class PointMaker(kp.Module): Attributes ---------- - max_n_hits : int - Maximum number of hits that gets saved per event. If an event has - more, some will get cut! - time_window : tuple, optional - Two ints (start, end). Hits outside of this time window will be cut - away (base on 'Hits/time'). - Default: Keep all hits. hit_infos : tuple, optional Which entries in the '/Hits' Table will be kept. E.g. pos_x, time, ... Default: Keep all entries. + time_window : tuple, optional + Two ints (start, end). Hits outside of this time window will be cut + away (based on 'Hits/time'). Default: Keep all hits. + only_triggered_hits : bool + If true, use only triggered hits. Otherwise, use all hits (default). + max_n_hits : int + Maximum number of hits that gets saved per event. If an event has + more, some will get cut randomly! Default: Keep all hits. + fixed_length : bool + If False (default), save hits of events with variable length as + 2d arrays using km3pipe's indices. + If True, pad hits of each event with 0s to a fixed length, + so that they can be stored as 3d arrays like images. + max_n_hits needs to be given in that case, and a column will be + added called 'is_valid', which is 0 if the entry is padded, + and 1 otherwise. + This is inefficient and will cut off hits, so it should not be used. dset_n_hits : str, optional If given, store the number of hits that are in the time window as a new column called 'n_hits_intime' in the dataset with this name (usually this is EventInfo). - only_triggered_hits : bool - If true, use only triggered hits. Otherwise, use all hits. """ def configure(self): - self.max_n_hits = self.require("max_n_hits") self.hit_infos = self.get("hit_infos", default=None) self.time_window = self.get("time_window", default=None) - self.dset_n_hits = self.get("dset_n_hits", default=None) self.only_triggered_hits = self.get("only_triggered_hits", default=False) + self.max_n_hits = self.get("max_n_hits", default=None) + self.fixed_length = self.get("fixed_length", default=False) + self.dset_n_hits = self.get("dset_n_hits", default=None) self.store_as = "samples" def process(self, blob): + if self.fixed_length and self.max_n_hits is None: + raise ValueError("Have to specify max_n_hits if fixed_length is True") if self.hit_infos is None: self.hit_infos = blob["Hits"].dtype.names points, n_hits = self.get_points(blob) - blob[self.store_as] = kp.NDArray( - np.expand_dims(points, 0), h5loc="x", title="nodes") + blob[self.store_as] = kp.NDArray(points, h5loc="x", title="nodes") if self.dset_n_hits: blob[self.dset_n_hits] = blob[self.dset_n_hits].append_columns( "n_hits_intime", n_hits) @@ -326,11 +336,9 @@ class PointMaker(kp.Module): actual hits, and 0 for if its a padded row. n_hits : int Number of hits in the given time window. + Can be stored as n_hits_intime. """ - points = np.zeros( - (self.max_n_hits, len(self.hit_infos) + 1), dtype="float32") - hits = blob["Hits"] if self.only_triggered_hits: hits = hits[hits.triggered != 0] @@ -342,7 +350,7 @@ class PointMaker(kp.Module): )] n_hits = len(hits) - if n_hits > self.max_n_hits: + if self.max_n_hits is not None and n_hits > self.max_n_hits: # if there are too many hits, take random ones, but keep order indices = np.arange(n_hits) np.random.shuffle(indices) @@ -350,15 +358,27 @@ class PointMaker(kp.Module): which.sort() hits = hits[which] - for i, which in enumerate(self.hit_infos): - data = hits[which] - points[:n_hits, i] = data - # last column is whether there was a hit or no - points[:n_hits, -1] = 1. + if self.fixed_length: + points = np.zeros( + (self.max_n_hits, len(self.hit_infos) + 1), dtype="float32") + for i, which in enumerate(self.hit_infos): + points[:n_hits, i] = hits[which] + # last column is whether there was a hit or no + points[:n_hits, -1] = 1. + # store along new axis + points = np.expand_dims(points, 0) + else: + points = np.zeros( + (len(hits), len(self.hit_infos)), dtype="float32") + for i, which in enumerate(self.hit_infos): + points[:, i] = hits[which] return points, n_hits def finish(self): - return {"hit_infos": tuple(self.hit_infos) + ("is_valid", )} + columns = tuple(self.hit_infos) + if self.fixed_length: + columns += ("is_valid", ) + return {"hit_infos": columns} class EventSkipper(kp.Module): diff --git a/orcasong/tools/concatenate.py b/orcasong/tools/concatenate.py index 0c47fb717d803cf6d627f5f6843f4c8a1326be68..5896c39fad39c9706a9f9d52b25b1489ed8c23af 100644 --- a/orcasong/tools/concatenate.py +++ b/orcasong/tools/concatenate.py @@ -1,8 +1,6 @@ import os import time import h5py -import numpy as np -import argparse import warnings @@ -29,8 +27,8 @@ class FileConcatenator: comptopts : dict Options for compression. They are read from the first input file, but they can be updated as well during init. - cumu_rows : np.array - The cumulative number of rows (axis_0) of the specified + cumu_rows : dict + Foe each dataset, the cumulative number of rows (axis_0) of the specified input .h5 files (i.e. [0,100,200,300,...] if each file has 100 rows). """ @@ -39,7 +37,6 @@ class FileConcatenator: print(f"Checking {len(input_files)} files ...") self.input_files, self.cumu_rows = self._get_cumu_rows(input_files) - print(f"Total rows:\t{self.cumu_rows[-1]}") # Get compression options from first file in the list self.comptopts = get_compopts(self.input_files[0]) @@ -47,8 +44,6 @@ class FileConcatenator: self.comptopts.update(comptopts_update) print("\n".join([f" {k}:\t{v}" for k, v in self.comptopts.items()])) - self._modify_folder = False - @classmethod def from_list(cls, list_file, n_files=None, **kwargs): """ @@ -92,7 +87,7 @@ class FileConcatenator: print(f'Processing file {input_file_nmbr+1}/' f'{len(self.input_files)}: {input_file}') with h5py.File(input_file, 'r') as f_in: - self._conc_file(f_in, f_out, input_file, input_file_nmbr) + self._conc_file(f_in, f_out, input_file_nmbr) f_out.flush() elapsed_time = time.time() - start_time @@ -109,46 +104,36 @@ class FileConcatenator: f"\nElapsed time: {elapsed_time/60:.2f} min " f"({elapsed_time/len(self.input_files):.2f} s per file)") - def _conc_file(self, f_in, f_out, input_file, input_file_nmbr): + def _conc_file(self, f_in, f_out, input_file_nmbr): """ Conc one file to the output. """ - for folder_name in f_in: - if is_folder_ignored(folder_name): + for dset_name in f_in: + if is_folder_ignored(dset_name): # we dont need datasets created by pytables anymore continue - input_dataset = f_in[folder_name] + input_dataset = f_in[dset_name] folder_data = input_dataset[()] if input_file_nmbr > 0: # we need to add the current number of the # group_id / index in the file_output to the # group_ids / indices of the file that is to be appended - try: - if folder_name.endswith("_indices") and \ - "index" in folder_data.dtype.names: - column_name = "index" - elif "group_id" in folder_data.dtype.names: - column_name = "group_id" - else: - column_name = None - except TypeError: - column_name = None - if column_name is not None: - # add 1 because the group_ids / indices start with 0 - folder_data[column_name] += \ - np.amax(f_out[folder_name][column_name]) + 1 - - if self._modify_folder: - data_mody = self._modify( - input_file, folder_data, folder_name) - if data_mody is not None: - folder_data = data_mody + last_index = self.cumu_rows[dset_name][input_file_nmbr] - 1 + if (dset_name.endswith("_indices") and + "index" in folder_data.dtype.names): + folder_data["index"] += ( + f_out[dset_name][last_index]["index"] + + f_out[dset_name][last_index]["n_items"] + ) + elif folder_data.dtype.names and "group_id" in folder_data.dtype.names: + # add 1 because the group_ids start with 0 + folder_data["group_id"] += f_out[dset_name][last_index]["group_id"] + 1 if input_file_nmbr == 0: # first file; create the dataset - dset_shape = (self.cumu_rows[-1],) + folder_data.shape[1:] - print(f"\tCreating dataset '{folder_name}' with shape {dset_shape}") + dset_shape = (self.cumu_rows[dset_name][-1],) + folder_data.shape[1:] + print(f"\tCreating dataset '{dset_name}' with shape {dset_shape}") output_dataset = f_out.create_dataset( - folder_name, + dset_name, data=folder_data, maxshape=dset_shape, chunks=(self.comptopts["chunksize"],) + folder_data.shape[1:], @@ -156,37 +141,37 @@ class FileConcatenator: compression_opts=self.comptopts["complevel"], shuffle=self.comptopts["shuffle"], ) - output_dataset.resize(self.cumu_rows[-1], axis=0) + output_dataset.resize(dset_shape[0], axis=0) else: - f_out[folder_name][ - self.cumu_rows[input_file_nmbr]:self.cumu_rows[input_file_nmbr + 1]] = folder_data - - def _modify(self, input_file, folder_data, folder_name): - raise NotImplementedError + f_out[dset_name][ + self.cumu_rows[dset_name][input_file_nmbr]: + self.cumu_rows[dset_name][input_file_nmbr + 1] + ] = folder_data def _get_cumu_rows(self, input_files): """ - Get the cumulative number of rows of the input_files. - Also checks if all the files can be safely concatenated to the - first one. + Checks if all the files can be safely concatenated to the first one. """ # names of datasets that will be in the output; read from first file with h5py.File(input_files[0], 'r') as f: - keys_stripped = strip_keys(list(f.keys())) + target_datasets = strip_keys(list(f.keys())) + + errors, valid_input_files = [], [] + cumu_rows = {k: [0] for k in target_datasets} - errors, rows_per_file, valid_input_files = [], [0], [] for i, file_name in enumerate(input_files, start=1): file_name = os.path.abspath(file_name) try: - rows_this_file = _get_rows(file_name, keys_stripped) + rows_this_file = _get_rows(file_name, target_datasets) except Exception as e: errors.append(e) warnings.warn(f"Error during check of file {i}: {file_name}") continue valid_input_files.append(file_name) - rows_per_file.append(rows_this_file) + for k in target_datasets: + cumu_rows[k].append(cumu_rows[k][-1] + rows_this_file[k]) if errors: print("\n------- Errors -------\n----------------------") @@ -200,8 +185,8 @@ class FileConcatenator: raise OSError(err_str) print(f"Valid input files: {len(valid_input_files)}/{len(input_files)}") - print("Datasets:\t" + ", ".join(keys_stripped)) - return valid_input_files, np.cumsum(rows_per_file) + print("Datasets:\t" + ", ".join(target_datasets)) + return valid_input_files, cumu_rows def _get_rows(file_name, target_datasets): @@ -215,11 +200,15 @@ def _get_rows(file_name, target_datasets): f"It has {f.keys()} First file: {target_datasets}" ) # check if all target datasets in the file have the same length - rows = [f[k].shape[0] for k in target_datasets] - if not all(row == rows[0] for row in rows): + # for datasets that have indices: only check indices + plain_rows = [ + f[k].shape[0] for k in target_datasets + if not (f[k].attrs.get("indexed") and f"{k}_indices" in target_datasets) + ] + if not all(row == plain_rows[0] for row in plain_rows): raise ValueError( f"Datasets in file {file_name} have varying length! " - f"{dict(zip(target_datasets, rows))}" + f"{dict(zip(target_datasets, plain_rows))}" ) # check if the file has additional datasets apart from the target keys if not all(k in target_datasets for k in strip_keys(list(f.keys()))): @@ -229,7 +218,8 @@ def _get_rows(file_name, target_datasets): f"This file: {strip_keys(list(f.keys()))} " f"First file {target_datasets}" ) - return rows[0] + rows = {k: f[k].shape[0] for k in target_datasets} + return rows def strip_keys(f_keys): @@ -270,7 +260,12 @@ def get_compopts(file): """ with h5py.File(file, 'r') as f: - dset = f[strip_keys(list(f.keys()))[0]] + # for reading the comptopts, take first datsets thats not indexed + dset_names = strip_keys(list(f.keys())) + for dset_name in dset_names: + if f"{dset_name}_indices" not in dset_names: + break + dset = f[dset_name] comptopts = {} comptopts["complib"] = dset.compression if comptopts["complib"] == 'lzf': diff --git a/orcasong/tools/shuffle2.py b/orcasong/tools/shuffle2.py index 90859dc4412998e5cc21a85e2a7d69882fc592bb..dcd9ed06d1d0ec66a950c6c45b2c480536ecc15d 100644 --- a/orcasong/tools/shuffle2.py +++ b/orcasong/tools/shuffle2.py @@ -6,6 +6,7 @@ import numpy as np import psutil import h5py from km3pipe.sys import peak_memory_usage +import awkward as ak from orcasong.tools.postproc import get_filepath_output, copy_used_files from orcasong.tools.concatenate import copy_attrs @@ -103,9 +104,10 @@ def _shuffle_file( output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime()) ) with h5py.File(input_file, "r") as f_in: - _check_dsets(f_in, datasets) - dset_info = _get_largest_dset(f_in, datasets, max_ram) - print(f"Shuffling datasets {datasets}") + dsets = (*datasets, *_get_indexed_datasets(f_in, datasets)) + _check_dsets(f_in, dsets) + dset_info = _get_largest_dset(f_in, dsets, max_ram) + print(f"Shuffling datasets {dsets}") indices_per_batch = _get_indices_per_batch( dset_info["n_batches"], dset_info["n_chunks"], @@ -113,7 +115,7 @@ def _shuffle_file( ) with h5py.File(temp_output_file, "x") as f_out: - for dset_name in datasets: + for dset_name in dsets: print("Creating dataset", dset_name) _shuffle_dset(f_out, f_in, dset_name, indices_per_batch) print("Done!") @@ -143,9 +145,9 @@ def get_n_iterations( max_ram = get_max_ram(max_ram_fraction=max_ram_fraction) with h5py.File(input_file, "r") as f_in: dset_info = _get_largest_dset(f_in, datasets, max_ram) - n_iterations = int( + n_iterations = np.amax((1, int( np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"])) - ) + ))) print(f"Largest dataset: {dset_info['name']}") print(f"Total chunks: {dset_info['n_chunks']}") print(f"Max. chunks per batch: {dset_info['chunks_per_batch']}") @@ -188,28 +190,53 @@ def _get_largest_dset(f, datasets, max_ram): def _check_dsets(f, datasets): # check if all datasets have the same number of lines - n_lines_list = [len(f[dset_name]) for dset_name in datasets] + n_lines_list = [] + for dset_name in datasets: + if dset_is_indexed(f, dset_name): + dset_name = f"{dset_name}_indices" + n_lines_list.append(len(f[dset_name])) + if not all([n == n_lines_list[0] for n in n_lines_list]): raise ValueError( f"Datasets have different lengths! " f"{n_lines_list}" ) +def _get_indexed_datasets(f, datasets): + indexed_datasets = [] + for dset_name in datasets: + if dset_is_indexed(f, dset_name): + indexed_datasets.append(f"{dset_name}_indices") + return indexed_datasets + + def _get_dset_infos(f, datasets, max_ram): """ Retrieve infos for each dataset. """ dset_infos = [] for i, name in enumerate(datasets): - dset = f[name] + if name.endswith("_indices"): + continue + if dset_is_indexed(f, name): + # for indexed dataset: take average bytes in x per line in x_indices + dset_data = f[name] + name = f"{name}_indices" + dset = f[name] + bytes_per_line = ( + np.asarray(dset[0]).nbytes * + len(dset_data) / len(dset) + ) + else: + dset = f[name] + bytes_per_line = np.asarray(dset[0]).nbytes + n_lines = len(dset) chunksize = dset.chunks[0] n_chunks = int(np.ceil(n_lines / chunksize)) - bytes_per_line = np.asarray(dset[0]).nbytes bytes_per_chunk = bytes_per_line * chunksize chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk)) dset_infos.append({ "name": name, - "dset": dset, "n_chunks": n_chunks, "chunks_per_batch": chunks_per_batch, "n_batches": int(np.ceil(n_chunks / chunks_per_batch)), @@ -219,6 +246,16 @@ def _get_dset_infos(f, datasets, max_ram): return dset_infos +def dset_is_indexed(f, dset_name): + if f[dset_name].attrs.get("indexed"): + if f"{dset_name}_indices" not in f: + raise KeyError( + f"{dset_name} is indexed, but {dset_name}_indices is missing!") + return True + else: + return False + + def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch): """ Create a batchwise-shuffled dataset in the output file using given indices. @@ -229,7 +266,12 @@ def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch): for batch_number, indices in enumerate(indices_per_batch): print(f"Processing batch {batch_number+1}/{len(indices_per_batch)}") # remove indices outside of dset - indices = indices[indices < len(dset_in)] + if dset_is_indexed(f_in, dset_name): + max_index = len(f_in[f"{dset_name}_indices"]) + else: + max_index = len(dset_in) + indices = indices[indices < max_index] + # reading has to be done with linearly increasing index # fancy indexing is super slow # so sort -> turn to slices -> read -> conc -> undo sorting @@ -237,8 +279,27 @@ def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch): unsort_ix = np.argsort(sort_ix) fancy_indices = indices[sort_ix] slices = _slicify(fancy_indices) - data = np.concatenate([dset_in[slc] for slc in slices]) - data = data[unsort_ix] + + if dset_is_indexed(f_in, dset_name): + # special treatment for indexed: slice based on indices dataset + slices_indices = [f_in[f"{dset_name}_indices"][slc] for slc in slices] + data = np.concatenate([ + dset_in[slice(*_resolve_indexed(slc))] for slc in slices_indices + ]) + # convert to 3d awkward array, then shuffle, then back to numpy + data_indices = np.concatenate(slices_indices) + data_ak = ak.unflatten(data, data_indices["n_items"]) + data = ak.flatten(data_ak[unsort_ix], axis=1).to_numpy() + + else: + data = np.concatenate([dset_in[slc] for slc in slices]) + data = data[unsort_ix] + + if dset_name.endswith("_indices"): + # recacalculate index + data["index"] = start_idx + np.concatenate([ + [0], np.cumsum(data["n_items"][:-1]) + ]) if batch_number == 0: out_dset = f_out.create_dataset( @@ -271,6 +332,11 @@ def _slicify(fancy_indices): return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))] +def _resolve_indexed(ind): + # based on slice of x_indices, get where to slice in x + return ind["index"][0], ind["index"][-1] + ind["n_items"][-1] + + def _get_temp_filenames(output_file, number): path, file = os.path.split(output_file) return [os.path.join(path, f"temp_iteration_{i}_{file}") for i in range(number)] diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 60f39cb6710daf5461e66d4b561ee79b9b56f837..3a417a284c6eb463095e457e922f0b588f3d0df8 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -1,13 +1,14 @@ import tempfile -from unittest import TestCase +import unittest import numpy as np import h5py import orcasong.tools.concatenate as conc +import os __author__ = 'Stefan Reck' -class TestFileConcatenator(TestCase): +class TestFileConcatenator(unittest.TestCase): """ Test concatenation on pre-generated h5 files. They are in tests/data. @@ -58,7 +59,9 @@ class TestFileConcatenator(TestCase): def test_get_cumu_rows(self): fc = conc.FileConcatenator(self.dummy_files) - np.testing.assert_array_equal(fc.cumu_rows, [0, 10, 25]) + self.assertDictEqual( + fc.cumu_rows, {'numpy_array': [0, 10, 25], 'rec_array': [0, 10, 25]} + ) def test_concatenate_used_files(self): fc = conc.FileConcatenator(self.dummy_files) @@ -114,6 +117,56 @@ class TestFileConcatenator(TestCase): ) +class BaseTestClass: + class BaseIndexedFile(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.infile = tempfile.NamedTemporaryFile() + with h5py.File(cls.infile, "w") as f: + cls.x = np.arange(20) + dset_x = f.create_dataset("x", data=cls.x, chunks=True) + dset_x.attrs.create("indexed", True) + cls.indices = np.array( + [(0, 5), (5, 12), (17, 3)], + dtype=[('index', '<i8'), ('n_items', '<i8')] + ) + f.create_dataset("x_indices", data=cls.indices, chunks=True) + + @classmethod + def tearDownClass(cls) -> None: + cls.infile.close() + + +class TestConcatenateIndexed(BaseTestClass.BaseIndexedFile): + def setUp(self) -> None: + self.outfile = "temp_out.h5" + conc.concatenate([self.infile.name] * 2, outfile=self.outfile) + + def tearDown(self) -> None: + if os.path.exists(self.outfile): + os.remove(self.outfile) + + def test_check_x(self): + with h5py.File(self.outfile) as f_out: + np.testing.assert_array_equal( + f_out["x"], + np.concatenate([self.x]*2) + ) + + def test_check_x_indices_n_items(self): + with h5py.File(self.outfile) as f_out: + target_n_items = np.concatenate([self.indices] * 2)["n_items"] + np.testing.assert_array_equal( + f_out["x_indices"]["n_items"], target_n_items) + + def test_check_x_indices_index(self): + with h5py.File(self.outfile) as f_out: + target_n_items = np.concatenate([self.indices] * 2)["n_items"] + target_index = np.concatenate([[0], target_n_items.cumsum()[:-1]]) + np.testing.assert_array_equal( + f_out["x_indices"]["index"], target_index) + + def _create_dummy_file(filepath, columns=10, val_array=1, val_recarray=(1, 3)): """ Create a dummy h5 file with an array and a recarray in it. """ with h5py.File(filepath, "w") as f: diff --git a/tests/test_core.py b/tests/test_core.py index af689ba2600fcf33aa1014939557caa0088490a5..886197dc030ade9c487cd049922799243bb0655e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -83,7 +83,9 @@ class TestFileGraph(TestCase): """ Assert that the FileGraph still produces the same output. """ @classmethod def setUpClass(cls): - cls.proc = orcasong.core.FileGraph( + # produce test file, once for fixed_length (old format), and once + # for the new format + cls.proc_fixed_length, cls.proc = [orcasong.core.FileGraph( max_n_hits=3, time_window=[0, 50], hit_infos=["pos_z", "time", "channel_id"], @@ -92,8 +94,14 @@ class TestFileGraph(TestCase): add_t0=False, keep_event_info=True, correct_timeslew=False, - ) + fixed_length=fixed_length, + ) for fixed_length in (True, False)] cls.tmpdir = tempfile.TemporaryDirectory() + + cls.outfile_fixed_length = os.path.join(cls.tmpdir.name, "binned_fixed_length.h5") + cls.proc_fixed_length.run(infile=MUPAGE_FILE, outfile=cls.outfile_fixed_length) + cls.f_fixed_length = h5py.File(cls.outfile_fixed_length, "r") + cls.outfile = os.path.join(cls.tmpdir.name, "binned.h5") cls.proc.run(infile=MUPAGE_FILE, outfile=cls.outfile) cls.f = h5py.File(cls.outfile, "r") @@ -101,25 +109,43 @@ class TestFileGraph(TestCase): @classmethod def tearDownClass(cls): cls.f.close() + cls.f_fixed_length.close() cls.tmpdir.cleanup() + def test_keys_fixed_length(self): + self.assertSetEqual(set(self.f_fixed_length.keys()), { + '_i_event_info', '_i_group_info', '_i_y', + 'event_info', 'group_info', 'x', 'x_indices', 'y'}) + def test_keys(self): - self.assertSetEqual(set(self.f.keys()), { + self.assertSetEqual(set(self.f_fixed_length.keys()), { '_i_event_info', '_i_group_info', '_i_y', 'event_info', 'group_info', 'x', 'x_indices', 'y'}) - def test_x_attrs(self): + def test_x_attrs_fixed_length(self): to_check = { "hit_info_0": "pos_z", "hit_info_1": "time", "hit_info_2": "channel_id", "hit_info_3": "is_valid", + "indexed": False, + } + attrs = dict(self.f_fixed_length["x"].attrs) + for k, v in to_check.items(): + self.assertTrue(attrs[k] == v) + + def test_x_attrs(self): + to_check = { + "hit_info_0": "pos_z", + "hit_info_1": "time", + "hit_info_2": "channel_id", + "indexed": True, } attrs = dict(self.f["x"].attrs) for k, v in to_check.items(): self.assertTrue(attrs[k] == v) - def test_x(self): + def test_x_fixed_length(self): target = np.array([ [[676.941, 13., 30., 1.], [461.111, 32., 9., 1.], @@ -131,8 +157,33 @@ class TestFileGraph(TestCase): [605.111, 9., 4., 1.], [424.889, 46., 29., 1.]] ], dtype=np.float32) + np.testing.assert_equal(target, self.f_fixed_length["x"]) + + def test_x(self): + target = np.array([ + [676.941, 13., 30.], + [461.111, 32., 9.], + [424.941, 1., 30.], + [172.83, 32., 25.], + [316.83, 2., 14.], + [461.059, 1., 3.], + [496.83, 34., 25.], + [605.111, 9., 4.], + [424.889, 46., 29.], + ], dtype=np.float32) np.testing.assert_equal(target, self.f["x"]) + def test_y_fixed_length(self): + y = self.f_fixed_length["y"][()] + target = { + 'event_id': np.array([0., 1., 2.]), + 'run_id': np.array([1., 1., 1.]), + 'trigger_mask': np.array([18., 18., 16.]), + 'group_id': np.array([0, 1, 2]), + } + for k, v in target.items(): + np.testing.assert_equal(y[k], v) + def test_y(self): y = self.f["y"][()] target = { @@ -142,4 +193,4 @@ class TestFileGraph(TestCase): 'group_id': np.array([0, 1, 2]), } for k, v in target.items(): - np.testing.assert_equal(y[k], v) + np.testing.assert_equal(y[k], v) \ No newline at end of file diff --git a/tests/test_extractor.py b/tests/test_extractor.py index 8660f590f707f7ca85565b6c931e2e0c33cd805b..e311ab5319a0287da6c4f57bbe9b73d5558052cb 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -17,6 +17,7 @@ DET_FILE_NEUTRINO = os.path.join(test_dir, "data", "KM3NeT_00000049_20200707.det NO_COMPLE_RECO_FILE = os.path.join(test_dir, "data", "arca_test_without_some_jmuon_recos.h5") ARCA_DETX = os.path.join(test_dir, "data", "KM3NeT_-00000001_20171212.detx") + class TestStdRecoExtractor(TestCase): """ Assert that the neutrino info is extracted correctly. File has 18 events. """ @@ -31,6 +32,7 @@ class TestStdRecoExtractor(TestCase): det_file=DET_FILE_NEUTRINO, add_t0=True, keep_event_info=True, + fixed_length=True, ) cls.tmpdir = tempfile.TemporaryDirectory() cls.outfile = os.path.join(cls.tmpdir.name, "binned.h5") @@ -50,6 +52,7 @@ class TestStdRecoExtractor(TestCase): det_file=ARCA_DETX, add_t0=True, keep_event_info=True, + fixed_length=True, ) cls.outfile_arca = os.path.join(cls.tmpdir.name, "binned_arca.h5") cls.proc.run(infile=NO_COMPLE_RECO_FILE, outfile=cls.outfile_arca) diff --git a/tests/test_from_toml.py b/tests/test_from_toml.py index bd3d3d6799e88a4b627fc61104068f117dfcce02..49a1c9475a023a44724d4721e72dd1059117b297 100644 --- a/tests/test_from_toml.py +++ b/tests/test_from_toml.py @@ -27,7 +27,7 @@ class TestSetupProcessorExampleConfig(TestCase): self.assertEqual(self.processor.time_window, [-100, 5000]) def test_max_n_hits(self): - self.assertEqual(self.processor.max_n_hits, 2000) + self.assertEqual(self.processor.max_n_hits, None) def test_chunksize(self): self.assertEqual(self.processor.chunksize, 16) diff --git a/tests/test_modules.py b/tests/test_modules.py index 5e1a375d0a597a3083bda9ce306e5db29063e40b..e8510261f9956a4af5461601eb7f6969953e9b00 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -191,6 +191,19 @@ class TestPointMaker(TestCase): pm = modules.PointMaker( max_n_hits=4) result = pm.process(self.input_blob_1)["samples"] + self.assertTupleEqual( + pm.finish()["hit_infos"], ("t0", "time", "x")) + target = np.array( + [[0.1, 1, 4], + [0.2, 2, 5], + [0.3, 3, 6]], + dtype="float32") + np.testing.assert_array_equal(result, target) + + def test_default_settings_fixed_length(self): + pm = modules.PointMaker( + max_n_hits=4, fixed_length=True) + result = pm.process(self.input_blob_1)["samples"] self.assertTupleEqual( pm.finish()["hit_infos"], ("t0", "time", "x", "is_valid")) target = np.array( @@ -200,12 +213,13 @@ class TestPointMaker(TestCase): [0, 0, 0, 0]]], dtype="float32") np.testing.assert_array_equal(result, target) - def test_input_blob_1(self): + def test_input_blob_1_fixed_length(self): pm = modules.PointMaker( max_n_hits=4, hit_infos=("x", "time"), time_window=None, dset_n_hits=None, + fixed_length=True, ) result = pm.process(self.input_blob_1)["samples"] self.assertTupleEqual( @@ -217,7 +231,7 @@ class TestPointMaker(TestCase): [0, 0, 0]]], dtype="float32") np.testing.assert_array_equal(result, target) - def test_input_blob_1_max_n_hits(self): + def test_input_blob_1_max_n_hits_fixed_length(self): input_blob_long = { "Hits": kp.Table({ "x": np.random.rand(1000).astype("float32"), @@ -227,18 +241,20 @@ class TestPointMaker(TestCase): hit_infos=("x",), time_window=None, dset_n_hits=None, + fixed_length=True, ).process(input_blob_long)["samples"] self.assertSequenceEqual(result.shape, (1, 10, 2)) self.assertTrue(all( np.isin(result[0, :, 0], input_blob_long["Hits"]["x"]))) - def test_input_blob_time_window(self): + def test_input_blob_time_window_fixed_length(self): result = modules.PointMaker( max_n_hits=4, hit_infos=("x", "time"), time_window=[1, 2], dset_n_hits=None, + fixed_length=True, ).process(self.input_blob_1)["samples"] target = np.array( [[[4, 1, 1], @@ -247,12 +263,13 @@ class TestPointMaker(TestCase): [0, 0, 0]]], dtype="float32") np.testing.assert_array_equal(result, target) - def test_input_blob_time_window_nhits(self): + def test_input_blob_time_window_nhits_fixed_length(self): result = modules.PointMaker( max_n_hits=4, hit_infos=("x", "time"), time_window=[1, 2], dset_n_hits="EventInfo", + fixed_length=True, ).process(self.input_blob_1)["EventInfo"] print(result) self.assertEqual(result["n_hits_intime"], 2) diff --git a/tests/test_postproc.py b/tests/test_postproc.py index dea566b37903351b103ec81a5ee234b7ad545176..0efac09ccbd45afb06747f2004037b60345e6d87 100644 --- a/tests/test_postproc.py +++ b/tests/test_postproc.py @@ -4,6 +4,7 @@ import h5py import numpy as np import orcasong.tools.postproc as postproc import orcasong.tools.shuffle2 as shuffle2 +from .test_concatenate import BaseTestClass __author__ = 'Stefan Reck' @@ -90,6 +91,40 @@ class TestShuffleV2(TestCase): os.remove(fname) +class TestShuffleIndexed(BaseTestClass.BaseIndexedFile): + def setUp(self) -> None: + self.outfile = "temp_out.h5" + shuffle2.h5shuffle2( + self.infile.name, + output_file=self.outfile, + datasets=("x",), + seed=2, + ) + + def tearDown(self) -> None: + if os.path.exists(self.outfile): + os.remove(self.outfile) + + def test_check_x(self): + with h5py.File(self.outfile) as f_out: + np.testing.assert_array_equal( + f_out["x"], + np.concatenate([np.arange(17, 20), np.arange(5, 17), np.arange(0, 5)]) + ) + + def test_check_x_indices_n_items(self): + with h5py.File(self.outfile) as f_out: + target_n_items = np.array([3, 12, 5]) + np.testing.assert_array_equal( + f_out["x_indices"]["n_items"], target_n_items) + + def test_check_x_indices_index(self): + with h5py.File(self.outfile) as f_out: + target_index = np.array([0, 3, 15]) + np.testing.assert_array_equal( + f_out["x_indices"]["index"], target_index) + + def _make_shuffle_dummy_file(filepath): x = np.random.rand(22, 2) x[:, 0] = np.arange(22)