# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
plt.style.use(r'Z:\Tweezer\People\Vincent\python snippets\vinstyle.mplstyle')
import os
os.chdir(os.path.dirname(__file__))
import time
import sys
sys.path.append(r'Z:\Tweezer\People\Stefan\general-python')
sys.path.append(r'Z:\Tweezer\Code\Python 3.5\PyDex')
sys.path.append(r'Z:\Tweezer\Code\Python 3.5\PyDex\imageanalysis')
from fitandgraph import fit
from histoHandler import histo_handler

plt.close('all')

# print("available functions: ", str([x for x in dir(fit) if '_' not in x]))
fields = np.array([171.8, 180, 181,181.9,182.23, 182.6, 183.08, 184,184.79]) - 181.6
plot = [0,0,0,1,1,1,2,2,2,2]
fig, axs = plt.subplots(len(fields), 1, sharex=True, figsize=(6,8))

xpeaks1 = []
xpeaks2 = []
peaks1 = [[],[]]
peaks2 = [[],[]]
offsets = [[],[]]
mean=[[],[]]

for ax, B, fn, plot_fit in zip(axs, fields, os.listdir(
        r'Z:\Tweezer\Experimental Results\2021\December\1064 pair loss measures'), plot):
    
    df = pd.read_csv(fn, skiprows=2)
    x = 0.5*(56+72)*np.sqrt(df['User variable']) # scaling to trap freq
    y = df['0 atom survival probability']
    ey = df['Error in 0 atom survival probability']

    ax.errorbar(x, y, yerr=ey, fmt='o', label='%.1f G'%B, ms=4)
    mean[0].append(y.mean(axis=0))
    mean[1].append(y.std(axis=0))
    if plot_fit == 1: 
        f = fit(x, y, ey, param = [0.3,B*9.7+47,5,0.4]) 
        ax.plot(*f.applyFit(f.offGauss))
        print(f.report())
        xpeaks1.append(B)
        peaks1[0].append(f.ps[1])
        peaks1[1].append(f.perrs[1])
        offsets[0].append(f.ps[-1])
        offsets[1].append(f.perrs[-1])
    elif plot_fit == 2: 
#        f = fit(x, y, ey, param = [0.6,B*9.7+49,5,0.8,29*B+36,20,0.25]) 
        f = fit(x, y, ey, param = [0.6,B*9.7+49,5,0.8,29*B+36,20,0.25]) 
        ax.plot(*f.applyFit(f.doubleGauss))
        print(f.report())
        xpeaks1.append(B)
        xpeaks2.append(B)
        peaks1[0].append(f.ps[1])
        peaks1[1].append(f.perrs[1])
        peaks2[0].append(f.ps[4])
        peaks2[1].append(f.perrs[4])
        offsets[0].append(f.ps[-1])
        offsets[1].append(f.perrs[-1])
       
    ax.set_ylim(0.2, 0.9)
    ax.legend()
    
    
axs[len(axs)//2].set_ylabel('0 atom survival probability')
ax.set_xlabel('Cs Trap Frequency (kHz)')
plt.tight_layout()
plt.subplots_adjust(hspace=0.0, top=0.98, bottom=0.1, left=0.15, right=0.95)
plt.show()

#%%
plt.figure()
for xs, peaks, color in zip([xpeaks1, xpeaks2], [peaks1, peaks2], ['C0', 'C1']):
    g = fit(xs, peaks[0], peaks[1])
    plt.errorbar(g.x, g.y, yerr=g.yerr, fmt='x', color=color)
    xth = np.linspace(0, max(xs), 10)
    g.getBestFit(g.linear)
    plt.plot(xth, g.linear(xth, *g.ps), color=color)
    print(g.ps)

plt.xlabel('B field (G) + 181.6G')
plt.ylabel('Loss Centre (kHz)')    
y0, y1 = plt.ylim()
ax = plt.gca().twinx()
ax.set_ylim(y0, y1)
yax = np.linspace(y0,y1,7)
ax.set_yticks(yax)
ax.set_yticklabels(['%.2f'%y for y in (yax*2/(56+72))**2])
ax.set_ylabel('1064 Tweezer Power (V)')
plt.tight_layout()
plt.show()

#%%
plt.figure()
g = fit(xpeaks1, offsets[0], offsets[1])
plt.errorbar(g.x, g.y, yerr=g.yerr, fmt='o')
plt.errorbar(fields[:3], mean[0][:3], yerr=mean[1][:3], fmt='o')
plt.xlabel('B field (G) + 181.6G')
plt.ylabel('Fitted offset')
plt.tight_layout()
plt.show()