Skip to content
Snippets Groups Projects

Api update

Merged Zineb Aly requested to merge api-update into master
Compare and Show latest version
2 files
+ 231
103
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 136
86
import uproot
BASKET_CACHE_SIZE = 23 * 1024**2 # [byte] DON T FORGET TO INCLUDE THIS FOR LARGE DATA SET!!!
class AanetKeys:
"wrapper for aanet keys"
@@ -12,8 +14,11 @@ class AanetKeys:
self._mc_hits_keys = None
self._mc_tracks_keys = None
self._valid_keys = None
self._fit_keys = None
self._cut_hits_keys = None
self._cut_tracks_keys = None
def __repr__(self):
def __str__(self):
return '\n'.join([
"Events keys are:\n\t" + '\n\t'.join(self.events_keys),
"Hits keys are:\n\t" + '\n\t'.join(self.hits_keys),
@@ -22,8 +27,8 @@ class AanetKeys:
"Mc tracks keys are:\n\t" + '\n\t'.join(self.mc_tracks_keys)
])
def __str__(self):
return repr(self)
def __repr__(self):
return f'{self.__class__.__name__}("{self._file_path}")'
@property
def events_keys(self):
@@ -47,8 +52,10 @@ class AanetKeys:
@property
def tracks_keys(self):
if self._tracks_keys is None:
fake_branches = ['trks.usr_data',
'trks.usr_names'] # uproot can't read these
fake_branches = [
'trks.usr_data', # a solution can be tree['trks.usr_data'].array(uproot.asdtype(">i4"))
'trks.usr_names'
] # can be accessed using tree['trks.usr_names'].array()
tree = uproot.open(self._file_path)['E']['Evt']['trks']
self._tracks_keys = [
key.decode('utf8') for key in tree.keys()
@@ -66,8 +73,8 @@ class AanetKeys:
@property
def mc_tracks_keys(self):
if self._mc_tracks_keys is None:
fake_branches = ['mc_trks.usr_data',
'mc_trks.usr_names'] # uproot can't read these
fake_branches = ['mc_trks.usr_data', 'mc_trks.usr_names'
] # same solution as above can be used
tree = uproot.open(self._file_path)['E']['Evt']['mc_trks']
self._mc_tracks_keys = [
key.decode('utf8') for key in tree.keys()
@@ -89,8 +96,38 @@ class AanetKeys:
self.mc_hits_keys)
return self._valid_keys
@property
def fit_keys(self):
if self._fit_keys is None:
self._fit_keys = [
'JGANDALF_BETA0_RAD', 'JGANDALF_BETA1_RAD', 'JGANDALF_CHI2',
'JGANDALF_NUMBER_OF_HITS', 'JENERGY_ENERGY', 'JENERGY_CHI2',
'JGANDALF_LAMBDA', 'JGANDALF_NUMBER_OF_ITERATIONS',
'JSTART_NPE_MIP', 'JSTART_NPE_MIP_TOTAL',
'JSTART_LENGTH_METRES', 'JVETO_NPE', 'JVETO_NUMBER_OF_HITS',
'JENERGY_MUON_RANGE_METRES', 'JENERGY_NOISE_LIKELIHOOD',
'JENERGY_NDF', 'JENERGY_NUMBER_OF_HITS', 'JCOPY_Z_M'
]
return self._fit_keys
@property
def cut_hits_keys(self):
if self._cut_hits_keys is None:
self._cut_hits_keys = [
k.split('hits.')[1].replace('.', '_') for k in self.hits_keys
]
return self._cut_hits_keys
@property
def cut_tracks_keys(self):
if self._cut_tracks_keys is None:
self._cut_tracks_keys = [
k.split('trks.')[1].replace('.', '_') for k in self.tracks_keys
]
return self._cut_tracks_keys
class Reader(AanetKeys):
class Reader:
"""Reader for one Aanet ROOT file"""
def __init__(self, file_path):
""" AanetReader class is a Aanet ROOT file wrapper
@@ -101,8 +138,9 @@ class Reader(AanetKeys):
Path to the file of interest. It can be a str or any python
path-like object that points to the file of ineterst.
"""
super().__init__(file_path)
self._file_path = file_path
self._data = uproot.open(self._file_path)['E'].lazyarrays()
self._keys = None
def __getitem__(self, key):
"""reads data stored in the branch of interest in an event tree.
@@ -124,18 +162,24 @@ class Reader(AanetKeys):
Raises
------
KeyEroor
KeyError
Some branches in an Aanet file structure are "fake branches" and do
not contain data. Therefore, the keys corresponding to these fake
branches are not read.
"""
if key not in self.valid_keys and not isinstance(key, int):
if key not in self.keys.valid_keys and not isinstance(key, int):
raise KeyError(
"'{}' is not a valid key or is a fake branch.".format(key))
return self._data[key]
@property
def keys(self):
if self._keys is None:
self._keys = AanetKeys(self._file_path)
return self._keys
class AanetReader(AanetKeys):
class AanetReader:
def __init__(self, file_path, data=None):
""" AanetReader class is a Aanet ROOT file wrapper
@@ -145,7 +189,7 @@ class AanetReader(AanetKeys):
Path to the file of interest. It can be a str or any python
path-like object that points to the file of ineterst.
"""
super().__init__(file_path)
self._file_path = file_path
if data is not None:
self._data = data
else:
@@ -155,52 +199,59 @@ class AanetReader(AanetKeys):
self._tracks = None
self._mc_hits = None
self._mc_tracks = None
self._keys = None
# def __getitem__(self, item):
# return AanetEvents(self._events_keys, [self._lazy_data[key] for key in self.events])
# return AanetEvents(self._events_keys, [self._data[key] for key in self.events])
def __getitem__(self, item):
return AanetReader(file_path=self._file_path,
data=self._data[item])
return AanetReader(file_path=self._file_path, data=self._data[item])
@property
def keys(self):
if self._keys is None:
self._keys = AanetKeys(self._file_path)
return self._keys
@property
def events(self):
if self._events is None:
self._events = AanetEvents(
self.events_keys,
[self._data[key] for key in self.events_keys])
self.keys.events_keys,
[self._data[key] for key in self.keys.events_keys])
return self._events
@property
def hits(self):
if self._hits is None:
self._hits = AanetHits(
self.hits_keys,
[self._data[key] for key in self.hits_keys])
self.keys.cut_hits_keys,
[self._data[key] for key in self.keys.hits_keys])
return self._hits
@property
def tracks(self):
if self._tracks is None:
self._tracks = AanetTracks(
self.tracks_keys,
[self._data[key] for key in self.tracks_keys])
self.keys.cut_tracks_keys,
[self._data[key] for key in self.keys.tracks_keys],
fit_keys=self.keys.fit_keys)
return self._tracks
@property
def mc_hits(self):
if self._mc_hits is None:
self._mc_hits = AanetHits(
self.mc_hits_keys,
[self._data[key] for key in self.mc_hits_keys])
self.keys.cut_hits_keys,
[self._data[key] for key in self.keys.mc_hits_keys])
return self._mc_hits
@property
def mc_tracks(self):
if self._mc_tracks is None:
self._mc_tracks = AanetTracks(
self.mc_tracks_keys,
[self._data[key] for key in self.mc_tracks_keys])
self.keys.cut_tracks_keys,
[self._data[key] for key in self.keys.mc_tracks_keys])
return self._mc_tracks
@@ -224,6 +275,9 @@ class AanetEvents:
def __str__(self):
return "Number of events: {}".format(len(self))
# def __repr__(self):
# return f'{self.__class__.__name__}({self._keys}, {self._values})'
def __repr__(self):
return str(self)
@@ -243,6 +297,9 @@ class AanetEvent:
for k, v in zip(self._keys, self._values)
])
# def __repr__(self):
# return f'{self.__class__.__name__}({self._keys}, {self._values})'
def __repr__(self):
return str(self)
@@ -254,7 +311,7 @@ class AanetHits:
self._keys = keys # list of keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k.split('hits.')[1].replace('.', '_'), v)
setattr(self, k, v)
def __getitem__(self, item):
# return self._values[item]
@@ -262,20 +319,21 @@ class AanetHits:
def __len__(self):
try:
return len(self._values[0])
return len(
self._values[0]
) # this is missleading and it sometimes prints the number of events ...
except IndexError:
return 0
def __str__(self):
# hits
if all(key.startswith('hits.') for key in self._keys):
return "Number of hits: {}".format(len(self))
# mc hits
if all(key.startswith('mc_hits.') for key in self._keys):
return "Number of MC hits: {}".format(
len(self))
return "Number of hits: {}".format(len(self))
# def __repr__(self):
# return f'{self.__class__.__name__}({self._keys}, {self._values})'
def __repr__(self):
return str(self)
return "<{}: {} parsed elements>".format(self.__class__.__name__,
len(self))
class AanetHit:
@@ -285,29 +343,21 @@ class AanetHit:
self._keys = keys # list of keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k.split('hits.')[1].replace('.', '_'), v)
setattr(self, k, v)
def __str__(self):
# hits
if all(key.startswith('hits.') for key in self._keys):
return "Aanet hit:\n\t" + "\n\t".join([
"{:15} {:^10} {:>10}".format(
k.split('hits.')[1].replace('.', '_'), ':', str(v))
for k, v in zip(self._keys, self._values)
])
# mc hits
if all(key.startswith('mc_hits.') for key in self._keys):
return "Aanet mc hit:\n\t" + "\n\t".join([
"{:15} {:^10} {:>10}".format(
k.split('mc_hits.')[1].replace('.', '_'), ':', str(v))
for k, v in zip(self._keys, self._values)
])
return "Aanet hit:\n\t" + "\n\t".join([
"{:15} {:^10} {:>10}".format(k, ':', str(v))
for k, v in zip(self._keys, self._values)
])
def __getitem__(self, item):
# return self._values[item]
return self._values[item]
# def __repr__(self):
# return f'{self.__class__.__name__}({self._keys}, {self._values})'
def __repr__(self):
return str(self)
@@ -315,64 +365,64 @@ class AanetHit:
class AanetTracks:
"wrapper for Aanet tracks, manages the display of all tracks in one event"
def __init__(self, keys, values): # values is a list of lists
def __init__(self, keys, values,
fit_keys=None): # values is a list of lists
self._keys = keys # list of keys
self._values = values
if fit_keys is not None:
self._fit_keys = fit_keys
for k, v in zip(self._keys, self._values):
setattr(self, k.split('trks.')[1].replace('.', '_'), v)
setattr(self, k, v)
def __getitem__(self, item):
# return self._values[item]
return AanetTrack(self._keys, [v[item] for v in self._values])
return AanetTrack(self._keys, [v[item] for v in self._values],
fit_keys=self._fit_keys)
def __len__(self):
return len(
self._values[0]
) # I don't like this being explicit, what if values is empty ...
try:
return len(self._values[0])
except IndexError:
return 0
def __str__(self):
# hits
if all(key.startswith('trks.') for key in self._keys):
return "Number of tracks in the selected event: {}".format(
len(self))
# mc hits
if all(key.startswith('mc_trks.') for key in self._keys):
return "Number of mc tracks in the selected event: {}".format(
len(self))
return "Number of tracks: {}".format(
len(self)) # this is not correct when reader.tracks is called
# def __repr__(self):
# return f'{self.__class__.__name__}({self._keys}, {self._values})'
def __repr__(self):
return str(self)
return "<{}: {} parsed elements>".format(self.__class__.__name__,
len(self))
class AanetTrack:
"wrapper for an Aanet track"
def __init__(self, keys, values): # both inputs are lists
def __init__(self, keys, values, fit_keys=None): # both inputs are lists
self._keys = keys # list of keys
self._values = values
if fit_keys is not None:
self._fit_keys = fit_keys
for k, v in zip(self._keys, self._values):
setattr(self, k.split('trks.')[1].replace('.', '_'), v)
setattr(self, k, v)
def __str__(self):
# hits
if all(key.startswith('trks.') for key in self._keys):
return "Aanet track:\n\t" + "\n\t".join([
"{:15} {:^10} {:>10}".format(
k.split('trks.')[1].replace('.', '_'), ':', str(v))
for k, v in zip(self._keys, self._values)
])
# mc hits
if all(key.startswith('mc_trks.') for key in self._keys):
return "Aanet mc track:\n\t" + "\n\t".join([
"{:15} {:^10} {:>10}".format(
k.split('trks.')[1].replace('.', '_'), ':', str(v))
for k, v in zip(self._keys, self._values)
])
return "Aanet track:\n\t" + "\n\t".join([
"{:30} {:^2} {:>26}".format(k, ':', str(v))
for k, v in zip(self._keys, self._values) if k not in ['fitinf']
]) + "\n\t" + "\n\t".join([
"{:30} {:^2} {:>26}".format(k, ':', str(v))
for k, v in zip(self._fit_keys, self._values[18]
) # I don't like 18 being explicit here
])
def __getitem__(self, item):
# return self._values[item]
return self._values[item]
# def __repr__(self):
# return f'{self.__class__.__name__}({self._keys}, {self._values})'
def __repr__(self):
return str(self)
\ No newline at end of file
return str(self)
Loading