"""Stefan Spence 02/11/18
for Python 3.5.2
Collection of useful fitting and graphing functions.
04/12/18
Added new functions to aid fitting and renamed graphing functions for contour plots
07/12/18
Added linear and arbitrary polynomial functions in fit
06/02/19
start adding functions to do batch fitting of CCD images.
24/02/19
running the script initiates the gaussian fitter
26/02/19
the gaussian fitter was moved to gaussianFitter.py
27/02/19
correct a missing minus sign in the fit.gauss() function
27/05/19
Allow labels in residual plot
03/03/20
Add in DU colour scheme
02/10/20
Add in arguments for fit report and extra functions
04/10/22 load using pandas"""
import numpy as np
import sys
import os
import time
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm
from scipy.optimize import curve_fit
from scipy.special import erfc, jv, spherical_jn, spherical_yn, legendre, factorial, genlaguerre
from scipy.signal import tukey
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
try:
    import qutip
except ImportError: pass

# scientific constants
c    = 2.99792458e8  # speed of light in m/s
eps0 = 8.85419e-12   # permittivity of free space in m^-3 kg^-1 s^4 A^2
mu0  = 1.256637062e-6# permeability of free space in kg m s^-2 A^-2
h    = 6.6260700e-34 # Planck's constant in m^2 kg / s
hbar = 1.0545718e-34 # reduced Planck's constant in m^2 kg / s
a0 = 5.29177e-11     # Bohr radius in m
e = 1.6021766208e-19 # magnitude of the charge on an electron in C
me = 9.10938356e-31  # mass of an electron in kg
kB = 1.38064852e-23  # Boltzmann's constant in m^2 kg s^-2 K^-1
amu = 1.6605390e-27  # atomic mass unit in kg
uB = 9.2740100783e-24 # Bohr magneton in J/T
Eh = me * e**4 /(4. *np.pi *eps0 *hbar)**2  # the Hartree energy
au = e**2 * a0**2 / Eh # atomic unit for polarisability        

"""Durham university official colour scheme."""
# primary colours
purple = '#68246D'
yellow = '#FFD53A'
cyan   = '#00AEEF'
red    = '#BE1E2D'
gold   = '#AFA961'
# secondary colours
heather  = '#CBA8B1'
stone    = '#DACDA2'
sky      = '#A5C8D0'
cedar    = '#B6AAA7'
concrete = '#B3BDB1'
# type colours
ink   = '#002A41'
black = '#333132'

"""Old colour scheme"""
DUblack         = '#231F20'
DUlight_purple  = '#D8ACE0'
DUpalatinate    = '#7E317B'
DUsea_blue      = '#006388'
DUcherry_red    = '#AA2B4A'
DUpastel_blue   = '#CFDAD1'
DUsky_blue      = '#C4E5FA'
DUolive         = '#9FA161'
DUbeige         = '#E8E391'
DUfull_pink     = '#C43B8E'
DUwindow_blue   = '#91B8BD'
DUwarm_grey     = '#968E85'

class defaultColours:
    """Cycle through the Durham university official colour scheme."""
    def __init__(self,ind=0):
        self.i = ind
        self.colours = [gold, purple, yellow, cyan, red]
        self.secondary = [heather, stone, sky, cedar, concrete]
        self.type = [ink, black]

    def c(self, ind=None):
        if ind == None:
            self.i += 1
            return self.colours[self.i%5]
        else: return self.colours[ind%5]
        
def weighted_std(v):
    if len(v.y.values)-1:
        weights = v.yerr
        average = np.ma.average(v.y, weights=v.yerr, axis=0)
        variance = np.dot(weights, (v.y - average)**2) / weights.sum()
        return max(np.sqrt(variance), 1/np.sqrt(np.sum(1/v.yerr**2)))
    else: return v.yerr.values[0]
    
def getDAQPowers(fname, hh, col=1):
    """Extract data from the DAQ graph and match the times with shot numbers
    returns: 
         - Mean DAQ Voltage from 1st column for each histogram in hh
         - Start time of each histogram
         - End time of each histogram
         - DAQ data extracted from csv file"""
    data = np.loadtxt(fname, delimiter=',')
    data = groupArrSort(*data.T)
    fid, p, times = data[:,0], data[:,col], data[:,-1]
    volts = np.zeros(len(hh.stats['User variable'])) # DAQ voltage
    start = np.zeros(len(hh.stats['User variable']),dtype=str) # start time
    end = np.zeros(len(hh.stats['User variable']),dtype=str) # end time
    for i in range(len(volts)):
        j = np.where(fid > hh.stats['Start file #'][i])[0][0]
        k = np.where(fid > hh.stats['End file #'][i])[0][0]
        volts[i] = np.mean(p[j+2:k-2])
        start[i] = time.strftime('%H:%M:%S', time.gmtime(times[j]))
        end[i] = time.strftime('%H:%M:%S', time.gmtime(times[k]))
    return volts, start, end, data

