"""Stefan Spence 09.06.21
Plot the probability of making a transition on a given sideband
"""
import matplotlib.pyplot as plt
plt.style.use('../../vinstyle.mplstyle')
import numpy as np
from scipy.special import genlaguerre, factorial
from qutip import *
from scipy.constants import h, hbar
from scipy.constants import k as kB

N  = 75          # number of harmonic levels to include in simulation
w  = 20 # trap frequency (kHz)
Or = 8  # Rabi freq (kHz)
wr = 2.07 # Recoil frequency (kHz)
Temp=10 # initial temperature in uK
n = tensor(num(N), qeye(2)) # number operator
eta = np.sqrt(wr/w) # LD parameter
nth = n_thermal(w, Temp*1e-9*kB/h)

def RamanHamiltonians(w, eta, Or, det, T, N_ph, nbar=None):
    w *= 2*np.pi # convert to rad/s
    Or *= 2*np.pi
    det *= 2*np.pi
    # phonon operators
    a = qutip.tensor(qutip.destroy(N_ph),qutip.qeye(2)) #anhilation operator
    #atomic operators here, #g1->0,g2->1,e->2
    sm_g2g1 = qutip.tensor(qutip.qeye(N_ph), qutip.sigmap()) # this is spin lowering, sig-
    #Hamiltonians
    Hb = det / 2 * qutip.tensor(qutip.qeye(N_ph), qutip.sigmaz()) # bare Hamiltonian, internal
    for i in range(N_ph): Hb += w*i*qutip.tensor(qutip.fock_dm(N_ph,i), qutip.qeye(2)) # number states
    HR = Or / 2 * sm_g2g1 * (-1j*eta*(a + a.dag())).expm() # Raman Hamiltonian
    HR += HR.dag() # add Hermitian conjugate
    if not nbar:
        psi0 = qutip.tensor(qutip.thermal_dm(N_ph, qutip.n_thermal(w, T*1e-9*kB/hbar), method='operator'), qutip.fock_dm(2,0)) 
    else:
        psi0 = qutip.tensor(qutip.thermal_dm(N_ph, nbar, method='operator'), qutip.fock_dm(2,0)) 
    return Hb, HR, psi0

def LR(eta, n, m):
    """Return the ratio of sideband / Raman Rabi frequency for LD parameter eta,
    between sidebands n and m."""
    return np.exp(-eta**2/2) * np.sqrt(factorial(min(n, m))/factorial(max(n, m))
        ) * eta**(abs(n-m)) * genlaguerre(min(n, m),abs(n-m))(eta**2)

results = []
#%%
for key in ['axes', 'xtick', 'ytick']:
    plt.rc(key, labelsize=18 if 'axes' in key else 15)
    
#%%
for i in range(5):
    ax1 = plt.figure(figsize=(5,3.5)).gca()
    OR = [np.abs(LR(eta, n,n+i)) for n in range(N)]
    nm = np.argmax(OR)
    # plt.title('$\Delta n = -%s,~n_{max}=%s$'.replace('-' if i==0 else '','')%(i,nm), size=18)
    print(i, nm, eta/LR(eta,nm,nm+i))
    tlist = np.linspace(0,1/2/Or/LR(eta,nm,nm+i),int(10/LR(eta,nm,nm+i))+50)
    Hb, HR, psi0 = RamanHamiltonians(w, eta, Or, -w*i+Or**2/2/(w*i) if i else 0, Temp, N)
    results.append(mesolve([Hb, HR], psi0, tlist, progress_bar=ui.EnhancedTextProgressBar()))
    after = np.real(np.diag(results[i].states[-1])[slice(0, N*2, 2)])
    before = np.real(np.diag(psi0)[slice(0, N*2, 2)])
    plt.bar(np.arange(0,N), before, color='C7', edgecolor='k', width=0.95)
    plt.bar(np.arange(0,N), after, color='C0', edgecolor='k', width=0.95)
    plt.bar(nm, before[nm], color='C4', edgecolor='k', width=0.95)
    ax1.set_ylabel(r'$|\uparrow; n\rangle$ Population', color='C0')
    ax1.set_xlabel('$n$')
    ax1.set_yscale('log')
    ax1.set_ylim((1e-5,1.3e-1))
    ax1.set_xlim(0,N-5)
    ax2 = ax1.twinx()
    ax2.plot(np.arange(N), OR, color='C1', lw=4)
    ax2.set_ylim(0, 0.6)
    ax2.set_ylabel(r'$\Omega_i(n+ \Delta n) / \Omega_R$', color='C1') #= |\langle n|e^{ik\cdot x}|n+\Delta n\rangle |
    plt.savefig(r'HigherSidebandPulse%s.svg'%i)
    
plt.show()