import os
os.chdir(os.path.dirname(__file__))
import sys
sys.path.append('../..')
from fitandgraph import fit, purple, red, heather, sky
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
plt.style.use('../../vinstyle.mplstyle')
import numpy as np
import qutip as qp
import lmfit
from scipy.special import genlaguerre
from uncertainties import ufloat
import pandas as pd

h    = 6.6260700e-34 # Planck's constant in m^2 kg / s
hbar = 1.0545718e-34 # reduced Planck's constant in m^2 kg / s
kB = 1.38064852e-23  # Boltzmann's constant in m^2 kg s^-2 K^-1

dets = np.linspace(-50,50,100)  # detunings kHz
w    = 107                      # trap freq kHz
eta  = (3.77/w)**.5             # Lamb-Dicke parameter
w2   = 163
eta2 = (3.77/w2)**.5

T = 15 # uK
rabis = 32.4 * np.array([np.exp(-eta**2/2)*genlaguerre(n,0)(eta**2) for n in range(100)])
weights = np.array([np.exp(-n*h*1e3*w / kB/T/1e-6) / sum(np.exp(-i*h*1e3*w / kB/T/1e-6) for i in range(100)) for n in range(100)])
average = np.average(rabis, weights=weights)
print(np.sqrt(np.dot(weights, (rabis - average)**2) / weights.sum()))

#%%

def RamanHamiltonians(w, eta, Or, det, T, N_ph):
    w *= 2*np.pi # convert to rad/s
    Or *= 2*np.pi
    det *= 2*np.pi
    a = qp.tensor(qp.destroy(N_ph),qp.qeye(2)) #anhilation operator
    sm_g2g1 = qp.tensor(qp.qeye(N_ph), qp.sigmap()) # this is spin lowering, sig-
    Hb = det / 2 * qp.tensor(qp.qeye(N_ph), qp.sigmaz()) # bare Hamiltonian, internal
    for i in range(N_ph): Hb += w*i*qp.tensor(qp.fock_dm(N_ph,i), qp.qeye(2)) # number states
    HR = Or / 2 * sm_g2g1 * (-1j*eta*(a + a.dag())).expm() # Raman Hamiltonian
    HR += HR.dag() # add Hermitian conjugate
    psi0 = qp.tensor(qp.thermal_dm(N_ph, qp.n_thermal(w, T*1e-9*kB/hbar), method='operator'), qp.fock_dm(2,0)) 
    return Hb, HR, psi0

def RabiSimulation(tlist=np.linspace(0,0.5,500), w=140, eta=0.22, Or=31, det=140, T=10, N_ph=20):
    """w:    Harmonic oscillator frequency (kHz)
    eta:  Lamb-Dicke parameter
    Or:   Rabi frequency for the carrier transition of the raman beams (kHz)
    det:  detuning from 2-photon resonance (kHz) (+ve for blue sideband)
    T:    Temperature (uK)
    N_ph: Hilbert space dimension for the external motion"""
    Hb, HR, psi0 = RamanHamiltonians(w, eta, Or, det, T, N_ph)
    result = qp.mesolve([Hb, HR], psi0, tlist, e_ops=qp.tensor(qp.qeye(N_ph), qp.fock_dm(2,1)))
    return tlist*1e3, result.expect[0]

def RabiFit(t, OR, T, eta, delta):
    """Use a thermal distribution of Rabi frequencies to fit carrier Rabi oscillations.
    Assume t is time in us. delta is detuning from resonance in same units as OR"""
    nbar = qp.n_thermal(w, T*1e-9*kB/h)
    boltz_factor = nbar/(1. + nbar) #rearrange eq 3.21 to obtain e^{-hbar*omega/(k*T)}
    P = 0
    for n in range(80):
        ORn = OR * np.exp(-eta**2/2)*genlaguerre(n,0)(eta**2)
        Oeff = (ORn**2+delta**2)**0.5
        Pn = boltz_factor**n / sum(np.power(boltz_factor, range(80)))
        P += Pn * ORn**2/Oeff**2 * np.sin(2*np.pi*np.outer(t*1e-3,Oeff)/2)**2
    return P

def RabiFit2(t, OR, T, eta, eta2, delta):
    """Use a thermal distribution of Rabi frequencies to fit carrier Rabi oscillations.
    Assume t is time in us. delta is detuning from resonance in same units as OR
    2D using e^(-ik.r) = e^(-ikx)e^(-iky)"""
    nbar = qp.n_thermal((w+w2)/2, T*1e-9*kB/h)
    boltz_factor = nbar/(1. + nbar) #rearrange eq 3.21 to obtain e^{-hbar*omega/(k*T)}
    P = 0
    for n in range(80):
        ORn = OR * np.exp(-eta**2/2)*genlaguerre(n,0)(eta**2) * np.exp(-eta2**2/2)*genlaguerre(n,0)(eta2**2)
        Oeff = (ORn**2+delta**2)**0.5
        Pn = boltz_factor**n / sum(np.power(boltz_factor, range(80)))
        P += Pn * ORn**2/Oeff**2 * np.sin(2*np.pi*np.outer(t*1e-3,Oeff)/2)**2
    return P

