import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.interpolate import splrep, BSpline
from scipy import interpolate
# import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import re

from BaselineRemoval import BaselineRemoval
from matplotlib.cm import get_cmap
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm
from matplotlib.colors import LogNorm
import matplotlib.gridspec as gridspec
SMALL_SIZE = 10
MEDIUM_SIZE = 10
BIGGER_SIZE = 10
# Set font to Helvetica
plt.rc('font',family='Helvetica')
n = 9


fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(7, 4), sharex=True)

def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=256):

    new_cmap = LinearSegmentedColormap.from_list(
        f"truncated_{cmap.name}",
        cmap(np.linspace(minval, maxval, n))
    )
    return new_cmap
# Original colormap
cmap_hot = plt.get_cmap('hot')


# Original colormap
cmap_hot = plt.get_cmap('hot')

# Truncate the colormap (e.g., keep the range from 0.2 to 0.9)
cmap_truncated = truncate_colormap(cmap_hot, minval=0, maxval=0.66
                                   )
coloe_list = [

    "#03045E",  # Dark Slate Blue
    "#007786",  # Steel Blue
    "#00B4D8",  # Light Blue

    "#90E0EF",# Light Coral
    "#BCD2E8",  # Coral
    "#f2c85b",
    "#F86E51",  # Tomato
    "#EE3E38",  # Cinnabar
    "#D1193E",  # Indian Red
    "#8B0000",   # Dark Red
]
cmap = cmap_truncated
fitted_slope = []
fitted_intercept = []
length=[]
a = []
F = []
k=0
# Define RGB values for each color



folder_path = '/AFM_simulation_line_tension (Fig_3)/line_tension'
os.chdir(folder_path)
def see_force (input_file, k, shift_by_final, data_from_correct, ax):
    # Get the name of the input file as a string
    curve_full_name = str(input_file)
    extracted_value = float(str(curve_full_name.split('.txt')[0]).split('_')[-1])
    calculated_value = extracted_value / 3.6 * 1e-6
    curve_name = f"{calculated_value:.1e}"  # Format as scientific notation

    print('strat processing', curve_full_name)

    f = open(input_file, "r")
    lines=f.readlines()
    energy=[]
    extra = []
    position=[]
    for x in lines:
        energy.append(x.split('\t')[1])
        position.append(x.split('\t')[0])

    position_corrected = []
    energy_corrected = []
    for i in position:
        position_corrected.append(float((i.strip()))+data_from_correct)
    for j in energy:
        energy_corrected.append(float((j.strip())))

    print('position', position_corrected)
    print('energy',energy_corrected)

    print('position', len(position))
    print('energy', len(energy))

    array_energy = np.array(energy_corrected)
    if shift_by_final == 'TRUE':
        shifted_energy = array_energy - np.average(array_energy[-1])
    else:
        shifted_energy = array_energy - np.average(array_energy[-int(len(array_energy) / 3)::])
    array_position = np.array(position_corrected)
    chopped_position = array_position[0:-1]
    print("lenchopped_position", len(chopped_position))
    print(array_position)
    print(array_energy)
    with open("sp_fit_source.txt", "r") as file_source:
        for line in file_source:
            parts = line.split()
            if len(parts) == 2:
                curve_full_name, sp_fit_cut = parts
                if curve_full_name == curve_full_name:
                    sp_fit_cut = int(sp_fit_cut)  # Convert to integer
                    print(f"The number associated with '{curve_full_name}' is: {sp_fit_cut}")
                    print('the type is', type(sp_fit_cut))
                    break
        else:
            print(f"No entry found for '{curve_full_name}'")

    print('hi')
    s_fit = 1
    tck = interpolate.splrep(array_position[:sp_fit_cut+1], shifted_energy[:sp_fit_cut+1], s=s_fit)

    print("here",array_position[sp_fit_cut])

    xnew = np.linspace(array_position[0], array_position[sp_fit_cut], 100)
    print('xnew',xnew[0])

    print('black', array_position[sp_fit_cut])
    print('red',xnew[-1])

    simulation_blunt_length = 0.0075 * np.tan(np.deg2rad(float(30) / 2))

    def format_scientific_latex(value):
        if value == 0:
            return "0"
        # Convert value to scientific notation
        formatted = f"{value:.1e}"
        # Split into coefficient and exponent
        coefficient, exponent = formatted.split('e')
        exponent = int(exponent)  # Convert exponent to integer for LaTeX
        if exponent == 0:
            return f"${coefficient}$"  # Just show the coefficient if exponent is zero
        # Construct the LaTeX label with superscript
        return f"${coefficient} \\times 10^{{{exponent}}}$"

    formatted_label = format_scientific_latex(float(curve_name))
    print('simulation_blunt_length', simulation_blunt_length)
    ax.plot(xnew / simulation_blunt_length, -BSpline(*tck)(xnew, 1) / simulation_blunt_length, '-',
             color=cmap(k * (1 / 6)), label=formatted_label, linewidth=1.5)
    x_flat = [xnew[-1], array_position[sp_fit_cut], array_position[-1], 1]
    y_flat = [-BSpline(*tck)(xnew[-1], 1), 0, 0, 0]

    ax.plot((x_flat) / simulation_blunt_length, y_flat / simulation_blunt_length, '-', color=cmap(k * (1 / 6)),
            zorder=5,
            linewidth=1.5, alpha=0.8)

    a.append(-BSpline(*tck)(xnew, 3))
    length.append(array_position[sp_fit_cut]+0.0025)
    F.append(BSpline(*tck)(xnew[0], 1))

    k=k+1

 # List all the .txt files in the folder
