diff --git a/orcasong_contrib/data_tools/shuffle/shuffle_h5.py b/orcasong_contrib/data_tools/shuffle/shuffle_h5.py index dfe1e80ba40f67151f26b1367c1e73ba7a7d1494..5c667105b7c067abad6c84e7a2f55b839a9edbab 100644 --- a/orcasong_contrib/data_tools/shuffle/shuffle_h5.py +++ b/orcasong_contrib/data_tools/shuffle/shuffle_h5.py @@ -126,7 +126,7 @@ def parse_input(): def shuffle_h5(filepath_input, tool=False, seed=42, delete=False, chunksize=None, complib=None, complevel=None, legacy_mode=False, shuffle=True, - event_skipper=None): + event_skipper=None, filepath_output=None): """ Shuffles a .h5 file where each dataset needs to have the same number of rows (axis_0). The shuffled data is saved to a new .h5 file with the suffix < _shuffled.h5 >. @@ -166,6 +166,9 @@ def shuffle_h5(filepath_input, tool=False, seed=42, delete=False, chunksize=None event_skipper : func, optional Function that takes the blob as an input, and returns a bool. If the bool is true, the blob will be skipped. + filepath_output : str, optional + If given, this will be the name of the output file. Otherwise, a name + is auto generated. Returns ------- @@ -185,13 +188,14 @@ def shuffle_h5(filepath_input, tool=False, seed=42, delete=False, chunksize=None if complib == 'lzf': complevel = None - filepath_input_without_ext = os.path.splitext(filepath_input)[0] - fname_adtn = '' - if shuffle: - fname_adtn += '_shuffled' - if event_skipper is not None: - fname_adtn += '_reb' - filepath_output = filepath_input_without_ext + fname_adtn + ".h5" + if filepath_output is None: + filepath_input_without_ext = os.path.splitext(filepath_input)[0] + fname_adtn = '' + if shuffle: + fname_adtn += '_shuffled' + if event_skipper is not None: + fname_adtn += '_reb' + filepath_output = filepath_input_without_ext + fname_adtn + ".h5" if not legacy_mode: # set random km3pipe (=numpy) seed diff --git a/orcasong_plag/core.py b/orcasong_plag/core.py index 181aa7fb69f274a0ff0a004f8d7a6d28bca68798..e8883c12d84df609f928453d666aee5a406999a4 100644 --- a/orcasong_plag/core.py +++ b/orcasong_plag/core.py @@ -26,7 +26,7 @@ class FileBinner: bin_edges_list : List List with the names of the fields to bin, and the respective bin edges, including the left- and right-most bin edge. - Example: + Example: For 10 bins in the z direction, and 100 bins in time: bin_edges_list = [ ["pos_z", np.linspace(0, 10, 11)], ["time", np.linspace(-50, 550, 101)], @@ -44,6 +44,10 @@ class FileBinner: It shows the distribution of hits, the bin edges, and how many hits were cut off for each field name in bin_edges_list. It will be saved to the same path as the outfile in run. + keep_event_info : bool + If True, will keep the "event_info" table. + keep_mc_tracks : bool + If True, will keep the "McTracks" table. n_statusbar : int, optional Print a statusbar every n blobs. n_memory_observer : int, optional @@ -70,6 +74,7 @@ class FileBinner: """ def __init__(self, bin_edges_list, mc_info_extr=None, event_skipper=None, add_bin_stats=True): + self.bin_edges_list = bin_edges_list self.mc_info_extr = mc_info_extr self.event_skipper = event_skipper @@ -79,6 +84,9 @@ class FileBinner: else: self.bin_plot_freq = None + self.keep_event_info = True + self.keep_mc_tracks = False + self.n_statusbar = 1000 self.n_memory_observer = 1000 self.do_time_preproc = True @@ -192,7 +200,12 @@ class FileBinner: mc_info_extr=mc_info_extr, store_as="mc_info") - pipe.attach(km.common.Keep, keys=['histogram', 'mc_info']) + keys_keep = ['histogram', 'mc_info'] + if self.keep_event_info: + keys_keep.append('EventInfo') + if self.keep_mc_tracks: + keys_keep.append('McTracks') + pipe.attach(km.common.Keep, keys=keys_keep) pipe.attach(kp.io.HDF5Sink, filename=outfile, diff --git a/orcasong_plag/mc_info_types.py b/orcasong_plag/mc_info_types.py index adcd6fe4934d1cce45181faeb81b12770a38a660..18cd75f12ebdbc2e458d7f397d44da80ac57a6f9 100644 --- a/orcasong_plag/mc_info_types.py +++ b/orcasong_plag/mc_info_types.py @@ -59,50 +59,76 @@ def get_mupage_mc(blob): The info for mc_info. """ - event_id = blob['EventInfo'].event_id[0] - run_id = blob["EventInfo"].run_id + # only one line has hits, but there are two for the mc. This one is active: + active_du = 2 + + track = dict() + + track["event_id"] = blob['EventInfo'].event_id[0] + track["run_id"] = blob["EventInfo"].run_id # run_id = blob['Header'].start_run.run_id.astype('float32') # take 0: assumed that this is the same for all muons in a bundle - particle_type = blob['McTracks'][0].type + track["particle_type"] = blob['McTracks'][0].type # always 1 actually - is_cc = blob['McTracks'][0].is_cc + track["is_cc"] = blob['McTracks'][0].is_cc # always 0 actually - bjorkeny = blob['McTracks'][0].bjorkeny + track["bjorkeny"] = blob['McTracks'][0].bjorkeny # same for all muons in a bundle #TODO not? - time_interaction = blob['McTracks'][0].time + track["time_interaction"] = blob['McTracks'][0].time # takes position of time_residual_vertex in 'neutrino' case - n_muons = blob['McTracks'].shape[0] + track["n_muons"] = blob['McTracks'].shape[0] # sum up the energy of all muons - energy = np.sum(blob['McTracks'].energy) + energy = blob['McTracks'].energy + track["energy"] = np.sum(energy) + # energy in the highest energy muons of the bundle + sorted_energy = np.sort(energy) + track["energy_highest_frac"] = sorted_energy[-1]/np.sum(energy) + if len(energy) > 1: + sec_highest_energy = sorted_energy[-2]/np.sum(energy) + else: + sec_highest_energy = 0. + track["energy_sec_highest_frac"] = sec_highest_energy + # coefficient of variation of energy + track["energy_cvar"] = np.std(energy)/np.mean(energy) + + # get how many mchits were produced per muon in the bundle + origin = blob["McHits"]["origin"][blob["McHits"]["du"] == active_du] + origin_dict = dict(zip(*np.unique(origin, return_counts=True))) + origin_list = [] + for i in range(1, track["n_muons"]+1): + origin_list.append(origin_dict.get(i, 0)) + origin_list = np.array(origin_list) + + # fraction of mchits in the highest mchit muons of the bundle + sorted_origin = np.sort(origin_list) + track["mchit_highest_frac"] = sorted_origin[-1] / np.sum(origin_list) + if len(sorted_origin) > 1: + sec_highest_mchit_frac = sorted_origin[-2] / np.sum(origin_list) + else: + sec_highest_mchit_frac = 0. + track["mchit_sec_highest_frac"] = sec_highest_mchit_frac + + # only muons with at least one mchit in active line + track["n_muons_visible"] = len(origin_list[origin_list > 0]) + # only muons with at least 5 mchits in active line + track["n_muons_thresh"] = len(origin_list[origin_list > 4]) + + # coefficient of variation of the origin of mc hits in the bundle + track["mchit_cvar"] = np.std(origin_list)/np.mean(origin_list) # all muons in a bundle are parallel, so just take dir of first muon - dir_x = blob['McTracks'][0].dir_x - dir_y = blob['McTracks'][0].dir_y - dir_z = blob['McTracks'][0].dir_z + track["dir_x"] = blob['McTracks'][0].dir_x + track["dir_y"] = blob['McTracks'][0].dir_y + track["dir_z"] = blob['McTracks'][0].dir_z # vertex is the weighted (energy) mean of the individual vertices - vertex_pos_x = np.average(blob['McTracks'][:].pos_x, - weights=blob['McTracks'][:].energy) - vertex_pos_y = np.average(blob['McTracks'][:].pos_y, - weights=blob['McTracks'][:].energy) - vertex_pos_z = np.average(blob['McTracks'][:].pos_z, - weights=blob['McTracks'][:].energy) - - track = {'event_id': event_id, - 'particle_type': particle_type, - 'energy': energy, - 'is_cc': is_cc, - 'bjorkeny': bjorkeny, - 'dir_x': dir_x, - 'dir_y': dir_y, - 'dir_z': dir_z, - 'time_interaction': time_interaction, - 'run_id': run_id, - 'vertex_pos_x': vertex_pos_x, - 'vertex_pos_y': vertex_pos_y, - 'vertex_pos_z': vertex_pos_z, - 'n_muons': n_muons} + track["vertex_pos_x"] = np.average(blob['McTracks'][:].pos_x, + weights=blob['McTracks'][:].energy) + track["vertex_pos_y"] = np.average(blob['McTracks'][:].pos_y, + weights=blob['McTracks'][:].energy) + track["vertex_pos_z"] = np.average(blob['McTracks'][:].pos_z, + weights=blob['McTracks'][:].energy) return track diff --git a/orcasong_plag/util/bin_stats_plot.py b/orcasong_plag/util/bin_stats_plot.py index 6b7a86deb9d6e1be40731f971c94af832bed2fd7..eb0c6f5c0650d7a7a8b74c169c874d755efd1dd9 100644 --- a/orcasong_plag/util/bin_stats_plot.py +++ b/orcasong_plag/util/bin_stats_plot.py @@ -186,7 +186,7 @@ def plot_hist_of_files(files, save_as): file.close() -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser(description='Plot the bin stats in h5 files') parser.add_argument('save_as', metavar='S', type=str, help='Where to save the plot to.') @@ -195,3 +195,7 @@ if __name__ == "__main__": args = parser.parse_args() plot_hist_of_files(args.files, args.save_as) + + +if __name__ == "__main__": + main()