import numpy as np
import matplotlib.pyplot as plt
import os
os.chdir(os.path.dirname(__file__))
plt.style.use(r'../../../vinstyle.mplstyle')

#%% setup 
def parsefile(fname):
    with open(fname) as dfile:
        keys = []
        items = []
        for i, line in enumerate(dfile):
            if (i > 10) and (i < 41):
                if ':' in line:
                    l1 = line.split(':')
                    keys.append(l1[0].replace(' ','').strip('#'))
                    items.append(list(map(int, np.fromstring(l1[1],sep=' '))))
                else:
                    items[-1] += list(map(int, np.fromstring(line.strip('#'), sep=' ')))
            elif 'NBASIS' in line:
                size1 = np.fromstring(line.split(':')[-1], 
                                      sep=' ', dtype=int)[-1] + 1
            elif 'TOTAL NUMBER OF POINTS' in line:
                size0 = int(line.split('IS')[-1])
                break
            elif not '#' in line:
                raise IOError('Could not find the size of the data.')
                
        data = np.zeros((size0, size1))
        i = 0
        for line in dfile:
            if not '#' in line:
                larr = np.fromstring(line, sep=' ')
                j = len(larr)
                data[i//size1, i%size1 : i%size1 + j] = larr
                i += j
    return data, keys, items


#%% parse data
fnames = ['bound-RbCs-wave-193G-100kHz.txt',
              'bound-RbCs-wave-181-5G-2MHz-SIF.txt', 
              'bound-RbCs-wave-196-7G-2MHz-SIF.txt', 
              ]

alllist = []
for ii, fname in enumerate(fnames):
    data, labels, states = parsefile(fname)
    states = np.array(states, dtype=int) // 2
    states[-1] *= 2
    states = np.concatenate((states, [4-np.array(states[-2])])).T
    ps = [sum(data[:,i]**2) for i in range(1,49)]
    ps = np.array(ps) /sum(ps)
    alllist.append((data, states, ps))
    print('.',end='')
    
#%% plot
letters = ['(i)', '(ii)', '(iii)']
fig, axs = plt.subplots(len(fnames), 1, sharex=True, figsize=(7,6))
plt.subplots_adjust(hspace=0)
for ii, (data, states, ps) in enumerate(alllist):
    inds = sorted(range(len(ps)), reverse=True, key=lambda i: ps[i])[:3] # top 3 contributions
    print('\n', fname)
    print('Singlet: ', sum(ps[np.where(states.T[0] == 0)[0]]))
    print('Triplet: ', sum(ps[np.where(states.T[0] == 1)[0]]))
    for i in inds:
        print(' '.join(states[i].astype(str)), '    ', ps[i] )
        axs[ii].semilogx(data[:,0], data[:,i+1], label=', '.join(states[i].astype(str)))
    l = axs[ii].legend(title='S, I, F, M$_F$, L, M$_L$' if not ii else None, 
                   loc='upper left' if not ii else 'lower right', title_fontsize=14)
    t = axs[ii].text(250, axs[ii].get_ylim()[1]*0.6, letters[ii])
    if not ii:
        axs[0].set_ylim(-0.026, 0.1)
        t.set_y(0.08)
        l.set_bbox_to_anchor([0,1.3])
axs[0].set_xlim(min(data[:,0]), max(data[:,0]))
axs[-1].set_xlabel('Internuclear Separation ($a_0$)')
axs[1].set_ylabel('Amplitude')
plt.savefig('wavefunctions.svg', bbox_inches='tight')