diff --git a/orcasong/core.py b/orcasong/core.py index a91c4d9d219f8dcfb3eeacf359b1ed99382f7929..67bb916498db498950e25109b8c1c9e22d12cab0 100644 --- a/orcasong/core.py +++ b/orcasong/core.py @@ -367,6 +367,8 @@ class FileGraph(BaseProcessor): hit_infos : tuple, optional Which entries in the '/Hits' Table will be kept. E.g. pos_x, time, ... Default: Keep all entries. + only_triggered_hits : bool + If true, use only triggered hits. Otherwise, use all hits. kwargs Options of the BaseProcessor. @@ -374,10 +376,12 @@ class FileGraph(BaseProcessor): def __init__(self, max_n_hits, time_window=None, hit_infos=None, + only_triggered_hits=False, **kwargs): self.max_n_hits = max_n_hits self.time_window = time_window self.hit_infos = hit_infos + self.only_triggered_hits = only_triggered_hits super().__init__(**kwargs) def get_cmpts_main(self): @@ -385,7 +389,9 @@ class FileGraph(BaseProcessor): "max_n_hits": self.max_n_hits, "time_window": self.time_window, "hit_infos": self.hit_infos, - "dset_n_hits": "EventInfo"}))] + "dset_n_hits": "EventInfo", + "only_triggered_hits": self.only_triggered_hits, + }))] def finish_file(self, f, summary): super().finish_file(f, summary) diff --git a/orcasong/modules.py b/orcasong/modules.py index 7645fb776073121d63c4e076ea9cd41342bdc423..4cf4f4d4cd646ba917d078a4afa5ad31dc4c8b96 100644 --- a/orcasong/modules.py +++ b/orcasong/modules.py @@ -287,6 +287,8 @@ class PointMaker(kp.Module): If given, store the number of hits that are in the time window as a new column called 'n_hits_intime' in the dataset with this name (usually this is EventInfo). + only_triggered_hits : bool + If true, use only triggered hits. Otherwise, use all hits. """ def configure(self): @@ -294,6 +296,7 @@ class PointMaker(kp.Module): self.hit_infos = self.get("hit_infos", default=None) self.time_window = self.get("time_window", default=None) self.dset_n_hits = self.get("dset_n_hits", default=None) + self.only_triggered_hits = self.get("only_triggered_hits", default=False) self.store_as = "samples" def process(self, blob): @@ -326,6 +329,8 @@ class PointMaker(kp.Module): (self.max_n_hits, len(self.hit_infos) + 1), dtype="float32") hits = blob["Hits"] + if self.only_triggered_hits: + hits = hits[hits.triggered != 0] if self.time_window is not None: # remove hits outside of time window hits = hits[np.logical_and(