diff --git a/orcasong/core.py b/orcasong/core.py index 36e5f45fee91913b08ee05b634f807bb607230bb..8c06b8f1fd883797e945fead7ea517083d4110e7 100644 --- a/orcasong/core.py +++ b/orcasong/core.py @@ -355,42 +355,44 @@ class FileGraph(BaseProcessor): Turn km3 events to graph data. The resulting file will have a dataset "x" of shape - (?, max_n_hits, len(hit_infos) + 1). + (total n_hits, len(hit_infos)). The column names of the last axis (i.e. hit_infos) are saved as attributes of the dataset (f["x"].attrs). - The last column will always be called 'is_valid', and its 0 if - the entry is padded, and 1 otherwise. Parameters ---------- - max_n_hits : int - Maximum number of hits that gets saved per event. If an event has - more, some will get cut randomly! - padded : bool - If True, pad hits of each event with 0s to a fixed width, so that they can - be stored as 3d arrays. max_n_hits needs to be given in that case. - If False, save events with variable length as a 2d arrays - using km3pipe's indices. - time_window : tuple, optional - Two ints (start, end). Hits outside of this time window will be cut - away (based on 'Hits/time'). Default: Keep all hits. hit_infos : tuple, optional Which entries in the '/Hits' Table will be kept. E.g. pos_x, time, ... Default: Keep all entries. + time_window : tuple, optional + Two ints (start, end). Hits outside of this time window will be cut + away (based on 'Hits/time'). Default: Keep all hits. only_triggered_hits : bool - If true, use only triggered hits. Otherwise, use all hits. + If true, use only triggered hits. Otherwise, use all hits (default). + max_n_hits : int + Maximum number of hits that gets saved per event. If an event has + more, some will get cut randomly! Default: Keep all hits. + fixed_length : bool + If False (default), save hits of events with variable length as + 2d arrays using km3pipe's indices. + If True, pad hits of each event with 0s to a fixed length, + so that they can be stored as 3d arrays like images. + max_n_hits needs to be given in that case, and a column will be + added called 'is_valid', which is 0 if the entry is padded, + and 1 otherwise. + This is inefficient and will cut off hits, so it should not be used. kwargs Options of the BaseProcessor. """ def __init__(self, max_n_hits=None, - padded=True, time_window=None, hit_infos=None, only_triggered_hits=False, + fixed_length=False, **kwargs): self.max_n_hits = max_n_hits - self.padded = padded + self.fixed_length = fixed_length self.time_window = time_window self.hit_infos = hit_infos self.only_triggered_hits = only_triggered_hits @@ -399,7 +401,7 @@ class FileGraph(BaseProcessor): def get_cmpts_main(self): return [((modules.PointMaker, { "max_n_hits": self.max_n_hits, - "padded": self.padded, + "fixed_length": self.fixed_length, "time_window": self.time_window, "hit_infos": self.hit_infos, "dset_n_hits": "EventInfo", @@ -410,4 +412,4 @@ class FileGraph(BaseProcessor): super().finish_file(f, summary) for i, hit_info in enumerate(summary["PointMaker"]["hit_infos"]): f["x"].attrs.create(f"hit_info_{i}", hit_info) - f["x"].attrs.create("indexed", not self.padded) + f["x"].attrs.create("indexed", not self.fixed_length) diff --git a/orcasong/modules.py b/orcasong/modules.py index 69677eae35441e4acad465bc4229afee72106b6f..96f8fe241a8dc5011175cfd7929ace452e8c8f66 100644 --- a/orcasong/modules.py +++ b/orcasong/modules.py @@ -276,36 +276,44 @@ class PointMaker(kp.Module): Attributes ---------- - max_n_hits : int - Maximum number of hits that gets saved per event. If an event has - more, some will get cut! - time_window : tuple, optional - Two ints (start, end). Hits outside of this time window will be cut - away (base on 'Hits/time'). - Default: Keep all hits. hit_infos : tuple, optional Which entries in the '/Hits' Table will be kept. E.g. pos_x, time, ... Default: Keep all entries. + time_window : tuple, optional + Two ints (start, end). Hits outside of this time window will be cut + away (based on 'Hits/time'). Default: Keep all hits. + only_triggered_hits : bool + If true, use only triggered hits. Otherwise, use all hits (default). + max_n_hits : int + Maximum number of hits that gets saved per event. If an event has + more, some will get cut randomly! Default: Keep all hits. + fixed_length : bool + If False (default), save hits of events with variable length as + 2d arrays using km3pipe's indices. + If True, pad hits of each event with 0s to a fixed length, + so that they can be stored as 3d arrays like images. + max_n_hits needs to be given in that case, and a column will be + added called 'is_valid', which is 0 if the entry is padded, + and 1 otherwise. + This is inefficient and will cut off hits, so it should not be used. dset_n_hits : str, optional If given, store the number of hits that are in the time window as a new column called 'n_hits_intime' in the dataset with this name (usually this is EventInfo). - only_triggered_hits : bool - If true, use only triggered hits. Otherwise, use all hits. """ def configure(self): - self.max_n_hits = self.get("max_n_hits", default=None) - self.padded = self.get("padded", default=True) self.hit_infos = self.get("hit_infos", default=None) self.time_window = self.get("time_window", default=None) - self.dset_n_hits = self.get("dset_n_hits", default=None) self.only_triggered_hits = self.get("only_triggered_hits", default=False) + self.max_n_hits = self.get("max_n_hits", default=None) + self.fixed_length = self.get("fixed_length", default=False) + self.dset_n_hits = self.get("dset_n_hits", default=None) self.store_as = "samples" def process(self, blob): - if self.padded and self.max_n_hits is None: - raise ValueError("Have to specify max_n_hits if padded is True") + if self.fixed_length and self.max_n_hits is None: + raise ValueError("Have to specify max_n_hits if fixed_length is True") if self.hit_infos is None: self.hit_infos = blob["Hits"].dtype.names points, n_hits = self.get_points(blob) @@ -350,7 +358,7 @@ class PointMaker(kp.Module): which.sort() hits = hits[which] - if self.padded: + if self.fixed_length: points = np.zeros( (self.max_n_hits, len(self.hit_infos) + 1), dtype="float32") for i, which in enumerate(self.hit_infos): @@ -368,7 +376,7 @@ class PointMaker(kp.Module): def finish(self): columns = tuple(self.hit_infos) - if self.padded: + if self.fixed_length: columns += ("is_valid", ) return {"hit_infos": columns} diff --git a/orcasong/tools/concatenate.py b/orcasong/tools/concatenate.py index db1cb8e5d4fb5c80ee1656abdbd8dfce86ebf4a3..5896c39fad39c9706a9f9d52b25b1489ed8c23af 100644 --- a/orcasong/tools/concatenate.py +++ b/orcasong/tools/concatenate.py @@ -141,7 +141,7 @@ class FileConcatenator: compression_opts=self.comptopts["complevel"], shuffle=self.comptopts["shuffle"], ) - output_dataset.resize(self.cumu_rows[dset_name][-1], axis=0) + output_dataset.resize(dset_shape[0], axis=0) else: f_out[dset_name][