def RabiFit_2temperatures(t, OR, T1,T2, eta, eta2, delta):
    """Use a thermal distribution of Rabi frequencies to fit carrier Rabi oscillations.
    Assume t is time in us. delta is detuning from resonance in same units as OR
    2D using e^(-ik.r) = e^(-ikx)e^(-iky)"""
    nbar1 = qp.n_thermal(w, T1*1e-9*kB/h)
    nbar2 = qp.n_thermal(w2, T2*1e-9*kB/h)
    boltz_1 = nbar1/(1. + nbar1) #rearrange eq 3.21 to obtain e^{-hbar*omega/(k*T)}
    boltz_2 = nbar2/(1. + nbar2) #rearrange eq 3.21 to obtain e^{-hbar*omega/(k*T)}
    P = 0
    for n in range(50):
        for m in range(50):
            ORn = OR * np.exp(-eta**2/2)*genlaguerre(n,0)(eta**2) * np.exp(-eta2**2/2)*genlaguerre(m,0)(eta2**2)
            Oeff = (ORn**2+delta**2)**0.5
            Pn1 = boltz_1**n / sum(np.power(boltz_1, range(50)))
            Pn2 = boltz_2**m / sum(np.power(boltz_2, range(50)))
            P += Pn1*Pn2 * ORn**2/Oeff**2 * np.sin(2*np.pi*np.outer(t*1e-3,Oeff)/2)**2
    return P


def FluctRabiSimulation(event, tlist=np.linspace(0,0.5,500), w=140, eta=0.22, Or=31, det=140, T=10, N_ph=20, nsim=10):
    """The detuning is randomly varied throughout the simulation with s.d. 0.3*det.
    nsims is the number of runs to average over. Estimate error 2% for 100 runs"""
    Hb, HR, psi0 = RamanHamiltonians(w, eta, Or, det, T, N_ph)
    th = np.diag(qp.thermal_dm(N_ph, qp.n_thermal(w, T*1e-9*kB/hbar), method='operator'))
    psi0 = sum(qp.tensor(qp.fock(N_ph,i)*th[i], qp.fock(2,0)) for i in range(N_ph))
    ys = []
    for i in range(nsim):
        result = qp.sesolve([Hb, np.random.normal(0,0.3,len(tlist)), HR], psi0, 
                tlist, e_ops=[qp.tensor(qp.qeye(N_ph), qp.fock_dm(2,1))], 
                progress_bar=qp.ui.EnhancedTextProgressBar())
        ys.append(result.expect[0])
    event.set()
    return tlist*1e3, np.mean(ys, axis=0)

def DampedRabiSimulation(tlist, w=140, eta=0.22, Or=31, det=140, T=10, N_ph=20, stdev=1):
    return RabiSimulation(tlist,w,eta,Or,det,T,N_ph)[1]*np.exp(-0.5*(x*2*np.pi*stdev*1e-6)**2)


p = np.zeros((3, len(dets), 300))

#%%
xs = {}
ys = {}
eys = {}
for i, (label, c1, c2, fn, flct, Or, T, A, m) in enumerate([ 
        ('Thermal', 'C0', 'C7', 'ROI1_Re_Measure18_thermal.dat', [6.9], 31.5, 19, 0.95, 1),
        ('814', 'C1', heather, 'ROI0_Re_Measure18_814.dat', [7.8], 20.2, 2.5, 0.95, 1e3), 
        ('938', 'C2',sky, 'ROI0_Re_Measure3_938.dat', [3.7], 21.2, 1.5, 0.95, 1e3)
        ]):
    plt.figure(figsize=(5,4))    
    f = flct[0]
    for j, d in enumerate(dets):
        times, pup = RabiSimulation(np.linspace(0,8/Or,np.shape(p)[-1]), w=w, eta=eta, Or=Or, det=d, T=T, N_ph=20) # w-Or**2/w/2+
        p[i,j,:] = pup*A ## acount for SPAM 0.95
        plt.plot(times, p[i,j,:], alpha=np.exp(-d**2/f**2/2)*0.95, color=c2)
        print('.',end='')
    
    # add data
    hh = fit()
    hh.load(fn)
    x = np.array(hh.x) *m
    xs[label] = x
    y = np.array(hh.y)
    ys[label] = y
    ey = np.array(hh.yerr)
    eys[label] = ey
    plt.errorbar(x, y, yerr=ey, fmt='o', color=c1)
    # add theory
    for f in flct:
        weights = np.exp(-dets**2/f**2/2)/np.sqrt(2*np.pi)/f * (dets[1]-dets[0])
        plt.plot(times, np.matmul(weights, p[i]), label=label, color=c1)
        yth = interp1d(times, np.matmul(weights, p[i]))
        print(f, np.sum((y-yth(x))**2/ey**2))
    
    # plt.legend()
    plt.xlabel('Pulse Duration ($\mu$s)')
    plt.ylabel(r'$|f=1 \rangle$ Population')
    plt.xlim(0,210) # 210
    plt.hlines(A, *plt.gca().get_xlim(), color='k', ls='--')
    plt.ylim(0,1)
    plt.savefig(label+'_radial.svg')
    plt.show()
    
