from __future__ import division
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 13:47:26 2019
@author: Rahul

02/09/2019: Vincent
Adapted for fitting to data
"""

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
import os
import pandas as pd
import lmfit
cwd = os.getcwd()
clear_output(wait=True)
###########################################################################
#plt.style.use('Alex_Style')
####  Constants  ####
kB        = 1.38e-23           # Boltzman's constant in SI units.
amu       = 1.66053e-27       # Mass of Rubedium atom, kg.
h_plank   = 6.626e-34     # Plank's constant in SI units.
g         = -9.81               # Gravitational acceleration in SI units 
epsilon_0 = 8.854e-12
c         = 299792458
bohr_rad  = 5.29E-11


def adjustSF(vals, errs):
    """Change the values to the right number of significant figures.
    Note: this format won't show trailing zeros. The output is a float,
    which may need further formatting if integer precision is required."""
    sfv = int(("%e"%vals).split("e")[1])  # s.f. of result
    errstr = "%e"%errs
    sfe = int((errstr).split("e")[1])        # s.f. of error
    if sfv < sfe:  # if the value is smaller than the error, round it
        vals = round(vals, sfe)
        sfv = sfe
    errs = float("%.1g"%errs)
    return ("%." + str(sfv - sfe + 1) + "g")%vals + '(' + ("%e"%errs)[0] + ')'

def arraySorter(*args):
    """Sort arrays by an array."""
    c = np.zeros([len(args[0]),len(args)])
    for i in range(len(args)):
        c[:,i] = args[i]
    c = c[c[:,0].argsort()]

    lnargs = [None] * int(len(args))  
    for i in range(len(args)):
        lnargs[i] =  c[:,i]
        
    return lnargs

class TemperatureMonteCarlo():
    def __init__(self, species, num_sims, T, lm, wr, power, polarisibility):
        self.num_sims         = num_sims 
        self.T                = T # Initial temperature, Kelvin.
        self.lm               = lm # Light wavelength, m.
        self.wr               = wr # Beam focus size, m.
        self.power            = power # Beam power in mW
        self.polarisibility   = polarisibility # Pol in a0^3

        if species == 'Cs':
            self.m = 133 * amu
            # self.polarisibility = 1162  # At 1064nm.
        elif species == 'Rb':
            self.m = 87 * amu
        self.initialize()
    def potential(self, x,y,z):
        '''This function returns the potential energy as a function of x,y and z '''
        wrz = self.wr*np.sqrt(1 + (z*self.lm/(np.pi*self.wr**2))**2)
        return self.U0*((self.wr/wrz)**2)*np.exp(-2*(x**2+y**2)/wrz**2)    
    
    def randx(self, T):
        '''Select random x/y position from thermal distribution'''
        watr = 2*np.sqrt(kB*T/(self.m*(self.omega_r)**2))#width of atomic spatial probability distribution in x and y direction
        return np.random.normal(loc = 0.0, scale = watr/2, size = None)
    
    def randz(self, T):
        '''Select random z position from thermal distribution'''
        watz = 2*np.sqrt(kB*T/(self.m*(self.omega_z)**2))#width of atomic spatial probability distribution in z direction
        return  np.random.normal(loc = 0.0, scale = watz/2, size = None)
    
    def randv(self, T):
        '''Select random velocity from thermal distribution'''
        return np.random.normal(loc = 0.0, scale = np.sqrt(kB*T/self.m), size = None)#
    

    
    def collision_sim(self, T, nsims,time):
        Ts = np.ones(nsims)*T
        x0 = self.randx(Ts)
        y0 = self.randx(Ts)
        z0 = self.randz(Ts)
        vx0 = self.randv(Ts)
        vy0 = self.randv(Ts)
        vz0 = self.randv(Ts)
        #Final state
        xt = x0+vx0*time
        yt = y0+vy0*time
        zt = z0+vz0*time-1/2*g*time**2
        vxt = vx0
        vyt = vy0 
        vzt = vz0+g*time
        #Compare with potential energy and see if the particle has enough escape energy. 
        Ut = self.potential(xt,yt,zt)
        #if Ut > (1/2*self.m*(vxt*1e-3)**2) or Ut > 1/2*self.m*(vyt*1e-3)**2 or Ut > 1/2*self.m*(vzt*1e-3)**2:
        nin = len(np.where(Ut > 0.5*self.m*((vxt**2) + (vyt**2) + (vzt**2)))[0])
        
        return [nin/nsims, (nsims-nin)/nsims]

    def getTrapDepth(self, waist, power, polarisibility):
        """Power in mW, waist in nm, polarisibility in a0^3"""
        intensity = 2 * power * 1e-3 / (np.pi * waist **2)
        polarisibilitySI = polarisibility*4*np.pi*epsilon_0*bohr_rad**3
        trapDepth_SI = np.abs(-0.5*polarisibilitySI*intensity / (c * epsilon_0))
        self.trapDepth_MHz = trapDepth_SI / h_plank / 1e6
        #print(self.trapDepth_MHz)
        self.trapDepth_mK = trapDepth_SI / kB * 1000
        self.U0 = trapDepth_SI
        #print(1*1e-3*kB)
        #print(self.U0)
        
    def initialize(self):
        """Given the temperature, wavelength, waist and trap power,
           calcualte a few more important simulation parameters."""

        self.getTrapDepth(self.wr, self.power, self.polarisibility)
        self.gf = (self.lm**2)/(2*(np.pi*self.wr)**2)         #Geometric factor between radial and axial directions.
        self.omega_r = np.sqrt(self.U0 * 4 / (self.m*(self.wr)**2))
        self.omega_z = np.sqrt(self.gf)*self.omega_r                 # Trap frequency in axial direction, rad s-1.


    def runSimulation(self, a, T, times, num_sims=None, plot = False):
        if num_sims == None:
            num_sims = self.num_sims
       
        prob_array = np.array(list(map(lambda time: self.collision_sim(T, num_sims,time*1e-6)[0], times))) * a
        
        if plot == True:
            plt.plot(times, prob_array,'r')
            plt.ylabel('Recapture Probability')
            plt.xlabel('Release Time ($\mu$s)')
            plt.show()
        
        self.MCresults = prob_array
        return(prob_array)

    def getMeasuredData(self, filename, plot = False, omit = None):
        """
           From the results file, extract the recapture probability and error
        """
        df = pd.read_csv(filename, skiprows=2)
        user_variable = np.array([v for i, v in enumerate(df['User variable']) if i not in omit])
        P_recapture = np.array([v for i, v in enumerate(df['Loading probability']) if i not in omit])
        error = np.array([v for i, v in enumerate(df['Error in Loading probability']) if i not in omit])
        if plot == True:
            plt.errorbar(user_variable, P_recapture, yerr = error, marker = 'o', linestyle = 'none')
            plt.show()
        return(user_variable, P_recapture, error)  

               
    def residual(self, params, times, data, errs):
        """calculate the residual"""
        amp = params['amp'].value
        T = params['T'].value
        
        model = self.runSimulation(amp, T, times, num_sims=6000)
        
        #plt.plot(times, model, color = 'k', alpha=0.1)
        #plt.pause(0.01)
        chi = np.sum((model-data)**2)
        #print(chi)
        return(model - data)/errs
    
    def fitMC(self, filename, omit=[], T_guess=19e-6,amp_guess=1, T_vary=True, amp_vary=True):
        """Fit the measured data with the Monte Carlo simulation."""
        print(omit)
        params = lmfit.Parameters()
        params = lmfit.Parameters()
        params.add('T', value = T_guess,  min=0e-6, max=500e-6,  vary=T_vary)
        params.add('amp', value = amp_guess,  min=0.1, max=1.2,     vary=amp_vary)
        
        times, data, errs = self.getMeasuredData(filename, omit=omit)
        
        # print(times)
        # print(data)
        plt.ylim(0.07, 1.1)
        plt.errorbar(times, data, yerr=errs, marker = 'o', linestyle = 'none')
        mini = lmfit.Minimizer(self.residual, params, fcn_args=(times, data, errs))
        results=mini.minimize(method='nelder', tol = 0.5)
        
        print(lmfit.fit_report(results))
        
        #ci = lmfit.conf_interval(mini, results)
        #lmfit.printfuncs.report_ci(ci)
        
        opt = results.params
        
        xfit = np.linspace(0, np.max(times+3), 100)
        yfit = self.runSimulation(opt['amp'], opt['T'], xfit, num_sims = 7000)
        
        plt.close('all')
#        plt.title("P = %.4g mW, waist = "%self.power+str(self.wr*1e6)+" $\mu$m",fontsize=12)
        plt.errorbar(times, data, yerr = errs, marker = 'o', linestyle = 'none')#
        plt.plot(xfit, yfit)#
        plt.xlabel('$\mathrm{Release \ time \ ( \mu s)}$')
        plt.ylabel('$\mathrm{Recapture \ probability}$')
#        plt.savefig('RRtemperatureMC.svg')
        plt.show()
        optT = np.around(results.params['T'].value*1e6, 3)
        errT = np.around(results.params['T'].stderr*1e6, 1)
        
        print('T = ', adjustSF(optT, errT), ' uK')





if __name__ == "__main__":
    
    
    Cs_938 = TemperatureMonteCarlo(
                      T  = 20e-6,
                      species = 'Cs',
                      num_sims = 1000,                
                      lm = 938e-9,
                      wr = 1.16e-6,
                      power = 4.1*0.726, # mW
                      polarisibility = 2880,  # a0^3
                      )
    
    Rb_938 = TemperatureMonteCarlo(
                      T  = 20e-6,
                      species = 'Rb',
                      num_sims = 1000,                
                      lm = 938e-9,
                      wr = 1160e-9,
                      power = 14.2*0.726, # mW
                      polarisibility = 1020,  # a0^3
                      )
    Rb_814 = TemperatureMonteCarlo(
                      T  = 20e-6,
                      species = 'Rb',
                      num_sims = 1000,                
                      lm = 814e-9,
                      wr = 1020e-9,
                      power = 2.55*0.626, # mW
                      polarisibility = 4760,  # a0^3
                      )

    Cs_1064 = TemperatureMonteCarlo(
                      T  = 20e-6,
                      species = 'Cs',
                      num_sims = 1000,                
                      lm = 1064e-9,
                      wr = 1.35e-6,
                      power = 15.3*0.726, # mW
                      polarisibility = 1200,  # a0^3
                      )

    r = Cs_938.fitMC("ROI0_Re_Measure1_RR.dat",amp_guess=0.97,T_vary=True,amp_vary=False)
