import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import re
from matplotlib.cm import get_cmap
from matplotlib import gridspec
from scipy.interpolate import interp1d

Sensitivity = 80 # nm/V, already input
Spring_constant = 1.97 # nN/nm
set_line_width = 1
set_markersize = 0.7
set_hvline_width = 0.5

red_color = tuple(val / 255.0 for val in (159,31,49))
blue_color = tuple(val / 255.0 for val in (40,133,169))
green_color = tuple(val / 255.0 for val in (27,124,61))

font_size_axis = 10
font_size_tick = 10
font_size_legend = 10


# Create a colormap
cmap = get_cmap("viridis")  # You can choose any colormap you prefer

def extract_number_from_curve_name(curve_name):
    match = re.search(r'(\d+)', curve_name)
    if match:
        return int(match.group())
    return float('inf')  # Use a large number for names without numbers

def process_input_file_amplitude(input_file, name, color):

    # Get the name of the input file as a string
    file_name = input_file
    curve_name = str(file_name.split('_')[-4])
    print('strat processing',file_name)

    with open(input_file, 'r') as file:
        for line in file:
            if line.startswith('X SetScale/P'):
                info_line = line

    # Print the last line
    if info_line:
        print('info_line:', info_line)
        position_interval = float(info_line.split(',')[1])
        print('position interval', position_interval)
    else:
        print('No line starting with "X SetScale/P" found in the input file.')

    # Initialize variables
    deflection_data = []  # To store the data

    # Read the input file and collect data
    with open(input_file, 'r') as file:
        data_started = False
        for line in file:
            line = line.strip()
            if data_started:
                if line == "END":
                    break  # Exit the loop when "end" is encountered
                deflection_data.append(float(line))  # Convert the line to a float and append it to the data list
            elif line == "BEGIN":
                data_started = True

    # put into an array
    deflection_data_array = np.array(deflection_data)
    print()

    position_array = np.arange(0, (len(deflection_data_array) + 2) * position_interval, position_interval).astype(
            float)
    position_array = position_array[:len(deflection_data_array)]

    # plot together
    plt.plot(position_array*1e9,deflection_data_array*1e9, 'o-',label = name, markersize = set_markersize,linewidth = set_line_width, alpha = 0.9, color = color)
    print(position_array*1e9)

def process_input_file_deflection(input_file, name, color,y_aligh_start_index_ratio):

    # Get the name of the input file as a string
    file_name = input_file
    curve_name = str(file_name.split('_')[-4])
    print('strat processing',file_name)

    with open(input_file, 'r') as file:
        for line in file:
            if line.startswith('X SetScale/P'):
                info_line = line

    # Print the last line
    if info_line:
        print('info_line:', info_line)
        position_interval = float(info_line.split(',')[1])
        print('position interval', position_interval)
    else:
        print('No line starting with "X SetScale/P" found in the input file.')

    # Initialize variables
    deflection_data = []  # To store the data

    # Read the input file and collect data
    with open(input_file, 'r') as file:
        data_started = False
        for line in file:
            line = line.strip()
            if data_started:
                if line == "END":
                    break  # Exit the loop when "end" is encountered
                deflection_data.append(float(line))  # Convert the line to a float and append it to the data list
            elif line == "BEGIN":
                data_started = True

    # put into an array
    deflection_data_array = np.array(deflection_data)
    print()

    position_array = np.arange(0, (len(deflection_data_array) + 2) * position_interval, position_interval).astype(
            float)
    position_array = position_array[:len(deflection_data_array)]

    # if we correct the bending
    position_array = position_array[:len(deflection_data_array)] + deflection_data_array / Spring_constant

    y_shift = np.average(deflection_data_array[int(len(deflection_data_array)*y_aligh_start_index_ratio):]*1e9)

    plt.plot(position_array*1e9,deflection_data_array*1e9-y_shift, 'o-',label = name, markersize = set_markersize,linewidth = set_line_width, alpha = 0.9, color = color)


