import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
os.chdir(os.path.dirname(__file__))
plt.style.use('../../vinstyle.mplstyle')
import sys
sys.path.append('../..')
from fitandgraph import fit
import matplotlib
cmap = matplotlib.cm.get_cmap('plasma')

fig, axs = plt.subplots(1,2, figsize=(8.3,3.4), sharey=False)
plt.subplots_adjust(wspace=0.25)

### long Rabi osc
f = fit(param=[0.5, 6e3, 1, -np.pi/2, 0.5])
f.load('ROI1_Re_Measure34_long.dat')
f.x /=1e3
axs[0].errorbar(f.x, f.y, f.yerr, fmt='o')
axs[0].plot(*f.applyFit(f.dampedSinekHz), color='C0') 
axs[0].set_xticks([0,1,2,3])
axs[0].set_xlabel('MW Pulse Duration (ms)')
axs[0].set_ylabel(r'$|\downarrow\rangle$ Probability')
axs[0].text(1.5, 0.97, '(a)', va='center')

### 2D map
def gen_rabi(xy, f0, w, A):
    t, f = xy
    return A*(w/(w**2 + (f-f0)**2)**0.5 * np.sin(2*np.pi*(w**2+(f-f0)**2)**0.5*t/2))**2

   
fn = 'ROI1_Re_Measure%s.dat'
freqs = np.array([67,69,71,73,70,64,68,75,78,81])-71.93
fsort = list(sorted(freqs))
inds = [fsort.index(f) for f in freqs]
prob = np.zeros((len(freqs), 20))
for m in range(33,33+len(freqs)):
    f.load(fn%m)
    prob[inds[m-33]] = f.y[np.argsort(f.x)]
    zz = f.y[np.argsort(f.x)]
    axs[1].scatter(sorted(f.x/1e3), [freqs[m-33]]*len(f.x), c=zz,s=zz*200, 
                   cmap=cmap, vmin=0, vmax=0.9)
    
x = list(sorted(f.x*1e-3))*len(freqs) 
maxx = max(x)*1e3
f.load(fn%43)
xx = f.x[f.x <= maxx]/1e3
zz = f.y[f.x <= maxx]
c = axs[1].scatter(xx, [0.07]*len(xx), c=zz, s=zz*200, cmap=cmap,vmin=0,vmax=0.9)
y = sum([[f]*20 for f in fsort],[]) + [72]*len(xx)
z = list(prob.flatten()) + list(f.y[np.argsort(f.x)][f.x <= maxx])
x += list(sorted(f.x[f.x <= maxx]*1e-3))

f = fit((x,y), z, param=[0,6,0.95])
f.getBestFit(gen_rabi)
f.args = ['centre freq (kHz)', 'Rabi freq (kHz)', 'spam offset']
print(f.report())
xfit = np.linspace(min(x), max(x), 80)
yfit = np.linspace(min(freqs)*1.04, max(freqs)*1.04, 100)
z = np.array([gen_rabi((xfit, yf), *f.ps) for yf in yfit])
axs[1].imshow(z, extent=(min(xfit),max(xfit),max(yfit),min(yfit)), 
              alpha=0.5, aspect='auto', cmap='Oranges')
axs[1].set_xlabel('MW Pulse Duration (ms)')
axs[1].set_ylabel('Detuning (kHz)')
axs[1].set_xlim(0,0.48)
axs[1].set_ylim(min(freqs)*1.04, max(freqs)*1.04)
plt.colorbar(c, label=r'$|\downarrow\rangle$ Probability', ax=axs[1])
axs[1].text(0.21, 7.9, '(b)')
plt.savefig('mw.svg')