# np.save(r'p.npy', p)
#%%
p = np.load(r'p.npy')
fig, axs = plt.subplots(1, 3, sharey=True, figsize=(11,3))
plt.subplots_adjust(wspace=0.07, top=0.95,bottom=0.22)
for i, (label, c1, c2, fn, f, Or, T, m, ax) in enumerate([ 
        ('Thermal', 'C0', 'C7', 'ROI1_Re_Measure18_thermal.dat', 6.9, 31.5, 19, 1, axs[0]),
        ('814', 'C1', heather, 'ROI0_Re_Measure18_814.dat', 7.8, 20.2, 2.5, 1e3, axs[1]), 
        ('938', 'C2',sky, 'ROI0_Re_Measure3_938.dat', 3.7, 21.2, 1.5, 1e3, axs[2])]):
    times = np.linspace(0,8e3/Or,np.shape(p)[-1])
    # add data
    hh = fit()
    hh.load(fn)
    x = np.array(hh.x) *m
    y = np.array(hh.y)
    ey = np.array(hh.yerr)
    ax.errorbar(x, y, yerr=ey, fmt='o', color=c1)
    # add theory
    if not i:
        # # analytic formula
        f = fit(x, y, erry=ey, param=[30,20])
        def rabifit2(t, OR, Temp):
            return np.sum(RabiFit2(t,OR,Temp,eta,eta2,0),axis=1)
        f.getBestFit(rabifit2, bounds=(0,100))
        f.args = ['Rabi Freq (kHz)', 'Temp (uK)']
        yth = rabifit2(times, *f.ps)
        ax.plot(times, yth, color=c1)
        print(f.report())
        # rabi = ufloat(f.ps[0], f.perrs[0])
        # stdv = ufloat(f.ps[1], f.perrs[1])
        # print('1/e time (ms): ', 2**0.5/2/np.pi/((rabi**2+stdv**2)**0.5-rabi))
    else:
        # damped sine
        f = fit(x, y, erry=ey, param=[20,100])
        def dampedfit(t, freq, tau):
            return f.dampedSinekHz(t, 0.95/2,freq,tau,-np.pi/2,0.95/2)
        f.getBestFit(dampedfit)
        yth = dampedfit(times, *f.ps)
        ax.plot(times, yth, color=c1)
    ax.set_xlim(0,210)
    ax.hlines(0.95, *ax.get_xlim(), color='k', ls='--')
    ax.set_ylim(0,1)
    ax.set_xlabel('Pulse Duration ($\mu$s)')
    
    # title = 'Fig7{}_data.csv'.format(['a','b','c'][i])
    # data = [x, y, ey, times, yth]
    # pd.DataFrame(data).transpose().to_csv(title, index=False,
    #    header=['# Pulse Duration (us) [data]','f=1 Population [data]','Error in f=1 Population [data]',
    #            '# Pulse Duration (us) [fit]','f=1 Population [fit]'])


axs[0].set_ylabel(r'$|f=1 \rangle$ Population')
plt.savefig('All_radial.svg')
plt.show()


#%%
# detuning fluctuations
key = 'Thermal'
T = 19
def FluctFit(t, OR, fluct, Temp=T, A=0.95):
    return np.sum(RabiFit(t, OR, Temp, eta, dets)*A * np.exp(-dets**2/fluct**2/2)/np.sqrt(2*np.pi)/fluct * (dets[1]-dets[0]), axis=1)

def rabifit2(t, OR, Temp):
    return np.sum(RabiFit2(t,OR,Temp,eta,eta2,0),axis=1)

def rabifit1(t, OR, Temp):
    return np.sum(RabiFit(t,OR,Temp,eta,0),axis=1)

f = fit(xs[key], ys[key], erry=eys[key], param=[20,10])
plt.errorbar(f.x, f.y, yerr=f.yerr, fmt='o')
xth = np.linspace(min(f.x), max(f.x), 200)
# plt.plot(*f.applyFit(FluctFit, bounds=(0,100)))
plt.plot(*f.applyFit(rabifit2, bounds=(0,100)))
# plt.plot(xth, RabiFit(xth, f.ps[0], T=T, eta=eta, delta=0)*0.95, '--')
# f.args = ['OR (kHz)', 'std dev (kHz)', 'Temp uK']
f.args = ['OR (kHz)', 'Temp uK']
print(f.report())
  