"""
Code to produce figures in "Velocity selection in a Doppler-broadened ensemble of atoms
interacting with a monochromatic laser beam"

Copyright 2017 J. Keaveney and I. G. Hughes

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


------------------------------------------------------------------------

All figures produced can be produced from this file. The only requirements are that
numpy and matplotlib be installed

------------------------------------------------------------------------
"""

import numpy as np
import matplotlib.pyplot as plt

from matplotlib.patches import ConnectionPatch

## Durham Colour Palette ##
d_black = [35./255,31./255,32./255]			#	Black				BlackC
d_olive	= [159.0/255.0,161.0/255.0,97.0/255.0] 	# 	Olive Green	5835C	
d_blue	= [0,99.0/255.0,136.0/255.0] 				# 	Blue			634C
d_red	= [170.0/255.0,43.0/255.0,74.0/255.0] 		#  Red			201C
d_midblue = [145./255,184.0/255.0,189.0/255.0]	# 	Mid Blue		5493C

# update matplotlib fonts etc
plt.rc('font',**{'family':'Serif','serif':['Times New Roman']})
params={'axes.labelsize':13,'xtick.labelsize':12,'ytick.labelsize':12,'legend.fontsize': 11}
plt.rcParams.update(params)


### PLOT PARAMETERS *****
W = 1./20 # Width of Lorentzian compared to Gaussian (FWHM)
Ds = [0, 0.5, 1.5, 2.5] # Detunings for the 4 panels in figure 2


def Lor(v, W, D):
	""" Lorentzian function with width W and centred on D, unitary amplitude """
	return W**2/ (W**2 + 4*(v-D)**2)

def Gau(v, u):
	""" Gaussian with FWHM u, centred on v=0 and with unity peak value """
	return np.exp(-4*np.log(2) * v**2/u**2)

def ten_ninety():
	""" 
	test function that calculates the 10/90 width of Lorentzian, for direct comparison with the FWHM
	"""
	Ds = np.linspace(-20,20,25001)
	
	L = Lor(0, 1, Ds)
	
	fig = plt.figure()
	ax1 = fig.add_subplot(211)
	ax2 = fig.add_subplot(212,sharex=ax1)
	
	ax1.plot(Ds, L)
	
	CS = np.cumsum(L)
	# normalised
	CS /= CS.max()
	
	ax2.plot(Ds, CS)
	
	# find 0.5 crossing points
	# crop data above 0.5
	v_above = Ds[(CS > 0.1) & (CS < 0.9)]
	#v_above_fw = Ds[(CS > 0.25) & (CS < 0.5)]
	
	for ax in fig.axes:
		ax.axvline(v_above[-1],color='k')
		ax.axvline(v_above[0],color='k')
	
	ten_ninety_width = TenNinetyWidth(Ds, L)
	print ten_ninety_width
	
	plt.show()
	
