Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
K
km3io
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
km3py
km3io
Commits
20fd4765
Commit
20fd4765
authored
4 years ago
by
Tamas Gal
Browse files
Options
Downloads
Patches
Plain Diff
Getting ready
parent
ebc24b5c
No related branches found
No related tags found
1 merge request
!39
WIP: Resolve "uproot4 integration"
Pipeline
#16161
failed
4 years ago
Stage: test
Stage: coverage
Stage: doc
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
km3io/offline.py
+112
-43
112 additions, 43 deletions
km3io/offline.py
tests/test_offline.py
+26
-28
26 additions, 28 deletions
tests/test_offline.py
with
138 additions
and
71 deletions
km3io/offline.py
+
112
−
43
View file @
20fd4765
from
collections
import
namedtuple
import
uproot4
as
uproot
import
warnings
import
uproot4
as
uproot
import
numpy
as
np
import
awkward1
as
ak
from
.definitions
import
mc_header
from
.tools
import
cached_property
,
to_num
from
.tools
import
cached_property
,
to_num
,
unfold_indices
class
OfflineReader
:
...
...
@@ -70,46 +72,69 @@ class OfflineReader:
"
mc_tracks
"
:
"
mc_trks
"
,
}
def
__init__
(
self
,
f
ile_path
,
step_size
=
2000
):
def
__init__
(
self
,
f
,
index_chain
=
None
,
step_size
=
2000
,
keys
=
None
,
aliases
=
None
,
event_ctor
=
None
):
"""
OfflineReader class is an offline ROOT file wrapper
Parameters
----------
file_path : path-like object
Path to the file of interest. It can be a str or any python
path-like object that points to the file.
f: str or uproot4.reading.ReadOnlyDirectory (from uproot4.open)
Path to the file of interest or uproot4 filedescriptor.
step_size: int, optional
Number of events to read into the cache when iterating.
Choosing higher numbers may improve the speed but also increases
the memory overhead.
index_chain: list, optional
Keeps track of index chaining.
keys: list or set, optional
Branch keys.
aliases: dict, optional
Branch key aliases.
event_ctor: class or namedtuple, optional
Event constructor.
"""
self
.
_fobj
=
uproot
.
open
(
file_path
)
self
.
step_size
=
step_size
self
.
_filename
=
file_path
if
isinstance
(
f
,
str
):
self
.
_fobj
=
uproot
.
open
(
f
)
self
.
_filepath
=
f
elif
isinstance
(
f
,
uproot
.
reading
.
ReadOnlyDirectory
):
self
.
_fobj
=
f
self
.
_filepath
=
f
.
_file
.
file_path
else
:
raise
TypeError
(
"
Unsupported file descriptor.
"
)
self
.
_step_size
=
step_size
self
.
_uuid
=
self
.
_fobj
.
_file
.
uuid
self
.
_iterator_index
=
0
self
.
_keys
=
None
self
.
_grouped_counts
=
{}
# TODO: e.g. {"events": [3, 66, 34]}
if
"
E/Evt/AAObject/usr
"
in
self
.
_fobj
:
if
ak
.
count
(
f
[
"
E/Evt/AAObject/usr
"
].
array
())
>
0
:
self
.
aliases
.
update
({
"
usr
"
:
"
AAObject/usr
"
,
"
usr_names
"
:
"
AAObject/usr_names
"
,
})
self
.
_initialise_keys
()
self
.
_event_ctor
=
namedtuple
(
self
.
item_name
,
set
(
list
(
self
.
keys
())
+
list
(
self
.
aliases
)
+
list
(
self
.
special_branches
)
+
list
(
self
.
special_aliases
)
),
)
self
.
_keys
=
keys
self
.
_event_ctor
=
event_ctor
self
.
_index_chain
=
[]
if
index_chain
is
None
else
index_chain
if
aliases
is
not
None
:
self
.
aliases
=
aliases
else
:
# Check for usr-awesomeness backward compatibility crap
print
(
"
Found usr data
"
)
if
"
E/Evt/AAObject/usr
"
in
self
.
_fobj
:
if
ak
.
count
(
f
[
"
E/Evt/AAObject/usr
"
].
array
())
>
0
:
self
.
aliases
.
update
(
{
"
usr
"
:
"
AAObject/usr
"
,
"
usr_names
"
:
"
AAObject/usr_names
"
,
}
)
if
self
.
_keys
is
None
:
self
.
_initialise_keys
()
if
self
.
_event_ctor
is
None
:
self
.
_event_ctor
=
namedtuple
(
self
.
item_name
,
set
(
list
(
self
.
keys
())
+
list
(
self
.
aliases
)
+
list
(
self
.
special_branches
)
+
list
(
self
.
special_aliases
)
),
)
def
_initialise_keys
(
self
):
skip_keys
=
set
(
self
.
skip_keys
)
...
...
@@ -144,9 +169,23 @@ class OfflineReader:
)
def
__getitem__
(
self
,
key
):
if
key
.
startswith
(
"
n_
"
):
# group counts, for e.g. n_events, n_hits etc.
# indexing
if
isinstance
(
key
,
(
slice
,
int
,
np
.
int32
,
np
.
int64
)):
if
not
isinstance
(
key
,
slice
):
key
=
int
(
key
)
return
self
.
__class__
(
self
.
_fobj
,
index_chain
=
self
.
_index_chain
+
[
key
],
step_size
=
self
.
_step_size
,
aliases
=
self
.
aliases
,
keys
=
self
.
keys
(),
event_ctor
=
self
.
_event_ctor
)
if
isinstance
(
key
,
str
)
and
key
.
startswith
(
"
n_
"
):
# group counts, for e.g. n_events, n_hits etc.
key
=
self
.
_keyfor
(
key
.
split
(
"
n_
"
)[
1
])
return
self
.
_fobj
[
self
.
event_path
][
key
].
array
(
uproot
.
AsDtype
(
"
>i4
"
))
arr
=
self
.
_fobj
[
self
.
event_path
][
key
].
array
(
uproot
.
AsDtype
(
"
>i4
"
))
return
unfold_indices
(
arr
,
self
.
_index_chain
)
key
=
self
.
_keyfor
(
key
)
branch
=
self
.
_fobj
[
self
.
event_path
]
...
...
@@ -154,10 +193,13 @@ class OfflineReader:
# We are explicitly grabbing just a predefined set of subbranches
# and also alias them to be backwards compatible (and attribute-accessible)
if
key
in
self
.
special_branches
:
return
branch
[
key
].
arrays
(
out
=
branch
[
key
].
arrays
(
self
.
special_branches
[
key
].
keys
(),
aliases
=
self
.
special_branches
[
key
]
)
return
branch
[
self
.
aliases
.
get
(
key
,
key
)].
array
()
else
:
out
=
branch
[
self
.
aliases
.
get
(
key
,
key
)].
array
()
return
unfold_indices
(
out
,
self
.
_index_chain
)
def
__iter__
(
self
):
self
.
_iterator_index
=
0
...
...
@@ -167,13 +209,18 @@ class OfflineReader:
def
_event_generator
(
self
):
events
=
self
.
_fobj
[
self
.
event_path
]
group_count_keys
=
set
(
k
for
k
in
self
.
keys
()
if
k
.
startswith
(
"
n_
"
))
keys
=
set
(
list
(
set
(
self
.
keys
())
-
set
(
self
.
special_branches
.
keys
())
-
set
(
self
.
special_aliases
)
-
group_count_keys
)
+
list
(
self
.
aliases
.
keys
()))
events_it
=
events
.
iterate
(
keys
,
aliases
=
self
.
aliases
,
step_size
=
self
.
step_size
)
keys
=
set
(
list
(
set
(
self
.
keys
())
-
set
(
self
.
special_branches
.
keys
())
-
set
(
self
.
special_aliases
)
-
group_count_keys
)
+
list
(
self
.
aliases
.
keys
())
)
events_it
=
events
.
iterate
(
keys
,
aliases
=
self
.
aliases
,
step_size
=
self
.
_step_size
)
specials
=
[]
special_keys
=
(
self
.
special_branches
.
keys
()
...
...
@@ -183,7 +230,7 @@ class OfflineReader:
events
[
key
].
iterate
(
self
.
special_branches
[
key
].
keys
(),
aliases
=
self
.
special_branches
[
key
],
step_size
=
self
.
step_size
,
step_size
=
self
.
_
step_size
,
)
)
group_counts
=
{}
...
...
@@ -206,7 +253,29 @@ class OfflineReader:
return
next
(
self
.
_events
)
def
__len__
(
self
):
return
self
.
_fobj
[
self
.
event_path
].
num_entries
if
not
self
.
_index_chain
:
return
self
.
_fobj
[
self
.
event_path
].
num_entries
elif
isinstance
(
self
.
_index_chain
[
-
1
],
(
int
,
np
.
int32
,
np
.
int64
)):
if
len
(
self
.
_index_chain
)
==
1
:
return
1
# try:
# return len(self[:])
# except IndexError:
# return 1
return
1
else
:
# ignore the usual index magic and access `id` directly
return
len
(
self
.
_fobj
[
self
.
event_path
][
"
id
"
].
array
(),
self
.
_index_chain
)
def
__actual_len__
(
self
):
"""
The raw number of events without any indexing/slicing magic
"""
return
len
(
self
.
_fobj
[
self
.
event_path
][
"
id
"
].
array
())
def
__repr__
(
self
):
length
=
len
(
self
)
actual_length
=
self
.
__actual_len__
()
return
f
"
{
self
.
__class__
.
__name__
}
(
{
length
}{
'
/
'
+
str
(
actual_length
)
if
length
<
actual_length
else
''
}
events)
"
@property
def
uuid
(
self
):
...
...
This diff is collapsed.
Click to expand it.
tests/test_offline.py
+
26
−
28
View file @
20fd4765
...
...
@@ -149,12 +149,6 @@ class TestOfflineEvents(unittest.TestCase):
def
test_len
(
self
):
assert
self
.
n_events
==
len
(
self
.
events
)
@unittest.skip
def
test_attributes_available
(
self
):
for
key
in
self
.
events
.
_keymap
.
keys
():
print
(
f
"
checking
{
key
}
"
)
getattr
(
self
.
events
,
key
)
def
test_attributes
(
self
):
assert
self
.
n_events
==
len
(
self
.
events
.
det_id
)
self
.
assertListEqual
(
self
.
det_id
,
list
(
self
.
events
.
det_id
))
...
...
@@ -165,7 +159,6 @@ class TestOfflineEvents(unittest.TestCase):
self
.
assertListEqual
(
self
.
t_sec
,
list
(
self
.
events
.
t_sec
))
self
.
assertListEqual
(
self
.
t_ns
,
list
(
self
.
events
.
t_ns
))
@unittest.skip
def
test_keys
(
self
):
assert
np
.
allclose
(
self
.
n_hits
,
self
.
events
[
"
n_hits
"
].
tolist
())
assert
np
.
allclose
(
self
.
n_tracks
,
self
.
events
[
"
n_tracks
"
].
tolist
())
...
...
@@ -182,38 +175,37 @@ class TestOfflineEvents(unittest.TestCase):
self
.
assertListEqual
(
self
.
t_sec
[
s
],
list
(
s_events
.
t_sec
))
self
.
assertListEqual
(
self
.
t_ns
[
s
],
list
(
s_events
.
t_ns
))
@unittest.skip
def
test_slicing_consistency
(
self
):
for
s
in
[
slice
(
1
,
3
),
slice
(
2
,
7
,
3
)]:
assert
np
.
allclose
(
self
.
events
[
s
].
n_hits
.
tolist
(),
self
.
events
.
n_hits
[
s
].
tolist
()
)
@unittest.skip
def
test_index_consistency
(
self
):
for
i
in
[
0
,
2
,
5
]:
assert
np
.
allclose
(
self
.
events
[
i
].
n_hits
.
tolist
()
,
self
.
events
.
n_hits
[
i
]
.
tolist
()
self
.
events
[
i
].
n_hits
,
self
.
events
.
n_hits
[
i
]
)
@unittest.skip
def
test_index_chaining
(
self
):
assert
np
.
allclose
(
self
.
events
[
3
:
5
].
n_hits
.
tolist
(),
self
.
events
.
n_hits
[
3
:
5
].
tolist
()
)
assert
np
.
allclose
(
self
.
events
[
3
:
5
][
0
].
n_hits
.
tolist
()
,
self
.
events
.
n_hits
[
3
:
5
][
0
]
.
tolist
()
self
.
events
[
3
:
5
][
0
].
n_hits
,
self
.
events
.
n_hits
[
3
:
5
][
0
]
)
@unittest.skip
def
test_index_chaining_on_nested_branches_aka_records
(
self
):
assert
np
.
allclose
(
self
.
events
[
3
:
5
].
hits
[
1
].
dom_id
[
4
]
.
tolist
()
,
self
.
events
.
hits
[
3
:
5
][
1
][
4
].
dom_id
.
tolist
()
,
self
.
events
[
3
:
5
].
hits
[
1
].
dom_id
[
4
],
self
.
events
.
hits
[
3
:
5
][
1
][
4
].
dom_id
,
)
assert
np
.
allclose
(
self
.
events
.
hits
[
3
:
5
][
1
][
4
].
dom_id
.
tolist
(),
self
.
events
[
3
:
5
][
1
][
4
].
hits
.
dom_id
.
tolist
(),
)
@unittest.skip
def
test_fancy_indexing
(
self
):
mask
=
self
.
events
.
n_tracks
>
55
tracks
=
self
.
events
.
tracks
[
mask
]
...
...
@@ -305,9 +297,6 @@ class TestOfflineHits(unittest.TestCase):
self
.
assertTrue
(
all
(
c
>=
0
for
c
in
ak
.
min
(
self
.
hits
.
channel_id
,
axis
=
1
)))
self
.
assertTrue
(
all
(
c
<
31
for
c
in
ak
.
max
(
self
.
hits
.
channel_id
,
axis
=
1
)))
def
test_str
(
self
):
assert
str
(
self
.
n_hits
)
in
str
(
self
.
hits
)
def
test_repr
(
self
):
assert
str
(
self
.
n_hits
)
in
repr
(
self
.
hits
)
...
...
@@ -344,19 +333,24 @@ class TestOfflineHits(unittest.TestCase):
)
assert
np
.
allclose
(
OFFLINE_FILE
.
events
[
idx
].
hits
.
dom_id
[:
self
.
n_hits
].
tolist
(),
dom_ids
[:
self
.
n_hits
]
.
tolist
()
,
dom_ids
[:
self
.
n_hits
],
)
for
idx
,
ts
in
self
.
t
.
items
():
assert
np
.
allclose
(
self
.
hits
[
idx
].
t
[:
self
.
n_hits
].
tolist
(),
ts
[:
self
.
n_hits
]
.
tolist
()
self
.
hits
[
idx
].
t
[:
self
.
n_hits
].
tolist
(),
ts
[:
self
.
n_hits
]
)
assert
np
.
allclose
(
OFFLINE_FILE
.
events
[
idx
].
hits
.
t
[:
self
.
n_hits
].
tolist
(),
ts
[:
self
.
n_hits
]
.
tolist
()
,
ts
[:
self
.
n_hits
],
)
def
test_keys
(
self
):
assert
"
dom_id
"
in
self
.
hits
.
keys
()
def
test_fields
(
self
):
assert
"
dom_id
"
in
self
.
hits
.
fields
assert
"
channel_id
"
in
self
.
hits
.
fields
assert
"
t
"
in
self
.
hits
.
fields
assert
"
tot
"
in
self
.
hits
.
fields
assert
"
trig
"
in
self
.
hits
.
fields
assert
"
id
"
in
self
.
hits
.
fields
class
TestOfflineTracks
(
unittest
.
TestCase
):
...
...
@@ -366,9 +360,9 @@ class TestOfflineTracks(unittest.TestCase):
self
.
tracks_numucc
=
OFFLINE_NUMUCC
self
.
n_events
=
10
def
test_
attributes_available
(
self
):
for
key
in
self
.
tracks
.
_keymap
.
keys
()
:
getattr
(
self
.
tracks
,
key
)
def
test_
fields
(
self
):
for
field
in
[
'
id
'
,
'
pos_x
'
,
'
pos_y
'
,
'
pos_z
'
,
'
dir_x
'
,
'
dir_y
'
,
'
dir_z
'
,
'
t
'
,
'
E
'
,
'
len
'
,
'
lik
'
,
'
rec_type
'
,
'
rec_stages
'
,
'
fitinf
'
]
:
getattr
(
self
.
tracks
,
field
)
@unittest.skip
def
test_attributes
(
self
):
...
...
@@ -383,8 +377,9 @@ class TestOfflineTracks(unittest.TestCase):
)
def
test_repr
(
self
):
assert
"
10
"
in
repr
(
self
.
tracks
)
assert
"
10
*
"
in
repr
(
self
.
tracks
)
@unittest.skip
def
test_slicing
(
self
):
tracks
=
self
.
tracks
self
.
assertEqual
(
10
,
len
(
tracks
))
# 10 events
...
...
@@ -404,6 +399,7 @@ class TestOfflineTracks(unittest.TestCase):
list
(
tracks
.
E
[:,
0
][
_slice
]),
list
(
tracks
[
_slice
].
E
[:,
0
])
)
@unittest.skip
def
test_nested_indexing
(
self
):
self
.
assertAlmostEqual
(
self
.
f
.
events
.
tracks
.
fitinf
[
3
:
5
][
1
][
9
][
2
],
...
...
@@ -427,7 +423,7 @@ class TestBranchIndexingMagic(unittest.TestCase):
def
setUp
(
self
):
self
.
events
=
OFFLINE_FILE
.
events
def
test_
foo
(
self
):
def
test_
slicing_magic
(
self
):
self
.
assertEqual
(
318
,
self
.
events
[
2
:
4
].
n_hits
[
0
])
assert
np
.
allclose
(
self
.
events
[
3
].
tracks
.
dir_z
[
10
],
self
.
events
.
tracks
.
dir_z
[
3
,
10
]
...
...
@@ -437,6 +433,8 @@ class TestBranchIndexingMagic(unittest.TestCase):
self
.
events
.
tracks
.
pos_y
[
3
:
6
,
0
].
tolist
(),
)
@unittest.skip
def
test_selecting_specific_items_via_a_list
(
self
):
# test selecting with a list
self
.
assertEqual
(
3
,
len
(
self
.
events
[[
0
,
2
,
3
]]))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment