import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import re
from matplotlib.cm import get_cmap
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm
plt.rc('font',family='Helvetica')
Sensitivity = 76.26 #
Spring_constant = 1.58 #
blunt_length= 1
ylg =1
# Create a colormap

def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=256):
    """
    Truncate a colormap.

    Parameters:
    - cmap: The colormap to truncate
    - minval: The lower bound of the range to keep (0.0 to 1.0)
    - maxval: The upper bound of the range to keep (0.0 to 1.0)
    - n: Number of colors in the truncated colormap
    """
    # Create a new colormap by sampling the original colormap
    new_cmap = LinearSegmentedColormap.from_list(
        f"truncated_{cmap.name}",
        cmap(np.linspace(minval, maxval, n))
    )
    return new_cmap
# Original colormap
cmap_hot = plt.get_cmap('hot')
cmap_viridis = plt.get_cmap('viridis')

# Truncate the colormap (e.g., keep the range from 0.2 to 0.9)
cmap_truncated = truncate_colormap(cmap_hot, minval=0.2, maxval=0.9)

# Original colormap
cmap_hot = plt.get_cmap('hot')

# Truncate the colormap (e.g., keep the range from 0.2 to 0.9)
cmap_truncated = truncate_colormap(cmap_viridis, minval=0, maxval=1)  # You can choose any colormap you prefer

def extract_number_from_curve_name(curve_name):
    match = re.search(r'(\d+)', curve_name)
    if match:
        return int(match.group())
    return float('inf')  # Use a large number for names without numbers


def process_input_file(input_file, k):

    # Get the name of the input file as a string
    file_name = input_file
    curve_name = str(file_name.split('_')[-2])
    print('strat processing',file_name)

    with open(input_file, 'r') as file:
        for line in file:
            if line.startswith('X SetScale/P'):
                info_line = line

    # Print the last line
    if info_line:
        print('info_line:', info_line)
        position_interval = float(info_line.split(',')[1])
        print('position interval', position_interval)
    else:
        print('No line starting with "X SetScale/P" found in the input file.')

    # Initialize variables
    deflection_data = []  # To store the data
    # Read the input file and collect data
    with open(input_file, 'r') as file:
        data_started = False
        for line in file:
            line = line.strip()
            if data_started:
                if line == "END":
                    break  # Exit the loop when "end" is encountered
                deflection_data.append(float(line))  # Convert the line to a float and append it to the data list
            elif line == "BEGIN":
                data_started = True

        # put into an array
        deflection_data_array = np.array(deflection_data) / 100 / 1 * Sensitivity * Spring_constant

        position_array = np.arange(0, (len(deflection_data_array) + 2) * position_interval, position_interval).astype(
            float)
        position_array = position_array[:len(deflection_data_array)]

        # if we correct the bending
        position_array = position_array[:len(deflection_data_array)] + deflection_data_array / Spring_constant
        position_array = position_array * 1e9
        deflection_data_array = deflection_data_array * 1e9


        def find_plateau_point_dynamic(y, min_slope_threshold=0.01, window_size=3):
            # Calculate slopes using a sliding window
            slopes = np.array([y[i + window_size] - y[i] for i in range(len(y) - window_size)])

            # Identify where slopes are below the threshold
            for i in range(len(slopes) - 1, -1, -1):
                if np.abs(slopes[i]) >= min_slope_threshold:

                    return i + window_size  # Returning the index of the plateau point

            return None  # If no plateau is found

        plateau_point = find_plateau_point_dynamic(deflection_data_array, min_slope_threshold=3, window_size=3)

        if plateau_point is not None:
                # plt.scatter(position_array[plateau_point], deflection_data_array[plateau_point], color='red', label='Plateau Start', s=100)
                print('find')
                plt.plot(position_array+(4.21-position_array[plateau_point])+40.851, deflection_data_array,
                         # label=curve_name,
                         linewidth=0.9,
                         alpha=0.7, color=cmap_truncated(k * (1 / 21)))
        else:
                print("No plateau point found.")

# Specify the folder containing the .txt files
folder_path = '/Users/thornbird/Library/CloudStorage/OneDrive-DurhamUniversity/PhD/Tip/Data_nanoparticle_adhesion/Repeatability_test_SI/individual_curve'
os.chdir(folder_path)



# List all the .txt files in the folder
txt_files = [f for f in os.listdir(folder_path) if f.endswith('.txt')]

def myFunc(file_name):
  return float(str(file_name.split('_')[-2]))

# Sort the files based on the numeric part of curve_name
txt_files.sort(key=myFunc)