class fit:
    """Collection of common functions for theoretical fits."""
    def __init__(self, xdat=0, ydat=0, erry=None, param=None, bf=None, labels=[], w=80):
        self.x    = xdat   # independent variable
        self.y    = ydat   # measured dependent variable
        self.yerr = erry   # errors in dependent variable
        self.p0   = param  # guess of parameters for fit
        self.ps   = param  # best fit parameters
        self.perrs = None  # error on best fit parameters
        self.args = labels     # list of labels for parameters
        self.bff  = bf   # function used for best fit

        self.w = w # trap frequency
        
    def load(self, fname, ykey='Loading probability', sort=True):
        """Load data from measure file using pandas"""
        df = pd.read_csv(fname, skiprows=2)
        if sort:
            self.x, self.y, self.yerr = groupArrSort(df['User variable'], 
                                             df[ykey], df['Error in '+ykey]).T
        else: 
            self.x, self.y, self.yerr = (df['User variable'], 
                                             df[ykey], df['Error in '+ykey])
         
    def estGaussParam(self):
        """Guess at the amplitude A, centre x0, width wx, and offset y0 of a 
        Gaussian
        Code taken from Vincent Brooks' function_fits.py"""
        # the FWHM is defined where the function drops to half of its max
        peak = np.max(self.y)       
        fwhm = peak
        i = 0
        while (fwhm - np.min(self.y)) > (peak - np.min(self.y)) / 2.:
            fwhm = self.y[(np.argmax(self.y) + i)]
            i += 1
            if (np.argmax(self.y) + i == (len(self.y) - 1)):
                break

        e2_width = 2 * (self.x[(np.argmax(self.y) + i)] - self.x[(np.argmax(self.y))])
        self.args = ['Amplitude', 'centre', '1/e^2 width', 'offset']
        self.p0 = [(np.max(self.y) - np.min(self.y)), self.x[np.argmax(self.y)],
                        e2_width, np.min(self.y)]
        
    def estHyperbolaParam(self):
        """Guess at the beam waist w0 at position z0, and weighting zR
        (Rayleigh range for Guassian beam propagation)"""
        self.args = ['waist', 'centre', 'range']
        self.p0 = [np.min(self.y), self.x[np.argmin(self.y)], np.pi*np.min(self.y)**2]
        
    def linear(self, x, m, c):
        """Straight-line linear fit with gradient m and intercept c."""
        self.args = ['gradient', 'intercept']
        return m*x + c
        
    def polynomial(self, x, coeff):
        """A polynomial with the supplied coefficients starting from the 
        highest order in x. For example, coeff = [1,2,3] gives the quadratic
        x^2 + 2x + 3"""
        p = np.poly1d(coeff)
        return p(x)
                        
    def gauss(self, x, A, x0, sig):
        """Gaussian centred at x0 with amplitude A and standard deviation sig
        When A = 1 this is the normal distribution. Note that Gaussian beam
        propagation uses a 1/e^2 width(radius) wx = 2*sig."""
        self.args = ['Amplitude', 'Centre', 'Standard deviation']
        return A* np.exp(-(x-x0)**2 /2. /sig**2) /sig /np.sqrt(2*np.pi)
        
    def eDecay(self, x, A, t0):
        """Exponential decay with amplitude A and time constant t0"""
        self.args = ['Amplitude', '1/e time']
        return A * np.exp(-x/t0)

    def Lorentz(self, x, A, x0, gam):
        """Lorentzian function centred at x0 with width gam, amplitude A"""
        self.args = ['Amplitude', 'Centre', 'Width']
        return A * gam / ((x - x0)**2 + gam**2/4.) / 2. /np.pi
        
    def offLorentz(self, x, A, x0, gam, y0):
        """Lorentzian function centred at x0 with width gam, amplitude A and 
        offset y0"""
        self.args = ['Amplitude', 'Centre', 'Width', 'Offset']
        return A * gam / ((x - x0)**2 + gam**2/4.) / 2. /np.pi + y0
        
        
    def DoubleLorentz(self, x, A, x0, gam,A2, x02, gam2, y0):
        """Double Lorentzian function centred at x0 with width gam, amplitude A and 
        common offset y0"""
        self.args = ['Amp1', 'Centre1', 'Width1','Amp2', 'Centre2', 'Width2', 'Offset']
        return A * gam / ((x - x0)**2 + gam**2/4.) / 2. /np.pi +A2 * gam2 / ((x - x02)**2 + gam2**2/4.) / 2. /np.pi + y0
        
    def DoubleLorentzSameWidthAmp(self, x, A, x0, gam, x02, y0):
        """Double Lorentzian function centred at x0 with width gam, amplitude A and 
        common offset y0"""
        self.args = ['Amp1', 'Centre1', 'Width1','Amp2', 'Centre2', 'Offset']
        gam2 = gam
        A2 = A
        return A * gam / ((x - x0)**2 + gam**2/4.) / 2. /np.pi +A2 * gam2 / ((x - x02)**2 + gam2**2/4.) / 2. /np.pi + y0
    
    def TripleLorentz(self, x, A, x0, gam,A2, x02, gam2,A3, x03, gam3, y0):
        """Triple Lorentzian function centred at x0 with width gam, amplitude A and 
        common offset y0"""
        self.args = ['Amp1', 'Centre1', 'Width1','Amp2', 'Centre2', 'Width2','Amp3', 'Centre3', 'Width3', 'Offset']
        return A * gam / ((x - x0)**2 + gam**2/4.) / 2. /np.pi +A2 * gam2 / ((x - x02)**2 + gam2**2/4.) / 2. /np.pi + A3 * gam3 / ((x - x03)**2 + gam3**2/4.) / 2. /np.pi + y0
    
    def QuadLorentz(self, x, A, x0, gam, A2, x02, gam2, A3, x03, gam3, A4, x04, gam4, y0):
        """Quadruple Lorentzian function centred at x0 with width gam, amplitude A and 
        common offset y0"""
        self.args = ['Amp1', 'Centre1', 'Width1','Amp2', 'Centre2', 'Width2','Amp3', 'Centre3', 'Width3','Amp4', 'Centre4', 'Width4', 'Offset']
        return A * gam / ((x - x0)**2 + gam**2/4.) / 2. /np.pi +A2 * gam2 / ((x - x02)**2 + gam2**2/4.) / 2. /np.pi + A3 * gam3 / ((x - x03)**2 + gam3**2/4.) / 2. /np.pi + A4 * gam4 / ((x - x04)**2 + gam4**2/4.) / 2. /np.pi + y0
    
    def LorentzianSidebands(self, x, xc, AC, WC, WS, x0, A1, A2):
        """3 Lorentzians, where the outer ones have the same width and separation."""
        self.args = ['Centre', 'Carrier Amp', 'Carrier width', 'Sideband width', 'Sideband detuning', 'Sideband Amp1', 'Sideband Amp2']    
        return AC*WC/((x-xc)**2+WC**2/4) + A1* WS/((x-xc-x0)**2+(WS/2)**2) + A2*WS/((x-xc+x0)**2+(WS/2)**2)
    
    def offGaussBeamFit(self, x, A, x0, wx, y0):
        """Gaussian function centred at x0 with amplitude A, 1/e^2 width wx
        and background offset y0"""
        self.args = ['Amplitude', 'centre', '1/e^2 width', 'offset']
        return A * np.exp( -2 * (x-x0)**2 /wx**2) + y0
    
    def offGauss(self, x, A, x0, FWHM, y0):
        """Gaussian function centred at x0 with amplitude A, Full width half maximum FWHM
        and background offset y0"""
        self.args = ['Amplitude', 'centre', 'FWHM', 'offset']
        sigma=FWHM/(2*np.sqrt(2*np.log(2)))
        return A * np.exp( - (x-x0)**2 /(2*(sigma**2))) + y0
    
    def sincSquare(self, x, A, x0, w, y0):
        """Squared sinc function, expected for FT of square pulse"""
        self.args = []
        
    def doubleGauss(self, x, A1, x1, FWHM1, A2, x2, FWHM2, y0):
        """Gaussian function centred at x0 with amplitude A, Full width at half maximum FWHM
        and background offset y0"""
        self.args = ['Amplitude1', 'centre1', 'FWHM1', 'Amplitude2', 'centre2', 'FWHM2', 'offset']
        sigma1=FWHM1/(2*np.sqrt(2*np.log(2)))
        sigma2=FWHM2/(2*np.sqrt(2*np.log(2)))
        return A1 * np.exp( - (x-x1)**2 /(2*(sigma1**2))) +A2 * np.exp( - (x-x2)**2 /(2*(sigma2**2)))+ y0
        
    def offeDecay(self, x, A, t0,y0):
        """Exponential decay with amplitude A, time constant t0 and offset y0"""
        self.args = ['Amplitude', '1/e time','offset']
        return A * np.exp(-x/t0) + y0

    def OPphotons(self, x, p, A, y0):
        """Number of scattered OP photons, assuming x is OP time*scattering rate"""
        self.args = ['Polarisation purity', 'Amplitude', 'offset']
        return A * (1 - np.power(0.53333333/p**2, x)) + y0
        
    def airyDisk(self, x, I0, k, x0, y0):
        """Intensity of Fraunhofer diffraction from circular aperture is an 
        Airy disk with peak intensity I0. The constant k incorporates the 
        aperture diameter and wavelength of the light. y0 takes account of
        background and x0 is offset."""
        self.args = ['Intensity', 'Length', 'Centre', 'Offset']
        return I0 * (2 * jv(1,k*(x-x0)) / k /(x-x0))**2 + y0

        
    def gauss2d(self, xy, A, wx, x0, wy, y0, bg):
        """Gaussian function centred at (x0, y0) with amplitude A, 
        1/e^2 width wx in the x direction and wy in the y direction, 
        and offset bg note that x and y should be given by a meshgrid"""
        self.args = ['Amplitude', 'X width', 'X centre', 'Y width', 'Y centre', 'Offset']
        x, y = xy       # unpack position
        return A * np.exp(-2*(x-x0)**2 /wx**2 - 2*(y-y0)**2 /wy**2) + bg
        
    def topHat(self, x, x0, tau, a, d, y0):
        self.args=['x0', 'tau', 'a', 'd', 'y0']
        y = np.zeros(len(x))
        y[np.where(x > (x0+d/2))] = a*(np.exp(-(x[np.where(x > (x0+d/2))]-(x0+d/2)) / tau))+y0
        y[np.where(x < (x0+d/2))] = a+y0
        y[np.where(x < (x0-d/2))] = a*(np.exp((x[np.where(x < (x0-d/2))]-(x0-d/2)) / tau))+y0
        return y
        
    def knifeEdge(self, x, x0, ptot, wx, bg):
        """Knife-edge measurement: the measured power when the knife is moved in
        the positive x direction for total beam power ptot, beam waist position
        x0, 1/e^2 beam radius wx, and background power bg."""
        self.args = ['Centre', 'Amplitude', 'Width', 'Offset']
        return ptot/2. * ( 1 - erfc(np.sqrt(2) * (x - x0) / wx) ) + bg
        
    def wz1(self, x, w0, z0, l):
        """Gaussian beam width propagation when M^2 = 1 for beam waist w0, 
        at position z0, for light of wavelength l"""
        self.args = ['Beam waist', 'Centre', 'Wavelength']
        return w0 * np.sqrt(1 + (x - z0)**2 / (np.pi*w0**2/l)**2)
        
    def wz(self, x, w0, z0, zR):
        """Gaussian beam propagation when the M^2 value is not known for beam
        waist w0 at position z0, with Rayleigh range zR = pi w0^2 / lambda / M^2"""
        self.args = ['Beam waist', 'Centre', 'Rayleigh range']
        return w0 * np.sqrt(1 + (x - z0)**2 / zR**2)
        
    def doublePoisson(self, x, mu1, mu2, A1, A2):
        """Two Poisson distributions with different means (mu1, mu2), as might 
        be seen in single atom image detection. The peaks have amplitudes A1, A2"""
        self.args = ['Mean1', 'Mean2', 'Amplitude1', 'Amplitude2']
        return (A1 * np.power(mu1,x) * np.exp(-mu1) + A2 * np.power(mu2,x) * np.exp(-mu2)) / factorial(x)

    def dampedSine(self, x, A, omega, tau, phi, c):
        """A damped sine with amplitude A, frequency omega in rad/s, phase
        shift phi in rad, decay time tau in s and offset c."""
        self.args = ['Amplitude', 'Frequency (rad/s)', 'Decay (s)', 'Phase', 'Offset']
        return A * np.sin(x * omega + phi) * np.exp(-x/tau) + c

    def dampedSinekHz(self, x, A, omega, tau, phi, c):
        """Assume x is time in us. Damped sine with amplitude A, 
        frequency omega in kHz, phase shift phi in rad, 
        decay time tau in us and offset c."""
        self.args = ['Amplitude', 'Frequency (kHz)', 'Decay (us)', 'Phase', 'Offset']
        return A * np.sin(x * 2*np.pi*1e-3 * omega + phi) * np.exp(-x/tau) + c
    
    
    def dampedSineGauss(self, x, A, omega, stdev, phi, c):
        """Assume x is time in us. Damped sine with amplitude A, 
        frequency omega in kHz, phase shift phi in rad, 
        gaussian envelope for decay with variance sigma in Hz and offset c."""
        self.args = ['Amplitude', 'Frequency (kHz)', 'Freq stdev (Hz)', 'Phase', 'Offset']
        return A * np.sin(x * 2*np.pi*1e-3 * omega + phi) * np.exp(-0.5*(x*2*np.pi*stdev*1e-6)**2) + c
    
    def RamseyDecay(self, x, A, delta, T2, phi, c):
        """Assume x is time between pi/2 pulses in us. 
        Equation 5.20 Kuhr thesis. Cosine with amplitude A, 
        detuning from resonance delta in kHz, phase shift phi in rad (could include a phase shift term but have neglected here), 
        inhomogenous dephasing time (due to thermal motion) T2 in us and offset c."""
        self.args = ['Amplitude', 'Detuning (kHz)', 'T*2', 'Phase', 'Offset']
        return A *np.cos(x * 2*np.pi*1e-3 * delta + phi)* (1. + 2.79*(x/T2)**2)**(-3./4.) + c
    
    def SpinEchoDecay(self, x, A, stdev,  c):
        """
        Fit decay in visibility of spin echo. Uisng eq 5.32 from Stefan Kuhr PhD thesis.
        X is the time between pi/2 pulse and pi pulse (T/2 or t_pi) in ms. Amp is the contrast at zero time delay
        (should be between zero and one) stdev is the standard deviation of the trasnsition frequency in Hz (sigma/2pi)
        and c is the offset
        """
        self.args = ['Amplitude', 'Freq stdev (Hz)', 'Offset']
        return A  * np.exp(-0.5*(x*(2*np.pi*stdev*1e-3))**2) + c
        
    def undampedSinekHz(self, x, A, omega, phi, c):
        """Assume x is time in us. Damped sine with amplitude A, 
        frequency omega in kHz, phase shift phi in rad,  offset c."""
        self.args = ['Amplitude', 'Frequency (kHz)', 'Phase', 'Offset']
        return A * np.sin(x * 2*np.pi*1e-3 * omega + phi) + c
    
    def thermalRabi(self, x, T, eta, OR):
        """Use a thermal distribution of Rabi frequencies to fit Rabi oscillations.
        Assume x is time in us.
        Set the trap frequency as the property w (in kHz), defaults to 80 kHz.
        Assumes dephasing is due to thermal distribution of states."""
        self.args = ['Temperature (uK)', 'Lamb-Dicke', 'Rabi Frequency (kHz)']
        self.w=132
        return np.sum([np.exp(-n*h*1e3*self.w / kB/T/1e-6) / sum(np.exp(-i*h*1e3*self.w / kB/T/1e-6) for i in range(100))/2 
            * (1 - np.cos(x * OR*2*np.pi*1e-3 * np.exp(-eta**2/2)*genlaguerre(n,0)(eta**2))) for n in range(100)], axis=0)

    def thermalRabi_nbar(self, x, OR, nbar, eta):
        """Use a thermal distribution of Rabi frequencies to fit Rabi oscillations.
        Assume x is time in us.
        Fits the mean motional number n_bar = 1/(e^{hbar*omega/(k*T)}-1) eq 3.21 Kaufman thesis
        Assumes dephasing is due to thermal distribution of states."""
        self.args = ['Rabi Frequency (kHz)', 'Mean motional number', 'Lamb-Dicke']
        boltz_factor = nbar/(1. + nbar) #rearrange eq 3.21 to obtain e^{-hbar*omega/(k*T)}
        return np.sum([boltz_factor**n / sum(np.power(boltz_factor, range(80)))/2 
            * (1 - np.cos(x * OR*2*np.pi*1e-3 * np.exp(-eta**2/2)*genlaguerre(n,0)(eta**2))) for n in range(80)], axis=0)
    
    def thermalDetunedRabi(self, x, OR, nbar, delta, eta, eta2=0):
        """Use a thermal distribution of Rabi frequencies to fit Rabi oscillations.
        Assume x is time in us. delta is detuning from resonance in same units as OR
        Include coupling to 2 radial axes using e^(-ik.r) = e^(-ikx)e^(-iky)"""
        self.args = ['Rabi Frequency (kHz)', 'Mean motional number', 'Detuning (kHz)'
                     'Lamb-Dicke_x', 'Lamb-Dicke_y']
        boltz_factor = nbar/(1. + nbar) #rearrange eq 3.21 to obtain e^{-hbar*omega/(k*T)}
        P = 0
        for n in range(80):
            ORn = OR * np.exp(-eta**2/2)*genlaguerre(n,0)(eta**2) * np.exp(-eta2**2/2)*genlaguerre(n,0)(eta2**2)
            Oeff = (ORn**2+delta**2)**0.5
            Pn = boltz_factor**n / sum(np.power(boltz_factor, range(80)))
            P += Pn * ORn**2/Oeff**2 * np.sin(2*np.pi*Oeff*1e-3*x/2)**2
        return P
    
    def thermalRabi_Sideband(self, x, OR, nbar, eta):
        """Use a thermal distribution of Rabi frequencies to fit Rabi oscillations
        on the first sideband. Assume x is time in us.
        Set the Lamb-Dicke parameter as the property eta, defaults to 0.2.
        Assumes dephasing is due to thermal distribution of states."""
        self.args = ['Rabi Frequency (kHz)', 'Mean Motional Number', 'LD parameter']
        boltz_factor = nbar/(1. + nbar) #rearrange eq 3.21 to obtain e^{-hbar*omega/(k*T)}
        return np.sum([boltz_factor**n / sum(np.power(boltz_factor, range(100)))/2 
            * (1 - np.cos(x * OR*2*np.pi*1e-3 * np.exp(-eta**2/2)*genlaguerre(n,1)(eta**2)*eta/(n+1)**0.5)) for n in range(100)], axis=0)
            
    def RamanHamiltonians(self, w, eta, Or, det, T, N_ph, nbar=None):
        w *= 2*np.pi # convert to rad/s
        Or *= 2*np.pi
        det *= 2*np.pi
        # phonon operators
        a = qutip.tensor(qutip.destroy(N_ph),qutip.qeye(2)) #anhilation operator
        #atomic operators here, #g1->0,g2->1,e->2
        sm_g2g1 = qutip.tensor(qutip.qeye(N_ph), qutip.sigmap()) # this is spin lowering, sig-
        #Hamiltonians
        Hb = det / 2 * qutip.tensor(qutip.qeye(N_ph), qutip.sigmaz()) # bare Hamiltonian, internal
        for i in range(N_ph): Hb += w*i*qutip.tensor(qutip.fock_dm(N_ph,i), qutip.qeye(2)) # number states
        HR = Or / 2 * sm_g2g1 * (-1j*eta*(a + a.dag())).expm() # Raman Hamiltonian
        HR += HR.dag() # add Hermitian conjugate
        if not nbar:
            psi0 = qutip.tensor(qutip.thermal_dm(N_ph, qutip.n_thermal(w, T*1e-9*kB/hbar), method='operator'), qutip.fock_dm(2,0)) 
        else:
            psi0 = qutip.tensor(qutip.thermal_dm(N_ph, nbar, method='operator'), qutip.fock_dm(2,0)) 
        return Hb, HR, psi0
        
            
    def RabiSimulation(self, tlist=np.linspace(0,0.5,500), w=140, eta=0.22, Or=31, det=140, T=10, N_ph=20):
        """Make the Hamiltonians
        w:    Harmonic oscillator frequency (kHz)
        eta:  Lamb-Dicke parameter
        Or:   Rabi frequency for the carrier transition of the raman beams (kHz)
        det:  detuning from 2-photon resonance (kHz) (+ve for blue sideband)
        T:    Temperature (uK)
        N_ph: Hilbert space dimension for the external motion"""
        Hb, HR, psi0 = self.RamanHamiltonians(w, eta, Or, det, T, N_ph)
        result = qutip.mesolve([Hb, HR], psi0, tlist, e_ops=qutip.tensor(qutip.qeye(N_ph), qutip.fock_dm(2,1)), progress_bar=qutip.ui.EnhancedTextProgressBar())
        return tlist*1e3, result.expect[0]

    def RabiSimulationMC(self, tlist=np.linspace(0,0.5,500), w=140, eta=0.22, Or=31, det=140, T=10, N_ph=20):
        """Monte Carlo Simulation of Raman Hamiltonian"""
        Hb, HR, psi0 = self.RamanHamiltonians(w, eta, Or, det, T, N_ph)
        th = np.diag(qutip.thermal_dm(N_ph, qutip.n_thermal(w, T*2e-9*kB/hbar), method='operator'))
        psi0 = sum(qutip.tensor(qutip.fock(N_ph,i)*th[i], qutip.fock(2,0)) for i in range(N_ph))
        result = qutip.mcsolve([Hb, HR], psi0, tlist, progress_bar=qutip.ui.EnhancedTextProgressBar())
        return tlist*1e3, qutip.expect(qutip.tensor(qutip.qeye(N_ph), qutip.fock_dm(2,1)), result.states)

    def RabiSimulationGauss(self, tlist=np.linspace(0,0.5,500), w=140, eta=0.22, Or=31, det=140, T=10, N_ph=20, dur=0.5):
        """Gaussian pulse shape for Raman Hamiltonian, returns P(spin up)
        w:    Harmonic oscillator frequency (kHz)
        eta:  Lamb-Dicke parameter
        Or:   Rabi frequency for the carrier transition of the raman beams (kHz)
        det:  detuning from 2-photon resonance (kHz) (+ve for blue sideband)
        T:    Temperature (uK)
        N_ph: Hilbert space dimension for the external motion
        dur:  equivalent duration of a square pulse (ms)"""
        Hb, HR, psi0 = self.RamanHamiltonians(w, eta, Or, det, T, N_ph)
        HRc = np.exp(- np.pi*((tlist-np.max(tlist)/2)/dur)**2)
        result = qutip.mesolve([Hb, [HR, HRc]], psi0, tlist, e_ops=qutip.tensor(qutip.qeye(N_ph), qutip.fock_dm(2,1)), progress_bar=qutip.ui.EnhancedTextProgressBar())
        return tlist*1e3, result.expect[0]

    def RabiSimulationTukey(self, tlist=np.linspace(0,0.5,500), w=140, eta=0.22, Or=31, det=140, T=10, N_ph=20, a=2/3):
        """Tukey pulse shape for Raman Hamiltonian, returns P(spin up)
        a:    curvature of Tukey edges"""
        Hb, HR, psi0 = self.RamanHamiltonians(w, eta, Or, det, T, N_ph)
        HRc = tukey(len(tlist), a)
        tlist /= (1-a/2) # equal area to square pulse
        result = qutip.mesolve([Hb, [HR, HRc]], psi0, tlist, e_ops=qutip.tensor(qutip.qeye(N_ph), qutip.fock_dm(2,1)), progress_bar=qutip.ui.EnhancedTextProgressBar())
        return tlist*1e3, result.expect[0]  
    
    def Randomised_Benchmarking(self, x, dif, eg):
        """Clifford gate error from https://journals.aps.org/prl/pdf/10.1103/PhysRevLett.121.240501"""
        self.args = ['depolarisation', 'error per gate']
        return 0.5 + 0.5*(1-dif) * (1 - 2*eg)**x

    def getBestFit(self, fn, **kwargs):
        """Use scipy.optimize.curve_fit to get the best fit to the supplied data
        using the supplied function fn
        Returns tuple of best fit parameters and their errors"""
        self.bff = fn
        popt, pcov = curve_fit(fn, self.x, self.y, p0=self.p0, sigma=self.yerr,**kwargs)
        self.ps = popt
        self.perrs = np.sqrt(np.diag(pcov))

    def applyFit(self, fn, **kwargs):
        """Get a best fit and then use it to make a curve (x,y)"""
        self.getBestFit(fn, **kwargs)
        newx = np.linspace(min(self.x), max(self.x), 1000)
        return newx, fn(newx, *self.ps)

    def report(self):
        """Report on the best fit: parameters and chisquared value."""
        d = data(self.ps, self.perrs)
        vals, errs = d.adjustSF()
        try:
            chisqval = np.sum((self.y - self.bff(self.x, *self.ps))**2 / self.yerr**2)
            msg =  "Chi-squared value          --- %.5g\n"%chisqval
            msg += "Reduced chi-squared value  --- %.5g\n"%(chisqval/(len(self.x)-1))
        except Exception as e:
            msg = ""
            print('Chi-squared failed with exception: ', e, '\n')
        try:
            sl = max(list(map(len, self.args)))
            for p, v, e in zip(self.args, vals, errs):
                msg += p + ' '*(sl-len(p)) + '\t --- \t' + str(v) + ' +/- ' + str(e) + '\n'    
        except Exception as e:
            print('Exception: ', e, '\n')
            msg = '\n'.join([str(v) + ' +/- ' + str(e) for v, e in zip(vals, errs)])
        return msg
        
    def chisq(self, fn, x=0, y=0, yerr=0, param=0):
        """Chi-squared value for measured data y with errors yerr fitted to a
        theoretical model fn with parameters param at independent variable values x
        These input variables can either be supplied to the fit class object,
        or if they are absent then it is assumed that they have been passed to the function."""
        
        if np.size(self.x) != 0 and np.size(self.y) != 0 and np.size(self.yerr) != 0:
            chisqval = np.sum((self.y - fn(self.y, *self.ps))**2 / self.yerr**2)
        else:
            if x == 0 or y == 0 or yerr == 0:
                print("WARNING in chisq(): No input variables were supplied.")
            chisqval = np.sum((y - fn(y, *param))**2 / yerr**2)
            
        return chisqval
        
    def NMmin(self, v, maxiter=500):
        """Nelder Mead descent with initial vertices v [[1,2,3] for each parameter]
        needs to be a sensible guess, especially if the error surface is complicated
        Might need some fiddling to find the global minimum."""
        # see http://www.jasoncantarella.com/downloads/NelderMeadProof.pdf        
        
        for i in range(maxiter):
            # chi-squared values for the current vector of parameters
            chis = np.array([self.chisq(x) for x in v.T]) 
            
            # re-order vertices as best, middle, worst
            i3 = np.where(chis == max(chis))[0][0] # worst, W
            i1 = np.where(chis == min(chis))[0][0] # best, B
            i2 = (3 - i1 - i3)%3                   # middle (good), G
            ils = [i1, i2, i3]                     # list of indices
            chis.sort()
            
            v = np.array([[v[ind1,ind2] for ind2 in ils] for ind1 in range(len(v))])
            
            M = 0.5 * (v[:,0] + v[:,1])   # midpoint of best and middle side
            R = 2 * M - v[:,2]            # reflection through best-middle side
            
            if self.chisq(R) < chis[1]:
                if chis[0] < self.chisq(R): # reflect
                    v[:,2] = R         # replace worst with R
                else:
                    E = 2*R - M           # extend in direction M->R
                    if self.chisq(E) < chis[0]:
                        v[:,2] = E     # replace worst with E
                    else:
                        v[:,2] = R     # replace worst with R
        
            else:
                if self.chisq(R) < chis[2]: # reflect
                    v[:,2] = R         # replace worst with R
        
                else:
                    W = v[:,2]
                    C = 0.5 * (W + M)     # centre point of W-M
                    
                    if self.chisq(C) < chis[2]: # shrink
                        v[:,2] = C     # replace worst with C
        
                    else:             # shrink
                        Sbw = 0.5 * (v[:,0] + W) # centre point of B-W
                        v[:,2] = Sbw     # replace worst with S
                        v[:,1] = M       # replace middle with M
        
            
            # convergence criterion
            if abs(chis[0] - chis[2]) < 1e-11:
                break
            
            # the minimum chisquared is chis[0]
            # return (chis[0], v.T[0])
            # best fit values of parameters
            return v.T[0]
        
        
class graph:
    """Collection of common graphing functions."""
    def __init__(self, xdata=0, ydata=0, xerrors=[], yerrors=[], xlabel='', 
            ylabel='',  figtitle='', labels=[], bestFitFunc=0, bestFitParam=0, 
            xlimits=None, ylimits=None):
        self.x = xdata
        self.y = ydata
        self.xerr = xerrors
        self.yerr = yerrors
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.xlims = xlimits
        self.ylims = ylimits
        self.title = figtitle
        self.linelabels = labels
        self.bff = bestFitFunc
        self.bfp = bestFitParam

    def gaussFitIm(self, gx, gy, imvals, units='mm', l=20):
        """Display a bmp image with Gaussian fit in x and y directions.
        gx and gy are fit objects with Gaussian fit parameters ps.
        ps = (amplitude, centre, 1/e^2 width, background offset)
        The units should already be converted in the fit objects.
        units (default 'mm') here is just for the labels.
        l gives the side length of the image in imvals to zoom in on.
        Code adapted from Vincent Brooks' batch_fitter.py"""
        # first zoom in on the central region
        c = []         # get the central pixel index (x, y)
        for g in [gx, gy]:
            diff = abs(g.x - g.ps[1])       # in converted units
            c.append(np.where(diff == min(diff))[0][0]) # pixel index

        xslice = imvals[c[1], c[0]-l:c[0]+l]     # central horizontal slice of image
        yslice = imvals[c[1]-l:c[1]+l, c[0]]     # central vertical slice of image
        imvals = imvals[c[1]-l:c[1]+l, c[0]-l:c[0]+l,] # zoom in on central region
        
        plt.figure()
        plt.subplots_adjust(left=0.24, right=0.83)
        ax1 = plt.subplot2grid((7,6), (1,1), rowspan=5, colspan = 5)
        ax2 = plt.subplot2grid((7,6), (1,0), rowspan=5, colspan = 1)
        ax3 = plt.subplot2grid((7,6), (0, 1), rowspan=1, colspan = 5)
        ax4 = plt.subplot2grid((7,6), (6,0), rowspan=1, colspan = 6)

        ax2.invert_xaxis()
        ax2.set_xticks((0, np.around(np.max(yslice),-1)))
        ax3.set_yticks((0, np.around(np.around(np.max(xslice),-1)/2), np.around(np.max(xslice),-1)))

        # side and top: centred position with central image slice and Gaussian fit.
        for p1 in [[gx, xslice, 'x', ax3], [gy, yslice, 'y', ax2]]:
            pos = p1[0].x[c[0]-l:c[0]+l] - p1[0].ps[1]  # centred position in converted units
            p1[3].plot(pos, p1[1])      # image p1
            p1[3].plot(pos, p1[0].offGauss(pos, *p1[0].ps))   # Gaussian fit
            p1[3].xaxis.tick_top()
            p1[3].set_xlabel(p1[2]+' position on CCD ('+units+')')
            p1[3].xaxis.set_label_position('top')
            p1[3].set_ylabel('Intensity')
            p1[3].set_xlim(pos[0], pos[-1])
            p1[3].set_ylim(np.min(p1[1]), np.max(p1[1]))

        ax2.invert_yaxis()

        # show the centre points and widths with their errors:
        p2 = data([gx.ps[1], gx.ps[2], gy.ps[1], gy.ps[2]],
                [gx.perrs[1], gx.perrs[2], gy.perrs[1], gy.perrs[2]])
        (xc, wx, yc, wy), (exc, ewx, eyc, ewy) = p2.adjustSF()
        resStr = ['$'+resLabel+' = %s \pm %s$ '+units for resLabel in ['w_x', 'x_c', 'w_y', 'y_c']]
        ax4.text(0.5, 0.5, resStr[0]+'\t\t'+resStr[1]+'\n'+resStr[2]+'\t\t'+resStr[3],
                ha = 'center', va = 'center', size=14)
        ax4.set_xticks([])
        ax4.set_yticks([])

        ax1.imshow(imvals, cmap='Blues', extent=[xslice[0], xslice[-1], yslice[0], yslice[-1]])
        ax1.yaxis.set_label_position('right')
        ax1.set_ylabel('y position on CCD ('+units+')')
        ax1.set_xlabel('x position on CCD ('+units+')')
        ax1.yaxis.tick_right()
        plt.show()

        
    def weightResPlot(self):
        """A standard plot with errorbars, best fit line and weighted residuals
        in a subplot underneath"""
        fig, ax = plt.subplots(2, 1, gridspec_kw = {'height_ratios':[3, 1],'hspace':0}, sharex='all')
        plt.title(self.title)
        if not np.size(self.linelabels):
            self.linelabels=['','']
        newx = np.linspace(min(self.x), max(self.x), 200) 
        ax[0].errorbar(self.x, self.y, yerr=self.yerr, fmt='o', 
                        capsize=3, markersize=4)
        ax[0].plot(newx, self.bff(newx, *self.bfp), '--')
        ax[0].set_ylabel(self.ylabel)
        ax[0].tick_params(direction='in',bottom=True,top=True,left=True,right=True)
        ax[0].legend()

        wres = (self.y - self.bff(self.x, *self.bfp))/self.yerr
        ax[1].plot(self.x, wres, 'o')
        ax[1].plot([min(self.x),max(self.x)],[0,0],'--')
        ax[1].set_ylabel("Weighted \nResiduals")
        ax[1].set_ylim((-max(abs(wres))*1.1, max(abs(wres))*1.1))
        plt.xlabel(self.xlabel)
        return fig, ax
        
        
    def residualPlot(self):
        """A standard plot with data, best fit line and residuals in a 
        subplot underneath"""
        fig, ax = plt.subplots(2, 1, gridspec_kw = {'height_ratios':[3, 1],'hspace':0}, sharex='all')
        plt.title(self.title)
        res = (self.y - self.bff(self.x, *self.bfp))
        ax[1].plot(self.x, res, 'o')
        ax[1].plot([min(self.x),max(self.x)],[0,0],'k--', alpha=0.4)
        ax[1].set_ylabel("Residuals")
        ax[1].set_ylim((-max(abs(res))*1.1, max(abs(res))*1.1))
        plt.xlabel(self.xlabel)
        # first label is data, second label is fit
        if not np.size(self.linelabels):
            self.linelabels=['','']
        newx = np.linspace(min(self.x), max(self.x), 200) 
        ax[0].plot(self.x, self.y, 'o', markersize=4, label=self.linelabels[0])
        ax[0].plot(newx, self.bff(newx, *self.bfp), '--', label=self.linelabels[1])
        ax[0].set_ylabel(self.ylabel)
        ax[0].tick_params(direction='in',bottom=True,top=True,left=True,right=True)
        ax[0].legend()
        # plt.text(min(self.x)*1.2, max(self.y)*0.9, '', bbox=dict(fc="none"))
        return fig, ax

        
        
    def contourFunction(self, f, param):
        """Use a function f with parameters param to plot a 2D surface with contours
        over the space specified by the x and y limits."""
        # contour plot and colour map
        x0, x1 = self.xlims
        y0, y1 = self.ylims
        x_axis = np.linspace(x0, x1, 100)          
        y_axis = np.linspace(y0, y1, 100)          
        
        # find values over this space
        data = np.zeros((len(x_axis), len(y_axis)))
        for iy, y in enumerate(y_axis):
            for ix, x in enumerate(x_axis):
                data[ix, iy] = f(x, y, *param)
        
        # plot 2D graph colormap for the calculated data
        plt.figure()
        plt.title(self.title)
        im = plt.imshow( data,
                            extent = (x0, x1, y0, y1),
                            origin = 'lower',
                            cmap = matplotlib.cm.gray,
                            aspect = 'auto')
        plt.colorbar(im, orientation='vertical')
        
        CS = plt.contour(x_axis, y_axis, data, cmap=matplotlib.cm.cool)
        plt.clabel(CS, inline=1, fontsize=10)
        plt.xlabel(self.xlabel)
        plt.ylabel(self.ylabel)
        
    def surfacePlot(self, image, colourmap='Blues'):
        """Plot the given image in a 2D colourmap with contours. The extent is
        set by the x and y limits."""
        x0, x1 = self.xlims
        y0, y1 = self.ylims
        
        plt.figure()
        plt.title(self.title)
        im = plt.imshow( image,
                            extent = (x0, x1, y0, y1),
                            origin = 'lower',
                            cmap = colourmap,
                            aspect = 'auto')
        plt.colorbar(im, orientation='vertical')
        
        plt.xlabel(self.xlabel)
        plt.ylabel(self.ylabel)
        
    def surface3DPlot(self, Z, zlabel='', colourmap='Blues'):
        """Plot the 2D array of Z values against the 2D arrays of X and Y values
        on a 3D surface with colourmap."""
        fig = plt.figure()
        plt.title(self.title)
        ax = fig.gca(projection = '3d')
        surface = ax.plot_surface(self.x, self.y, Z, cmap=colourmap)
        
        cb = fig.colorbar(surface)
        cb.remove()
        ax.set_xlabel(self.xlabel)
        ax.set_ylabel(self.ylabel)
        ax.set_zlabel(zlabel)
        plt.draw()
        
        return fig, ax
        

def broken_xaxis(xdata, ydata, xlim1, xlim2, wspace=0.1, **kwargs):
    fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
    fig.subplots_adjust(wspace=wspace)  # adjust space between axes
    # plot the same data on both axes
    ax1.plot(xdata, ydata, **kwargs)
    ax2.plot(xdata, ydata, **kwargs)
    # zoom-in / limit the view to different portions of the data
    ax1.set_xlim(*xlim1)  
    ax2.set_xlim(*xlim2)  
    # hide the spines between ax and ax2
    ax1.spines['right'].set_visible(False)
    ax2.spines['left'].set_visible(False)
    ax1.yaxis.tick_left()
    ax1.tick_params(labelleft='off')  # don't put tick labels at the top
    ax2.yaxis.tick_right()
    d = .015 # how big to make the diagonal lines in axes coordinates
    # arguments to pass plot, just so we don't keep repeating them
    kwargs = dict(transform=ax1.transAxes, color='k', clip_on=False)
    ax1.plot((1-d,1+d), (-d,+d), **kwargs)
    ax1.plot((1-d,1+d),(1-d,1+d), **kwargs)
    kwargs.update(transform=ax2.transAxes)  # switch to the bottom axes
    ax2.plot((-d,+d), (1-d,1+d), **kwargs)
    ax2.plot((-d,+d), (-d,+d), **kwargs)
    return fig, ax1, ax2