txt_files = [f for f in os.listdir(folder_path) if f.endswith('.txt') and f.startswith('0723_condition_0.0002')]
print(txt_files)

def myFunc(file_name):

  return float((str(file_name.split('.txt')[0])).split('_')[-1])/3.6*1e-6

# Sort the files based on the numeric part of curve_name
txt_files.sort(key=myFunc)

print('file_list', txt_files)

k = 0

for i, txt_file in enumerate(txt_files):
    position_value = myFunc(txt_file)

    input_file = txt_file
    see_force(input_file, k, shift_by_final='TRUE',data_from_correct=0, ax=ax_right)
    k += 1
# Set the x-ticks
ax_right.set_xlim(-1,8.1)

ax_right.set_xticks(np.arange(0, 10, 2))
ax_right.set_ylim(-41.2,6.5)
ax_right.legend(fontsize=10, loc='lower right',frameon=False)
ax_right.set_xlabel('$\dfrac{Position}{R}$', fontsize=10)

# Experiment
Sensitivity = 80 # nm/V, already input
Spring_constant = 1.97 # nN/nm
blunt_radius = 56

ylg=0.02
red_color = tuple(val / 255.0 for val in (177,24,45))
blue_color = tuple(val / 255.0 for val in (36,100,171))
orange_color = tuple(val / 255.0 for val in (241,108,35))
green_color = tuple(val / 255.0 for val in (27,124,61))
error_color = tuple(val / 255.0 for val in (71,146,196))
linewidth_plot = 0.6


adhesion_force_array = []
bridge_length_array = []
speed_array = []

def extract_number_from_curve_name(curve_name):
    match = re.search(r'(\d+)', curve_name)
    if match:
        return int(match.group())
    return float('inf')  # Use a large number for names without numbers

def process_error(input_file):

    # Get the name of the input file as a string
    file_name = input_file
    # curve_name = str(file_name.split('_')[-4])
    print('strat processing',file_name)

    with open(input_file, 'r') as file:
        for line in file:
            if line.startswith('X SetScale/P'):
                info_line = line

    # Print the last line
    if info_line:
        print('info_line:', info_line)
        first_x = float(info_line.split(',')[0].split(' ')[3])
        position_interval = float(info_line.split(',')[1])
        print('position interval', position_interval)
        print('first_x', first_x)
    else:
        print('No line starting with "X SetScale/P" found in the input file.')

    # Initialize variables
    error_deflection_data = []  # To store the data

    # Read the input file and collect data
    with open(input_file, 'r') as file:
        data_started = False
        for line in file:
            line = line.strip()
            if data_started:
                if line == "END":
                    break  # Exit the loop when "end" is encountered
                error_deflection_data.append(float(line))  # Convert the line to a float and append it to the data list
            elif line == "BEGIN":
                data_started = True


    # put into an array in Force
    error_data_array = np.array(error_deflection_data) * Spring_constant

    error_position_array = np.arange(first_x, (len(error_data_array) + 2) * position_interval, position_interval).astype(
            float)
    error_position_array = error_position_array[:len(error_data_array)]+error_data_array / Spring_constant

    print('x_len', len(error_position_array))
    print('error_len', len(error_data_array))

    error_position_array = error_position_array * 1e9 / (blunt_radius / 2)
    error_data_array = error_data_array * 1e9 / (blunt_radius / 2 * ylg)

    return error_data_array, error_position_array

def process_input_file(input_file):

    # Get the name of the input file as a string
    file_name = input_file
    # curve_name = str(file_name.split('_')[-4])
    print('strat processing',file_name)

    with open(input_file, 'r') as file:
        for line in file:
            if line.startswith('X SetScale/P'):
                info_line = line

    # Print the last line
    if info_line:
        print('info_line:', info_line)
        first_x = float(info_line.split(',')[0].split(' ')[3])
        position_interval = float(info_line.split(',')[1])
        print('position interval', position_interval)
        print('first_x', first_x)
    else:
        print('No line starting with "X SetScale/P" found in the input file.')

    # Initialize variables
    deflection_data = []  # To store the data

    # Read the input file and collect data
    with open(input_file, 'r') as file:
        data_started = False
        for line in file:
            line = line.strip()
            if data_started:
                if line == "END":
                    break  # Exit the loop when "end" is encountered
                deflection_data.append(float(line))  # Convert the line to a float and append it to the data list
            elif line == "BEGIN":
                data_started = True


    # put into an array in Force
    deflection_data_array = np.array(deflection_data) * Spring_constant

    position_array = np.arange(first_x, (len(deflection_data_array) + 2) * position_interval, position_interval).astype(
            float)
    position_array = position_array[:len(deflection_data_array)]+deflection_data_array / Spring_constant
    position_array = position_array*1e9/(blunt_radius/2)
    deflection_data_array = deflection_data_array*1e9/(blunt_radius/2*ylg)

    print('x_len', len(position_array))
    print('y_len', len(deflection_data_array))


    # Return deflection data array and position array
    return deflection_data_array, position_array


    # baseline removal
    baseObj_retr_minus_ext_base_on_height = BaselineRemoval(deflection_data_array)
    Imodpoly_output_for_ext_minus_retr_base_on_height = baseObj_retr_minus_ext_base_on_height.IModPoly(
        3)

    start_one_third_Imodpoly_output_for_ext_minus_retr_base_on_height = Imodpoly_output_for_ext_minus_retr_base_on_height[
                                                                        0:round(
                                                                            len(Imodpoly_output_for_ext_minus_retr_base_on_height) / 3)]
    shorter_start_half_Imodpoly_output_for_ext_minus_retr_base_on_height = start_one_third_Imodpoly_output_for_ext_minus_retr_base_on_height[
                                                                           np.argmin(
                                                                               start_one_third_Imodpoly_output_for_ext_minus_retr_base_on_height):]

    index_of_zero_x_base_on_height_on_interpad_measured_height = np.argmin(
        start_one_third_Imodpoly_output_for_ext_minus_retr_base_on_height) + np.argmax(
        shorter_start_half_Imodpoly_output_for_ext_minus_retr_base_on_height)
    zero_x_base_on_height = 0
    zero_y_base_on_height = deflection_data_array[
        index_of_zero_x_base_on_height_on_interpad_measured_height]
    shifted_array_retr_minus_ext_base_on_height = deflection_data_array - zero_y_base_on_height

    y_manual = np.average(shifted_array_retr_minus_ext_base_on_height[
                          int(0.2 * len(shifted_array_retr_minus_ext_base_on_height)):
                          int(0.9 * len(shifted_array_retr_minus_ext_base_on_height))])



    # measure
    F_ad = abs(min(deflection_data_array)*1e9) #adhesion force in nN
    bridge_length = abs(position_array[np.argmin(deflection_data_array)]-(zero_x_base_on_height)) #bridge length in nm
    speed = float(curve_name)
    adhesion_force_array.append(F_ad)
    bridge_length_array.append(bridge_length)
    speed_array.append(speed)
    blunt_length = 56/2*1e-9 # end radius in nm

    plt.axvline(x=0, color='grey', linestyle = '--',zorder=0)
    plt.axhline(y=0, color='grey', linestyle = '--',zorder=0)


