import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from scipy.signal import find_peaks
import os
os.chdir(os.path.dirname(__file__))
import arrayFitter as af
import json

with open(r'Measure10\Iteration0\params.json') as f:
    params = json.load(f)
    
xmin, ymin, xmax, ymax = params['Camera ROI']
imshape = 572, 124 #(xmax-xmin, ymax-ymin)
roi_size = 46 # also used for drawing ROI on plot
o = af.imageArray((4,1), roi_size) # params['Fit ROI size']
o.ref = params['reference'][0]*1.056151665

#%
imname = r'Measure10\Iteration{0}\{1}.png'
beforeim = np.zeros(imshape)
for i in range(9):
    beforeim += o.loadImage(imname.format(0,i), imshape)/9
o._imvals = beforeim.copy()
o.fitImage()
o.getScaleFactors(2)
start_amps = o.df['I0']/o.ref

#%
afterim = np.zeros(imshape)
for i in range(9):
    afterim += o.loadImage(imname.format(2,i), imshape)/9
o._imvals = afterim.copy()
o.fitImage()
o.getScaleFactors(2)
end_amps = o.df['I0']/o.ref

print('Initial std dev : {:%}'.format(np.std(start_amps, ddof=1)))
print('Final std dev : {:%}'.format(np.std(end_amps, ddof=1)))

#%%
ints = []
stds = []
for i in range(7):
    im = np.zeros(imshape)
    for j in range(9):
        im += o.loadImage(imname.format(i,j), imshape)/9
    o._imvals = im.copy()
    o.fitImage()
    o.getScaleFactors()
    ints.append(o.df['I0']/o.ref)
    stds.append(np.std(ints[-1], ddof=1))
   
ints = np.array(ints)[:3]
stds = np.array(stds)[:3]
#%%
fig, (ax1, ax2) = plt.subplots(1,2, gridspec_kw={'width_ratios':[1,1]}, figsize=(9,4))

ints /= np.mean(ints[-1])
# ax1.fill_between(np.arange(len(stds)), mean-stds[2], mean+stds[2], alpha=0.3, color='C2')
for i in range(3): ax1.scatter([i]*4, ints[i], s=50)
ax1.fill_between(np.arange(3), np.mean(ints,axis=1)-stds, np.mean(ints,axis=1)+stds, alpha=0.2, color='C4')
ax1.set_xlabel('Iteration')
ax1.set_ylabel('Power')
ax1.set_xticks(list(range(3)))
ax1.text(-0.8,1.1,'(a)')

ref = np.mean([max(afterim[i*140:(i+1)*140,60]) for i in range(4)])
ax2.plot(beforeim[:,60]/ref, color='C0', label='0')
ax2.plot(afterim[:,60]/ref, color='C2', label='2')
ax2.plot(np.ones(imshape[0]), '--', color='C2')
ax2.set_xlabel('Pixel')
ax2.set_ylabel('Intensity')
ax2.legend()
# plt.subplots_adjust(wspace=0.3)
x0,x1 = ax1.get_xlim()
x2,x3 = ax2.get_xlim()
ax2.text(-0.8/(x1-x0)*(x3-x2),1.18,'(b)')
ax3 = plt.axes((0.18,0.14,0.3,0.35))
ax3.xaxis.set_visible(False)
ax3.yaxis.set_visible(False)
for x0 in find_peaks(afterim[:,60], height=4, distance=50)[0]:
    ax3.add_patch(Rectangle((x0-roi_size,124//2-roi_size), roi_size*2, roi_size*2, facecolor='none', 
                              linestyle='--', edgecolor='k'))
ax3.imshow(afterim.T)
plt.tight_layout()
plt.savefig('Normalisation.pdf')
#%%
plt.imshow(np.rot90(afterim,1), cmap='Greys')
plt.gca().axis('off')

#%%
fig, (ax1, ax2) = plt.subplots(2,1)
ax1.imshow(np.rot90(beforeim,1), cmap='Blues', vmax=130)
ax1.axis('off')
ax2.imshow(np.rot90(afterim,1), cmap='Purples', vmax=130)
ax2.axis('off')


