import numpy as np
import subprocess as subp
import random
import os,shutil

Kelvin_to_Hartree = 3.1668153222559E-006
Atomic_Mass_Unit = 1.660538921e-27
electron_mass =  9.10938215E-031
amu_to_au = Atomic_Mass_Unit/electron_mass
Bohr_to_Angstrom = 0.52917720859


def find_mass(A):
    if A=="Dy":
	return 163.9291748e0
    elif A=="Er":
	return 165.9302931e0
    elif A=="Yb":
	return 173.9388621e0
    elif A=="Sr":
	return 87.9056121e0
    print "Invalid atom"
    A=0/0

def replace(file_name,flag,value):
    rtn1=subp.call(["sed","-i",r"s/{}/{}/".format(flag,value),file_name])
    if (rtn1!=0):
	print "Failed, replace:", file_name,flag,value, rtn1

run_no=0

sys_A="Er"
sys_A="Dy"
sys_B="Sr"
sys_B="Yb"

A=sys_A
B=sys_B

m_A=find_mass(A)
m_B=find_mass(B)
ured=m_A*m_B/(m_A+m_B)

bd_E_min=-250.0
bd_E_max=-0.001

mol_E=1e-7
fld_E=1e-7

B_min=0.1
B_max=200.0
d_B=0.1
B_min=90.0
B_min=0.1
B_max=500.0

mol_name="m_S_L"
bd_name="b_S_L"
fld_name="f_S_L"

ref_name=".ref"
in_name=".in"
out_name=".out"

script_name="seq.slurm"

top_dir="../"
exe_dir=top_dir+"/"+"exe"
fld_exe=exe_dir+"/"+"field-1S_L_at"
mol_exe=exe_dir+"/"+"molscat-1S_L_at"
bd_exe=exe_dir+"/"+"bound-1S_L_at"


rmax=2000
k=np.sqrt(2*fld_E*Kelvin_to_Hartree*ured*amu_to_au)/Bohr_to_Angstrom
boundary=-k*np.tan(k*rmax)

# FIELD
fld_infile_name=fld_name+in_name
fld_outfile_name=fld_name+out_name
shutil.copy(fld_name+ref_name,fld_infile_name)
replace(fld_infile_name,"$FLDMIN",B_min)
replace(fld_infile_name,"$FLDMAX",B_max)
replace(fld_infile_name,"$ENERGY",fld_E)
replace(fld_infile_name,"$BCYOMX",boundary)

mstdin=open(fld_infile_name,"r")
mstdout=open(fld_outfile_name,"w")
rtn1=subp.call([fld_exe],stdin=mstdin,stdout=mstdout,stderr=mstdout)
mstdin.close()
mstdout.close()
if (rtn1!=0):
    print "Run failed:", rtn1

fld_outfile=open(fld_outfile_name,"r")
states_list=[]
N_expected=0
found_state=False
reached_max=False
for line in fld_outfile:
    if "MONOTONIC" in line:
        data=line.split()
        N_expected=int(data[7])
    elif "SEARCHING" in line:
        found_state=True
        data=line.split()
        state_dict={}
        state_dict["N"]=int(data[5])
    if not found_state:
        continue
    if "CONVERGED ON STATE" in line:
        data=line.split()
        state_dict["B_res_fld"]=float(data[10])
        states_list.append(state_dict)
        found_state=False
    elif "FAILED" in line:
        state_dict["B_res_fld"]=None
        states_list.append(state_dict)
        found_state=False
        print "FIELD Failed"
        print line
    elif "MXCALC" in line:
        print "Reached MXCALC"
fld_outfile.close()
os.remove(fld_infile_name)
os.remove(fld_outfile_name)
if len(states_list)!=N_expected:
    print "Mismatch between states expected and scraped from outfile"
if len(states_list)==0:
    print "No states found from outfile"
print "states from FIELD:"
for A in states_list:
    print A

for A in states_list:
    B_res=A["B_res_fld"]
    spread=0.01
    for B in states_list:
        if A==B:
	    continue
        if abs(B_res-B["B_res_fld"])<20.0*spread:
	    spread=abs(B_res-B["B_res_fld"])/20.0
    A["spread"]=spread
    A["t_lo"]=-0.1
    A["t_hi"]=1.0
    A["xi"]=0.25
    A["dtol"]=1e-8
    A["IFCONV"]=1
    A["bg_offset"]=0.0
    A["run"]=True