# ax1 = plt.subplot(111)
# ax2 = ax1.twiny()
# ax1.plot(power(currents), wavelengths, 'o-')
# ax1.set_xlabel('Power (mW)')
# ax1.set_ylabel('Wavelength (nm)')
# ax2.set_xlabel('Current (mA)')
# ax2.set_title('SLC23-2019-1')
# current_xaxis = np.array([currents[5*i] for i in range(len(currents)//5)])
# ax2.set_xlim(ax1.get_xlim())
# ax2.set_xticks(power(current_xaxis))
# ax2.set_xticklabels(current_xaxis)
# plt.show()

class data:
    """Collection of common data manipulation tools."""
    def __init__(self, results=0, errors=0):
        self.r = results
        self.e = errors
        
    def adjustSF(self):
        """Change the values to the right number of significant figures.
        Note: this format won't show trailing zeros. The output is a float,
        which may need further formatting if integer precision is required."""
        vals = np.array(self.r).copy()  # copy so original keeps full precision
        errs = np.array(self.e).copy()
        
        for i in range(len(vals)):
            # use exponential notation to find the difference in decimal places
            sfv = int(("%e"%vals[i]).split("e")[1])  # s.f. of result
            errstr = "%e"%errs[i]
            sfe = int((errstr).split("e")[1])        # s.f. of error
            
            if errstr[0] == "1":                    
                # if the first s.f. on the error is 1, take to 2 s.f.
                errs[i] = float("%.2g"%errs[i])
                if sfv < sfe:  # if the value is smaller than the error, round it
                    vals[i] = round(vals[i], sfe)
                    sfv = sfe
                vals[i] = float(("%." + str(sfv - sfe + 2) + "g")%vals[i])
            else:
                if sfv < sfe:  # if the value is smaller than the error, round it
                    vals[i] = round(vals[i], sfe)
                    sfv = sfe
                vals[i] = float(("%." + str(sfv - sfe + 1) + "g")%vals[i])
                errs[i] = float("%.1g"%errs[i])
                
                
        return (vals, errs)
        
def groupArrSort(*args):
    """From Vincent Brooks: sort the arrays given in args such that they are
    all ordered with the first argument ascending."""
    nargs = len(args)
    c = np.zeros([len(args[0]),nargs])
    for i in range(nargs):
        c[:,i] = args[i]
    c = c[c[:,0].argsort()]
    
    return c
    