import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
os.chdir(os.path.dirname(__file__))
plt.style.use('../../vinstyle.mplstyle')
import sys
sys.path.append('../..')
from fitandgraph import fit

### Load data
### Zernike Optimisation
zk0 = pd.read_csv('ZK_optimisation/ROI0_Re_Measure0.dat', skiprows=2)
f0 = fit(zk0['User variable'], zk0['Loading probability'], 
         zk0['Error in Loading probability'], [-0.5,25,2,1])
zk40 = pd.read_csv('ZK_optimisation/ROI0_Re_Measure40.dat', skiprows=2)
f40 = fit(zk40['User variable'], zk40['Loading probability'], 
         zk40['Error in Loading probability'], [-0.5,31,2,1])

### 4 trap array
fits = []
for i in range(4):
    df = pd.read_csv('4traparray/ROI%s_Re_Measure8.dat'%i, skiprows=2)
    fits.append(fit(df['User variable'], df['Loading probability'], 
                df['Error in Loading probability'], [-0.5, 35, 2, 1]))


#%% plot

fig, axs = plt.subplot_mosaic([['upper', 'upper'],['lower left','lower right']],
    gridspec_kw=dict(height_ratios=[1,2],hspace=.3, wspace=.1), figsize=(9,4.3))

axs['upper'].plot(np.concatenate((np.ones(100),1+0.1*np.sin(np.linspace(0,115,200)),
                                  np.ones(20), .3*np.ones(20),np.ones(40))))
axs['upper'].set_xticks([])
axs['upper'].set_yticks([0,1])
axs['upper'].set_ylabel('Trap \nIntensity')
axs['upper'].annotate('(a)', (2,0.6))
axs['upper'].set_xlim(0,380)
axs['upper'].set_xlabel('Time')

ax0 = axs['lower left']
ax1 = axs['lower right']

ax0.errorbar(f0.x, f0.y, yerr=f0.yerr, color='C0', fmt='o', label='None Applied')
ax0.plot(*f0.applyFit(f0.offGauss), color='C0')
#4,0=-0.129; 2,2=-0.2; 2,-2=-0.076; 4,4=0.08
ax0.errorbar(f40.x, f40.y, f40.yerr, color='C1', fmt='^', label='Optimised')
ax0.plot(*f40.applyFit(f40.offGauss), color='C1')
# ax0.set_xlabel('Modulation Frequency (kHz)')
ax0.set_ylabel('Survival Probability')
ax0.legend() # title='$A (^4_0, ^2_2, ^2_{-2}, ^4_0)$'
ax0.annotate('(b)', (34, 0.95), ha='center', va='center')
ax0.set_ylim(-.1, 1.07)
ax0.set_yticks([0.1,0.4,0.7,1.])

ss = ['o', '^', 's', 'x']
for i, f in enumerate(fits):
    ax1.errorbar(f.x, f.y, f.yerr, fmt=ss[i], color='C%s'%i, label='ROI%s'%i)
    ax1.plot(*f.applyFit(f.offGauss), color='C%s'%i)
# ax1.set_xlabel('Modulation Frequency (kHz)')
# ax1.set_ylabel('Survival Probability')
ax1.legend()
ax1.annotate('(c)', (36.2, 0.95), ha='center', va='center')
ax1.set_ylim(-.1, 1.07)
ax1.set_yticks([0.1,0.4,0.7,1.])
ax1.set_yticklabels([])
fig.text(0.5, -0.01, 'Modulation Frequency (kHz)', ha='center', va='center')
# plt.tight_layout()
plt.savefig('parametric_heating.pdf', bbox_inches='tight')