# Specify the folder containing the .txt files
folder_path = '/AFM_simulation_line_tension (Fig_3)/AFM'
os.chdir(folder_path)


# Process input file
deflection_data, position_array = process_input_file('B2T6_100_avg_Defl.txt')
show_from=32
ax_left.scatter(position_array[show_from:], deflection_data[show_from:], marker="o",s=3.0, alpha=0.1, edgecolor='blue',facecolor='blue',zorder=5)


# Process error file
error_data, error_position = process_error('B2T6_100_avgSDV_Defl.txt')


# Calculate upper and lower bounds for shaded area
upper_bound = (deflection_data + error_data)
lower_bound = (deflection_data - error_data)

# Shade the area between the curves
ax_left.fill_between(position_array[show_from:], lower_bound[show_from:], upper_bound[show_from:], color=error_color, alpha=0.3, label='Experiment±SDV',zorder=6)


folder_path = '/AFM_simulation_line_tension (Fig_3)/Simulation'
os.chdir(folder_path)

SMALL_SIZE = 15
MEDIUM_SIZE = 20
BIGGER_SIZE = 20
n = 6
fitted_slope = []
fitted_intercept = []

To_fill_x_array = []
To_fill_y_array = []

def see_force_plot (input_file, k, shift_by_final, ax):

    # Get the name of the input file as a string
    curve_full_name = str(input_file)
    curve_name = str(curve_full_name.split('_')[-4])
    cone_angle.append(float(curve_name))
    print('strat processing', curve_full_name)

    f=open(input_file, "r")
    lines=f.readlines()
    energy=[]
    extra = []
    position=[]
    for x in lines:
        energy.append(x.split('\t')[1])
        position.append(x.split('\t')[0])

    position_corrected = []
    energy_corrected = []
    for i in position:
        position_corrected.append(float((i.strip())))
    for j in energy:
        energy_corrected.append(float((j.strip())))

    print(position_corrected)
    print(energy_corrected)

    print('position', len(position))
    print('energy', len(energy))

    array_energy = np.array(energy_corrected)
    if shift_by_final == 'TRUE':
        shifted_energy = array_energy - np.average(array_energy[-1])
    else:
        shifted_energy = array_energy - np.average(array_energy[-int(len(array_energy) / 4)::])
    array_position = np.array(position_corrected)+0.0075


    # Define the target file name you want to search for
    target_file_name = curve_full_name

    # Open the data file for reading
    with open("sp_fit_source.txt", "r") as file_source:
        for line in file_source:
            parts = line.split()
            if len(parts) == 2:
                curve_full_name, sp_fit_cut = parts
                if curve_full_name == target_file_name:
                    sp_fit_cut = int(sp_fit_cut)  # Convert to integer
                    print(f"The number associated with '{target_file_name}' is: {sp_fit_cut}")
                    print('the type is',type(sp_fit_cut))
                    break
        else:
            print(f"No entry found for '{target_file_name}'")


    plt.locator_params(axis='x', nbins=4)

    print('hi')
    s = 1
    tck = interpolate.splrep(array_position[:sp_fit_cut], shifted_energy[:sp_fit_cut], s=s)

    # Get the coefficients of the fitted spline
    coeffs = tck[1]

    # The coefficients include the quadratic term at index 2
    a = coeffs[2]
    quadratic_constants.append(a)

    sp_fit_cut = int(sp_fit_cut)
    show_from_sim=23
    xnew = np.arange(array_position[show_from_sim], array_position[sp_fit_cut], 0.0001)

    simulation_blunt_length = 0.0075*np.tan(np.deg2rad(float(curve_name)/2))

    print('simulation_blunt_length',simulation_blunt_length)
    ax.plot((xnew)/simulation_blunt_length, (-BSpline(*tck)(xnew, 1))/simulation_blunt_length, '-', color = 'r', linewidth=2,alpha=0.8,zorder = 5)
    print('shade area in x',xnew[0]/simulation_blunt_length,xnew[-1]/simulation_blunt_length)
    adhesion_force.append(-(-BSpline(*tck)(xnew, 1)[0]))


    x_flat=[xnew[-1],array_position[sp_fit_cut], array_position[-1],1]
    y_flat=[-BSpline(*tck)(xnew[-1], 1),0,0,0]
    # Plot a vertical line at x=2

    ax.plot((x_flat)/simulation_blunt_length, y_flat/simulation_blunt_length, '-', color ='r',zorder=5,linewidth=2,alpha=0.8)


def see_force_fill(input_file, k, shift_by_final, ax):

    # Get the name of the input file as a string
    curve_full_name = str(input_file)
    curve_name = str(curve_full_name.split('_')[-4])
    cone_angle.append(float(curve_name))
    print('strat processing', curve_full_name)

    f=open(input_file, "r")
    lines=f.readlines()
    energy=[]
    extra = []
    position=[]
    for x in lines:
        energy.append(x.split('\t')[1])
        position.append(x.split('\t')[0])

    position_corrected = []
    energy_corrected = []
    for i in position:
        position_corrected.append(float((i.strip())))
    for j in energy:
        energy_corrected.append(float((j.strip())))

    print(position_corrected)
    print(energy_corrected)

    print('position', len(position))
    print('energy', len(energy))

    array_energy = np.array(energy_corrected)
    if shift_by_final == 'TRUE':
        shifted_energy = array_energy - np.average(array_energy[-1])
    else:
        shifted_energy = array_energy - np.average(array_energy[-int(len(array_energy) / 4)::])
    array_position = np.array(position_corrected)+0.0075


    # Define the target file name you want to search for
    target_file_name = curve_full_name

    # Open the data file for reading
    with open("sp_fit_source.txt", "r") as file_source:
        for line in file_source:
            parts = line.split()
            if len(parts) == 2:
                curve_full_name, sp_fit_cut = parts
                if curve_full_name == target_file_name:
                    sp_fit_cut = int(sp_fit_cut)  # Convert to integer
                    print(f"The number associated with '{target_file_name}' is: {sp_fit_cut}")
                    print('the type is',type(sp_fit_cut))
                    break
        else:
            print(f"No entry found for '{target_file_name}'")

    print('hi')
    s = 1
    tck = interpolate.splrep(array_position[:sp_fit_cut], shifted_energy[:sp_fit_cut], s=s)



    sp_fit_cut = int(sp_fit_cut)
    xnew = np.arange(array_position[0], array_position[sp_fit_cut], 0.0001)

    simulation_blunt_length = 0.0075*np.tan(np.deg2rad(float(curve_name)/2))
    print('simulation_blunt_length',simulation_blunt_length)

    To_fill_x_array.append((xnew)/simulation_blunt_length)
    print('simulation_blunt_length',simulation_blunt_length)

    To_fill_y_array.append((-BSpline(*tck)(xnew, 1))/simulation_blunt_length)
    adhesion_force.append(-(-BSpline(*tck)(xnew, 1)[0]))


    x_flat=[xnew[-1],array_position[sp_fit_cut], array_position[-1]]
    y_flat=[-BSpline(*tck)(xnew[-1], 1),0,0]

