import numpy as np
import matplotlib.pyplot as plt
plt.style.use(r'../../../vinstyle.mplstyle')
from matplotlib.collections import LineCollection
from matplotlib import colors
from scipy.special import gamma
from scipy.constants import hbar, physical_constants
from scipy.interpolate import interp1d
from mpmath import hyp2f1
import os
os.chdir(os.path.dirname(__file__))

########### make colormaps
c0 = np.array(colors.to_rgba('C0')) # purple
c7 = np.array(colors.to_rgba('C7')) # lilac
c8 = np.array(colors.to_rgba('C8')) # pink
c9 = np.array(colors.to_rgba('C9')) # grey
x = np.array([np.linspace(0,1,256)]*3+[np.ones(256)]).T
cmap08 = colors.ListedColormap((c8-c0)*x + np.outer(c0, np.ones(256)).T)
cmap89f = colors.ListedColormap((c9-c8)*x + np.outer(c8, np.ones(256)).T)
cmap70 = colors.ListedColormap((c0-c7)*x + np.outer(c7, np.ones(256)).T)
cmap708 = colors.ListedColormap(np.concatenate((
    (c0-c7)*x[:128]*2 + np.outer(c7, np.ones(128)).T,
    (c8-c0)*x[:128] + np.outer(c0, np.ones(128)).T)))
cmap089 = colors.ListedColormap(np.concatenate((
    (c8-c0)*x[128:] + np.outer(c0, np.ones(128)).T,
    (c9-c8)*x[:128]*1.5 + np.outer(c8, np.ones(128)).T)))
cmap89 = colors.ListedColormap((c9-c8)*(x*.6+.4) + np.outer(c8, np.ones(256)).T)

########### equations
eta = 6  # anisotropy wr / wz
lw = 4   # linewidth for plotting
m1 = 133 # mass Cs
m2 = 87  # mass Rb
mu = m1*m2/(m1+m2) * physical_constants['atomic mass constant'][0]# reduced mass
wz = 2*np.pi * 13e3 # trap frequency in axial direction
d = np.sqrt(hbar/mu/wz) # scaling for distance units
dB = 0.09 # width of FR
abg = 680 * physical_constants['Bohr radius'][0] # background scattering length
Eb = hbar/mu/(abg)**2/wz # background binding energy
Em = 4.2e6*2*np.pi/wz   # gradient of dE/dB at FR

def aB(B, e):
    return abg/d* (1 - dB*(1 + e/Eb) / (B + e/Em - dB*e/Eb))

def F(xs, n=int(eta)):
    """x is units of energy, n is trap freq ratio but must be integer"""
    return np.array([-2*gamma(x)/gamma(x-0.5) + gamma(x)/gamma(x+0.5) * sum(
        hyp2f1(1, x, x+0.5, np.exp(2j*np.pi*m/n))
        for m in range(1, n)
    ) for x in xs], dtype=complex)

def lc_gradient(x, y, cmap, ls='solid', label=None):
    cs = np.linspace(0,250,len(x))
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    try:
        lc = LineCollection(segments, cmap=cmap, linestyle=ls, label=label, lw=lw)
    except ValueError:
        lc = LineCollection(segments, colors=cmap, linestyle=ls, label=label, lw=lw)
    lc.set_array(cs)
    return lc


#%%
########### plot

vardict = locals()

fig = plt.figure(figsize=(8,4), constrained_layout=False)
ax1 = plt.subplot2grid((1,4),(0,0),1,3)
ax2 = plt.subplot2grid((1,4),(0,3),1,1)
ax1.sharey(ax2)
ax1.spines['right'].set_visible(False)
ax2.spines['left'].set_visible(False)
ax1.tick_params(right=False, labelright=False)
ax2.yaxis.tick_right()
ax2.tick_params(labelright=False)
dl = .01 # how big to make the diagonal lines in axes coordinates
# arguments to pass plot, just so we don't keep repeating them
kwargs = dict(transform=ax1.transAxes, color='k', clip_on=False)
ax1.plot((1-dl,1+dl), (-dl,+dl), **kwargs)
ax1.plot((1-dl,1+dl),(1-dl,1+dl), **kwargs)
kwargs.update(transform=ax2.transAxes)  # switch to the bottom axes
ax2.plot((-dl,+dl), (1-dl,1+dl), **kwargs)
ax2.plot((-dl,+dl), (-dl,+dl), **kwargs)
plt.subplots_adjust(wspace=0.4)

