# -*- coding: utf-8 -*-
import os
os.chdir(os.path.dirname(__file__))
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('../../vinstyle.mplstyle')
import sys
sys.path.append('../..')
from fitandgraph import fit
import uncertainties as unc
import lmfit
import pandas as pd


def get_res(a1, a1err, a2, a2err):
    Ap = unc.ufloat(a1, a1err)
    Am = unc.ufloat(a2, a2err)
    A = max(Ap, Am) / min(Ap, Am)
    n = 1/(A-1)
    P = 1 - 1/A
    return n, P

def broken_axis3(xlim1, xlim2):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True)
    fig.subplots_adjust(wspace=0.1)  # adjust space between axes
    # zoom-in / limit the view to different portions of the data
    ax1.set_xlim(-xlim2, -xlim1)  
    ax3.set_xlim(xlim1, xlim2)  
    # hide the spines between ax and ax2
    ax1.spines['right'].set_visible(False)
    ax2.spines['left'].set_visible(False)
    ax2.spines['right'].set_visible(False)
    ax3.spines['left'].set_visible(False)
    ax1.yaxis.tick_left()
    ax2.tick_params(left=False, right=False)
    ax3.yaxis.tick_right()
    d = .015 # how big to make the diagonal lines in axes coordinates
    # arguments to pass plot, just so we don't keep repeating them
    kwargs = dict(transform=ax1.transAxes, color='k', clip_on=False)
    ax1.plot((1-d,1+d), (-d,+d), **kwargs)
    ax1.plot((1-d,1+d),(1-d,1+d), **kwargs)
    kwargs.update(transform=ax2.transAxes)  # switch to the bottom axes
    ax2.plot((-d,+d), (1-d,1+d), **kwargs)
    ax2.plot((-d,+d), (-d,+d), **kwargs)
    ax2.plot((1-d,1+d), (-d,+d), **kwargs)
    ax2.plot((1-d,1+d),(1-d,1+d), **kwargs)
    kwargs.update(transform=ax3.transAxes)  # switch to the bottom axes
    ax3.plot((-d,+d), (1-d,1+d), **kwargs)
    ax3.plot((-d,+d), (-d,+d), **kwargs)
    return fig, ax1, ax2, ax3

f = fit(param=[0, 17, 5,0.5,0.1,0])
f.load('ROI0_Re_Measure6_CsAxial.dat')
f.x = (f.x-100.063)*1e3

def doubleGauss(x, xc, w, FWHM, A1, A2, y0):
    return f.doubleGauss(x, A1, xc+w, FWHM, A2, xc-w, FWHM, y0)


fig, ax1, ax2, ax3 = broken_axis3(105, 180)
x, y = f.applyFit(doubleGauss)

fmodel = lmfit.Model(doubleGauss)
fmodel.set_param_hint('A1', min=0)
fmodel.set_param_hint('A2', min=0)
pars = fmodel.make_params(A1=0.1,xc=0, FWHM=5, A2=0.5, w=17, y0=0.017)
def residual(p):
    return (f.y - doubleGauss(f.x, p['xc'], p['w'], p['FWHM'], p['A1'], p['A2'], p['y0']))/f.yerr
mini = lmfit.Minimizer(residual, pars)
rslt = mini.minimize()
ps = rslt.params
ci = lmfit.conf_interval(mini, rslt)
# lmfit.printfuncs.report_ci(ci)

