import numpy as np
import argparse
import scipy.linalg
import scipy.sparse.linalg as ssl
from scipy.special import sph_harm
from pathlib import Path


def w(energies,gamma,wp,laser=True):
    # takes energies, widths, and the probe frequency, and outputs detuning of the probe from the VB->nL transition (plus complex width)
    return np.array(energies - energies[0] - wp - 1j*gamma)


def A_diag(N,wc,energies,gamma,wp):
    return w(energies,gamma,wp) - N*wc


def M(nrange,wc,energies,gamma,omega,wp,dim):
    # constructs the matrix for the N coupled optical Bloch equations
    Vplus = -0.5*omega
    Vminus = -0.5*np.conjugate(omega.T)
    n = len(nrange)

    # it needs to be this complicated for the block tri-diagonal structure
    up_diag = np.append(np.insert(np.identity(n-1),0,np.array([0]*(n-1)),axis=1),np.array([0]*n)).reshape(n,n)
    down_diag = up_diag.T
    M = np.hstack(np.hstack(np.outer(up_diag, Vminus).reshape(n,n,dim,dim))).astype(complex)
    M += np.hstack(np.hstack(np.outer(down_diag, Vplus).reshape(n,n,dim,dim)))
    M += np.diag(np.array([A_diag(N,wc,energies,gamma,wp) for N in nrange]).flatten())

    return M


def cft(cn, j, t, wc, dim, nrange):
    # 'coefficient fourier transform' : turns the N c_i;n's into the coefficient of |i>, c_i. Also outputs the list c_i;n.
    c = 0+0j
    data_out = []
    for i,n in enumerate(nrange):
        c += cn[i*dim+j]*np.exp(-1j * n * wc * t)
        data_out += [cn[i*dim+j]]
    data_out = np.array(data_out)
    return c, data_out



def small_d(l,m1,m2,beta):
    # a function for the Wigner small d matrix to rotate between the probe and microwave bases
    prefac = np.sqrt(np.math.factorial(l+m1)*np.math.factorial(l-m1)*np.math.factorial(l+m2)*np.math.factorial(l-m2))
    s_list = np.arange(max(0,m2-m1), min(l+m2,l-m1)+1)
    tot = 0
    for	s in s_list:
        num1 = (-1)**s * 1j**(m2-m1)
        num2 = np.cos(beta/2) ** (2*l + m2 - m1 - 2*s)
        num3 = np.sin(beta/2) ** (m1 - m2 + 2*s)
        num = num1*num2*num3
        den = np.math.factorial(l + m2 - s) * np.math.factorial(s) * np.math.factorial(m1 - m2 + s) * np.math.factorial(l - m1 - s)
        tot += num/den

    return prefac * tot



def sh_integral(li,mi,lf,mf,beta_deg):
    # spherical harmonics overlap integral <Ylm| cos theta |Yl'm'>
    beta_rad = np.pi*beta_deg/180
    n = 300
    th = np.linspace(0,np.pi,n)
    ph = np.linspace(0,2*np.pi,2*n)
    dth = dph = np.pi/n
    x,y = np.meshgrid(th,ph)

    #Yi = np.conjugate(sph_harm(mi,li,y,x))
    #Yf = sph_harm(mf,lf,y,x)

    mi_list = np.arange(-li,li+1)
    rot_Yi = sum([small_d(li,m,mi,beta_rad)*sph_harm(m,li,y,x) for m in mi_list])

    mf_list = np.arange(-lf,lf+1)
    rot_Yf = sum([np.conjugate(small_d(lf,m,mf,beta_rad)*sph_harm(m,lf,y,x)) for m in mf_list])

    lmclm = sum((rot_Yf*np.cos(x)*rot_Yi*dth*dph).flatten())
    return lmclm



def init_dipole(state_labels, mat_el, spherical=False, pol_ang=0, percent_error = 0):

    # constructs the dipole matrix, given a list of states (nlm or nl) and an array of matrix elements < nl(m) | D | n'l'(m') >

    dim = len(state_labels)
    dipole = np.zeros((dim,dim), dtype=complex)

    if spherical:
        for i,[n1,l1,m1] in enumerate(state_labels):
            for j,[n2,l2,m2] in enumerate(state_labels):
                mask = (mat_el.T[0]==n1) & (mat_el.T[1]==l1) & (mat_el.T[2]==n2) & (mat_el.T[3]==l2)
                if any(mat_el.T[4][mask]):
                    if n1 and n2:
                        dipole[i][j] = abs(mat_el.T[4][mask]) * abs(sh_integral(l1,m1,l2,m2,pol_ang)) * (1+float(percent_error)/100)
                    elif n1==0 or n2==0:
                        if m2==init_m and not mixed_m:
                            dipole[i][j] = mat_el.T[4][mask] * (1+float(percent_error)/100)
                        elif mixed_m:
                            dipole[i][j] = abs(mat_el.T[4][mask])/np.sqrt(2*(l1 or l2)+1) * (1+float(percent_error)/100)

    else:
        for i,[n1,l1] in enumerate(state_labels):
            for j,[n2,l2] in enumerate(state_labels):
                mask = (mat_el.T[0]==n1) & (mat_el.T[1]==l1) & (mat_el.T[2]==n2) & (mat_el.T[3]==l2)
                if any(mat_el.T[4][mask]):

                    if n1 and n2:
                        dipole[i][j] = abs(mat_el.T[4][mask]) * (1+float(percent_error)/100)
                    elif n1==0 or n2==0:
                        dipole[i][j] = mat_el.T[4][mask] * (1+float(percent_error)/100)

    dipole += np.conjugate(dipole.T)

    return dipole




