import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
plt.style.use(r'~/.matplotlib/vinstyle.mplstyle')
import os
os.chdir(os.path.dirname(__file__))
import scipy.constants as constants

a0 = constants.physical_constants['Bohr radius'][0]
au = constants.physical_constants['Hartree energy'][0]
u = constants.u
h = constants.h
c = constants.c
hbar = h/(2*np.pi)


fig, axs = plt.subplot_mosaic([['left', 'upper right'],['left', 'lower right']], 
      figsize=(9,4), gridspec_kw={'width_ratios': [2, 1]},constrained_layout=True)
#fig, (ax0, ax1)= plt.subplots(1, 2, gridspec_kw={'width_ratios': [2, 1]}, constrained_layout=True, figsize=(9,4))

### potential curves
files = ["1S","3S","3P"]
labels = ["$^1\Sigma^+$","$^3\Sigma^+$","$^3\Pi$"]
ax0 = axs['left']

for i, (fn,label) in enumerate(zip(files,labels)):
    df = pd.read_csv(fn+'.csv')
    x = 1e9*df['R']*a0 #convert Bohr to nm
    for j in range(1,5):
        #convert atomic units (hartree) to THz
        try:
            ax0.plot(x, 1e-12* (df['(%s)'%j+fn]+0.29661)*au/h, color='k', 
                     label=label if j==1 else None, lw=1) 
        except: break
    
ax0.arrow(0.44, constants.c/1557e3, 0, -constants.c/977e3, ls='-', color='C4', fill=True, 
          length_includes_head=True, head_width=0.02, head_length=20, label='938 nm')
ax0.annotate('Stokes', (.41,-30), va='center', ha='center', rotation=90, color='C4')
ax0.arrow(0.55, 0, 0, constants.c/1557e3, ls='-', color='C3', fill=True, 
          length_includes_head=True, head_width=0.02, head_length=20, label='')
ax0.annotate('Pump', (.52,100), va='center', ha='center', rotation=90, color='C3')
ax0.annotate('X$^1\Sigma^+$', (.33,-110), va='center', ha='center')
ax0.annotate('a$^3\Sigma^+$', (.53,-30), va='center', ha='center')
ax0.annotate('b$^3\Pi$', (.5,220), va='center', ha='center')
ax0.annotate('A$^1\Sigma^+$', (.75,220), va='center', ha='center')
ax0.hlines(-0.1, 0.52, 0.9, ls='--', color='C0')
ax0.annotate(r'$|F\rangle=|-6(2,4)d(2,4)\rangle$', (.57,15), color='C0')
ax0.hlines(constants.c/1557e3, 0.37, 0.56, colors='C1', linestyles='--')
ax0.annotate(r"$|E\rangle=|^3\Pi_1 v'=29, J'=1\rangle$", (.57,180), va='center', color='C1')
ax0.hlines(constants.c*(1/1557e3-1/977e3), 0.42, 0.47, colors='C2', linestyles='--')
ax0.annotate(r"$|G\rangle=|v''=0, N''=0\rangle$", (.57,-100), va='center', color='C2')
ax0.set_xlim(0.25,1.05)
ax0.set_ylim(-130,250)
ax0.set_xlabel('Internuclear separation (nm)')
ax0.set_ylabel('Energy / h (THz)')
ax0.annotate('(a)', (1,220), ha='center', va='center')

### Rabi frequencies
ax1 = axs['upper right']
s = np.cos(np.linspace(0,np.pi/2,40))**2
ax1.plot(np.concatenate((np.zeros(10), np.ones(10), s, np.zeros(20))), 
             c='C4', label='$\Omega_S$')
p = np.sin(np.linspace(0,np.pi/2,40))**2
ax1.plot(np.concatenate((np.zeros(20), p, np.ones(10), np.zeros(10))), 
             c='C3', label='$\Omega_P$')
ax1.annotate('$\Omega_S$', (22,0.5),va='center', ha='center', color='C4')
ax1.annotate('$\Omega_P$', (56,0.5),va='center', ha='center', color='C3')
ax1.set_xticks([])
ax1.set_xlabel('Time')
ax1.set_ylabel('$\Omega/\Omega_\mathrm{max}$')
ax1.annotate('(b)', (76,0.87), ha='center', va='center')

### Populations
ax2 = axs['lower right']
ax2.plot(np.zeros(80), color='C1')
ax2.plot(np.concatenate((np.ones(20), np.cos(np.arctan(p/s))**2, np.zeros(20))), color='C0')
ax2.plot(np.concatenate((np.zeros(20), np.sin(np.arctan(p/s))**2, np.ones(20))), color='C2')
ax2.annotate(r'$|F\rangle$', (22,0.7), va='center', ha='center', color='C0')
ax2.annotate(r'$|G\rangle$', (56,0.7), va='center', ha='center', color='C2')
ax2.annotate(r'$|E\rangle$', (22,0.15), va='center', ha='center', color='C1')
ax2.set_xticks([])
ax2.set_xlabel('Time')
ax2.set_ylabel('Population')
ax2.annotate('(c)', (76,0.87), ha='center', va='center')
plt.savefig('STIRAP.pdf', bbox_inches='tight')