x = np.linspace(-30, 30, 200)
y = fmodel.eval(ps, x=x)
ax2.errorbar(f.x-ps['xc'].value, f.y, yerr=f.yerr, fmt='o', color='C0')
ax3.errorbar(f.x-ps['xc'].value, f.y, yerr=f.yerr, fmt='o', color='C0', label='$n_\mathrm{z}=0.09^{+0.05}_{-0.04}$')
ax2.plot(x-ps['xc'].value, y, color='C0')
n1cs, n1upcs, n1locs = [1/(ci['A2'][i][1]/ci['A1'][j][1] - 1) for (i,j) in [(3,3), (2,4), (4,2)]]
P1cs, P1upcs, P1locs = [1 - ci['A1'][i][1]/ci['A2'][j][1] for (i,j) in [(3,3), (2,4), (4,2)]]
print('Cs Axial:')
print('Mean motional level\t---\t%.3g +%.1g -%.1g'%(n1cs,n1upcs-n1cs,n1cs-n1locs))
print('GS probability\t\t---\t%.3g +%.1g -%.1g'%(P1cs,P1upcs-P1cs,P1cs-P1locs))

data = [f.x-ps['xc'].value, f.y, f.yerr, x-ps['xc'].value, y]


########    #########   ###########     ##########

f = fit(param=[0.2,5,0.01,130,0.2,5,0.01,159,0])
f.load('ROI0_Re_Measure5_CsRadial.dat')
f.x = (f.x-100.0153)*1e3

def quadgauss(x, A1, g1,A2, A3,g3,A4,y0): 
    return f.doubleGauss(x, A1, -131, g1,A2, 131, g1, y0
                         ) + f.doubleGauss(x, A3, -159.2, g3,A4, 159.2, g3, y0)

xsplit = 145
x0 = np.linspace(-xsplit,-105,200)
x1 = np.linspace(105,xsplit,200)
x2 = np.linspace(-180,-xsplit,200)
x3 = np.linspace(xsplit,180,200)
inds = np.where(f.x > -xsplit,1,0)* np.where(f.x<xsplit,1,0)
i1 = [i for i in range(len(inds)) if inds[i]]
i2 = [i for i in range(len(inds)) if not inds[i]]

fmodel = lmfit.Model(quadgauss)
fmodel.set_param_hint('A1', min=0, max=1)
fmodel.set_param_hint('A2', min=0, max=1)
fmodel.set_param_hint('A3', min=0, max=1)
fmodel.set_param_hint('A4', min=0, max=1)
fmodel.set_param_hint('y0', min=0, max=0.1)
fmodel.set_param_hint('g1', min=0, max=20)
fmodel.set_param_hint('g3', min=0, max=20)
pars = fmodel.make_params(A1=0.5, g1=5, A2=0.1, A3=0.5, g3=5, A4=0.1, y0=0.019)
def residual(p):
    return (f.y - quadgauss(f.x, p['A1'], p['g1'], p['A2'], p['A3'], p['g3'], p['A4'], p['y0']))/f.yerr

mini = lmfit.Minimizer(residual, pars)
rslt = mini.minimize()
ps = rslt.params
ci = lmfit.conf_interval(mini, rslt)
# lmfit.printfuncs.report_ci(ci)

for x, c, ax in [[x0,'C3', ax1], [x1, 'C3', ax3], [x2, 'C4', ax1], [x3, 'C4', ax3]]:
    ax.plot(x, fmodel.eval(ps, x=x), color=c)
ax1.errorbar(f.x[i1], f.y[i1], yerr=f.yerr[i1], fmt='o', color='C3')
ax1.errorbar(f.x[i2], f.y[i2], yerr=f.yerr[i2], fmt='o', color='C4')
ax3.errorbar(f.x[i2], f.y[i2], yerr=f.yerr[i2], fmt='o', color='C4', label='$n_\mathrm{y}=0.05^{+0.04}_{-0.03}$')
ax3.errorbar(f.x[i1], f.y[i1], yerr=f.yerr[i1], fmt='o', color='C3', label='$n_\mathrm{x}=0.13^{+0.06}_{-0.05}$')


ax2.set_xlabel('2-Photon Detuning (kHz)')
ax1.set_ylabel(r'$|F=3\rangle$ Population')
y0, y1 = ax1.get_ylim()
ax1.set_ylim(0,y1)
ax2.set_ylim(0,y1)
ax3.set_ylim(0,y1)
ax3.legend()
plt.savefig('CsSidebands_brokenaxes.svg')
plt.show()


#%%
n2cs, n2upcs, n2locs = [1/(ci['A1'][i][1]/ci['A2'][j][1] - 1) for (i,j) in [(3,3), (2,4), (4,2)]]
P2cs, P2upcs, P2locs = [1 - ci['A2'][i][1]/ci['A1'][j][1] for (i,j) in [(3,3), (2,4), (4,2)]]
print('Cs Radial1:')
print('Mean motional level\t---\t%.3g +%.1g -%.1g'%(n2cs,n2upcs-n2cs,n2cs-n2locs))
print('GS probability\t\t---\t%.3g +%.1g -%.1g'%(P2cs,P2upcs-P2cs,P2cs-P2locs))
n3cs, n3upcs, n3locs = [1/(ci['A3'][i][1]/ci['A4'][j][1] - 1) for (i,j) in [(3,3), (2,4), (4,2)]]
P3cs, P3upcs, P3locs = [1 - ci['A4'][i][1]/ci['A3'][j][1] for (i,j) in [(3,3), (2,4), (4,2)]]
print('Cs Radial1:')
print('Mean motional level\t---\t%.3g +%.1g -%.1g'%(n3cs,n3upcs-n3cs,n3cs-n3locs))
print('GS probability\t\t---\t%.3g +%.1g -%.1g'%(P3cs,P3upcs-P3cs,P3cs-P3locs))

P0cs = P1cs*P2cs*P3cs
P0uperrcs = np.sqrt((P1upcs*P2cs*P3cs-P0cs)**2 + (P1cs*P2upcs*P3cs-P0cs)**2 +(P1cs*P2cs*P3upcs-P0cs)**2)
P0loerrcs = np.sqrt((P1locs*P2cs*P3cs-P0cs)**2 + (P1cs*P2locs*P3cs-P0cs)**2 +(P1cs*P2cs*P3locs-P0cs)**2)
print('\nGS Population = %.3g +%.1g -%.1g'%(P0cs, P0uperrcs, P0loerrcs))

########    #########   ###########     ##########

#%%

f = fit(param = [0, 17, 5,0.5,0.1,0])
f.load('ROI1_Re_Measure19_RbAxial.dat')
f.x = (f.x-99.986)*1e3

fig, ax1, ax2, ax3 = broken_axis3(105, 180)
fmodel = lmfit.Model(doubleGauss)
fmodel.set_param_hint('A1', min=0)
fmodel.set_param_hint('A2', min=0)
pars = fmodel.make_params(A1=0.1,xc=0, FWHM=5, A2=0.5, w=17, y0=0.017)
def residual(p):
    return (f.y - doubleGauss(f.x, p['xc'], p['w'], p['FWHM'], p['A1'], p['A2'], p['y0']))/f.yerr
mini = lmfit.Minimizer(residual, pars)
rslt = mini.minimize()
ps = rslt.params
ci = lmfit.conf_interval(mini, rslt)

ax2.errorbar(f.x-ps['xc'].value, f.y, yerr=f.yerr, fmt='o', color='C0')
ax3.errorbar(f.x-ps['xc'].value, f.y, yerr=f.yerr, fmt='o', color='C0', label='$n_\mathrm{z}=0.09^{+0.04}_{-0.03}$')
x = np.linspace(-30, 30, 200)
ax2.plot(x-ps['xc'].value, fmodel.eval(ps, x=x), color='C0')

data = [f.x-ps['xc'].value, f.y, f.yerr, x-ps['xc'].value, fmodel.eval(ps, x=x)]

n1rb, n1uprb, n1lorb = [1/(ci['A2'][i][1]/ci['A1'][j][1] - 1) for (i,j) in [(3,3), (2,4), (4,2)]]
P1rb, P1uprb, P1lorb = [1 - ci['A1'][i][1]/ci['A2'][j][1] for (i,j) in [(3,3), (2,4), (4,2)]]
print('rb Axial:')
print('Mean motional level\t---\t%.3g +%.1g -%.1g'%(n1rb,n1uprb-n1rb,n1rb-n1lorb))
print('GS probability\t\t---\t%.3g +%.1g -%.1g'%(P1rb,P1uprb-P1rb,P1rb-P1lorb))