def main():
	""" 
	***
	Figure 2
	***

	2x2 panel figure with Gaussian and combination of Gaussian and Lorentzian with different detunings
 	"""
	
	## abscissa axis
	v_over_u = np.linspace(-5,5,2501)
	dv = v_over_u[1] - v_over_u[0]
	
	# colours for lines
	col_G = d_red
	col_L = d_blue
	
	# Gaussian component (constant in all panels)
	G = Gau(v_over_u, 1)
	G /= G.sum() * dv
	
	# Set up figure canvas
	fig = plt.figure(1,facecolor='w')
	fig.subplots_adjust(left=0.09,right=0.9,top=0.95,bottom=0.11,hspace=0.15,wspace=0.3)
	
	
	## Ghost axes for displaying centred axes labels
	ax_ghost = fig.add_subplot(111,frameon=False)
	# Turn off axis lines and ticks of the big subplot
	ax_ghost.spines['top'].set_color('none')
	ax_ghost.spines['bottom'].set_color('none')
	ax_ghost.spines['left'].set_color('none')
	ax_ghost.spines['right'].set_color('none')
	ax_ghost.set_xticks([])
	ax_ghost.set_yticks([])
	ax_ghost2 = fig.add_axes(ax_ghost.get_position(),frameon=False)
	# Turn off axis lines and ticks of the big subplot
	ax_ghost2.spines['top'].set_color('none')
	ax_ghost2.spines['bottom'].set_color('none')
	ax_ghost2.spines['left'].set_color('none')
	ax_ghost2.spines['right'].set_color('none')
	ax_ghost2.set_xticks([])
	ax_ghost2.set_yticks([])
	ax_ghost2.yaxis.set_label_position('right')
	# Add labels
	ax_ghost.set_xlabel('Atomic velocity, $v_z/u$',labelpad=20)
	ax_ghost.set_ylabel(r'Relative Weighting, $f(v_z/u)$ (Doppler)',labelpad=22,color=d_red)
	ax_ghost2.set_ylabel(r'Relative Weighting, $f(v_z/u)\times L(v_z/u, \Delta/ku)$',labelpad=22,color=d_blue)
	
	# Data axes
	ax1 = fig.add_subplot(221,axisbg='none')
	ax1b = ax1.twinx()
	ax1.tick_params('y', colors=d_red)
	ax1b.tick_params('y', colors=d_blue)
	ax2 = fig.add_subplot(222,sharex=ax1,axisbg='none')
	ax2b = ax2.twinx()
	ax2.tick_params('y', colors=d_red)
	ax2b.tick_params('y', colors=d_blue)
	ax3 = fig.add_subplot(223,sharex=ax1,axisbg='none')
	ax3b = ax3.twinx()
	ax3.tick_params('y', colors=d_red)
	ax3b.tick_params('y', colors=d_blue)
	ax4 = fig.add_subplot(224,sharex=ax1,axisbg='none')
	ax4b = ax4.twinx()
	ax4.tick_params('y', colors=d_red)
	ax4b.tick_params('y', colors=d_blue)
	
	# Panel (a)
	D = Ds[0] # Detuning value (defined at the top of the file)
	L = Lor(v_over_u, W, D)
	ax1b.plot(v_over_u, G*L, color=col_L, lw=2)
	ax1b.fill_between(v_over_u, G*L, 0, color=d_midblue, lw=1.5)
	ax1.plot(v_over_u, G, color=col_G, linestyle='dashed', lw=1.5)
	ax1.set_zorder(ax1b.get_zorder()+1)
	
	# Panel (b)
	D = Ds[1]
	L = Lor(v_over_u, W, D)
	ax2b.plot(v_over_u, G*L, color=col_L, lw=2)
	ax2b.fill_between(v_over_u, G*L, 0, color=d_midblue, lw=1.5)
	ax2.plot(v_over_u, G, color=col_G, linestyle='dashed', lw=1.5)
	ax2.set_zorder(ax2b.get_zorder()+1)
	ax2b.set_ylim(0,1)
	
	# Panel (c)
	D = Ds[2]
	L = Lor(v_over_u, W, D)
	ax3b.plot(v_over_u, G*L*1e3, color=col_L, lw=2)
	ax3b.fill_between(v_over_u, G*L*1e3, 0, color=d_midblue, lw=1.5)
	ax3.plot(v_over_u, G, color=col_G, linestyle='dashed', lw=1.5)
	ax3b.set_zorder(ax3.get_zorder()-1)
	ax3.text(3.25,0.88,r'$\times 10^{-3}$',ha='center',color=col_L,fontsize=14)

	# Panel (d)
	D=Ds[3]
	L = Lor(v_over_u, W, D)
	ax4b.plot(v_over_u, G*L*1e4, color=col_L, lw=2)
	ax4b.fill_between(v_over_u, G*L*1e4, 0, color=d_midblue, lw=1.5)
	ax4.plot(v_over_u, G, color=col_G, linestyle='dashed', lw=1.5)
	ax4b.set_zorder(ax4.get_zorder()-1)
	ax4.text(3.25,0.88,r'$\times 10^{-4}$',ha='center',color=col_L,fontsize=14)
	ax4b.set_ylim(0,1.1)
	# Inset to panel (d)
	ax4c = fig.add_axes([0.75,0.22,0.11,0.11])
	ax4c.plot(v_over_u, G*L*1e8, color=col_L, lw=2)
	ax4c.fill_between(v_over_u, G*L*1e8, 0, color=d_midblue, lw=2)
	ax4c.set_xlim(2,3)
	ax4c.set_xticks([2,2.5,3])
	ax4c.yaxis.set_label_position('right')
	ax4c.set_ylim(0,4.25)
	ax4c.set_yticks([0,2,4])
	ax4c.text(2.7,3,r'$\times 10^{-8}$',ha='center',color=col_L,fontsize=12)

	# lines to connect the inset to the main panel
	for xy1 in [(2,0), (3,0)]:
		#xy1 = (2,0)
		col = np.array(d_black)*3
		alpha = 1
		#styles = ['solid','solid','solid','solid','dashed','dashed','dashed','dashed']
		coordsA = 'data'
		coordsB = 'data'
		con = ConnectionPatch(xy1, xy1, coordsA, coordsB,
					arrowstyle="-", shrinkB=0,
					axesA=ax4, axesB=ax4c, mutation_scale=12,
					ec=col,fc=col,lw=0.5,alpha=alpha)
		A = ax4.add_artist(con)
		A.set_zorder(20)
	
	# global x-axis limit (same for all panels)
	ax1.set_xlim(-2,4)
	
	# panel labels
	ax1.text(-1.5,0.88,r'(a)',ha='center',color=d_black,fontsize=14)
	ax2.text(-1.5,0.88,r'(b)',ha='center',color=d_black,fontsize=14)
	ax3.text(-1.5,0.88,r'(c)',ha='center',color=d_black,fontsize=14)
	ax4.text(-1.5,0.88,r'(d)',ha='center',color=d_black,fontsize=14)
	
	# save figure
	plt.savefig('two_by_two.png')
	plt.savefig('two_by_two.pdf')
	plt.savefig('two_by_two.eps')
	
	# Show figure in interactive window
	#plt.show()


