Skip to content
Snippets Groups Projects
Commit 5be9ee99 authored by Stefan Reck's avatar Stefan Reck
Browse files

Merge branch 'dev'

parents f6d406b6 b745211a
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,8 @@ Concatenate
-----------
Concatenate files resulting from OrcaSong, i.e. merge some h5 files
into a single, bigger one.
into a single, bigger one. The resulting file can still be read in with
km3pipe.
Can be used via the commandline like so::
......@@ -19,3 +20,19 @@ or import as:
from orcasong.tools import FileConcatenator
Shuffle
-------
Shuffle an h5 file using km3pipe.
Can be used via the commandline like so::
h5shuffle --help
or import function for general postprocessing:
.. code-block:: python
from orcasong.tools.postproc import postproc_file
......@@ -53,6 +53,11 @@ class BaseProcessor:
If True, will keep the "event_info" table [default: True].
keep_mc_tracks : bool
If True, will keep the "McTracks" table. It's large! [default: False]
overwrite : bool
If True, overwrite the output file if it exists already.
If False, throw an error instead.
mc_info_to_float64 : bool
Convert everything in the mcinfo array to float 64 (Default: True).
Attributes
----------
......@@ -85,7 +90,9 @@ class BaseProcessor:
event_skipper=None,
chunksize=32,
keep_event_info=True,
keep_mc_tracks=False):
keep_mc_tracks=False,
overwrite=True,
mc_info_to_float64=True):
self.mc_info_extr = mc_info_extr
self.det_file = det_file
self.center_time = center_time
......@@ -95,6 +102,8 @@ class BaseProcessor:
self.chunksize = chunksize
self.keep_event_info = keep_event_info
self.keep_mc_tracks = keep_mc_tracks
self.overwrite = overwrite
self.mc_info_to_float64 = mc_info_to_float64
self.n_statusbar = 1000
self.n_memory_observer = 1000
......@@ -119,6 +128,9 @@ class BaseProcessor:
if outfile is None:
outfile = os.path.join(os.getcwd(), "{}_hist.h5".format(
os.path.splitext(os.path.basename(infile))[0]))
if not self.overwrite:
if os.path.isfile(outfile):
raise FileExistsError(f"File exists: {outfile}")
if self.seed:
km.GlobalRandomState(seed=self.seed)
pipe = self.build_pipe(infile, outfile)
......@@ -166,10 +178,7 @@ class BaseProcessor:
def get_cmpts_pre(self, infile):
""" Modules that read and calibrate the events. """
cmpts = []
cmpts.append((kp.io.hdf5.HDF5Pump, {"filename": infile}))
cmpts.append((km.common.Keep, {"keys": [
'EventInfo', 'Header', 'RawHeader', 'McTracks', 'Hits', 'McHits']}))
cmpts = [(kp.io.hdf5.HDF5Pump, {"filename": infile})]
if self.det_file:
cmpts.append((modules.DetApplier, {"det_file": self.det_file}))
......@@ -192,6 +201,7 @@ class BaseProcessor:
if self.mc_info_extr is not None:
cmpts.append((modules.McInfoMaker, {
"mc_info_extr": self.mc_info_extr,
"to_float64": self.mc_info_to_float64,
"store_as": "mc_info"}))
if self.event_skipper is not None:
......@@ -331,8 +341,9 @@ class FileGraph(BaseProcessor):
Turn km3 events to graph data.
The resulting file will have a dataset "x" of shape
(?, max_n_hits, len(hit_infos) + 1), and its title (x.attrs["TITLE"])
is the column names of the last axis, seperated by ', ' (= hit_infos).
(?, max_n_hits, len(hit_infos) + 1).
The column names of the last axis (i.e. hit_infos) are saved
as attributes of the dataset (f["x"].attrs).
The last column will always be called 'is_valid', and its 0 if
the entry is padded, and 1 otherwise.
......@@ -366,3 +377,8 @@ class FileGraph(BaseProcessor):
"time_window": self.time_window,
"hit_infos": self.hit_infos,
"dset_n_hits": "EventInfo"}))]
def finish_file(self, f, summary):
super().finish_file(f, summary)
for i, hit_info in enumerate(summary["PointMaker"]["hit_infos"]):
f["x"].attrs.create(f"hit_info_{i}", hit_info)
"""
Functions that extract info from a blob for the mc_info / y datafield
in the h5 files.
in the h5 files. Very much WIP.
These are made for the specific given runs. They might not be
applicable to other data, and could cause errors or produce unexpected
......@@ -20,11 +20,9 @@ def get_real_data(blob):
Designed for the 2017 one line runs.
"""
event_info = blob['EventInfo']
event_info = blob['EventInfo'][0]
track = {
'event_id': event_info.event_id,
# was .event_id[0] up to km3pipe 8.16.0
'run_id': event_info.run_id,
'trigger_mask': event_info.trigger_mask,
}
......
......@@ -27,10 +27,19 @@ class McInfoMaker(kp.Module):
def configure(self):
self.mc_info_extr = self.require('mc_info_extr')
self.store_as = self.require('store_as')
self.to_float64 = self.get("to_float64", default=True)
def process(self, blob):
track = self.mc_info_extr(blob)
dtypes = [(key, np.float64) for key in track.keys()]
if self.to_float64:
dtypes = []
for key, v in track.items():
if key in ("group_id", "event_id"):
dtypes.append((key, type(v)))
else:
dtypes.append((key, np.float64))
else:
dtypes = [(k, type(v)) for k, v in track.items()]
kp_hist = kp.dataclasses.Table(
track, dtype=dtypes, h5loc='y', name='event_info')
if len(kp_hist) != 1:
......@@ -80,7 +89,8 @@ class TimePreproc(kp.Module):
blob = self.timeslew(blob)
if self.subtract_t0_mchits and self._has_mchits:
blob = self.subtract_t0_mctime(blob)
blob = self.center_hittime(blob)
if self.center_time:
blob = self.center_hittime(blob)
return blob
......@@ -110,9 +120,8 @@ class TimePreproc(kp.Module):
hits_triggered = blob["Hits"].triggered
t_first_trigger = np.min(hits_time[hits_triggered != 0])
if self.center_time:
self._print_once("Centering time of Hits with first triggered hit")
blob["Hits"].time = np.subtract(hits_time, t_first_trigger)
self._print_once("Centering time of Hits with first triggered hit")
blob["Hits"].time = np.subtract(hits_time, t_first_trigger)
if self._has_mchits:
self._print_once("Centering time of McHits with first triggered hit")
......@@ -318,16 +327,13 @@ class PointMaker(kp.Module):
self.time_window = self.get("time_window", default=None)
self.dset_n_hits = self.get("dset_n_hits", default=None)
self.store_as = "samples"
self._dset_name = None
def process(self, blob):
if self.hit_infos is None:
self.hit_infos = blob["Hits"].dtype.names
if self._dset_name is None:
self._dset_name = ", ".join(tuple(self.hit_infos) + ("is_valid", ))
points, n_hits = self.get_points(blob)
blob[self.store_as] = kp.NDArray(
np.expand_dims(points, 0), h5loc="x", title=self._dset_name)
np.expand_dims(points, 0), h5loc="x", title="nodes")
if self.dset_n_hits:
blob[self.dset_n_hits] = blob[self.dset_n_hits].append_columns(
"n_hits_intime", n_hits)
......@@ -375,6 +381,9 @@ class PointMaker(kp.Module):
points[:n_hits, -1] = 1.
return points, n_hits
def finish(self):
return {"hit_infos": tuple(self.hit_infos) + ("is_valid", )}
class EventSkipper(kp.Module):
"""
......@@ -403,7 +412,10 @@ class EventSkipper(kp.Module):
return blob
def _remove_groupid(self, blob):
""" Workaround until bug in km3pipe is fixed: Drop all group_ids """
"""
Workaround until bug https://git.km3net.de/km3py/km3pipe/-/issues/203
in km3pipe is fixed: Drop all group_ids
"""
if "GroupInfo" in blob:
del blob["GroupInfo"]
for key in blob.keys():
......@@ -421,18 +433,6 @@ class EventSkipper(kp.Module):
)
def _remove_groupid(blob):
""" Workaround until bug in km3pipe is fixed: Drop all group_ids """
if "GroupInfo" in blob:
del blob["GroupInfo"]
for key in blob.keys():
try:
blob[key] = blob[key].drop_columns("group_id")
except AttributeError:
continue
return blob
class DetApplier(kp.Module):
"""
Apply calibration to the Hits and McHits with a detx file.
......
......@@ -10,12 +10,21 @@ __author__ = 'Stefan Reck, Michael Moser'
class FileConcatenator:
"""
For concatenating many small h5 files to a single large one.
For concatenating many small h5 files to a single large one in
km3pipe-compatible format.
Attributes
Parameters
----------
input_files : list
input_files : List
List that contains all filepaths of the input files.
skip_errors : bool
If true, ignore files that can't be concatenated.
comptopts_update : dict, optional
Overwrite the compression options that get read from the
first file. E.g. {'chunksize': 10} to get a chunksize of 10.
Attributes
----------
comptopts : dict
Options for compression. They are read from the first input file,
but they can be updated as well during init.
......@@ -25,20 +34,6 @@ class FileConcatenator:
"""
def __init__(self, input_files, skip_errors=False, comptopts_update=None):
"""
Check the files to concatenate.
Parameters
----------
input_files : List
List that contains all filepaths of the input files.
comptopts_update : dict, optional
Overwrite the compression options that get read from the
first file. E.g. {'chunksize': 10} to get a chunksize of 10.
skip_errors : bool
If true, ignore files that can't be concatenated.
"""
self.skip_errors = skip_errors
print(f"Checking {len(input_files)} files ...")
......@@ -101,29 +96,26 @@ class FileConcatenator:
elapsed_time = time.time() - start_time
if append_used_files:
# include the used filepaths in the file
print("Adding used files to output")
f_out.create_dataset(
"used_files",
data=[n.encode("ascii", "ignore") for n in self.input_files]
)
copy_attrs(self.input_files[0], output_filepath)
print(f"\nConcatenation complete!"
f"\nElapsed time: {elapsed_time/60:.2f} min "
f"({elapsed_time/len(self.input_files):.2f} s per file)")
def _conc_file(self, f_in, f_out, input_file, input_file_nmbr):
""" Conc one file to the output. """
# create metadata
if input_file_nmbr == 0 and 'format_version' in list(f_in.attrs.keys()):
f_out.attrs['format_version'] = f_in.attrs['format_version']
for folder_name in f_in:
if is_folder_ignored(folder_name):
# we dont need datasets created by pytables anymore
continue
folder_data = f_in[folder_name][()]
input_dataset = f_in[folder_name]
folder_data = input_dataset[()]
if input_file_nmbr > 0:
# we need to add the current number of the
......@@ -153,14 +145,12 @@ class FileConcatenator:
if input_file_nmbr == 0:
# first file; create the dataset
dset_shape = (self.cumu_rows[-1],) + folder_data.shape[1:]
print(f"\tCreating dataset '{folder_name}' with shape "
f"{dset_shape}")
print(f"\tCreating dataset '{folder_name}' with shape {dset_shape}")
output_dataset = f_out.create_dataset(
folder_name,
data=folder_data,
maxshape=dset_shape,
chunks=(self.comptopts["chunksize"],) + folder_data.shape[
1:],
chunks=(self.comptopts["chunksize"],) + folder_data.shape[1:],
compression=self.comptopts["complib"],
compression_opts=self.comptopts["complevel"],
shuffle=self.comptopts["shuffle"],
......@@ -290,6 +280,29 @@ def get_compopts(file):
return comptopts
def copy_attrs(source_file, target_file):
"""
Copy file and dataset attributes from one h5 file to another.
"""
print("Copying attributes")
with h5py.File(source_file, "r") as src:
with h5py.File(target_file, "a") as trg:
_copy_attrs(src, trg)
for dset_name, target_dataset in trg.items():
if dset_name in src:
_copy_attrs(src[dset_name], target_dataset)
def _copy_attrs(src_datset, target_dataset):
for k in src_datset.attrs.keys():
try:
if k not in target_dataset.attrs:
target_dataset.attrs[k] = src_datset.attrs[k]
except TypeError as e:
# above can fail if attr is bool and created using pt
warnings.warn(f"Error: Can not copy attribute {k}: {e}")
def get_parser():
parser = argparse.ArgumentParser(
description='Concatenate many small h5 files to a single large one '
......
"""
Scripts for postprocessing h5 files, e.g. shuffling.
"""
import os
import argparse
import h5py
import km3pipe as kp
import km3modules as km
from orcasong.modules import EventSkipper
from orcasong.tools.concatenate import get_compopts, copy_attrs
def postproc_file(
input_file,
output_file=None,
shuffle=True,
event_skipper=None,
delete=False,
seed=42,
statusbar_every=1000):
"""
Postprocess a file using km3pipe after it has been preprocessed in OrcaSong.
Parameters
----------
input_file : str
Path of the file that will be processed.
output_file : str, optional
If given, this will be the name of the output file.
Otherwise, a name is auto generated.
shuffle : bool
Shuffle order of events.
event_skipper : func, optional
Function that takes the blob as an input, and returns a bool.
If the bool is true, the event will be skipped.
delete : bool
Specifies if the input file should be deleted after processing.
seed : int
Sets a fixed random seed for the shuffling.
statusbar_every : int or None
After how many line a km3pipe status should be printed.
Returns
-------
output_file : str
Path to the output file.
"""
if output_file is None:
output_file = get_filepath_output(
input_file, shuffle=shuffle, event_skipper=event_skipper)
if os.path.exists(output_file):
raise FileExistsError(output_file)
print(f'Setting a Global Random State with the seed < {seed} >.')
km.GlobalRandomState(seed=seed)
comptopts = get_compopts(input_file)
# km3pipe uses pytables for saving the shuffled output file,
# which has the name 'zlib' for the 'gzip' filter
if comptopts["complib"] == 'gzip':
comptopts["complib"] = 'zlib'
pipe = kp.Pipeline()
if statusbar_every is not None:
pipe.attach(km.common.StatusBar, every=statusbar_every)
pipe.attach(km.common.MemoryObserver, every=statusbar_every)
pipe.attach(
kp.io.hdf5.HDF5Pump,
filename=input_file,
shuffle=shuffle,
reset_index=True,
)
if event_skipper is not None:
pipe.attach(EventSkipper, event_skipper=event_skipper)
pipe.attach(
kp.io.hdf5.HDF5Sink,
filename=output_file,
complib=comptopts["complib"],
complevel=comptopts["complevel"],
chunksize=comptopts["chunksize"],
flush_frequency=1000,
)
pipe.drain()
copy_used_files(input_file, output_file)
copy_attrs(input_file, output_file)
if delete:
print("Deleting original file")
os.remove(input_file)
print("Done!")
return output_file
def copy_used_files(source_file, target_file):
""" Copy the "used_files" dataset from one h5 file to another, if it is present.
"""
with h5py.File(source_file, "r") as src:
if "used_files" in src:
print("Copying used_files dataset")
with h5py.File(target_file, "a") as trg:
trg.create_dataset("used_files", data=src["used_files"])
def get_filepath_output(input_file, shuffle=True, event_skipper=None):
""" Get the filename of the shuffled / rebalanced output file as a str.
"""
fname_adtn = ''
if shuffle:
fname_adtn += '_shuffled'
if event_skipper is not None:
fname_adtn += '_reb'
return f"{os.path.splitext(input_file)[0]}{fname_adtn}.h5"
def h5shuffle():
parser = argparse.ArgumentParser(description='Shuffle an h5 file using km3pipe.')
parser.add_argument('input_file', type=str, help='File to shuffle.')
parser.add_argument('--output_file', type=str,
help='Name of output file. Default: Auto generate name.')
parser.add_argument('--delete', action="store_true",
help='Delete original file afterwards.')
postproc_file(**vars(parser.parse_args()), shuffle=True, event_skipper=None)
......@@ -27,11 +27,11 @@ setup(
'tag_regex': r'^(?P<prefix>v)?(?P<version>[^\+]+)(?P<suffix>.*)?$', },
entry_points={'console_scripts': [
'make_nn_images=legacy.make_nn_images:main',
'shuffle=orcasong_contrib.data_tools.shuffle.shuffle_h5:main',
'concatenate=orcasong.tools.concatenate:main',
'make_dsplit=orcasong_contrib.data_tools.make_data_split.make_data_split:main',
'plot_binstats=orcasong.plotting.plot_binstats:main']}
'h5shuffle=orcasong.tools.postproc:h5shuffle',
'plot_binstats=orcasong.plotting.plot_binstats:main',
'make_nn_images=legacy.make_nn_images:main',
'make_dsplit=orcasong_contrib.data_tools.make_data_split.make_data_split:main']}
)
......
......@@ -70,6 +70,20 @@ class TestFileConcatenator(TestCase):
[n.encode("ascii", "ignore") for n in self.dummy_files],
)
def test_concatenate_attrs(self):
fc = conc.FileConcatenator(self.dummy_files)
with tempfile.TemporaryFile() as tf:
fc.concatenate(tf)
with h5py.File(tf, "r") as f:
target_attrs = dict(f.attrs)
target_dset_attrs = dict(f["numpy_array"].attrs)
with h5py.File(self.dummy_files[0], "r") as f:
source_attrs = dict(f.attrs)
source_dset_attrs = dict(f["numpy_array"].attrs)
self.assertDictEqual(source_attrs, target_attrs)
self.assertDictEqual(source_dset_attrs, target_dset_attrs)
def test_concatenate_array(self):
fc = conc.FileConcatenator(self.dummy_files)
with tempfile.TemporaryFile() as tf:
......@@ -103,13 +117,15 @@ class TestFileConcatenator(TestCase):
def _create_dummy_file(filepath, columns=10, val_array=1, val_recarray=(1, 3)):
""" Create a dummy h5 file with an array and a recarray in it. """
with h5py.File(filepath, "w") as f:
f.create_dataset(
dset = f.create_dataset(
"numpy_array",
data=np.ones(shape=(columns, 7, 3))*val_array,
chunks=(5, 7, 3),
compression="gzip",
compression_opts=1
)
dset.attrs.create("test_dset", "ok")
rec_array = np.array(
[val_recarray + (1, )] * columns,
dtype=[('x', '<f8'), ('y', '<i8'), ("group_id", "<i8")]
......@@ -122,3 +138,4 @@ def _create_dummy_file(filepath, columns=10, val_array=1, val_recarray=(1, 3)):
compression="gzip",
compression_opts=1
)
f.attrs.create("test_file", "ok")
......@@ -105,8 +105,16 @@ class TestFileGraph(TestCase):
'_i_event_info', '_i_group_info', '_i_y',
'event_info', 'group_info', 'x', 'x_indices', 'y'})
def test_x_title(self):
self.assertEqual(self.f["x"].attrs["TITLE"].decode(), "pos_z, time, channel_id, is_valid")
def test_x_attrs(self):
to_check = {
"hit_info_0": "pos_z",
"hit_info_1": "time",
"hit_info_2": "channel_id",
"hit_info_3": "is_valid",
}
attrs = dict(self.f["x"].attrs)
for k, v in to_check.items():
self.assertTrue(attrs[k] == v)
def test_x(self):
target = np.array([
......
......@@ -30,9 +30,31 @@ class TestModules(TestCase):
self.assertSequenceEqual(list(out_blob["test"].dtype.names),
('dom_id_0', 'time_2'))
np.testing.assert_array_equal(out_blob["test"]["dom_id_0"],
np.array([2, ]))
np.array([2, ], dtype="float64"))
np.testing.assert_array_equal(out_blob["test"]["time_2"],
np.array([12.3, ]))
np.array([12.3, ], dtype="float64"))
def test_mc_info_maker_dtype(self):
""" Test the mcinfo maker on some dummy data. """
def mc_info_extr(blob):
hits = blob["Hits"]
return {"dom_id_0": hits.dom_id[0],
"time_2": hits.time[2]}
in_blob = {
"Hits": Table({
'dom_id': np.array([2, 3, 3], dtype="int8"),
'time': np.array([10.1, 11.2, 12.3], dtype="float32"),
})
}
module = modules.McInfoMaker(
mc_info_extr=mc_info_extr, store_as="test", to_float64=False)
out_blob = module.process(in_blob)
np.testing.assert_array_equal(
out_blob["test"]["dom_id_0"], np.array([2, ], dtype="int8"))
np.testing.assert_array_equal(
out_blob["test"]["time_2"], np.array([12.3, ], dtype="float32"))
def test_event_skipper(self):
def event_skipper(blob):
......@@ -160,9 +182,11 @@ class TestPointMaker(TestCase):
}
def test_default_settings(self):
result = modules.PointMaker(
max_n_hits=4).process(self.input_blob_1)["samples"]
self.assertEqual(result.title, 't0, time, x, is_valid')
pm = modules.PointMaker(
max_n_hits=4)
result = pm.process(self.input_blob_1)["samples"]
self.assertTupleEqual(
pm.finish()["hit_infos"], ("t0", "time", "x", "is_valid"))
target = np.array(
[[[0.1, 1, 4, 1],
[0.2, 2, 5, 1],
......@@ -171,13 +195,15 @@ class TestPointMaker(TestCase):
np.testing.assert_array_equal(result, target)
def test_input_blob_1(self):
result = modules.PointMaker(
pm = modules.PointMaker(
max_n_hits=4,
hit_infos=("x", "time"),
time_window=None,
dset_n_hits=None,
).process(self.input_blob_1)["samples"]
self.assertEqual(result.title, 'x, time, is_valid')
)
result = pm.process(self.input_blob_1)["samples"]
self.assertTupleEqual(
pm.finish()["hit_infos"], ("x", "time", "is_valid"))
target = np.array(
[[[4, 1, 1],
[5, 2, 1],
......
from unittest import TestCase
import os
import h5py
import numpy as np
import orcasong.tools.postproc as postproc
__author__ = 'Stefan Reck'
test_dir = os.path.dirname(os.path.realpath(__file__))
MUPAGE_FILE = os.path.join(test_dir, "data", "mupage.root.h5")
class TestPostproc(TestCase):
def setUp(self):
self.output_file = "temp_output.h5"
def tearDown(self):
if os.path.exists(self.output_file):
os.remove(self.output_file)
def test_shuffle(self):
postproc.postproc_file(
input_file=MUPAGE_FILE,
output_file=self.output_file,
shuffle=True,
event_skipper=None,
delete=False,
seed=13,
)
with h5py.File(self.output_file, "r") as f:
np.testing.assert_equal(f["event_info"]["event_id"], np.array([1, 0, 2]))
self.assertTrue("origin" in f.attrs.keys())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment