import sys
import matplotlib.pyplot as plt
plt.style.use('../../../vinstyle.mplstyle')
import numpy as np
import pandas as pd
from scipy.fft import fft, fftfreq
from scipy.signal import blackman
sys.path.append('../../..')
from fitandgraph import fit
import os
os.chdir(os.path.dirname(__file__))

#%%
traces = range(54,59)
labels = [1.5,0.5, 4, 10]
periods = []

for trace,label in zip(traces,labels):
    df = pd.read_csv('T00{}CH2.csv'.format(trace),skiprows=15)
    i0 = np.where(df['CH2']>0.15)[0][0]
    i1 = np.where(df['CH2'].iloc[i0+1000:]> 0.15)[0][0]
    df = df[i0:i0+i1]
    time = np.linspace(0,label,len(df))
    
    f0 = 12/label
    freqs = fftfreq(len(df['CH2']), time[1] - time[0])
    inds = freqs > f0
    freqs = freqs[inds]
    A = np.abs(fft(df['CH2'].values))[inds]
    print('Osc. Freq. 1 (kHz):\t', freqs[np.argmax(A)])
    periods.append(1/freqs[np.argmax(A)])
    out = pd.DataFrame([pd.Series(v) for v in [time, df['CH2'], freqs, A]], 
            columns=['Time (ms)','PD Signal (V)','Frequency (kHz)','Fourier Transform'])
    out.to_csv('PD_FT_%sms.csv'%label)
    amp1 = max(A)
    ind1 = np.argmax(A)
    A[np.argmax(A)] = 0
    print('Osc. Freq. 2 (kHz):\t', freqs[np.argmax(A)])
    A[ind1] = amp1
    
#%%

f = fit([1/x for x in labels], periods)
f.args = ['FSR (kHz)']

plt.scatter(f.x, f.y, color='C2')
plt.plot(*f.applyFit(lambda x,A: 1e-3*A/x), color='C2', alpha=0.6)
print(f.report())
# plt.scatter(f.x[-1], f.y[-1], color='C0')
f.y = np.array(f.y)/1.1053
plt.scatter(f.x, f.y, color='C1')
plt.plot(*f.applyFit(lambda x,A: 1e-3*A/x), color='C1', alpha=0.6)
print(f.report())
# plt.scatter(f.x[-1], f.y[-1], color='C0')
plt.xlabel('Sweep Rate (GHz/s)')
plt.ylabel('Oscillation Period (ms)')
out = pd.DataFrame(np.array([f.x, f.y, f.y*1.1053]).T, 
                   columns=['Sweep Rate (GHz/s)', 'Period 1 (ms)', 'Period 2 (ms)'])
out.to_csv('FSR.csv')

ax1 = plt.axes([.42,.61,.45,.13])
ax1.plot(time, df['CH2'], linewidth=1, label=str(label))
ax1.set_xlabel(r'Time (ms)')
ax1.set_ylabel('Signal (V)')
ax1.set_ylim(-0.13,0.13)
ax1.xaxis.set_label_position('top')
ax1.xaxis.tick_top()
for item in ([ax1.title, ax1.xaxis.label, ax1.yaxis.label] +
             ax1.get_xticklabels() + ax1.get_yticklabels()):
    item.set_fontsize(14)

ax2 = plt.axes([.42,.42,.45,.14])
for item in ([ax2.title, ax2.xaxis.label, ax2.yaxis.label] +
             ax2.get_xticklabels() + ax2.get_yticklabels()):
    item.set_fontsize(14)
ax2.plot(freqs, A)
ax2.set_xlabel('Frequency (kHz)')
ax2.set_ylabel('FT')
ax2.set_xlim(f0, 5*f0)
find = np.argmin(abs(freqs-1/f.y[-1]))
ax2.scatter(freqs[find], A[find], color='C1')
find = np.argmin(abs(freqs-1/f.y[-1]/1.1053))
ax2.scatter(freqs[find], A[find], color='C2')

plt.savefig('FSR.svg')
plt.show()