import numpy as np
from matplotlib import pyplot as plt
import sys
from scipy.interpolate import interp1d
from scipy.optimize import minimize
from matplotlib.ticker import ScalarFormatter
import os


plt.rcParams['font.family'] = 'Latin Modern Roman'
fig,ax = plt.subplots()
wp = float(sys.argv[1].split('/')[-1].split('_')[0].removesuffix('eV'))
peak = wp
sb_data = {}
em_list = []

red0blue1=0
if red0blue1:
    primary_colour = 'Blue'
else:
    primary_colour = 'Red'

exp_data = np.genfromtxt(sys.argv[1], delimiter=',',skip_header=2)
exp_car = exp_data.T[2]
car_err = exp_data.T[8]
exp_sb2 = exp_data.T[5-red0blue1]
exp_sb4 = exp_data.T[7-red0blue1]
sb2_err = exp_data.T[11-red0blue1]
sb4_err = exp_data.T[13-red0blue1]
twoerr = sb2_err
fourerr = sb4_err

best_residual = float('inf')
best_n = 0

def residuals(em, theory, sb):
    interp = interp1d(em,theory)
    if sb == 2:
        exp_y = exp_sb2[exp_sb2>twoerr]
        exp_x = exp_data.T[1][exp_sb2>twoerr]
        resi = sum((interp(exp_x[exp_x>23]) - exp_y[exp_x>23])**2)
    if sb == 4:
        exp_y = exp_sb4[exp_sb4>fourerr]
        exp_x = exp_data.T[1][exp_sb4>fourerr]
        resi = sum((interp(exp_x[exp_x>23]) - exp_y[exp_x>23])**2)

    return resi



for directory in sorted(sys.argv[2:], key = lambda x: float(x.strip('/').split('_')[-1])):
    run_n = directory.split('_')[-1]
    spec_files = [f for f in os.listdir(directory) if '.spec' in f]
    sb_data = {}
    em_list = []

    for f in sorted(spec_files, key = lambda x: float(x.split('/')[-1].split('_')[2])):
        if '_sb_' in f:

            sb = int(f.removesuffix('.spec').split('_')[-1])
            if red0blue1:
                if sb <= 0 or sb==6:
                    continue
            else:
                if sb >= 0 or sb==-6:
                    continue

            em = float(f.split('/')[-1].split('_')[2])
            if em not in em_list:
                em_list.append(em)
            temp = np.genfromtxt(directory + '/' + f, delimiter=' ')
            sb = int(f.removesuffix('.spec').split('_')[-1])

            for r in temp:
                if abs(r[0] - wp) == min(np.absolute(temp.T[0] - wp)):
                    if sb in sb_data.keys():
                        sb_data[sb] += [r[1]]
                    else:
                        sb_data[sb] = [r[1]]



    #### over carrier ####
    # 1st file red
    #scale = 217418.01015936

    # 2nd file red
    #scale = 934809.9520327

    # 1st file blue
    #scale = 513003.42938655

    # 2nd file blue
    #scale = 1108848.98522876


    #### no carrier ####
    # 1st file red
    #scale = 1.77494758e+12

    # 2nd file red
    #scale = 7.67618738e+12

    # 1st file blue
    #scale = 4.42554116e+12

    # 2nd file blue
    #scale = 9.22848173e+12

    scale = 7e12


    em = np.array(em_list)
    mw = (np.array(em)/43)**2


    c = {4:'grey',2:primary_colour}

    th_lines = {4:0, 2:0}
    for sb,y in sb_data.items():
        residual = residuals(em, scale*np.array(y), abs(sb))
        if residual < best_residual:
            best_residual = residual
            best_n = run_n

        th_lines[abs(sb)], = ax.plot(em,scale*np.array(y), c=c[abs(sb)], alpha = 0.1, rasterized=True)

print(best_n)

if red0blue1:
    pm = 'blue'
else:
    pm = 'red'



ex_line2 = ax.errorbar(exp_data.T[1][exp_sb2>twoerr], exp_sb2[exp_sb2>twoerr], c=primary_colour, fmt='o', yerr=twoerr[exp_sb2>twoerr], markersize=8)
ex_line4 = ax.errorbar(exp_data.T[1][exp_sb4>fourerr], exp_sb4[exp_sb4>fourerr], c='black', fmt='^', yerr=fourerr[exp_sb4>fourerr], markersize=8)


ax.set_xscale('log')
ax.set_yscale('log')
#plt.ylim([0,45000])
ax.set_xlim([23,687])
ax.set_ylim([1e1,1e5])
th_lines[2], = ax.plot(1,0.0001,c=primary_colour)
th_lines[4], = ax.plot(1,0.0001,c='grey')
#ax.legend([th_lines[2], th_lines[4], ex_line2, ex_line4], ['Theory 2$^\\text{nd}$','Theory 4$^\\text{th}$', 'Exp. 2$^\\text{nd}$', 'Exp. 4$^\\text{th}$'], fontsize = 16, markerscale=0.5, frameon=False)


#plt.xlabel('Microwave Field Strength, Vm$^{-1}$')
plt.tight_layout(pad=4)
ax.set_xlabel('Microwave field strength, Vm$^{-1}$', fontsize=18)
#plt.ylabel('$|\chi(\omega_p + n\omega_{MW})|^2$')
ax.set_ylabel('Counts Per Second', fontsize = 18)
ax.set_xticks([40,60,100,200,400,600])
plt.yticks(fontsize=18)
plt.xticks(fontsize=18)
ax.xaxis.set_major_formatter(ScalarFormatter())

if red0blue1:
    title = f'Energy = {peak} eV, blue sidebands'
    filename = f'{peak}_blue_sidebands_vs_exp_1st_red_norm_band_clusters'
else:
    title = f'Energy = {peak} eV, red sidebands'
    filename = f'{peak}_red_sidebands_vs_exp_1st_red_norm_band_clusters'

#plt.title(title, fontsize = 18)

plt.text(30, 30000,'(a)', fontsize=18)

plt.savefig(filename + '.eps')
