# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import os
os.chdir(os.path.dirname(__file__))
plt.style.use('../../vinstyle.mplstyle')
import sys
sys.path.append('../..')
from fitandgraph import fit, broken_xaxis
import uncertainties as unc
import lmfit

def get_res(a1, a1err, a2, a2err):
    Ap = unc.ufloat(a1, a1err)
    Am = unc.ufloat(a2, a2err)
    A = max(Ap, Am) / min(Ap, Am)
    n = 1/(A-1)
    P = 1 - 1/A
    return n, P

#%%
f = fit(param = [0, 17, 0.5, 5, 0.5, 5, 0.5, 0.5, 5, 0.5, 0]) 
f.load('ROI0_Re_Measure3_CsAxial_before.dat')
f.x = (f.x-100.04)*1e3

def quintgauss(x, xc, w, Ac, gc, A1, g1, A2, A3, g3, A4, y0):
    return f.offGauss(x, Ac, xc, gc, y0) + f.doubleGauss(x, A1, xc-w, g1,A2, 
           xc+w, g1, y0) + f.doubleGauss(x, A3, xc-w*2, g3,A4, xc+w*2, g3, y0)

def trigauss(x, xc, AC, WC, WS, x0, A1, A2,y0):
    return f.offGauss(x, AC, xc, WC, y0) + f.doubleGauss(x, A1, xc+x0, WS, A2, xc-x0, WS, y0)

plt.figure(figsize=(5,3))
x, y = f.applyFit(quintgauss)
plt.errorbar(f.x-f.ps[0], f.y, yerr=f.yerr, fmt='^', color='C7', label='Partial RSC')
plt.plot(x-f.ps[0], y, color='C7')
data = [f.x-f.ps[0], f.y, f.yerr, x-f.ps[0], y]

n, P = get_res(f.ps[4], f.perrs[4], f.ps[6], f.perrs[6])
print('Cs Axial before RSC:')
print('Mean motional level\t---\t', n)
print('GS probability\t\t---\t', P)

f = fit(param = [0,0.5,10,10,17,0.5,0.1,0])
f.load('ROI0_Re_Measure0_CsAxial.dat')
f.x = (f.x-100.043143)*1e3

fmodel = lmfit.Model(f.doubleGauss)
fmodel.set_param_hint('A1', min=0)
fmodel.set_param_hint('A2', min=0)
pars = fmodel.make_params(A1=0.1,x1=17, FWHM1=5, A2=0.5, x2=17, FWHM2=5, y0=0.017)
pars['x1'].set(expr='-x2')
pars['y0'].set(vary=False)
def residual(p):
    return (f.y - f.doubleGauss(f.x, p['A1'], p['x1'], p['FWHM1'], p['A2'], p['x2'], p['FWHM2'], p['y0']))/f.yerr

mini = lmfit.Minimizer(residual, pars)
rslt = mini.minimize()
ps = rslt.params
ci = lmfit.conf_interval(mini, rslt)
# lmfit.printfuncs.report_ci(ci)

x = np.linspace(-25, -10, 200)
y = fmodel.eval(ps, x=x)
plt.plot(x, y, color='C0')
x = np.linspace(10, 28, 200)
plt.plot(x, fmodel.eval(ps, x=x), color='C0')
n1, n1A1up, n1A2up, n1A1lo, n1A2lo = [1/(ci['A1'][i][1]/ci['A2'][j][1] - 1) for (i,j) in [(3,3), (2,3), (3,4), (4,3),(3,2)]]
P1, P1A1up, P1A2up, P1A1lo, P1A2lo= [1 - ci['A2'][i][1]/ci['A1'][j][1] for (i,j) in [(3,3), (2,3), (3,4), (4,3),(3,2)]]
P1up = max(P1A1up, P1A2up)
P1lo = min(P1A1lo, P1A2lo)
print('Cs Axial:')
n1uperr = ((n1A1up-n1)**2+(n1A2up-n1)**2)**0.5
n1loerr = ((n1-n1A1lo)**2+(n1-n1A2lo)**2)**0.5
print('Mean motional level\t---\t%.3g +%.1g -%.1g'%(n1,n1uperr,n1loerr))
print('GS probability\t\t---\t%.3g +%.1g -%.1g'%(P1, ((P1A1up-P1)**2+(P1A2up-P1)**2)**0.5, ((P1A1lo-P1)**2+(P1A2lo-P1)**2)**0.5))

plt.errorbar(f.x, f.y, yerr=f.yerr, fmt='o', color='C0', label='$n_\mathrm{z} = %.1g^{+%.1g}_{-%.1g}$'%(n1,n1uperr,n1loerr))

data += [f.x, f.y, f.yerr, np.concatenate((np.linspace(-25, -10, 200), x)), np.concatenate((y, fmodel.eval(ps, x=x)))]

plt.xlabel('2-Photon Detuning (kHz)')
plt.ylabel(r'$|f=3\rangle$ Population')
plt.legend()
plt.ylim(0, max(plt.gca().get_ylim()))
plt.savefig('CsAxialSidebandsBefore+After.svg')
plt.show()


#%%

########         #######        #########

f = fit(param = [0.2,5,0.01,87,0.2,5,0.01,130,0])
f.load('ROI0_Re_Measure1_CsRadial_before.dat')
f.x = (f.x-100.003)*1e3

def quadgauss(x, A1, g1,A2, w1, A3,g3,A4,w3,y0): 
    return f.doubleGauss(x, A1, -w1, g1,A2, w1, g1, y0
                         ) + f.doubleGauss(x, A3, -w3, g3,A4, w3, g3, y0)
x, y = f.applyFit( quadgauss)
fig, ax1, ax2 = broken_xaxis(x, y, (-135,-65), (65,135), wspace=0.1, color='C7')
ax1.errorbar(f.x, f.y, yerr=f.yerr, fmt='^', color='C7')
ax2.errorbar(f.x, f.y, yerr=f.yerr, fmt='^', color='C7', label='Before RSC')
n2, P2 = get_res(f.ps[0], f.perrs[0], f.ps[2], f.perrs[2])

data2 = [f.x, f.y, f.yerr, x, y]

print('Cs Radial1 before RSC:')
print('Mean motional level\t---\t', n2)
print('GS probability\t\t---\t', P2)

n3, P3 = get_res(f.ps[3], f.perrs[3], f.ps[5], f.perrs[5])
print('Cs Radial2 before RSC:')
print('Mean motional level\t---\t', n3)
print('GS probability\t\t---\t', P3)


########         #######        #########
f = fit(param = [0.2,5,0.1,0.2,5,0.1,0])
f.load('ROI0_Re_Measure1_CsRadial.dat')
f.x = (f.x-99.993)*1e3

w1, w2 = 84, 123
# w1, w2 = 76, 117
def quadgauss(x, A1, g1,A2, A3,g3,A4,y0): 
    return f.doubleGauss(x, A1, -w1, g1,A2, w1, g1, y0
                         ) + f.doubleGauss(x, A3, -w2, g3,A4, w2, g3, y0)
xsplit = 105
x0 = np.linspace(-xsplit,-65,200)
x1 = np.linspace(65,xsplit,200)
x2 = np.linspace(-135,-xsplit,200)
x3 = np.linspace(xsplit,135,200)
inds = np.where(f.x > -xsplit,1,0)* np.where(f.x<xsplit,1,0)
i1 = [i for i in range(len(inds)) if inds[i]]
i2 = [i for i in range(len(inds)) if not inds[i]]

