From 105dec4c3a26f4ee05b83862828cc5c2161fbd64 Mon Sep 17 00:00:00 2001
From: Tamas Gal <tgal@km3net.de>
Date: Fri, 22 Nov 2019 23:57:21 +0100
Subject: [PATCH] Update ztplot

---
 scripts/ztplot.py | 135 ++++++++++++----------------------------------
 1 file changed, 34 insertions(+), 101 deletions(-)

diff --git a/scripts/ztplot.py b/scripts/ztplot.py
index 87db9a9..51bd36d 100755
--- a/scripts/ztplot.py
+++ b/scripts/ztplot.py
@@ -34,6 +34,7 @@ import numpy as np
 import km3pipe as kp
 from km3pipe.io.daq import is_3dmuon, is_3dshower, is_mxshower
 from km3modules.hits import count_multiplicities
+from km3modules.plot import ztplot
 import km3pipe.style
 km3pipe.style.use('km3pipe')
 
@@ -86,8 +87,10 @@ class ZTPlot(kp.Module):
         n_triggered_dus = len(np.unique(hits[hits.triggered == True].du))
         n_triggered_doms = len(np.unique(hits[hits.triggered == True].dom_id))
         if n_triggered_dus < self.min_dus or n_triggered_doms < self.min_doms:
-            print(f"Skipping event with {n_triggered_dus} DUs "
-                  f"and {n_triggered_doms} DOMs.")
+            print(
+                f"Skipping event with {n_triggered_dus} DUs "
+                f"and {n_triggered_doms} DOMs."
+            )
             return blob
 
         print("OK")
@@ -108,116 +111,44 @@ class ZTPlot(kp.Module):
 
     def create_plot(self, event_info, hits):
         print(self.__class__.__name__ + ": updating plot.")
+
         dus = set(hits.du)
         doms = set(hits.dom_id)
-        fontsize = 16
-
-        hits = hits.append_columns('multiplicity',
-                                   np.ones(len(hits))).sorted(by='time')
-
-        for dom in doms:
-            dom_hits = hits[hits.dom_id == dom]
-            mltps, m_ids = count_multiplicities(dom_hits.time)
-            hits['multiplicity'][hits.dom_id == dom] = mltps
-
-        time_offset = np.min(hits[hits.triggered == True].time)
-        hits.time -= time_offset
 
