Skip to content
Snippets Groups Projects
Commit ffc13cbe authored by Stefan Reck's avatar Stefan Reck
Browse files

Added my orcasong modification as stand alone module.

parent a12617cf
No related branches found
No related tags found
No related merge requests found
OrcaSong Plag
=============
Several changes to the original OrcaSong. Allows to set desired binning via
a list.
Does not contain all features of OrcaSong, like skipping events, plotting, etc.
\ No newline at end of file
import km3pipe as kp
import km3modules as km
from orcasong_plag.modules import TimePreproc, ImageMaker, McInfoMaker
from orcasong_plag.mc_info_types import get_mc_info_extr
class FileBinner:
"""
For making binned images.
Attributes
----------
bin_edges_list : List
List with the names of the fields to bin, and the respective bin edges,
including the left- and right-most bin edge.
Example:
bin_edges_list = [
["pos_z", np.linspace(0, 10, 11)],
["time", np.linspace(-50, 550, 101)],
]
mc_info_extr : function or string, optional
Function that extracts desired mc_info from a blob, which is then
stored as the "y" datafield in the .h5 file.
Can also give a str identifier for an existing extractor.
n_statusbar : int, optional
Print a statusbar every n blobs.
n_memory_observer : int, optional
Print memory usage every n blobs.
do_time_preproc : bool
Do time preprocessing, i.e. add t0 to real data, subtract time
of first triggered hit.
chunksize : int
Chunksize (along axis_0) used for saving the output to a .h5 file.
complib : str
Compression library used for saving the output to a .h5 file.
All PyTables compression filters are available, e.g. 'zlib',
'lzf', 'blosc', ... .
complevel : int
Compression level for the compression filter that is used for
saving the output to a .h5 file.
flush_frequency : int
After how many events the accumulated output should be flushed to
the harddisk.
A larger value leads to a faster orcasong execution,
but it increases the RAM usage as well.
"""
def __init__(self, bin_edges_list, mc_info_extr=None):
self.bin_edges_list = bin_edges_list
self.mc_info_extr = mc_info_extr
self.n_statusbar = 200
self.n_memory_observer = 400
self.do_time_preproc = True
# self.data_cuts = None
self.chunksize = 32
self.complib = 'zlib'
self.complevel = 1
self.flush_frequency = 1000
def run(self, infile, outfile):
"""
Build the pipeline to make images for the given file.
Parameters
----------
infile : str or List
Path to the input file(s).
outfile : str
Path to the output file.
"""
name, shape = self.get_name_and_shape()
print("Generating {} images with shape {}".format(name, shape))
pipe = kp.Pipeline()
if self.n_statusbar is not None:
pipe.attach(km.common.StatusBar, every=self.n_statusbar)
if self.n_memory_observer is not None:
pipe.attach(km.common.MemoryObserver, every=400)
if not isinstance(infile, list):
infile = [infile]
pipe.attach(kp.io.hdf5.HDF5Pump, filenames=infile)
self.attach_binning_modules(pipe)
pipe.attach(kp.io.HDF5Sink,
filename=outfile,
complib=self.complib,
complevel=self.complevel,
chunksize=self.chunksize,
flush_frequency=self.flush_frequency)
pipe.drain()
def attach_binning_modules(self, pipe):
"""
Attach modules to transform a blob to images and mc_info to a km3pipe.
"""
pipe.attach(km.common.Keep, keys=['EventInfo', 'Header', 'RawHeader',
'McTracks', 'Hits', 'McHits'])
if self.do_time_preproc:
pipe.attach(TimePreproc)
# if self.data_cuts is not None:
# from orcasong.utils import EventSkipper
# pipe.attach(EventSkipper, data_cuts=self.data_cuts)
pipe.attach(ImageMaker,
bin_edges_list=self.bin_edges_list,
store_as="histogram")
if self.mc_info_extr is not None:
if isinstance(self.mc_info_extr, str):
mc_info_extr = get_mc_info_extr(self.mc_info_extr)
else:
mc_info_extr = self.mc_info_extr
pipe.attach(McInfoMaker,
mc_info_extr=mc_info_extr,
store_as="mc_info")
pipe.attach(km.common.Keep, keys=['histogram', 'mc_info'])
def get_name_and_shape(self):
"""
Get name and shape of the resulting x data, e.g. "pos_z_time", (18, 50).
"""
res_names, shape = [], []
for bin_name, bin_edges in self.bin_edges_list:
res_names.append(bin_name)
shape.append(len(bin_edges) - 1)
name = "_".join(res_names)
shape = tuple(shape)
return name, shape
def __repr__(self):
name, shape = self.get_name_and_shape()
return "<FileBinner: {} {}>".format(name, shape)
"""
Functions that extract info from a blob for the mc_info / y datafield
in the h5 files.
"""
import numpy as np
def get_mc_info_extr(mc_info_type):
"""
Get an existing mc info extractor function.
"""
if mc_info_type == "mupage":
mc_info_extr = get_mupage_mc
else:
raise ValueError("Unknown mc_info_type " + mc_info_type)
return mc_info_extr
def get_mupage_mc(blob):
"""
For mupage muon simulations.
Parameters
----------
blob : dict
The blob from the pipeline.
Returns
-------
track : dict
The info for mc_info.
"""
event_id = blob['EventInfo'].event_id[0]
run_id = blob["EventInfo"].run_id
# run_id = blob['Header'].start_run.run_id.astype('float32')
# take 0: assumed that this is the same for all muons in a bundle
particle_type = blob['McTracks'][0].type
# always 1 actually
is_cc = blob['McTracks'][0].is_cc
# always 0 actually
bjorkeny = blob['McTracks'][0].bjorkeny
# same for all muons in a bundle #TODO not?
time_interaction = blob['McTracks'][0].time
# takes position of time_residual_vertex in 'neutrino' case
n_muons = blob['McTracks'].shape[0]
# sum up the energy of all muons
energy = np.sum(blob['McTracks'].energy)
# all muons in a bundle are parallel, so just take dir of first muon
dir_x = blob['McTracks'][0].dir_x
dir_y = blob['McTracks'][0].dir_y
dir_z = blob['McTracks'][0].dir_z
# vertex is the weighted (energy) mean of the individual vertices
vertex_pos_x = np.average(blob['McTracks'][:].pos_x,
weights=blob['McTracks'][:].energy)
vertex_pos_y = np.average(blob['McTracks'][:].pos_y,
weights=blob['McTracks'][:].energy)
vertex_pos_z = np.average(blob['McTracks'][:].pos_z,
weights=blob['McTracks'][:].energy)
track = {'event_id': event_id,
'particle_type': particle_type,
'energy': energy,
'is_cc': is_cc,
'bjorkeny': bjorkeny,
'dir_x': dir_x,
'dir_y': dir_y,
'dir_z': dir_z,
'time_interaction': time_interaction,
'run_id': run_id,
'vertex_pos_x': vertex_pos_x,
'vertex_pos_y': vertex_pos_y,
'vertex_pos_z': vertex_pos_z,
'n_muons': n_muons}
return track
"""
Custom km3pipe modules for making nn input files.
"""
import km3pipe as kp
import numpy as np
class McInfoMaker(kp.Module):
"""
Get the desired mc_info from the blob.
"""
def configure(self):
self.mc_info_extr = self.require('mc_info_extr')
self.store_as = self.require('store_as')
def process(self, blob):
track = self.mc_info_extr(blob)
dtypes = [(key, np.float64) for key in track.keys()]
kp_hist = kp.dataclasses.Table(track,
dtype=dtypes,
h5loc='y',
name='event_info')
blob[self.store_as] = kp_hist
return blob
class TimePreproc(kp.Module):
"""
Preprocess the time in the blob.
t0 will be added to the time for real data, but not simulations.
Time hits and mchits will be shifted by the time of the first
triggered hit.
"""
def configure(self):
self.correct_hits = self.get('correct_hits', default=True)
self.correct_mchits = self.get('correct_mchits', default=True)
def process(self, blob):
blob = time_preproc(blob, self.correct_hits, self.correct_mchits)
return blob
def time_preproc(blob, correct_hits=True, correct_mchits=True):
"""
Preprocess the time in the blob.
t0 will be added to the time for real data, but not simulations.
Time hits and mchits will be shifted by the time of the first
triggered hit.
"""
hits_time = blob["Hits"].time
if "McHits" not in blob:
# add t0 only for real data, not sims
hits_t0 = blob["Hits"].t0
hits_time = np.add(hits_time, hits_t0)
hits_triggered = blob["Hits"].triggered
t_first_trigger = np.min(hits_time[hits_triggered == 1])
if correct_hits:
blob["Hits"].time = np.subtract(hits_time, t_first_trigger)
if correct_mchits:
mchits_time = blob["McHits"].time
blob["McHits"].time = np.subtract(mchits_time, t_first_trigger)
return blob
class ImageMaker(kp.Module):
"""
Make a n-d histogram from the blob.
"""
def configure(self):
self.bin_edges_list = self.require('bin_edges_list')
self.store_as = self.require('store_as')
def process(self, blob):
data, bins, name = [], [], ""
for bin_name, bin_edges in self.bin_edges_list:
data.append(blob["Hits"][bin_name])
bins.append(bin_edges)
name += bin_name + "_"
histogram = np.histogramdd(data, bins=bins)[0]
title = name + "event_images"
hist_one_event = histogram[np.newaxis, ...].astype(np.uint8)
kp_hist = kp.dataclasses.NDArray(hist_one_event, h5loc='x', title=title)
blob[self.store_as] = kp_hist
return blob
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
For investigating the ideal binning, based on the info in calibrated
.h5 files.
Specialized classes TimePlotter and ZPlotter are available for plotting
the time/ Z-Coordinate.
"""
import numpy as np
import km3pipe as kp
import matplotlib.pyplot as plt
......@@ -8,10 +17,10 @@ import matplotlib.pyplot as plt
class FieldPlotter:
"""
For investigating the ideal binning, based on the info in calibrated
.h5 files.
Baseclass for investigating the ideal binning, based on the info in
a field of calibrated .h5 files.
Intended for 1d binning, like time or pos_z.
Intended for 1d binning, like for fields "time" or "pos_z".
Workflow:
1. Initialize with files, then run .plot() to extract and store
the data, and show the plot interactively.
......@@ -19,7 +28,7 @@ class FieldPlotter:
3. Run .plot() again to show the plot with the adjusted binning on the
stored data.
4. Repeat step 2 and 3 unitl happy with binning.
(5.) Save plot via .plot(savepath)
(5.) Save plot via .plot(savepath), or get the bin edges via .get_bin_edges()
The plot will have some bins attached in both directions for
better overview.
......@@ -30,15 +39,12 @@ class FieldPlotter:
The .h5 file(s).
field : str
The field to look stuff up, e.g. "time", "pos_z", ...
only_mc : bool
If true, will look up "McHits" in the blob. Otherwise "Hits".
center_events : int
For centering events with their median.
0 : No centering.
1 : Center with median of triggered hits.
2 : Center with median of all hits.
data : ndarray
The extracted data.
filter_for_du : int, optional
Only get hits from one specific du, specified by the int.
hits : ndarray
The extracted Hits.
mc_hits : ndarray
The extracted McHits, if present.
n_events : int
The number of events in the extracted data.
limits : List
......@@ -47,9 +53,11 @@ class FieldPlotter:
The number of bins.
plot_padding : List
Fraction of bins to append to left and right direction
(only in the plot).
(only in the plot for better overview).
x_label : str
X label of the plot.
y_label : str
Y label of the plot.
hist_kwargs : dict
Kwargs for plt.hist
xlim : List
......@@ -58,13 +66,13 @@ class FieldPlotter:
If True, auto plt.show() the plot.
"""
def __init__(self, files, field, only_mc=False, center_events=0):
def __init__(self, files, field):
self.files = files
self.field = field
self.only_mc = only_mc
self.center_events = center_events
self.filter_for_du = None
self.data = None
self.hits = None
self.mc_hits = None
self.n_events = None
self.limits = None
......@@ -79,27 +87,31 @@ class FieldPlotter:
"density": True,
}
self.xlim = None
self.ylim = None
self.show_plots = True
self.last_ylim = None
def plot(self, save_path=None):
def plot(self, only_mc_hits=False, save_path=None):
"""
Generate and store or load the data, then make the plot.
Parameters
----------
only_mc_hits : bool
If true, plot the McHits instead of the Hits.
save_path : str, optional
Save plot to here.
Returns
-------
fig : pyplot figure
fig, ax : pyplot figure
The plot.
"""
if self.data is None:
self.data = self.get_events_data()
fig = self.make_histogram(save_path)
return fig
if self.hits is None:
self.extract()
fig, ax = self.make_histogram(only_mc_hits, save_path)
return fig, ax
def set_binning(self, limits, n_bins):
"""
......@@ -118,7 +130,7 @@ class FieldPlotter:
def get_binning(self):
"""
Get the stored binning.
Set the desired binning.
Returns
-------
......@@ -130,17 +142,28 @@ class FieldPlotter:
"""
return self.limits, self.n_bins
def get_events_data(self):
def get_bin_edges(self):
"""
Get the bin edges as a ndarray.
"""
Get the content of a field from all events in the file(s).
limits, n_bins = self.get_binning()
if limits is None:
raise ValueError("Can not return bin edges: No binning limits set")
bin_edges = np.linspace(limits[0], limits[1], n_bins + 1)
return bin_edges
Returns:
--------
data : ndarray
The desired data.
def extract(self):
"""
Extract the content of a field from all events in the file(s) and
store it.
"""
data_all_events = None
mc_all_events = None
self.n_events = 0
if not isinstance(self.files, list):
......@@ -148,49 +171,67 @@ class FieldPlotter:
else:
files = self.files
for fname in files:
print("File " + fname)
event_pump = kp.io.hdf5.HDF5Pump(filename=fname)
event_pump = kp.io.hdf5.HDF5Pump(filenames=files)
for i, event_blob in enumerate(event_pump):
self.n_events += 1
if i % 2000 == 0:
print("Blob no. "+str(i))
for i, event_blob in enumerate(event_pump):
self.n_events += 1
data_one_event = self._get_hits(event_blob, get_mc_hits=False)
if i % 2000 == 1:
print("Blob no. "+str(i))
if data_all_events is None:
data_all_events = data_one_event
else:
data_all_events = np.concatenate(
[data_all_events, data_one_event], axis=0)
data_one_event = self._get_hits(event_blob)
if "McHits" in event_blob:
mc_one_event = self._get_hits(event_blob, get_mc_hits=True)
if data_all_events is None:
data_all_events = data_one_event
if mc_all_events is None:
mc_all_events = mc_one_event
else:
data_all_events = np.concatenate(
[data_all_events, data_one_event], axis=0)
mc_all_events = np.concatenate(
[mc_all_events, mc_one_event], axis=0)
event_pump.close()
print("Number of events: " + str(self.n_events))
return data_all_events
def make_histogram(self, save_path=None):
self.hits = data_all_events
self.mc_hits = mc_all_events
def make_histogram(self, only_mc_hits=False, save_path=None):
"""
Plot the hist data. Can also save it if given a save path.
Parameters
----------
only_mc_hits : bool
If true, plot the McHits instead of the Hits.
save_path : str, optional
Save the fig to this path.
Returns
-------
fig : pyplot figure
fig, ax : pyplot figure
The plot.
"""
if self.data is None:
if only_mc_hits:
data = self.mc_hits
else:
data = self.hits
if data is None:
raise ValueError("Can not make histogram, no data extracted yet.")
bin_edges = self._get_bin_edges()
bin_edges = self._get_padded_bin_edges()
fig, ax = plt.subplots()
n, bins, patches = plt.hist(self.data, bins=bin_edges, **self.hist_kwargs)
n, bins, patches = plt.hist(data, bins=bin_edges, **self.hist_kwargs)
print("Size of first bin: " + str(bins[1] - bins[0]))
plt.grid(True, zorder=0, linestyle='dotted')
......@@ -205,6 +246,9 @@ class FieldPlotter:
if self.xlim is not None:
plt.xlim(self.xlim)
if self.ylim is not None:
plt.ylim(self.ylim)
plt.ylabel(self.ylabel)
plt.tight_layout()
......@@ -215,9 +259,9 @@ class FieldPlotter:
if self.show_plots:
plt.show()
return fig
return fig, ax
def _get_bin_edges(self):
def _get_padded_bin_edges(self):
"""
Get the padded bin edges.
......@@ -246,36 +290,33 @@ class FieldPlotter:
return bin_edges
def _get_hits(self, event_blob):
def _get_hits(self, blob, get_mc_hits):
"""
Get desired attribute from a event blob.
Get desired attribute from an event blob.
Parameters
----------
event_blob
The km3pipe event blob.
blob
The blob.
get_mc_hits : bool
If true, will get the "McHits" instead of the "Hits".
Returns
-------
blob_data : ndarray
The desired data.
The data.
"""
if self.only_mc:
if get_mc_hits:
field_name = "McHits"
else:
field_name = "Hits"
blob_data = event_blob[field_name][self.field]
if self.center_events == 1:
triggered = event_blob[field_name].triggered
median_trigger = np.median(blob_data[triggered == 1])
blob_data = np.subtract(blob_data, median_trigger)
blob_data = blob[field_name][self.field]
elif self.center_events == 2:
median = np.median(blob_data)
blob_data = np.subtract(blob_data, median)
if self.filter_for_du is not None:
dus = blob[field_name]["du"]
blob_data = blob_data[dus == self.filter_for_du]
return blob_data
......@@ -292,40 +333,91 @@ class FieldPlotter:
xlabel = None
return xlabel
def __repr__(self):
return "<FieldPlotter: {}>".format(self.files)
class TimePreproc(kp.Module):
"""
Preprocess the time in the blob.
t0 will be added to the time for real data, but not simulations.
Time hits and mchits will be shifted by the time of the first triggered hit.
"""
def configure(self):
self.correct_hits = self.get('correct_hits', default=True)
self.correct_mchits = self.get('correct_mchits', default=True)
def process(self, blob):
blob = time_preproc(blob, self.correct_hits, self.correct_mchits)
return blob
def time_preproc(blob, correct_hits=True, correct_mchits=True):
"""
Preprocess the time in the blob.
t0 will be added to the time for real data, but not simulations.
Time hits and mchits will be shifted by the time of the first triggered hit.
"""
hits_time = blob["Hits"].time
if "McHits" not in blob:
# add t0 only for real data, not sims
hits_t0 = blob["Hits"].t0
hits_time = np.add(hits_time, hits_t0)
hits_triggered = blob["Hits"].triggered
t_first_trigger = np.min(hits_time[hits_triggered == 1])
if correct_hits:
blob["Hits"].time = np.subtract(hits_time, t_first_trigger)
if correct_mchits:
mchits_time = blob["McHits"].time
blob["McHits"].time = np.subtract(mchits_time, t_first_trigger)
return blob
class TimePlotter(FieldPlotter):
"""
For plotting the time.
"""
def __init__(self, files, only_mc=False):
def __init__(self, files):
field = "time"
FieldPlotter.__init__(self, files, field)
def _get_hits(self, blob, get_mc_hits):
blob = time_preproc(blob)
if only_mc:
center_events = 2
if get_mc_hits:
field_name = "McHits"
else:
center_events = 1
field_name = "Hits"
blob_data = blob[field_name][self.field]
FieldPlotter.__init__(self, files,
field,
only_mc=only_mc,
center_events=center_events)
if self.filter_for_du is not None:
dus = blob[field_name]["du"]
blob_data = blob_data[dus == self.filter_for_du]
return blob_data
class ZPlotter(FieldPlotter):
"""
For plotting the z dim.
"""
def __init__(self, files, only_mc=False):
def __init__(self, files):
field = "pos_z"
center_events = 0
FieldPlotter.__init__(self, files,
field,
only_mc=only_mc,
center_events=center_events)
FieldPlotter.__init__(self, files, field)
self.plotting_bins = 100
def _get_bin_edges(self):
def _get_padded_bin_edges(self):
"""
Get the padded bin edges.
......@@ -348,3 +440,6 @@ class ZPlotter(FieldPlotter):
n_bins + 1)
self.limits = bin_edges
self.n_bins = n_bins
def get_bin_edges(self):
return self.limits
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment