"""Extract centre of loss feature from plots of 2-body recapture vs B field at
different tweezer powers. Combine these to extract the LS, then see how the
LS changes at different tweezer wavelengths."""
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
plt.style.use(r'../../vinstyle.mplstyle')
import os
os.chdir(os.path.dirname(__file__))
import sys
sys.path.append('../..')
from fitandgraph import fit, c, e, a0, hbar, weighted_std, uB


# folder, measure number, [1064 power before SLM, single/double Gaussian]
f0 = 281.782 # estimated PA resonance position in THz
uBkHz = 2*4200 # differential magnetic moment between atomic/FB state in kHz/G: 3*uB/h
file_dic = {'281.364THz197G': [(8,[15.63,1]), (10,[4.11,2]), 
                                           (11,[1.04,2]), (12,[31,1])],
            '281.624THz197G': [(16,[19,2]), (17,[33.5,2]), (18,[5.2,1])],
            '281.631THz197G': [(22,[33.5,1]), (23,[19,2]), (26,[26.4,1])],
            '281.7537THz197G': [(8,[19,1]), (9,[5.2,1]), (10,[12,1])],
            '281.864THz197G': [(23,[6.45,2]), (24,[15.6, 1]), (25,[1.8, 2]),
                               (26,[9.5, 2]), (27,[12.6, 1])]}

fname = '/Comp0.Measure{0}.dat'
ylabel = 'Differential Polarisability \n(kHz/(kW/cm$^2$))'
rslts = pd.DataFrame(np.zeros((len(file_dic),3)), columns=
     ['Laser Frequency (THz)',ylabel, 
      'Error in '+ylabel])
allrslts = pd.DataFrame(np.zeros((18,4)), columns=['Laser Frequency (THz)',
        'Intensity (kW/cm2)','Loss Centre (G)', 'Error in Loss Centre (G)'])
ii = 0
power_calibration = 2/2*0.86 /(np.pi* 1.19e-6**2)/1e10
Bfield_calibration1 = lambda V: 200.05 + 2.77*V
Bfield_calibration2 = lambda V: 194.97 + 2.143*V

