diff --git a/src/hardware.jl b/src/hardware.jl
index 0633225767fe6bfb876bfc59971b8038915320aa..54fc88cf1f0b7a50430e4a0a6ba3c674414e3bf6 100644
--- a/src/hardware.jl
+++ b/src/hardware.jl
@@ -280,6 +280,7 @@ struct Detector
     locations::Dict{Tuple{Int, Int}, DetectorModule}
     strings::Vector{Int}
     comments::Vector{String}
+    _pmt_id_module_map::Dict{Int, DetectorModule}
 end
 """
 Return a vector of all modules of a given detector.
@@ -324,7 +325,15 @@ Return the detector module for a given location.
 """
 Return the `PMT` for a given hit.
 """
-@inline getpmt(d::Detector, hit) = getpmt(getmodule(d, hit.dom_id), hit.channel_id)
+@inline getpmt(d::Detector, hit::AbstractDAQHit) = getpmt(getmodule(d, hit.dom_id), hit.channel_id)
+"""
+Return the detector module for a given DAQ hit.
+"""
+@inline getmodule(d::Detector, hit::AbstractDAQHit) = getmodule(d, hit.dom_id)
+"""
+Return the detector module for a given MC hit.
+"""
+@inline getmodule(d::Detector, hit::AbstractMCHit) = d._pmt_id_module_map[hit.pmt_id]
 Base.getindex(d::Detector, string::Int, ::Colon) = sort!(filter(m->m.location.string == string, modules(d)))
 Base.getindex(d::Detector, string::Int, floors::T) where T<:Union{AbstractArray, UnitRange} = [d[string, floor] for floor in sort(floors)]
 Base.getindex(d::Detector, ::Colon, floor::Int) = sort!(filter(m->m.location.floor == floor, modules(d)))
@@ -422,6 +431,7 @@ function read_datx(io::IO)
     modules = Dict{Int32, DetectorModule}()
     locations = Dict{Tuple{Int, Int}, DetectorModule}()
     strings = Int[]
+    _pmt_id_module_map = Dict{Int, DetectorModule}()
     for _ in 1:n_modules
         module_id = read(io, Int32)
         location = Location(read(io, Int32), read(io, Int32))
@@ -443,10 +453,13 @@ function read_datx(io::IO)
             push!(pmts, PMT(pmt_id, pmt_pos, pmt_dir, pmt_tâ‚€, pmt_status))
         end
         m = DetectorModule(module_id, module_pos, location, n_pmts, pmts, q, module_status, module_tâ‚€)
+        for pmt in pmts
+            _pmt_id_module_map[pmt.id] = m
+        end
         modules[module_id] = m
         locations[(location.string, location.floor)] = m
     end
-    Detector(version, det_id, validity, utm_position, utm_ref_grid, n_modules, modules, locations, strings, comments)
+    Detector(version, det_id, validity, utm_position, utm_ref_grid, n_modules, modules, locations, strings, comments, _pmt_id_module_map)
 end
 @inline _readstring(io) = String(read(io, read(io, Int32)))
 
@@ -481,6 +494,7 @@ function read_detx(io::IO)
     modules = Dict{Int32, DetectorModule}()
     locations = Dict{Tuple{Int, Int}, DetectorModule}()
     strings = Int8[]
+    _pmt_id_module_map = Dict{Int, DetectorModule}()
 
     # a counter to work around the floor == -1 bug in some older DETX files
     floor_counter = 1
@@ -556,10 +570,14 @@ function read_detx(io::IO)
         m = DetectorModule(module_id, pos, Location(string, floor), n_pmts, pmts, q, status, tâ‚€)
         modules[module_id] = m
         locations[(string, floor)] = m
+        for pmt in pmts
+            _pmt_id_module_map[pmt.id] = m
+        end
+
         idx += n_pmts + 1
     end
 
-    Detector(version, det_id, validity, utm_position, utm_ref_grid, n_modules, modules, locations, strings, comments)
+    Detector(version, det_id, validity, utm_position, utm_ref_grid, n_modules, modules, locations, strings, comments, _pmt_id_module_map)
 end
 
 
@@ -607,6 +625,7 @@ The `version` parameter can be a version number or `:same`, which is the default
 and writes the same version as the provided detector has.
 """
 function write(filename::AbstractString, d::Detector; version=:same)
+    !endswith(filename, ".detx") && error("Only DETX is supported for detector writing.")
     isfile(filename) && @warn "File '$(filename)' already exists, overwriting."
     open(filename, "w") do fobj
         write(fobj, d; version=version)
diff --git a/src/root/offline.jl b/src/root/offline.jl
index 2bcf756d8208eae073950e4c9fbd1852e3649bf9..e591e748aedd9dee04a3fabed931dab6afe8ebf0 100644
--- a/src/root/offline.jl
+++ b/src/root/offline.jl
@@ -25,7 +25,7 @@ A calibrated MC hit of the offline dataformat. Caveat: the `position` and
 the offline format (one class for all).
 
 """
-struct CalibratedMCHit
+struct CalibratedMCHit <: AbstractCalibratedMCHit
     pmt_id::Int32
     t::Float64  # MC truth
     a::Float64  # amplitude (in p.e.)
diff --git a/src/types.jl b/src/types.jl
index 0f6d11260318b6ab61db42734a2717dfd3727bbb..97056347bb46d47ce1c7006a3700eb8687246576 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -50,6 +50,7 @@ abstract type AbstractHit end
 abstract type AbstractDAQHit<:AbstractHit end
 abstract type AbstractMCHit<:AbstractHit end
 abstract type AbstractCalibratedHit <: AbstractDAQHit end
+abstract type AbstractCalibratedMCHit <: AbstractMCHit end
 
 """
 
diff --git a/test/hardware.jl b/test/hardware.jl
index 7595fb55948297348ba3c11e95bf99950fa040d3..4244436083b87b883e216d2eab9d4d8f10d70d77 100644
--- a/test/hardware.jl
+++ b/test/hardware.jl
@@ -8,6 +8,9 @@ const SAMPLES_DIR = joinpath(@__DIR__, "samples")
 
 
 @testset "DETX parsing" begin
+    mchit_sample = KM3io.CalibratedMCHit(7636, 0, 0, 0, 0, Position(0.0, 0.0, 0.0), Direction(0.0, 0.0, 0.0))
+    daqhit_sample = KM3io.SnapshotHit(808966287, 0, 0, 0)
+
     for version ∈ 1:5
         d = Detector(joinpath(SAMPLES_DIR, "v$(version).detx"))
 
@@ -72,6 +75,8 @@ const SAMPLES_DIR = joinpath(@__DIR__, "samples")
         @test 817287557 == getmodule(d, 30, 18).id
         @test 817287557 == getmodule(d, (30, 18)).id
         @test 817287557 == getmodule(d, Location(30, 18)).id
+        @test 808966287 == getmodule(d, daqhit_sample).id
+        @test 808966287 == getmodule(d, mchit_sample).id
 
         @test 19 == length(d[:, 18])
         for m in d[:, 17]
@@ -153,7 +158,7 @@ end
     detx = Detector(datapath("detx", "KM3NeT_00000133_20221025.detx"))
     datx = Detector(datapath("datx", "KM3NeT_00000133_20221025.datx"))
     for field in fieldnames(Detector)
-        field == :modules && continue
+        field in (:modules, :_pmt_id_module_map) && continue
         if field == :locations
             detx_locs = getfield(detx, field)
             datx_locs = getfield(datx, field)