max_iter=30
for n_iter in range(max_iter):
    print "iteration",n_iter
    for A in states_list:
        if A["run"]:
	    run_no+=1
	    print "Rerunning with altered paramaters:"
	    print "run_no",run_no
	    print A
	    mol_infile_name=mol_name+".{}".format(run_no)+in_name
	    mol_outfile_name=mol_name+".{}".format(run_no)+out_name
	    shutil.copy(mol_name+ref_name,mol_infile_name)
	    replace(mol_infile_name,"$FLDMIN",A["B_res_fld"]-A["spread"])
	    replace(mol_infile_name,"$FLDMAX",A["B_res_fld"]+A["spread"])
	    offset_fac=max(1.0,abs(A["bg_offset"]/100.0))
	    replace(mol_infile_name,"$TLO",A["t_lo"]*offset_fac)
	    replace(mol_infile_name,"$THI",A["t_hi"]*offset_fac)
	    replace(mol_infile_name,"$XI",A["xi"])
	    replace(mol_infile_name,"$DTOL",A["dtol"])
	    replace(mol_infile_name,"$IFCONV",A["IFCONV"])
	    replace(mol_infile_name,"$A_0",A["bg_offset"])
	    replace(mol_infile_name,"$ENERGY",mol_E)
	    replace(mol_infile_name,"$ICHAN",1)

	    mstdin=open(mol_infile_name,"r")
	    mstdout=open(mol_outfile_name,"w")
	    rtn1=subp.call([mol_exe],stdin=mstdin,stdout=mstdout,stderr=mstdout)
	    mstdin.close()
	    mstdout.close()
	    if (rtn1!=0):
		print "Run failed:", rtn1

	    mol_outfile=open(mol_outfile_name,"r")
	    found_res=False
	    a_min=np.inf
	    a_max=-np.inf
	    a_bg_min=np.inf
	    a_bg_max=-np.inf
	    Delta_min=np.inf
	    Delta_max=-np.inf
	    dist_max=0.0
	    N_close=0
	    N_far=0
            found_params=False
	    for line in mol_outfile:
		if "3-POINT POLE FORMULA" in line:
		    found_params=True
		    data=line.split()
#		    A["B_res_mol"]=float(data[8])
		    B_res=float(data[8])
		elif "CONVERGED ON RESONANCE" in line:
		    found_res=True
		    break
		elif "PREDICTED CORRECTION DUE TO" in line:
		    found_res="decayed"
		    break
		elif "  1    0   1 " in line:
		    data=line.split()
		    a=float(data[4])
		elif not found_params:
		    continue
		elif "DELTA" in line:
		    data=line.split()