-        n_plots = len(dus)
-        n_cols = int(np.ceil(np.sqrt(n_plots)))
-        n_rows = int(n_plots / n_cols) + (n_plots % n_cols > 0)
-        marker_fig, marker_axes = plt.subplots()  # for the marker size hack...
-        fig, axes = plt.subplots(ncols=n_cols,
-                                 nrows=n_rows,
-                                 sharex=True,
-                                 sharey=True,
-                                 figsize=(16, 8),
-                                 constrained_layout=True)
-
-        axes = [axes] if n_plots == 1 else axes.flatten()
-
-        dom_zs = self.calib.detector.pmts.pos_z[
+        grid_lines = self.calib.detector.pmts.pos_z[
             (self.calib.detector.pmts.du == min(dus))
             & (self.calib.detector.pmts.channel_id == 0)]
 
-        for ax, du in zip(axes, dus):
-            for z in dom_zs:
-                ax.axhline(z, lw=1, color='b', ls='--', alpha=0.15)
-            du_hits = hits[hits.du == du]
-            trig_hits = du_hits[du_hits.triggered == True]
-
-            ax.scatter(du_hits.time,
-                       du_hits.pos_z,
-                       s=du_hits.multiplicity * 30,
-                       c='#09A9DE',
-                       label='hit',
-                       alpha=0.5)
-            ax.scatter(trig_hits.time,
-                       trig_hits.pos_z,
-                       s=trig_hits.multiplicity * 30,
-                       alpha=0.8,
-                       marker="+",
-                       c='#FF6363',
-                       label='triggered hit')
-            ax.set_title('DU{0}'.format(int(du)),
-                         fontsize=fontsize,
-                         fontweight='bold')
-
-            # The only way I could create a legend with matching marker sizes
-            max_multiplicity = int(np.max(du_hits.multiplicity))
-            markers = list(
-                range(0, max_multiplicity,
-                      np.ceil(max_multiplicity / 10).astype(int)))
-            custom_markers = [
-                marker_axes.scatter(
-                    [], [], s=mult * 30, color='#09A9DE', lw=0, alpha=0.5)
-                for mult in markers
-            ] + [marker_axes.scatter([], [], s=30, marker="+", c='#FF6363')]
-            ax.legend(custom_markers, ['multiplicity'] +
-                      ["       %d" % m for m in markers[1:]] + ['triggered'],
-                      scatterpoints=1,
-                      markerscale=1,
-                      loc='upper left',
-                      frameon=True,
-                      framealpha=0.7)
-
-        for idx, ax in enumerate(axes):
-            ax.set_ylim(0, self.max_z)
-            ax.tick_params(labelsize=fontsize)
-            ax.yaxis.set_major_locator(
-                ticker.MultipleLocator(self.ytick_distance))
-            xlabels = ax.get_xticklabels()
-            for label in xlabels:
-                label.set_rotation(45)
-
-            if idx % n_cols == 0:
-                ax.set_ylabel('z [m]', fontsize=fontsize)
-            if idx >= len(axes) - n_cols:
-                ax.set_xlabel('time [ns]', fontsize=fontsize)
-
         trigger_params = ' '.join([
-            trig
-            for trig, trig_check in (("MX", is_mxshower), ("3DM", is_3dmuon),
-                                     ("3DS", is_3dshower))
+            trig for trig, trig_check in (("MX",
+                                           is_mxshower), ("3DM", is_3dmuon),
+                                          ("3DS", is_3dshower))
             if trig_check(int(event_info.trigger_mask[0]))
         ])
 
-        plt.suptitle(
+        title = (
             "z-t-Plot for DetID-{0} (t0set: {1}), Run {2}, FrameIndex {3}, "
-            "TriggerCounter {4}, Overlays {5}, Trigger: {8}"
-            "\n{7} UTC (time offset: {6} ns)".format(
+            "TriggerCounter {4}, Overlays {5}, Trigger: {6}"
+            "\n{7} UTC".format(
                 event_info.det_id[0], self.t0set, event_info.run_id[0],
                 event_info.frame_index[0], event_info.trigger_counter[0],
-                event_info.overlays[0], time_offset,
-                datetime.utcfromtimestamp(event_info.utc_seconds),
-                trigger_params),
-            fontsize=fontsize,
-            y=1.05)
+                event_info.overlays[0], trigger_params,
+                datetime.utcfromtimestamp(event_info.utc_seconds)
+            )
+        )[0]
+
+        fig = ztplot(
+            hits,
+            title,
+            max_z=self.max_z,
+            ytick_distance=self.ytick_distance,
+            grid_lines=grid_lines
+        )
 
         filename = 'ztplot'
         f = os.path.join(self.plots_path, filename + '.png')
         f_tmp = os.path.join(self.plots_path, filename + '_tmp.png')
-        plt.savefig(f_tmp, dpi=120, bbox_inches="tight")
+        fig.savefig(f_tmp, dpi=120, bbox_inches="tight")
         if len(doms) > 4:
             plt.savefig(os.path.join(self.plots_path, filename + '_5doms.png'))
         plt.close('all')
@@ -237,12 +168,14 @@ def main():
     ligier_port = int(args['-p'])
 
     pipe = kp.Pipeline()
-    pipe.attach(kp.io.ch.CHPump,
-                host=ligier_ip,
-                port=ligier_port,
-                tags='IO_EVT, IO_SUM',
-                timeout=60 * 60 * 24 * 7,
-                max_queue=2000)
+    pipe.attach(
+        kp.io.ch.CHPump,
+        host=ligier_ip,
+        port=ligier_port,
+        tags='IO_EVT, IO_SUM',
+        timeout=60 * 60 * 24 * 7,
+        max_queue=2000
+    )
     pipe.attach(kp.io.daq.DAQProcessor)
     pipe.attach(ZTPlot, det_id=det_id, plots_path=plots_path)
     pipe.drain()
-- 
GitLab