# List all the .txt files in the folder
txt_files = [f for f in os.listdir(folder_path) if f.endswith('angle_5.txt') and f.startswith('1')]

selected_files_fill = []
selected_files_plot =[]

for filename in txt_files:
    match = re.search(r'_cone_angle_(\d+)_', filename)
    if match:
        cone_angle = int(match.group(1))
        if cone_angle == 22 or cone_angle == 38:
            selected_files_fill.append(filename)

for filename in txt_files:
    match = re.search(r'_cone_angle_(\d+)_', filename)
    if match:
        cone_angle = int(match.group(1))
        if cone_angle == 30:
            print('find 30')
            selected_files_plot.append(filename)

def myFunc(file_name):
  return float(str(file_name.split('_')[-4]))

# Sort the files based on the numeric part of curve_name
selected_files_fill.sort(key=myFunc)
selected_files_plot.sort(key=myFunc)

print('selected_files_fill', selected_files_fill)

quadratic_constants = []  # Array to store quadratic constants
cone_angle = []
adhesion_force = []
 # Process each .txt file and create a subplot for each
k = 0
for i, txt_file in enumerate(selected_files_plot):
    input_file = txt_file
    see_force_plot(input_file,k,shift_by_final='TRUE', ax=ax_left)
    k += 1

for i, txt_file in enumerate(selected_files_fill):
    input_file = txt_file
    see_force_fill(input_file,k,shift_by_final='TRUE', ax=ax_left)
    k += 1

# Adjust the spacing between figures (increase top and bottom space)
plt.subplots_adjust(top=0.95, bottom=0.05)


# Fill the area between the first and last lines
fill_x_first = To_fill_x_array[0]
fill_y_first = To_fill_y_array[0]
fill_x_last = To_fill_x_array[-1]
fill_y_last = To_fill_y_array[-1]


# Interpolate to align the dataƒ
f_y1 = interpolate.interp1d(fill_x_first, fill_y_first, kind='linear', fill_value='extrapolate')
f_y2 = interpolate.interp1d(fill_x_last , fill_y_last, kind='linear', fill_value='extrapolate')


# Ensure that the common_x_range is within the range of all curves
common_x = np.linspace(min(min(fill_x_first), min(fill_x_last)), max(max(fill_x_first), max(fill_x_last)), 100)
common_x = np.linspace(-1.2440169358562927,4.976067743425159, 100)

ax_left.fill_between(common_x, f_y1(common_x), f_y2(common_x), color='coral', alpha=0.15, edgecolor='none',label='Simulation '+ r'$\alpha$'+ ' = 30±8'+u'\N{DEGREE SIGN}',zorder =1,linewidth=2)
ax_left.set_xlim(-1,8.1)
ax_left.set_ylim(-41.2,6.5)
ax_left.set_xticks([0,2,4,6,8],fontsize=10)
ax_left.set_yticks([-40,-30,-20,-10,0],fontsize=10)
ax_left.set_xlabel('$\dfrac{Position}{R}$', fontsize=10)
ax_left.set_ylabel('$\dfrac{F}{R\, \\times \, \gamma_{LG}}$', fontsize=10)
ax_right.axvspan(-1,0,facecolor='gray',alpha=0.15,edgecolor='none')
plt.tight_layout()
save_name = 'exp_sim_lt.svg'
save_path = '/Users/thornbird/Library/CloudStorage/OneDrive-DurhamUniversity/PhD/Tip/Figure paper/'

# plt.savefig(save_path+save_name, format='svg', bbox_inches='tight',dpi=1200)
plt.show()



