import numpy as np
import matplotlib.pyplot as plt
import os
os.chdir(os.path.dirname(__file__))
plt.style.use('../../vinstyle.mplstyle')
import sys
sys.path.append('../..')
from fitandgraph import fit
import lmfit

def prob(x, dif, eg):
    """Clifford gate error from https://journals.aps.org/prl/pdf/10.1103/PhysRevLett.121.240501"""
    return 0.5 + 0.5*(1-dif) * (1 - 2*eg)**x

def get(c, key):
    mean = c[key][3][1]
    uper = c[key][4][1] - mean
    loer = mean - c[key][2][1]
    print('%.3g + %.1g - %.3g'%(mean, uper, loer))

fmodel = lmfit.Model(prob) # xc, w, FWHM, A1, A2
fmodel.set_param_hint('dif', min=0)
fmodel.set_param_hint('eg', min=0)
pars = fmodel.make_params(dif=0.01, eg=0.01)


fig, (ax0, ax1) = plt.subplots(1,2, sharey=True, figsize=(8,3))
plt.subplots_adjust(wspace=0.1)
f = fit(param=[0.1,0.1])

### pi pulse
for i, fn in enumerate(['ROI0_Re_Measure22.dat', 'ROI1_Re_Measure8.dat']):
    f.load(fn)
    
    ax0.errorbar(f.x, f.y, f.yerr, fmt='o', color='C%s'%(2-i), label='Rb' if i else 'Cs')
    ax0.plot(*f.applyFit(f.Randomised_Benchmarking), color='C%s'%(2-i))
    
    def residual(p):
        return (f.y - prob(f.x, p['dif'], p['eg']))/f.yerr
    mini = lmfit.Minimizer(residual, pars)
    rslt = mini.minimize()
    ci = lmfit.conf_interval(mini, rslt)
    print('SPAM: ')
    get(ci, 'dif')
    print('Error/gate: ')
    get(ci, 'eg')

ax0.set_xlabel('# $\pi$-pulses')
ax0.set_ylabel(r'|$\downarrow\rangle$ Probability')
ax0.text(15, 0.93, '(a)')


### pushout
for i, fn in enumerate(['ROI0_Re_Measure24.dat', 'ROI1_Re_Measure1.dat']):
    f.load(fn)
    if i:
        X = 5
        ax1.errorbar(f.x[f.x<X], f.y[f.x<X], f.yerr[f.x<X], fmt='^',
                     color='C%s'%(2-i), alpha=0.4)
        f.y = f.y[f.x>X]
        f.yerr = f.yerr[f.x>X]
        f.x = f.x[f.x>X]
    ax1.errorbar(f.x, f.y, f.yerr, fmt='o', color='C%s'%(2-i), label='Rb' if i else 'Cs')
    ax1.plot(*f.applyFit(f.Randomised_Benchmarking), color='C%s'%(2-i))
    
    def residual(p):
        return (f.y - prob(f.x, p['dif'], p['eg']))/f.yerr
    mini = lmfit.Minimizer(residual, pars)
    rslt = mini.minimize()
    ci = lmfit.conf_interval(mini, rslt)
    print('SPAM: ')
    get(ci, 'dif')
    print('Error/gate: ')
    get(ci, 'eg')


ax1.set_xlabel('# Pushout Pulses')
ax1.legend()
ax1.text(13, 0.93, '(b)')
plt.tight_layout()
plt.savefig('spam_fidelity.svg')