#		    A["Delta"]=float(data[2])
		    Delta=float(data[2])
		elif "A_BG" in line:
		    data=line.split()
		    a_bg=float(data[2])
		    if np.isnan(Delta) or np.isnan(a_bg) or np.isnan(B_res):
			break
		    A["B_res_mol"]=B_res
		    A["Delta"]=Delta
		    A["a_bg"]=a_bg
		    if abs(B_res-A["B_res_fld"])<1e-5:
			N_close+=1
			a_max=max(a_max,a)
			a_min=min(a_min,a)
			a_bg_max=max(a_bg_max,a_bg)
			a_bg_min=min(a_bg_min,a_bg)
			Delta_max=max(Delta_max,Delta)
			Delta_min=min(Delta_min,Delta)
		    else:
			N_far+=1
			dist_max=max(dist_max,abs(B_res-A["B_res_fld"]))
		    found_params=False
	    mol_outfile.close()
	    os.remove(mol_infile_name)
	    os.remove(mol_outfile_name)
	    if not found_res and N_close>2*N_far and N_close>=3:
		if abs((a_bg_max-a_bg_min)/(a_bg_max+a_bg_min))>0.05:
		    A["convergence"]="bg_offset"
		else:
		    if A["t_hi"]>500:
			print "Too unstable, even with widely spaced points. Giving up and using final params."
			A["convergence"]="given up"
			A["why"]="unstable"
		    else:
			A["convergence"]="close"
	    elif found_res=="decayed":
		A["convergence"]="decayed"
	    elif not found_res:
		if (a_max-a_min)<1e-6 and N_close>=3:
		    print "No signs of resonance observed:",a_min,a_max,N_close
		    print "Giving up and assigning Delta=0"
		    A["convergence"]="given up"
		    A["why"]="No signs"
		    A["Delta"]=0.0
		else:
		    A["convergence"]=False
	    else:
		if abs(A["B_res_mol"]-A["B_res_fld"])>1e-5:
		    print "pole appears to be in a different position to FIELD calculation"
		    A["convergence"]="mislocated"
		else:
		    A["convergence"]="converged"
	    print A
    for A in states_list:
	A["run"]=False
	if A["convergence"]=="close":
	    A["t_hi"]*=10
	    A["t_lo"]*=10
	    A["run"]=True
	    continue
	elif A["convergence"]=="given up":
	    continue
	elif A["convergence"]=="bg_offset":
	    if A["bg_offset"]==0:
		if A["a_bg"]>0:
		    A["bg_offset"]=-200.0
		else:
		    A["bg_offset"]=200.0
	    else:
		A["bg_offset"]*=5
	    A["run"]=True
	    continue
	elif A["convergence"]=="decayed":
	    A["IFCONV"]=2
	    A["run"]=True
	    continue
	elif A["convergence"]=="mislocated":
	    A["spread"]/=10
	    if A["spread"]>1e-7:
		A["run"]=True
	    else:
		print "Giving up on this resonance, influence from nearby too great. Assigning Delta=0"
		A["Delta"]=0.0
		A["convergence"]="given up"
		A["why"]="mislocated"
	    continue
	elif not A["convergence"]:
	    A["spread"]/=10
# Could also need to fiddle with t_lo and t_hi here, but let's see how this works for now
	    if A["spread"]>1e-7:
		A["run"]=True
		continue
	    else:
		print "Giving up on this resonance, reasons unknown, assigning Delta=0"
		A["Delta"]=0.0
		A["convergence"]="given up"
		A["why"]="unknown"
        B_res=A["B_res_mol"]
        for B in states_list:
	    if A==B:
		continue
	    if B["convergence"]!="converged":
		continue
# This part is kinda hokey, and may not actually be necessary.
# But at least it makes sense dimensionally now *shrug*
	    factor=abs(A["t_hi"]*A["Delta"]*B["Delta"]/(B_res-B["B_res_mol"])**2)
	    factor/=max(1.0,abs(a-B["Delta"]/(B_res-B["B_res_mol"])))
	    while factor>0.005:
		A["t_hi"]/=2
		A["t_lo"]/=2
		A["run"]=True
		factor=abs(A["t_hi"]*A["Delta"]*B["Delta"]/(B_res-B["B_res_mol"])**2)
		factor/=max(1.0,abs(a-B["Delta"]/(B_res-B["B_res_mol"])))
    reruns = sum([1 for A in states_list if A["run"]])
#    print "reruns array"
#    print reruns
    if reruns==0:
        print "No reruns requested, breaking"
        break
    else:
	print "{} reruns requested".format(reruns)

if reruns!=0:
    print
    print "not converged in {} iterations".format(max_iter)
print
print "Full final results after refinement:"
for A in states_list:
    print A
print
print "Remaining problem states:"
for A in states_list:
    if A["convergence"]=="converged":
	continue
    print A
summary_file=open("summary.out","w")
for A in states_list:
    line=""
    line+="{} ".format(A["N"])
    if A["convergence"]=="converged":
	line+="{} ".format(A["B_res_mol"])
	line+="{} ".format(A["Delta"]*A["a_bg"])
	line+="{} ".format(A["a_bg"]+A["bg_offset"])
    elif A["convergence"]=="given up":
	if A["why"]=="unstable":
	    line+="{} ".format(A["B_res_mol"])
	    line+="{} ".format(0.0)
	    line+="{} ".format(A["a_bg"]+A["bg_offset"])
	else:
	    line+="{} ".format(A["B_res_fld"])
	    line+="{} ".format(0.0)
	    line+="{} ".format(0.0)
    line+="\n"
    summary_file.write(line)
summary_file.close()
