Skip to content
Snippets Groups Projects
Verified Commit 7d3ce525 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

More linalg optimisations

parent 6303719b
No related branches found
No related tags found
1 merge request!4Muonscanfit tuning
Pipeline #44502 passed
......@@ -60,6 +60,9 @@ struct HitR1 <: AbstractReducedHit
n::Int
weight::Float64
end
# TODO: rething the design of `time(hit)`. `time()` is called several times and slewing is anyways
# needed, so we can instead use the constructor below to pass `combined_hit.t - slew(combined_hit.tot)`
# as time and then use `hit.t` everywhere instead of `time(hit)`, see below
Base.isless(lhs::HitR1, rhs::HitR1) = lhs.dom_id == rhs.dom_id ? time(lhs) < time(rhs) : lhs.dom_id < rhs.dom_id
const HitR2 = HitR1
function HitR1(dom_id::Integer, hits::Vector{HitL0})
......@@ -67,7 +70,11 @@ function HitR1(dom_id::Integer, hits::Vector{HitL0})
h = first(hits)
count = weight = length(hits)
HitR1(dom_id, h.pos, combined_hit.t, combined_hit.tot, count, weight)
# TODO: something like this (and the inline function below)
# HitR1(dom_id, h.pos, combined_hit.t - slew(combined_hit.tot), combined_hit.tot, count, weight)
end
# @inline Base.time(hit::HitR1) = hit.t
# function HitR1(m::DetectorModule, hit::HitL1)
# count = weight = length(hit)
# HitR1(m.id, position(hit), time(hit), tot(hit), count, weight)
......@@ -496,12 +503,21 @@ mutable struct Match3B <: AbstractMatcher
end
Base.show(io::IO, m::Match3B) = print(io, "Match3B($(m.roadwidth), $(m.tmaxextra))")
function (m::Match3B)(hit1, hit2)
"""
The Match3B algorithm.
"""
(m::Match3B)(hit1, hit2) = m(hit1, hit2, time(hit1), time(hit2))
"""
Optimised version which avoids the usage of calling `time(hit)` multiple times.
"""
function (m::Match3B)(hit1, hit2, time1, time2)
m.x = hit1.pos.x - hit2.pos.x
m.y = hit1.pos.y - hit2.pos.y
m.z = hit1.pos.z - hit2.pos.z
m.d₂ = m.x * m.x + m.y * m.y + m.z * m.z
m.t = abs(time(hit1) - time(hit2))
m.t = abs(time1 - time2)
if (m.d₂ < m.D02)
m.dmax = √m.d₂ * KM3io.Constants.INDEX_OF_REFRACTION_WATER
......@@ -542,10 +558,14 @@ mutable struct Match1D <: AbstractMatcher
end
end
function (m::Match1D)(hit1, hit2)
(m::Match1D)(hit1, hit2) = m(hit1, hit2, time(hit1), time(hit2))
"""
Optimised version of Match1D which avoids the usage of calling `time(hit)` multiple times.
"""
function (m::Match1D)(hit1, hit2, time1, time2)
m.z = hit1.pos.z - hit2.pos.z
m.t = abs(time(hit1) - time(hit2) - m.z * KM3io.Constants.C_INVERSE)
m.t = abs(time1 - time2 - m.z * KM3io.Constants.C_INVERSE)
m.t > m.tmax && return false
......@@ -573,28 +593,32 @@ Clique clusterizer which takes a matcher algorithm like `Match3B` as input.
struct Clique{T<:AbstractMatcher}
match::T
weights::Vector{Float64}
Clique(m::T) where T = new{T}(m, Float64[])
times::Vector{Float64}
Clique(m::T) where T = new{T}(m, Float64[], Float64[])
end
"""
Applies the clique clusterization algorithm and leaves only the best matching
hits in the input array.
"""
function clusterize!(hits::AbstractArray{T}, m::AbstractMatcher) where T<:AbstractSpecialHit
c = Clique(m)
clusterize!(hits::AbstractArray{T}, m::AbstractMatcher) where T<:AbstractSpecialHit = clusterize!(hits, Clique(m))
function clusterize!(hits::AbstractArray{T}, c::Clique) where T<:AbstractSpecialHit
N = length(hits)
N == 0 && return hits
resize!(c.weights, N)
resize!(c.times, N)
times = c.times
@inbounds for i 1:N
c.weights[i] = weight(hits[i])
times[i] = time(hits[i])
end
@inbounds for i 1:N
@inbounds for j i:N
j == i && continue
if c.match(hits[i], hits[j])
if c.match(hits[i], hits[j], times[i], times[j])
c.weights[i] += weight(hits[j])
c.weights[j] += weight(hits[i])
end
......@@ -624,11 +648,12 @@ function clusterize!(hits::AbstractArray{T}, m::AbstractMatcher) where T<:Abstra
# Swap the selected hit to end.
swap!(hits, j, n)
swap!(c.weights, j, n)
swap!(times, j, n)
# Decrease weight of associated hits for each associated hit.
@inbounds for i 1:n
c.weights[n] <= weight(hits[n]) && break
if c.match(hits[i], hits[n])
if c.match(hits[i], hits[n], times[i], times[n])
c.weights[i] -= weight(hits[n])
c.weights[n] -= weight(hits[i])
end
......
......@@ -180,7 +180,7 @@ function estimate!(est::Line1ZEstimator, hits)
reset!(est)
y = y₁ = y₂ = 0.0
yvec = zeros(MVector{3})
hit₀ = first(hits)
xi = hit₀.pos.x - posx(lz)
yi = hit₀.pos.y - posy(lz)
......@@ -210,9 +210,9 @@ function estimate!(est::Line1ZEstimator, hits)
est.V[2, 3] += dy * dt
est.V[3, 3] += dt * dt
y += dx * y
y += dy * y
y += dt * y
yvec[1] += dx * y
yvec[2] += dy * y
yvec[3] += dt * y
xi = xj
yi = yj
......@@ -228,16 +228,21 @@ function estimate!(est::Line1ZEstimator, hits)
end
invert!(est.V, est.MINIMAL_SVD_WEIGHT)
# Hermitian is needed for typestability!
wvec, evecs = invert2!(Hermitian(est.V), est.MINIMAL_SVD_WEIGHT)
yvec2 = (evecs' * yvec)
yvec2 .*= wvec
mul!(yvec, evecs, yvec2)
#yvec = evecs * (diagm(wvec) * (evecs' * yvec))
@inbounds begin
est.model = Line1Z(
Position(
pos.x + est.V[1, 1] * y₀ + est.V[1, 2] * y₁ + est.V[1, 3] * y₂,
pos.y + est.V[2, 1] * y₀ + est.V[2, 2] * y₁ + est.V[2, 3] * y₂,
pos.x + yvec[1],
pos.y + yvec[2],
posz(lz)
),
(est.V[3, 1] * y₀ + est.V[3, 2] * y₁ + est.V[3, 3] * y₂) * KM3io.Constants.KAPPA_WATER * KM3io.Constants.C_INVERSE + t₀
yvec[3] * KM3io.Constants.KAPPA_WATER * KM3io.Constants.C_INVERSE + t₀
)
end
......@@ -264,6 +269,19 @@ function invert!(V, precision)
mul!(V, F.U, diagm(F.S) * F.Vt)
end
@inline function invert2!(V, precision)
evals, evecs = eigen(V)
abs(evals[2]) < precision * abs(evals[1]) && throw(SingularSVDException("$evals"))
w = maximum(abs, evals) * precision
wvec = ifelse.(abs.(evals) .>= w, inv.(evals), zero(float(eltype(evals))))
return wvec, evecs
#mul!(V, evecs, diagm(wvec) * evecs')
end
struct Variance <: FieldVector{4, Float64}
x::Float64
y::Float64
......@@ -386,9 +404,8 @@ function (s::XYTSolver)(hits::Vector{T}, dir::Direction{Float64}, α::Float64) w
# TODO: better name for this function
timeresvec!(s.timeresvec, s.est.model, hits)
V⁻¹ = inv(V)
Y = view(s.timeresvec, 1:n_final_hits) # only take the relevant part of the buffer
χ² = transpose(Y) * V⁻¹ * Y
χ² = dot(Y, V \ Y)
fit_pos = R \ s.est.model.pos
MuonScanfitCandidate(fit_pos, dir, s.est.model.t, quality(χ², N, NDF), NDF)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment