# -*- coding: utf-8 -*-
import os
os.chdir(os.path.dirname(__file__))
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('../../vinstyle.mplstyle')
import sys
sys.path.append('../..')
from fitandgraph import fit
import uncertainties as unc
import lmfit
import pandas as pd

def get_res(a1, a1err, a2, a2err):
    try:
        Ap = unc.ufloat(a1, a1err)
        Am = unc.ufloat(a2, a2err)
        A = max(Ap, Am) / min(Ap, Am)
        n = 1/(A-1)
        P = 1 - 1/A
        return n, P
    except ZeroDivisionError:
        return unc.ufloat(0,0), unc.ufloat(1,0)
    

offset = 0.03

def doubleGauss(x, xc, w, FWHM, A1, A2):
    return f.doubleGauss(x, A1, xc+w, FWHM, A2, xc-w, FWHM, offset)

def quadgauss(x, xc, A1, g1,A2, w1, A3,g3,A4,w3): 
    return f.doubleGauss(x, A1, xc-w1, g1,A2, xc+w1, g1, offset/2
                          ) + f.doubleGauss(x, A3, xc-w3, g3,A4, xc+w3, g3, offset/2)

def broken_axis3(xlim1, xlim2, fig, roi):
    ax1 = fig.add_subplot(2, 6, 1+3*roi)
    ax2 = fig.add_subplot(2, 6, 2+3*roi)
    ax3 = fig.add_subplot(2, 6, 3+3*roi)
    fig.subplots_adjust(wspace=0.1)  # adjust space between axes
    # zoom-in / limit the view to different portions of the data
    ax1.set_xlim(-xlim2, -xlim1)  
    ax3.set_xlim(xlim1, xlim2)  
    # hide the spines between ax and ax2
    ax1.spines['right'].set_visible(False)
    ax2.spines['left'].set_visible(False)
    ax2.spines['right'].set_visible(False)
    ax3.spines['left'].set_visible(False)
    ax1.yaxis.tick_left()
    ax2.tick_params(left=False, labelleft=False, right=False, labelright=False)
    ax3.yaxis.tick_right()
    ax3.tick_params(labelright=False)
    d = .015 # how big to make the diagonal lines in axes coordinates
    # arguments to pass plot, just so we don't keep repeating them
    kwargs = dict(transform=ax1.transAxes, color='k', clip_on=False)
    ax1.plot((1-d,1+d), (-d,+d), **kwargs)
    ax1.plot((1-d,1+d),(1-d,1+d), **kwargs)
    kwargs.update(transform=ax2.transAxes)  # switch to the bottom axes
    ax2.plot((-d,+d), (1-d,1+d), **kwargs)
    ax2.plot((-d,+d), (-d,+d), **kwargs)
    ax2.plot((1-d,1+d), (-d,+d), **kwargs)
    ax2.plot((1-d,1+d),(1-d,1+d), **kwargs)
    kwargs.update(transform=ax3.transAxes)  # switch to the bottom axes
    ax3.plot((-d,+d), (1-d,1+d), **kwargs)
    ax3.plot((-d,+d), (-d,+d), **kwargs)
    return ax1, ax2, ax3

fig = plt.figure()
axs = []

xsplit = 130
x0 = np.linspace(-xsplit,-90,200)
x1 = np.linspace(90,xsplit,200)
x2 = np.linspace(-160,-xsplit,200)
x3 = np.linspace(xsplit,160,200)

oldGSprobs = ['0.75(5,6)', '0.89(3,5)', '0.60(8,13)', '0.67(7,11)']
GSprobs = ['0.72(5,5)', '0.70(07,10)', '0.48(08,12)', '0.67(7,9)']
# centres = [-2.17, -4.16, -2.34, -0.35]

for roi in range(4):
    f = fit(param=[0, 22, 5,0.5,0.1])
    f.load('ROI%s_Re_Measure11.dat'%roi)
    f.x = (f.x-99.967)*1e3
    
    ax1, ax2, ax3 = broken_axis3(90, 180, fig, roi)
    
    x, y = f.applyFit(doubleGauss) # zuplims=f.yerr/2, zlolims=f.yerr/2, 
    # ax2.errorbar(f.x-f.ps[0], f.y, yerr=f.yerr, fmt='o', color='C0')
    # ax2.plot(x-f.ps[0], y, color='C0')
    # n1cs, P1cs = get_res(f.ps[-2], f.perrs[-2], f.ps[-1], f.perrs[-1])
    # print('Rb Axial trap freq: ', unc.ufloat(f.ps[1], f.perrs[1]))
    # print('Mean motional level\t---\t', n1cs)
    # print('GS probability\t\t---\t', P1cs)
    # # print(f.ps)
    
    fmodel = lmfit.Model(doubleGauss) # xc, w, FWHM, A1, A2
    fmodel.set_param_hint('A1', min=0)
    fmodel.set_param_hint('A2', min=0)
    pars = fmodel.make_params(A1=0.1,xc=0, FWHM=5, A2=0.5, w=25)
    def residual(p):
        return (f.y - doubleGauss(f.x, p['xc'], p['w'], p['FWHM'], p['A1'], p['A2']))/f.yerr

    mini = lmfit.Minimizer(residual, pars)
    rslt = mini.minimize()
    ps = rslt.params
    ci = lmfit.conf_interval(mini, rslt)

    ax2.errorbar(f.x-ps['xc'].value, f.y, yerr=f.yerr, fmt='o', color='C0', zorder=1)
    ax2.plot(x, fmodel.eval(ps, x=x), color='k', lw=4, zorder=2)
    ax2.plot(x, fmodel.eval(ps, x=x), color='C0', lw=3, zorder=3)
    data = [f.x-ps['xc'].value, f.y, f.yerr, x, fmodel.eval(ps, x=x)]
    
    n1, n1A1up, n1A2up, n1A1lo, n1A2lo = [1/(ci['A2'][i][1]/ci['A1'][j][1] - 1) for (i,j) in [(3,3), (2,3), (3,4), (4,3),(3,2)]]
    P1, P1A1up, P1A2up, P1A1lo, P1A2lo= [1 - ci['A1'][i][1]/ci['A2'][j][1] for (i,j) in [(3,3), (2,3), (3,4), (4,3),(3,2)]]
    P1up = max(P1A1up, P1A2up)
    P1lo = min(P1A1lo, P1A2lo)
    # print('Rb Axial:')
    n1uperr = ((n1A1up-n1)**2+(n1A2up-n1)**2)**0.5
    n1loerr = ((n1-n1A1lo)**2+(n1-n1A2lo)**2)**0.5
    # print('Mean motional level\t---\t%.3g +%.1g -%.1g'%(n1,n1uperr,n1loerr))
    # print('GS probability\t\t---\t%.3g +%.1g -%.1g'%(P1, ((P1A1up-P1)**2+(P1A2up-P1)**2)**0.5, ((P1A1lo-P1)**2+(P1A2lo-P1)**2)**0.5))


    
    ########    #########   ###########     ##########
    
    f.load('ROI%s_Re_Measure10.dat'%roi)
    f.p0 = [0, 0.8,5,0.1,109,0.8,5,0.1,145]
    f.x = (f.x-99.934)*1e3
    
    x, y = f.applyFit(quadgauss, bounds=[[-10,0,0,0,0,0,0,0,0],[10,1,30,1,200,1,30,1,200]])
    
    inds = np.where(f.x > -xsplit,1,0)* np.where(f.x<xsplit,1,0)
    i1 = [i for i in range(len(inds)) if inds[i]]
    i2 = [i for i in range(len(inds)) if not inds[i]]
    
    fmodel = lmfit.Model(quadgauss)
    fmodel.set_param_hint('A1', min=0, max=1)
    fmodel.set_param_hint('A2', min=0, max=1)
    fmodel.set_param_hint('A3', min=0, max=1)
    fmodel.set_param_hint('A4', min=0, max=1)
    fmodel.set_param_hint('g1', vary=False)
    fmodel.set_param_hint('g3', vary=False)
    fmodel.set_param_hint('xc', vary=False)
    fmodel.set_param_hint('w1', vary=False)
    fmodel.set_param_hint('w3', vary=False)
    pars = fmodel.make_params(xc=f.ps[0], A1=0.2, g1=f.ps[2], A2=0.5, A3=0.5, 
                              g3=f.ps[6], A4=0.1, w1=f.ps[4], w3=f.ps[-1])
    def residual(p):
        return (f.y - quadgauss(f.x, p['xc'], p['A1'], p['g1'], p['A2'], p['w1'], p['A3'], p['g3'], p['A4'], p['w3']))/f.yerr

    mini = lmfit.Minimizer(residual, pars)
    rslt = mini.minimize()
    ps = rslt.params
    ci = lmfit.conf_interval(mini, rslt)
    # lmfit.printfuncs.report_ci(ci)

    ax1.errorbar(f.x[i1], f.y[i1], yerr=f.yerr[i1], fmt='o', color='C3',zorder=1)
    ax1.errorbar(f.x[i2], f.y[i2], yerr=f.yerr[i2], fmt='o', color='C4',zorder=1)
    ax3.errorbar(f.x[i2], f.y[i2], yerr=f.yerr[i2], fmt='o', color='C4',zorder=1)
    ax3.errorbar(f.x[i1], f.y[i1], yerr=f.yerr[i1], fmt='o', color='C3',zorder=1)
    
    for x, c, ax in [[x0,'C3', ax1], [x1, 'C3', ax3], [x2, 'C4', ax1], [x3, 'C4', ax3]]:
        # ax.plot(x, quadgauss(x,*f.ps), color=c)
        ax.plot(x, fmodel.eval(ps, x=x), color='k', lw=4, zorder=2)
        ax.plot(x, fmodel.eval(ps, x=x), color=c, lw=3, zorder=3)

    
    ax3.text(100, 0.45, 'Trap '+str(roi), size=16, ha='center') #+'\n$P(n=0)=$'+GSprobs[roi]
    axs.append((ax1,ax2,ax3))
    ########    #########   ###########     ##########
    
    print('\n', roi)
    
    n2, n2up, n2lo = [1/(ci['A1'][i][1]/ci['A2'][j][1] - 1) for (i,j) in [(3,3), (2,4), (4,2)]]
    P2, P2up, P2lo = [1 - ci['A2'][i][1]/ci['A1'][j][1] for (i,j) in [(3,3), (2,4), (4,2)]]
    if np.isinf(P2up):
        P2up = 1
    # print('Rb Radial1:')
    # print('Mean motional level\t---\t%.3g +%.1g -%.1g'%(n2,n2up-n2,n2-n2lo))
    # print('GS probability\t\t---\t%.3g +%.1g -%.1g'%(P2,P2up-P2,P2-P2lo))
    n3, n3up, n3lo = [1/(ci['A3'][i][1]/ci['A4'][j][1] - 1) for (i,j) in [(3,3), (2,4), (4,2)]]
    P3, P3up, P3lo = [1 - ci['A4'][i][1]/ci['A3'][j][1] for (i,j) in [(3,3), (2,4), (4,2)]]
    if np.isinf(P3up):
        P3up = 1
    # print('Rb Radial2:')
    # print('Mean motional level\t---\t%.3g +%.1g -%.1g'%(n3,n3up-n3,n3-n3lo))
    # print('GS probability\t\t---\t%.3g +%.1g -%.1g'%(P3,P3up-P3,P3-P3lo))

    P0 = P1*P2*P3
    P0loerr = ((P1lo*P2*P3-P0)**2 + (P1*P2lo*P3-P0)**2 + (P1*P2*P3lo-P0)**2)**0.5
    P0uperr = ((P1up*P2*P3-P0)**2 + (P1*P2up*P3-P0)**2 + (P1*P2*P3up-P0)**2)**0.5
    print('\nGS Population = %.3g +%.3g -%.3g'%(P0, P0uperr, P0loerr))


    # n2cs, P2cs = get_res(f.ps[1], f.perrs[1], f.ps[3], f.perrs[3])
    # print('Rb Radial1 trap freq:  ', unc.ufloat(f.ps[4], f.perrs[4]))
    # print('Mean motional level\t---\t', n2cs)
    # print('GS probability\t\t---\t', P2cs)
    
    # n3cs, P3cs = get_res(f.ps[5], f.perrs[5], f.ps[7], f.perrs[7])
    # print('Rb Radial2:  ', unc.ufloat(f.ps[-1], f.perrs[-1]))
    # print('Mean motional level\t---\t', n3cs)
    # print('GS probability\t\t---\t', P3cs)
    
    # P0cs = unc.ufloat(P1, abs(0.5*(P1up+ P1lo)-P1))*P2cs*P3cs
    # print('GS Population = ', P0cs)
    
    ########    #########   ###########     ##########
    
for ax in axs:
    for i in range(3):
        ax[i].set_ylim(0,0.8)
for i in range(2):
    for ax in axs[i]:
        ax.tick_params(labelbottom=False)
    for ax in axs[-i*2-1]:
        ax.tick_params(labelleft=False)
axs[2][1].set_xlabel('2-Photon Detuning (kHz)')
axs[3][1].set_xlabel('2-Photon Detuning (kHz)')
axs[0][0].set_ylabel(r'$|F=1\rangle$ Population')
axs[2][0].set_ylabel(r'$|F=1\rangle$ Population')
fig.set_size_inches(9,6)
plt.subplots_adjust(top=0.95, right=0.97, hspace=0.08)
plt.savefig('4TrapSpectroscopy.svg')
plt.show()
