Skip to content
Snippets Groups Projects
Commit 37b95aa2 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Merge branch '58-uproot4-integration-2' into 'master'

Resolve "uproot4 integration"

Closes #58

See merge request !47
parents 7e16ce69 f26ee18a
No related branches found
No related tags found
1 merge request!47Resolve "uproot4 integration"
Pipeline #16353 passed with warnings
......@@ -17,3 +17,5 @@ exclude_lines =
if self.debug:
if settings.DEBUG
def __repr__
@njit
@nb.njit
......@@ -43,169 +43,133 @@ If you have a question about km3io, please proceed as follows:
- Haven't you found an answer to your question in the documentation, post a git issue with your question showing us an example of what you have tried first, and what you would like to do.
- Have you noticed a bug, please post it in a git issue, we appreciate your contribution.
Tutorial
========
**Table of contents:**
Introduction
------------
* `Introduction <#introduction>`__
Most of km3net data is stored in root files. These root files are created using the `KM3NeT Dataformat library <https://git.km3net.de/common/km3net-dataformat>`__
A ROOT file created with
`Jpp <https://git.km3net.de/common/jpp>`__ is an "online" file and all other software usually produces "offline" files.
* `Overview of online files <#overview-of-online-files>`__
km3io is a Python package that provides a set of classes: ``OnlineReader``, ``OfflineReader`` and a special class to read gSeaGen files. All of these ROOT files can be read installing any other software like Jpp, aanet or ROOT.
* `Overview of offline files <#overview-of-offline-files>`__
Data in km3io is returned as ``awkward.Array`` which is an advance Numpy-like container type to store
contiguous data for high performance computations.
Such an ``awkward.Array`` supports any level of nested arrays and records which can have different lengths, in contrast to Numpy where everything has to be rectangular.
* `Online files reader <#online-files-reader>`__
The example is shown below shows the array which contains the ``dir_z`` values
of each track of the first 4 events. The type ``4 * var * float64`` means that
it has 4 subarrays with variable lengths of type ``float64``:
* `Reading Events <#reading-events>`__
.. code-block:: python3
* `Reading SummarySlices <#reading-summaryslices>`__
>>> import km3io
>>> from km3net_testdata import data_path
>>> f = km3io.OfflineReader(data_path("offline/numucc.root"))
>>> f[:4].tracks.dir_z
<Array [[0.213, 0.213, ... 0.229, 0.323]] type='4 * var * float64'>
* `Reading Timeslices <#reading-timeslices>`__
The same concept applies to everything, including ``hits``, ``mc_hits``,
``mc_tracks``, ``t_sec`` etc.
* `Offline files reader <#offline-file-reader>`__
Offline files reader
--------------------
* `reading events data <#reading-events-data>`__
In general an offline file has two methods to fetch data: the header and the events. Let's start with the header.
* `reading usr data of events <#reading-usr-data-of-events>`__
Reading the file header
"""""""""""""""""""""""
* `reading hits data <#reading-hits-data>`__
To read an offline file start with opening it with an OfflineReader:
* `reading tracks data <#reading-tracks-data>`__
.. code-block:: python3
* `reading mc hits data <#reading-mc-hits-data>`__
>>> import km3io
>>> from km3net_testdata import data_path
>>> f = km3io.OfflineReader(data_path("offline/numucc.root"))
* `reading mc tracks data <#reading-mc-tracks-data>`__
Calling the header can be done with:
.. code-block:: python3
>>> f.header
<km3io.offline.Header at 0x7fcd81025990>
Introduction
------------
and provides lazy access. In offline files the header is unique and can be printed
Most of km3net data is stored in root files. These root files are either created with `Jpp <https://git.km3net.de/common/jpp>`__ or `aanet <https://git.km3net.de/common/aanet>`__ software. A root file created with
`Jpp <https://git.km3net.de/common/jpp>`__ is often referred to as "a Jpp root file". Similarly, a root file created with `aanet <https://git.km3net.de/common/aanet>`__ is often referred to as "an aanet file". In km3io, an aanet root file will always be reffered to as an ``offline file``, while a Jpp ROOT file will always be referred to as a ``online file``.
.. code-block:: python3
km3io is a Python package that provides a set of classes (``OnlineReader`` and ``OfflineReader``) to read both online ROOT files and offline ROOT files without any dependency to aanet, Jpp or ROOT.
>>> print(f.header)
MC Header:
DAQ(livetime=394)
PDF(i1=4, i2=58)
can(zmin=0, zmax=1027, r=888.4)
can_user: can_user(field_0=0.0, field_1=1027.0, field_2=888.4)
coord_origin(x=0, y=0, z=0)
cut_in(Emin=0, Emax=0, cosTmin=0, cosTmax=0)
cut_nu(Emin=100, Emax=100000000.0, cosTmin=-1, cosTmax=1)
cut_primary(Emin=0, Emax=0, cosTmin=0, cosTmax=0)
cut_seamuon(Emin=0, Emax=0, cosTmin=0, cosTmax=0)
decay: decay(field_0='doesnt', field_1='happen')
detector: NOT
drawing: Volume
genhencut(gDir=2000, Emin=0)
genvol(zmin=0, zmax=1027, r=888.4, volume=2649000000.0, numberOfEvents=100000)
kcut: 2
livetime(numberOfSeconds=0, errorOfSeconds=0)
model(interaction=1, muon=2, scattering=0, numberOfEnergyBins=1, field_4=12)
ngen: 100000.0
norma(primaryFlux=0, numberOfPrimaries=0)
nuflux: nuflux(field_0=0, field_1=3, field_2=0, field_3=0.5, field_4=0.0, field_5=1.0, field_6=3.0)
physics(program='GENHEN', version='7.2-220514', date=181116, time=1138)
seed(program='GENHEN', level=3, iseed=305765867, field_3=0, field_4=0)
simul(program='JSirene', version=11012, date='11/17/18', time=7)
sourcemode: diffuse
spectrum(alpha=-1.4)
start_run(run_id=1)
target: isoscalar
usedetfile: false
xlat_user: 0.63297
xparam: OFF
zed_user: zed_user(field_0=0.0, field_1=3450.0)
Data in km3io is often returned as a "lazyarray", a "jagged lazyarray" or a `Numpy <https://docs.scipy.org/doc/numpy>`__ array. A lazyarray is an array-like object that reads data on demand! In a lazyarray, only the first and the last chunks of data are read in memory. A lazyarray can be used with all Numpy's universal `functions <https://docs.scipy.org/doc/numpy/reference/ufuncs.html>`__. Here is how a lazyarray looks like:
To read the values in the header one can call them directly:
.. code-block:: python3
# <ChunkedArray [5971 5971 5971 ... 5971 5971 5971] at 0x7fb2341ad810>
>>> f.header.DAQ.livetime
394
>>> f.header.cut_nu.Emin
100
>>> f.header.genvol.numberOfEvents
100000
A jagged array, is a 2+ dimentional array with different arrays lengths. In other words, a jagged array is an array of arrays of different sizes. So a jagged lazyarray is simply a jagged array of lazyarrays with different sizes. Here is how a jagged lazyarray looks like:
Reading events
""""""""""""""
To start reading events call the events method on the file:
.. code-block:: python3
# <JaggedArray [[102 102 102 ... 11517 11518 11518] [] [101 101 102 ... 11518 11518 11518] ... [101 101 102 ... 11516 11516 11517] [] [101 101 101 ... 11517 11517 11518]] at 0x7f74b0ef8810>
Overview of Online files
""""""""""""""""""""""""
Online files are written by the DataWriter (part of Jpp) and contain events, timeslices and summary slices.
Overview of offline files
"""""""""""""""""""""""""
Offline files contain data about events, hits and tracks. Based on aanet version 2.0.0 documentation, the following tables show the definitions, the types and the units of the branches founds in the events, hits and tracks trees. A description of the file header are also displayed.
.. csv-table:: events keys definitions and units
:header: "type", "name", "definition"
:widths: 20, 20, 80
"int", "id", "offline event identifier"
"int", "det_id", "detector identifier from DAQ"
"int", "mc_id", "identifier of the MC event (as found in ascii or antcc file)"
"int", "run_id", "DAQ run identifier"
"int", "mc_run_id", "MC run identifier"
"int", "frame_index", "from the raw data"
"ULong64_t", "trigger_mask", "trigger mask from raw data (i.e. the trigger bits)"
"ULong64_t", "trigger_counter", "trigger counter"
"unsigned int", "overlays", "number of overlaying triggered events"
"TTimeStamp", "t", "UTC time of the start of the timeslice the event came from"
"vec Hit", "hits", "list of hits"
"vec Trk", "trks", "list of reconstructed tracks (can be several because of prefits,showers, etc)"
"vec double", "w", "MC: Weights w[0]=w1 & w[1]=w2 & w[2]]=w3"
"vec double", "w2list", "MC: factors that make up w[1]=w2"
"vec double", "w3list", "MC: atmospheric flux information"
"double", "mc_t", "MC: time of the mc event"
"vec Hit", "mc_hits", "MC: list of MC truth hits"
"vec Trk", "mc_trks", "MC: list of MC truth tracks"
"string", "comment", "user can use this as he/she likes"
"int", "index", "user can use this as he/she likes"
.. csv-table:: hits keys definitions and units
:header: "type", "name", "definition"
:widths: 20, 20, 80
"int", "id", "hit id"
"int", "dom_id", "module identifier from the data (unique in the detector)"
"unsigned int", "channel_id", "PMT channel id {0,1, .., 31} local to module"
"unsigned int", "tdc", "hit tdc (=time in ns)"
"unsigned int", "tot", "tot value as stored in raw data (int for pyroot)"
"int", "trig", "non-zero if the hit is a trigger hit"
"int", "pmt_id", "global PMT identifier as found in evt files"
"double", "t", "hit time (from calibration or MC truth)"
"double", "a", "hit amplitude (in p.e.)"
"vec", "pos", "hit position"
"vec", "dir", "hit direction i.e. direction of the PMT"
"double", "pure_t", "photon time before pmt simultion (MC only)"
"double", "pure_a", "amptitude before pmt simution (MC only)"
"int", "type", "particle type or parametrisation used for hit (mc only)"
"int", "origin", "track id of the track that created this hit"
"unsigned", "pattern_flags", "some number that you can use to flag the hit"
.. csv-table:: tracks keys definitions and units
:header: "type", "name", "definition"
:widths: 20, 20, 80
"int", "id", "track identifier"
"vec", "pos", "position of the track at time t"
"vec", "dir", "track direction"
"double", "t", "track time (when particle is at pos)"
"double", "E", "Energy (either MC truth or reconstructed)"
"double", "len", "length if applicable"
"double", "lik", "likelihood or lambda value (for aafit: lambda)"
"int", "type", "MC: particle type in PDG encoding"
"int", "rec_type", "identifyer for the overall fitting algorithm/chain/strategy"
"vec int", "rec_stages", "list of identifyers of succesfull fitting stages resulting in this track"
"int", "status", "MC status code"
"int", "mother_id", "MC id of the parent particle"
"vec double", "fitinf", "place to store additional fit info for jgandalf see FitParameters.csv"
"vec int", "hit_ids", "list of associated hit-ids (corresponds to Hit::id)"
"vec double", "error_matrix", "(5x5) error covariance matrix (stored as linear vector)"
"string", "comment", "user comment"
.. csv-table:: offline file header definitions
:header: "name", "definition"
:widths: 40, 80
"DAQ", "livetime"
"cut_primary cut_seamuon cut_in cut_nu", "Emin Emax cosTmin cosTmax"
"generator physics simul", "program version date time"
"seed", "program level iseed"
"PM1_type_area", "type area TTS"
"PDF", "i1 i2"
"model", "interaction muon scattering numberOfEnergyBins"
"can", "zmin zmax r"
"genvol", "zmin zmax r volume numberOfEvents"
"merge", "time gain"
"coord_origin", "x y z"
"translate", "x y z"
"genhencut", "gDir Emin"
"k40", "rate time"
"norma", "primaryFlux numberOfPrimaries"
"livetime", "numberOfSeconds errorOfSeconds"
"flux", "type key file_1 file_2"
"spectrum", "alpha"
"fixedcan", "xcenter ycenter zmin zmax radius"
"start_run", "run_id"
>>> f
OfflineReader (10 events)
>>> f.keys()
{'comment', 'det_id', 'flags', 'frame_index', 'hits', 'id', 'index',
'mc_hits', 'mc_id', 'mc_run_id', 'mc_t', 'mc_tracks', 'mc_trks',
'n_hits', 'n_mc_hits', 'n_mc_tracks', 'n_mc_trks', 'n_tracks',
'n_trks', 'overlays', 'run_id', 't_ns', 't_sec', 'tracks',
'trigger_counter', 'trigger_mask', 'trks', 'usr', 'usr_names',
'w', 'w2list', 'w3list'}
Like the online reader lazy access is used. Using <TAB> completion gives an overview of available data. Alternatively the method `keys` can be used on events and it's data members containing a structure to see what is available for reading.
Reading the reconstructed values like energy and direction of an event can be done with:
.. code-block:: python3
>>> f.events.tracks.E
<Array [[117, 117, 0, 0, 0, ... 0, 0, 0, 0, 0]] type='10 * var * float64'>
Online files reader
-------------------
......@@ -332,105 +296,3 @@ channel, time and ToT:
Offline files reader
--------------------
In general an offline file has two methods to fetch data: the header and the events. Let's start with the header.
Reading the file header
"""""""""""""""""""""""
To read an offline file start with opening it with an OfflineReader:
.. code-block:: python3
import km3io
f = km3io.OfflineReader("mcv5.0.gsg_elec-CC_1-500GeV.sirene.jte.jchain.jsh.aanet.1.root")
Calling the header can be done with:
.. code-block:: python3
>>> f.header
<km3io.offline.Header at 0x7fcd81025990>
and provides lazy access. In offline files the header is unique and can be printed
.. code-block:: python3
>>> print(f.header)
MC Header:
DAQ(livetime=35.5)
XSecFile: /project/antares/public_student_software/genie/v3.00.02-hedis/Generator/genie_xsec/gSeaGen/G18_02a_00_000/gxspl-seawater.xml
coord_origin(x=457.8, y=574.3, z=0)
cut_nu(Emin=1, Emax=500, cosTmin=-1, cosTmax=1)
drawing: surface
fixedcan(xcenter=457.8, ycenter=574.3, zmin=0, zmax=475.6, radius=308.2)
genvol(zmin=0, zmax=475.6, r=308.2, volume=148000000.0, numberOfEvents=1000000.0)
simul(program='gSeaGen', version='dev', date=200616, time=223726)
simul_1: simul_1(field_0='GENIE', field_1='3.0.2', field_2=200616, field_3=223726)
simul_2: simul_2(field_0='GENIE_REWEIGHT', field_1='1.0.0', field_2=200616, field_3=223726)
simul_3: simul_3(field_0='JSirene', field_1='13.0.0-alpha.5-113-gaa686a6a-D', field_2='06/17/20', field_3=0)
spectrum(alpha=-3)
start_run(run_id=1)
tgen: 31556900.0
An overview of the values in a the header are given in the `Overview of offline files <#overview-of-offline-files>`__.
To read the values in the header one can call them directly:
.. code-block:: python3
>>> f.header.DAQ.livetime
35.5
>>> f.header.cut_nu.Emin
1
>>> f.header.genvol.numberOfEvents
1000000.0
Reading events
""""""""""""""
To start reading events call the events method on the file:
.. code-block:: python3
>>> f.events
<OfflineBranch[events]: 355 elements>
Like the online reader lazy access is used. Using <TAB> completion gives an overview of available data. Alternatively the method `keys` can be used on events and it's data members containing a structure to see what is available for reading.
.. code-block:: python3
>>> f.events.keys()
dict_keys(['w2list', 'frame_index', 'overlays', 'comment', 'id', 'w', 'run_id', 'mc_t', 'mc_run_id', 'det_id', 'w3list', 'trigger_mask', 'mc_id', 'flags', 'trigger_counter', 'index', 't_sec', 't_ns', 'n_hits', 'n_mc_hits', 'n_tracks', 'n_mc_tracks'])
>>> f.events.tracks.keys()
dict_keys(['mother_id', 'status', 'lik', 'error_matrix', 'dir_z', 'len', 'rec_type', 'id', 't', 'dir_x', 'rec_stages', 'dir_y', 'fitinf', 'pos_z', 'hit_ids', 'comment', 'type', 'any', 'E', 'pos_y', 'usr_names', 'pos_x'])
Reading the reconstructed values like energy and direction of an event can be done with:
.. code-block:: python3
>>> f.events.tracks.E
<ChunkedArray [[3.8892237665736844 0.0 0.0 ... 0.0 0.0 0.0] [2.2293441683824318 5.203533524801224 6.083598278897039 ... 0.0 0.0 0.0] [3.044857858677666 3.787165776302862 4.5667729757360656 ... 0.0 0.0 0.0] ... [2.205652079790387 2.120769181474425 1.813066579943641 ... 0.0 0.0 0.0] [2.1000775068170343 3.939512272391431 3.697537355163539 ... 0.0 0.0 0.0] [4.213600763523154 1.7412855636388889 1.6657605276356036 ... 0.0 0.0 0.0]] at 0x7fcd5acb0950>
>>> f.events.tracks.E[12]
array([ 4.19391543, 15.3079374 , 10.47125863, ..., 0. ,
0. , 0. ])
>>> f.events.tracks.dir_z
<ChunkedArray [[0.7855203887479368 0.7855203887479368 0.7855203887479368 ... -0.5680647731737454 1.0 1.0] [0.9759269228630431 0.2677622006758061 -0.06664626796127045 ... -2.3205103555187022e-08 1.0 1.0] [-0.12332041078454238 0.09537382569575953 0.09345521875272474 ... -0.6631226836266504 -0.6631226836266504 -0.6631226836266504] ... [-0.1396584943602339 -0.08400681020109765 -0.014562067998281832 ... 1.0 1.0 1.0] [0.011997491147399564 -0.08496327394947281 -0.12675279061755318 ... 0.12053665899140412 1.0 1.0] [0.6548114607791208 0.8115427935470209 0.9043563059276946 ... 1.0 1.0 1.0]] at 0x7fcd73746410>
>>> f.events.tracks.dir_z[12]
array([ 2.39745910e-01, 3.45008838e-01, 4.81870447e-01, 4.55139657e-01, ...,
-2.32051036e-08, 1.00000000e+00])
Since reconstruction stages can be done multiple times and events can have multiple reconstructions, the vectors of reconstructed values can have variable length. Other data members like the header are always the same size. The definitions of data members can be found in the `definitions <https://git.km3net.de/km3py/km3io/-/tree/master/km3io/definitions>`__ folder. The definitions contain fit parameters, header information, reconstruction information, generator output and can be expaneded to include more.
To use the definitions imagine the following: the user wants to read out the MC value of the Bjorken-Y of event 12 that was generated with gSeaGen. This can be found in the `gSeaGen definitions <https://git.km3net.de/km3py/km3io/-/blob/master/km3io/definitions/w2list_gseagen.py>`__: `"W2LIST_GSEAGEN_BY": 8,`
This value is saved into `w2list`, so if an event is generated with gSeaGen the value can be fetched like:
.. code-block:: python3
>>> f.events.w2list[12][8]
0.393755
Note that w2list can also contain other values if the event is generated with another generator.
"""
Reading Offline hits
====================
The following example shows how to access hits data in an offline ROOT file, which is
written by aanet software.
Note: the offline file used here has MC offline data and was intentionaly reduced
to 10 events.
"""
import km3io as ki
from km3net_testdata import data_path
#####################################################
# To access offline hits/mc_hits data:
mc_hits = ki.OfflineReader(data_path("offline/numucc.root")).events.mc_hits
hits = ki.OfflineReader(data_path("offline/km3net_offline.root")).events.hits
#####################################################
# Note that not all data is loaded in memory, so printing
# hits will only return how many elements (events) were found in
# the hits branch of the file.
print(hits)
#####################################################
# same for mc hits
print(mc_hits)
#####################################################
# Accessing the hits/mc_hits keys
# -------------------------------
# to explore the hits keys:
keys = hits.keys()
print(keys)
#####################################################
# to explore the mc_hits keys:
mc_keys = mc_hits.keys()
print(mc_keys)
#####################################################
# Accessing hits data
# -------------------------
# to access data in dom_id:
dom_ids = hits.dom_id
print(dom_ids)
#####################################################
# to access the channel ids:
channel_ids = hits.channel_id
print(channel_ids)
#####################################################
# That's it! you can access any key of your interest in the hits
# keys in the exact same way.
#####################################################
# Accessing the mc_hits data
# --------------------------
# similarly, you can access mc_hits data in any key of interest by
# following the same procedure as for hits:
mc_pmt_ids = mc_hits.pmt_id
print(mc_pmt_ids)
#####################################################
# to access the mc_hits time:
mc_t = mc_hits.t
print(mc_t)
#####################################################
# item selection in hits data
# ---------------------------
# hits data can be selected as you would select an item from a numpy array.
# for example, to select DOM ids in the hits corresponding to the first event:
print(hits[0].dom_id)
#####################################################
# or:
print(hits.dom_id[0])
#####################################################
# slicing of hits
# ---------------
# to select a slice of hits data:
print(hits[0:3].channel_id)
#####################################################
# or:
print(hits.channel_id[0:3])
#####################################################
# you can apply masks to hits data as you would do with numpy arrays:
mask = hits.channel_id > 10
print(hits.channel_id[mask])
#####################################################
# or:
print(hits.dom_id[mask])
......@@ -2,8 +2,7 @@
Reading Offline tracks
======================
The following example shows how to access tracks data in an offline ROOT file, which is
written by aanet software.
The following example shows how to access tracks data in an offline ROOT file.
Note: the offline files used here were intentionaly reduced to 10 events.
"""
......@@ -11,107 +10,100 @@ import km3io as ki
from km3net_testdata import data_path
#####################################################
# To access offline tracks/mc_tracks data:
# We open the file using the
f = ki.OfflineReader(data_path("offline/numucc.root"))
mc_tracks = ki.OfflineReader(data_path("offline/numucc.root")).events.mc_tracks
tracks = ki.OfflineReader(data_path("offline/km3net_offline.root")).events.tracks
#####################################################
# To access offline tracks/mc_tracks data:
f.tracks
f.mc_tracks
#####################################################
# Note that not all data is loaded in memory, so printing
# tracks will only return how many elements (events) were found in
# the tracks branch of the file.
# Note that no data is loaded in memory at this point, so printing
# tracks will only return how many sub-branches (corresponding to
# events) were found.
print(tracks)
f.tracks
#####################################################
# same for mc hits
print(mc_tracks)
f.mc_tracks
#####################################################
# Accessing the tracks/mc_tracks keys
# -----------------------------------
# to explore the tracks keys:
# to explore the reconstructed tracks fields:
keys = tracks.keys()
print(keys)
f.tracks.fields
#####################################################
# to explore the mc_tracks keys:
# the same for MC tracks
mc_keys = mc_tracks.keys()
print(mc_keys)
f.mc_tracks.fields
#####################################################
# Accessing tracks data
# ---------------------
# to access data in `E` (tracks energy):
# each field will return a nested `awkward.Array` and load everything into
# memory, so be careful if you are working with larger files.
E = tracks.E
print(E)
f.tracks.E
#####################################################
# to access the likelihood:
######################################################
# The z direction of all reconstructed tracks
likelihood = tracks.lik
print(likelihood)
f.tracks.dir_z
#####################################################
# That's it! you can access any key of your interest in the tracks
# keys in the exact same way.
######################################################
# The likelihoods
#####################################################
# Accessing the mc_tracks data
# ----------------------------
# similarly, you can access mc_tracks data in any key of interest by
# following the same procedure as for tracks:
cos_zenith = mc_tracks.dir_z
print(cos_zenith)
f.tracks.lik
#####################################################
# or:
# To select just a single event or a subset of events, use the indices or slices.
# The following will access all tracks and their fields
# of the third event (0 is the first):
dir_y = mc_tracks.dir_y
print(dir_y)
f[2].tracks
######################################################
# The z direction of all tracks in the third event:
#####################################################
# item selection in tracks data
# -----------------------------
# tracks data can be selected as you would select an item from a numpy array.
# for example, to select E (energy) in the tracks corresponding to the first event:
f[2].tracks.dir_z
print(tracks[0].E)
#####################################################
# or:
# while here, we select the first 3 events. Notice that all fields will return
# nested arrays, as we have seem above where all events were selected.
print(tracks.E[0])
f[:3]
######################################################
# All tracks for the first three events
#####################################################
# slicing of tracks
# -----------------
# to select a slice of tracks data:
f[:3].tracks
print(tracks[0:3].E)
######################################################
# The z directions of all tracks of the first three events
f[:3].tracks.dir_z
#####################################################
# or:
# or events from 3 and 5 (again, 0 indexing):
print(tracks.E[0:3])
#####################################################
# you can apply masks to tracks data as you would do with numpy arrays:
f[2:5]
mask = tracks.lik > 100
######################################################
# the tracks of those events
print(tracks.lik[mask])
f[2:5].tracks
#####################################################
# or:
######################################################
# and just the z directions of those
print(tracks.dir_z[mask])
f[2:5].tracks.dir_z
......@@ -18,19 +18,12 @@ r = ki.OfflineReader(data_path("offline/usr-sample.root"))
#####################################################
# Accessing the usr data:
# Accessing the usr fields:
usr = r.events.usr
print(usr)
print(r.events.usr_names.tolist())
#####################################################
# to access data of a specific key, you can either do:
# to access data of a specific key:
print(usr.DeltaPosZ)
#####################################################
# or
print(usr["RecoQuality"])
print(ki.tools.usr(r.events, "DeltaPosZ"))
......@@ -5,4 +5,3 @@ version = get_distribution(__name__).version
from .offline import OfflineReader
from .online import OnlineReader
from .gseagen import GSGReader
from . import patches
......@@ -3,46 +3,30 @@
# Filename: gseagen.py
# Author: Johannes Schumann <jschumann@km3net.de>
import uproot3
import numpy as np
import warnings
from .rootio import Branch, BranchMapper
from .rootio import EventReader
from .tools import cached_property
MAIN_TREE_NAME = "Events"
class GSGReader:
class GSGReader(EventReader):
"""reader for gSeaGen ROOT files"""
def __init__(self, file_path=None, fobj=None):
"""GSGReader class is a gSeaGen ROOT file wrapper
Parameters
----------
file_path : file path or file-like object
The file handler. It can be a str or any python path-like object
that points to the file.
"""
self._fobj = uproot3.open(file_path)
header_key = "Header"
event_path = "Events"
skip_keys = [header_key]
@cached_property
def header(self):
header_key = "Header"
if header_key in self._fobj:
if self.header_key in self._fobj:
header = {}
for k, v in self._fobj[header_key].items():
for k, v in self._fobj[self.header_key].items():
v = v.array()[0]
if isinstance(v, bytes):
try:
v = v.decode("utf-8")
except UnicodeDecodeError:
pass
header[k.decode("utf-8")] = v
header[k] = v
return header
else:
warnings.warn("Your file header has an unsupported format")
@cached_property
def events(self):
return Branch(self._fobj, BranchMapper(name="Events", key="Events"))
import binascii
from collections import namedtuple
import uproot3
import logging
import warnings
import numba as nb
import uproot
import numpy as np
import awkward as ak
from .definitions import mc_header, fitparameters, reconstruction
from .definitions import mc_header
from .tools import cached_property, to_num, unfold_indices
from .rootio import Branch, BranchMapper
from .rootio import EventReader
MAIN_TREE_NAME = "E"
EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"]
log = logging.getLogger("offline")
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024 ** 2
BASKET_CACHE = uproot3.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
def _nested_mapper(key):
"""Maps a key in the ROOT file to another key (e.g. trks.pos.x -> pos_x)"""
return "_".join(key.split(".")[1:])
EVENTS_MAP = BranchMapper(
name="events",
key="Evt",
extra={"t_sec": "t.fSec", "t_ns": "t.fNanoSec"},
exclude=EXCLUDE_KEYS,
update={
"n_hits": "hits",
"n_mc_hits": "mc_hits",
"n_tracks": "trks",
"n_mc_tracks": "mc_trks",
},
)
SUBBRANCH_MAPS = [
BranchMapper(
name="tracks",
key="trks",
extra={},
exclude=EXCLUDE_KEYS
+ ["trks.usr_data", "trks.usr", "trks.fUniqueID", "trks.fBits"],
attrparser=_nested_mapper,
flat=False,
toawkward=["fitinf", "rec_stages"],
),
BranchMapper(
name="mc_tracks",
key="mc_trks",
exclude=EXCLUDE_KEYS
+ [
"mc_trks.rec_stages",
"mc_trks.fitinf",
"mc_trks.fUniqueID",
"mc_trks.fBits",
],
attrparser=_nested_mapper,
toawkward=["usr", "usr_names"],
flat=False,
),
BranchMapper(
name="hits",
key="hits",
exclude=EXCLUDE_KEYS
+ [
"hits.usr",
"hits.pmt_id",
"hits.origin",
"hits.a",
"hits.pure_a",
"hits.fUniqueID",
"hits.fBits",
],
attrparser=_nested_mapper,
flat=False,
),
BranchMapper(
name="mc_hits",
key="mc_hits",
exclude=EXCLUDE_KEYS
+ [
"mc_hits.usr",
"mc_hits.dom_id",
"mc_hits.channel_id",
"mc_hits.tdc",
"mc_hits.tot",
"mc_hits.trig",
"mc_hits.fUniqueID",
"mc_hits.fBits",
],
attrparser=_nested_mapper,
flat=False,
),
]
class OfflineBranch(Branch):
@cached_property
def usr(self):
return Usr(self._mapper, self._branch, index_chain=self._index_chain)
class Usr:
"""Helper class to access AAObject `usr` stuff (only for events.usr)"""
def __init__(self, mapper, branch, index_chain=None):
self._mapper = mapper
self._name = mapper.name
self._index_chain = [] if index_chain is None else index_chain
self._branch = branch
self._usr_names = []
self._usr_idx_lookup = {}
self._usr_key = "usr" if mapper.flat else mapper.key + ".usr"
self._initialise()
def _initialise(self):
try:
self._branch[self._usr_key]
# This will raise a KeyError in old aanet files
# which has a different strucuter and key (usr_data)
# We do not support those (yet)
except (KeyError, IndexError):
print(
"The `usr` fields could not be parsed for the '{}' branch.".format(
self._name
)
)
return
self._usr_names = [
n.decode("utf-8")
for n in self._branch[self._usr_key + "_names"].lazyarray()[0]
]
self._usr_idx_lookup = {
name: index for index, name in enumerate(self._usr_names)
}
data = self._branch[self._usr_key].lazyarray()
if self._index_chain:
data = unfold_indices(data, self._index_chain)
self._usr_data = data
for name in self._usr_names:
setattr(self, name, self[name])
def __getitem__(self, item):
if self._index_chain:
return unfold_indices(self._usr_data, self._index_chain)[
:, self._usr_idx_lookup[item]
]
else:
return self._usr_data[:, self._usr_idx_lookup[item]]
def keys(self):
return self._usr_names
def __str__(self):
entries = []
for name in self.keys():
entries.append("{}: {}".format(name, self[name]))
return "\n".join(entries)
def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name)
class OfflineReader:
class OfflineReader(EventReader):
"""reader for offline ROOT files"""
def __init__(self, file_path=None):
"""OfflineReader class is an offline ROOT file wrapper
Parameters
----------
file_path : path-like object
Path to the file of interest. It can be a str or any python
path-like object that points to the file.
"""
self._fobj = uproot3.open(file_path)
self._filename = file_path
self._tree = self._fobj[MAIN_TREE_NAME]
self._uuid = binascii.hexlify(self._fobj._context.uuid).decode("ascii")
@property
def uuid(self):
return self._uuid
def close(self):
self._fobj.close()
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
@cached_property
def events(self):
"""The `E` branch, containing all offline events."""
return OfflineBranch(
self._tree, mapper=EVENTS_MAP, subbranchmaps=SUBBRANCH_MAPS
)
event_path = "E/Evt"
item_name = "OfflineEvent"
skip_keys = ["t", "AAObject"]
aliases = {
"t_sec": "t/t.fSec",
"t_ns": "t/t.fNanoSec",
"usr": "AAObject/usr",
"usr_names": "AAObject/usr_names",
}
nested_branches = {
"hits": {
"id": "hits.id",
"channel_id": "hits.channel_id",
"dom_id": "hits.dom_id",
"t": "hits.t",
"tot": "hits.tot",
"trig": "hits.trig", # non-zero if the hit is a triggered hit
},
"mc_hits": {
"id": "mc_hits.id",
"pmt_id": "mc_hits.pmt_id",
"t": "mc_hits.t", # hit time (MC truth)
"a": "mc_hits.a", # hit amplitude (in p.e.)
"origin": "mc_hits.origin", # track id of the track that created this hit
"pure_t": "mc_hits.pure_t", # photon time before pmt simultion
"pure_a": "mc_hits.pure_a", # amplitude before pmt simution,
"type": "mc_hits.type", # particle type or parametrisation used for hit
},
"trks": {
"id": "trks.id",
"pos_x": "trks.pos.x",
"pos_y": "trks.pos.y",
"pos_z": "trks.pos.z",
"dir_x": "trks.dir.x",
"dir_y": "trks.dir.y",
"dir_z": "trks.dir.z",
"t": "trks.t",
"E": "trks.E",
"len": "trks.len",
"lik": "trks.lik",
"rec_type": "trks.rec_type",
"rec_stages": "trks.rec_stages",
"fitinf": "trks.fitinf",
},
"mc_trks": {
"id": "mc_trks.id",
"pos_x": "mc_trks.pos.x",
"pos_y": "mc_trks.pos.y",
"pos_z": "mc_trks.pos.z",
"dir_x": "mc_trks.dir.x",
"dir_y": "mc_trks.dir.y",
"dir_z": "mc_trks.dir.z",
"E": "mc_trks.E",
"t": "mc_trks.t",
"len": "mc_trks.len",
# "status": "mc_trks.status", # TODO: check this
# "mother_id": "mc_trks.mother_id", # TODO: check this
"pdgid": "mc_trks.type",
"hit_ids": "mc_trks.hit_ids",
"usr": "mc_trks.usr", # TODO: trouble with uproot4
"usr_names": "mc_trks.usr_names", # TODO: trouble with uproot4
},
}
nested_aliases = {
"tracks": "trks",
"mc_tracks": "mc_trks",
}
@cached_property
def header(self):
"""The file header"""
if "Head" in self._fobj:
header = {}
for n, x in self._fobj["Head"]._map_3c_string_2c_string_3e_.items():
header[n.decode("utf-8")] = x.decode("utf-8").strip()
return Header(header)
return Header(self._fobj["Head"].tojson()["map<string,string>"])
else:
warnings.warn("Your file header has an unsupported format")
......
import awkward0 as ak0
import awkward as ak1
# to avoid infinite recursion
old_getitem = ak0.ChunkedArray.__getitem__
def new_getitem(self, item):
"""Monkey patch the getitem in awkward.ChunkedArray to apply
awkward1.Array masks on awkward.ChunkedArray"""
if isinstance(item, (ak1.Array, ak0.ChunkedArray)):
return ak1.Array(self)[item]
else:
return old_getitem(self, item)
ak0.ChunkedArray.__getitem__ = new_getitem
#!/usr/bin/env python3
from collections import namedtuple
import numpy as np
import awkward as ak
import uproot3
import uproot
from .tools import unfold_indices
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024 ** 2
BASKET_CACHE = uproot3.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
class BranchMapper:
"""
Mapper helper for keys in a ROOT branch.
Parameters
----------
name: str
The name of the mapper helper which is displayed to the user
key: str
The key of the branch in the ROOT tree.
exclude: ``None``, ``list(str)``
Keys to exclude from parsing.
update: ``None``, ``dict(str: str)``
An update map for keys which are to be presented with a different
key to the user e.g. ``{"n_hits": "hits"}`` will rename the ``hits``
key to ``n_hits``.
extra: ``None``, ``dict(str: str)``
An extra mapper for hidden object, primarily nested ones like
``t.fSec``, which can be revealed and mapped to e.g. ``t_sec``
via ``{"t_sec", "t.fSec"}``.
attrparser: ``None``, ``function(str) -> str``
The function to be used to create attribute names. This is only
needed if unsupported characters are present, like ``.``, which
would prevent setting valid Python attribute names.
toawkward: ``None``, ``list(str)``
List of keys to convert to awkward arrays (recommended for
doubly ragged arrays)
"""
import logging
def __init__(
self,
name,
key,
extra=None,
exclude=None,
update=None,
attrparser=None,
flat=True,
interpretations=None,
toawkward=None,
):
self.name = name
self.key = key
log = logging.getLogger("km3io.rootio")
self.extra = {} if extra is None else extra
self.exclude = [] if exclude is None else exclude
self.update = {} if update is None else update
self.attrparser = (lambda x: x) if attrparser is None else attrparser
self.flat = flat
self.interpretations = {} if interpretations is None else interpretations
self.toawkward = [] if toawkward is None else toawkward
class EventReader:
"""reader for offline ROOT files"""
class Branch:
"""Branch accessor class"""
event_path = None
item_name = "Event"
skip_keys = [] # ignore these subbranches, even if they exist
aliases = {} # top level aliases -> {fromkey: tokey}
nested_branches = {}
nested_aliases = {}
def __init__(
self,
tree,
mapper,
f,
index_chain=None,
subbranchmaps=None,
keymap=None,
awkward_cache=None,
step_size=2000,
keys=None,
aliases=None,
nested_branches=None,
event_ctor=None,
):
self._tree = tree
self._mapper = mapper
self._index_chain = [] if index_chain is None else index_chain
self._keymap = None
self._branch = tree[mapper.key]
self._subbranches = []
self._subbranchmaps = subbranchmaps
# FIXME preliminary cache to improve performance. Hopefully uproot4
# will fix this automatically!
self._awkward_cache = {} if awkward_cache is None else awkward_cache
"""EventReader base class
Parameters
----------
f : str or uproot4.reading.ReadOnlyDirectory (from uproot4.open)
Path to the file of interest or uproot4 filedescriptor.
step_size : int, optional
Number of events to read into the cache when iterating.
Choosing higher numbers may improve the speed but also increases
the memory overhead.
index_chain : list, optional
Keeps track of index chaining.
keys : list or set, optional
Branch keys.
aliases : dict, optional
Branch key aliases.
event_ctor : class or namedtuple, optional
Event constructor.
"""
if isinstance(f, str):
self._fobj = uproot.open(f)
self._filepath = f
elif isinstance(f, uproot.reading.ReadOnlyDirectory):
self._fobj = f
self._filepath = f._file.file_path
else:
raise TypeError("Unsupported file descriptor.")
self._step_size = step_size
self._uuid = self._fobj._file.uuid
self._iterator_index = 0
self._keys = keys
self._event_ctor = event_ctor
self._index_chain = [] if index_chain is None else index_chain
if keymap is None:
self._initialise_keys() #
else:
self._keymap = keymap
if subbranchmaps is not None:
for mapper in subbranchmaps:
subbranch = self.__class__(
self._tree,
mapper=mapper,
index_chain=self._index_chain,
awkward_cache=self._awkward_cache,
)
self._subbranches.append(subbranch)
for subbranch in self._subbranches:
setattr(self, subbranch._mapper.name, subbranch)
if aliases is not None:
self.aliases = aliases
if nested_branches is not None:
self.nested_branches = nested_branches
if self._keys is None:
self._initialise_keys()
if self._event_ctor is None:
self._event_ctor = namedtuple(
self.item_name,
set(
list(self.keys())
+ list(self.aliases)
+ list(self.nested_branches)
+ list(self.nested_aliases)
),
)
def _initialise_keys(self):
"""Create the keymap and instance attributes for branch keys"""
# TODO: this could be a cached property
keys = set(k.decode("utf-8") for k in self._branch.keys()) - set(
self._mapper.exclude
skip_keys = set(self.skip_keys)
all_keys = set(self._fobj[self.event_path].keys())
toplevel_keys = set(k.split("/")[0] for k in all_keys)
valid_aliases = {}
for fromkey, tokey in self.aliases.items():
if tokey in all_keys:
valid_aliases[fromkey] = tokey
self.aliases = valid_aliases
keys = (toplevel_keys - skip_keys).union(
list(valid_aliases) + list(self.nested_aliases)
)
self._keymap = {
**{self._mapper.attrparser(k): k for k in keys},
**self._mapper.extra,
}
self._keymap.update(self._mapper.update)
for k in self._mapper.update.values():
del self._keymap[k]
for key in self._keymap.keys():
setattr(self, key, None)
for key in list(self.nested_branches) + list(self.nested_aliases):
keys.add("n_" + key)
# self._grouped_branches = {k for k in toplevel_keys - skip_keys if isinstance(self._fobj[self.event_path][k].interpretation, uproot.AsGrouped)}
valid_nested_branches = {}
for nested_key, aliases in self.nested_branches.items():
if nested_key in toplevel_keys:
valid_nested_branches[nested_key] = {}
subbranch_keys = self._fobj[self.event_path][nested_key].keys()
for fromkey, tokey in aliases.items():
if tokey in subbranch_keys:
valid_nested_branches[nested_key][fromkey] = tokey
self.nested_branches = valid_nested_branches
self._keys = keys
def __dir__(self):
"""Tab completion in IPython"""
return list(self.keys()) + ["header"]
def keys(self):
return self._keymap.keys()
"""Returns all accessible branch keys, without the skipped ones."""
return self._keys
def __getattribute__(self, attr):
if attr.startswith("_"): # let all private and magic methods pass
return object.__getattribute__(self, attr)
if attr in self._keymap.keys(): # intercept branch key lookups
return self.__getkey__(attr)
@property
def events(self):
# TODO: deprecate this, since `self` is already the container type
return iter(self)
def _keyfor(self, key):
"""Return the correct key for a given alias/key"""
return self.nested_aliases.get(key, key)
def __getattr__(self, attr):
attr = self._keyfor(attr)
# if attr in self.keys() or (attr.startswith("n_") and self._keyfor(attr.split("n_")[1]) in self._grouped_branches):
if attr in self.keys():
return self.__getitem__(attr)
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)
return object.__getattribute__(self, attr)
def __getitem__(self, key):
# indexing
# TODO: maybe just propagate everything to awkward and let it deal
# with the type?
if isinstance(
key, (slice, int, np.int32, np.int64, list, np.ndarray, ak.Array)
):
if isinstance(key, (int, np.int32, np.int64)):
key = int(key)
return self.__class__(
self._fobj,
index_chain=self._index_chain + [key],
step_size=self._step_size,
aliases=self.aliases,
nested_branches=self.nested_branches,
keys=self.keys(),
event_ctor=self._event_ctor,
)
# group counts, for e.g. n_events, n_hits etc.
if isinstance(key, str) and key.startswith("n_"):
key = self._keyfor(key.split("n_")[1])
arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4"))
return unfold_indices(arr, self._index_chain)
key = self._keyfor(key)
branch = self._fobj[self.event_path]
# These are special branches which are nested, like hits/trks/mc_trks
# We are explicitly grabbing just a predefined set of subbranches
# and also alias them to be backwards compatible (and attribute-accessible)
if key in self.nested_branches:
fields = []
# some fields are not always available, like `usr_names`
for to_field, from_field in self.nested_branches[key].items():
if from_field in branch[key].keys():
fields.append(to_field)
log.debug(fields)
return Branch(
branch[key], fields, self.nested_branches[key], self._index_chain
)
else:
return unfold_indices(
branch[self.aliases.get(key, key)].array(), self._index_chain
)
def __getkey__(self, key):
interpretation = self._mapper.interpretations.get(key)
def __iter__(self, chunkwise=False):
self._events = self._event_generator(chunkwise=chunkwise)
return self
if key == "usr_names":
# TODO this will be fixed soon in uproot,
# see https://github.com/scikit-hep/uproot/issues/465
interpretation = uproot3.asgenobj(
uproot3.SimpleArray(uproot3.STLVector(uproot3.STLString())),
self._branch[self._keymap[key]]._context,
6,
def _get_iterator_limits(self):
"""Determines start and stop, used for event iteration"""
if len(self._index_chain) > 1:
raise NotImplementedError(
"iteration is currently not supported with nested slices"
)
if key == "usr":
# triple jagged array is wrongly parsed in uproot3
interpretation = uproot3.asgenobj(
uproot3.SimpleArray(uproot3.STLVector(uproot3.asdtype(">f8"))),
self._branch[self._keymap[key]]._context,
6,
if self._index_chain:
s = self._index_chain[0]
if not isinstance(s, slice):
raise NotImplementedError("iteration is only supported with slices")
if s.step is None or s.step == 1:
start = s.start
stop = s.stop
else:
raise NotImplementedError(
"iteration is only supported with single steps"
)
else:
start = None
stop = None
return start, stop
def _event_generator(self, chunkwise=False):
start, stop = self._get_iterator_limits()
if chunkwise:
raise NotImplementedError("iterating over chunks is not implemented yet")
events = self._fobj[self.event_path]
group_count_keys = set(
k for k in self.keys() if k.startswith("n_")
) # extra keys to make it easy to count subbranch lengths
log.debug("group_count_keys: %s", group_count_keys)
keys = set(
list(
set(self.keys())
- set(self.nested_branches.keys())
- set(self.nested_aliases)
- group_count_keys
)
out = self._branch[self._keymap[key]].lazyarray(
interpretation=interpretation, basketcache=BASKET_CACHE
+ list(self.aliases.keys())
) # all top-level keys for regular branches
log.debug("keys: %s", keys)
log.debug("aliases: %s", self.aliases)
events_it = events.iterate(
keys,
aliases=self.aliases,
step_size=self._step_size,
entry_start=start,
entry_stop=stop,
)
if self._index_chain is not None and key in self._mapper.toawkward:
cache_key = self._mapper.name + "/" + key
if cache_key not in self._awkward_cache:
if len(out) > 20000: # It will take more than 10 seconds
print("Creating cache for '{}'.".format(cache_key))
self._awkward_cache[cache_key] = ak.from_iter(out)
out = self._awkward_cache[cache_key]
return unfold_indices(out, self._index_chain)
def __getitem__(self, item):
"""Slicing magic"""
if isinstance(item, str):
return self.__getkey__(item)
if isinstance(item, (np.int32, np.int64)):
item = int(item)
# if item.__class__.__name__ == "ChunkedArray":
# item = np.array(item)
nested = []
nested_keys = (
self.nested_branches.keys()
) # dict-key ordering is an implementation detail
log.debug("nested_keys: %s", nested_keys)
for key in nested_keys:
nested.append(
events[key].iterate(
self.nested_branches[key].keys(),
aliases=self.nested_branches[key],
step_size=self._step_size,
entry_start=start,
entry_stop=stop,
)
)
group_counts = {}
for key in group_count_keys:
group_counts[key] = iter(self[key])
log.debug("group_counts: %s", group_counts)
for event_set, *nested_sets in zip(events_it, *nested):
for _event, *nested_items in zip(event_set, *nested_sets):
data = {}
for k in keys:
data[k] = _event[k]
for (k, i) in zip(nested_keys, nested_items):
data[k] = i
for tokey, fromkey in self.nested_aliases.items():
data[tokey] = data[fromkey]
for key in group_counts:
data[key] = next(group_counts[key])
yield self._event_ctor(**data)
return self.__class__(
self._tree,
self._mapper,
index_chain=self._index_chain + [item],
keymap=self._keymap,
subbranchmaps=self._subbranchmaps,
awkward_cache=self._awkward_cache,
)
def __next__(self):
return next(self._events)
def __len__(self):
if not self._index_chain:
return len(self._branch)
return self._fobj[self.event_path].num_entries
elif isinstance(self._index_chain[-1], (int, np.int32, np.int64)):
if len(self._index_chain) == 1:
try:
return len(self[:])
except IndexError:
return 1
# TODO: not sure why this is needed at all, it's too late...
return 1
# try:
# return len(self[:])
# except IndexError:
# return 1
return 1
else:
# ignore the usual index magic and access `id` directly
return len(
unfold_indices(
self._branch[self._keymap["id"]].lazyarray(
basketcache=BASKET_CACHE
),
self._index_chain,
self._fobj[self.event_path]["id"].array(), self._index_chain
)
)
def __actual_len__(self):
"""The raw number of events without any indexing/slicing magic"""
return len(self._fobj[self.event_path]["id"].array())
def __repr__(self):
length = len(self)
actual_length = self.__actual_len__()
return f"{self.__class__.__name__} ({length}{'/' + str(actual_length) if length < actual_length else ''} events)"
@property
def is_single(self):
"""Returns True when a single branch is selected."""
if len(self._index_chain) > 0:
if isinstance(self._index_chain[0], (int, np.int32, np.int64)):
return True
return False
def __iter__(self):
self._iterator_index = 0
def uuid(self):
return self._uuid
def close(self):
self._fobj.close()
def __enter__(self):
return self
def __next__(self):
idx = self._iterator_index
self._iterator_index += 1
if idx >= len(self):
raise StopIteration
return self[idx]
def __exit__(self, *args):
self.close()
def __str__(self):
length = len(self)
return "{} ({}) with {} element{}".format(
self.__class__.__name__,
self._mapper.name,
length,
"s" if length > 1 else "",
class Branch:
"""Helper class for nested branches likes tracks/hits"""
def __init__(self, branch, fields, aliases, index_chain):
self._branch = branch
self.fields = fields
self._aliases = aliases
self._index_chain = index_chain
def __dir__(self):
"""Tab completion in IPython"""
return list(self.fields)
def __getattr__(self, attr):
if attr not in self._aliases:
raise AttributeError(
f"No field named {attr}. Available fields: {self.fields}"
)
key = self._aliases[attr]
if self._index_chain:
idx0 = self._index_chain[0]
if isinstance(idx0, (int, np.int32, np.int64)):
# optimise single-element and slice lookups
start = idx0
stop = idx0 + 1
arr = ak.flatten(
self._branch[key].array(entry_start=start, entry_stop=stop)
)
return unfold_indices(arr, self._index_chain[1:])
if isinstance(idx0, slice):
if idx0.step is None or idx0.step == 1:
start = idx0.start
stop = idx0.stop
arr = self._branch[key].array(entry_start=start, entry_stop=stop)
return unfold_indices(arr, self._index_chain[1:])
return unfold_indices(self._branch[key].array(), self._index_chain)
def __getitem__(self, key):
return self.__class__(
self._branch, self.fields, self._aliases, self._index_chain + [key]
)
def __len__(self):
if not self._index_chain:
return self._branch.num_entries
elif isinstance(self._index_chain[-1], (int, np.int32, np.int64)):
if len(self._index_chain) == 1:
return 1
# try:
# return len(self[:])
# except IndexError:
# return 1
return 1
else:
# ignore the usual index magic and access `id` directly
return len(self.id)
def __actual_len__(self):
"""The raw number of events without any indexing/slicing magic"""
return len(self._branch[self._aliases["id"]].array())
def __repr__(self):
length = len(self)
return "<{}[{}]: {} element{}>".format(
self.__class__.__name__,
self._mapper.name,
length,
"s" if length > 1 else "",
)
actual_length = self.__actual_len__()
return f"{self.__class__.__name__} ({length}{'/' + str(actual_length) if length < actual_length else ''} {self._branch.name})"
@property
def ndim(self):
if not self._index_chain:
return 2
elif any(isinstance(i, (int, np.int32, np.int64)) for i in self._index_chain):
return 1
return 2
#!/usr/bin/env python3
from collections import namedtuple
import numba as nb
import numpy as np
import awkward as ak
......@@ -133,34 +134,19 @@ def fitinf(fitparam, tracks):
----------
fitparam : int
the fit parameter key according to fitparameters defined in
KM3NeT-Dataformat.
tracks : km3io.offline.OfflineBranch
the tracks class. both full tracks branch or a slice of the
tracks branch (example tracks[:, 0]) work.
KM3NeT-Dataformat (see km3io.definitions.fitparameters).
tracks : ak.Array or km3io.rootio.Branch
reconstructed tracks with .fitinf attribute
Returns
-------
awkward1.Array
awkward array of the values of the fit parameter requested.
awkward array of the values of the fit parameter requested. Missing
values are set to NaN.
"""
fit = tracks.fitinf
index = fitparam
if tracks.is_single and len(tracks) != 1:
params = fit[count_nested(fit, axis=1) > index]
out = params[:, index]
if tracks.is_single and len(tracks) == 1:
out = fit[index]
else:
if len(tracks[0]) == 1: # case of tracks slice with 1 track per event.
params = fit[count_nested(fit, axis=1) > index]
out = params[:, index]
else:
params = fit[count_nested(fit, axis=2) > index]
out = ak.Array([i[:, index] for i in params])
return out
nonempty = ak.num(fit, axis=-1) > 0
return ak.fill_none(fit.mask[nonempty][..., 0], np.nan)
def count_nested(arr, axis=0):
......@@ -206,12 +192,9 @@ def get_multiplicity(tracks, rec_stages):
awkward1.Array
tracks multiplicty.
"""
masked_tracks = tracks[mask(tracks, stages=rec_stages)]
masked_tracks = tracks[mask(tracks.rec_stages, sequence=rec_stages)]
if tracks.is_single:
out = count_nested(masked_tracks.rec_stages, axis=0)
else:
out = count_nested(masked_tracks.rec_stages, axis=1)
out = count_nested(masked_tracks.rec_stages, axis=tracks.ndim - 1)
return out
......@@ -221,548 +204,223 @@ def best_track(tracks, startend=None, minmax=None, stages=None):
Parameters
----------
tracks : km3io.offline.OfflineBranch
Array of tracks or jagged array of tracks (multiple events).
tracks : awkward.Array
A list of tracks or doubly nested tracks, usually from
OfflineReader.events.tracks or subarrays of that, containing recunstructed
tracks.
startend: tuple(int, int), optional
The required first and last stage in tracks.rec_stages.
The required first and last stage in tracks.rec_stages.
minmax: tuple(int, int), optional
The range (minimum and maximum) value of rec_stages to take into account.
The range (minimum and maximum) value of rec_stages to take into account.
stages : list or set, optional
- list: the order of the rec_stages is respected.
- set: a subset of required stages; the order is irrelevant.
- list: the order of the rec_stages is respected.
- set: a subset of required stages; the order is irrelevant.
Returns
-------
km3io.offline.OfflineBranch
The best tracks based on the selection.
awkward.Array or namedtuple
Be aware that the dimensions are kept, which means that the final
track attributes are nested when multiple events are passed in.
If a single event (just a list of tracks) is provided, a named tuple
with a single track and flat attributes is created.
Raises
------
ValueError
- too many inputs specified.
- no inputs are specified.
When invalid inputs are specified.
"""
inputs = (stages, startend, minmax)
if all(v is None for v in inputs):
if sum(v is not None for v in inputs) != 1:
raise ValueError("either stages, startend or minmax must be specified.")
if stages is not None and (startend is not None or minmax is not None):
raise ValueError("Please specify either a range or a set of rec stages.")
if stages is not None and startend is None and minmax is None:
selected_tracks = tracks[mask(tracks, stages=stages)]
if startend is not None and minmax is None and stages is None:
selected_tracks = tracks[mask(tracks, startend=startend)]
if minmax is not None and startend is None and stages is None:
selected_tracks = tracks[mask(tracks, minmax=minmax)]
return _max_lik_track(_longest_tracks(selected_tracks))
if stages is not None:
if isinstance(stages, list):
m1 = mask(tracks.rec_stages, sequence=stages)
elif isinstance(stages, set):
m1 = mask(tracks.rec_stages, atleast=list(stages))
else:
raise ValueError("stages must be a list or a set of integers")
def _longest_tracks(tracks):
"""Select the longest reconstructed track"""
if tracks.is_single:
stages_nesting_level = 1
tracks_nesting_level = 0
if startend is not None:
m1 = mask(tracks.rec_stages, startend=startend)
else:
stages_nesting_level = 2
tracks_nesting_level = 1
if minmax is not None:
m1 = mask(tracks.rec_stages, minmax=minmax)
len_stages = count_nested(tracks.rec_stages, axis=stages_nesting_level)
longest = tracks[len_stages == ak.max(len_stages, axis=tracks_nesting_level)]
try:
original_ndim = tracks.ndim
except AttributeError:
original_ndim = 1
axis = 1 if original_ndim == 2 else 0
return longest
tracks = tracks[m1]
rec_stage_lengths = ak.num(tracks.rec_stages, axis=-1)
max_rec_stage_length = ak.max(rec_stage_lengths, axis=axis)
m2 = rec_stage_lengths == max_rec_stage_length
tracks = tracks[m2]
def _max_lik_track(tracks):
"""Select the track with the highest likelihood """
if tracks.is_single:
tracks_nesting_level = 0
else:
tracks_nesting_level = 1
m3 = ak.argmax(tracks.lik, axis=axis, keepdims=True)
return tracks[tracks.lik == ak.max(tracks.lik, axis=tracks_nesting_level)]
out = tracks[m3]
if original_ndim == 1:
if isinstance(out, ak.Record):
return out[:, 0]
return out[0]
return out[:, 0]
def mask(tracks, stages=None, startend=None, minmax=None):
"""Create a mask for tracks.rec_stages.
def mask(arr, sequence=None, startend=None, minmax=None, atleast=None):
"""Return a boolean mask which mask each nested sub-array for a condition.
Parameters
----------
tracks : km3io.offline.OfflineBranch
tracks, or one track, or slice of tracks, or slice of one track.
stages : list or set
reconstruction stages of interest:
- list: the order of rec_stages in respected.
- set: the order of rec_stages in irrelevant.
arr : awkward.Array with ndim>=2
The array to mask.
startend: tuple(int, int), optional
The required first and last stage in tracks.rec_stages.
True for entries where the first and last element are matching the tuple.
minmax: tuple(int, int), optional
The range (minimum and maximum) value of rec_stages to take into account.
Returns
-------
awkward1.Array(bool)
an awkward1 Array mask where True corresponds to the positions
where stages were found. False otherwise.
Raises
------
ValueError
- too many inputs specified.
- no inputs are specified.
True for entries where each element is within the min-max-range.
sequence : list(int), optional
True for entries which contain the exact same elements (in that specific
order)
atleast : list(int), optional
True for entries where at least the provided elements are present.
An extensive discussion about this implementation can be found at:
https://github.com/scikit-hep/awkward-1.0/issues/580
Many thanks for Jim for the fruitful discussion and the final implementation.
"""
inputs = (stages, startend, minmax)
if all(v is None for v in inputs):
raise ValueError("either stages, startend or minmax must be specified.")
if stages is not None and (startend is not None or minmax is not None):
raise ValueError("Please specify either a range or a set of rec stages.")
if stages is not None and startend is None and minmax is None:
if isinstance(stages, list):
# order of stages is conserved
return _mask_explicit_rec_stages(tracks, stages)
if isinstance(stages, set):
# order of stages is no longer conserved
return _mask_rec_stages_in_set(tracks, stages)
if startend is not None and minmax is None and stages is None:
return _mask_rec_stages_between_start_end(tracks, *startend)
if minmax is not None and startend is None and stages is None:
return _mask_rec_stages_in_range_min_max(tracks, *minmax)
def _mask_rec_stages_between_start_end(tracks, start, end):
"""Mask tracks.rec_stages that start exactly with start and end exactly
with end. ie [start, a, b ...,z , end]"""
builder = ak.ArrayBuilder()
if tracks.is_single:
_find_between_single(tracks.rec_stages, start, end, builder)
return (builder.snapshot() == 1)[0]
else:
_find_between(tracks.rec_stages, start, end, builder)
return builder.snapshot() == 1
@nb.jit(nopython=True)
def _find_between(rec_stages, start, end, builder):
"""Find tracks.rec_stages where rec_stages[0] == start and rec_stages[-1] == end."""
for s in rec_stages:
builder.begin_list()
for i in s:
num_stages = len(i)
if num_stages != 0:
if (i[0] == start) and (i[-1] == end):
builder.append(1)
else:
builder.append(0)
inputs = (sequence, startend, minmax, atleast)
if sum(v is not None for v in inputs) != 1:
raise ValueError(
"either sequence, startend, minmax or atleast must be specified."
)
def recurse(layout):
if layout.purelist_depth == 2:
if startend is not None:
np_array = _mask_startend(ak.Array(layout), *startend)
elif minmax is not None:
np_array = _mask_minmax(ak.Array(layout), *minmax)
elif sequence is not None:
np_array = _mask_sequence(ak.Array(layout), np.array(sequence))
elif atleast is not None:
np_array = _mask_atleast(ak.Array(layout), np.array(atleast))
return ak.layout.NumpyArray(np_array)
elif isinstance(
layout,
(
ak.layout.ListArray32,
ak.layout.ListArrayU32,
ak.layout.ListArray64,
),
):
if len(layout.stops) == 0:
content = recurse(layout.content)
else:
builder.append(0)
builder.end_list()
content = recurse(layout.content[: np.max(layout.stops)])
return type(layout)(layout.starts, layout.stops, content)
elif isinstance(
layout,
(
ak.layout.ListOffsetArray32,
ak.layout.ListOffsetArrayU32,
ak.layout.ListOffsetArray64,
),
):
content = recurse(layout.content[: layout.offsets[-1]])
return type(layout)(layout.offsets, content)
elif isinstance(layout, ak.layout.RegularArray):
content = recurse(layout.content)
return ak.layout.RegularArray(content, layout.size)
@nb.jit(nopython=True)
def _find_between_single(rec_stages, start, end, builder):
"""Find tracks.rec_stages where rec_stages[0] == start and
rec_stages[-1] == end in a single track."""
builder.begin_list()
for s in rec_stages:
num_stages = len(s)
if num_stages != 0:
if (s[0] == start) and (s[-1] == end):
builder.append(1)
else:
builder.append(0)
else:
builder.append(0)
builder.end_list()
raise NotImplementedError(repr(arr))
layout = ak.to_layout(arr, allow_record=True, allow_other=False)
return ak.Array(recurse(layout))
def _mask_explicit_rec_stages(tracks, stages):
"""Mask explicit rec_stages .
Parameters
----------
tracks : km3io.offline.OfflineBranch
tracks or one track, or slice of tracks.
stages : list
reconstruction stages of interest. The order of stages is conserved.
Returns
-------
awkward1.Array
an awkward1 Array mask where True corresponds to the positions
where stages were found. False otherwise.
"""
builder = ak.ArrayBuilder()
if tracks.is_single:
_find_single(tracks.rec_stages, ak.Array(stages), builder)
return (builder.snapshot() == 1)[0]
else:
_find(tracks.rec_stages, ak.Array(stages), builder)
return builder.snapshot() == 1
@nb.njit
def _mask_startend(arr, start, end):
out = np.empty(len(arr), np.bool_)
for i, subarr in enumerate(arr):
out[i] = len(subarr) > 0 and subarr[0] == start and subarr[-1] == end
return out
@nb.jit(nopython=True)
def _find(rec_stages, stages, builder):
"""construct an awkward1 array with the same structure as tracks.rec_stages.
When stages are found, the Array is filled with value 1, otherwise it is filled
with value 0.
Parameters
----------
rec_stages : awkward1.Array
tracks.rec_stages from multiple events.
stages : awkward1.Array
reconstruction stages of interest.
builder : awkward1.highlevel.ArrayBuilder
awkward1 Array builder.
"""
for s in rec_stages:
builder.begin_list()
for i in s:
num_stages = len(i)
if num_stages == len(stages):
found = 0
for j in range(num_stages):
if i[j] == stages[j]:
found += 1
if found == num_stages:
builder.append(1)
else:
builder.append(0)
@nb.njit
def _mask_minmax(arr, min, max):
out = np.empty(len(arr), np.bool_)
for i, subarr in enumerate(arr):
if len(subarr) == 0:
out[i] = False
else:
for el in subarr:
if el < min or el > max:
out[i] = False
break
else:
builder.append(0)
builder.end_list()
out[i] = True
return out
@nb.jit(nopython=True)
def _find_single(rec_stages, stages, builder):
"""Construct an awkward1 array with the same structure as tracks.rec_stages.
@nb.njit
def _mask_sequence(arr, sequence):
out = np.empty(len(arr), np.bool_)
n = len(sequence)
for i, subarr in enumerate(arr):
if len(subarr) != n:
out[i] = False
else:
for j in range(n):
if subarr[j] != sequence[j]:
out[i] = False
break
else:
out[i] = True
return out
When stages are found, the Array is filled with value 1, otherwise it is filled
with value 0.
Parameters
----------
rec_stages : awkward1.Array
tracks.rec_stages from a SINGLE event.
stages : awkward1.Array
reconstruction stages of interest.
builder : awkward1.highlevel.ArrayBuilder
awkward1 Array builder.
"""
builder.begin_list()
for s in rec_stages:
num_stages = len(s)
if num_stages == len(stages):
found = 0
for j in range(num_stages):
if s[j] == stages[j]:
found += 1
if found == num_stages:
builder.append(1)
else:
builder.append(0)
@nb.njit
def _mask_atleast(arr, atleast):
out = np.empty(len(arr), np.bool_)
for i, subarr in enumerate(arr):
for req_el in atleast:
if req_el not in subarr:
out[i] = False
break
else:
builder.append(0)
builder.end_list()
out[i] = True
return out
def best_jmuon(tracks):
"""Select the best JMUON track.
Parameters
----------
tracks : km3io.offline.OfflineBranch
tracks, or one track, or slice of tracks, or slices of tracks.
Returns
-------
km3io.offline.OfflineBranch
the longest + highest likelihood track reconstructed with JMUON.
"""
mask = _mask_rec_stages_in_range_min_max(
tracks, min_stage=krec.JMUONBEGIN, max_stage=krec.JMUONEND
)
return _max_lik_track(_longest_tracks(tracks[mask]))
"""Select the best JMUON track."""
return best_track(tracks, minmax=(krec.JMUONBEGIN, krec.JMUONEND))
def best_jshower(tracks):
"""Select the best JSHOWER track.
Parameters
----------
tracks : km3io.offline.OfflineBranch
tracks, or one track, or slice of tracks, or slices of tracks.
Returns
-------
km3io.offline.OfflineBranch
the longest + highest likelihood track reconstructed with JSHOWER.
"""
mask = _mask_rec_stages_in_range_min_max(
tracks, min_stage=krec.JSHOWERBEGIN, max_stage=krec.JSHOWEREND
)
return _max_lik_track(_longest_tracks(tracks[mask]))
"""Select the best JSHOWER track."""
return best_track(tracks, minmax=(krec.JSHOWERBEGIN, krec.JSHOWEREND))
def best_aashower(tracks):
"""Select the best AASHOWER track.
Parameters
----------
tracks : km3io.offline.OfflineBranch
tracks, or one track, or slice of tracks, or slices of tracks.
Returns
-------
km3io.offline.OfflineBranch
the longest + highest likelihood track reconstructed with AASHOWER.
"""
mask = _mask_rec_stages_in_range_min_max(
tracks, min_stage=krec.AASHOWERBEGIN, max_stage=krec.AASHOWEREND
)
return _max_lik_track(_longest_tracks(tracks[mask]))
"""Select the best AASHOWER track. """
return best_track(tracks, minmax=(krec.AASHOWERBEGIN, krec.AASHOWEREND))
def best_dusjshower(tracks):
"""Select the best DISJSHOWER track.
Parameters
----------
tracks : km3io.offline.OfflineBranch
tracks, or one track, or slice of tracks, or slices of tracks.
Returns
-------
km3io.offline.OfflineBranch
the longest + highest likelihood track reconstructed with DUSJSHOWER.
"""
mask = _mask_rec_stages_in_range_min_max(
tracks, min_stage=krec.DUSJSHOWERBEGIN, max_stage=krec.DUSJSHOWEREND
)
return _max_lik_track(_longest_tracks(tracks[mask]))
def _mask_rec_stages_in_range_min_max(tracks, min_stage=None, max_stage=None):
"""Mask tracks where rec_stages are withing the range(min, max).
Parameters
----------
tracks : km3io.offline.OfflineBranch
tracks, or one track, or slice of tracks, or slices of tracks.
min_stage : int
minimum value of rec_stages.
max_stage : int
maximum value of rec_stages.
Returns
-------
awkward1.Array
an awkward1 Array mask where True corresponds to the positions
where stages were found. False otherwise.
"""
if (min_stage is not None) and (max_stage is not None):
builder = ak.ArrayBuilder()
if tracks.is_single:
_find_in_range_single(tracks.rec_stages, min_stage, max_stage, builder)
return (builder.snapshot() == 1)[0]
else:
_find_in_range(tracks.rec_stages, min_stage, max_stage, builder)
return builder.snapshot() == 1
else:
raise ValueError("please provide min_stage and max_stage.")
@nb.jit(nopython=True)
def _find_in_range(rec_stages, min_stage, max_stage, builder):
"""Construct an awkward1 array with the same structure as tracks.rec_stages.
When stages are within the range(min, max), the Array is filled with
value 1, otherwise it is filled with value 0.
Parameters
----------
rec_stages : awkward1.Array
tracks.rec_stages of MULTILPLE events.
min_stage: int
minimum value of rec_stages.
max_stage: int
minimum value of rec_stages.
builder : awkward1.highlevel.ArrayBuilder
awkward1 Array builder.
"""
for s in rec_stages:
builder.begin_list()
for i in s:
num_stages = len(i)
if num_stages != 0:
found = 0
for j in i:
if min_stage <= j <= max_stage:
found += 1
if found == num_stages:
builder.append(1)
else:
builder.append(0)
else:
builder.append(0)
builder.end_list()
@nb.jit(nopython=True)
def _find_in_range_single(rec_stages, min_stage, max_stage, builder):
"""Construct an awkward1 array with the same structure as tracks.rec_stages.
When stages are within the range(min, max), the Array is filled with
value 1, otherwise it is filled with value 0.
Parameters
----------
rec_stages : awkward1.Array
tracks.rec_stages of a SINGLE event.
min_stage: int
minimum value of rec_stages.
max_stage: int
minimum value of rec_stages.
builder : awkward1.highlevel.ArrayBuilder
awkward1 Array builder.
"""
builder.begin_list()
for s in rec_stages:
num_stages = len(s)
if num_stages != 0:
found = 0
for i in s:
if min_stage <= i <= max_stage:
found += 1
if found == num_stages:
builder.append(1)
else:
builder.append(0)
else:
builder.append(0)
builder.end_list()
def _mask_rec_stages_in_set(tracks, stages):
"""Mask tracks where rec_stages are withing the range(min, max).
Parameters
----------
tracks : km3io.offline.OfflineBranch
tracks, or one track, or slice of tracks, or slices of tracks.
stages : set
set of stages to look for in tracks.rec_stages.
Returns
-------
awkward1.Array
an awkward1 Array mask where True corresponds to the positions
where stages were found. False otherwise.
"""
if isinstance(stages, set):
builder = ak.ArrayBuilder()
if tracks.is_single:
_find_in_set_single(tracks.rec_stages, stages, builder)
return (builder.snapshot() == 1)[0]
else:
_find_in_set(tracks.rec_stages, stages, builder)
return builder.snapshot() == 1
else:
raise ValueError("stages must be a set")
@nb.jit(nopython=True)
def _find_in_set(rec_stages, stages, builder):
"""Construct an awkward1 array with the same structure as tracks.rec_stages.
When all stages are found in rec_stages, the Array is filled with
value 1, otherwise it is filled with value 0.
Parameters
----------
rec_stages : awkward1.Array
tracks.rec_stages of MULTILPLE events.
stages : set
set of stages.
builder : awkward1.highlevel.ArrayBuilder
awkward1 Array builder.
"""
n = len(stages)
for s in rec_stages:
builder.begin_list()
for i in s:
num_stages = len(i)
if num_stages != 0:
found = 0
for j in i:
if j in stages:
found += 1
if found == n:
builder.append(1)
else:
builder.append(0)
else:
builder.append(0)
builder.end_list()
@nb.jit(nopython=True)
def _find_in_set_single(rec_stages, stages, builder):
"""Construct an awkward1 array with the same structure as tracks.rec_stages.
When all stages are found in rec_stages, the Array is filled with
value 1, otherwise it is filled with value 0.
Parameters
----------
rec_stages : awkward1.Array
tracks.rec_stages of a SINGLE event.
stages : set
set of stages.
builder : awkward1.highlevel.ArrayBuilder
awkward1 Array builder.
"""
n = len(stages)
builder.begin_list()
for s in rec_stages:
num_stages = len(s)
if num_stages != 0:
found = 0
for j in s:
if j in stages:
found += 1
if found == n:
builder.append(1)
else:
builder.append(0)
else:
builder.append(0)
builder.end_list()
"""Select the best DISJSHOWER track."""
return best_track(tracks, minmax=(krec.DUSJSHOWERBEGIN, krec.DUSJSHOWEREND))
def is_cc(fobj):
......@@ -788,32 +446,49 @@ def is_cc(fobj):
w2list = fobj.events.w2list
len_w2lists = ak.num(w2list, axis=1)
if all(len_w2lists <= 7): # old nu file have w2list of len 7.
usr_names = fobj.events.mc_tracks.usr_names
usr_data = fobj.events.mc_tracks.usr
mask_cc_flag = usr_names[:, 0] == b"cc"
inter_ID = usr_data[:, 0][mask_cc_flag]
out = ak.flatten(inter_ID == 2) # 2 is the interaction ID for CC.
# According to: https://wiki.km3net.de/index.php/Simulations/The_gSeaGen_code#Physics_event_entries
# the interaction types are defined as follow:
else:
if "gseagen" in program.lower():
# INTER Interaction type
# 1 EM
# 2 Weak[CC]
# 3 Weak[NC]
# 4 Weak[CC+NC+interference]
# 5 NucleonDecay
# According to: https://wiki.km3net.de/index.php/Simulations/The_gSeaGen_code#Physics_event_entries
# the interaction types are defined as follow:
# INTER Interaction type
# 1 EM
# 2 Weak[CC]
# 3 Weak[NC]
# 4 Weak[CC+NC+interference]
# 5 NucleonDecay
if all(len_w2lists <= 7): # old nu file have w2list of len 7.
# Checking the `cc` value in usr of the first mc_tracks,
# which are the primary neutrinos and carry the event property.
# This has been changed in 2020 to be a property in the w2list.
# See https://git.km3net.de/common/km3net-dataformat/-/issues/23
return usr(fobj.events.mc_tracks[:, 0], "cc") == 2
cc_flag = w2list[:, kw2gsg.W2LIST_GSEAGEN_CC]
out = cc_flag > 0 # to be tested with a newly generated nu file.
else:
# TODO: to be tested with a newly generated files with th eupdated
# w2list definitionn.
if "gseagen" in program.lower():
return w2list[:, kw2gen.W2LIST_GSEAGEN_CC] == 2
if "genhen" in program.lower():
cc_flag = w2list[:, kw2gen.W2LIST_GENHEN_CC]
out = cc_flag > 0
return w2list[:, kw2gen.W2LIST_GENHEN_CC] == 2
else:
raise ValueError(f"simulation program {program} is not implemented.")
raise NotImplementedError(
f"don't know how to determine the CC-ness of {program} files."
)
return out
def usr(objects, field):
"""Return the usr-data for a given field.
Parameters
----------
objects : awkward.Array
Events, tracks, hits or whatever objects which have usr and usr_names
fields (e.g. OfflineReader().events).
"""
if len(unique(ak.num(objects.usr_names))) > 1:
# let's do it the hard way
return ak.flatten(objects.usr[objects.usr_names == field])
available_fields = objects.usr_names[0].tolist()
idx = available_fields.index(field)
return objects.usr[:, idx]
......@@ -3,4 +3,5 @@ numba>=0.50
awkward>=1.0.0rc2
awkward0
uproot3>=3.11.1
uproot>=4.0.0rc5
setuptools_scm
import os
import re
import unittest
import inspect
import awkward as ak
from km3net_testdata import data_path
......@@ -8,18 +10,22 @@ from km3io.gseagen import GSGReader
GSG_READER = GSGReader(data_path("gseagen/gseagen.root"))
AWKWARD_STR_CLASSES = [
s[1] for s in inspect.getmembers(ak.behaviors.string, inspect.isclass)
]
class TestGSGHeader(unittest.TestCase):
def setUp(self):
self.header = GSG_READER.header
def test_str_byte_type(self):
assert isinstance(self.header["gSeaGenVer"], str)
assert isinstance(self.header["GenieVer"], str)
assert isinstance(self.header["gSeaGenVer"], str)
assert isinstance(self.header["InpXSecFile"], str)
assert isinstance(self.header["Flux1"], str)
assert isinstance(self.header["Flux2"], str)
assert type(self.header["gSeaGenVer"]) in AWKWARD_STR_CLASSES
assert type(self.header["GenieVer"]) in AWKWARD_STR_CLASSES
assert type(self.header["gSeaGenVer"]) in AWKWARD_STR_CLASSES
assert type(self.header["InpXSecFile"]) in AWKWARD_STR_CLASSES
assert type(self.header["Flux1"]) in AWKWARD_STR_CLASSES
assert type(self.header["Flux2"]) in AWKWARD_STR_CLASSES
def test_values(self):
assert self.header["RunNu"] == 1
......@@ -55,11 +61,6 @@ class TestGSGHeader(unittest.TestCase):
assert self.header["NNu"] == 2
self.assertListEqual(self.header["NuList"].tolist(), [-14, 14])
def test_unsupported_header(self):
f = GSGReader(data_path("online/km3net_online.root"))
with self.assertWarns(UserWarning):
f.header
class TestGSGEvents(unittest.TestCase):
def setUp(self):
......
import unittest
import numpy as np
from pathlib import Path
import uuid
import awkward as ak
from km3net_testdata import data_path
from km3io import OfflineReader
from km3io.offline import _nested_mapper, Header
from km3io.offline import Header
from km3io.tools import usr
OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root"))
OFFLINE_USR = OfflineReader(data_path("offline/usr-sample.root"))
......@@ -32,7 +35,7 @@ class TestOfflineReader(unittest.TestCase):
assert self.n_events == len(self.r.events)
def test_uuid(self):
assert self.r.uuid == "0001b192d888fcc711e9b4306cf09e86beef"
assert str(self.r.uuid) == "b192d888-fcc7-11e9-b430-6cf09e86beef"
class TestHeader(unittest.TestCase):
......@@ -147,23 +150,21 @@ class TestOfflineEvents(unittest.TestCase):
def test_len(self):
assert self.n_events == len(self.events)
def test_attributes_available(self):
for key in self.events._keymap.keys():
getattr(self.events, key)
def test_attributes(self):
assert self.n_events == len(self.events.det_id)
self.assertListEqual(self.det_id, list(self.events.det_id))
print(self.n_hits)
print(self.events.hits)
self.assertListEqual(self.n_hits, list(self.events.n_hits))
self.assertListEqual(self.n_tracks, list(self.events.n_tracks))
self.assertListEqual(self.t_sec, list(self.events.t_sec))
self.assertListEqual(self.t_ns, list(self.events.t_ns))
def test_keys(self):
assert np.allclose(self.n_hits, self.events["n_hits"])
assert np.allclose(self.n_tracks, self.events["n_tracks"])
assert np.allclose(self.t_sec, self.events["t_sec"])
assert np.allclose(self.t_ns, self.events["t_ns"])
assert np.allclose(self.n_hits, self.events["n_hits"].tolist())
assert np.allclose(self.n_tracks, self.events["n_tracks"].tolist())
assert np.allclose(self.t_sec, self.events["t_sec"].tolist())
assert np.allclose(self.t_ns, self.events["t_ns"].tolist())
def test_slicing(self):
s = slice(2, 8, 2)
......@@ -176,20 +177,28 @@ class TestOfflineEvents(unittest.TestCase):
def test_slicing_consistency(self):
for s in [slice(1, 3), slice(2, 7, 3)]:
assert np.allclose(self.events[s].n_hits, self.events.n_hits[s])
assert np.allclose(
self.events[s].n_hits.tolist(), self.events.n_hits[s].tolist()
)
def test_index_consistency(self):
for i in [0, 2, 5]:
assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])
def test_index_chaining(self):
assert np.allclose(self.events[3:5].n_hits, self.events.n_hits[3:5])
assert np.allclose(
self.events[3:5].n_hits.tolist(), self.events.n_hits[3:5].tolist()
)
assert np.allclose(self.events[3:5][0].n_hits, self.events.n_hits[3:5][0])
def test_index_chaining_on_nested_branches_aka_records(self):
assert np.allclose(
self.events[3:5].hits[1].dom_id[4], self.events.hits[3:5][1][4].dom_id
self.events[3:5].hits[1].dom_id[4],
self.events.hits[3:5][1].dom_id[4],
)
assert np.allclose(
self.events.hits[3:5][1][4].dom_id, self.events[3:5][1][4].hits.dom_id
self.events.hits[3:5][1].dom_id[4],
self.events[3:5][1].hits.dom_id[4],
)
def test_fancy_indexing(self):
......@@ -207,8 +216,24 @@ class TestOfflineEvents(unittest.TestCase):
assert 10 == i
def test_iteration_2(self):
n_hits = [e.n_hits for e in self.events]
assert np.allclose(n_hits, self.events.n_hits)
n_hits = [len(e.hits.id) for e in self.events]
assert np.allclose(n_hits, ak.num(self.events.hits.id, axis=1).tolist())
def test_iteration_over_slices(self):
ids = [e.id for e in self.events[2:5]]
self.assertListEqual([3, 4, 5], ids)
def test_iteration_over_slices_raises_when_stepsize_not_supported(self):
with self.assertRaises(NotImplementedError):
[e.id for e in self.events[2:8:2]]
def test_iteration_over_slices_raises_when_single_item(self):
with self.assertRaises(NotImplementedError):
[e.id for e in self.events[0]]
def test_iteration_over_slices_raises_when_multiple_slices(self):
with self.assertRaises(NotImplementedError):
[e.id for e in self.events[2:8][2:4]]
def test_str(self):
assert str(self.n_events) in str(self.events)
......@@ -274,16 +299,14 @@ class TestOfflineHits(unittest.TestCase):
],
}
def test_attributes_available(self):
for key in self.hits._keymap.keys():
def test_fields_work_as_keys_and_attributes(self):
for key in self.hits.fields:
getattr(self.hits, key)
self.hits[key]
def test_channel_ids(self):
self.assertTrue(all(c >= 0 for c in self.hits.channel_id.min()))
self.assertTrue(all(c < 31 for c in self.hits.channel_id.max()))
def test_str(self):
assert str(self.n_hits) in str(self.hits)
self.assertTrue(all(c >= 0 for c in ak.min(self.hits.channel_id, axis=1)))
self.assertTrue(all(c < 31 for c in ak.max(self.hits.channel_id, axis=1)))
def test_repr(self):
assert str(self.n_hits) in repr(self.hits)
......@@ -292,7 +315,7 @@ class TestOfflineHits(unittest.TestCase):
for idx, dom_id in self.dom_id.items():
self.assertListEqual(dom_id, list(self.hits.dom_id[idx][: len(dom_id)]))
for idx, t in self.t.items():
assert np.allclose(t, self.hits.t[idx][: len(t)])
assert np.allclose(t, self.hits.t[idx][: len(t)].tolist())
def test_slicing(self):
s = slice(2, 8, 2)
......@@ -306,28 +329,39 @@ class TestOfflineHits(unittest.TestCase):
def test_slicing_consistency(self):
for s in [slice(1, 3), slice(2, 7, 3)]:
for idx in range(3):
assert np.allclose(self.hits.dom_id[idx][s], self.hits[idx].dom_id[s])
assert np.allclose(
OFFLINE_FILE.events[idx].hits.dom_id[s], self.hits.dom_id[idx][s]
self.hits.dom_id[idx][s].tolist(), self.hits[idx].dom_id[s].tolist()
)
assert np.allclose(
OFFLINE_FILE.events[idx].hits.dom_id[s].tolist(),
self.hits.dom_id[idx][s].tolist(),
)
def test_index_consistency(self):
for idx, dom_ids in self.dom_id.items():
assert np.allclose(
self.hits[idx].dom_id[: self.n_hits], dom_ids[: self.n_hits]
self.hits[idx].dom_id[: self.n_hits].tolist(), dom_ids[: self.n_hits]
)
assert np.allclose(
OFFLINE_FILE.events[idx].hits.dom_id[: self.n_hits],
OFFLINE_FILE.events[idx].hits.dom_id[: self.n_hits].tolist(),
dom_ids[: self.n_hits],
)
for idx, ts in self.t.items():
assert np.allclose(self.hits[idx].t[: self.n_hits], ts[: self.n_hits])
assert np.allclose(
OFFLINE_FILE.events[idx].hits.t[: self.n_hits], ts[: self.n_hits]
self.hits[idx].t[: self.n_hits].tolist(), ts[: self.n_hits]
)
assert np.allclose(
OFFLINE_FILE.events[idx].hits.t[: self.n_hits].tolist(),
ts[: self.n_hits],
)
def test_keys(self):
assert "dom_id" in self.hits.keys()
def test_fields(self):
assert "dom_id" in self.hits.fields
assert "channel_id" in self.hits.fields
assert "t" in self.hits.fields
assert "tot" in self.hits.fields
assert "trig" in self.hits.fields
assert "id" in self.hits.fields
class TestOfflineTracks(unittest.TestCase):
......@@ -337,16 +371,24 @@ class TestOfflineTracks(unittest.TestCase):
self.tracks_numucc = OFFLINE_NUMUCC
self.n_events = 10
def test_attributes_available(self):
for key in self.tracks._keymap.keys():
getattr(self.tracks, key)
@unittest.skip
def test_attributes(self):
for idx, dom_id in self.dom_id.items():
self.assertListEqual(dom_id, list(self.hits.dom_id[idx][: len(dom_id)]))
for idx, t in self.t.items():
assert np.allclose(t, self.hits.t[idx][: len(t)])
def test_fields(self):
for field in [
"id",
"pos_x",
"pos_y",
"pos_z",
"dir_x",
"dir_y",
"dir_z",
"t",
"E",
"len",
"lik",
"rec_type",
"rec_stages",
"fitinf",
]:
getattr(self.tracks, field)
def test_item_selection(self):
self.assertListEqual(
......@@ -354,23 +396,23 @@ class TestOfflineTracks(unittest.TestCase):
)
def test_repr(self):
assert " 10 " in repr(self.tracks)
assert "10" in repr(self.tracks)
def test_slicing(self):
tracks = self.tracks
self.assertEqual(10, len(tracks)) # 10 events
self.assertEqual(56, len(tracks[0])) # number of tracks in first event
self.assertEqual(56, len(tracks[0].id)) # number of tracks in first event
track_selection = tracks[2:7]
assert 5 == len(track_selection)
track_selection_2 = tracks[1:3]
assert 2 == len(track_selection_2)
for _slice in [
slice(0, 0),
slice(0, 1),
slice(0, 2),
slice(1, 5),
slice(3, -2),
]:
print(f"checking {_slice}")
self.assertListEqual(
list(tracks.E[:, 0][_slice]), list(tracks[_slice].E[:, 0])
)
......@@ -382,15 +424,7 @@ class TestOfflineTracks(unittest.TestCase):
)
self.assertAlmostEqual(
self.f.events.tracks.fitinf[3:5][1][9][2],
self.f.events[3:5][1][9][2].tracks.fitinf,
)
self.assertAlmostEqual(
self.f.events.tracks.fitinf[3:5][1][9][2],
self.f.events[3:5][1].tracks[9][2].fitinf,
)
self.assertAlmostEqual(
self.f.events.tracks.fitinf[3:5][1][9][2],
self.f.events[3:5][1].tracks[9].fitinf[2],
self.f.events[3:5][1].tracks.fitinf[9][2],
)
......@@ -398,18 +432,28 @@ class TestBranchIndexingMagic(unittest.TestCase):
def setUp(self):
self.events = OFFLINE_FILE.events
def test_foo(self):
def test_slicing_magic(self):
self.assertEqual(318, self.events[2:4].n_hits[0])
assert np.allclose(
self.events[3].tracks.dir_z[10], self.events.tracks.dir_z[3, 10]
)
assert np.allclose(
self.events[3:6].tracks.pos_y[:, 0], self.events.tracks.pos_y[3:6, 0]
self.events[3:6].tracks.pos_y[:, 0].tolist(),
self.events.tracks.pos_y[3:6, 0].tolist(),
)
def test_selecting_specific_items_via_a_list(self):
# test selecting with a list
self.assertEqual(3, len(self.events[[0, 2, 3]]))
def test_selecting_specific_items_via_a_numpy_array(self):
# test selecting with a list
self.assertEqual(3, len(self.events[np.array([0, 2, 3])]))
def test_selecting_specific_items_via_a_awkward_array(self):
# test selecting with a list
self.assertEqual(3, len(self.events[ak.Array([0, 2, 3])]))
class TestUsr(unittest.TestCase):
def setUp(self):
......@@ -439,27 +483,7 @@ class TestUsr(unittest.TestCase):
"NGeometryVetoHits",
"ClassficationScore",
],
self.f.events.usr.keys(),
)
def test_getitem_flat(self):
assert np.allclose(
[118.6302815337638, 44.33580521344907, 99.93916717621543],
self.f.events.usr["CoC"],
)
assert np.allclose(
[37.51967774166617, -10.280346193553832, 13.67595659707355],
self.f.events.usr["DeltaPosZ"],
)
def test_attributes_flat(self):
assert np.allclose(
[118.6302815337638, 44.33580521344907, 99.93916717621543],
self.f.events.usr.CoC,
)
assert np.allclose(
[37.51967774166617, -10.280346193553832, 13.67595659707355],
self.f.events.usr.DeltaPosZ,
self.f.events.usr_names[0].tolist(),
)
......@@ -471,11 +495,11 @@ class TestMcTrackUsr(unittest.TestCase):
n_tracks = len(self.f.events)
for i in range(3):
self.assertListEqual(
[b"bx", b"by", b"ichan", b"cc"],
["bx", "by", "ichan", "cc"],
self.f.events.mc_tracks.usr_names[i][0].tolist(),
)
self.assertListEqual(
[b"energy_lost_in_can"],
["energy_lost_in_can"],
self.f.events.mc_tracks.usr_names[i][1].tolist(),
)
......@@ -488,8 +512,3 @@ class TestMcTrackUsr(unittest.TestCase):
assert np.allclose(
[0.147, 0.4, 3, 2], self.f.events.mc_tracks.usr[1][0].tolist(), atol=0.001
)
class TestNestedMapper(unittest.TestCase):
def test_nested_mapper(self):
self.assertEqual("pos_x", _nested_mapper("trks.pos.x"))
......@@ -18,7 +18,6 @@ from km3io.tools import (
uniquecount,
fitinf,
count_nested,
_find,
mask,
best_track,
get_w2list_param,
......@@ -28,9 +27,16 @@ from km3io.tools import (
best_aashower,
best_dusjshower,
is_cc,
usr,
)
OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root"))
OFFLINE_USR = OfflineReader(data_path("offline/usr-sample.root"))
OFFLINE_MC_TRACK_USR = OfflineReader(
data_path(
"offline/mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root"
)
)
GENHEN_OFFLINE_FILE = OfflineReader(
data_path("offline/mcv5.1.genhen_anumuNC.sirene.jte.jchain.aashower.sample.root")
)
......@@ -58,18 +64,12 @@ class TestFitinf(unittest.TestCase):
assert beta[1] == self.best_fit[1][0]
assert beta[2] == self.best_fit[2][0]
def test_fitinf_from_one_event_and_one_track(self):
beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks[0][0])
assert beta == self.tracks[0][0].fitinf[0]
class TestBestTrackSelection(unittest.TestCase):
def setUp(self):
self.events = OFFLINE_FILE.events
self.one_event = OFFLINE_FILE.events[0]
@unittest.skip
def test_best_track_selection_from_multiple_events_with_explicit_stages_in_list(
self,
):
......@@ -102,58 +102,46 @@ class TestBestTrackSelection(unittest.TestCase):
assert best3.rec_stages[2] is None
assert best3.rec_stages[3] is None
@unittest.skip
def test_best_track_selection_from_multiple_events_with_explicit_stages_in_set(
def test_best_track_selection_from_multiple_events_with_a_set_of_stages(
self,
):
best = best_track(self.events.tracks, stages={1, 3, 4, 5})
assert len(best) == 10
assert best.rec_stages[0].tolist() == [[1, 3, 5, 4]]
assert best.rec_stages[1].tolist() == [[1, 3, 5, 4]]
assert best.rec_stages[2].tolist() == [[1, 3, 5, 4]]
assert best.rec_stages[3].tolist() == [[1, 3, 5, 4]]
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
assert best.rec_stages[3].tolist() == [1, 3, 5, 4]
# test with a shorter set of rec_stages
best2 = best_track(self.events.tracks, stages={1, 3})
assert len(best2) == 10
assert best2.rec_stages[0].tolist() == [[1, 3]]
assert best2.rec_stages[1].tolist() == [[1, 3]]
assert best2.rec_stages[2].tolist() == [[1, 3]]
assert best2.rec_stages[3].tolist() == [[1, 3]]
# test the irrelevance of order in rec_stages in sets
best3 = best_track(self.events.tracks, stages={3, 1})
assert len(best3) == 10
assert best3.rec_stages[0].tolist() == [[1, 3]]
assert best3.rec_stages[1].tolist() == [[1, 3]]
assert best3.rec_stages[2].tolist() == [[1, 3]]
assert best3.rec_stages[3].tolist() == [[1, 3]]
for rec_stages in best2.rec_stages:
for stage in {1, 3}:
assert stage in rec_stages
def test_best_track_selection_from_multiple_events_with_start_end(self):
best = best_track(self.events.tracks, startend=(1, 4))
assert len(best) == 10
assert best.rec_stages[0].tolist() == [[1, 3, 5, 4]]
assert best.rec_stages[1].tolist() == [[1, 3, 5, 4]]
assert best.rec_stages[2].tolist() == [[1, 3, 5, 4]]
assert best.rec_stages[3].tolist() == [[1, 3, 5, 4]]
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
assert best.rec_stages[3].tolist() == [1, 3, 5, 4]
# test with shorter stages
best2 = best_track(self.events.tracks, startend=(1, 3))
assert len(best2) == 10
assert best2.rec_stages[0].tolist() == [[1, 3]]
assert best2.rec_stages[1].tolist() == [[1, 3]]
assert best2.rec_stages[2].tolist() == [[1, 3]]
assert best2.rec_stages[3].tolist() == [[1, 3]]
assert best2.rec_stages[0].tolist() == [1, 3]
assert best2.rec_stages[1].tolist() == [1, 3]
assert best2.rec_stages[2].tolist() == [1, 3]
assert best2.rec_stages[3].tolist() == [1, 3]
# test the importance of start as a real start of rec_stages
best3 = best_track(self.events.tracks, startend=(0, 3))
......@@ -179,23 +167,20 @@ class TestBestTrackSelection(unittest.TestCase):
# stages as a list
best = best_track(self.one_event.tracks, stages=[1, 3, 5, 4])
assert len(best) == 1
assert best.lik == ak.max(self.one_event.tracks.lik)
assert np.allclose(best.rec_stages[0].tolist(), [1, 3, 5, 4])
assert np.allclose(best.rec_stages.tolist(), [1, 3, 5, 4])
# stages as a set
best2 = best_track(self.one_event.tracks, stages={1, 3, 4, 5})
assert len(best2) == 1
assert best2.lik == ak.max(self.one_event.tracks.lik)
assert np.allclose(best2.rec_stages[0].tolist(), [1, 3, 5, 4])
assert np.allclose(best2.rec_stages.tolist(), [1, 3, 5, 4])
# stages with start and end
best3 = best_track(self.one_event.tracks, startend=(1, 4))
assert len(best3) == 1
assert best3.lik == ak.max(self.one_event.tracks.lik)
assert np.allclose(best3.rec_stages[0].tolist(), [1, 3, 5, 4])
assert np.allclose(best3.rec_stages.tolist(), [1, 3, 5, 4])
def test_best_track_on_slices_one_event(self):
tracks_slice = self.one_event.tracks[self.one_event.tracks.rec_type == 4000]
......@@ -203,63 +188,90 @@ class TestBestTrackSelection(unittest.TestCase):
# test stages with list
best = best_track(tracks_slice, stages=[1, 3, 5, 4])
assert len(best) == 1
assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert best.rec_stages.tolist() == [1, 3, 5, 4]
# test stages with set
best2 = best_track(tracks_slice, stages={1, 3, 4, 5})
assert len(best2) == 1
assert best2.lik == ak.max(tracks_slice.lik)
assert best2.rec_stages[0].tolist() == [1, 3, 5, 4]
assert best2.rec_stages.tolist() == [1, 3, 5, 4]
def test_best_track_on_slices_with_start_end_one_event(self):
tracks_slice = self.one_event.tracks[0:5]
best = best_track(tracks_slice, startend=(1, 4))
assert len(best) == 1
assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0][0] == 1
assert best.rec_stages[0][-1] == 4
assert best.rec_stages[0] == 1
assert best.rec_stages[-1] == 4
def test_best_track_on_slices_with_explicit_rec_stages_one_event(self):
tracks_slice = self.one_event.tracks[0:5]
best = best_track(tracks_slice, stages=[1, 3, 5, 4])
assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0][0] == 1
assert best.rec_stages[0][-1] == 4
assert best.rec_stages[0] == 1
assert best.rec_stages[-1] == 4
@unittest.skip
def test_best_track_on_slices_multiple_events(self):
tracks_slice = self.events.tracks[0:5]
tracks_slice = self.events[0:5].tracks
# stages in list
best = best_track(tracks_slice, stages=[1, 3, 5, 4])
assert len(best) == 5
assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert np.allclose(
best.lik.tolist(),
[
294.6407542676734,
96.75133289411137,
560.2775306614813,
278.2872951665753,
99.59098153341449,
],
)
for i in range(len(best)):
assert best.rec_stages[i].tolist() == [1, 3, 5, 4]
# stages in set
best = best_track(tracks_slice, stages={1, 3, 4, 5})
best = best_track(tracks_slice, stages={3, 4, 5})
assert len(best) == 5
assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert np.allclose(
best.lik.tolist(),
[
294.6407542676734,
96.75133289411137,
560.2775306614813,
278.2872951665753,
99.59098153341449,
],
)
for i in range(len(best)):
assert best.rec_stages[i].tolist() == [1, 3, 5, 4]
# using start and end
best = best_track(tracks_slice, startend=(1, 4))
start, end = (1, 4)
best = best_track(tracks_slice, startend=(start, end))
assert len(best) == 5
assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert np.allclose(
best.lik.tolist(),
[
294.6407542676734,
96.75133289411137,
560.2775306614813,
278.2872951665753,
99.59098153341449,
],
)
for i in range(len(best)):
rs = best.rec_stages[i].tolist()
assert rs[0] == start
assert rs[-1] == end
def test_best_track_raises_when_unknown_stages(self):
with self.assertRaises(ValueError):
......@@ -276,10 +288,10 @@ class TestBestJmuon(unittest.TestCase):
assert len(best) == 10
assert best.rec_stages[0].tolist() == [[1, 3, 5, 4]]
assert best.rec_stages[1].tolist() == [[1, 3, 5, 4]]
assert best.rec_stages[2].tolist() == [[1, 3, 5, 4]]
assert best.rec_stages[3].tolist() == [[1, 3, 5, 4]]
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
assert best.rec_stages[3].tolist() == [1, 3, 5, 4]
assert best.lik[0] == ak.max(OFFLINE_FILE.events.tracks.lik[0])
assert best.lik[1] == ak.max(OFFLINE_FILE.events.tracks.lik[1])
......@@ -378,28 +390,18 @@ class TestRecStagesMasks(unittest.TestCase):
self.tracks = OFFLINE_FILE.events.tracks
def test_find(self):
builder = ak.ArrayBuilder()
_find(self.nested, ak.Array([1, 2, 3]), builder)
labels = builder.snapshot()
assert labels[0][0] == 1
assert labels[0][1] == 1
assert labels[0][2] == 0
assert labels[1][0] == 0
def test_mask_with_explicit_rec_stages_in_list_with_multiple_events(self):
rec_stages = self.tracks.rec_stages
stages = [1, 3, 5, 4]
masks = mask(self.tracks, stages=stages)
masks = mask(self.tracks.rec_stages, sequence=stages)
assert masks[0][0] == all(rec_stages[0][0] == ak.Array(stages))
assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
assert masks[0][1] == False
def test_mask_with_explicit_rec_stages_in_set_with_multiple_events(self):
stages = {1, 3, 4, 5}
masks = mask(self.tracks, stages=stages)
def test_mask_with_atleast_on_multiple_events(self):
stages = [1, 3, 4, 5]
masks = mask(self.tracks.rec_stages, atleast=stages)
tracks = self.tracks[masks]
assert 1 in tracks.rec_stages[0][0]
......@@ -410,7 +412,7 @@ class TestRecStagesMasks(unittest.TestCase):
def test_mask_with_start_and_end_of_rec_stages_with_multiple_events(self):
rec_stages = self.tracks.rec_stages
stages = [1, 3, 5, 4]
masks = mask(self.tracks, startend=(1, 4))
masks = mask(self.tracks.rec_stages, startend=(1, 4))
assert masks[0][0] == all(rec_stages[0][0] == ak.Array(stages))
assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
......@@ -420,7 +422,7 @@ class TestRecStagesMasks(unittest.TestCase):
rec_stages = self.tracks.rec_stages[0][0]
stages = [1, 3, 5, 4]
track = self.tracks[0]
masks = mask(track, startend=(1, 4))
masks = mask(track.rec_stages, startend=(1, 4))
assert track[masks].rec_stages[0][0] == 1
assert track[masks].rec_stages[0][-1] == 4
......@@ -429,20 +431,37 @@ class TestRecStagesMasks(unittest.TestCase):
rec_stages = self.tracks.rec_stages[0][0]
stages = [1, 3]
track = self.tracks[0]
masks = mask(track, stages=stages)
masks = mask(track.rec_stages, sequence=stages)
assert track[masks].rec_stages[0][0] == stages[0]
assert track[masks].rec_stages[0][1] == stages[1]
def test_mask_raises_when_too_many_inputs(self):
with self.assertRaises(ValueError):
mask(self.tracks, startend=(1, 4), stages=[1, 3, 5, 4])
def test_mask_raises_when_no_inputs(self):
with self.assertRaises(ValueError):
mask(self.tracks)
class TestMask(unittest.TestCase):
def test_minmax_2dim_mask(self):
arr = ak.Array([[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]])
m = mask(arr, minmax=(1, 4))
self.assertListEqual(m.tolist(), [True, False, False])
def test_minmax_3dim_mask(self):
arr = ak.Array([[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]])
m = mask(arr, minmax=(1, 4))
self.assertListEqual(m.tolist(), [[True, False, False], [True]])
def test_minmax_4dim_mask(self):
arr = ak.Array(
[[[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]], [[[1, 9], [3, 3]]]]
)
m = mask(arr, minmax=(1, 4))
self.assertListEqual(
m.tolist(), [[[True, False, False], [True]], [[False, True]]]
)
class TestUnique(unittest.TestCase):
def run_random_test_with_dtype(self, dtype):
max_range = 100
......@@ -543,3 +562,22 @@ class TestIsCC(unittest.TestCase):
all(NC_file) == True
) # this test fails because the CC flags are not reliable in old files
self.assertTrue(all(CC_file) == True)
class TestUsr(unittest.TestCase):
def test_event_usr(self):
assert np.allclose(
[118.6302815337638, 44.33580521344907, 99.93916717621543],
usr(OFFLINE_USR.events, "CoC").tolist(),
)
assert np.allclose(
[37.51967774166617, -10.280346193553832, 13.67595659707355],
usr(OFFLINE_USR.events, "DeltaPosZ").tolist(),
)
def test_mc_tracks_usr(self):
assert np.allclose(
[0.0487],
usr(OFFLINE_MC_TRACK_USR.mc_tracks[0], "bx").tolist(),
atol=0.0001,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment