diff --git a/docs/tools.rst b/docs/tools.rst index f99eec09fa35ec3952605d9b0009a1b567794cdb..b5d3b328bdbeeaebef03c9c6431cd6559014655d 100644 --- a/docs/tools.rst +++ b/docs/tools.rst @@ -7,7 +7,8 @@ Concatenate ----------- Concatenate files resulting from OrcaSong, i.e. merge some h5 files -into a single, bigger one. +into a single, bigger one. The resulting file can still be read in with +km3pipe. Can be used via the commandline like so:: @@ -19,3 +20,19 @@ or import as: from orcasong.tools import FileConcatenator + +Shuffle +------- + +Shuffle an h5 file using km3pipe. + +Can be used via the commandline like so:: + + h5shuffle --help + +or import function for general postprocessing: + +.. code-block:: python + + from orcasong.tools.postproc import postproc_file + diff --git a/orcasong/core.py b/orcasong/core.py index e2065ee562fb4b5b69312d8e0179ec7abf008304..cf0c6895b6241d36a89a395f3d244e0785fc3213 100644 --- a/orcasong/core.py +++ b/orcasong/core.py @@ -53,6 +53,11 @@ class BaseProcessor: If True, will keep the "event_info" table [default: True]. keep_mc_tracks : bool If True, will keep the "McTracks" table. It's large! [default: False] + overwrite : bool + If True, overwrite the output file if it exists already. + If False, throw an error instead. + mc_info_to_float64 : bool + Convert everything in the mcinfo array to float 64 (Default: True). Attributes ---------- @@ -85,7 +90,9 @@ class BaseProcessor: event_skipper=None, chunksize=32, keep_event_info=True, - keep_mc_tracks=False): + keep_mc_tracks=False, + overwrite=True, + mc_info_to_float64=True): self.mc_info_extr = mc_info_extr self.det_file = det_file self.center_time = center_time @@ -95,6 +102,8 @@ class BaseProcessor: self.chunksize = chunksize self.keep_event_info = keep_event_info self.keep_mc_tracks = keep_mc_tracks + self.overwrite = overwrite + self.mc_info_to_float64 = mc_info_to_float64 self.n_statusbar = 1000 self.n_memory_observer = 1000 @@ -119,6 +128,9 @@ class BaseProcessor: if outfile is None: outfile = os.path.join(os.getcwd(), "{}_hist.h5".format( os.path.splitext(os.path.basename(infile))[0])) + if not self.overwrite: + if os.path.isfile(outfile): + raise FileExistsError(f"File exists: {outfile}") if self.seed: km.GlobalRandomState(seed=self.seed) pipe = self.build_pipe(infile, outfile) @@ -166,10 +178,7 @@ class BaseProcessor: def get_cmpts_pre(self, infile): """ Modules that read and calibrate the events. """ - cmpts = [] - cmpts.append((kp.io.hdf5.HDF5Pump, {"filename": infile})) - cmpts.append((km.common.Keep, {"keys": [ - 'EventInfo', 'Header', 'RawHeader', 'McTracks', 'Hits', 'McHits']})) + cmpts = [(kp.io.hdf5.HDF5Pump, {"filename": infile})] if self.det_file: cmpts.append((modules.DetApplier, {"det_file": self.det_file})) @@ -192,6 +201,7 @@ class BaseProcessor: if self.mc_info_extr is not None: cmpts.append((modules.McInfoMaker, { "mc_info_extr": self.mc_info_extr, + "to_float64": self.mc_info_to_float64, "store_as": "mc_info"})) if self.event_skipper is not None: @@ -331,8 +341,9 @@ class FileGraph(BaseProcessor): Turn km3 events to graph data. The resulting file will have a dataset "x" of shape - (?, max_n_hits, len(hit_infos) + 1), and its title (x.attrs["TITLE"]) - is the column names of the last axis, seperated by ', ' (= hit_infos). + (?, max_n_hits, len(hit_infos) + 1). + The column names of the last axis (i.e. hit_infos) are saved + as attributes of the dataset (f["x"].attrs). The last column will always be called 'is_valid', and its 0 if the entry is padded, and 1 otherwise. @@ -366,3 +377,8 @@ class FileGraph(BaseProcessor): "time_window": self.time_window, "hit_infos": self.hit_infos, "dset_n_hits": "EventInfo"}))] + + def finish_file(self, f, summary): + super().finish_file(f, summary) + for i, hit_info in enumerate(summary["PointMaker"]["hit_infos"]): + f["x"].attrs.create(f"hit_info_{i}", hit_info) diff --git a/orcasong/mc_info_extr.py b/orcasong/mc_info_extr.py index d3c4dd9aa5e362650c36a838be0a2aa535c32cd4..81456541d99e54bce65bf69be8e7db1b50c10985 100644 --- a/orcasong/mc_info_extr.py +++ b/orcasong/mc_info_extr.py @@ -1,6 +1,6 @@ """ Functions that extract info from a blob for the mc_info / y datafield -in the h5 files. +in the h5 files. Very much WIP. These are made for the specific given runs. They might not be applicable to other data, and could cause errors or produce unexpected @@ -20,11 +20,9 @@ def get_real_data(blob): Designed for the 2017 one line runs. """ - event_info = blob['EventInfo'] - + event_info = blob['EventInfo'][0] track = { 'event_id': event_info.event_id, - # was .event_id[0] up to km3pipe 8.16.0 'run_id': event_info.run_id, 'trigger_mask': event_info.trigger_mask, } diff --git a/orcasong/modules.py b/orcasong/modules.py index be2b1c0558362818a6f9f80789eb9570e99214d8..d08696b226d2361ec89e35a3d98afecd1a9e5729 100644 --- a/orcasong/modules.py +++ b/orcasong/modules.py @@ -27,10 +27,19 @@ class McInfoMaker(kp.Module): def configure(self): self.mc_info_extr = self.require('mc_info_extr') self.store_as = self.require('store_as') + self.to_float64 = self.get("to_float64", default=True) def process(self, blob): track = self.mc_info_extr(blob) - dtypes = [(key, np.float64) for key in track.keys()] + if self.to_float64: + dtypes = [] + for key, v in track.items(): + if key in ("group_id", "event_id"): + dtypes.append((key, type(v))) + else: + dtypes.append((key, np.float64)) + else: + dtypes = [(k, type(v)) for k, v in track.items()] kp_hist = kp.dataclasses.Table( track, dtype=dtypes, h5loc='y', name='event_info') if len(kp_hist) != 1: @@ -80,7 +89,8 @@ class TimePreproc(kp.Module): blob = self.timeslew(blob) if self.subtract_t0_mchits and self._has_mchits: blob = self.subtract_t0_mctime(blob) - blob = self.center_hittime(blob) + if self.center_time: + blob = self.center_hittime(blob) return blob @@ -110,9 +120,8 @@ class TimePreproc(kp.Module): hits_triggered = blob["Hits"].triggered t_first_trigger = np.min(hits_time[hits_triggered != 0]) - if self.center_time: - self._print_once("Centering time of Hits with first triggered hit") - blob["Hits"].time = np.subtract(hits_time, t_first_trigger) + self._print_once("Centering time of Hits with first triggered hit") + blob["Hits"].time = np.subtract(hits_time, t_first_trigger) if self._has_mchits: self._print_once("Centering time of McHits with first triggered hit") @@ -318,16 +327,13 @@ class PointMaker(kp.Module): self.time_window = self.get("time_window", default=None) self.dset_n_hits = self.get("dset_n_hits", default=None) self.store_as = "samples" - self._dset_name = None def process(self, blob): if self.hit_infos is None: self.hit_infos = blob["Hits"].dtype.names - if self._dset_name is None: - self._dset_name = ", ".join(tuple(self.hit_infos) + ("is_valid", )) points, n_hits = self.get_points(blob) blob[self.store_as] = kp.NDArray( - np.expand_dims(points, 0), h5loc="x", title=self._dset_name) + np.expand_dims(points, 0), h5loc="x", title="nodes") if self.dset_n_hits: blob[self.dset_n_hits] = blob[self.dset_n_hits].append_columns( "n_hits_intime", n_hits) @@ -375,6 +381,9 @@ class PointMaker(kp.Module): points[:n_hits, -1] = 1. return points, n_hits + def finish(self): + return {"hit_infos": tuple(self.hit_infos) + ("is_valid", )} + class EventSkipper(kp.Module): """ @@ -403,7 +412,10 @@ class EventSkipper(kp.Module): return blob def _remove_groupid(self, blob): - """ Workaround until bug in km3pipe is fixed: Drop all group_ids """ + """ + Workaround until bug https://git.km3net.de/km3py/km3pipe/-/issues/203 + in km3pipe is fixed: Drop all group_ids + """ if "GroupInfo" in blob: del blob["GroupInfo"] for key in blob.keys(): @@ -421,18 +433,6 @@ class EventSkipper(kp.Module): ) -def _remove_groupid(blob): - """ Workaround until bug in km3pipe is fixed: Drop all group_ids """ - if "GroupInfo" in blob: - del blob["GroupInfo"] - for key in blob.keys(): - try: - blob[key] = blob[key].drop_columns("group_id") - except AttributeError: - continue - return blob - - class DetApplier(kp.Module): """ Apply calibration to the Hits and McHits with a detx file. diff --git a/orcasong/tools/concatenate.py b/orcasong/tools/concatenate.py index 193934c3dbba9a805c3d2328425aced6b672c88f..74f9d74ffe14eecff4d2098783527bc3f30d62ce 100644 --- a/orcasong/tools/concatenate.py +++ b/orcasong/tools/concatenate.py @@ -10,12 +10,21 @@ __author__ = 'Stefan Reck, Michael Moser' class FileConcatenator: """ - For concatenating many small h5 files to a single large one. + For concatenating many small h5 files to a single large one in + km3pipe-compatible format. - Attributes + Parameters ---------- - input_files : list + input_files : List List that contains all filepaths of the input files. + skip_errors : bool + If true, ignore files that can't be concatenated. + comptopts_update : dict, optional + Overwrite the compression options that get read from the + first file. E.g. {'chunksize': 10} to get a chunksize of 10. + + Attributes + ---------- comptopts : dict Options for compression. They are read from the first input file, but they can be updated as well during init. @@ -25,20 +34,6 @@ class FileConcatenator: """ def __init__(self, input_files, skip_errors=False, comptopts_update=None): - """ - Check the files to concatenate. - - Parameters - ---------- - input_files : List - List that contains all filepaths of the input files. - comptopts_update : dict, optional - Overwrite the compression options that get read from the - first file. E.g. {'chunksize': 10} to get a chunksize of 10. - skip_errors : bool - If true, ignore files that can't be concatenated. - - """ self.skip_errors = skip_errors print(f"Checking {len(input_files)} files ...") @@ -101,29 +96,26 @@ class FileConcatenator: elapsed_time = time.time() - start_time if append_used_files: - # include the used filepaths in the file print("Adding used files to output") f_out.create_dataset( "used_files", data=[n.encode("ascii", "ignore") for n in self.input_files] ) + copy_attrs(self.input_files[0], output_filepath) + print(f"\nConcatenation complete!" f"\nElapsed time: {elapsed_time/60:.2f} min " f"({elapsed_time/len(self.input_files):.2f} s per file)") def _conc_file(self, f_in, f_out, input_file, input_file_nmbr): """ Conc one file to the output. """ - # create metadata - if input_file_nmbr == 0 and 'format_version' in list(f_in.attrs.keys()): - f_out.attrs['format_version'] = f_in.attrs['format_version'] - for folder_name in f_in: if is_folder_ignored(folder_name): # we dont need datasets created by pytables anymore continue - - folder_data = f_in[folder_name][()] + input_dataset = f_in[folder_name] + folder_data = input_dataset[()] if input_file_nmbr > 0: # we need to add the current number of the @@ -153,14 +145,12 @@ class FileConcatenator: if input_file_nmbr == 0: # first file; create the dataset dset_shape = (self.cumu_rows[-1],) + folder_data.shape[1:] - print(f"\tCreating dataset '{folder_name}' with shape " - f"{dset_shape}") + print(f"\tCreating dataset '{folder_name}' with shape {dset_shape}") output_dataset = f_out.create_dataset( folder_name, data=folder_data, maxshape=dset_shape, - chunks=(self.comptopts["chunksize"],) + folder_data.shape[ - 1:], + chunks=(self.comptopts["chunksize"],) + folder_data.shape[1:], compression=self.comptopts["complib"], compression_opts=self.comptopts["complevel"], shuffle=self.comptopts["shuffle"], @@ -290,6 +280,29 @@ def get_compopts(file): return comptopts +def copy_attrs(source_file, target_file): + """ + Copy file and dataset attributes from one h5 file to another. + """ + print("Copying attributes") + with h5py.File(source_file, "r") as src: + with h5py.File(target_file, "a") as trg: + _copy_attrs(src, trg) + for dset_name, target_dataset in trg.items(): + if dset_name in src: + _copy_attrs(src[dset_name], target_dataset) + + +def _copy_attrs(src_datset, target_dataset): + for k in src_datset.attrs.keys(): + try: + if k not in target_dataset.attrs: + target_dataset.attrs[k] = src_datset.attrs[k] + except TypeError as e: + # above can fail if attr is bool and created using pt + warnings.warn(f"Error: Can not copy attribute {k}: {e}") + + def get_parser(): parser = argparse.ArgumentParser( description='Concatenate many small h5 files to a single large one ' diff --git a/orcasong/tools/postproc.py b/orcasong/tools/postproc.py new file mode 100644 index 0000000000000000000000000000000000000000..b2191c8ab6abd5ca70143992c41d1b586ab89ab6 --- /dev/null +++ b/orcasong/tools/postproc.py @@ -0,0 +1,125 @@ +""" +Scripts for postprocessing h5 files, e.g. shuffling. +""" +import os +import argparse +import h5py +import km3pipe as kp +import km3modules as km +from orcasong.modules import EventSkipper +from orcasong.tools.concatenate import get_compopts, copy_attrs + + +def postproc_file( + input_file, + output_file=None, + shuffle=True, + event_skipper=None, + delete=False, + seed=42, + statusbar_every=1000): + """ + Postprocess a file using km3pipe after it has been preprocessed in OrcaSong. + + Parameters + ---------- + input_file : str + Path of the file that will be processed. + output_file : str, optional + If given, this will be the name of the output file. + Otherwise, a name is auto generated. + shuffle : bool + Shuffle order of events. + event_skipper : func, optional + Function that takes the blob as an input, and returns a bool. + If the bool is true, the event will be skipped. + delete : bool + Specifies if the input file should be deleted after processing. + seed : int + Sets a fixed random seed for the shuffling. + statusbar_every : int or None + After how many line a km3pipe status should be printed. + + Returns + ------- + output_file : str + Path to the output file. + + """ + if output_file is None: + output_file = get_filepath_output( + input_file, shuffle=shuffle, event_skipper=event_skipper) + if os.path.exists(output_file): + raise FileExistsError(output_file) + + print(f'Setting a Global Random State with the seed < {seed} >.') + km.GlobalRandomState(seed=seed) + + comptopts = get_compopts(input_file) + # km3pipe uses pytables for saving the shuffled output file, + # which has the name 'zlib' for the 'gzip' filter + if comptopts["complib"] == 'gzip': + comptopts["complib"] = 'zlib' + + pipe = kp.Pipeline() + if statusbar_every is not None: + pipe.attach(km.common.StatusBar, every=statusbar_every) + pipe.attach(km.common.MemoryObserver, every=statusbar_every) + pipe.attach( + kp.io.hdf5.HDF5Pump, + filename=input_file, + shuffle=shuffle, + reset_index=True, + ) + if event_skipper is not None: + pipe.attach(EventSkipper, event_skipper=event_skipper) + pipe.attach( + kp.io.hdf5.HDF5Sink, + filename=output_file, + complib=comptopts["complib"], + complevel=comptopts["complevel"], + chunksize=comptopts["chunksize"], + flush_frequency=1000, + ) + pipe.drain() + + copy_used_files(input_file, output_file) + copy_attrs(input_file, output_file) + if delete: + print("Deleting original file") + os.remove(input_file) + + print("Done!") + return output_file + + +def copy_used_files(source_file, target_file): + """ Copy the "used_files" dataset from one h5 file to another, if it is present. + """ + with h5py.File(source_file, "r") as src: + if "used_files" in src: + print("Copying used_files dataset") + with h5py.File(target_file, "a") as trg: + trg.create_dataset("used_files", data=src["used_files"]) + + +def get_filepath_output(input_file, shuffle=True, event_skipper=None): + """ Get the filename of the shuffled / rebalanced output file as a str. + """ + fname_adtn = '' + if shuffle: + fname_adtn += '_shuffled' + if event_skipper is not None: + fname_adtn += '_reb' + return f"{os.path.splitext(input_file)[0]}{fname_adtn}.h5" + + +def h5shuffle(): + parser = argparse.ArgumentParser(description='Shuffle an h5 file using km3pipe.') + parser.add_argument('input_file', type=str, help='File to shuffle.') + parser.add_argument('--output_file', type=str, + help='Name of output file. Default: Auto generate name.') + parser.add_argument('--delete', action="store_true", + help='Delete original file afterwards.') + + postproc_file(**vars(parser.parse_args()), shuffle=True, event_skipper=None) diff --git a/setup.py b/setup.py index a6c3389d431a46aac3532b1bd96c8b1b6662d8da..23137e0e2fd554dffc6366b1ff331497db9c901b 100644 --- a/setup.py +++ b/setup.py @@ -27,11 +27,11 @@ setup( 'tag_regex': r'^(?P<prefix>v)?(?P<version>[^\+]+)(?P<suffix>.*)?$', }, entry_points={'console_scripts': [ - 'make_nn_images=legacy.make_nn_images:main', - 'shuffle=orcasong_contrib.data_tools.shuffle.shuffle_h5:main', 'concatenate=orcasong.tools.concatenate:main', - 'make_dsplit=orcasong_contrib.data_tools.make_data_split.make_data_split:main', - 'plot_binstats=orcasong.plotting.plot_binstats:main']} + 'h5shuffle=orcasong.tools.postproc:h5shuffle', + 'plot_binstats=orcasong.plotting.plot_binstats:main', + 'make_nn_images=legacy.make_nn_images:main', + 'make_dsplit=orcasong_contrib.data_tools.make_data_split.make_data_split:main']} ) diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 88d43e5adb3eef07f9b8adcb8e1294ba7b1ba08c..47a4b4292483da5e3d6e6746a7847e1c712c9c75 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -70,6 +70,20 @@ class TestFileConcatenator(TestCase): [n.encode("ascii", "ignore") for n in self.dummy_files], ) + def test_concatenate_attrs(self): + fc = conc.FileConcatenator(self.dummy_files) + with tempfile.TemporaryFile() as tf: + fc.concatenate(tf) + with h5py.File(tf, "r") as f: + target_attrs = dict(f.attrs) + target_dset_attrs = dict(f["numpy_array"].attrs) + with h5py.File(self.dummy_files[0], "r") as f: + source_attrs = dict(f.attrs) + source_dset_attrs = dict(f["numpy_array"].attrs) + + self.assertDictEqual(source_attrs, target_attrs) + self.assertDictEqual(source_dset_attrs, target_dset_attrs) + def test_concatenate_array(self): fc = conc.FileConcatenator(self.dummy_files) with tempfile.TemporaryFile() as tf: @@ -103,13 +117,15 @@ class TestFileConcatenator(TestCase): def _create_dummy_file(filepath, columns=10, val_array=1, val_recarray=(1, 3)): """ Create a dummy h5 file with an array and a recarray in it. """ with h5py.File(filepath, "w") as f: - f.create_dataset( + dset = f.create_dataset( "numpy_array", data=np.ones(shape=(columns, 7, 3))*val_array, chunks=(5, 7, 3), compression="gzip", compression_opts=1 ) + dset.attrs.create("test_dset", "ok") + rec_array = np.array( [val_recarray + (1, )] * columns, dtype=[('x', '<f8'), ('y', '<i8'), ("group_id", "<i8")] @@ -122,3 +138,4 @@ def _create_dummy_file(filepath, columns=10, val_array=1, val_recarray=(1, 3)): compression="gzip", compression_opts=1 ) + f.attrs.create("test_file", "ok") diff --git a/tests/test_core.py b/tests/test_core.py index b385e6aba241c9306a8e2c3e1cb2e1a021631cf9..70156ec03df774eaf82195a2d0803e54e3b0c461 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -105,8 +105,16 @@ class TestFileGraph(TestCase): '_i_event_info', '_i_group_info', '_i_y', 'event_info', 'group_info', 'x', 'x_indices', 'y'}) - def test_x_title(self): - self.assertEqual(self.f["x"].attrs["TITLE"].decode(), "pos_z, time, channel_id, is_valid") + def test_x_attrs(self): + to_check = { + "hit_info_0": "pos_z", + "hit_info_1": "time", + "hit_info_2": "channel_id", + "hit_info_3": "is_valid", + } + attrs = dict(self.f["x"].attrs) + for k, v in to_check.items(): + self.assertTrue(attrs[k] == v) def test_x(self): target = np.array([ diff --git a/tests/test_modules.py b/tests/test_modules.py index 67f9732bf0008cbdb925956c69e7d673f12b03bd..242f26ef86e5688a68bdb8a65bf19f3feb273e27 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -30,9 +30,31 @@ class TestModules(TestCase): self.assertSequenceEqual(list(out_blob["test"].dtype.names), ('dom_id_0', 'time_2')) np.testing.assert_array_equal(out_blob["test"]["dom_id_0"], - np.array([2, ])) + np.array([2, ], dtype="float64")) np.testing.assert_array_equal(out_blob["test"]["time_2"], - np.array([12.3, ])) + np.array([12.3, ], dtype="float64")) + + def test_mc_info_maker_dtype(self): + """ Test the mcinfo maker on some dummy data. """ + def mc_info_extr(blob): + hits = blob["Hits"] + return {"dom_id_0": hits.dom_id[0], + "time_2": hits.time[2]} + + in_blob = { + "Hits": Table({ + 'dom_id': np.array([2, 3, 3], dtype="int8"), + 'time': np.array([10.1, 11.2, 12.3], dtype="float32"), + }) + } + module = modules.McInfoMaker( + mc_info_extr=mc_info_extr, store_as="test", to_float64=False) + out_blob = module.process(in_blob) + + np.testing.assert_array_equal( + out_blob["test"]["dom_id_0"], np.array([2, ], dtype="int8")) + np.testing.assert_array_equal( + out_blob["test"]["time_2"], np.array([12.3, ], dtype="float32")) def test_event_skipper(self): def event_skipper(blob): @@ -160,9 +182,11 @@ class TestPointMaker(TestCase): } def test_default_settings(self): - result = modules.PointMaker( - max_n_hits=4).process(self.input_blob_1)["samples"] - self.assertEqual(result.title, 't0, time, x, is_valid') + pm = modules.PointMaker( + max_n_hits=4) + result = pm.process(self.input_blob_1)["samples"] + self.assertTupleEqual( + pm.finish()["hit_infos"], ("t0", "time", "x", "is_valid")) target = np.array( [[[0.1, 1, 4, 1], [0.2, 2, 5, 1], @@ -171,13 +195,15 @@ class TestPointMaker(TestCase): np.testing.assert_array_equal(result, target) def test_input_blob_1(self): - result = modules.PointMaker( + pm = modules.PointMaker( max_n_hits=4, hit_infos=("x", "time"), time_window=None, dset_n_hits=None, - ).process(self.input_blob_1)["samples"] - self.assertEqual(result.title, 'x, time, is_valid') + ) + result = pm.process(self.input_blob_1)["samples"] + self.assertTupleEqual( + pm.finish()["hit_infos"], ("x", "time", "is_valid")) target = np.array( [[[4, 1, 1], [5, 2, 1], diff --git a/tests/test_postproc.py b/tests/test_postproc.py new file mode 100644 index 0000000000000000000000000000000000000000..917116895b3280a1ce1c6aa004d6b283ad533059 --- /dev/null +++ b/tests/test_postproc.py @@ -0,0 +1,33 @@ +from unittest import TestCase +import os +import h5py +import numpy as np +import orcasong.tools.postproc as postproc + +__author__ = 'Stefan Reck' + +test_dir = os.path.dirname(os.path.realpath(__file__)) +MUPAGE_FILE = os.path.join(test_dir, "data", "mupage.root.h5") + + +class TestPostproc(TestCase): + def setUp(self): + self.output_file = "temp_output.h5" + + def tearDown(self): + if os.path.exists(self.output_file): + os.remove(self.output_file) + + def test_shuffle(self): + postproc.postproc_file( + input_file=MUPAGE_FILE, + output_file=self.output_file, + shuffle=True, + event_skipper=None, + delete=False, + seed=13, + ) + + with h5py.File(self.output_file, "r") as f: + np.testing.assert_equal(f["event_info"]["event_id"], np.array([1, 0, 2])) + self.assertTrue("origin" in f.attrs.keys())