import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import re
import matplotlib.colors as mcolors
import matplotlib.cm as cm
from scipy.optimize import curve_fit

def mse_model(data, A, B, alpha):
    return A + B / (data)**alpha

def parse_text(text):
    pattern = r"traffic_(\d+).*?percentIs(\d*).*?mse:(\d+\.\d+)"
    matches = re.findall(pattern, text, re.DOTALL)
    final_list = [(int(m[0]), int(m[1]), float(m[2])) for m in matches]
    new_final_list = []
    for answer_tuple in final_list:
        if answer_tuple[0] >= 390:
        # if answer_tuple[1] >= 3 and answer_tuple[0] < 1000:
        # if answer_tuple[0] >= 10 and answer_tuple[0] <= 336 and answer_tuple[1] >= 3 and answer_tuple[1] <= 100 or (answer_tuple[1] > 3 and answer_tuple[0] == 768) or (answer_tuple[0] == 1024 and answer_tuple[1] >= 5 and answer_tuple[1] != 25):
            new_final_list.append(answer_tuple)
    return new_final_list

def read_data(file_path):
    with open(file_path, 'r') as file:
        text = file.read()
    results = parse_text(text)
    return pd.DataFrame(results, columns=['Interto', 'Percentage', 'MSE'])

def plot_data(df):
    norm = mcolors.LogNorm(vmin=df['Interto'].min(), vmax=df['Interto'].max())
    scalar_map = cm.ScalarMappable(norm=norm, cmap=cm.viridis)

    plt.figure(figsize=(12, 8))
    grouped = df.groupby('Interto')
    # print(grouped)
    
    for interto, group in grouped:
        # if interto < 336 or interto == 640:
        #     continue
        print(group)
        
        group.sort_values('Percentage', inplace=True)
        mean_mse = group.groupby('Percentage')['MSE'].mean()
        std_mse = group.groupby('Percentage')['MSE'].std()
        percents = mean_mse.index
        color = scalar_map.to_rgba(interto)
        
        plt.errorbar(percents, mean_mse, yerr=0.0, label=f'traffic: {interto} to 192', fmt='o', capsize=5, capthick=2, color = color)
        # plt.scatter(group['Percentage'], group['MSE'], label=f'Horizon {interto}', color=color, alpha=0.6)
        
        # Fit the model to the data
        popt, pcov = curve_fit(mse_model, mean_mse.index, mean_mse, maxfev=10000, p0=[0.2,2.0,1.0], sigma = mean_mse)
        print(popt)
        A, B, alpha = popt
        std_alpha = np.sqrt(np.diag(pcov))[2]  # Standard deviation of alpha
        observed = mean_mse.values
        mean_observed = np.mean(observed)
        predicted = mse_model(mean_mse.index, *popt)
        SSR = np.sum((observed - predicted) ** 2)
        SST = np.sum((observed - mean_observed) ** 2)
        R_squared = 1 - (SSR / SST)
        
        # Create a smooth line for the model
        smooth_data = np.linspace(group['Percentage'].min(), group['Percentage'].max(), 500)
        smooth_mse = mse_model(smooth_data, *popt)
        plt.plot(smooth_data, smooth_mse, color=color, label=f'Fit: α={alpha:.3f}±{std_alpha:.3f},R^2={R_squared:.3f}')
    
    plt.title('MSE by Data Percentage with Nonlinear Fit')
    plt.xlabel('Percentage of Data Used')
    plt.ylabel('MSE')
    # plt.xscale('log')
    # plt.yscale('log')
    # cbar = plt.colorbar(scalar_map, label='Horizon')
    tick_locs = np.unique(df['Interto'])
    # cbar.set_ticks(tick_locs)
    # cbar.set_ticklabels(tick_locs)
    
    plt.legend()
    plt.grid(True)
    plt.savefig("newresult_iTF_traffic_horizonXDatascaling___byData.png")

if __name__ == "__main__":
    file_path = 'newresult_iTF_traffic_HorizonXData.txt'
    df = read_data(file_path)
    plot_data(df)
