diff --git a/docs/conf.py b/docs/conf.py index efcb17330dad70c4bd0f0bf9a31a8e66aee1c5c7..b8b791e9a2a027aee3b0461beae5d90b44588491 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -219,4 +219,4 @@ todo_include_todos = True def setup(app): - app.add_stylesheet('_static/style.css') + app.add_css_file('_static/style.css') diff --git a/orcasong/core.py b/orcasong/core.py index 5840cc7799b433ba58ca95b9e4b334acb9a282b2..292efe531a7e7fd2471cf4c31bf0a4b963ddfb96 100644 --- a/orcasong/core.py +++ b/orcasong/core.py @@ -398,31 +398,31 @@ class FileGraph(BaseProcessor): hit_infos : tuple, optional Which entries in the '/Hits' Table will be kept. E.g. pos_x, time, ... Default: Keep all entries. + only_triggered_hits : bool + If true, use only triggered hits. Otherwise, use all hits. kwargs Options of the BaseProcessor. """ - - def __init__(self, max_n_hits, time_window=None, hit_infos=None, **kwargs): + def __init__(self, max_n_hits, + time_window=None, + hit_infos=None, + only_triggered_hits=False, + **kwargs): self.max_n_hits = max_n_hits self.time_window = time_window self.hit_infos = hit_infos + self.only_triggered_hits = only_triggered_hits super().__init__(**kwargs) def get_cmpts_main(self): - return [ - ( - ( - modules.PointMaker, - { - "max_n_hits": self.max_n_hits, - "time_window": self.time_window, - "hit_infos": self.hit_infos, - "dset_n_hits": "EventInfo", - }, - ) - ) - ] + return [((modules.PointMaker, { + "max_n_hits": self.max_n_hits, + "time_window": self.time_window, + "hit_infos": self.hit_infos, + "dset_n_hits": "EventInfo", + "only_triggered_hits": self.only_triggered_hits, + }))] def finish_file(self, f, summary): super().finish_file(f, summary) diff --git a/orcasong/modules.py b/orcasong/modules.py index 0ecb8cca2b9fec94dc26720993b9ff699f549c5e..8d825d45188ba9d90bfd41ec00835ac06711cb3b 100644 --- a/orcasong/modules.py +++ b/orcasong/modules.py @@ -288,6 +288,8 @@ class PointMaker(kp.Module): If given, store the number of hits that are in the time window as a new column called 'n_hits_intime' in the dataset with this name (usually this is EventInfo). + only_triggered_hits : bool + If true, use only triggered hits. Otherwise, use all hits. """ @@ -296,6 +298,7 @@ class PointMaker(kp.Module): self.hit_infos = self.get("hit_infos", default=None) self.time_window = self.get("time_window", default=None) self.dset_n_hits = self.get("dset_n_hits", default=None) + self.only_triggered_hits = self.get("only_triggered_hits", default=False) self.store_as = "samples" def process(self, blob): @@ -329,6 +332,8 @@ class PointMaker(kp.Module): points = np.zeros((self.max_n_hits, len(self.hit_infos) + 1), dtype="float32") hits = blob["Hits"] + if self.only_triggered_hits: + hits = hits[hits.triggered != 0] if self.time_window is not None: # remove hits outside of time window hits = hits[