print('file_list', txt_files)


# Create a single figure
fig, ax = plt.subplots(1, 1, figsize=(3.5, 3))  # Create a 2-row, 1-column subplot grid

# Process each .txt file and create a subplot for each
k = 0
# Create a single figure with specific size

for i, txt_file in enumerate(txt_files):
    input_file = txt_file
    process_input_file(input_file,k)
    k += 1
# Adjust the spacing between figures (increase top and bottom space)
plt.subplots_adjust(top=0.95, bottom=0.05)

# Enlarge axis number size



plt.xlabel('Position (nm)', fontsize=10)
plt.ylabel('$\it{F}$'+' (nN)',fontsize=10)


plt.xlim(-12,100)
# plt.ylim(-34,3)
plt.xticks([0,50,100],fontsize=10)
plt.yticks([-12,-8,-4,0],fontsize=10)
# Add a colorbar to the side with the scale from 0 to 30
sm = ScalarMappable(cmap=cmap_truncated, norm=plt.Normalize(vmin=0, vmax=20))
sm.set_array([])  # Dummy array

# Create colorbar
cbar = fig.colorbar(sm, orientation='vertical')  # Assuming colorbar is for the first subplot

# Set colorbar ticks and labels
ticks = [0, 5, 10, 15, 20]
cbar.set_ticks(ticks)
cbar.set_ticklabels([f'{tick:.1f}' for tick in ticks])
cbar.set_ticklabels([str(tick) for tick in ticks])  # Convert ticks to strings without decimals


cbar.set_label("Curve number", fontsize=10)  # You can change the label as needed

ylg=1
red_color = tuple(val / 255.0 for val in (177,24,45))
blue_color = tuple(val / 255.0 for val in (36,100,171))
orange_color = tuple(val / 255.0 for val in (241,108,35))
green_color = tuple(val / 255.0 for val in (27,124,61))
error_color = tuple(val / 255.0 for val in (71,146,196))
linewidth_plot = 0.6


adhesion_force_array = []
bridge_length_array = []
speed_array = []



def extract_number_from_curve_name(curve_name):
    match = re.search(r'(\d+)', curve_name)
    if match:
        return int(match.group())
    return float('inf')  # Use a large number for names without numbers

def process_error(input_file):

    # Get the name of the input file as a string
    file_name = input_file
    # curve_name = str(file_name.split('_')[-4])
    print('strat processing',file_name)

    with open(input_file, 'r') as file:
        for line in file:
            if line.startswith('X SetScale/P'):
                info_line = line

    # Print the last line
    if info_line:
        print('info_line:', info_line)
        first_x = float(info_line.split(',')[0].split(' ')[3])
        position_interval = float(info_line.split(',')[1])
        print('position interval', position_interval)
        print('first_x', first_x)
    else:
        print('No line starting with "X SetScale/P" found in the input file.')

    # Initialize variables
    error_deflection_data = []  # To store the data

    # Read the input file and collect data
    with open(input_file, 'r') as file:
        data_started = False
        for line in file:
            line = line.strip()
            if data_started:
                if line == "END":
                    break  # Exit the loop when "end" is encountered
                error_deflection_data.append(float(line))  # Convert the line to a float and append it to the data list
            elif line == "BEGIN":
                data_started = True


    # put into an array in Force
    error_data_array = np.array(error_deflection_data)/100/1* Spring_constant*Sensitivity

    error_position_array = np.arange(first_x, (len(error_data_array) + 2) * position_interval, position_interval).astype(
            float)
    error_data_array = error_data_array[:len(error_position_array)]
    print('x_len', len(error_position_array))
    print('error_len', len(error_data_array))
    error_position_array = error_position_array[:len(error_data_array)]
    print('x_len2', len(error_position_array))
    print('error_len2', len(error_data_array))
    error_position_array = error_position_array[:len(error_data_array)] + error_data_array / Spring_constant

    print('x_len', len(error_position_array))
    print('error_len', len(error_data_array))

    error_position_array = error_position_array * 1e9
    error_data_array = error_data_array * 1e9

    return error_data_array, error_position_array

