diff --git a/.coveragerc b/.coveragerc index 39f1ba7c73884d8a5b4884522bd91f025709899e..f1263edf116d2e9f372c2e11a8ef333d779559a8 100644 --- a/.coveragerc +++ b/.coveragerc @@ -17,3 +17,5 @@ exclude_lines = if self.debug: if settings.DEBUG def __repr__ + @njit + @nb.njit diff --git a/README.rst b/README.rst index c365dea99d5b8c2d782f4b05bb5ed5d0e7189fc4..ac7cac55b8fc3fbaa26de4a6a8eef630b93531d2 100644 --- a/README.rst +++ b/README.rst @@ -43,169 +43,133 @@ If you have a question about km3io, please proceed as follows: - Haven't you found an answer to your question in the documentation, post a git issue with your question showing us an example of what you have tried first, and what you would like to do. - Have you noticed a bug, please post it in a git issue, we appreciate your contribution. -Tutorial -======== -**Table of contents:** +Introduction +------------ -* `Introduction <#introduction>`__ +Most of km3net data is stored in root files. These root files are created using the `KM3NeT Dataformat library <https://git.km3net.de/common/km3net-dataformat>`__ +A ROOT file created with +`Jpp <https://git.km3net.de/common/jpp>`__ is an "online" file and all other software usually produces "offline" files. - * `Overview of online files <#overview-of-online-files>`__ +km3io is a Python package that provides a set of classes: ``OnlineReader``, ``OfflineReader`` and a special class to read gSeaGen files. All of these ROOT files can be read installing any other software like Jpp, aanet or ROOT. - * `Overview of offline files <#overview-of-offline-files>`__ +Data in km3io is returned as ``awkward.Array`` which is an advance Numpy-like container type to store +contiguous data for high performance computations. +Such an ``awkward.Array`` supports any level of nested arrays and records which can have different lengths, in contrast to Numpy where everything has to be rectangular. -* `Online files reader <#online-files-reader>`__ +The example is shown below shows the array which contains the ``dir_z`` values +of each track of the first 4 events. The type ``4 * var * float64`` means that +it has 4 subarrays with variable lengths of type ``float64``: - * `Reading Events <#reading-events>`__ +.. code-block:: python3 - * `Reading SummarySlices <#reading-summaryslices>`__ + >>> import km3io + >>> from km3net_testdata import data_path + >>> f = km3io.OfflineReader(data_path("offline/numucc.root")) + >>> f[:4].tracks.dir_z + <Array [[0.213, 0.213, ... 0.229, 0.323]] type='4 * var * float64'> - * `Reading Timeslices <#reading-timeslices>`__ +The same concept applies to everything, including ``hits``, ``mc_hits``, +``mc_tracks``, ``t_sec`` etc. -* `Offline files reader <#offline-file-reader>`__ +Offline files reader +-------------------- - * `reading events data <#reading-events-data>`__ +In general an offline file has two methods to fetch data: the header and the events. Let's start with the header. - * `reading usr data of events <#reading-usr-data-of-events>`__ +Reading the file header +""""""""""""""""""""""" - * `reading hits data <#reading-hits-data>`__ +To read an offline file start with opening it with an OfflineReader: - * `reading tracks data <#reading-tracks-data>`__ +.. code-block:: python3 - * `reading mc hits data <#reading-mc-hits-data>`__ + >>> import km3io + >>> from km3net_testdata import data_path + >>> f = km3io.OfflineReader(data_path("offline/numucc.root")) - * `reading mc tracks data <#reading-mc-tracks-data>`__ +Calling the header can be done with: +.. code-block:: python3 + >>> f.header + <km3io.offline.Header at 0x7fcd81025990> -Introduction ------------- +and provides lazy access. In offline files the header is unique and can be printed -Most of km3net data is stored in root files. These root files are either created with `Jpp <https://git.km3net.de/common/jpp>`__ or `aanet <https://git.km3net.de/common/aanet>`__ software. A root file created with -`Jpp <https://git.km3net.de/common/jpp>`__ is often referred to as "a Jpp root file". Similarly, a root file created with `aanet <https://git.km3net.de/common/aanet>`__ is often referred to as "an aanet file". In km3io, an aanet root file will always be reffered to as an ``offline file``, while a Jpp ROOT file will always be referred to as a ``online file``. +.. code-block:: python3 -km3io is a Python package that provides a set of classes (``OnlineReader`` and ``OfflineReader``) to read both online ROOT files and offline ROOT files without any dependency to aanet, Jpp or ROOT. + >>> print(f.header) + MC Header: + DAQ(livetime=394) + PDF(i1=4, i2=58) + can(zmin=0, zmax=1027, r=888.4) + can_user: can_user(field_0=0.0, field_1=1027.0, field_2=888.4) + coord_origin(x=0, y=0, z=0) + cut_in(Emin=0, Emax=0, cosTmin=0, cosTmax=0) + cut_nu(Emin=100, Emax=100000000.0, cosTmin=-1, cosTmax=1) + cut_primary(Emin=0, Emax=0, cosTmin=0, cosTmax=0) + cut_seamuon(Emin=0, Emax=0, cosTmin=0, cosTmax=0) + decay: decay(field_0='doesnt', field_1='happen') + detector: NOT + drawing: Volume + genhencut(gDir=2000, Emin=0) + genvol(zmin=0, zmax=1027, r=888.4, volume=2649000000.0, numberOfEvents=100000) + kcut: 2 + livetime(numberOfSeconds=0, errorOfSeconds=0) + model(interaction=1, muon=2, scattering=0, numberOfEnergyBins=1, field_4=12) + ngen: 100000.0 + norma(primaryFlux=0, numberOfPrimaries=0) + nuflux: nuflux(field_0=0, field_1=3, field_2=0, field_3=0.5, field_4=0.0, field_5=1.0, field_6=3.0) + physics(program='GENHEN', version='7.2-220514', date=181116, time=1138) + seed(program='GENHEN', level=3, iseed=305765867, field_3=0, field_4=0) + simul(program='JSirene', version=11012, date='11/17/18', time=7) + sourcemode: diffuse + spectrum(alpha=-1.4) + start_run(run_id=1) + target: isoscalar + usedetfile: false + xlat_user: 0.63297 + xparam: OFF + zed_user: zed_user(field_0=0.0, field_1=3450.0) -Data in km3io is often returned as a "lazyarray", a "jagged lazyarray" or a `Numpy <https://docs.scipy.org/doc/numpy>`__ array. A lazyarray is an array-like object that reads data on demand! In a lazyarray, only the first and the last chunks of data are read in memory. A lazyarray can be used with all Numpy's universal `functions <https://docs.scipy.org/doc/numpy/reference/ufuncs.html>`__. Here is how a lazyarray looks like: +To read the values in the header one can call them directly: .. code-block:: python3 - # <ChunkedArray [5971 5971 5971 ... 5971 5971 5971] at 0x7fb2341ad810> + >>> f.header.DAQ.livetime + 394 + >>> f.header.cut_nu.Emin + 100 + >>> f.header.genvol.numberOfEvents + 100000 -A jagged array, is a 2+ dimentional array with different arrays lengths. In other words, a jagged array is an array of arrays of different sizes. So a jagged lazyarray is simply a jagged array of lazyarrays with different sizes. Here is how a jagged lazyarray looks like: +Reading events +"""""""""""""" +To start reading events call the events method on the file: .. code-block:: python3 - # <JaggedArray [[102 102 102 ... 11517 11518 11518] [] [101 101 102 ... 11518 11518 11518] ... [101 101 102 ... 11516 11516 11517] [] [101 101 101 ... 11517 11517 11518]] at 0x7f74b0ef8810> - - -Overview of Online files -"""""""""""""""""""""""" -Online files are written by the DataWriter (part of Jpp) and contain events, timeslices and summary slices. - - -Overview of offline files -""""""""""""""""""""""""" - -Offline files contain data about events, hits and tracks. Based on aanet version 2.0.0 documentation, the following tables show the definitions, the types and the units of the branches founds in the events, hits and tracks trees. A description of the file header are also displayed. - -.. csv-table:: events keys definitions and units - :header: "type", "name", "definition" - :widths: 20, 20, 80 - - "int", "id", "offline event identifier" - "int", "det_id", "detector identifier from DAQ" - "int", "mc_id", "identifier of the MC event (as found in ascii or antcc file)" - "int", "run_id", "DAQ run identifier" - "int", "mc_run_id", "MC run identifier" - "int", "frame_index", "from the raw data" - "ULong64_t", "trigger_mask", "trigger mask from raw data (i.e. the trigger bits)" - "ULong64_t", "trigger_counter", "trigger counter" - "unsigned int", "overlays", "number of overlaying triggered events" - "TTimeStamp", "t", "UTC time of the start of the timeslice the event came from" - "vec Hit", "hits", "list of hits" - "vec Trk", "trks", "list of reconstructed tracks (can be several because of prefits,showers, etc)" - "vec double", "w", "MC: Weights w[0]=w1 & w[1]=w2 & w[2]]=w3" - "vec double", "w2list", "MC: factors that make up w[1]=w2" - "vec double", "w3list", "MC: atmospheric flux information" - "double", "mc_t", "MC: time of the mc event" - "vec Hit", "mc_hits", "MC: list of MC truth hits" - "vec Trk", "mc_trks", "MC: list of MC truth tracks" - "string", "comment", "user can use this as he/she likes" - "int", "index", "user can use this as he/she likes" - - -.. csv-table:: hits keys definitions and units - :header: "type", "name", "definition" - :widths: 20, 20, 80 - - "int", "id", "hit id" - "int", "dom_id", "module identifier from the data (unique in the detector)" - "unsigned int", "channel_id", "PMT channel id {0,1, .., 31} local to module" - "unsigned int", "tdc", "hit tdc (=time in ns)" - "unsigned int", "tot", "tot value as stored in raw data (int for pyroot)" - "int", "trig", "non-zero if the hit is a trigger hit" - "int", "pmt_id", "global PMT identifier as found in evt files" - "double", "t", "hit time (from calibration or MC truth)" - "double", "a", "hit amplitude (in p.e.)" - "vec", "pos", "hit position" - "vec", "dir", "hit direction i.e. direction of the PMT" - "double", "pure_t", "photon time before pmt simultion (MC only)" - "double", "pure_a", "amptitude before pmt simution (MC only)" - "int", "type", "particle type or parametrisation used for hit (mc only)" - "int", "origin", "track id of the track that created this hit" - "unsigned", "pattern_flags", "some number that you can use to flag the hit" - - -.. csv-table:: tracks keys definitions and units - :header: "type", "name", "definition" - :widths: 20, 20, 80 - - "int", "id", "track identifier" - "vec", "pos", "position of the track at time t" - "vec", "dir", "track direction" - "double", "t", "track time (when particle is at pos)" - "double", "E", "Energy (either MC truth or reconstructed)" - "double", "len", "length if applicable" - "double", "lik", "likelihood or lambda value (for aafit: lambda)" - "int", "type", "MC: particle type in PDG encoding" - "int", "rec_type", "identifyer for the overall fitting algorithm/chain/strategy" - "vec int", "rec_stages", "list of identifyers of succesfull fitting stages resulting in this track" - "int", "status", "MC status code" - "int", "mother_id", "MC id of the parent particle" - "vec double", "fitinf", "place to store additional fit info for jgandalf see FitParameters.csv" - "vec int", "hit_ids", "list of associated hit-ids (corresponds to Hit::id)" - "vec double", "error_matrix", "(5x5) error covariance matrix (stored as linear vector)" - "string", "comment", "user comment" - - -.. csv-table:: offline file header definitions - :header: "name", "definition" - :widths: 40, 80 - - "DAQ", "livetime" - "cut_primary cut_seamuon cut_in cut_nu", "Emin Emax cosTmin cosTmax" - "generator physics simul", "program version date time" - "seed", "program level iseed" - "PM1_type_area", "type area TTS" - "PDF", "i1 i2" - "model", "interaction muon scattering numberOfEnergyBins" - "can", "zmin zmax r" - "genvol", "zmin zmax r volume numberOfEvents" - "merge", "time gain" - "coord_origin", "x y z" - "translate", "x y z" - "genhencut", "gDir Emin" - "k40", "rate time" - "norma", "primaryFlux numberOfPrimaries" - "livetime", "numberOfSeconds errorOfSeconds" - "flux", "type key file_1 file_2" - "spectrum", "alpha" - "fixedcan", "xcenter ycenter zmin zmax radius" - "start_run", "run_id" + >>> f + OfflineReader (10 events) + >>> f.keys() + {'comment', 'det_id', 'flags', 'frame_index', 'hits', 'id', 'index', + 'mc_hits', 'mc_id', 'mc_run_id', 'mc_t', 'mc_tracks', 'mc_trks', + 'n_hits', 'n_mc_hits', 'n_mc_tracks', 'n_mc_trks', 'n_tracks', + 'n_trks', 'overlays', 'run_id', 't_ns', 't_sec', 'tracks', + 'trigger_counter', 'trigger_mask', 'trks', 'usr', 'usr_names', + 'w', 'w2list', 'w3list'} + +Like the online reader lazy access is used. Using <TAB> completion gives an overview of available data. Alternatively the method `keys` can be used on events and it's data members containing a structure to see what is available for reading. + +Reading the reconstructed values like energy and direction of an event can be done with: + +.. code-block:: python3 + >>> f.events.tracks.E + <Array [[117, 117, 0, 0, 0, ... 0, 0, 0, 0, 0]] type='10 * var * float64'> Online files reader ------------------- @@ -332,105 +296,3 @@ channel, time and ToT: -Offline files reader --------------------- - -In general an offline file has two methods to fetch data: the header and the events. Let's start with the header. - -Reading the file header -""""""""""""""""""""""" - -To read an offline file start with opening it with an OfflineReader: - -.. code-block:: python3 - - import km3io - f = km3io.OfflineReader("mcv5.0.gsg_elec-CC_1-500GeV.sirene.jte.jchain.jsh.aanet.1.root") - -Calling the header can be done with: - -.. code-block:: python3 - - >>> f.header - <km3io.offline.Header at 0x7fcd81025990> - -and provides lazy access. In offline files the header is unique and can be printed - -.. code-block:: python3 - - >>> print(f.header) - MC Header: - DAQ(livetime=35.5) - XSecFile: /project/antares/public_student_software/genie/v3.00.02-hedis/Generator/genie_xsec/gSeaGen/G18_02a_00_000/gxspl-seawater.xml - coord_origin(x=457.8, y=574.3, z=0) - cut_nu(Emin=1, Emax=500, cosTmin=-1, cosTmax=1) - drawing: surface - fixedcan(xcenter=457.8, ycenter=574.3, zmin=0, zmax=475.6, radius=308.2) - genvol(zmin=0, zmax=475.6, r=308.2, volume=148000000.0, numberOfEvents=1000000.0) - simul(program='gSeaGen', version='dev', date=200616, time=223726) - simul_1: simul_1(field_0='GENIE', field_1='3.0.2', field_2=200616, field_3=223726) - simul_2: simul_2(field_0='GENIE_REWEIGHT', field_1='1.0.0', field_2=200616, field_3=223726) - simul_3: simul_3(field_0='JSirene', field_1='13.0.0-alpha.5-113-gaa686a6a-D', field_2='06/17/20', field_3=0) - spectrum(alpha=-3) - start_run(run_id=1) - tgen: 31556900.0 - -An overview of the values in a the header are given in the `Overview of offline files <#overview-of-offline-files>`__. -To read the values in the header one can call them directly: - -.. code-block:: python3 - - >>> f.header.DAQ.livetime - 35.5 - >>> f.header.cut_nu.Emin - 1 - >>> f.header.genvol.numberOfEvents - 1000000.0 - - -Reading events -"""""""""""""" - -To start reading events call the events method on the file: - -.. code-block:: python3 - - >>> f.events - <OfflineBranch[events]: 355 elements> - -Like the online reader lazy access is used. Using <TAB> completion gives an overview of available data. Alternatively the method `keys` can be used on events and it's data members containing a structure to see what is available for reading. - -.. code-block:: python3 - - >>> f.events.keys() - dict_keys(['w2list', 'frame_index', 'overlays', 'comment', 'id', 'w', 'run_id', 'mc_t', 'mc_run_id', 'det_id', 'w3list', 'trigger_mask', 'mc_id', 'flags', 'trigger_counter', 'index', 't_sec', 't_ns', 'n_hits', 'n_mc_hits', 'n_tracks', 'n_mc_tracks']) - >>> f.events.tracks.keys() - dict_keys(['mother_id', 'status', 'lik', 'error_matrix', 'dir_z', 'len', 'rec_type', 'id', 't', 'dir_x', 'rec_stages', 'dir_y', 'fitinf', 'pos_z', 'hit_ids', 'comment', 'type', 'any', 'E', 'pos_y', 'usr_names', 'pos_x']) - -Reading the reconstructed values like energy and direction of an event can be done with: - -.. code-block:: python3 - - >>> f.events.tracks.E - <ChunkedArray [[3.8892237665736844 0.0 0.0 ... 0.0 0.0 0.0] [2.2293441683824318 5.203533524801224 6.083598278897039 ... 0.0 0.0 0.0] [3.044857858677666 3.787165776302862 4.5667729757360656 ... 0.0 0.0 0.0] ... [2.205652079790387 2.120769181474425 1.813066579943641 ... 0.0 0.0 0.0] [2.1000775068170343 3.939512272391431 3.697537355163539 ... 0.0 0.0 0.0] [4.213600763523154 1.7412855636388889 1.6657605276356036 ... 0.0 0.0 0.0]] at 0x7fcd5acb0950> - >>> f.events.tracks.E[12] - array([ 4.19391543, 15.3079374 , 10.47125863, ..., 0. , - 0. , 0. ]) - >>> f.events.tracks.dir_z - <ChunkedArray [[0.7855203887479368 0.7855203887479368 0.7855203887479368 ... -0.5680647731737454 1.0 1.0] [0.9759269228630431 0.2677622006758061 -0.06664626796127045 ... -2.3205103555187022e-08 1.0 1.0] [-0.12332041078454238 0.09537382569575953 0.09345521875272474 ... -0.6631226836266504 -0.6631226836266504 -0.6631226836266504] ... [-0.1396584943602339 -0.08400681020109765 -0.014562067998281832 ... 1.0 1.0 1.0] [0.011997491147399564 -0.08496327394947281 -0.12675279061755318 ... 0.12053665899140412 1.0 1.0] [0.6548114607791208 0.8115427935470209 0.9043563059276946 ... 1.0 1.0 1.0]] at 0x7fcd73746410> - >>> f.events.tracks.dir_z[12] - array([ 2.39745910e-01, 3.45008838e-01, 4.81870447e-01, 4.55139657e-01, ..., - -2.32051036e-08, 1.00000000e+00]) - -Since reconstruction stages can be done multiple times and events can have multiple reconstructions, the vectors of reconstructed values can have variable length. Other data members like the header are always the same size. The definitions of data members can be found in the `definitions <https://git.km3net.de/km3py/km3io/-/tree/master/km3io/definitions>`__ folder. The definitions contain fit parameters, header information, reconstruction information, generator output and can be expaneded to include more. - -To use the definitions imagine the following: the user wants to read out the MC value of the Bjorken-Y of event 12 that was generated with gSeaGen. This can be found in the `gSeaGen definitions <https://git.km3net.de/km3py/km3io/-/blob/master/km3io/definitions/w2list_gseagen.py>`__: `"W2LIST_GSEAGEN_BY": 8,` - -This value is saved into `w2list`, so if an event is generated with gSeaGen the value can be fetched like: - -.. code-block:: python3 - - >>> f.events.w2list[12][8] - 0.393755 - -Note that w2list can also contain other values if the event is generated with another generator. diff --git a/examples/plot_offline_hits.py b/examples/plot_offline_hits.py deleted file mode 100644 index 05972cd0dd50cb2ac2818e7c7f963439b9646916..0000000000000000000000000000000000000000 --- a/examples/plot_offline_hits.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -Reading Offline hits -==================== - -The following example shows how to access hits data in an offline ROOT file, which is -written by aanet software. - -Note: the offline file used here has MC offline data and was intentionaly reduced -to 10 events. -""" -import km3io as ki -from km3net_testdata import data_path - -##################################################### -# To access offline hits/mc_hits data: - -mc_hits = ki.OfflineReader(data_path("offline/numucc.root")).events.mc_hits -hits = ki.OfflineReader(data_path("offline/km3net_offline.root")).events.hits - - -##################################################### -# Note that not all data is loaded in memory, so printing -# hits will only return how many elements (events) were found in -# the hits branch of the file. - -print(hits) - -##################################################### -# same for mc hits - -print(mc_hits) - -##################################################### -# Accessing the hits/mc_hits keys -# ------------------------------- -# to explore the hits keys: - -keys = hits.keys() -print(keys) - -##################################################### -# to explore the mc_hits keys: - -mc_keys = mc_hits.keys() -print(mc_keys) - -##################################################### -# Accessing hits data -# ------------------------- -# to access data in dom_id: - -dom_ids = hits.dom_id -print(dom_ids) - -##################################################### -# to access the channel ids: - -channel_ids = hits.channel_id -print(channel_ids) - -##################################################### -# That's it! you can access any key of your interest in the hits -# keys in the exact same way. - -##################################################### -# Accessing the mc_hits data -# -------------------------- -# similarly, you can access mc_hits data in any key of interest by -# following the same procedure as for hits: - -mc_pmt_ids = mc_hits.pmt_id -print(mc_pmt_ids) - - -##################################################### -# to access the mc_hits time: -mc_t = mc_hits.t -print(mc_t) - -##################################################### -# item selection in hits data -# --------------------------- -# hits data can be selected as you would select an item from a numpy array. -# for example, to select DOM ids in the hits corresponding to the first event: - -print(hits[0].dom_id) - -##################################################### -# or: - -print(hits.dom_id[0]) - -##################################################### -# slicing of hits -# --------------- -# to select a slice of hits data: - -print(hits[0:3].channel_id) - -##################################################### -# or: - -print(hits.channel_id[0:3]) - -##################################################### -# you can apply masks to hits data as you would do with numpy arrays: - -mask = hits.channel_id > 10 - -print(hits.channel_id[mask]) - -##################################################### -# or: - -print(hits.dom_id[mask]) diff --git a/examples/plot_offline_tracks.py b/examples/plot_offline_tracks.py index 8f0c4b0224a80f5c64f43fc305542901b09657e9..98cebbfa6d729d6f2e6af6f65a7b69c75f2e5429 100644 --- a/examples/plot_offline_tracks.py +++ b/examples/plot_offline_tracks.py @@ -2,8 +2,7 @@ Reading Offline tracks ====================== -The following example shows how to access tracks data in an offline ROOT file, which is -written by aanet software. +The following example shows how to access tracks data in an offline ROOT file. Note: the offline files used here were intentionaly reduced to 10 events. """ @@ -11,107 +10,100 @@ import km3io as ki from km3net_testdata import data_path ##################################################### -# To access offline tracks/mc_tracks data: +# We open the file using the + +f = ki.OfflineReader(data_path("offline/numucc.root")) -mc_tracks = ki.OfflineReader(data_path("offline/numucc.root")).events.mc_tracks -tracks = ki.OfflineReader(data_path("offline/km3net_offline.root")).events.tracks +##################################################### +# To access offline tracks/mc_tracks data: +f.tracks +f.mc_tracks ##################################################### -# Note that not all data is loaded in memory, so printing -# tracks will only return how many elements (events) were found in -# the tracks branch of the file. +# Note that no data is loaded in memory at this point, so printing +# tracks will only return how many sub-branches (corresponding to +# events) were found. -print(tracks) +f.tracks ##################################################### # same for mc hits -print(mc_tracks) +f.mc_tracks ##################################################### # Accessing the tracks/mc_tracks keys # ----------------------------------- -# to explore the tracks keys: +# to explore the reconstructed tracks fields: -keys = tracks.keys() -print(keys) +f.tracks.fields ##################################################### -# to explore the mc_tracks keys: +# the same for MC tracks -mc_keys = mc_tracks.keys() -print(mc_keys) +f.mc_tracks.fields ##################################################### # Accessing tracks data # --------------------- -# to access data in `E` (tracks energy): +# each field will return a nested `awkward.Array` and load everything into +# memory, so be careful if you are working with larger files. -E = tracks.E -print(E) +f.tracks.E -##################################################### -# to access the likelihood: +###################################################### +# The z direction of all reconstructed tracks -likelihood = tracks.lik -print(likelihood) +f.tracks.dir_z -##################################################### -# That's it! you can access any key of your interest in the tracks -# keys in the exact same way. +###################################################### +# The likelihoods -##################################################### -# Accessing the mc_tracks data -# ---------------------------- -# similarly, you can access mc_tracks data in any key of interest by -# following the same procedure as for tracks: - -cos_zenith = mc_tracks.dir_z -print(cos_zenith) +f.tracks.lik ##################################################### -# or: +# To select just a single event or a subset of events, use the indices or slices. +# The following will access all tracks and their fields +# of the third event (0 is the first): -dir_y = mc_tracks.dir_y -print(dir_y) +f[2].tracks +###################################################### +# The z direction of all tracks in the third event: -##################################################### -# item selection in tracks data -# ----------------------------- -# tracks data can be selected as you would select an item from a numpy array. -# for example, to select E (energy) in the tracks corresponding to the first event: +f[2].tracks.dir_z -print(tracks[0].E) ##################################################### -# or: +# while here, we select the first 3 events. Notice that all fields will return +# nested arrays, as we have seem above where all events were selected. -print(tracks.E[0]) +f[:3] +###################################################### +# All tracks for the first three events -##################################################### -# slicing of tracks -# ----------------- -# to select a slice of tracks data: +f[:3].tracks -print(tracks[0:3].E) +###################################################### +# The z directions of all tracks of the first three events + +f[:3].tracks.dir_z ##################################################### -# or: +# or events from 3 and 5 (again, 0 indexing): -print(tracks.E[0:3]) -##################################################### -# you can apply masks to tracks data as you would do with numpy arrays: +f[2:5] -mask = tracks.lik > 100 +###################################################### +# the tracks of those events -print(tracks.lik[mask]) +f[2:5].tracks -##################################################### -# or: +###################################################### +# and just the z directions of those -print(tracks.dir_z[mask]) +f[2:5].tracks.dir_z diff --git a/examples/plot_offline_usr.py b/examples/plot_offline_usr.py index 9d7959be5ef06fdcba25e0245b20477137fea521..8bf3b2d16fa9686fc7e9502b48930480f35130c5 100644 --- a/examples/plot_offline_usr.py +++ b/examples/plot_offline_usr.py @@ -18,19 +18,12 @@ r = ki.OfflineReader(data_path("offline/usr-sample.root")) ##################################################### -# Accessing the usr data: +# Accessing the usr fields: -usr = r.events.usr -print(usr) +print(r.events.usr_names.tolist()) ##################################################### -# to access data of a specific key, you can either do: +# to access data of a specific key: -print(usr.DeltaPosZ) - - -##################################################### -# or - -print(usr["RecoQuality"]) +print(ki.tools.usr(r.events, "DeltaPosZ")) diff --git a/km3io/__init__.py b/km3io/__init__.py index 52ba6348e74fb7334c04ebe5a1a1b1025869c31a..1d7af3825e9df7cfb878432cd64ec60197e94f20 100644 --- a/km3io/__init__.py +++ b/km3io/__init__.py @@ -5,4 +5,3 @@ version = get_distribution(__name__).version from .offline import OfflineReader from .online import OnlineReader from .gseagen import GSGReader -from . import patches diff --git a/km3io/gseagen.py b/km3io/gseagen.py index 35f8b8c58faeb9c1a1e6a762a58d1e2c6bf99933..f7c570b79713f15110933a478b8c4a415d075f8c 100644 --- a/km3io/gseagen.py +++ b/km3io/gseagen.py @@ -3,46 +3,30 @@ # Filename: gseagen.py # Author: Johannes Schumann <jschumann@km3net.de> -import uproot3 -import numpy as np import warnings -from .rootio import Branch, BranchMapper +from .rootio import EventReader from .tools import cached_property -MAIN_TREE_NAME = "Events" - -class GSGReader: +class GSGReader(EventReader): """reader for gSeaGen ROOT files""" - def __init__(self, file_path=None, fobj=None): - """GSGReader class is a gSeaGen ROOT file wrapper - - Parameters - ---------- - file_path : file path or file-like object - The file handler. It can be a str or any python path-like object - that points to the file. - """ - self._fobj = uproot3.open(file_path) + header_key = "Header" + event_path = "Events" + skip_keys = [header_key] @cached_property def header(self): - header_key = "Header" - if header_key in self._fobj: + if self.header_key in self._fobj: header = {} - for k, v in self._fobj[header_key].items(): + for k, v in self._fobj[self.header_key].items(): v = v.array()[0] if isinstance(v, bytes): try: v = v.decode("utf-8") except UnicodeDecodeError: pass - header[k.decode("utf-8")] = v + header[k] = v return header else: warnings.warn("Your file header has an unsupported format") - - @cached_property - def events(self): - return Branch(self._fobj, BranchMapper(name="Events", key="Events")) diff --git a/km3io/offline.py b/km3io/offline.py index e2f251676b8b7402778282f7c87f0799de73ef4c..c59302474b2f2aa722e0f0d400cfcb32e1db767a 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -1,220 +1,93 @@ -import binascii from collections import namedtuple -import uproot3 +import logging import warnings -import numba as nb +import uproot +import numpy as np +import awkward as ak -from .definitions import mc_header, fitparameters, reconstruction +from .definitions import mc_header from .tools import cached_property, to_num, unfold_indices -from .rootio import Branch, BranchMapper +from .rootio import EventReader -MAIN_TREE_NAME = "E" -EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"] +log = logging.getLogger("offline") -# 110 MB based on the size of the largest basket found so far in km3net -BASKET_CACHE_SIZE = 110 * 1024 ** 2 -BASKET_CACHE = uproot3.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE) - -def _nested_mapper(key): - """Maps a key in the ROOT file to another key (e.g. trks.pos.x -> pos_x)""" - return "_".join(key.split(".")[1:]) - - -EVENTS_MAP = BranchMapper( - name="events", - key="Evt", - extra={"t_sec": "t.fSec", "t_ns": "t.fNanoSec"}, - exclude=EXCLUDE_KEYS, - update={ - "n_hits": "hits", - "n_mc_hits": "mc_hits", - "n_tracks": "trks", - "n_mc_tracks": "mc_trks", - }, -) - -SUBBRANCH_MAPS = [ - BranchMapper( - name="tracks", - key="trks", - extra={}, - exclude=EXCLUDE_KEYS - + ["trks.usr_data", "trks.usr", "trks.fUniqueID", "trks.fBits"], - attrparser=_nested_mapper, - flat=False, - toawkward=["fitinf", "rec_stages"], - ), - BranchMapper( - name="mc_tracks", - key="mc_trks", - exclude=EXCLUDE_KEYS - + [ - "mc_trks.rec_stages", - "mc_trks.fitinf", - "mc_trks.fUniqueID", - "mc_trks.fBits", - ], - attrparser=_nested_mapper, - toawkward=["usr", "usr_names"], - flat=False, - ), - BranchMapper( - name="hits", - key="hits", - exclude=EXCLUDE_KEYS - + [ - "hits.usr", - "hits.pmt_id", - "hits.origin", - "hits.a", - "hits.pure_a", - "hits.fUniqueID", - "hits.fBits", - ], - attrparser=_nested_mapper, - flat=False, - ), - BranchMapper( - name="mc_hits", - key="mc_hits", - exclude=EXCLUDE_KEYS - + [ - "mc_hits.usr", - "mc_hits.dom_id", - "mc_hits.channel_id", - "mc_hits.tdc", - "mc_hits.tot", - "mc_hits.trig", - "mc_hits.fUniqueID", - "mc_hits.fBits", - ], - attrparser=_nested_mapper, - flat=False, - ), -] - - -class OfflineBranch(Branch): - @cached_property - def usr(self): - return Usr(self._mapper, self._branch, index_chain=self._index_chain) - - -class Usr: - """Helper class to access AAObject `usr` stuff (only for events.usr)""" - - def __init__(self, mapper, branch, index_chain=None): - self._mapper = mapper - self._name = mapper.name - self._index_chain = [] if index_chain is None else index_chain - self._branch = branch - self._usr_names = [] - self._usr_idx_lookup = {} - - self._usr_key = "usr" if mapper.flat else mapper.key + ".usr" - - self._initialise() - - def _initialise(self): - try: - self._branch[self._usr_key] - # This will raise a KeyError in old aanet files - # which has a different strucuter and key (usr_data) - # We do not support those (yet) - except (KeyError, IndexError): - print( - "The `usr` fields could not be parsed for the '{}' branch.".format( - self._name - ) - ) - return - - self._usr_names = [ - n.decode("utf-8") - for n in self._branch[self._usr_key + "_names"].lazyarray()[0] - ] - self._usr_idx_lookup = { - name: index for index, name in enumerate(self._usr_names) - } - - data = self._branch[self._usr_key].lazyarray() - - if self._index_chain: - data = unfold_indices(data, self._index_chain) - - self._usr_data = data - - for name in self._usr_names: - setattr(self, name, self[name]) - - def __getitem__(self, item): - if self._index_chain: - return unfold_indices(self._usr_data, self._index_chain)[ - :, self._usr_idx_lookup[item] - ] - else: - return self._usr_data[:, self._usr_idx_lookup[item]] - - def keys(self): - return self._usr_names - - def __str__(self): - entries = [] - for name in self.keys(): - entries.append("{}: {}".format(name, self[name])) - return "\n".join(entries) - - def __repr__(self): - return "<{}[{}]>".format(self.__class__.__name__, self._name) - - -class OfflineReader: +class OfflineReader(EventReader): """reader for offline ROOT files""" - def __init__(self, file_path=None): - """OfflineReader class is an offline ROOT file wrapper - - Parameters - ---------- - file_path : path-like object - Path to the file of interest. It can be a str or any python - path-like object that points to the file. - - """ - self._fobj = uproot3.open(file_path) - self._filename = file_path - self._tree = self._fobj[MAIN_TREE_NAME] - self._uuid = binascii.hexlify(self._fobj._context.uuid).decode("ascii") - - @property - def uuid(self): - return self._uuid - - def close(self): - self._fobj.close() - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - - @cached_property - def events(self): - """The `E` branch, containing all offline events.""" - return OfflineBranch( - self._tree, mapper=EVENTS_MAP, subbranchmaps=SUBBRANCH_MAPS - ) + event_path = "E/Evt" + item_name = "OfflineEvent" + skip_keys = ["t", "AAObject"] + aliases = { + "t_sec": "t/t.fSec", + "t_ns": "t/t.fNanoSec", + "usr": "AAObject/usr", + "usr_names": "AAObject/usr_names", + } + nested_branches = { + "hits": { + "id": "hits.id", + "channel_id": "hits.channel_id", + "dom_id": "hits.dom_id", + "t": "hits.t", + "tot": "hits.tot", + "trig": "hits.trig", # non-zero if the hit is a triggered hit + }, + "mc_hits": { + "id": "mc_hits.id", + "pmt_id": "mc_hits.pmt_id", + "t": "mc_hits.t", # hit time (MC truth) + "a": "mc_hits.a", # hit amplitude (in p.e.) + "origin": "mc_hits.origin", # track id of the track that created this hit + "pure_t": "mc_hits.pure_t", # photon time before pmt simultion + "pure_a": "mc_hits.pure_a", # amplitude before pmt simution, + "type": "mc_hits.type", # particle type or parametrisation used for hit + }, + "trks": { + "id": "trks.id", + "pos_x": "trks.pos.x", + "pos_y": "trks.pos.y", + "pos_z": "trks.pos.z", + "dir_x": "trks.dir.x", + "dir_y": "trks.dir.y", + "dir_z": "trks.dir.z", + "t": "trks.t", + "E": "trks.E", + "len": "trks.len", + "lik": "trks.lik", + "rec_type": "trks.rec_type", + "rec_stages": "trks.rec_stages", + "fitinf": "trks.fitinf", + }, + "mc_trks": { + "id": "mc_trks.id", + "pos_x": "mc_trks.pos.x", + "pos_y": "mc_trks.pos.y", + "pos_z": "mc_trks.pos.z", + "dir_x": "mc_trks.dir.x", + "dir_y": "mc_trks.dir.y", + "dir_z": "mc_trks.dir.z", + "E": "mc_trks.E", + "t": "mc_trks.t", + "len": "mc_trks.len", + # "status": "mc_trks.status", # TODO: check this + # "mother_id": "mc_trks.mother_id", # TODO: check this + "pdgid": "mc_trks.type", + "hit_ids": "mc_trks.hit_ids", + "usr": "mc_trks.usr", # TODO: trouble with uproot4 + "usr_names": "mc_trks.usr_names", # TODO: trouble with uproot4 + }, + } + nested_aliases = { + "tracks": "trks", + "mc_tracks": "mc_trks", + } @cached_property def header(self): """The file header""" if "Head" in self._fobj: - header = {} - for n, x in self._fobj["Head"]._map_3c_string_2c_string_3e_.items(): - header[n.decode("utf-8")] = x.decode("utf-8").strip() - return Header(header) + return Header(self._fobj["Head"].tojson()["map<string,string>"]) else: warnings.warn("Your file header has an unsupported format") diff --git a/km3io/patches.py b/km3io/patches.py deleted file mode 100644 index 8d3be710146cb7db031f7e7e64578d924fa2b0be..0000000000000000000000000000000000000000 --- a/km3io/patches.py +++ /dev/null @@ -1,17 +0,0 @@ -import awkward0 as ak0 -import awkward as ak1 - -# to avoid infinite recursion -old_getitem = ak0.ChunkedArray.__getitem__ - - -def new_getitem(self, item): - """Monkey patch the getitem in awkward.ChunkedArray to apply - awkward1.Array masks on awkward.ChunkedArray""" - if isinstance(item, (ak1.Array, ak0.ChunkedArray)): - return ak1.Array(self)[item] - else: - return old_getitem(self, item) - - -ak0.ChunkedArray.__getitem__ = new_getitem diff --git a/km3io/rootio.py b/km3io/rootio.py index 3445f59715bf82ca1753b7a6742fa3aee0e21290..a816ad1198dc23b7d23327b0c22848f2402f4cba 100644 --- a/km3io/rootio.py +++ b/km3io/rootio.py @@ -1,244 +1,392 @@ #!/usr/bin/env python3 +from collections import namedtuple import numpy as np import awkward as ak -import uproot3 +import uproot from .tools import unfold_indices -# 110 MB based on the size of the largest basket found so far in km3net -BASKET_CACHE_SIZE = 110 * 1024 ** 2 -BASKET_CACHE = uproot3.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE) - - -class BranchMapper: - """ - Mapper helper for keys in a ROOT branch. - - Parameters - ---------- - name: str - The name of the mapper helper which is displayed to the user - key: str - The key of the branch in the ROOT tree. - exclude: ``None``, ``list(str)`` - Keys to exclude from parsing. - update: ``None``, ``dict(str: str)`` - An update map for keys which are to be presented with a different - key to the user e.g. ``{"n_hits": "hits"}`` will rename the ``hits`` - key to ``n_hits``. - extra: ``None``, ``dict(str: str)`` - An extra mapper for hidden object, primarily nested ones like - ``t.fSec``, which can be revealed and mapped to e.g. ``t_sec`` - via ``{"t_sec", "t.fSec"}``. - attrparser: ``None``, ``function(str) -> str`` - The function to be used to create attribute names. This is only - needed if unsupported characters are present, like ``.``, which - would prevent setting valid Python attribute names. - toawkward: ``None``, ``list(str)`` - List of keys to convert to awkward arrays (recommended for - doubly ragged arrays) - """ +import logging - def __init__( - self, - name, - key, - extra=None, - exclude=None, - update=None, - attrparser=None, - flat=True, - interpretations=None, - toawkward=None, - ): - self.name = name - self.key = key +log = logging.getLogger("km3io.rootio") - self.extra = {} if extra is None else extra - self.exclude = [] if exclude is None else exclude - self.update = {} if update is None else update - self.attrparser = (lambda x: x) if attrparser is None else attrparser - self.flat = flat - self.interpretations = {} if interpretations is None else interpretations - self.toawkward = [] if toawkward is None else toawkward +class EventReader: + """reader for offline ROOT files""" -class Branch: - """Branch accessor class""" + event_path = None + item_name = "Event" + skip_keys = [] # ignore these subbranches, even if they exist + aliases = {} # top level aliases -> {fromkey: tokey} + nested_branches = {} + nested_aliases = {} def __init__( self, - tree, - mapper, + f, index_chain=None, - subbranchmaps=None, - keymap=None, - awkward_cache=None, + step_size=2000, + keys=None, + aliases=None, + nested_branches=None, + event_ctor=None, ): - self._tree = tree - self._mapper = mapper - self._index_chain = [] if index_chain is None else index_chain - self._keymap = None - self._branch = tree[mapper.key] - self._subbranches = [] - self._subbranchmaps = subbranchmaps - # FIXME preliminary cache to improve performance. Hopefully uproot4 - # will fix this automatically! - self._awkward_cache = {} if awkward_cache is None else awkward_cache - + """EventReader base class + + Parameters + ---------- + f : str or uproot4.reading.ReadOnlyDirectory (from uproot4.open) + Path to the file of interest or uproot4 filedescriptor. + step_size : int, optional + Number of events to read into the cache when iterating. + Choosing higher numbers may improve the speed but also increases + the memory overhead. + index_chain : list, optional + Keeps track of index chaining. + keys : list or set, optional + Branch keys. + aliases : dict, optional + Branch key aliases. + event_ctor : class or namedtuple, optional + Event constructor. + + """ + if isinstance(f, str): + self._fobj = uproot.open(f) + self._filepath = f + elif isinstance(f, uproot.reading.ReadOnlyDirectory): + self._fobj = f + self._filepath = f._file.file_path + else: + raise TypeError("Unsupported file descriptor.") + self._step_size = step_size + self._uuid = self._fobj._file.uuid self._iterator_index = 0 + self._keys = keys + self._event_ctor = event_ctor + self._index_chain = [] if index_chain is None else index_chain - if keymap is None: - self._initialise_keys() # - else: - self._keymap = keymap - - if subbranchmaps is not None: - for mapper in subbranchmaps: - subbranch = self.__class__( - self._tree, - mapper=mapper, - index_chain=self._index_chain, - awkward_cache=self._awkward_cache, - ) - self._subbranches.append(subbranch) - for subbranch in self._subbranches: - setattr(self, subbranch._mapper.name, subbranch) + if aliases is not None: + self.aliases = aliases + if nested_branches is not None: + self.nested_branches = nested_branches + + if self._keys is None: + self._initialise_keys() + + if self._event_ctor is None: + self._event_ctor = namedtuple( + self.item_name, + set( + list(self.keys()) + + list(self.aliases) + + list(self.nested_branches) + + list(self.nested_aliases) + ), + ) def _initialise_keys(self): - """Create the keymap and instance attributes for branch keys""" - # TODO: this could be a cached property - keys = set(k.decode("utf-8") for k in self._branch.keys()) - set( - self._mapper.exclude + skip_keys = set(self.skip_keys) + all_keys = set(self._fobj[self.event_path].keys()) + toplevel_keys = set(k.split("/")[0] for k in all_keys) + valid_aliases = {} + for fromkey, tokey in self.aliases.items(): + if tokey in all_keys: + valid_aliases[fromkey] = tokey + self.aliases = valid_aliases + keys = (toplevel_keys - skip_keys).union( + list(valid_aliases) + list(self.nested_aliases) ) - self._keymap = { - **{self._mapper.attrparser(k): k for k in keys}, - **self._mapper.extra, - } - self._keymap.update(self._mapper.update) - for k in self._mapper.update.values(): - del self._keymap[k] - - for key in self._keymap.keys(): - setattr(self, key, None) + for key in list(self.nested_branches) + list(self.nested_aliases): + keys.add("n_" + key) + # self._grouped_branches = {k for k in toplevel_keys - skip_keys if isinstance(self._fobj[self.event_path][k].interpretation, uproot.AsGrouped)} + valid_nested_branches = {} + for nested_key, aliases in self.nested_branches.items(): + if nested_key in toplevel_keys: + valid_nested_branches[nested_key] = {} + subbranch_keys = self._fobj[self.event_path][nested_key].keys() + for fromkey, tokey in aliases.items(): + if tokey in subbranch_keys: + valid_nested_branches[nested_key][fromkey] = tokey + self.nested_branches = valid_nested_branches + self._keys = keys + + def __dir__(self): + """Tab completion in IPython""" + return list(self.keys()) + ["header"] def keys(self): - return self._keymap.keys() + """Returns all accessible branch keys, without the skipped ones.""" + return self._keys - def __getattribute__(self, attr): - if attr.startswith("_"): # let all private and magic methods pass - return object.__getattribute__(self, attr) - - if attr in self._keymap.keys(): # intercept branch key lookups - return self.__getkey__(attr) + @property + def events(self): + # TODO: deprecate this, since `self` is already the container type + return iter(self) + + def _keyfor(self, key): + """Return the correct key for a given alias/key""" + return self.nested_aliases.get(key, key) + + def __getattr__(self, attr): + attr = self._keyfor(attr) + # if attr in self.keys() or (attr.startswith("n_") and self._keyfor(attr.split("n_")[1]) in self._grouped_branches): + if attr in self.keys(): + return self.__getitem__(attr) + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{attr}'" + ) - return object.__getattribute__(self, attr) + def __getitem__(self, key): + # indexing + # TODO: maybe just propagate everything to awkward and let it deal + # with the type? + if isinstance( + key, (slice, int, np.int32, np.int64, list, np.ndarray, ak.Array) + ): + if isinstance(key, (int, np.int32, np.int64)): + key = int(key) + return self.__class__( + self._fobj, + index_chain=self._index_chain + [key], + step_size=self._step_size, + aliases=self.aliases, + nested_branches=self.nested_branches, + keys=self.keys(), + event_ctor=self._event_ctor, + ) + # group counts, for e.g. n_events, n_hits etc. + if isinstance(key, str) and key.startswith("n_"): + key = self._keyfor(key.split("n_")[1]) + arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4")) + return unfold_indices(arr, self._index_chain) + + key = self._keyfor(key) + branch = self._fobj[self.event_path] + # These are special branches which are nested, like hits/trks/mc_trks + # We are explicitly grabbing just a predefined set of subbranches + # and also alias them to be backwards compatible (and attribute-accessible) + if key in self.nested_branches: + fields = [] + # some fields are not always available, like `usr_names` + for to_field, from_field in self.nested_branches[key].items(): + if from_field in branch[key].keys(): + fields.append(to_field) + log.debug(fields) + return Branch( + branch[key], fields, self.nested_branches[key], self._index_chain + ) + else: + return unfold_indices( + branch[self.aliases.get(key, key)].array(), self._index_chain + ) - def __getkey__(self, key): - interpretation = self._mapper.interpretations.get(key) + def __iter__(self, chunkwise=False): + self._events = self._event_generator(chunkwise=chunkwise) + return self - if key == "usr_names": - # TODO this will be fixed soon in uproot, - # see https://github.com/scikit-hep/uproot/issues/465 - interpretation = uproot3.asgenobj( - uproot3.SimpleArray(uproot3.STLVector(uproot3.STLString())), - self._branch[self._keymap[key]]._context, - 6, + def _get_iterator_limits(self): + """Determines start and stop, used for event iteration""" + if len(self._index_chain) > 1: + raise NotImplementedError( + "iteration is currently not supported with nested slices" ) - - if key == "usr": - # triple jagged array is wrongly parsed in uproot3 - interpretation = uproot3.asgenobj( - uproot3.SimpleArray(uproot3.STLVector(uproot3.asdtype(">f8"))), - self._branch[self._keymap[key]]._context, - 6, + if self._index_chain: + s = self._index_chain[0] + if not isinstance(s, slice): + raise NotImplementedError("iteration is only supported with slices") + if s.step is None or s.step == 1: + start = s.start + stop = s.stop + else: + raise NotImplementedError( + "iteration is only supported with single steps" + ) + else: + start = None + stop = None + return start, stop + + def _event_generator(self, chunkwise=False): + start, stop = self._get_iterator_limits() + + if chunkwise: + raise NotImplementedError("iterating over chunks is not implemented yet") + + events = self._fobj[self.event_path] + group_count_keys = set( + k for k in self.keys() if k.startswith("n_") + ) # extra keys to make it easy to count subbranch lengths + log.debug("group_count_keys: %s", group_count_keys) + keys = set( + list( + set(self.keys()) + - set(self.nested_branches.keys()) + - set(self.nested_aliases) + - group_count_keys ) - - out = self._branch[self._keymap[key]].lazyarray( - interpretation=interpretation, basketcache=BASKET_CACHE + + list(self.aliases.keys()) + ) # all top-level keys for regular branches + log.debug("keys: %s", keys) + log.debug("aliases: %s", self.aliases) + events_it = events.iterate( + keys, + aliases=self.aliases, + step_size=self._step_size, + entry_start=start, + entry_stop=stop, ) - if self._index_chain is not None and key in self._mapper.toawkward: - cache_key = self._mapper.name + "/" + key - if cache_key not in self._awkward_cache: - if len(out) > 20000: # It will take more than 10 seconds - print("Creating cache for '{}'.".format(cache_key)) - self._awkward_cache[cache_key] = ak.from_iter(out) - out = self._awkward_cache[cache_key] - return unfold_indices(out, self._index_chain) - - def __getitem__(self, item): - """Slicing magic""" - if isinstance(item, str): - return self.__getkey__(item) - - if isinstance(item, (np.int32, np.int64)): - item = int(item) - - # if item.__class__.__name__ == "ChunkedArray": - # item = np.array(item) + nested = [] + nested_keys = ( + self.nested_branches.keys() + ) # dict-key ordering is an implementation detail + log.debug("nested_keys: %s", nested_keys) + for key in nested_keys: + nested.append( + events[key].iterate( + self.nested_branches[key].keys(), + aliases=self.nested_branches[key], + step_size=self._step_size, + entry_start=start, + entry_stop=stop, + ) + ) + group_counts = {} + for key in group_count_keys: + group_counts[key] = iter(self[key]) + + log.debug("group_counts: %s", group_counts) + for event_set, *nested_sets in zip(events_it, *nested): + for _event, *nested_items in zip(event_set, *nested_sets): + data = {} + for k in keys: + data[k] = _event[k] + for (k, i) in zip(nested_keys, nested_items): + data[k] = i + for tokey, fromkey in self.nested_aliases.items(): + data[tokey] = data[fromkey] + for key in group_counts: + data[key] = next(group_counts[key]) + yield self._event_ctor(**data) - return self.__class__( - self._tree, - self._mapper, - index_chain=self._index_chain + [item], - keymap=self._keymap, - subbranchmaps=self._subbranchmaps, - awkward_cache=self._awkward_cache, - ) + def __next__(self): + return next(self._events) def __len__(self): if not self._index_chain: - return len(self._branch) + return self._fobj[self.event_path].num_entries elif isinstance(self._index_chain[-1], (int, np.int32, np.int64)): if len(self._index_chain) == 1: - try: - return len(self[:]) - except IndexError: - return 1 + # TODO: not sure why this is needed at all, it's too late... + return 1 + # try: + # return len(self[:]) + # except IndexError: + # return 1 return 1 else: + # ignore the usual index magic and access `id` directly return len( unfold_indices( - self._branch[self._keymap["id"]].lazyarray( - basketcache=BASKET_CACHE - ), - self._index_chain, + self._fobj[self.event_path]["id"].array(), self._index_chain ) ) + def __actual_len__(self): + """The raw number of events without any indexing/slicing magic""" + return len(self._fobj[self.event_path]["id"].array()) + + def __repr__(self): + length = len(self) + actual_length = self.__actual_len__() + return f"{self.__class__.__name__} ({length}{'/' + str(actual_length) if length < actual_length else ''} events)" + @property - def is_single(self): - """Returns True when a single branch is selected.""" - if len(self._index_chain) > 0: - if isinstance(self._index_chain[0], (int, np.int32, np.int64)): - return True - return False - - def __iter__(self): - self._iterator_index = 0 + def uuid(self): + return self._uuid + + def close(self): + self._fobj.close() + + def __enter__(self): return self - def __next__(self): - idx = self._iterator_index - self._iterator_index += 1 - if idx >= len(self): - raise StopIteration - return self[idx] + def __exit__(self, *args): + self.close() - def __str__(self): - length = len(self) - return "{} ({}) with {} element{}".format( - self.__class__.__name__, - self._mapper.name, - length, - "s" if length > 1 else "", + +class Branch: + """Helper class for nested branches likes tracks/hits""" + + def __init__(self, branch, fields, aliases, index_chain): + self._branch = branch + self.fields = fields + self._aliases = aliases + self._index_chain = index_chain + + def __dir__(self): + """Tab completion in IPython""" + return list(self.fields) + + def __getattr__(self, attr): + if attr not in self._aliases: + raise AttributeError( + f"No field named {attr}. Available fields: {self.fields}" + ) + key = self._aliases[attr] + + if self._index_chain: + idx0 = self._index_chain[0] + if isinstance(idx0, (int, np.int32, np.int64)): + # optimise single-element and slice lookups + start = idx0 + stop = idx0 + 1 + arr = ak.flatten( + self._branch[key].array(entry_start=start, entry_stop=stop) + ) + return unfold_indices(arr, self._index_chain[1:]) + if isinstance(idx0, slice): + if idx0.step is None or idx0.step == 1: + start = idx0.start + stop = idx0.stop + arr = self._branch[key].array(entry_start=start, entry_stop=stop) + return unfold_indices(arr, self._index_chain[1:]) + + return unfold_indices(self._branch[key].array(), self._index_chain) + + def __getitem__(self, key): + return self.__class__( + self._branch, self.fields, self._aliases, self._index_chain + [key] ) + def __len__(self): + if not self._index_chain: + return self._branch.num_entries + elif isinstance(self._index_chain[-1], (int, np.int32, np.int64)): + if len(self._index_chain) == 1: + return 1 + # try: + # return len(self[:]) + # except IndexError: + # return 1 + return 1 + else: + # ignore the usual index magic and access `id` directly + return len(self.id) + + def __actual_len__(self): + """The raw number of events without any indexing/slicing magic""" + return len(self._branch[self._aliases["id"]].array()) + def __repr__(self): length = len(self) - return "<{}[{}]: {} element{}>".format( - self.__class__.__name__, - self._mapper.name, - length, - "s" if length > 1 else "", - ) + actual_length = self.__actual_len__() + return f"{self.__class__.__name__} ({length}{'/' + str(actual_length) if length < actual_length else ''} {self._branch.name})" + + @property + def ndim(self): + if not self._index_chain: + return 2 + elif any(isinstance(i, (int, np.int32, np.int64)) for i in self._index_chain): + return 1 + return 2 diff --git a/km3io/tools.py b/km3io/tools.py index c42841919e0a5265aee6d41276f1ed65a7358e03..9a50408947301436ee428788bad1c52ee1a92265 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from collections import namedtuple import numba as nb import numpy as np import awkward as ak @@ -133,34 +134,19 @@ def fitinf(fitparam, tracks): ---------- fitparam : int the fit parameter key according to fitparameters defined in - KM3NeT-Dataformat. - tracks : km3io.offline.OfflineBranch - the tracks class. both full tracks branch or a slice of the - tracks branch (example tracks[:, 0]) work. + KM3NeT-Dataformat (see km3io.definitions.fitparameters). + tracks : ak.Array or km3io.rootio.Branch + reconstructed tracks with .fitinf attribute Returns ------- awkward1.Array - awkward array of the values of the fit parameter requested. + awkward array of the values of the fit parameter requested. Missing + values are set to NaN. """ fit = tracks.fitinf - index = fitparam - if tracks.is_single and len(tracks) != 1: - params = fit[count_nested(fit, axis=1) > index] - out = params[:, index] - - if tracks.is_single and len(tracks) == 1: - out = fit[index] - - else: - if len(tracks[0]) == 1: # case of tracks slice with 1 track per event. - params = fit[count_nested(fit, axis=1) > index] - out = params[:, index] - else: - params = fit[count_nested(fit, axis=2) > index] - out = ak.Array([i[:, index] for i in params]) - - return out + nonempty = ak.num(fit, axis=-1) > 0 + return ak.fill_none(fit.mask[nonempty][..., 0], np.nan) def count_nested(arr, axis=0): @@ -206,12 +192,9 @@ def get_multiplicity(tracks, rec_stages): awkward1.Array tracks multiplicty. """ - masked_tracks = tracks[mask(tracks, stages=rec_stages)] + masked_tracks = tracks[mask(tracks.rec_stages, sequence=rec_stages)] - if tracks.is_single: - out = count_nested(masked_tracks.rec_stages, axis=0) - else: - out = count_nested(masked_tracks.rec_stages, axis=1) + out = count_nested(masked_tracks.rec_stages, axis=tracks.ndim - 1) return out @@ -221,548 +204,223 @@ def best_track(tracks, startend=None, minmax=None, stages=None): Parameters ---------- - tracks : km3io.offline.OfflineBranch - Array of tracks or jagged array of tracks (multiple events). + tracks : awkward.Array + A list of tracks or doubly nested tracks, usually from + OfflineReader.events.tracks or subarrays of that, containing recunstructed + tracks. startend: tuple(int, int), optional - The required first and last stage in tracks.rec_stages. + The required first and last stage in tracks.rec_stages. minmax: tuple(int, int), optional - The range (minimum and maximum) value of rec_stages to take into account. + The range (minimum and maximum) value of rec_stages to take into account. stages : list or set, optional - - list: the order of the rec_stages is respected. - - set: a subset of required stages; the order is irrelevant. + - list: the order of the rec_stages is respected. + - set: a subset of required stages; the order is irrelevant. Returns ------- - km3io.offline.OfflineBranch - The best tracks based on the selection. + awkward.Array or namedtuple + Be aware that the dimensions are kept, which means that the final + track attributes are nested when multiple events are passed in. + If a single event (just a list of tracks) is provided, a named tuple + with a single track and flat attributes is created. Raises ------ ValueError - - too many inputs specified. - - no inputs are specified. + When invalid inputs are specified. """ inputs = (stages, startend, minmax) - if all(v is None for v in inputs): + if sum(v is not None for v in inputs) != 1: raise ValueError("either stages, startend or minmax must be specified.") - if stages is not None and (startend is not None or minmax is not None): - raise ValueError("Please specify either a range or a set of rec stages.") - - if stages is not None and startend is None and minmax is None: - selected_tracks = tracks[mask(tracks, stages=stages)] - - if startend is not None and minmax is None and stages is None: - selected_tracks = tracks[mask(tracks, startend=startend)] - - if minmax is not None and startend is None and stages is None: - selected_tracks = tracks[mask(tracks, minmax=minmax)] - - return _max_lik_track(_longest_tracks(selected_tracks)) - + if stages is not None: + if isinstance(stages, list): + m1 = mask(tracks.rec_stages, sequence=stages) + elif isinstance(stages, set): + m1 = mask(tracks.rec_stages, atleast=list(stages)) + else: + raise ValueError("stages must be a list or a set of integers") -def _longest_tracks(tracks): - """Select the longest reconstructed track""" - if tracks.is_single: - stages_nesting_level = 1 - tracks_nesting_level = 0 + if startend is not None: + m1 = mask(tracks.rec_stages, startend=startend) - else: - stages_nesting_level = 2 - tracks_nesting_level = 1 + if minmax is not None: + m1 = mask(tracks.rec_stages, minmax=minmax) - len_stages = count_nested(tracks.rec_stages, axis=stages_nesting_level) - longest = tracks[len_stages == ak.max(len_stages, axis=tracks_nesting_level)] + try: + original_ndim = tracks.ndim + except AttributeError: + original_ndim = 1 + axis = 1 if original_ndim == 2 else 0 - return longest + tracks = tracks[m1] + rec_stage_lengths = ak.num(tracks.rec_stages, axis=-1) + max_rec_stage_length = ak.max(rec_stage_lengths, axis=axis) + m2 = rec_stage_lengths == max_rec_stage_length + tracks = tracks[m2] -def _max_lik_track(tracks): - """Select the track with the highest likelihood """ - if tracks.is_single: - tracks_nesting_level = 0 - else: - tracks_nesting_level = 1 + m3 = ak.argmax(tracks.lik, axis=axis, keepdims=True) - return tracks[tracks.lik == ak.max(tracks.lik, axis=tracks_nesting_level)] + out = tracks[m3] + if original_ndim == 1: + if isinstance(out, ak.Record): + return out[:, 0] + return out[0] + return out[:, 0] -def mask(tracks, stages=None, startend=None, minmax=None): - """Create a mask for tracks.rec_stages. +def mask(arr, sequence=None, startend=None, minmax=None, atleast=None): + """Return a boolean mask which mask each nested sub-array for a condition. Parameters ---------- - tracks : km3io.offline.OfflineBranch - tracks, or one track, or slice of tracks, or slice of one track. - stages : list or set - reconstruction stages of interest: - - list: the order of rec_stages in respected. - - set: the order of rec_stages in irrelevant. + arr : awkward.Array with ndim>=2 + The array to mask. startend: tuple(int, int), optional - The required first and last stage in tracks.rec_stages. + True for entries where the first and last element are matching the tuple. minmax: tuple(int, int), optional - The range (minimum and maximum) value of rec_stages to take into account. - - Returns - ------- - awkward1.Array(bool) - an awkward1 Array mask where True corresponds to the positions - where stages were found. False otherwise. - - Raises - ------ - ValueError - - too many inputs specified. - - no inputs are specified. + True for entries where each element is within the min-max-range. + sequence : list(int), optional + True for entries which contain the exact same elements (in that specific + order) + atleast : list(int), optional + True for entries where at least the provided elements are present. + + An extensive discussion about this implementation can be found at: + https://github.com/scikit-hep/awkward-1.0/issues/580 + Many thanks for Jim for the fruitful discussion and the final implementation. """ - inputs = (stages, startend, minmax) - - if all(v is None for v in inputs): - raise ValueError("either stages, startend or minmax must be specified.") - - if stages is not None and (startend is not None or minmax is not None): - raise ValueError("Please specify either a range or a set of rec stages.") - - if stages is not None and startend is None and minmax is None: - if isinstance(stages, list): - # order of stages is conserved - return _mask_explicit_rec_stages(tracks, stages) - if isinstance(stages, set): - # order of stages is no longer conserved - return _mask_rec_stages_in_set(tracks, stages) - - if startend is not None and minmax is None and stages is None: - return _mask_rec_stages_between_start_end(tracks, *startend) - - if minmax is not None and startend is None and stages is None: - return _mask_rec_stages_in_range_min_max(tracks, *minmax) - - -def _mask_rec_stages_between_start_end(tracks, start, end): - """Mask tracks.rec_stages that start exactly with start and end exactly - with end. ie [start, a, b ...,z , end]""" - builder = ak.ArrayBuilder() - if tracks.is_single: - _find_between_single(tracks.rec_stages, start, end, builder) - return (builder.snapshot() == 1)[0] - else: - _find_between(tracks.rec_stages, start, end, builder) - return builder.snapshot() == 1 - - -@nb.jit(nopython=True) -def _find_between(rec_stages, start, end, builder): - """Find tracks.rec_stages where rec_stages[0] == start and rec_stages[-1] == end.""" - - for s in rec_stages: - builder.begin_list() - for i in s: - num_stages = len(i) - if num_stages != 0: - if (i[0] == start) and (i[-1] == end): - builder.append(1) - else: - builder.append(0) + inputs = (sequence, startend, minmax, atleast) + + if sum(v is not None for v in inputs) != 1: + raise ValueError( + "either sequence, startend, minmax or atleast must be specified." + ) + + def recurse(layout): + if layout.purelist_depth == 2: + if startend is not None: + np_array = _mask_startend(ak.Array(layout), *startend) + elif minmax is not None: + np_array = _mask_minmax(ak.Array(layout), *minmax) + elif sequence is not None: + np_array = _mask_sequence(ak.Array(layout), np.array(sequence)) + elif atleast is not None: + np_array = _mask_atleast(ak.Array(layout), np.array(atleast)) + + return ak.layout.NumpyArray(np_array) + + elif isinstance( + layout, + ( + ak.layout.ListArray32, + ak.layout.ListArrayU32, + ak.layout.ListArray64, + ), + ): + if len(layout.stops) == 0: + content = recurse(layout.content) else: - builder.append(0) - builder.end_list() + content = recurse(layout.content[: np.max(layout.stops)]) + return type(layout)(layout.starts, layout.stops, content) + + elif isinstance( + layout, + ( + ak.layout.ListOffsetArray32, + ak.layout.ListOffsetArrayU32, + ak.layout.ListOffsetArray64, + ), + ): + content = recurse(layout.content[: layout.offsets[-1]]) + return type(layout)(layout.offsets, content) + + elif isinstance(layout, ak.layout.RegularArray): + content = recurse(layout.content) + return ak.layout.RegularArray(content, layout.size) - -@nb.jit(nopython=True) -def _find_between_single(rec_stages, start, end, builder): - """Find tracks.rec_stages where rec_stages[0] == start and - rec_stages[-1] == end in a single track.""" - - builder.begin_list() - for s in rec_stages: - num_stages = len(s) - if num_stages != 0: - if (s[0] == start) and (s[-1] == end): - builder.append(1) - else: - builder.append(0) else: - builder.append(0) - builder.end_list() + raise NotImplementedError(repr(arr)) + layout = ak.to_layout(arr, allow_record=True, allow_other=False) + return ak.Array(recurse(layout)) -def _mask_explicit_rec_stages(tracks, stages): - """Mask explicit rec_stages . - - Parameters - ---------- - tracks : km3io.offline.OfflineBranch - tracks or one track, or slice of tracks. - stages : list - reconstruction stages of interest. The order of stages is conserved. - - Returns - ------- - awkward1.Array - an awkward1 Array mask where True corresponds to the positions - where stages were found. False otherwise. - """ - - builder = ak.ArrayBuilder() - if tracks.is_single: - _find_single(tracks.rec_stages, ak.Array(stages), builder) - return (builder.snapshot() == 1)[0] - else: - _find(tracks.rec_stages, ak.Array(stages), builder) - return builder.snapshot() == 1 +@nb.njit +def _mask_startend(arr, start, end): + out = np.empty(len(arr), np.bool_) + for i, subarr in enumerate(arr): + out[i] = len(subarr) > 0 and subarr[0] == start and subarr[-1] == end + return out -@nb.jit(nopython=True) -def _find(rec_stages, stages, builder): - """construct an awkward1 array with the same structure as tracks.rec_stages. - When stages are found, the Array is filled with value 1, otherwise it is filled - with value 0. - Parameters - ---------- - rec_stages : awkward1.Array - tracks.rec_stages from multiple events. - stages : awkward1.Array - reconstruction stages of interest. - builder : awkward1.highlevel.ArrayBuilder - awkward1 Array builder. - """ - for s in rec_stages: - builder.begin_list() - for i in s: - num_stages = len(i) - if num_stages == len(stages): - found = 0 - for j in range(num_stages): - if i[j] == stages[j]: - found += 1 - if found == num_stages: - builder.append(1) - else: - builder.append(0) +@nb.njit +def _mask_minmax(arr, min, max): + out = np.empty(len(arr), np.bool_) + for i, subarr in enumerate(arr): + if len(subarr) == 0: + out[i] = False + else: + for el in subarr: + if el < min or el > max: + out[i] = False + break else: - builder.append(0) - builder.end_list() + out[i] = True + return out -@nb.jit(nopython=True) -def _find_single(rec_stages, stages, builder): - """Construct an awkward1 array with the same structure as tracks.rec_stages. +@nb.njit +def _mask_sequence(arr, sequence): + out = np.empty(len(arr), np.bool_) + n = len(sequence) + for i, subarr in enumerate(arr): + if len(subarr) != n: + out[i] = False + else: + for j in range(n): + if subarr[j] != sequence[j]: + out[i] = False + break + else: + out[i] = True + return out - When stages are found, the Array is filled with value 1, otherwise it is filled - with value 0. - Parameters - ---------- - rec_stages : awkward1.Array - tracks.rec_stages from a SINGLE event. - stages : awkward1.Array - reconstruction stages of interest. - builder : awkward1.highlevel.ArrayBuilder - awkward1 Array builder. - """ - builder.begin_list() - for s in rec_stages: - num_stages = len(s) - if num_stages == len(stages): - found = 0 - for j in range(num_stages): - if s[j] == stages[j]: - found += 1 - if found == num_stages: - builder.append(1) - else: - builder.append(0) +@nb.njit +def _mask_atleast(arr, atleast): + out = np.empty(len(arr), np.bool_) + for i, subarr in enumerate(arr): + for req_el in atleast: + if req_el not in subarr: + out[i] = False + break else: - builder.append(0) - builder.end_list() + out[i] = True + return out def best_jmuon(tracks): - """Select the best JMUON track. - - Parameters - ---------- - tracks : km3io.offline.OfflineBranch - tracks, or one track, or slice of tracks, or slices of tracks. - - Returns - ------- - km3io.offline.OfflineBranch - the longest + highest likelihood track reconstructed with JMUON. - """ - mask = _mask_rec_stages_in_range_min_max( - tracks, min_stage=krec.JMUONBEGIN, max_stage=krec.JMUONEND - ) - - return _max_lik_track(_longest_tracks(tracks[mask])) + """Select the best JMUON track.""" + return best_track(tracks, minmax=(krec.JMUONBEGIN, krec.JMUONEND)) def best_jshower(tracks): - """Select the best JSHOWER track. - - Parameters - ---------- - tracks : km3io.offline.OfflineBranch - tracks, or one track, or slice of tracks, or slices of tracks. - - Returns - ------- - km3io.offline.OfflineBranch - the longest + highest likelihood track reconstructed with JSHOWER. - """ - mask = _mask_rec_stages_in_range_min_max( - tracks, min_stage=krec.JSHOWERBEGIN, max_stage=krec.JSHOWEREND - ) - - return _max_lik_track(_longest_tracks(tracks[mask])) + """Select the best JSHOWER track.""" + return best_track(tracks, minmax=(krec.JSHOWERBEGIN, krec.JSHOWEREND)) def best_aashower(tracks): - """Select the best AASHOWER track. - - Parameters - ---------- - tracks : km3io.offline.OfflineBranch - tracks, or one track, or slice of tracks, or slices of tracks. - - Returns - ------- - km3io.offline.OfflineBranch - the longest + highest likelihood track reconstructed with AASHOWER. - """ - mask = _mask_rec_stages_in_range_min_max( - tracks, min_stage=krec.AASHOWERBEGIN, max_stage=krec.AASHOWEREND - ) - - return _max_lik_track(_longest_tracks(tracks[mask])) + """Select the best AASHOWER track. """ + return best_track(tracks, minmax=(krec.AASHOWERBEGIN, krec.AASHOWEREND)) def best_dusjshower(tracks): - """Select the best DISJSHOWER track. - - Parameters - ---------- - tracks : km3io.offline.OfflineBranch - tracks, or one track, or slice of tracks, or slices of tracks. - - Returns - ------- - km3io.offline.OfflineBranch - the longest + highest likelihood track reconstructed with DUSJSHOWER. - """ - mask = _mask_rec_stages_in_range_min_max( - tracks, min_stage=krec.DUSJSHOWERBEGIN, max_stage=krec.DUSJSHOWEREND - ) - - return _max_lik_track(_longest_tracks(tracks[mask])) - - -def _mask_rec_stages_in_range_min_max(tracks, min_stage=None, max_stage=None): - """Mask tracks where rec_stages are withing the range(min, max). - - Parameters - ---------- - tracks : km3io.offline.OfflineBranch - tracks, or one track, or slice of tracks, or slices of tracks. - min_stage : int - minimum value of rec_stages. - max_stage : int - maximum value of rec_stages. - - - Returns - ------- - awkward1.Array - an awkward1 Array mask where True corresponds to the positions - where stages were found. False otherwise. - """ - if (min_stage is not None) and (max_stage is not None): - - builder = ak.ArrayBuilder() - if tracks.is_single: - _find_in_range_single(tracks.rec_stages, min_stage, max_stage, builder) - return (builder.snapshot() == 1)[0] - else: - _find_in_range(tracks.rec_stages, min_stage, max_stage, builder) - return builder.snapshot() == 1 - - else: - raise ValueError("please provide min_stage and max_stage.") - - -@nb.jit(nopython=True) -def _find_in_range(rec_stages, min_stage, max_stage, builder): - """Construct an awkward1 array with the same structure as tracks.rec_stages. - - When stages are within the range(min, max), the Array is filled with - value 1, otherwise it is filled with value 0. - - Parameters - ---------- - rec_stages : awkward1.Array - tracks.rec_stages of MULTILPLE events. - min_stage: int - minimum value of rec_stages. - max_stage: int - minimum value of rec_stages. - builder : awkward1.highlevel.ArrayBuilder - awkward1 Array builder. - - """ - for s in rec_stages: - builder.begin_list() - for i in s: - num_stages = len(i) - if num_stages != 0: - found = 0 - for j in i: - if min_stage <= j <= max_stage: - found += 1 - if found == num_stages: - builder.append(1) - else: - builder.append(0) - else: - builder.append(0) - builder.end_list() - - -@nb.jit(nopython=True) -def _find_in_range_single(rec_stages, min_stage, max_stage, builder): - """Construct an awkward1 array with the same structure as tracks.rec_stages. - - When stages are within the range(min, max), the Array is filled with - value 1, otherwise it is filled with value 0. - - Parameters - ---------- - rec_stages : awkward1.Array - tracks.rec_stages of a SINGLE event. - min_stage: int - minimum value of rec_stages. - max_stage: int - minimum value of rec_stages. - builder : awkward1.highlevel.ArrayBuilder - awkward1 Array builder. - """ - builder.begin_list() - for s in rec_stages: - num_stages = len(s) - if num_stages != 0: - found = 0 - for i in s: - if min_stage <= i <= max_stage: - found += 1 - if found == num_stages: - builder.append(1) - else: - builder.append(0) - else: - builder.append(0) - builder.end_list() - - -def _mask_rec_stages_in_set(tracks, stages): - """Mask tracks where rec_stages are withing the range(min, max). - - Parameters - ---------- - tracks : km3io.offline.OfflineBranch - tracks, or one track, or slice of tracks, or slices of tracks. - stages : set - set of stages to look for in tracks.rec_stages. - - Returns - ------- - awkward1.Array - an awkward1 Array mask where True corresponds to the positions - where stages were found. False otherwise. - """ - if isinstance(stages, set): - - builder = ak.ArrayBuilder() - if tracks.is_single: - _find_in_set_single(tracks.rec_stages, stages, builder) - return (builder.snapshot() == 1)[0] - else: - _find_in_set(tracks.rec_stages, stages, builder) - return builder.snapshot() == 1 - - else: - raise ValueError("stages must be a set") - - -@nb.jit(nopython=True) -def _find_in_set(rec_stages, stages, builder): - """Construct an awkward1 array with the same structure as tracks.rec_stages. - - When all stages are found in rec_stages, the Array is filled with - value 1, otherwise it is filled with value 0. - - Parameters - ---------- - rec_stages : awkward1.Array - tracks.rec_stages of MULTILPLE events. - stages : set - set of stages. - builder : awkward1.highlevel.ArrayBuilder - awkward1 Array builder. - - """ - n = len(stages) - for s in rec_stages: - builder.begin_list() - for i in s: - num_stages = len(i) - if num_stages != 0: - found = 0 - for j in i: - if j in stages: - found += 1 - if found == n: - builder.append(1) - else: - builder.append(0) - else: - builder.append(0) - builder.end_list() - - -@nb.jit(nopython=True) -def _find_in_set_single(rec_stages, stages, builder): - """Construct an awkward1 array with the same structure as tracks.rec_stages. - - When all stages are found in rec_stages, the Array is filled with - value 1, otherwise it is filled with value 0. - - Parameters - ---------- - rec_stages : awkward1.Array - tracks.rec_stages of a SINGLE event. - stages : set - set of stages. - builder : awkward1.highlevel.ArrayBuilder - awkward1 Array builder. - """ - n = len(stages) - builder.begin_list() - for s in rec_stages: - num_stages = len(s) - if num_stages != 0: - found = 0 - for j in s: - if j in stages: - found += 1 - if found == n: - builder.append(1) - else: - builder.append(0) - else: - builder.append(0) - builder.end_list() + """Select the best DISJSHOWER track.""" + return best_track(tracks, minmax=(krec.DUSJSHOWERBEGIN, krec.DUSJSHOWEREND)) def is_cc(fobj): @@ -788,32 +446,49 @@ def is_cc(fobj): w2list = fobj.events.w2list len_w2lists = ak.num(w2list, axis=1) - if all(len_w2lists <= 7): # old nu file have w2list of len 7. - usr_names = fobj.events.mc_tracks.usr_names - usr_data = fobj.events.mc_tracks.usr - mask_cc_flag = usr_names[:, 0] == b"cc" - inter_ID = usr_data[:, 0][mask_cc_flag] - out = ak.flatten(inter_ID == 2) # 2 is the interaction ID for CC. + # According to: https://wiki.km3net.de/index.php/Simulations/The_gSeaGen_code#Physics_event_entries + # the interaction types are defined as follow: - else: - if "gseagen" in program.lower(): + # INTER Interaction type + # 1 EM + # 2 Weak[CC] + # 3 Weak[NC] + # 4 Weak[CC+NC+interference] + # 5 NucleonDecay - # According to: https://wiki.km3net.de/index.php/Simulations/The_gSeaGen_code#Physics_event_entries - # the interaction types are defined as follow: - # INTER Interaction type - # 1 EM - # 2 Weak[CC] - # 3 Weak[NC] - # 4 Weak[CC+NC+interference] - # 5 NucleonDecay + if all(len_w2lists <= 7): # old nu file have w2list of len 7. + # Checking the `cc` value in usr of the first mc_tracks, + # which are the primary neutrinos and carry the event property. + # This has been changed in 2020 to be a property in the w2list. + # See https://git.km3net.de/common/km3net-dataformat/-/issues/23 + return usr(fobj.events.mc_tracks[:, 0], "cc") == 2 - cc_flag = w2list[:, kw2gsg.W2LIST_GSEAGEN_CC] - out = cc_flag > 0 # to be tested with a newly generated nu file. + else: + # TODO: to be tested with a newly generated files with th eupdated + # w2list definitionn. + if "gseagen" in program.lower(): + return w2list[:, kw2gen.W2LIST_GSEAGEN_CC] == 2 if "genhen" in program.lower(): - cc_flag = w2list[:, kw2gen.W2LIST_GENHEN_CC] - out = cc_flag > 0 + return w2list[:, kw2gen.W2LIST_GENHEN_CC] == 2 else: - raise ValueError(f"simulation program {program} is not implemented.") + raise NotImplementedError( + f"don't know how to determine the CC-ness of {program} files." + ) - return out + +def usr(objects, field): + """Return the usr-data for a given field. + + Parameters + ---------- + objects : awkward.Array + Events, tracks, hits or whatever objects which have usr and usr_names + fields (e.g. OfflineReader().events). + """ + if len(unique(ak.num(objects.usr_names))) > 1: + # let's do it the hard way + return ak.flatten(objects.usr[objects.usr_names == field]) + available_fields = objects.usr_names[0].tolist() + idx = available_fields.index(field) + return objects.usr[:, idx] diff --git a/requirements/install.txt b/requirements/install.txt index 127b6746a48426ffc4ad6a3ab0db28511bff126b..a9844c1814f4066cfb94dc54eed51fffa75359fa 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -3,4 +3,5 @@ numba>=0.50 awkward>=1.0.0rc2 awkward0 uproot3>=3.11.1 +uproot>=4.0.0rc5 setuptools_scm diff --git a/tests/test_gseagen.py b/tests/test_gseagen.py index 4b55e89ff63426635e6b93b6d873d361882b4834..85a41a2baa0c73fe2bddc66ce7bcf738bec77469 100644 --- a/tests/test_gseagen.py +++ b/tests/test_gseagen.py @@ -1,6 +1,8 @@ import os import re import unittest +import inspect +import awkward as ak from km3net_testdata import data_path @@ -8,18 +10,22 @@ from km3io.gseagen import GSGReader GSG_READER = GSGReader(data_path("gseagen/gseagen.root")) +AWKWARD_STR_CLASSES = [ + s[1] for s in inspect.getmembers(ak.behaviors.string, inspect.isclass) +] + class TestGSGHeader(unittest.TestCase): def setUp(self): self.header = GSG_READER.header def test_str_byte_type(self): - assert isinstance(self.header["gSeaGenVer"], str) - assert isinstance(self.header["GenieVer"], str) - assert isinstance(self.header["gSeaGenVer"], str) - assert isinstance(self.header["InpXSecFile"], str) - assert isinstance(self.header["Flux1"], str) - assert isinstance(self.header["Flux2"], str) + assert type(self.header["gSeaGenVer"]) in AWKWARD_STR_CLASSES + assert type(self.header["GenieVer"]) in AWKWARD_STR_CLASSES + assert type(self.header["gSeaGenVer"]) in AWKWARD_STR_CLASSES + assert type(self.header["InpXSecFile"]) in AWKWARD_STR_CLASSES + assert type(self.header["Flux1"]) in AWKWARD_STR_CLASSES + assert type(self.header["Flux2"]) in AWKWARD_STR_CLASSES def test_values(self): assert self.header["RunNu"] == 1 @@ -55,11 +61,6 @@ class TestGSGHeader(unittest.TestCase): assert self.header["NNu"] == 2 self.assertListEqual(self.header["NuList"].tolist(), [-14, 14]) - def test_unsupported_header(self): - f = GSGReader(data_path("online/km3net_online.root")) - with self.assertWarns(UserWarning): - f.header - class TestGSGEvents(unittest.TestCase): def setUp(self): diff --git a/tests/test_offline.py b/tests/test_offline.py index b99cb8b12b9ee4166b666c50bc18471d9195f27b..84d0cb6338d70b75e5d6e5126c35ea3ad4ab052c 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -1,11 +1,14 @@ import unittest import numpy as np from pathlib import Path +import uuid +import awkward as ak from km3net_testdata import data_path from km3io import OfflineReader -from km3io.offline import _nested_mapper, Header +from km3io.offline import Header +from km3io.tools import usr OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root")) OFFLINE_USR = OfflineReader(data_path("offline/usr-sample.root")) @@ -32,7 +35,7 @@ class TestOfflineReader(unittest.TestCase): assert self.n_events == len(self.r.events) def test_uuid(self): - assert self.r.uuid == "0001b192d888fcc711e9b4306cf09e86beef" + assert str(self.r.uuid) == "b192d888-fcc7-11e9-b430-6cf09e86beef" class TestHeader(unittest.TestCase): @@ -147,23 +150,21 @@ class TestOfflineEvents(unittest.TestCase): def test_len(self): assert self.n_events == len(self.events) - def test_attributes_available(self): - for key in self.events._keymap.keys(): - getattr(self.events, key) - def test_attributes(self): assert self.n_events == len(self.events.det_id) self.assertListEqual(self.det_id, list(self.events.det_id)) + print(self.n_hits) + print(self.events.hits) self.assertListEqual(self.n_hits, list(self.events.n_hits)) self.assertListEqual(self.n_tracks, list(self.events.n_tracks)) self.assertListEqual(self.t_sec, list(self.events.t_sec)) self.assertListEqual(self.t_ns, list(self.events.t_ns)) def test_keys(self): - assert np.allclose(self.n_hits, self.events["n_hits"]) - assert np.allclose(self.n_tracks, self.events["n_tracks"]) - assert np.allclose(self.t_sec, self.events["t_sec"]) - assert np.allclose(self.t_ns, self.events["t_ns"]) + assert np.allclose(self.n_hits, self.events["n_hits"].tolist()) + assert np.allclose(self.n_tracks, self.events["n_tracks"].tolist()) + assert np.allclose(self.t_sec, self.events["t_sec"].tolist()) + assert np.allclose(self.t_ns, self.events["t_ns"].tolist()) def test_slicing(self): s = slice(2, 8, 2) @@ -176,20 +177,28 @@ class TestOfflineEvents(unittest.TestCase): def test_slicing_consistency(self): for s in [slice(1, 3), slice(2, 7, 3)]: - assert np.allclose(self.events[s].n_hits, self.events.n_hits[s]) + assert np.allclose( + self.events[s].n_hits.tolist(), self.events.n_hits[s].tolist() + ) def test_index_consistency(self): for i in [0, 2, 5]: assert np.allclose(self.events[i].n_hits, self.events.n_hits[i]) def test_index_chaining(self): - assert np.allclose(self.events[3:5].n_hits, self.events.n_hits[3:5]) + assert np.allclose( + self.events[3:5].n_hits.tolist(), self.events.n_hits[3:5].tolist() + ) assert np.allclose(self.events[3:5][0].n_hits, self.events.n_hits[3:5][0]) + + def test_index_chaining_on_nested_branches_aka_records(self): assert np.allclose( - self.events[3:5].hits[1].dom_id[4], self.events.hits[3:5][1][4].dom_id + self.events[3:5].hits[1].dom_id[4], + self.events.hits[3:5][1].dom_id[4], ) assert np.allclose( - self.events.hits[3:5][1][4].dom_id, self.events[3:5][1][4].hits.dom_id + self.events.hits[3:5][1].dom_id[4], + self.events[3:5][1].hits.dom_id[4], ) def test_fancy_indexing(self): @@ -207,8 +216,24 @@ class TestOfflineEvents(unittest.TestCase): assert 10 == i def test_iteration_2(self): - n_hits = [e.n_hits for e in self.events] - assert np.allclose(n_hits, self.events.n_hits) + n_hits = [len(e.hits.id) for e in self.events] + assert np.allclose(n_hits, ak.num(self.events.hits.id, axis=1).tolist()) + + def test_iteration_over_slices(self): + ids = [e.id for e in self.events[2:5]] + self.assertListEqual([3, 4, 5], ids) + + def test_iteration_over_slices_raises_when_stepsize_not_supported(self): + with self.assertRaises(NotImplementedError): + [e.id for e in self.events[2:8:2]] + + def test_iteration_over_slices_raises_when_single_item(self): + with self.assertRaises(NotImplementedError): + [e.id for e in self.events[0]] + + def test_iteration_over_slices_raises_when_multiple_slices(self): + with self.assertRaises(NotImplementedError): + [e.id for e in self.events[2:8][2:4]] def test_str(self): assert str(self.n_events) in str(self.events) @@ -274,16 +299,14 @@ class TestOfflineHits(unittest.TestCase): ], } - def test_attributes_available(self): - for key in self.hits._keymap.keys(): + def test_fields_work_as_keys_and_attributes(self): + for key in self.hits.fields: getattr(self.hits, key) + self.hits[key] def test_channel_ids(self): - self.assertTrue(all(c >= 0 for c in self.hits.channel_id.min())) - self.assertTrue(all(c < 31 for c in self.hits.channel_id.max())) - - def test_str(self): - assert str(self.n_hits) in str(self.hits) + self.assertTrue(all(c >= 0 for c in ak.min(self.hits.channel_id, axis=1))) + self.assertTrue(all(c < 31 for c in ak.max(self.hits.channel_id, axis=1))) def test_repr(self): assert str(self.n_hits) in repr(self.hits) @@ -292,7 +315,7 @@ class TestOfflineHits(unittest.TestCase): for idx, dom_id in self.dom_id.items(): self.assertListEqual(dom_id, list(self.hits.dom_id[idx][: len(dom_id)])) for idx, t in self.t.items(): - assert np.allclose(t, self.hits.t[idx][: len(t)]) + assert np.allclose(t, self.hits.t[idx][: len(t)].tolist()) def test_slicing(self): s = slice(2, 8, 2) @@ -306,28 +329,39 @@ class TestOfflineHits(unittest.TestCase): def test_slicing_consistency(self): for s in [slice(1, 3), slice(2, 7, 3)]: for idx in range(3): - assert np.allclose(self.hits.dom_id[idx][s], self.hits[idx].dom_id[s]) assert np.allclose( - OFFLINE_FILE.events[idx].hits.dom_id[s], self.hits.dom_id[idx][s] + self.hits.dom_id[idx][s].tolist(), self.hits[idx].dom_id[s].tolist() + ) + assert np.allclose( + OFFLINE_FILE.events[idx].hits.dom_id[s].tolist(), + self.hits.dom_id[idx][s].tolist(), ) def test_index_consistency(self): for idx, dom_ids in self.dom_id.items(): assert np.allclose( - self.hits[idx].dom_id[: self.n_hits], dom_ids[: self.n_hits] + self.hits[idx].dom_id[: self.n_hits].tolist(), dom_ids[: self.n_hits] ) assert np.allclose( - OFFLINE_FILE.events[idx].hits.dom_id[: self.n_hits], + OFFLINE_FILE.events[idx].hits.dom_id[: self.n_hits].tolist(), dom_ids[: self.n_hits], ) for idx, ts in self.t.items(): - assert np.allclose(self.hits[idx].t[: self.n_hits], ts[: self.n_hits]) assert np.allclose( - OFFLINE_FILE.events[idx].hits.t[: self.n_hits], ts[: self.n_hits] + self.hits[idx].t[: self.n_hits].tolist(), ts[: self.n_hits] + ) + assert np.allclose( + OFFLINE_FILE.events[idx].hits.t[: self.n_hits].tolist(), + ts[: self.n_hits], ) - def test_keys(self): - assert "dom_id" in self.hits.keys() + def test_fields(self): + assert "dom_id" in self.hits.fields + assert "channel_id" in self.hits.fields + assert "t" in self.hits.fields + assert "tot" in self.hits.fields + assert "trig" in self.hits.fields + assert "id" in self.hits.fields class TestOfflineTracks(unittest.TestCase): @@ -337,16 +371,24 @@ class TestOfflineTracks(unittest.TestCase): self.tracks_numucc = OFFLINE_NUMUCC self.n_events = 10 - def test_attributes_available(self): - for key in self.tracks._keymap.keys(): - getattr(self.tracks, key) - - @unittest.skip - def test_attributes(self): - for idx, dom_id in self.dom_id.items(): - self.assertListEqual(dom_id, list(self.hits.dom_id[idx][: len(dom_id)])) - for idx, t in self.t.items(): - assert np.allclose(t, self.hits.t[idx][: len(t)]) + def test_fields(self): + for field in [ + "id", + "pos_x", + "pos_y", + "pos_z", + "dir_x", + "dir_y", + "dir_z", + "t", + "E", + "len", + "lik", + "rec_type", + "rec_stages", + "fitinf", + ]: + getattr(self.tracks, field) def test_item_selection(self): self.assertListEqual( @@ -354,23 +396,23 @@ class TestOfflineTracks(unittest.TestCase): ) def test_repr(self): - assert " 10 " in repr(self.tracks) + assert "10" in repr(self.tracks) def test_slicing(self): tracks = self.tracks self.assertEqual(10, len(tracks)) # 10 events - self.assertEqual(56, len(tracks[0])) # number of tracks in first event + self.assertEqual(56, len(tracks[0].id)) # number of tracks in first event track_selection = tracks[2:7] assert 5 == len(track_selection) track_selection_2 = tracks[1:3] assert 2 == len(track_selection_2) for _slice in [ - slice(0, 0), slice(0, 1), slice(0, 2), slice(1, 5), slice(3, -2), ]: + print(f"checking {_slice}") self.assertListEqual( list(tracks.E[:, 0][_slice]), list(tracks[_slice].E[:, 0]) ) @@ -382,15 +424,7 @@ class TestOfflineTracks(unittest.TestCase): ) self.assertAlmostEqual( self.f.events.tracks.fitinf[3:5][1][9][2], - self.f.events[3:5][1][9][2].tracks.fitinf, - ) - self.assertAlmostEqual( - self.f.events.tracks.fitinf[3:5][1][9][2], - self.f.events[3:5][1].tracks[9][2].fitinf, - ) - self.assertAlmostEqual( - self.f.events.tracks.fitinf[3:5][1][9][2], - self.f.events[3:5][1].tracks[9].fitinf[2], + self.f.events[3:5][1].tracks.fitinf[9][2], ) @@ -398,18 +432,28 @@ class TestBranchIndexingMagic(unittest.TestCase): def setUp(self): self.events = OFFLINE_FILE.events - def test_foo(self): + def test_slicing_magic(self): self.assertEqual(318, self.events[2:4].n_hits[0]) assert np.allclose( self.events[3].tracks.dir_z[10], self.events.tracks.dir_z[3, 10] ) assert np.allclose( - self.events[3:6].tracks.pos_y[:, 0], self.events.tracks.pos_y[3:6, 0] + self.events[3:6].tracks.pos_y[:, 0].tolist(), + self.events.tracks.pos_y[3:6, 0].tolist(), ) + def test_selecting_specific_items_via_a_list(self): # test selecting with a list self.assertEqual(3, len(self.events[[0, 2, 3]])) + def test_selecting_specific_items_via_a_numpy_array(self): + # test selecting with a list + self.assertEqual(3, len(self.events[np.array([0, 2, 3])])) + + def test_selecting_specific_items_via_a_awkward_array(self): + # test selecting with a list + self.assertEqual(3, len(self.events[ak.Array([0, 2, 3])])) + class TestUsr(unittest.TestCase): def setUp(self): @@ -439,27 +483,7 @@ class TestUsr(unittest.TestCase): "NGeometryVetoHits", "ClassficationScore", ], - self.f.events.usr.keys(), - ) - - def test_getitem_flat(self): - assert np.allclose( - [118.6302815337638, 44.33580521344907, 99.93916717621543], - self.f.events.usr["CoC"], - ) - assert np.allclose( - [37.51967774166617, -10.280346193553832, 13.67595659707355], - self.f.events.usr["DeltaPosZ"], - ) - - def test_attributes_flat(self): - assert np.allclose( - [118.6302815337638, 44.33580521344907, 99.93916717621543], - self.f.events.usr.CoC, - ) - assert np.allclose( - [37.51967774166617, -10.280346193553832, 13.67595659707355], - self.f.events.usr.DeltaPosZ, + self.f.events.usr_names[0].tolist(), ) @@ -471,11 +495,11 @@ class TestMcTrackUsr(unittest.TestCase): n_tracks = len(self.f.events) for i in range(3): self.assertListEqual( - [b"bx", b"by", b"ichan", b"cc"], + ["bx", "by", "ichan", "cc"], self.f.events.mc_tracks.usr_names[i][0].tolist(), ) self.assertListEqual( - [b"energy_lost_in_can"], + ["energy_lost_in_can"], self.f.events.mc_tracks.usr_names[i][1].tolist(), ) @@ -488,8 +512,3 @@ class TestMcTrackUsr(unittest.TestCase): assert np.allclose( [0.147, 0.4, 3, 2], self.f.events.mc_tracks.usr[1][0].tolist(), atol=0.001 ) - - -class TestNestedMapper(unittest.TestCase): - def test_nested_mapper(self): - self.assertEqual("pos_x", _nested_mapper("trks.pos.x")) diff --git a/tests/test_tools.py b/tests/test_tools.py index d88a03523b122ce24f1e1302eeaf245c25b3dea7..6610a4b28264973af36ecdf09f8f757ec9c450f0 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -18,7 +18,6 @@ from km3io.tools import ( uniquecount, fitinf, count_nested, - _find, mask, best_track, get_w2list_param, @@ -28,9 +27,16 @@ from km3io.tools import ( best_aashower, best_dusjshower, is_cc, + usr, ) OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root")) +OFFLINE_USR = OfflineReader(data_path("offline/usr-sample.root")) +OFFLINE_MC_TRACK_USR = OfflineReader( + data_path( + "offline/mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root" + ) +) GENHEN_OFFLINE_FILE = OfflineReader( data_path("offline/mcv5.1.genhen_anumuNC.sirene.jte.jchain.aashower.sample.root") ) @@ -58,18 +64,12 @@ class TestFitinf(unittest.TestCase): assert beta[1] == self.best_fit[1][0] assert beta[2] == self.best_fit[2][0] - def test_fitinf_from_one_event_and_one_track(self): - beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks[0][0]) - - assert beta == self.tracks[0][0].fitinf[0] - class TestBestTrackSelection(unittest.TestCase): def setUp(self): self.events = OFFLINE_FILE.events self.one_event = OFFLINE_FILE.events[0] - @unittest.skip def test_best_track_selection_from_multiple_events_with_explicit_stages_in_list( self, ): @@ -102,58 +102,46 @@ class TestBestTrackSelection(unittest.TestCase): assert best3.rec_stages[2] is None assert best3.rec_stages[3] is None - @unittest.skip - def test_best_track_selection_from_multiple_events_with_explicit_stages_in_set( + def test_best_track_selection_from_multiple_events_with_a_set_of_stages( self, ): best = best_track(self.events.tracks, stages={1, 3, 4, 5}) assert len(best) == 10 - assert best.rec_stages[0].tolist() == [[1, 3, 5, 4]] - assert best.rec_stages[1].tolist() == [[1, 3, 5, 4]] - assert best.rec_stages[2].tolist() == [[1, 3, 5, 4]] - assert best.rec_stages[3].tolist() == [[1, 3, 5, 4]] + assert best.rec_stages[0].tolist() == [1, 3, 5, 4] + assert best.rec_stages[1].tolist() == [1, 3, 5, 4] + assert best.rec_stages[2].tolist() == [1, 3, 5, 4] + assert best.rec_stages[3].tolist() == [1, 3, 5, 4] # test with a shorter set of rec_stages best2 = best_track(self.events.tracks, stages={1, 3}) assert len(best2) == 10 - assert best2.rec_stages[0].tolist() == [[1, 3]] - assert best2.rec_stages[1].tolist() == [[1, 3]] - assert best2.rec_stages[2].tolist() == [[1, 3]] - assert best2.rec_stages[3].tolist() == [[1, 3]] - - # test the irrelevance of order in rec_stages in sets - best3 = best_track(self.events.tracks, stages={3, 1}) - - assert len(best3) == 10 - - assert best3.rec_stages[0].tolist() == [[1, 3]] - assert best3.rec_stages[1].tolist() == [[1, 3]] - assert best3.rec_stages[2].tolist() == [[1, 3]] - assert best3.rec_stages[3].tolist() == [[1, 3]] + for rec_stages in best2.rec_stages: + for stage in {1, 3}: + assert stage in rec_stages def test_best_track_selection_from_multiple_events_with_start_end(self): best = best_track(self.events.tracks, startend=(1, 4)) assert len(best) == 10 - assert best.rec_stages[0].tolist() == [[1, 3, 5, 4]] - assert best.rec_stages[1].tolist() == [[1, 3, 5, 4]] - assert best.rec_stages[2].tolist() == [[1, 3, 5, 4]] - assert best.rec_stages[3].tolist() == [[1, 3, 5, 4]] + assert best.rec_stages[0].tolist() == [1, 3, 5, 4] + assert best.rec_stages[1].tolist() == [1, 3, 5, 4] + assert best.rec_stages[2].tolist() == [1, 3, 5, 4] + assert best.rec_stages[3].tolist() == [1, 3, 5, 4] # test with shorter stages best2 = best_track(self.events.tracks, startend=(1, 3)) assert len(best2) == 10 - assert best2.rec_stages[0].tolist() == [[1, 3]] - assert best2.rec_stages[1].tolist() == [[1, 3]] - assert best2.rec_stages[2].tolist() == [[1, 3]] - assert best2.rec_stages[3].tolist() == [[1, 3]] + assert best2.rec_stages[0].tolist() == [1, 3] + assert best2.rec_stages[1].tolist() == [1, 3] + assert best2.rec_stages[2].tolist() == [1, 3] + assert best2.rec_stages[3].tolist() == [1, 3] # test the importance of start as a real start of rec_stages best3 = best_track(self.events.tracks, startend=(0, 3)) @@ -179,23 +167,20 @@ class TestBestTrackSelection(unittest.TestCase): # stages as a list best = best_track(self.one_event.tracks, stages=[1, 3, 5, 4]) - assert len(best) == 1 assert best.lik == ak.max(self.one_event.tracks.lik) - assert np.allclose(best.rec_stages[0].tolist(), [1, 3, 5, 4]) + assert np.allclose(best.rec_stages.tolist(), [1, 3, 5, 4]) # stages as a set best2 = best_track(self.one_event.tracks, stages={1, 3, 4, 5}) - assert len(best2) == 1 assert best2.lik == ak.max(self.one_event.tracks.lik) - assert np.allclose(best2.rec_stages[0].tolist(), [1, 3, 5, 4]) + assert np.allclose(best2.rec_stages.tolist(), [1, 3, 5, 4]) # stages with start and end best3 = best_track(self.one_event.tracks, startend=(1, 4)) - assert len(best3) == 1 assert best3.lik == ak.max(self.one_event.tracks.lik) - assert np.allclose(best3.rec_stages[0].tolist(), [1, 3, 5, 4]) + assert np.allclose(best3.rec_stages.tolist(), [1, 3, 5, 4]) def test_best_track_on_slices_one_event(self): tracks_slice = self.one_event.tracks[self.one_event.tracks.rec_type == 4000] @@ -203,63 +188,90 @@ class TestBestTrackSelection(unittest.TestCase): # test stages with list best = best_track(tracks_slice, stages=[1, 3, 5, 4]) - assert len(best) == 1 - assert best.lik == ak.max(tracks_slice.lik) - assert best.rec_stages[0].tolist() == [1, 3, 5, 4] + assert best.rec_stages.tolist() == [1, 3, 5, 4] # test stages with set best2 = best_track(tracks_slice, stages={1, 3, 4, 5}) - assert len(best2) == 1 - assert best2.lik == ak.max(tracks_slice.lik) - assert best2.rec_stages[0].tolist() == [1, 3, 5, 4] + assert best2.rec_stages.tolist() == [1, 3, 5, 4] def test_best_track_on_slices_with_start_end_one_event(self): tracks_slice = self.one_event.tracks[0:5] best = best_track(tracks_slice, startend=(1, 4)) - assert len(best) == 1 assert best.lik == ak.max(tracks_slice.lik) - assert best.rec_stages[0][0] == 1 - assert best.rec_stages[0][-1] == 4 + assert best.rec_stages[0] == 1 + assert best.rec_stages[-1] == 4 def test_best_track_on_slices_with_explicit_rec_stages_one_event(self): tracks_slice = self.one_event.tracks[0:5] best = best_track(tracks_slice, stages=[1, 3, 5, 4]) assert best.lik == ak.max(tracks_slice.lik) - assert best.rec_stages[0][0] == 1 - assert best.rec_stages[0][-1] == 4 + assert best.rec_stages[0] == 1 + assert best.rec_stages[-1] == 4 - @unittest.skip def test_best_track_on_slices_multiple_events(self): - tracks_slice = self.events.tracks[0:5] + tracks_slice = self.events[0:5].tracks # stages in list best = best_track(tracks_slice, stages=[1, 3, 5, 4]) assert len(best) == 5 - assert best.lik == ak.max(tracks_slice.lik) - assert best.rec_stages[0].tolist() == [1, 3, 5, 4] + assert np.allclose( + best.lik.tolist(), + [ + 294.6407542676734, + 96.75133289411137, + 560.2775306614813, + 278.2872951665753, + 99.59098153341449, + ], + ) + for i in range(len(best)): + assert best.rec_stages[i].tolist() == [1, 3, 5, 4] # stages in set - best = best_track(tracks_slice, stages={1, 3, 4, 5}) + best = best_track(tracks_slice, stages={3, 4, 5}) assert len(best) == 5 - assert best.lik == ak.max(tracks_slice.lik) - assert best.rec_stages[0].tolist() == [1, 3, 5, 4] + assert np.allclose( + best.lik.tolist(), + [ + 294.6407542676734, + 96.75133289411137, + 560.2775306614813, + 278.2872951665753, + 99.59098153341449, + ], + ) + for i in range(len(best)): + assert best.rec_stages[i].tolist() == [1, 3, 5, 4] # using start and end - best = best_track(tracks_slice, startend=(1, 4)) + start, end = (1, 4) + best = best_track(tracks_slice, startend=(start, end)) assert len(best) == 5 - assert best.lik == ak.max(tracks_slice.lik) - assert best.rec_stages[0].tolist() == [1, 3, 5, 4] + assert np.allclose( + best.lik.tolist(), + [ + 294.6407542676734, + 96.75133289411137, + 560.2775306614813, + 278.2872951665753, + 99.59098153341449, + ], + ) + for i in range(len(best)): + rs = best.rec_stages[i].tolist() + assert rs[0] == start + assert rs[-1] == end def test_best_track_raises_when_unknown_stages(self): with self.assertRaises(ValueError): @@ -276,10 +288,10 @@ class TestBestJmuon(unittest.TestCase): assert len(best) == 10 - assert best.rec_stages[0].tolist() == [[1, 3, 5, 4]] - assert best.rec_stages[1].tolist() == [[1, 3, 5, 4]] - assert best.rec_stages[2].tolist() == [[1, 3, 5, 4]] - assert best.rec_stages[3].tolist() == [[1, 3, 5, 4]] + assert best.rec_stages[0].tolist() == [1, 3, 5, 4] + assert best.rec_stages[1].tolist() == [1, 3, 5, 4] + assert best.rec_stages[2].tolist() == [1, 3, 5, 4] + assert best.rec_stages[3].tolist() == [1, 3, 5, 4] assert best.lik[0] == ak.max(OFFLINE_FILE.events.tracks.lik[0]) assert best.lik[1] == ak.max(OFFLINE_FILE.events.tracks.lik[1]) @@ -378,28 +390,18 @@ class TestRecStagesMasks(unittest.TestCase): self.tracks = OFFLINE_FILE.events.tracks - def test_find(self): - builder = ak.ArrayBuilder() - _find(self.nested, ak.Array([1, 2, 3]), builder) - labels = builder.snapshot() - - assert labels[0][0] == 1 - assert labels[0][1] == 1 - assert labels[0][2] == 0 - assert labels[1][0] == 0 - def test_mask_with_explicit_rec_stages_in_list_with_multiple_events(self): rec_stages = self.tracks.rec_stages stages = [1, 3, 5, 4] - masks = mask(self.tracks, stages=stages) + masks = mask(self.tracks.rec_stages, sequence=stages) assert masks[0][0] == all(rec_stages[0][0] == ak.Array(stages)) assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages)) assert masks[0][1] == False - def test_mask_with_explicit_rec_stages_in_set_with_multiple_events(self): - stages = {1, 3, 4, 5} - masks = mask(self.tracks, stages=stages) + def test_mask_with_atleast_on_multiple_events(self): + stages = [1, 3, 4, 5] + masks = mask(self.tracks.rec_stages, atleast=stages) tracks = self.tracks[masks] assert 1 in tracks.rec_stages[0][0] @@ -410,7 +412,7 @@ class TestRecStagesMasks(unittest.TestCase): def test_mask_with_start_and_end_of_rec_stages_with_multiple_events(self): rec_stages = self.tracks.rec_stages stages = [1, 3, 5, 4] - masks = mask(self.tracks, startend=(1, 4)) + masks = mask(self.tracks.rec_stages, startend=(1, 4)) assert masks[0][0] == all(rec_stages[0][0] == ak.Array(stages)) assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages)) @@ -420,7 +422,7 @@ class TestRecStagesMasks(unittest.TestCase): rec_stages = self.tracks.rec_stages[0][0] stages = [1, 3, 5, 4] track = self.tracks[0] - masks = mask(track, startend=(1, 4)) + masks = mask(track.rec_stages, startend=(1, 4)) assert track[masks].rec_stages[0][0] == 1 assert track[masks].rec_stages[0][-1] == 4 @@ -429,20 +431,37 @@ class TestRecStagesMasks(unittest.TestCase): rec_stages = self.tracks.rec_stages[0][0] stages = [1, 3] track = self.tracks[0] - masks = mask(track, stages=stages) + masks = mask(track.rec_stages, sequence=stages) assert track[masks].rec_stages[0][0] == stages[0] assert track[masks].rec_stages[0][1] == stages[1] - def test_mask_raises_when_too_many_inputs(self): - with self.assertRaises(ValueError): - mask(self.tracks, startend=(1, 4), stages=[1, 3, 5, 4]) - def test_mask_raises_when_no_inputs(self): with self.assertRaises(ValueError): mask(self.tracks) +class TestMask(unittest.TestCase): + def test_minmax_2dim_mask(self): + arr = ak.Array([[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]]) + m = mask(arr, minmax=(1, 4)) + self.assertListEqual(m.tolist(), [True, False, False]) + + def test_minmax_3dim_mask(self): + arr = ak.Array([[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]]) + m = mask(arr, minmax=(1, 4)) + self.assertListEqual(m.tolist(), [[True, False, False], [True]]) + + def test_minmax_4dim_mask(self): + arr = ak.Array( + [[[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]], [[[1, 9], [3, 3]]]] + ) + m = mask(arr, minmax=(1, 4)) + self.assertListEqual( + m.tolist(), [[[True, False, False], [True]], [[False, True]]] + ) + + class TestUnique(unittest.TestCase): def run_random_test_with_dtype(self, dtype): max_range = 100 @@ -543,3 +562,22 @@ class TestIsCC(unittest.TestCase): all(NC_file) == True ) # this test fails because the CC flags are not reliable in old files self.assertTrue(all(CC_file) == True) + + +class TestUsr(unittest.TestCase): + def test_event_usr(self): + assert np.allclose( + [118.6302815337638, 44.33580521344907, 99.93916717621543], + usr(OFFLINE_USR.events, "CoC").tolist(), + ) + assert np.allclose( + [37.51967774166617, -10.280346193553832, 13.67595659707355], + usr(OFFLINE_USR.events, "DeltaPosZ").tolist(), + ) + + def test_mc_tracks_usr(self): + assert np.allclose( + [0.0487], + usr(OFFLINE_MC_TRACK_USR.mc_tracks[0], "bx").tolist(), + atol=0.0001, + )