import numpy as np
import scipy.signal as ssg

from functools import partial

"""
    Calculates distance(s) between a point 'p' and a(n array of) point(s) 'q'
    as the norm of the difference vector(s) 'pq'
    m = number of values    
    n = number of coordinates
    p ~ (1, n) - q ~ (m, n) => pq ~ (m, n)
"""
def dist(p, q):
    pq = p - q
    axis = pq.ndim - 1 # so clever! ;)
    return np.linalg.norm(pq, axis=axis)


# envelope for the chirp
def window(t, t1):
    return np.exp(-(2*t/t1)**100)

# windowed chirp
def wchirp(t, f0, t1, f1):
    return window(t, t1) * ssg.chirp(t, f0, t1, f1)

# cross correlation function
def xcorr(u1, u2):
    return np.convolve(u1, u2[::-1], 'same')

# returns a function-object with fixed chirp parameters
def chirp_template(f0, t1, f1):
    return partial(wchirp, f0=f0, t1=t1, f1=f1)

class AcouEnv():
    O  = (0,0,0) # origin
    c0 = 1500    # speed of sound [m/s]
    h  = 2500

class AcouStreamer(AcouEnv):
    def __init__(self, f_s, footprint, beacon, noise_level = 0.0, t_margin = 0.0):
        self.f_s         = f_s
        self.beacon      = beacon
        self.footprint   = footprint
        self.calc_delays(footprint)
        self.init_timebase(t_margin)
        self.reset_noise(noise_level)
    
    def calc_delays(self, footprint):
        self.distances = dist(footprint, self.beacon) - dist(self.O, self.beacon) 
        
    def init_timebase(self, t_margin):
        self.T   = 1.5 * (np.max(self.delays) + t_margin)
        self.N   = int(np.ceil(self.T * self.f_s))
        self.T   = self.N / self.f_s
        self.n_s = 1 + 4 * self.N
        self.timebase = np.linspace(-2 * self.T, 2 * self.T, self.n_s)
        self.sect = slice(self.N, 1 + 3 * self.N)
        
    def reset_noise(self, noise_level=0.0):
        shape = (len(self.delays), self.n_s)
        self.streams = np.random.normal(0.0, noise_level, shape)
        
    def sim_signals(self, template):
        self.reference = template(self.timebase)
        for stream, delay in zip(self.streams, self.delays):
            stream[self.sect] += template(self.timebase[self.sect] - delay)
            
    def calibrate_delays(self):
        self.rec_shifts = np.empty_like(self.delays, dtype=int)    
        for i, stream in enumerate(self.streams):
            cross_correlation = xcorr(stream[self.sect], self.reference[self.sect])
            self.rec_shifts[i] = np.argmax(cross_correlation) - self.sect.start
            
    @property
    def t(self):
        return self.timebase[self.sect]
    
    @property
    def s(self):
        return self.streams[:, self.sect]
   
    @property
    def delays(self):
        return self.distances / self.c0

    @property
    def calibrated_distances(self):
        return dist(self.O, self.beacon) + (self.rec_shifts / self.f_s) * self.c0
    
class BeamFormer(AcouStreamer):        
    
    def beamform(self, nominal_footprint, probe):
        W = np.zeros_like(self.timebase[self.sect])
        shifts = self.f_s * ((dist(probe, nominal_footprint) - dist(probe, 0)) / self.c0)
        shifts = shifts.astype('int')
        for stream, shift in zip(self.streams, shifts):
            shf = slice(self.sect.start + shift, self.sect.stop + shift)
            W  += stream[shf]
        X = xcorr(W, self.reference[self.sect])
        return W, X
          
    @property
    def delays(self):
        return self.distances / self.c0
    

    
    
''' meshgrid helper '''
def gen_probegrid(center, size, steps):
    axis = np.linspace(-size, +size, steps)
    P = np.meshgrid(axis, axis)
    P[0] += center[0]
    P[1] += center[1]
    return P
'''
Generate a meshgrid:
- around a position (beacon)
- of given angular size (size_deg)
- of given number of stesp (steps)
'''
def make_grid(beacon, size_deg, steps):
    r = dist(beacon, (0,0,0))
    size_rad = np.deg2rad(size_deg)
    size_m   = r * size_rad
    return gen_probegrid(beacon, size_m, steps)


class BeamScanner:
    def __init__(self, beamformer, nominal_footprint):
        self.beamformer = beamformer
        self.footprint = nominal_footprint
        self.X = []
        self.W = []
        
    def scan(self, probe_grid, z_grid):
        PX, PY = probe_grid
        self.Z = np.zeros_like(PX)
        for i in tqdm(range(len(PX)), desc='X'):
            for j in tqdm(range(len(PY)), desc='Y', leave=False):
                probe = np.array([PX[i,j], PY[i,j], z_grid])
                W, X = self.beamformer.beamform(self.footprint, probe)
                self.W.append(W)
                self.X.append(X)
                Q = np.max(X)
                self.Z[i, j] = Q
               
                
def mse(P, locations, distances):
    mse = 0.0
    for location, distance in zip(locations, distances):
        d = dist(P, location)
        mse += (d - distance)**2 
    mse /= len(locations)
    return mse