def fig1():	
	""" 
	'cartoon' figure showing hole in ground state velocity 
	distirbution and selective excited statea population
	"""
	
	# setup figure
	fig = plt.figure(2,facecolor='w',figsize=(5,4))
	fig.subplots_adjust(left=0.105,right=0.95,top=0.95,bottom=0.13,hspace=0.15,wspace=0.3)
	
	# subplots
	yy = 3
	xx = 1
	ax = plt.subplot2grid((yy,xx),(0,0))
	ax2 = plt.subplot2grid((yy,xx),(1,0), rowspan=yy-1)
	
	# remove most of the default axes lines
	ax.spines['top'].set_color('none')
	ax.spines['bottom'].set_color('none')
	ax.spines['right'].set_color('none')	
	ax2.spines['top'].set_color('none')
	#ax2.spines['bottom'].set_color('none')
	ax2.spines['right'].set_color('none')	
	
	ax.set_xticks([0])
	ax.set_xticklabels([])
	ax2.set_xticks([0])
	ax.set_yticks([0,1])
	ax2.set_yticks([0,1])
	ax.set_yticklabels([])
	ax2.set_yticklabels([])
	ax.yaxis.set_ticks_position('left')
	ax2.yaxis.set_ticks_position('left')
	
	# abscissa axis
	v_over_u = np.linspace(-2,2,2501)
	
	#colours
	col_G = d_red
	col_L = d_blue
	
	#Gaussian component
	G = Gau(v_over_u, 1)

	# Lorentzian component
	D = 0.35
	W = 1./10
	L = Lor(v_over_u, W, D)
	
	# Hole that's "burned" from the velocity distribution
	sc = 0.75
	Hole = G * (1 - sc*L)
	
	# Add curves to plot
	ax2.plot(v_over_u, 0.95*Hole+0.04, color=d_blue, lw=2)
	ax2.plot(v_over_u, 0.95*G+0.04, color=d_black, linestyle='dashed', lw=1.5)
	ax.plot(v_over_u, 0.9*L+0.04, color=d_red, lw=2)
	
	# Add vertical dashed lines at 0
	ax.axvline(0,color=d_black, linestyle='dashed')
	ax2.axvline(0,color=d_black, linestyle='dashed')
	
	# Set axes limits
	ax.set_xlim(-2.5,2.5)
	ax2.set_xlim(ax.get_xlim())
	ax.set_ylim(-0,1)
	ax2.set_ylim(-0,1)
	
	# Axes labels
	ax.set_ylabel('Excited State\nPopulation')
	ax2.set_ylabel('Ground state\nPopulation')
	ax2.set_xlabel('Velocity, $v_z$')

	# x-axis arrow rather than the default line:
	# Reproduced with little modification from
	# http://stackoverflow.com/questions/17646247/how-to-make-fuller-axis-arrows-with-matplotlib
	
	xmin, xmax = ax2.get_xlim()
	ymin, ymax = ax2.get_ylim()
	
	# manual arrowhead width and length
	hw = 1./20.*(ymax-ymin) 
	hl = 1./20.*(xmax-xmin)
	lw = 0.8 # axis line width
	ohg = 0.3 # arrow overhang

	# compute matching arrowhead length and width
	yhw = hw/(ymax-ymin)*(xmax-xmin) #* height/width 
	yhl = hl/(xmax-xmin)*(ymax-ymin) #* width/height

	ax2.arrow(xmin, 0, xmax-xmin, 0., fc=d_black, ec=d_black, lw = lw, 
			 head_width=hw, head_length=hl, overhang = ohg, 
			 length_includes_head= True, clip_on = False) 


	# Save figures
	plt.savefig('fig1.png')
	plt.savefig('fig1.eps')
	plt.savefig('fig1.pdf')
	
	# Show figure in interactive window
	#plt.show()