############ ax2  calculate eigenstates
if not 'E3' in vardict:
    print('recalculating')
    E0 = np.concatenate((np.linspace(-15,-5,50), np.linspace(-4.99,-4.13,250)))
    a0 = -1/F(-E0/2, 6).real
    E1 = np.concatenate((np.linspace(-4.09,-3.6,150), np.linspace(-3.59,1,40),
                         np.linspace(1.01,1.288,150)))
    a1 = -1/F(-E1/2, 6).real
    E2 = np.concatenate((np.linspace(1.299,1.8,150), np.linspace(1.801,3.,40), 
                         np.linspace(3.01,3.32,150)))
    a2 = -1/F(-E2/2, 6).real
    E3 = np.concatenate((np.linspace(3.323,3.45,150), np.linspace(3.46,5,40), 
                         np.linspace(5.01,5.27,150)))
    a3 = -1/F(-E3/2, 6).real

Bs = np.concatenate((np.linspace(-.3,-0.03,50), np.linspace(0.03,.3,50)))
Elist = [np.zeros(len(Bs)), np.zeros(len(Bs)), np.zeros(len(Bs)), np.zeros(len(Bs))]
for i, B in enumerate(Bs):
    Elist[0][i] = E0[np.where(abs(a0-aB(B,E0))<1e-1)[0][-1]]
    Elist[1][i] = E1[np.where(abs(a1-aB(B,E1))<1e-1)[0][0]]
    Elist[2][i] = E2[np.where(abs(a2-aB(B,E2))<1e-1)[0][0]]
    Elist[3][i] = E3[np.where(abs(a3-aB(B,E3))<1e-1)[0][0]]

inds00 = np.where(Elist[0] > -9)[0]
inds01 = np.where(Elist[1] < 1)[0]
inds10 = np.where(Elist[1] > 1)[0]
inds11 = np.where(Elist[2] < 2.5)[0]
inds21 = np.where(Elist[2] > 2.5)[0]
inds22 = np.where(Elist[3] < 4.5)[0]

Bs0 = Bs[np.concatenate((inds00, inds01))]
Es0 = np.concatenate((Elist[0][inds00], Elist[1][inds01]))
Bs1 = Bs[np.concatenate((inds10, inds11))]
Es1 = np.concatenate((Elist[1][inds10], Elist[2][inds11]))
Bs2 = Bs[np.concatenate((inds21, inds22))]
Es2 = np.concatenate((Elist[2][inds21], Elist[3][inds22]))

Bx = np.linspace(-1,3,100)
for B, e, cmap in zip([Bs0, Bs1, Bs2], [Es0, Es1, Es2], [cmap708, cmap089, cmap89]):
    e = interp1d(B/dB, e)
    ax1.add_collection(lc_gradient(Bx, e(Bx), cmap))

ax1.set_ylabel('Energy ($\hbar\omega_z$)')
# ax1.set_xlabel('$(B-B_0)/\Delta B$')
# ax2.set_yticklabels([])
ax1.set_ylim(-4,4.3)
ax1.set_xlim(-1,3)
# ax1.annotate('(b)', (-.8, 3.6))

ax1.annotate(r'$\psi_b$', (-.3, -3.5), va='center', ha='center')
ax1.annotate(r'$|0\rangle$', (1, 0.3), va='center', ha='center')
ax1.annotate(r'$|2\rangle$', (1, 2.1), va='center', ha='center')

############ plot broken axis
ax2.hlines(1,10,30,colors=cmap08(0.5), lw=lw)
ax2.hlines(2.5,10,30,colors=cmap89(0.3), lw=lw)
ax2.set_xlim(18,22)
# ax2.scatter(20, 1, marker='o', color='C4', s=50, zorder=10)
ax2.annotate(r'$\psi_{a,0}=c_1|0\rangle + c_2|2\rangle$', (18, 1.5), va='center', ha='center')
ax2.annotate(r'$\psi_{a,2}=c_3|2\rangle + c_4|4\rangle$', (18, 3), va='center', ha='center')

fig.text(0.5, -0.04, '$(B-B_0)/\Delta B$', ha='center', va='center')
plt.savefig('Bfield_eigenstates.svg', bbox_inches='tight')