diff --git a/orcasong/core.py b/orcasong/core.py index 57de103db300fc6a98e787092e4f0c3366e0de3d..8a2edf7e506d4304b6d82c3cc3751a6d738b96d6 100644 --- a/orcasong/core.py +++ b/orcasong/core.py @@ -378,12 +378,14 @@ class FileGraph(BaseProcessor): Options of the BaseProcessor. """ - def __init__(self, max_n_hits, + def __init__(self, max_n_hits=None, + padded=True, time_window=None, hit_infos=None, only_triggered_hits=False, **kwargs): self.max_n_hits = max_n_hits + self.padded = padded self.time_window = time_window self.hit_infos = hit_infos self.only_triggered_hits = only_triggered_hits @@ -392,6 +394,7 @@ class FileGraph(BaseProcessor): def get_cmpts_main(self): return [((modules.PointMaker, { "max_n_hits": self.max_n_hits, + "padded": self.padded, "time_window": self.time_window, "hit_infos": self.hit_infos, "dset_n_hits": "EventInfo", @@ -402,3 +405,4 @@ class FileGraph(BaseProcessor): super().finish_file(f, summary) for i, hit_info in enumerate(summary["PointMaker"]["hit_infos"]): f["x"].attrs.create(f"hit_info_{i}", hit_info) + f["x"].attrs.create("indexed", not self.padded) diff --git a/orcasong/modules.py b/orcasong/modules.py index 30d8f0c07b173f49e9a058560a5ad83b49b1aba9..69677eae35441e4acad465bc4229afee72106b6f 100644 --- a/orcasong/modules.py +++ b/orcasong/modules.py @@ -295,7 +295,8 @@ class PointMaker(kp.Module): """ def configure(self): - self.max_n_hits = self.require("max_n_hits") + self.max_n_hits = self.get("max_n_hits", default=None) + self.padded = self.get("padded", default=True) self.hit_infos = self.get("hit_infos", default=None) self.time_window = self.get("time_window", default=None) self.dset_n_hits = self.get("dset_n_hits", default=None) @@ -303,11 +304,12 @@ class PointMaker(kp.Module): self.store_as = "samples" def process(self, blob): + if self.padded and self.max_n_hits is None: + raise ValueError("Have to specify max_n_hits if padded is True") if self.hit_infos is None: self.hit_infos = blob["Hits"].dtype.names points, n_hits = self.get_points(blob) - blob[self.store_as] = kp.NDArray( - np.expand_dims(points, 0), h5loc="x", title="nodes") + blob[self.store_as] = kp.NDArray(points, h5loc="x", title="nodes") if self.dset_n_hits: blob[self.dset_n_hits] = blob[self.dset_n_hits].append_columns( "n_hits_intime", n_hits) @@ -326,11 +328,9 @@ class PointMaker(kp.Module): actual hits, and 0 for if its a padded row. n_hits : int Number of hits in the given time window. + Can be stored as n_hits_intime. """ - points = np.zeros( - (self.max_n_hits, len(self.hit_infos) + 1), dtype="float32") - hits = blob["Hits"] if self.only_triggered_hits: hits = hits[hits.triggered != 0] @@ -342,7 +342,7 @@ class PointMaker(kp.Module): )] n_hits = len(hits) - if n_hits > self.max_n_hits: + if self.max_n_hits is not None and n_hits > self.max_n_hits: # if there are too many hits, take random ones, but keep order indices = np.arange(n_hits) np.random.shuffle(indices) @@ -350,15 +350,27 @@ class PointMaker(kp.Module): which.sort() hits = hits[which] - for i, which in enumerate(self.hit_infos): - data = hits[which] - points[:n_hits, i] = data - # last column is whether there was a hit or no - points[:n_hits, -1] = 1. + if self.padded: + points = np.zeros( + (self.max_n_hits, len(self.hit_infos) + 1), dtype="float32") + for i, which in enumerate(self.hit_infos): + points[:n_hits, i] = hits[which] + # last column is whether there was a hit or no + points[:n_hits, -1] = 1. + # store along new axis + points = np.expand_dims(points, 0) + else: + points = np.zeros( + (len(hits), len(self.hit_infos)), dtype="float32") + for i, which in enumerate(self.hit_infos): + points[:, i] = hits[which] return points, n_hits def finish(self): - return {"hit_infos": tuple(self.hit_infos) + ("is_valid", )} + columns = tuple(self.hit_infos) + if self.padded: + columns += ("is_valid", ) + return {"hit_infos": columns} class EventSkipper(kp.Module): diff --git a/orcasong/tools/concatenate.py b/orcasong/tools/concatenate.py index 0c47fb717d803cf6d627f5f6843f4c8a1326be68..2a29e26703080018cc99446295c27361774949a8 100644 --- a/orcasong/tools/concatenate.py +++ b/orcasong/tools/concatenate.py @@ -2,7 +2,6 @@ import os import time import h5py import numpy as np -import argparse import warnings @@ -29,8 +28,8 @@ class FileConcatenator: comptopts : dict Options for compression. They are read from the first input file, but they can be updated as well during init. - cumu_rows : np.array - The cumulative number of rows (axis_0) of the specified + cumu_rows : dict + Foe each dataset, the cumulative number of rows (axis_0) of the specified input .h5 files (i.e. [0,100,200,300,...] if each file has 100 rows). """ @@ -39,7 +38,6 @@ class FileConcatenator: print(f"Checking {len(input_files)} files ...") self.input_files, self.cumu_rows = self._get_cumu_rows(input_files) - print(f"Total rows:\t{self.cumu_rows[-1]}") # Get compression options from first file in the list self.comptopts = get_compopts(self.input_files[0]) @@ -111,44 +109,40 @@ class FileConcatenator: def _conc_file(self, f_in, f_out, input_file, input_file_nmbr): """ Conc one file to the output. """ - for folder_name in f_in: - if is_folder_ignored(folder_name): + for dset_name in f_in: + if is_folder_ignored(dset_name): # we dont need datasets created by pytables anymore continue - input_dataset = f_in[folder_name] + input_dataset = f_in[dset_name] folder_data = input_dataset[()] if input_file_nmbr > 0: # we need to add the current number of the # group_id / index in the file_output to the # group_ids / indices of the file that is to be appended - try: - if folder_name.endswith("_indices") and \ - "index" in folder_data.dtype.names: - column_name = "index" - elif "group_id" in folder_data.dtype.names: - column_name = "group_id" - else: - column_name = None - except TypeError: - column_name = None - if column_name is not None: - # add 1 because the group_ids / indices start with 0 - folder_data[column_name] += \ - np.amax(f_out[folder_name][column_name]) + 1 + last_index = self.cumu_rows[dset_name][input_file_nmbr] - 1 + if (dset_name.endswith("_indices") and + "index" in folder_data.dtype.names): + folder_data["index"] += ( + f_out[dset_name][last_index]["index"] + + f_out[dset_name][last_index]["n_items"] + ) + elif folder_data.dtype.names and "group_id" in folder_data.dtype.names: + # add 1 because the group_ids start with 0 + folder_data["group_id"] += f_out[dset_name][last_index]["group_id"] + 1 if self._modify_folder: data_mody = self._modify( - input_file, folder_data, folder_name) + input_file, folder_data, dset_name) if data_mody is not None: folder_data = data_mody if input_file_nmbr == 0: # first file; create the dataset - dset_shape = (self.cumu_rows[-1],) + folder_data.shape[1:] - print(f"\tCreating dataset '{folder_name}' with shape {dset_shape}") + dset_shape = (self.cumu_rows[dset_name][-1],) + folder_data.shape[1:] + print(f"\tCreating dataset '{dset_name}' with shape {dset_shape}") output_dataset = f_out.create_dataset( - folder_name, + dset_name, data=folder_data, maxshape=dset_shape, chunks=(self.comptopts["chunksize"],) + folder_data.shape[1:], @@ -156,37 +150,40 @@ class FileConcatenator: compression_opts=self.comptopts["complevel"], shuffle=self.comptopts["shuffle"], ) - output_dataset.resize(self.cumu_rows[-1], axis=0) + output_dataset.resize(self.cumu_rows[dset_name][-1], axis=0) else: - f_out[folder_name][ - self.cumu_rows[input_file_nmbr]:self.cumu_rows[input_file_nmbr + 1]] = folder_data + f_out[dset_name][ + self.cumu_rows[dset_name][input_file_nmbr]: + self.cumu_rows[dset_name][input_file_nmbr + 1] + ] = folder_data def _modify(self, input_file, folder_data, folder_name): raise NotImplementedError def _get_cumu_rows(self, input_files): """ - Get the cumulative number of rows of the input_files. - Also checks if all the files can be safely concatenated to the - first one. + Checks if all the files can be safely concatenated to the first one. """ # names of datasets that will be in the output; read from first file with h5py.File(input_files[0], 'r') as f: - keys_stripped = strip_keys(list(f.keys())) + target_datasets = strip_keys(list(f.keys())) + + errors, valid_input_files = [], [] + cumu_rows = {k: [0] for k in target_datasets} - errors, rows_per_file, valid_input_files = [], [0], [] for i, file_name in enumerate(input_files, start=1): file_name = os.path.abspath(file_name) try: - rows_this_file = _get_rows(file_name, keys_stripped) + rows_this_file = _get_rows(file_name, target_datasets) except Exception as e: errors.append(e) warnings.warn(f"Error during check of file {i}: {file_name}") continue valid_input_files.append(file_name) - rows_per_file.append(rows_this_file) + for k in target_datasets: + cumu_rows[k].append(cumu_rows[k][-1] + rows_this_file[k]) if errors: print("\n------- Errors -------\n----------------------") @@ -200,8 +197,8 @@ class FileConcatenator: raise OSError(err_str) print(f"Valid input files: {len(valid_input_files)}/{len(input_files)}") - print("Datasets:\t" + ", ".join(keys_stripped)) - return valid_input_files, np.cumsum(rows_per_file) + print("Datasets:\t" + ", ".join(target_datasets)) + return valid_input_files, cumu_rows def _get_rows(file_name, target_datasets): @@ -215,11 +212,15 @@ def _get_rows(file_name, target_datasets): f"It has {f.keys()} First file: {target_datasets}" ) # check if all target datasets in the file have the same length - rows = [f[k].shape[0] for k in target_datasets] - if not all(row == rows[0] for row in rows): + # for datasets that have indices: only check indices + plain_rows = [ + f[k].shape[0] for k in target_datasets + if not (f[k].attrs.get("indexed") and f"{k}_indices" in target_datasets) + ] + if not all(row == plain_rows[0] for row in plain_rows): raise ValueError( f"Datasets in file {file_name} have varying length! " - f"{dict(zip(target_datasets, rows))}" + f"{dict(zip(target_datasets, plain_rows))}" ) # check if the file has additional datasets apart from the target keys if not all(k in target_datasets for k in strip_keys(list(f.keys()))): @@ -229,7 +230,8 @@ def _get_rows(file_name, target_datasets): f"This file: {strip_keys(list(f.keys()))} " f"First file {target_datasets}" ) - return rows[0] + rows = {k: f[k].shape[0] for k in target_datasets} + return rows def strip_keys(f_keys): diff --git a/orcasong/tools/shuffle2.py b/orcasong/tools/shuffle2.py index 90859dc4412998e5cc21a85e2a7d69882fc592bb..e0c3f2d6e3882f1fb3be1403da320de3acdd3ef9 100644 --- a/orcasong/tools/shuffle2.py +++ b/orcasong/tools/shuffle2.py @@ -6,6 +6,7 @@ import numpy as np import psutil import h5py from km3pipe.sys import peak_memory_usage +import awkward as ak from orcasong.tools.postproc import get_filepath_output, copy_used_files from orcasong.tools.concatenate import copy_attrs @@ -103,9 +104,10 @@ def _shuffle_file( output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime()) ) with h5py.File(input_file, "r") as f_in: - _check_dsets(f_in, datasets) - dset_info = _get_largest_dset(f_in, datasets, max_ram) - print(f"Shuffling datasets {datasets}") + dsets = (*datasets, *_get_indexed_datasets(f_in, datasets)) + _check_dsets(f_in, dsets) + dset_info = _get_largest_dset(f_in, dsets, max_ram) + print(f"Shuffling datasets {dsets}") indices_per_batch = _get_indices_per_batch( dset_info["n_batches"], dset_info["n_chunks"], @@ -113,7 +115,7 @@ def _shuffle_file( ) with h5py.File(temp_output_file, "x") as f_out: - for dset_name in datasets: + for dset_name in dsets: print("Creating dataset", dset_name) _shuffle_dset(f_out, f_in, dset_name, indices_per_batch) print("Done!") @@ -188,28 +190,53 @@ def _get_largest_dset(f, datasets, max_ram): def _check_dsets(f, datasets): # check if all datasets have the same number of lines - n_lines_list = [len(f[dset_name]) for dset_name in datasets] + n_lines_list = [] + for dset_name in datasets: + if dset_is_indexed(f, dset_name): + dset_name = f"{dset_name}_indices" + n_lines_list.append(len(f[dset_name])) + if not all([n == n_lines_list[0] for n in n_lines_list]): raise ValueError( f"Datasets have different lengths! " f"{n_lines_list}" ) +def _get_indexed_datasets(f, datasets): + indexed_datasets = [] + for dset_name in datasets: + if dset_is_indexed(f, dset_name): + indexed_datasets.append(f"{dset_name}_indices") + return indexed_datasets + + def _get_dset_infos(f, datasets, max_ram): """ Retrieve infos for each dataset. """ dset_infos = [] for i, name in enumerate(datasets): - dset = f[name] + if name.endswith("_indices"): + continue + if dset_is_indexed(f, name): + # for indexed dataset: take average bytes in x per line in x_indices + dset_data = f[name] + name = f"{name}_indices" + dset = f[name] + bytes_per_line = ( + np.asarray(dset[0]).nbytes * + len(dset_data) / len(dset) + ) + else: + dset = f[name] + bytes_per_line = np.asarray(dset[0]).nbytes + n_lines = len(dset) chunksize = dset.chunks[0] n_chunks = int(np.ceil(n_lines / chunksize)) - bytes_per_line = np.asarray(dset[0]).nbytes bytes_per_chunk = bytes_per_line * chunksize chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk)) dset_infos.append({ "name": name, - "dset": dset, "n_chunks": n_chunks, "chunks_per_batch": chunks_per_batch, "n_batches": int(np.ceil(n_chunks / chunks_per_batch)), @@ -219,6 +246,16 @@ def _get_dset_infos(f, datasets, max_ram): return dset_infos +def dset_is_indexed(f, dset_name): + if f[dset_name].attrs.get("indexed"): + if f"{dset_name}_indices" not in f: + raise KeyError( + f"{dset_name} is indexed, but {dset_name}_indices is missing!") + return True + else: + return False + + def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch): """ Create a batchwise-shuffled dataset in the output file using given indices. @@ -229,7 +266,12 @@ def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch): for batch_number, indices in enumerate(indices_per_batch): print(f"Processing batch {batch_number+1}/{len(indices_per_batch)}") # remove indices outside of dset - indices = indices[indices < len(dset_in)] + if dset_is_indexed(f_in, dset_name): + max_index = len(f_in[f"{dset_name}_indices"]) + else: + max_index = len(dset_in) + indices = indices[indices < max_index] + # reading has to be done with linearly increasing index # fancy indexing is super slow # so sort -> turn to slices -> read -> conc -> undo sorting @@ -237,8 +279,27 @@ def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch): unsort_ix = np.argsort(sort_ix) fancy_indices = indices[sort_ix] slices = _slicify(fancy_indices) - data = np.concatenate([dset_in[slc] for slc in slices]) - data = data[unsort_ix] + + if dset_is_indexed(f_in, dset_name): + # special treatment for indexed: slice based on indices dataset + slices_indices = [f_in[f"{dset_name}_indices"][slc] for slc in slices] + data = np.concatenate([ + dset_in[slice(*_resolve_indexed(slc))] for slc in slices_indices + ]) + # convert to 3d awkward array, then shuffle, then back to numpy + data_indices = np.concatenate(slices_indices) + data_ak = ak.unflatten(data, data_indices["n_items"]) + data = ak.flatten(data_ak[unsort_ix], axis=1).to_numpy() + + else: + data = np.concatenate([dset_in[slc] for slc in slices]) + data = data[unsort_ix] + + if dset_name.endswith("_indices"): + # recacalculate index + data["index"] = start_idx + np.concatenate([ + [0], np.cumsum(data["n_items"][:-1]) + ]) if batch_number == 0: out_dset = f_out.create_dataset( @@ -271,6 +332,11 @@ def _slicify(fancy_indices): return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))] +def _resolve_indexed(ind): + # based on slice of x_indices, get where to slice in x + return ind["index"][0], ind["index"][-1] + ind["n_items"][-1] + + def _get_temp_filenames(output_file, number): path, file = os.path.split(output_file) return [os.path.join(path, f"temp_iteration_{i}_{file}") for i in range(number)]