def process_input_file_subtract(input_file1, input_file2, name, color,y_aligh_start_index_ratio):
    # Get the name of the input file as a string
    file_name1 = input_file1
    curve_name = str(file_name1.split('_')[-4])
    print('strat processing',file_name1)

    with open(input_file1, 'r') as file:
        for line in file:
            if line.startswith('X SetScale/P'):
                info_line = line

    # Print the last line
    if info_line:
        print('info_line:', info_line)
        position_interval1 = float(info_line.split(',')[1])
        print('position interval', position_interval1)
    else:
        print('No line starting with "X SetScale/P" found in the input file.')

    # Initialize variables
    deflection_data = []  # To store the data

    # Read the input file and collect data
    with open(input_file1, 'r') as file:
        data_started = False
        for line in file:
            line = line.strip()
            if data_started:
                if line == "END":
                    break  # Exit the loop when "end" is encountered
                deflection_data.append(float(line))  # Convert the line to a float and append it to the data list
            elif line == "BEGIN":
                data_started = True

    # put into an array
    deflection_data_array1 = np.array(deflection_data)
    print()

    position_array1 = np.arange(0, (len(deflection_data_array1) + 2) * position_interval1, position_interval1).astype(
            float)
    position_array1 = position_array1[:len(deflection_data_array1)]

    # if we correct the bending
    position_array1 = position_array1[:len(deflection_data_array1)] + deflection_data_array1 / Spring_constant


    y_shift = np.average(deflection_data_array1[int(len(deflection_data_array1)*y_aligh_start_index_ratio):]*1e9)
    # plot together
    x1_plot = position_array1*1e9
    y1_plot = deflection_data_array1*1e9-y_shift

    # Get the name of the input file as a string
    file_name2 = input_file2
    curve_name2 = str(file_name2.split('_')[-4])
    print('strat processing', file_name2)

    with open(input_file2, 'r') as file:
        for line in file:
            if line.startswith('X SetScale/P'):
                info_line = line

    # Print the last line
    if info_line:
        print('info_line:', info_line)
        position_interval2 = float(info_line.split(',')[1])
        print('position interval', position_interval2)
    else:
        print('No line starting with "X SetScale/P" found in the input file.')

    # Initialize variables
    deflection_data2 = []  # To store the data

    # Read the input file and collect data
    with open(input_file2, 'r') as file:
        data_started = False
        for line in file:
            line = line.strip()
            if data_started:
                if line == "END":
                    break  # Exit the loop when "end" is encountered
                deflection_data2.append(float(line))  # Convert the line to a float and append it to the data list
            elif line == "BEGIN":
                data_started = True

    # put into an array
    deflection_data_array2 = np.array(deflection_data2)
    print()

    position_array2 = np.arange(0, (len(deflection_data_array2) + 2) * position_interval2, position_interval2).astype(
        float)
    position_array2 = position_array2[:len(deflection_data_array2)]

    # if we correct the bending
    position_array2 = position_array2[:len(deflection_data_array2)] + deflection_data_array2 / Spring_constant

    y_shift = np.average(deflection_data_array2[int(len(deflection_data_array2) * y_aligh_start_index_ratio):] * 1e9)
    # plot together
    x2_plot = position_array2 * 1e9
    y2_plot = deflection_data_array2 * 1e9 - y_shift

    # Find the overlap area between x1 and x2
    x_min = max(x1_plot.min(), x2_plot.min())
    x_max = min(x1_plot.max(), x2_plot.max())

    # Generate x3 within the overlap area
    x3 = np.linspace(x_min, x_max, 2000)

    # Interpolate y3 using interp1d
    f1 = interp1d(x1_plot, y1_plot)
    f2 = interp1d(x2_plot, y2_plot)
    y3 = f1(x3) - f2(x3)

    plt.plot(x3,-y3*Spring_constant, 'o-',label = name, markersize = set_markersize,linewidth = set_line_width, alpha = 0.9, color = color)