def integrand(v,W,D):
	""" Multiply L by G elementwise for given inputs """
	return Lor(v,W,D) * Gau(v,1)

def TenNinetyWidth(v,Lineshape):
	""" 
	calculate the 10/90 width by finding zero crossings when the cumulative
	distribution crosses 10% and 90% of the maximum value (similar to a knife-edge
	measurement of laser beam width)
	"""
	# normalise input
	LS = np.array(Lineshape)
	LS /= LS.max()
	
	# Cumulative sum
	CS = np.cumsum(LS)
	# normalised
	CS /= CS.max()
	
	# find 0.1/0.9 crossing points
	# crop data between these points - arrays must be in increasing order of v!
	v_above = v[(CS > 0.1) & (CS < 0.9)]
	FW = v_above[-1] - v_above[0]
	return FW
	
def CoMass(v,Lineshape):
	""" 
	Method for finding centre-of-mass by finding where 
	the cumulative sum crosses half the maximum value """
	LS = np.array(Lineshape)
	
	# cumulative sum of LS array
	CS = np.cumsum(LS)
	# normalised
	CS /= CS.max()
	
	# find position of crossing halfway
	v_above = v[CS > 0.5]
	
	return v_above[0]

def FWHM_analysis():
	""" 
	****
	Figure 3
	****
	
	Look at evolution of combined Lorentz/Gauss distribution as a function of detuning.
	Calculate 10/90 width and centre-of-mass position
	"""

	# abscissa axis
	vs = np.linspace(-10,10,25001)
	dets = np.linspace(0,5,1000)
	
	# initialise arrays
	FWs = np.zeros_like(dets)
	CMs = np.zeros_like(dets)
	
	# calculate width and CoM for each detuning
	for i in range(len(dets)):
		LS = integrand(vs,W,dets[i])
		
		FWs[i] = TenNinetyWidth(vs,LS)
		CMs[i] = CoMass(vs,LS)
	
	# set up figure
	fig = plt.figure(3,facecolor='w',figsize=(5,4))
	fig.subplots_adjust(left=0.16,right=0.95,top=0.9,bottom=0.12)
	
	# top panel
	ax1 = fig.add_subplot(211)
	ax1.plot(dets,FWs,color=d_blue,lw=2)
	
	# bottom panel
	ax2 = fig.add_subplot(212, sharex=ax1)
	ax2.plot(dets,CMs,color=d_red,lw=2)
	
	ax2.set_xlabel('Detuning ($\Delta / ku$)')
	ax2.set_ylabel('Centre-of-mass\nvelocity ($v_z/u$)')
	ax1.set_ylabel('10/90 width ($v_z/u$)')
	
	#plt.tight_layout()
	
	labels = ['(a)', '(b)', '(c)', '(d)']
	for D, lab in zip(Ds, labels):
		for ax in fig.axes:
			ax.axvline(D, color=np.array(d_black)*2, linestyle='dashed')
		ax1.text(D, 2.17, lab, color=np.array(d_black)*2, size=13, clip_on=False, ha='center')
	
	# save figure
	plt.savefig('FWHM_CoM.png')
	plt.savefig('FWHM_CoM.pdf')
	plt.savefig('FWHM_CoM.eps')

	# Show figure in interactive window
	plt.show()

if __name__ == '__main__':
	fig1()
	main()
	FWHM_analysis()