From 0181fbf84716286e28e211aa680fb6f4036af60f Mon Sep 17 00:00:00 2001
From: Tamas Gal <himself@tamasgal.com>
Date: Fri, 13 Oct 2023 10:29:30 +0200
Subject: [PATCH] Improve and extend detector indexing behaviour

---
 docs/src/api.md  |  2 ++
 src/KM3io.jl     |  2 +-
 src/hardware.jl  | 32 +++++++++++++++++++++++++++++++-
 test/hardware.jl | 31 +++++++++++++++++++++++++++++--
 4 files changed, 63 insertions(+), 4 deletions(-)

diff --git a/docs/src/api.md b/docs/src/api.md
index f8ff5099..b7165da8 100644
--- a/docs/src/api.md
+++ b/docs/src/api.md
@@ -49,6 +49,8 @@ flush
 PMT
 DetectorModule
 Detector
+getmodule
+modules
 write(::AbstractString, ::Detector)
 write(::IO, ::Detector)
 Hydrophone
diff --git a/src/KM3io.jl b/src/KM3io.jl
index 78eb0dae..7c4cfe72 100644
--- a/src/KM3io.jl
+++ b/src/KM3io.jl
@@ -21,7 +21,7 @@ export ROOTFile
 export H5File, H5CompoundDataset, create_dataset, addmeta
 
 export Direction, Position, UTMPosition, Location, Quaternion, Track, AbstractCalibratedHit
-export Detector, DetectorModule, PMT, Tripod, Hydrophone, center, isbasemodule
+export Detector, DetectorModule, PMT, Tripod, Hydrophone, center, isbasemodule, getmodule, modules
 
 # Acoustics
 export Waveform, AcousticSignal, AcousticsTriggerParameter, piezoenabled, hydrophoneenabled
diff --git a/src/hardware.jl b/src/hardware.jl
index 16343f08..a0036f2d 100644
--- a/src/hardware.jl
+++ b/src/hardware.jl
@@ -266,8 +266,38 @@ function Base.iterate(d::Detector, state=(Int[], 1))
     end
     (d.modules[module_ids[count]], (module_ids, count + 1))
 end
-Base.getindex(d::Detector, module_id) = d.modules[module_id]
+Base.getindex(d::Detector, module_id::Integer) = d.modules[module_id]
+Base.getindex(d::Detector, string::Integer, floor::Integer) = d.locations[string, floor]
+"""
+Return the detector module for a given string and floor.
+"""
+@inline getmodule(d::Detector, string::Integer, floor::Integer) = d[string, floor]
+"""
+Return the detector module for a given string and floor (as `Tuple`).
+"""
+@inline getmodule(d::Detector, loc::Tuple{T, T}) where T<:Integer = d[loc...]
+"""
+Return the detector module for a given location.
+"""
+@inline getmodule(d::Detector, loc::Location) = d[loc.string, loc.floor]
+Base.getindex(d::Detector, string::Int, ::Colon) = sort!(filter(m->m.location.string == string, modules(d)))
+Base.getindex(d::Detector, string::Int, floors::T) where T<:Union{AbstractArray, UnitRange} = [d[string, floor] for floor in sort(floors)]
+Base.getindex(d::Detector, ::Colon, floor::Int) = sort!(filter(m->m.location.floor == floor, modules(d)))
+"""
+Return a vector of detector modules for a given range of floors on all strings.
 
+This can be useful if specific detector module layers of the detector are needed, e.g.
+the base modules (e.g. `detector[:, 0]`) or the top layer (e.g. `detector[:, 18]`).
+"""
+function Base.getindex(d::Detector, ::Colon, floors::UnitRange{T}) where T<:Integer
+    modules = DetectorModule[]
+    for string in d.strings
+        for floor in floors
+            push!(modules, d[string, floor])
+        end
+    end
+    sort!(modules)
+end
 
 """
 
diff --git a/test/hardware.jl b/test/hardware.jl
index 3c7de6b1..df7ceeab 100644
--- a/test/hardware.jl
+++ b/test/hardware.jl
@@ -22,9 +22,11 @@ const SAMPLES_DIR = joinpath(@__DIR__, "samples")
                 # no base modules in DETX version <4
                 @test 342 == length(d)
                 @test 342 == length(d.modules)
+                @test 342 == length(modules(d))
                 @test DetectorModule == eltype(d)
             else
                 @test 361 == length(mods)
+                @test 361 == length(modules(d))
                 @test 106.95 ≈ d.modules[808469291].pos.y  # base module
                 @test 97.3720395 ≈ d.modules[808974928].pos.z  # base module
                 @test 0.0 == d.modules[808469291].tâ‚€  # base module
@@ -35,8 +37,20 @@ const SAMPLES_DIR = joinpath(@__DIR__, "samples")
             if version > 3
                 @test Quaternion(1, 0, 0, 0) ≈ d.modules[808995481].q
                 @test 19 == length(collect(m for m ∈ d if isbasemodule(m)))
+                @test 19 == length(d[:, 0])
+                @test 19 == length(d[30, :])
+                @test 5 == length(d[30, 0:4])
+                for m in d[30, 0:4]
+                    @test m.location.floor in 0:4
+                end
             else
                 @test 0 == length(collect(m for m ∈ d if isbasemodule(m)))
+                @test 0 == length(d[:, 0])
+                @test 18 == length(d[30, :])
+                @test 4 == length(d[30, 1:4])
+                for m in d[30, 1:4]
+                    @test m.location.floor in 1:4
+                end
             end
 
             if version > 4
@@ -51,8 +65,21 @@ const SAMPLES_DIR = joinpath(@__DIR__, "samples")
             end
 
             @test 31 == d.modules[808992603].n_pmts
-            @test 30 ≈ d.modules[817287557].location.string
-            @test 18 ≈ d.modules[817287557].location.floor
+            @test 30 == d.modules[817287557].location.string
+            @test 18 == d.modules[817287557].location.floor
+            @test Location(30, 18) == d[817287557].location
+            @test 817287557 == d[30, 18].id
+            @test 817287557 == getmodule(d, 30, 18).id
+            @test 817287557 == getmodule(d, (30, 18)).id
+            @test 817287557 == getmodule(d, Location(30, 18)).id
+
+            @test 19 == length(d[:, 18])
+            for m in d[:, 17]
+                @test 17 == m.location.floor
+            end
+            for m in d[:, 1:4]
+                @test m.location.floor in 1:4
+            end
 
             @test 478392.31980645156 ≈ d.modules[808992603].t₀
 
-- 
GitLab