########    #########   ###########     ##########

f = fit(param = [0.2,5,0.01,130,0.2,5,0.01,158,0])
f.load('ROI1_Re_Measure1_RbRadial.dat')
f.x = (f.x-99.93)*1e3

def quadgauss(x, A1, g1,A2, A3,g3,A4,y0): 
    return f.doubleGauss(x, A1, -129.9, g1,A2, 129.9, g1, y0
                         ) + f.doubleGauss(x, A3, -157.99, g3,A4, 157.99, g3, y0)

xsplit = 145
x0 = np.linspace(-xsplit,-105,200)
x1 = np.linspace(105,xsplit,200)
x2 = np.linspace(-180,-xsplit,200)
x3 = np.linspace(xsplit,180,200)
inds = np.where(f.x > -xsplit,1,0)* np.where(f.x<xsplit,1,0)
i1 = [i for i in range(len(inds)) if inds[i]]
i2 = [i for i in range(len(inds)) if not inds[i]]

fmodel = lmfit.Model(quadgauss)
fmodel.set_param_hint('A1', min=0, max=1)
fmodel.set_param_hint('A2', min=0, max=1)
fmodel.set_param_hint('A3', min=0, max=1)
fmodel.set_param_hint('A4', min=0, max=1)
fmodel.set_param_hint('y0', min=0, max=0.1)
fmodel.set_param_hint('g1', min=0, max=20)
fmodel.set_param_hint('g3', min=0, max=20)
pars = fmodel.make_params(A1=0.5, g1=5, A2=0.1, A3=0.5, g3=5, A4=0.1, y0=0.019)
def residual(p):
    return (f.y - quadgauss(f.x, p['A1'], p['g1'], p['A2'], p['A3'], p['g3'], p['A4'], p['y0']))/f.yerr

mini = lmfit.Minimizer(residual, pars)
rslt = mini.minimize()
ps = rslt.params
ci = lmfit.conf_interval(mini, rslt)

for x, c, ax in [[x0,'C3', ax1], [x1, 'C3', ax3], [x2, 'C4', ax1], [x3, 'C4', ax3]]:
    ax.plot(x, fmodel.eval(ps, x=x), color=c)
ax1.errorbar(f.x[i1], f.y[i1], yerr=f.yerr[i1], fmt='o', color='C3')
ax1.errorbar(f.x[i2], f.y[i2], yerr=f.yerr[i2], fmt='o', color='C4')
ax3.errorbar(f.x[i2], f.y[i2], yerr=f.yerr[i2], fmt='o', color='C4', label='$n_\mathrm{y}=0.02^{+0.02}_{-0.02}$')
ax3.errorbar(f.x[i1], f.y[i1], yerr=f.yerr[i1], fmt='o', color='C3', label='$n_\mathrm{x}=0.08^{+0.03}_{-0.03}$')


ax2.set_xlabel('2-Photon Detuning (kHz)')
ax1.set_ylabel(r'$|F=1\rangle$ Population')
y0, y1 = ax1.get_ylim()
ax1.set_ylim(0,y1)
ax2.set_ylim(0,y1)
ax3.set_ylim(0,y1)
ax3.legend()
plt.savefig('RbSidebands_brokenaxes.svg')
plt.show()


#%%

n2rb, n2uprb, n2lorb = [1/(ci['A1'][i][1]/ci['A2'][j][1] - 1) for (i,j) in [(3,3), (2,4), (4,2)]]
P2rb, P2uprb, P2lorb = [1 - ci['A2'][i][1]/ci['A1'][j][1] for (i,j) in [(3,3), (2,4), (4,2)]]
print('rb Radial1:')
print('Mean motional level\t---\t%.3g +%.1g -%.1g'%(n2rb,n2uprb-n2rb,n2rb-n2lorb))
print('GS probability\t\t---\t%.3g +%.1g -%.1g'%(P2rb,P2uprb-P2rb,P2rb-P2lorb))
n3rb, n3uprb, n3lorb = [1/(ci['A3'][i][1]/ci['A4'][j][1] - 1) for (i,j) in [(3,3), (2,4), (4,2)]]
P3rb, P3uprb, P3lorb = [1 - ci['A4'][i][1]/ci['A3'][j][1] for (i,j) in [(3,3), (2,4), (4,2)]]
print('rb Radial2:')
print('Mean motional level\t---\t%.3g +%.1g -%.1g'%(n3rb,n3uprb-n3rb,n3rb-n3lorb))
print('GS probability\t\t---\t%.3g +%.1g -%.1g'%(P3rb,P3uprb-P3rb,P3rb-P3lorb))

P0rb = P1rb*P2rb*P3rb
P0uperrrb = np.sqrt((P1uprb*P2rb*P3rb-P0rb)**2 + (P1rb*P2uprb*P3rb-P0rb)**2 +(P1rb*P2rb*P3uprb-P0rb)**2)
P0loerrrb = np.sqrt((P1lorb*P2rb*P3rb-P0rb)**2 + (P1rb*P2lorb*P3rb-P0rb)**2 +(P1rb*P2rb*P3lorb-P0rb)**2)
print('\nGS Population = %.3g +%.1g -%.1g'%(P0rb, P0uperrrb, P0loerrrb))


def Pnrel(Pcs, ncs, Prb, nrb):
    return Pcs*Prb / (1 - 87/(133+87)*nrb/(nrb+1) - 133/(87+133)*ncs/(ncs+1))

P0uperr = np.sqrt(((P0uperrcs+P0cs)*P0rb-P0cs*P0rb)**2 + (P0cs*(P0rb+P0uperrrb)-P0cs*P0rb)**2)
P0loerr = np.sqrt(((P0loerrcs+P0cs)*P0rb-P0cs*P0rb)**2 + (P0cs*(P0rb+P0loerrrb)-P0cs*P0rb)**2)
print('\nP0Cs*P0Rb = %.3g + %.1g - %.1g'%(P0cs*P0rb, P0uperr, P0loerr))

#%%
# def Pnrel(Pcs, Prb):
#     return Pcs*Prb / (1 - 87/(133+87)*(1-Prb) - 133/(87+133)*(1-Pcs))

Pnrelmean = Pnrel(P1cs, n1cs, P1rb, n1rb) * Pnrel(P2cs, n2cs, P2rb, n2rb) * Pnrel(P3cs, n3cs, P3rb, n3rb)
means = [P1cs, n1cs, P1rb, n1rb, P2cs, n2cs, P2rb, n2rb, P3cs, n3cs, P3rb, n3rb]
ups = [P1upcs, n1upcs, P1uprb, n1uprb, P2upcs, n2upcs, P2uprb, n2uprb, P3upcs, n3upcs, P3uprb, n3uprb]
los = [P1locs, n1locs, P1lorb, n1lorb, P2locs, n2locs, P2lorb, n2lorb, P3locs, n3locs, P3lorb, n3lorb]
uperrs = np.zeros(len(means))
loerrs = np.zeros(len(means))
for i in range(len(means)):
    vals = means.copy()
    vals[i] = ups[i]
    uperrs[i] = Pnrel(*vals[:4]) * Pnrel(*vals[4:8]) * Pnrel(*vals[8:]) - Pnrelmean
    vals[i] = los[i]
    loerrs[i] = Pnrelmean - Pnrel(*vals[:4]) * Pnrel(*vals[4:8]) * Pnrel(*vals[8:])
print('P(nrel=0) = %.3g +%.1g -%.1g'%(Pnrelmean, sum(uperrs**2)**0.5, sum(loerrs**2)**0.5))