import numpy as np
import sys
from matplotlib import pyplot as plt
from exisolver import init_dipole
import matplotlib as mpl
from matplotlib.colors import ListedColormap


plt.rcParams['font.family'] = 'Latin Modern Roman'
plt.rcParams['font.size'] = 18

reds = mpl.colormaps['Reds']
reds_dat = reds(np.linspace(0,1,256))
reds_dat[:25,:] = np.linspace([1,1,1,1], reds_dat[25,:], 25)
new_reds = ListedColormap(reds_dat)
fig,(ax1,ax2) = plt.subplots(2,1,sharex=True,height_ratios=[0.3,0.7], figsize=(13,4))

def get_em(x):
    return float(x.split('/')[-1].split('_')[2])

def absorption(wp, sus):
    return 2*(wp*5.0677307) * np.imag(np.sqrt(1+sus))

def weights_gen(floq_eig_vals, right_eig_vecs, left_eig_vecs, d, Ep):
    weights = np.zeros(len(right_eig_vecs.T), dtype=complex)
    for i,u in enumerate(left_eig_vecs.T):
        num = np.vdot(u,d) * right_eig_vecs.T[i]
        weights[i] = np.vdot(d, num)
        #weights[i] = np.vdot(-2*d/Ep, num)
    return weights

laser_energy = float(sys.argv[1])
files = sorted(sys.argv[2:], key = get_em)

sd_and_usd_thresholds = np.genfromtxt('sd_and_usd_thresholds.txt', delimiter=' ')
ex = np.genfromtxt('cu2o_230807.ex', delimiter=' ')
mat_el = np.genfromtxt('cu2o_230807.mat', delimiter=' ')

target_em = 688
N = 31
nd=1
nrange = np.arange(-int(N/2), -int(N/2)+N)
Ep = 1.6e-8
dipole = init_dipole(ex[:,:2], mat_el)
dim = len(ex)
d = np.zeros(dim*N, dtype=complex)
for j in range(dim):
    d[np.where(nrange==0)[0][0]*dim+j] = -Ep * dipole[j,0]/2

# what if we didn't use the <nP|D|0> bit
#d = (d!=0)

max_red_weights = []
max_grey_weights = []
laser_range = np.linspace(2.169,2.1718,301)
absorptions = np.zeros((len(laser_range), len(files)))

for i_f, f in enumerate(files):
    em = get_em(f)
    floq_eig_vals = np.genfromtxt(f, delimiter=' ', dtype=complex)

    left_eig_vecs = np.load(f.removesuffix('.floq_eig_vals') + '_floq_l_vecs.npy')
    right_eig_vecs = np.load(f.removesuffix('.floq_eig_vals') + '_floq_r_vecs.npy')

    for i_l, wp in enumerate(laser_range):
        if em == target_em:
            grey_weights = weights_gen(floq_eig_vals, right_eig_vecs, left_eig_vecs, d!=0, Ep)
            red_weights = weights_gen(floq_eig_vals, right_eig_vecs, left_eig_vecs, d, Ep)

            for i,r in enumerate(red_weights):
                red_weights[i] = red_weights[i]/(- floq_eig_vals[i] + wp)

            susceptibility = (-2/Ep) * nd/(8.854187e1*Ep) * (-1.602176634) * sum(red_weights)
            absorptions[i_l,i_f] = 2*(wp*5.0677307) * np.imag(np.sqrt(1+susceptibility))
            #if abs(wp - laser_energy) == min(abs(laser_range - laser_energy)):
                #grey_weights = abs(grey_weights)
                #red_weights = abs(np.imag(red_weights))
                #max_red_weights.append(max(red_weights))
                #max_grey_weights.append(max(grey_weights))

                #weights = weights/max(weights)
                #grey_weights = grey_weights/1.76713e-13
                #grey_weights = grey_weights/1.1197310677025312
                #red_weights = red_weights/7.027372670476415e-19
                #ax2.scatter(np.real(floq_eig_vals), [em]*len(floq_eig_vals), alpha=grey_weights, c='grey', rasterized=True, s=5)
                #ax2.scatter(np.real(floq_eig_vals), [em]*len(floq_eig_vals), c='red', alpha=red_weights, rasterized=True, s=5)

        if abs(wp - laser_energy) == min(abs(laser_range - laser_energy)):
            grey_weights = weights_gen(floq_eig_vals, right_eig_vecs, left_eig_vecs, d!=0, Ep)
            red_weights = weights_gen(floq_eig_vals, right_eig_vecs, left_eig_vecs, d, Ep)

            for i,r in enumerate(red_weights):
                red_weights[i] = red_weights[i]/(- floq_eig_vals[i] + laser_energy)

            grey_weights = abs(grey_weights)
            red_weights = abs(np.imag(red_weights))
            max_red_weights.append(max(red_weights))
            max_grey_weights.append(max(grey_weights))

            #weights = weights/max(weights)
            #grey_weights = grey_weights/1.76713e-13
            grey_weights = grey_weights/1.1197310677025312
            #red_weights = red_weights/7.027372670476415e-19
            red_weights = red_weights/8.074368047610392e-19
            ax2.scatter(np.real(floq_eig_vals), [em]*len(floq_eig_vals), alpha=grey_weights, c='grey', rasterized=True, s=5)
            ax2.scatter(np.real(floq_eig_vals), [em]*len(floq_eig_vals), alpha=red_weights, c='red', rasterized=True, s=5)




print('max red weight:', max(max_red_weights))
print('max grey weight:', max(max_grey_weights))

nPs = ex.T[2][ex.T[1]==1]
offset = 0
for i in range(len(nPs)):
    nP = nPs[i]
    if 2.169 < nP < 2.1718 and i + ex.T[0,1] < 21:
        ax2.plot([nP+offset, nP+offset], [0, 1100], color = "k", lw = 1, linestyle='dashed', alpha=0.5)
        if i + ex.T[0,1] < 12:
            ax2.text(nP+offset - 2e-5, 630, str(i+int(ex.T[0,1])) + "P", fontsize=18)

for row in absorptions.T:
    if any(row):
        ax1.plot(laser_range, -100*row, c='blue')



ax2.plot([laser_energy, laser_energy], [0, 1170], color='k', lw=1)
ax2.text(laser_energy - 20e-5, 800, 'Probe Laser', fontsize=18)

ax2.plot(sd_and_usd_thresholds[:,3],sd_and_usd_thresholds[:,4],c='orange')
ax2.plot(sd_and_usd_thresholds[:,3],sd_and_usd_thresholds[:,5],c='green')

ax2.set_xlabel('Eigenenergy of Floquet States, eV', fontsize=18)
ax2.set_ylabel('$\\mathcal{E}_\\text{M}$, Vm$^{-1}$', fontsize=18)
ax1.set_ylabel('$\\Delta \\alpha L$', fontsize=18)
ax1.yaxis.tick_right()
ax1.yaxis.set_label_position("right")
ax1.set_yticks([0,1,2])
ax2.set_xlim([2.17, 2.1718])
ax2.set_ylim([0,688])
plt.yticks(fontsize=18)
plt.xticks(fontsize=18)
fig.subplots_adjust(wspace=0)
#plt.margins(4)
plt.savefig('test.pdf')
