import numpy as np
import matplotlib.pyplot as plt
import os
os.chdir(os.path.dirname(__file__))
plt.style.use('../../vinstyle.mplstyle')
import sys
sys.path.append('../..')
from fitandgraph import fit
import pandas as pd

def weighted_std(v):
    if len(v.y.values)-1:
        weights = v.yerr
        average = np.ma.average(v.y, weights=v.yerr, axis=0)
        variance = np.dot(weights, (v.y - average)**2) / weights.sum()
        return np.sqrt(variance)
    else: return v.yerr.values[0]

fig, (ax0, ax1) = plt.subplots(1,2, figsize=(8.7,3.5), sharey=True)
f = fit()
for m, c, ax, label in [[0,'C0',ax0,'1064'],[9,'C1',ax1,'817'],[10,'C2',ax1,'938']]:
    f.load('axial/ROI0.Im0.Measure%s.dat'%m, ykey='Separation')
    df = pd.DataFrame(np.array((f.x, f.y/max(f.y), f.yerr/max(f.y))).T, columns=['x','y','yerr'])
    ave = df.groupby(df.x).apply(lambda v: np.average(v.y, weights=v.yerr))
    averr = df.groupby(df.x).apply(weighted_std).values
    f = fit(np.array(ave.keys()), ave.values, averr, param = [1,-16 if not m else 0,10,.2])
    ax.errorbar(f.x, f.y, f.yerr, fmt='o' if m-9 else '^', color=c, label=label+' nm')
    ax.plot(*f.applyFit(f.offGauss), color=c)
    print(f.report())
    ax.legend()

ax0.set_xlabel('Axial Displacement ($\mu$m)')
ax0.set_ylabel('Normalised \nFluourescence Signal')
ax0.annotate('(a)', (-25, 0.9))
ax1.set_xlabel('Lens Rotations')
ax1.annotate('(b)', (-5, 0.9))
plt.savefig('axial_psf_overlap.pdf')