import numpy as np
import matplotlib.pyplot as plt
plt.style.use(r'../../../vinstyle.mplstyle')
from matplotlib.collections import LineCollection
from matplotlib import colors
from scipy.special import gamma
from scipy.constants import hbar, physical_constants
from scipy.interpolate import interp1d
from mpmath import hyp2f1
import os
os.chdir(os.path.dirname(__file__))

######## make colormaps
c0 = np.array(colors.to_rgba('C0')) # purple
c7 = np.array(colors.to_rgba('C7')) # lilac
c8 = np.array(colors.to_rgba('C8')) # pink
c9 = np.array(colors.to_rgba('C9')) # grey
x = np.array([np.linspace(0,1,256)]*3+[np.ones(256)]).T
cmap08 = colors.ListedColormap((c8-c0)*x + np.outer(c0, np.ones(256)).T)
cmap89f = colors.ListedColormap((c9-c8)*x + np.outer(c8, np.ones(256)).T)
cmap70 = colors.ListedColormap((c0-c7)*x + np.outer(c7, np.ones(256)).T)
cmap708 = colors.ListedColormap(np.concatenate((
    (c0-c7)*x[:128]*2 + np.outer(c7, np.ones(128)).T,
    (c8-c0)*x[:128] + np.outer(c0, np.ones(128)).T)))
cmap089 = colors.ListedColormap(np.concatenate((
    (c8-c0)*x[128:] + np.outer(c0, np.ones(128)).T,
    (c9-c8)*x[:128]*1.5 + np.outer(c8, np.ones(128)).T)))
cmap89 = colors.ListedColormap((c9-c8)*(x*.6+.4) + np.outer(c8, np.ones(256)).T)

########### equations
eta = 6  # anisotropy wr / wz
lw = 4   # linewidth for plotting
m1 = 133 # mass Cs
m2 = 87  # mass Rb
mu = m1*m2/(m1+m2) * physical_constants['atomic mass constant'][0]# reduced mass
wz = 2*np.pi * 13e3 # trap frequency in axial direction
d = np.sqrt(hbar/mu/wz) # scaling for distance units
# dB = 0.09 # width of FR
# abg = 680 * physical_constants['Bohr radius'][0] # background scattering length
# Eb = hbar/mu/(abg)**2/wz # background binding energy
# Em = 4.2e6*2*np.pi/wz   # gradient of dE/dB at FR

# def aB(B, e):
#     return abg/d* (1 - dB*(1 + e/Eb) / (B + e/Em - dB*e/Eb))

def F(xs, n=int(eta)):
    """x is units of energy, n is trap freq ratio but must be integer"""
    return np.array([-2*gamma(x)/gamma(x-0.5) + gamma(x)/gamma(x+0.5) * sum(
        hyp2f1(1, x, x+0.5, np.exp(2j*np.pi*m/n))
        for m in range(1, n)
    ) for x in xs], dtype=complex)

def lc_gradient(x, y, cmap, ls='solid', label=None, lw=lw):
    cs = np.linspace(0,250,len(x))
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    try:
        lc = LineCollection(segments, cmap=cmap, linestyle=ls, label=label, lw=lw)
    except ValueError:
        lc = LineCollection(segments, colors=cmap, linestyle=ls, label=label, lw=lw)
    lc.set_array(cs)
    return lc

vardict = locals()


######## ax1 plot

fig, (ax1, ax2) = plt.subplots(1, 2, sharey=False, constrained_layout=True,
            sharex=False, figsize=(8,4.2))

Es = (np.linspace(2.2,4,70), np.linspace(0.2,1.9,70))
for E, cmap, label in zip(Es, [cmap89f, cmap08], [r'$|0,0\rangle$', r'$|2,0\rangle$']):
    if not 'inva' in vardict: # don't repeat long calculation
        print('recalculating')
        inva = F(-E/2, eta).real
    ax1.add_collection(lc_gradient(inva, E, cmap))
    # ax1.plot(inva, E+1, color='C1', lw=1.5, alpha=1.2/max(E))
    # ax1.plot(inva, E+2, color='C2', lw=1.5, alpha=0.8/max(E))
    