fmodel = lmfit.Model(quadgauss)
fmodel.set_param_hint('A1', min=0, max=1)
fmodel.set_param_hint('A2', min=0, max=1)
fmodel.set_param_hint('A3', min=0, max=1)
fmodel.set_param_hint('A4', min=0, max=1)
fmodel.set_param_hint('y0', min=0, max=0.1)
fmodel.set_param_hint('g1', min=0, max=20)
fmodel.set_param_hint('g3', min=0, max=20)
pars = fmodel.make_params(A1=0.5, g1=5, A2=0.1, A3=0.5, g3=5, A4=0.1, y0=0)
# pars['y0'].set(vary=False)
def residual(p):
    return (f.y - quadgauss(f.x, p['A1'], p['g1'], p['A2'], p['A3'], p['g3'], p['A4'], p['y0']))/f.yerr

mini = lmfit.Minimizer(residual, pars)
rslt = mini.minimize()
ps = rslt.params
ci = lmfit.conf_interval(mini, rslt)
# lmfit.printfuncs.report_ci(ci)

x, y = f.applyFit(quadgauss, bounds=[[0,0,0,0,0,0,0],[1,30,1,1,30,1,0.1]])
for ax, x, c in [[ax1, x0, 'C3'], [ax2, x1, 'C3'], [ax1, x2, 'C4'], [ax2, x3, 'C4']]:
    # ax.plot(x, quadgauss(x, *f.ps), color=c)
    ax.plot(x, fmodel.eval(ps, x=x), color=c)


ax1.errorbar(f.x[i1], f.y[i1], yerr=f.yerr[i1], fmt='o', color='C3')
ax2.errorbar(f.x[i1], f.y[i1], yerr=f.yerr[i1], fmt='o', color='C3', label='$n_\mathrm{x}=0.000^{+0.007}_{-0.000}$')
ax1.errorbar(f.x[i2], f.y[i2], yerr=f.yerr[i2], fmt='o', color='C4')
ax2.errorbar(f.x[i2], f.y[i2], yerr=f.yerr[i2], fmt='o', color='C4', label='$n_\mathrm{y}=0.02^{+0.02}_{-0.02}$')


fig.supxlabel('2-Photon Detuning (kHz)', size=16)
ax1.set_ylabel(r'$|f=3\rangle$ Population')
y0, y1 = ax1.get_ylim()
ax1.set_ylim(0,y1)
ax2.set_ylim(0,y1)
ax2.legend(fontsize=12)
fig.set_size_inches(5,3)
plt.subplots_adjust(bottom=0.2)
plt.savefig('CsRadialSidebandsBefore+After.svg')
plt.show()

#%%
n2, n2up, n2lo = [1/(ci['A1'][i][1]/ci['A2'][j][1] - 1) for (i,j) in [(3,3), (2,4), (4,2)]]
P2, P2up, P2lo = [1 - ci['A2'][i][1]/ci['A1'][j][1] for (i,j) in [(3,3), (2,4), (4,2)]]
print('Cs Radial1:')
print('Mean motional level\t---\t%.3g +%.1g -%.1g'%(n2,n2up-n2,n2-n2lo))
print('GS probability\t\t---\t%.3g +%.1g -%.1g'%(P2,P2up-P2,P2-P2lo))

n3, n3up, n3lo = [1/(ci['A3'][i][1]/ci['A4'][j][1] - 1) for (i,j) in [(3,3), (2,4), (4,2)]]
P3, P3up, P3lo = [1 - ci['A4'][i][1]/ci['A3'][j][1] for (i,j) in [(3,3), (2,4), (4,2)]]
print('Cs Radial1:')
print('Mean motional level\t---\t%.3g +%.1g -%.1g'%(n3,n3up-n3,n3-n3lo))
print('GS probability\t\t---\t%.3g +%.1g -%.1g'%(P3,P3up-P3,P3-P3lo))


P0 = P1*P2*P3
P0loerr = ((P1lo*P2*P3-P0)**2 + (P1*P2lo*P3-P0)**2 + (P1*P2*P3lo-P0)**2)**0.5
P0uperr = ((P1up*P2*P3-P0)**2 + (P1*P3-P0)**2 + (P1*P2-P0)**2)**0.5
print('\nGS Population = %.3g +%.1g -%.1g'%(P0, P0uperr, P0loerr))