def process_input_file(input_file):

    # Get the name of the input file as a string
    file_name = input_file
    # curve_name = str(file_name.split('_')[-4])
    print('strat processing',file_name)

    with open(input_file, 'r') as file:
        for line in file:
            if line.startswith('X SetScale/P'):
                info_line = line

    # Print the last line
    if info_line:
        print('info_line:', info_line)
        first_x = float(info_line.split(',')[0].split(' ')[3])
        position_interval = float(info_line.split(',')[1])
        print('position interval', position_interval)
        print('first_x', first_x)
    else:
        print('No line starting with "X SetScale/P" found in the input file.')

    # Initialize variables
    deflection_data = []  # To store the data

    # Read the input file and collect data
    with open(input_file, 'r') as file:
        data_started = False
        for line in file:
            line = line.strip()
            if data_started:
                if line == "END":
                    break  # Exit the loop when "end" is encountered
                deflection_data.append(float(line))  # Convert the line to a float and append it to the data list
            elif line == "BEGIN":
                data_started = True

            # put into an array
        deflection_data_array = np.array(deflection_data) / 100 / 1 * Sensitivity * Spring_constant

        position_array = np.arange(0, (len(deflection_data_array) + 2) * position_interval, position_interval).astype(
            float)
        position_array = position_array[:len(deflection_data_array)]

        # if we correct the bending
        position_array = position_array[:len(deflection_data_array)] + deflection_data_array / Spring_constant
        position_array = position_array * 1e9
        deflection_data_array = deflection_data_array * 1e9





    # Return deflection data array and position array
    return deflection_data_array, position_array


    # baseline removal
    baseObj_retr_minus_ext_base_on_height = BaselineRemoval(deflection_data_array)
    Imodpoly_output_for_ext_minus_retr_base_on_height = baseObj_retr_minus_ext_base_on_height.IModPoly(
        3)

    start_one_third_Imodpoly_output_for_ext_minus_retr_base_on_height = Imodpoly_output_for_ext_minus_retr_base_on_height[
                                                                        0:round(
                                                                            len(Imodpoly_output_for_ext_minus_retr_base_on_height) / 3)]
    shorter_start_half_Imodpoly_output_for_ext_minus_retr_base_on_height = start_one_third_Imodpoly_output_for_ext_minus_retr_base_on_height[
                                                                           np.argmin(
                                                                               start_one_third_Imodpoly_output_for_ext_minus_retr_base_on_height):]

    index_of_zero_x_base_on_height_on_interpad_measured_height = np.argmin(
        start_one_third_Imodpoly_output_for_ext_minus_retr_base_on_height) + np.argmax(
        shorter_start_half_Imodpoly_output_for_ext_minus_retr_base_on_height)
    zero_x_base_on_height = 0
    zero_y_base_on_height = deflection_data_array[
        index_of_zero_x_base_on_height_on_interpad_measured_height]
    shifted_array_retr_minus_ext_base_on_height = deflection_data_array - zero_y_base_on_height

    y_manual = np.average(shifted_array_retr_minus_ext_base_on_height[
                          int(0.2 * len(shifted_array_retr_minus_ext_base_on_height)):
                          int(0.9 * len(shifted_array_retr_minus_ext_base_on_height))])



    # measure
    F_ad = abs(min(deflection_data_array)*1e9) #adhesion force in nN
    bridge_length = abs(position_array[np.argmin(deflection_data_array)]-(zero_x_base_on_height)) #bridge length in nm
    speed = float(curve_name)
    adhesion_force_array.append(F_ad)
    bridge_length_array.append(bridge_length)
    speed_array.append(speed)
    blunt_length = 56/2*1e-9 # end radius in nm


    plt.axvline(x=0, color='grey', linestyle = '--',zorder=0)
    plt.axhline(y=0, color='grey', linestyle = '--',zorder=0)


# Create a single figure

folder_path = '/Users/thornbird/Library/CloudStorage/OneDrive-DurhamUniversity/PhD/Tip/Data_nanoparticle_adhesion/Repeatability_test_SI'

os.chdir(folder_path)



# Process input file
deflection_data, position_array = process_input_file('section_B2T9_100_avg_Defl.txt')
show_from=0
plt.plot(position_array[show_from:]+8.51, deflection_data[show_from:],linewidth = 1.0, color='r',alpha=0.8,zorder=5,label='Average')

# Process error file

error_data, error_position = process_error('section_B2T9_100_avgSDV_Defl.txt')


# Calculate upper and lower bounds for shaded area
upper_bound = (deflection_data[:12536] + error_data[:12536] )
lower_bound = (deflection_data[:12536]  - error_data[:12536] )

# Shade the area between the curves

plt.legend(loc=4, frameon=False)
plt.tight_layout()
save_name = 'repeat measurement.eps'
save_path = '/Users/thornbird/Library/CloudStorage/OneDrive-DurhamUniversity/PhD/Tip/Figure paper/'
plt.savefig(save_path+save_name, format='eps', bbox_inches='tight',dpi=1200)
plt.show()