Skip to content
Snippets Groups Projects
Commit 04928c6c authored by Stefan Reck's avatar Stefan Reck
Browse files

adjust names and default option for fixed_length is true now

parent add7ee7d
No related branches found
No related tags found
1 merge request!23Jagged graph
......@@ -355,42 +355,44 @@ class FileGraph(BaseProcessor):
Turn km3 events to graph data.
The resulting file will have a dataset "x" of shape
(?, max_n_hits, len(hit_infos) + 1).
(total n_hits, len(hit_infos)).
The column names of the last axis (i.e. hit_infos) are saved
as attributes of the dataset (f["x"].attrs).
The last column will always be called 'is_valid', and its 0 if
the entry is padded, and 1 otherwise.
Parameters
----------
max_n_hits : int
Maximum number of hits that gets saved per event. If an event has
more, some will get cut randomly!
padded : bool
If True, pad hits of each event with 0s to a fixed width, so that they can
be stored as 3d arrays. max_n_hits needs to be given in that case.
If False, save events with variable length as a 2d arrays
using km3pipe's indices.
time_window : tuple, optional
Two ints (start, end). Hits outside of this time window will be cut
away (based on 'Hits/time'). Default: Keep all hits.
hit_infos : tuple, optional
Which entries in the '/Hits' Table will be kept. E.g. pos_x, time, ...
Default: Keep all entries.
time_window : tuple, optional
Two ints (start, end). Hits outside of this time window will be cut
away (based on 'Hits/time'). Default: Keep all hits.
only_triggered_hits : bool
If true, use only triggered hits. Otherwise, use all hits.
If true, use only triggered hits. Otherwise, use all hits (default).
max_n_hits : int
Maximum number of hits that gets saved per event. If an event has
more, some will get cut randomly! Default: Keep all hits.
fixed_length : bool
If False (default), save hits of events with variable length as
2d arrays using km3pipe's indices.
If True, pad hits of each event with 0s to a fixed length,
so that they can be stored as 3d arrays like images.
max_n_hits needs to be given in that case, and a column will be
added called 'is_valid', which is 0 if the entry is padded,
and 1 otherwise.
This is inefficient and will cut off hits, so it should not be used.
kwargs
Options of the BaseProcessor.
"""
def __init__(self, max_n_hits=None,
padded=True,
time_window=None,
hit_infos=None,
only_triggered_hits=False,
fixed_length=False,
**kwargs):
self.max_n_hits = max_n_hits
self.padded = padded
self.fixed_length = fixed_length
self.time_window = time_window
self.hit_infos = hit_infos
self.only_triggered_hits = only_triggered_hits
......@@ -399,7 +401,7 @@ class FileGraph(BaseProcessor):
def get_cmpts_main(self):
return [((modules.PointMaker, {
"max_n_hits": self.max_n_hits,
"padded": self.padded,
"fixed_length": self.fixed_length,
"time_window": self.time_window,
"hit_infos": self.hit_infos,
"dset_n_hits": "EventInfo",
......@@ -410,4 +412,4 @@ class FileGraph(BaseProcessor):
super().finish_file(f, summary)
for i, hit_info in enumerate(summary["PointMaker"]["hit_infos"]):
f["x"].attrs.create(f"hit_info_{i}", hit_info)
f["x"].attrs.create("indexed", not self.padded)
f["x"].attrs.create("indexed", not self.fixed_length)
......@@ -276,36 +276,44 @@ class PointMaker(kp.Module):
Attributes
----------
max_n_hits : int
Maximum number of hits that gets saved per event. If an event has
more, some will get cut!
time_window : tuple, optional
Two ints (start, end). Hits outside of this time window will be cut
away (base on 'Hits/time').
Default: Keep all hits.
hit_infos : tuple, optional
Which entries in the '/Hits' Table will be kept. E.g. pos_x, time, ...
Default: Keep all entries.
time_window : tuple, optional
Two ints (start, end). Hits outside of this time window will be cut
away (based on 'Hits/time'). Default: Keep all hits.
only_triggered_hits : bool
If true, use only triggered hits. Otherwise, use all hits (default).
max_n_hits : int
Maximum number of hits that gets saved per event. If an event has
more, some will get cut randomly! Default: Keep all hits.
fixed_length : bool
If False (default), save hits of events with variable length as
2d arrays using km3pipe's indices.
If True, pad hits of each event with 0s to a fixed length,
so that they can be stored as 3d arrays like images.
max_n_hits needs to be given in that case, and a column will be
added called 'is_valid', which is 0 if the entry is padded,
and 1 otherwise.
This is inefficient and will cut off hits, so it should not be used.
dset_n_hits : str, optional
If given, store the number of hits that are in the time window
as a new column called 'n_hits_intime' in the dataset with
this name (usually this is EventInfo).
only_triggered_hits : bool
If true, use only triggered hits. Otherwise, use all hits.
"""
def configure(self):
self.max_n_hits = self.get("max_n_hits", default=None)
self.padded = self.get("padded", default=True)
self.hit_infos = self.get("hit_infos", default=None)
self.time_window = self.get("time_window", default=None)
self.dset_n_hits = self.get("dset_n_hits", default=None)
self.only_triggered_hits = self.get("only_triggered_hits", default=False)
self.max_n_hits = self.get("max_n_hits", default=None)
self.fixed_length = self.get("fixed_length", default=False)
self.dset_n_hits = self.get("dset_n_hits", default=None)
self.store_as = "samples"
def process(self, blob):
if self.padded and self.max_n_hits is None:
raise ValueError("Have to specify max_n_hits if padded is True")
if self.fixed_length and self.max_n_hits is None:
raise ValueError("Have to specify max_n_hits if fixed_length is True")
if self.hit_infos is None:
self.hit_infos = blob["Hits"].dtype.names
points, n_hits = self.get_points(blob)
......@@ -350,7 +358,7 @@ class PointMaker(kp.Module):
which.sort()
hits = hits[which]
if self.padded:
if self.fixed_length:
points = np.zeros(
(self.max_n_hits, len(self.hit_infos) + 1), dtype="float32")
for i, which in enumerate(self.hit_infos):
......@@ -368,7 +376,7 @@ class PointMaker(kp.Module):
def finish(self):
columns = tuple(self.hit_infos)
if self.padded:
if self.fixed_length:
columns += ("is_valid", )
return {"hit_infos": columns}
......
......@@ -141,7 +141,7 @@ class FileConcatenator:
compression_opts=self.comptopts["complevel"],
shuffle=self.comptopts["shuffle"],
)
output_dataset.resize(self.cumu_rows[dset_name][-1], axis=0)
output_dataset.resize(dset_shape[0], axis=0)
else:
f_out[dset_name][
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment