"""Lifetime of Feshbach molecule at different tweezer wavelength and intensities."""
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
os.chdir(os.path.dirname(__file__))
plt.style.use('../../vinstyle.mplstyle')
import sys
sys.path.append('../..')
from fitandgraph import fit, c, e, a0, hbar, weighted_std

# folder, measure number, [1064 power before SLM, single/double Gaussian]

f0 = 281.791 # estimated PA resonance position in THz
power_calibration = 2/2*0.86 /(np.pi* 1.19e-6**2)/1e10

fname = '/Comp0.Measure{0}.dat'
ylabel = 'Loss rate (ms$^{-1}$/(kW/cm$^2$))'
folders = ['1065.499', '1065.083', '1064.023', '1064.5']
rslts = pd.DataFrame(np.zeros((len(folders),3)), columns=
     ['Laser Wavelength (nm)',ylabel, 
      'Error in '+ylabel])
allrslts = pd.DataFrame(np.zeros((12,4)), columns=['I','Freq','LR','err'])
k = 0

for i, folder in enumerate(folders):  
    rslts['Laser Wavelength (nm)'][i] = float(folder)
    meta = pd.read_csv(folder+'/metadata.csv')
    intensity = meta[' Power before SLM (mW)']*power_calibration
    tau = np.zeros(len(intensity))
    err_tau = np.zeros(len(intensity))
    plt.figure()
    for j in range(len(intensity)):
        raw = pd.read_csv(folder+fname.format(meta['# Measure'][j]), skiprows=2)
        x = raw['User variable'] + 1
        y = raw['0 atom survival probability']
        ey = raw['Error in 0 atom survival probability']
        df = pd.DataFrame(np.array((x,y,ey)).T, columns=['x','y','yerr'])
        ave = df.groupby(df.x).apply(lambda v: np.average(v.y, weights=v.yerr))
        averr = df.groupby(df.x).apply(weighted_std) # ##df.groupby(df.x).apply(lambda v: 1/np.sqrt(np.sum(1/v.yerr**2)))#
        f = fit(ave.keys(), ave.values, erry=averr, param = [-0.5,10,0.7])  # V


        I = intensity[j]
        colour = plt.errorbar(f.x, f.y, f.yerr, fmt='o', label='%.3g kW/cm$^2$'%I, ms=4)[0].get_color()
        plt.plot(*f.applyFit(f.offeDecay), color=colour)
        
        tau[j] = f.ps[1]
        err_tau[j] = f.perrs[1]
        
        allrslts['I'][k] = I
        allrslts['Freq'][k] = c / rslts['Laser Wavelength (nm)'][i]/1e3
        allrslts['LR'][k] = 1/f.ps[1]
        allrslts['err'][k] = f.perrs[1]/f.ps[1]**2
        k += 1
        
    plt.ylabel(r'P$_{2\rightarrow 0}$')
    plt.xlabel('Hold Duration (ms)')
    plt.legend()
    plt.show()
    
    g = fit(intensity, 1/tau, err_tau/tau**2)
    plt.figure(figsize=(2.5,2))
    plt.errorbar(g.x, g.y, g.yerr, fmt='o')
    try:
        plt.plot(*g.applyFit(lambda x,m: m*x), color='C0')
        print(rslts['Laser Wavelength (nm)'][i], '\n', g.report())
        plt.xlabel('Intensity (kW/cm$^2$)')
        plt.ylabel('Loss rate (ms$^{-1}$)')
        rslts[ylabel][i] = g.ps[0]
        rslts['Error in '+ylabel][i] = g.perrs[0]
    except IndexError:
        rslts[ylabel][i] = 1/tau[-1]/I
        rslts['Error in '+ylabel][i] = 0.5/tau[-1]/I
    # plt.savefig(folder+'_shift.svg')
    
#%%
rslts2 = pd.read_csv('181.4G/results.csv')
rslts2[ylabel] = np.zeros(len(rslts2))
rslts2['Error in '+ylabel] = np.zeros(len(rslts2))
for i, row in rslts2.iterrows():
    intensity = row[' Power before SLM (mW)']*power_calibration
    raw = pd.read_csv('181.4G' + fname.format(int(row['# Measure'])), skiprows=2)
    raw['x'] = raw['User variable'] + 1
    raw.rename(columns={'0 atom survival probability':'y', 
                'Error in 0 atom survival probability':'yerr'},inplace=True)
    ave = raw.groupby(raw.x).apply(lambda v: np.average(v.y, weights=v.yerr))
    averr = raw.groupby(raw.x).apply(weighted_std)
    f = fit(ave.keys(), ave.values, erry=averr, param = [-0.5,10,0.7])
    f.getBestFit(f.offeDecay)
    rslts2.loc[i, ylabel] = 1/f.ps[1]/intensity
    rslts2.loc[i, 'Error in '+ylabel] = 1/f.ps[1]/intensity*((f.perrs[1]/f.ps[1])**2 +(0.7/intensity)**2)**0.5
    
rslts2.to_csv('181.4G/results.csv', index=False)

#%% 
# 196.7G
h = fit(c/rslts['Laser Wavelength (nm)']/1e3, rslts[ylabel], rslts['Error in '+ylabel], [281.8,0.1])
plt.figure(figsize=(8,4.3)) # 8,4.3
plt.errorbar(h.x, h.y, h.yerr, fmt='o', color='k', label='196.7')
plt.xlabel('Laser Frequency (THz)')
plt.ylabel('Imaginary polarisability \n(kHz/(kW/cm$^2$))')
def Rscat(f, f0, d):
    """Scattering rate is Gamma \Omega^2/4\Delta^2 Where Rabi freq \Omega = d.E/hbar
    Assuming far detuned so Delta >> Gamma"""
    return 376.730 * (d*e*a0/hbar)**2 / (2*np.pi*abs(f-f0)*1e12)**2 /4 *1e4 * 5.2e6*2*np.pi

# plt.semilogy(*h.applyFit(Rscat), color='k')
# h.args = ['Resonance Frequency (THz)', 'Dipole Moment (a.u.)']
h.p0 = [0.1]
xth = np.linspace(281.36, 281.864, 500)
h.getBestFit(lambda f, d: Rscat(f, 281.7832, d))
plt.semilogy(xth, Rscat(xth, 281.7832, h.ps[0]), color='k')
h.args = ['Dipole Moment (a.u.)']
print(h.report())

#### other data
# plt.errorbar([281.473], [1/17/19.3], [7/17**2/19.3], fmt='x', label='196', color='C2')
# plt.errorbar([281.7537], [1/9/1.5], [4/9**2/1.5], fmt='s', label='194', color='C3')
plt.errorbar([281.7537, 281.473], [1/60/1.5, 1/90/19.3], [20/60**2/1.5, 60/90**2/19.3], 
              fmt='^', label='185', color='C4')

# 181.4G
h2 = fit(rslts2[' Laser Frequency (THz)'], rslts2[ylabel], rslts2['Error in '+ylabel], [281.8,0.05])
plt.errorbar(h2.x, h2.y, h2.yerr, fmt='o', color='C1', label='181.4')
# plt.semilogy(*h2.applyFit(Rscat), color='C1')
# h2.args = ['Resonance Frequency (THz)', 'Dipole Moment (a.u.)']
# print(h2.report())

### plot the line from stark shifted FR measurements
plt.semilogy(xth, Rscat(xth, 281.7832, 0.064), color='C1')
plt.fill_between(xth, Rscat(xth, 281.7815, 0.062), Rscat(xth, 281.7849, 0.062), color='C1', alpha=0.1)
plt.fill_between(xth, Rscat(xth, 281.7815, 0.066), Rscat(xth, 281.7849, 0.066), color='C1', alpha=0.1)
plt.fill_between(xth, Rscat(xth, 281.7815, 0.062), Rscat(xth, 281.7815, 0.066), color='C1', alpha=0.1)
plt.fill_between(xth, Rscat(xth, 281.7849, 0.062), Rscat(xth, 281.7849, 0.066), color='C1', alpha=0.1)


plt.legend(title='B Field (G)', loc=(0.67,0.05))#
plt.ylim(1e-4, 1)

# subplot
subax = plt.axes([0.25, 0.57, 0.24, 0.27])
subax.errorbar(g.x, g.y, g.yerr, fmt='o')
subax.plot(*g.applyFit(lambda x,m: m*x), color='C0')
subax.set_xlabel('Intensity (kW/cm$^2$)', fontsize=14)
subax.set_ylabel('Loss rate (ms$^{-1}$)', fontsize=14)

plt.savefig('LossRate_wavelength.svg',bbox_inches='tight')
# dfg = pd.DataFrame([g.x, g.y, g.yerr], columns=['Intensity (kW/cm2)','Loss rate (1/ms)'])
# dfg.to_csv('Inset_lossrate_1065.5.csv')
rslts.to_csv('LossRate_wavelength.csv')

#%%
G = 5.2e6*2*np.pi
def Rsc(fI, f0, d):
    """Scattering rate is Gamma \Omega^2/4\Delta^2 Where Rabi freq \Omega = d.E/hbar
    Assuming far detuned so Delta >> Gamma"""
    f, I = fI
    O2 = 376.730 * (d*e*a0/hbar)**2*I*1e4
    return G*O2 / (G*G + 4*(2*np.pi*abs(f-f0)*1e12)**2 + 2*O2)

x = np.vstack((allrslts['Freq'].values, allrslts['I'].values))
hall= fit(x, allrslts['LR'].values, allrslts['err'], param=[281.783, 0.16])
hall.getBestFit(Rsc)
hall.args = ['Resonance Frequency (THz)', 'Dipole moment (a.u.)']
print(hall.report())
