import numpy as np
import matplotlib.pyplot as plt
plt.style.use('../../../vinstyle.mplstyle')
import os
os.chdir(os.path.dirname(__file__))
import arrayFitter as af

o = af.imageArray((4,1), 50) 

imname = r'{}_0.bmp'
scales = [1, 1.4, 1.5, 0.8, 1.2]
rms = [107, 135, 140, 94, 121]
ims = []
fits = []
sums = []
for s in scales:
    o.fitImage(imname.format(s))
    o.getScaleFactors(2)
    fits.append(o.df['I0'])
    ims.append(o._imvals.copy())
    flat = ims[-1][500:700].sum(axis=0)
    sums.append((flat[:403].sum() + flat[820:].sum())/flat[403:820].sum())

#%%
from matplotlib import colors
fig, (ax1, ax2) = plt.subplots(2,1, figsize=(6.5,4.5))
ref = max(ims[1].sum(axis=0))
red = np.array(colors.to_rgba('C1')[:-1])
for s, im, color in zip(rms[:3], ims, [1.4, 0.8, 0.]):
    ax1.plot(im.sum(axis=0)[200:1020]/ref, label=str(s), color=red*color)
ax1.legend(title='$V_\mathrm{RMS}$ (mV)')
ax1.set_xlabel('Pixel')
ax1.set_ylabel('Intensity')
ax1.text(-20,0.8,'(a)')

start_amps = [1.051354, 1.086596, 0.926873, 0.892107]
end_amps = [1.004925, 1.001554, 1.007807, 0.985714]
mean = np.mean(end_amps)
std = np.std(end_amps, ddof=1)
ax2.scatter(np.arange(4), start_amps, label='Before', color='C1')
ax2.scatter(np.arange(4), end_amps, label='After', color='C2')
ax2.hlines(1, 0, 3, ls='dashed', colors=['C2'])
ax2.fill_between(np.arange(4), mean+std, mean-std, color='C2', alpha=0.3)
ax2.legend()
ax2.set_xlabel('Trap Index')
ax2.set_xlim(-3.2, 6.2)
ax2.set_xticks(np.arange(4))
ax2.set_ylabel('Power')
ax2.set_ylim(0.85,1.1)
ax2.text(-3,1.045,'(b)')
plt.subplots_adjust(hspace=0.4)
ax3 = plt.axes((0.19,0.18,0.17,0.15))
ax3.xaxis.set_visible(False)
ax3.yaxis.set_visible(False)
amps = [0.31997227, 0.31316984, 0.37274244, 0.37482698]
phases = [-140.72214774353284, -262.1750857605246, -488.07519516563497, -816.7873979982915]
freqs = [94, 98, 102, 106]
t = np.arange(1024)*2*np.pi/1024
data = np.zeros(1024)
for f, phi, a in zip(freqs, phases, amps):
    data += a*np.sin(f*t + 2*np.pi*phi/360)
ax3.plot(data, color='C2')
ax3.set_xlabel('Sample')
ax3.set_ylabel('Amplitude')
plt.savefig('TrapArrayBalance.svg')
plt.show()