for i, c in enumerate(['C2', 'k']): #['C0', 'C1', 'C8', 'C1']
    ax1.hlines(2*i+1, -15,15, colors=c, lw=lw/2) #0.9/(i/1.5+1)
    

ax1.set_xlim(-15,15)
ax1.set_ylim(-4,4.3)
ax1.set_xlabel('-1/a ($\sqrt{\mu\omega_z/\hbar}$)')
# ax1.set_xticks([])
ax1.set_ylabel('Energy ($\hbar \omega_z$)')
# ax1.set_yticks(list(range(5)))
# ax1.set_yticks([])
# ax1.set_yticklabels(['$E_0$'])
ax1.annotate('(a)', (-13, 3.6))
# ax1.scatter(-2, 1, marker='o', color='C4', s=50, zorder=10)

E00 = np.concatenate((np.linspace(-30, -2, 50), np.linspace(-2, -.1, 200)))
if not 'inva0' in vardict:
    print('recalculating')
    inva0 = F(-E00/2, eta).real

ax1.add_collection(lc_gradient(inva0, E00, cmap70))
# ax1.plot(inva0, E0+2, color='C2', label=r'$|0,2\rangle$', lw=1.5)
# ax1.plot(inva0, E0+1, color='C1', label=r'$|0,1\rangle$', lw=1.5)
ax1.hlines(0,-20,20,colors='C0',linestyles='dashed')
ax1.hlines(2,-20,20,colors='C8',linestyles='dashed')

handles = []
#[r'$|%s,0\rangle$'%i for i in reversed(range(4))]
for colour, label in zip(reversed(['C0', 'C2', 'C8', 'k']), [str(i) for i in reversed(range(4))]):
    handles.append(ax1.plot([-50],[-50], color=colour, label=label, lw=lw))
ax1.legend(title='$n_{\mathrm{rel},z}$', title_fontsize=14, # r'$|n_\mathrm{rel},n_\mathrm{com}\rangle$'
           loc='lower left')



ax2.add_collection(lc_gradient(inva, E, cmap))
# ax2.plot(inva, E, color='k', alpha=0.4, lw=lw)
for i in range(20):
    ax2.add_collection(lc_gradient(inva0, E00+i, cmap70, lw=lw/2)) # color='k', alpha=0.2, lw=lw/2
    if i == 0:
        ax2.annotate('$n_{\mathrm{com},z}=%s$'%i, (11, i-.25), va='center',
                     ha='center',fontsize=14)
    elif i < 5:
        ax2.annotate('%s'%i, (14.3, i-.25), va='center',
                     ha='center',fontsize=14)
ax2.plot(inva0, E00 + eta*2, color='C1', label='2', lw=lw)
ax2.plot(inva0, E00 + eta, color='C2', label='1', lw=lw)
ax2.add_collection(lc_gradient(inva0, E00, cmap70, label='0'))
# ax2.scatter(1.7033, 1.4136, marker='o', color='C2')
# ax2.vlines(0, -5,5, color='C0', linestyle='dashed', lw=1)
# ax2.scatter(-3.196, 0.8705, marker='o', color='C1')
# ax2.set_xticks([])
# ax2.set_xticklabels(['$B_0$'])
ax2.set_yticks([])
# ax2.set_yticklabels([])
ax2.set_xlim(-15,15)
ax2.set_xlabel('-1/a ($\sqrt{\mu\omega_z/\hbar}$)')
ax2.set_ylim(-4,4.3)
ax2.annotate('(b)', (-13, 3.6))
ax2.legend(title=r'$n_{\mathrm{com}, r}$', 
           title_fontsize=14, loc='lower left')

# fig.text(0.5, -0.04, '-1/a', ha='center', va='center')
plt.savefig('confinement_eigenstates.svg', bbox_inches='tight')