for i, key in enumerate(file_dic.keys()):  # file_dic.keys()
    rslts['Laser Frequency (THz)'][i] = float(key.split('THz')[0])
    folder = file_dic[key]
    intensity = np.array([v[1][0] for v in folder]) * power_calibration
    centre = np.zeros(len(intensity))
    err_centre = np.zeros(len(intensity))
    width = np.zeros(len(intensity))
    err_width = np.zeros(len(intensity))
    
    fig, axs = plt.subplots(len(intensity), 1, sharex=True, figsize=(6,6))
    plt.subplots_adjust(hspace=0.0)
    inds = list(reversed(np.argsort(intensity)))

    
    for j in range(len(intensity)):
        measure = folder[inds[j]]
        raw = pd.read_csv(key+fname.format(measure[0]), skiprows=2)
        if '281.364THz' in key or '281.864THz' in key:
            x = Bfield_calibration2(raw['User variable'])
            B0 = 197.1
        else: 
            x = Bfield_calibration1(raw['User variable'])
            B0 = 197.3
        y = raw['0 atom survival probability']
        ey = raw['Error in 0 atom survival probability']
        df = pd.DataFrame(np.array((x,y,ey)).T, columns=['x','y','yerr'])
        ave = df.groupby(df.x).apply(lambda v: np.average(v.y, weights=v.yerr))
        averr = df.groupby(df.x).apply(weighted_std) # ##df.groupby(df.x).apply(lambda v: 1/np.sqrt(np.sum(1/v.yerr**2)))#
        
        I = intensity[inds[j]]
        axs[j].errorbar(ave.keys(), ave.values, averr, fmt='o', label='%.3g kW/cm$^2$'%I, ms=4)
        
        c0 = B0 + 0.0103*I*0.028/(f0-rslts['Laser Frequency (THz)'][i])
        print(measure, I, c0, j, inds[j])
        indx = 1
        if measure[1][1] %2:
            f = fit(ave.keys(), ave.values, averr, param = [0.5,c0, 0.3,0.2]) 
            axs[j].plot(*f.applyFit(f.offGauss), color='C0')
        else:
            f = fit(ave.keys(), ave.values, averr, param = [0.5,c0,0.3,0.5,c0-0.35*(f0-rslts['Laser Frequency (THz)'][i])/0.3,0.1,0.2]) 
            axs[j].plot(*f.applyFit(f.doubleGauss), color='C0')
            if f.ps[4] > f.ps[1]:
                indx = 4
            
        axs[j].set_ylim(0.0, 0.9)
        axs[j].legend()
        centre[inds[j]] = f.ps[indx]
        err_centre[inds[j]] = f.perrs[indx]
        width[inds[j]] = f.ps[indx+1]
        err_width[inds[j]] = f.perrs[indx+1]
        allrslts['Laser Frequency (THz)'][ii] = rslts['Laser Frequency (THz)'][i]
        allrslts['Intensity (kW/cm2)'][ii] = I
        allrslts['Loss Centre (G)'][ii] = f.ps[indx]
        allrslts['Error in Loss Centre (G)'][ii] = f.perrs[indx]
        ii += 1

        
    axs[len(axs)//2].set_ylabel(r'P$_{2\rightarrow 0}$')
    axs[-1].set_xlabel('B Field (G)')
    plt.tight_layout()
    plt.subplots_adjust(hspace=0.0, top=0.98, bottom=0.1, left=0.15, right=0.95)
    if '281.364THz' in key:
        plt.xlim(196,199)
    plt.show()
    
    g = fit(intensity, centre, err_centre)
    plt.figure(figsize=(2.5,2))
    plt.errorbar(g.x, g.y, g.yerr, fmt='o')
    plt.plot(*g.applyFit(g.linear), color='C0')
    print(rslts['Laser Frequency (THz)'][i], '\n', g.report())
    plt.xlabel('Intensity (kW/cm$^2$)')
    plt.ylabel('Loss Centre (G)')
    rslts[ylabel][i] = g.ps[0]*uBkHz
    rslts['Error in '+ylabel][i] = g.perrs[0]*uBkHz
    plt.savefig(key+'_shift.svg')


#%%
x = np.vstack((allrslts['Laser Frequency (THz)'].values, allrslts['Intensity (kW/cm2)'].values))
hall= fit(x, allrslts['Loss Centre (G)'].values, allrslts['Error in Loss Centre (G)'], 
          param=[281.783, 0.06,197.1])
def LS(xy, f0, d, B0):
    """Light shift is \Omega^2/2\Delta hbar/2mu_ag
    Where mu_ag = 3uB, Rabi freq \Omega = d.E/hbar and E^2 = 2**Z0*I"""
    f, I = xy
    return B0 - 2*376.730 * (d*e*a0/hbar)**2*I /2 /2/np.pi/(f-f0)/1e12 *hbar/3/uB/2*1e11# in G/(kW/cm2)
hall.getBestFit(LS)
hall.args = ['Resonance Frequency (THz)', 'Dipole moment (a.u.)', 'Unshifted Field (G)']
print(hall.report())

#%%
# fit dipole moment and resonant frequency
h = fit(rslts['Laser Frequency (THz)'], rslts[ylabel], rslts['Error in '+ylabel], [281.8,0.1])

def LS(f, f0, d):
    """Light shift is \Omega^2/2\Delta=\alpha E^2 /2
    Where Rabi freq \Omega = d.E/hbar and E^2 = 2**Z0*I"""
    return -2*376.730 * (d*e*a0/hbar/2/np.pi)**2 /2 /(f-f0)/1e12 *1e4 # in kHz/(kW/cm2)

fig = plt.figure(constrained_layout=True, figsize=(8,4))
fig1, fig2 = fig.subfigures(1,2)
# (ax2, ax1) = fig2.subplots(2, 1, gridspec_kw={'height_ratios':[1,2]})
ax1 = fig2.add_subplot()
ax1.errorbar(h.x , h.y, h.yerr, color='k', fmt='o')
ax1.set_ylabel(ylabel)
ax1.set_xlabel('Laser Frequency (THz)')
ax1.plot(*h.applyFit(LS), color='k')
h.args = ['Resonance Frequency (THz)', 'Dipole Moment (a.u.)']
print(h.report())
ax1.set_ylim(-50,130)
ax1.set_xlim(281.27, 281.88)

key = '281.7537THz197G'
freq = float(key.split('THz')[0])
folder = file_dic[key]
intensity = np.array([v[1][0] for v in folder]) * power_calibration
inds = list(reversed(np.argsort(intensity)))
centre = np.zeros(len(intensity))
err_centre = np.zeros(len(intensity))

axs = fig1.subplots(len(intensity), 1, sharex=True, sharey=True)

handles = []
labels = []
for j in range(len(intensity)):
    measure = folder[inds[j]]
    raw = pd.read_csv(key+fname.format(measure[0]), skiprows=2)
    x = Bfield_calibration1(raw['User variable'])
    B0 = 197.3
    y = raw['0 atom survival probability']
    ey = raw['Error in 0 atom survival probability']
    df = pd.DataFrame(np.array((x,y,ey)).T, columns=['x','y','yerr'])
    ave = df.groupby(df.x).apply(lambda v: np.average(v.y, weights=v.yerr))
    averr = df.groupby(df.x).apply(weighted_std) 
    
    I = intensity[inds[j]]
    handles.append(axs[j].errorbar(ave.keys(), ave.values, yerr=averr, color='C%s'%j,
                        fmt='o', label='%.3g kW/cm$^2$'%I, ms=4)[0])
    labels.append('%.3g kW/cm$^2$'%I)
    colour = handles[j].get_color()
    c0 = B0 + 0.0103*I*0.028/(f0-freq)
    f = fit(ave.keys(), ave.values, averr, param = [0.5,c0, 0.3,0.2]) 
    axs[j].plot(*f.applyFit(f.offGauss), color=colour)
    
    centre[inds[j]] = f.ps[1]
    err_centre[inds[j]] = f.perrs[2]
    
axs[-1].set_xlim(196.5, 203)
axs[-1].set_xticks([197, 199, 201, 203])
axs[-1].set_xlabel('B Field (G)')
axs[1].set_ylabel(r'$P_{2\rightarrow 0}$')

ax1.legend(handles, labels)


plt.savefig('LaserFrequency_DifferentialPolarisability.svg')
rslts.to_csv('LaserFrequency_DifferentialPolarisability.csv')