def main():

    # parsing arguments into the program
    parser = argparse.ArgumentParser(description='Calculate the absorption and generated sidebands for a laser incident on a system of excitons in the presence of a strong microwave field')
    parser.add_argument('filename', metavar='filename', type=str, nargs=1, help='The name (without file extension!) of the files containing the exciton series energies and widths (.ex), and the matrix elements (.mat)')
    args = parser.parse_args()


    filename = args.filename[0]


    # reading in parameters from the .param file
    with open(filename + '.param','r+') as f:
        d = {}
        for line in f:
            if line[0] == '#':
                continue
            if line.strip() == '':
                continue
            parts = [s.strip() for s in line.split(':')]
            parts[0] = parts[0].lower()
            d[parts[0]] = parts[1]

    # reading in exciton spectrum energies and widths
    ex = np.genfromtxt(d['excitons'], delimiter=' ', dtype=float)
    energies = ex.T[2]
    gamma = ex.T[3]

    # reading in the nL -> n'L' dipole elements
    mat_el = np.genfromtxt(d['mat_el'], delimiter=' ',dtype=complex)

    # reading in width asymmtries, optional
    if 'asymmetry' in d:
        asym = np.genfromtxt(d['asymmetry'], delimiter=' ', dtype=float)
    else:
        asym = np.zeros(len(energies))

    # initialising the laser frequency axis
    if 'probe_freq_range' in d:
        freq_limits = np.float_(np.array(d['probe_freq_range'].split(',')))
        freq_points = int(d['probe_freq_points'])
        freqrange = np.linspace(freq_limits[0],freq_limits[1],freq_points)
    elif 'probe_freq_list' in d:
        freqrange = np.array([float(f) for f in d['probe_freq_list'].split(',')], dtype=float)
        freq_points = len(freqrange)
    # point_dist() isn't finished yet
    #freqrange = point_dist(freq_limits[0],freq_limits[1],freq_points,energies)

    # how many Floquet components to include
    N = int(d['floquet_n'])
    nrange = np.arange(-int(N/2), -int(N/2)+N)

    # max microwave field strength [V/um]
    if 'em' in d:
        Em = float(d['em'])
    else:
        Em = 0

    # max probe field strength [V/um]
    Ep = float(d['ep'])

    # effective number density of ground states excitons
    nd = float(d['nd'])

    # microwave field frequency [eV]
    wc = float(d['microwave_freq'])

    # depth of the sample
    L = float(d['l'])

    # polarisation angle from microwaves to probe
    if 'polarisation_angle' in d:
        pol_ang = float(d['polarisation_angle'])
    else:
        pol_ang = 0

    # which sidebands to write to
    if 'sidebands' in d:
        sidebands = [int(s) for s in d['sidebands'].split(',')]
    else:
        sidebands = [0]

    # spherical harmonics toggle
    spherical = True
    if 'spherical' in d:
        if d['spherical'].lower() == 'false':
            spherical = False

    # which m states to create excitons in
    mixed_m = False
    init_m = 0
    if 'init_m' in d:
        if d['init_m'].strip('-').isnumeric():
            init_m = int(d['init_m'])
        elif d['init_m'].lower() == 'mixed':
            mixed_m = True

    # time, t, for evaluating coefficients of wavefunction
    if 'time' in d:
        time = float(d['time'])
    else:
        time = 1e5

    # percentage error that can be added to the matrix elements for testing sensitivity
    mat_el_error = 0
    if 'mat_el_error' in d:
        mat_el_error = float(d['mat_el_error'])

    # for debugging, you may want to output c_j_N at various energies
    if 'c_j_N output' in d:
        c_j_N_output_energies = np.array([float(f) for f in d['c_j_N output'].split(',')], dtype=float)

    # the user can request Floquet eigenstates instead of steady state coupling to the valence band
    floquet_eig = False
    if 'floquet eigenstates' in d:
        if d['floquet eigenstates'].lower() == 'true':
            floquet_eig = True



    # initialise arrays of various quantities ordered by nlm

    # energies widths and asymmetries
    m_energies = []
    m_gamma = []
    m_asym = []

    # ordered state labels
    nlm = [[],[],[]]
    nl = ex[:,:2]

    # populate them
    for i,l in enumerate(ex.T[1]):
        l = int(l)
        m_energies += [energies[i]]*(2*l+1)
        m_gamma += [gamma[i]]*(2*l+1)
        m_asym += [asym[i]]*(2*l+1)
        nlm[0] += [int(ex.T[0][i])]*(2*l+1)
        nlm[1] += [l]*(2*l+1)
        nlm[2] += list(range(-l,l+1))


    m_energies = np.array(m_energies)
    m_gamma = np.array(m_gamma)
    m_asym = np.array(m_asym)
    nlm = np.array(nlm).T

    if spherical:
        dim = len(nlm)
    else:
        dim = len(nl)

    # TODO replace nlm and nl with a variable called state labels
    if spherical:
        dipole = init_dipole(nlm, mat_el, spherical=True, pol_ang=pol_ang, percent_error=mat_el_error)
    else:
        dipole = init_dipole(nl, mat_el, percent_error=mat_el_error)

    omega = np.zeros(dipole.shape, dtype=complex)
    omega[1:,1:] = -dipole[1:,1:]*Em
    omega[0,:] = -dipole[0,:]*Ep
    omega[:,0] = -dipole[:,0]*Ep




    # solve the system for the Floquet eigenstates (no coupling to the valence band)
    if floquet_eig:
        if spherical:
            floquet_hamiltonian = M(nrange,wc,m_energies,m_gamma,omega,0,dim)
        else:
            floquet_hamiltonian = M(nrange,wc,energies,gamma,omega,0,dim)

        eigvals, lvecs, rvecs = scipy.linalg.eig(floquet_hamiltonian, left=True, right=True)

        norms = np.diag(np.matmul(np.conj(lvecs.T), rvecs))
        lvecs = lvecs/norms

        with open(filename + '.floq_eig_vals','w+') as f:
            for e in eigvals:
                f.write(str(e) + '\n')

        # these arrays are too big to practically store as CSVs, so I store them as .npy binaries. To open them, use np.load()
        np.save(filename + '_floq_l_vecs', lvecs)
        np.save(filename + '_floq_r_vecs', rvecs)




    # the vector d in the system to be solved: Mc = d; this gives a coupling to the valence band
    d = np.zeros(dim*N, dtype=complex)
    for j in range(dim):
        d[np.where(nrange==0)[0][0]*dim+j] = omega[j,0]/2


    # solve the system for each laser energy requested
    if not floquet_eig:
        for wp in freqrange:
            floquet_hamiltonian = M(nrange,wc,m_energies,m_gamma,omega,wp,dim)

            # solve the system in the inhomogenous case (coupling to the valence band)
            c_j_N = ssl.spsolve(floquet_hamiltonian,d)

            # for each sideband calculate the susceptibility and absorption
            for sideband in sidebands:
                c_j_s = c_j_N[np.where(nrange==sideband)[0][0]*dim : np.where(nrange==sideband)[0][0]*dim+dim]
                # TODO write out these factors in full so they are easier to read
                susceptibility = nd/(8.854187e1*Ep) * (-1.602176634) * sum(dipole[0] * c_j_s * (1-1j*m_asym.T[2]))
                absorption = 2*(wp*5.0677307) * np.imag(np.sqrt(1+susceptibility))


                # open a file for putting primary data
                with open(filename + f'_sb_{sideband}.spec', 'a+') as f:
                    # if the user wants carrier data, write alpha L
                    if sideband == 0:
                        f.write(str(wp) + ' ' + str(absorption*L) + '\n')
                    # if the user wants sideband data, write |chi|^2
                    else:
                        f.write(str(wp) + ' ' +  str(np.absolute(susceptibility)**2) + '\n')


                # at requested energies save the entire state of the system
                for cjne in c_j_N_output_energies:
                    if abs(wp - cjne) == min(abs(freqrange - cjne)):
                        with open(filename + f'_{cjne}.cjn','w+') as f:
                            for i in c_j_N:
                                f.write(str(i) + '\n')


if __name__ == '__main__':
    main()