# Specify the folder containing the .txt files
folder_path = '/B2T6_AFM_exp(Fig_2)'
os.chdir(folder_path)


# Create a single figure with GridSpec
fig = plt.figure(figsize=(3.5, 6))
gs = gridspec.GridSpec(3, 1, height_ratios=[1,1,1], hspace=0.08)

# Substraction
ax3 = plt.subplot(gs[2])
process_input_file_subtract('B2T6_100_0006_Defl_e.txt', 'B2T6_100_0006_Defl_r.txt', name='Result', color='g',y_aligh_start_index_ratio=0.7)
plt.legend(fontsize=font_size_legend, frameon=False)
plt.ylabel("F (nN)", fontsize = font_size_axis)
plt.xlabel("Deflection (nm)", fontsize = font_size_axis)
plt.yticks(fontsize=font_size_tick)
plt.axhline(y=0, color='black', linestyle='--', zorder=3,linewidth=set_hvline_width)
plt.tick_params(axis='y', which='both', right=False)  # Remove y ticks on the right
plt.xlim(-30,400)
ax3 = plt.gca()  # Get the current axis
ax3.set_xticks([0, 100, 200, 300, 400])
ax3.set_xticklabels([0, '', 200, '', 400], fontsize=10)
plt.yticks([-16, -8, 0])


ax3.set_ylim(-19,2)

# Your existing code for the second subplot
ax2 = plt.subplot(gs[1],sharex=ax3)
process_input_file_deflection(input_file ='B2T6_100_0006_Defl_e.txt', name ='Extension', color = red_color, y_aligh_start_index_ratio = 0.7)
process_input_file_deflection(input_file = 'B2T6_100_0006_Defl_r.txt',name = 'Retraction', color = blue_color,y_aligh_start_index_ratio = 0.7)

# plt.ylabel("Deflection (nm)", fontsize = font_size_axis)
plt.xticks(fontsize=font_size_tick)
plt.yticks(fontsize=font_size_tick)
plt.legend(fontsize=font_size_legend, frameon=False)
plt.axhline(y=0, color='black', linestyle='--', zorder=3,linewidth=set_hvline_width)
plt.tick_params(axis='y', which='both', right=False)  # Remove y ticks on the right
plt.yticks(np.arange(-8, 4, 4))
plt.ylabel("Deflection (nm)", fontsize = font_size_axis)
plt.setp(ax2.get_xticklabels(), visible=False)
ax2.set_ylim(-11,2.5)


# Amplitude figure with adjusted height
ax1 = plt.subplot(gs[0],sharex=ax3)
process_input_file_amplitude(input_file ='B2T6_100_0006_Amp_e.txt', name ='Extension', color =red_color)
process_input_file_amplitude(input_file ='B2T6_100_0006_Amp_r.txt', name ='Retrction', color =blue_color)
plt.ylabel("Amplitude (nm)", fontsize = font_size_axis)
plt.yticks(fontsize=font_size_tick)
plt.axhline(y=1.2, color='black', linestyle='--', zorder=3, linewidth= set_hvline_width)
plt.locator_params(axis='y', nbins=6)
plt.locator_params(axis='x', nbins=4)
plt.yticks(np.arange(0.6, 1.21, 0.6))
plt.setp(ax1.get_xticklabels(), visible=False)
ax1.set_ylim(0.4,1.4)


# Adjust vertical space between subplots
plt.tight_layout()

# Save the figure with a full path
save_path = '/Users/thornbird/Library/CloudStorage/OneDrive-DurhamUniversity/PhD/Tip/Figure paper/poster_exp_method.svg'
# plt.savefig(save_path, format='svg', bbox_inches='tight',dpi=1200)
ax3 = plt.gca()  # Get the current axis
ax3.set_xticks([0, 100, 200, 300, 400])
ax3.set_xticklabels([0, '', 200, '', 400], fontsize=10)

plt.show()
