import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.ticker as ticker
import matplotlib

# Set font to Times Roman
matplotlib.rcParams['font.family'] = 'Times New Roman'
# Data
models = ['ResNet32', 'ResNet44', 'ResNet56', 'MBNv2-x0.5', 'MBNv2-x0.75',
          'MBNv2-x1.0', 'MBNv2-x1.4', 'RepVGG A0', 'RepVGG A1', 'RepVGG A2']
baseline_acc = [70.16, 71.63, 72.63, 70.88, 73.61, 74.20, 75.98, 75.22, 76.12, 77.18]
calibrated_acc = [70.51, 72.63, 73.47, 71.80, 74.92, 75.44, 76.34, 76.22, 76.84, 77.45]
baseline_ted = [8.86, 8.81, 8.83, 8.87, 8.76, 8.73, 8.65, 8.74, 8.72, 8.70]
calibrated_ted = [8.84, 8.77, 8.71, 8.76, 8.63, 8.61, 8.62, 8.66, 8.58, 8.59]
baseline_mcs = [25.39, 25.97, 26.47, 25.50, 26.71, 27.09, 27.41, 26.84, 27.43, 27.99]
calibrated_mcs = [25.76, 26.28, 27.15, 26.36, 27.57, 27.83, 27.69, 27.54, 28.43, 28.17]
baseline_tree_kernel = [48.58, 49.79, 50.17, 49.08, 51.21, 51.47, 52.91, 52.01, 52.81, 53.04]
calibrated_tree_kernel = [49.13, 50.73, 51.43, 50.34, 52.49, 52.82, 53.24, 52.89, 53.81, 53.71]

# Combine baseline and calibrated values for each metric
# acc_data = np.array([baseline_acc, calibrated_acc]).T
# ted_data = np.array([baseline_ted, calibrated_ted]).T
# mcs_data = np.array([baseline_mcs, calibrated_mcs]).T
# tree_kernel_data = np.array([baseline_tree_kernel, calibrated_tree_kernel]).T

# Combine the data into a DataFrame
data = np.concatenate([baseline_acc, calibrated_acc])
groups = np.concatenate([np.repeat('Baseline', len(baseline_acc)),
                         np.repeat('Calibrated', len(calibrated_acc))])
df = pd.DataFrame({'Accuracy': data, 'Method': groups})

# Create custom color palette
colors = ['#FFC857', '#3CAEA3']

# Create violin plot using Seaborn
fig, ax = plt.subplots(figsize=(5, 5))
sns.violinplot(data=df, x='Method', y='Accuracy', palette=colors, ax=ax)

# Add scatter points
sns.stripplot(data=df, x='Method', y='Accuracy', color='black', size=4, alpha=0.7)

# Set y-axis label and title
ax.set_ylabel('Accuracy (%)', fontsize=22)
# ax.set_title('Model Comparison on CIFAR-100 Dataset')

# Set scientific notation for y-axis
ax.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
ax.yaxis.offsetText.set_fontsize(10)
ax.tick_params(axis='y', labelsize=10)

# Customize x-axis labels
ax.set_xticklabels(['Baseline', 'Calibrated'], fontsize=22)

# Set plot limits
# ax.set_ylim([70, 90])

# Add statistical annotations
sns.despine(trim=True)
for i, group_name in enumerate(['Baseline', 'Calibrated']):
    group_data = df[df['Method'] == group_name]['Accuracy']
    mean_value = np.mean(group_data)
    std_value = np.std(group_data)
    ax.annotate(f'Mean: {mean_value:.2f}', xy=(i, mean_value), xytext=(40, -15),
                textcoords='offset points', ha='center', fontsize=14, color='black')


plt.tight_layout()
# plt.show()
plt.savefig('tools/plot/cifar100_acc_calibrated.pdf')


data = np.concatenate([baseline_ted, calibrated_ted])
groups = np.concatenate([np.repeat('Baseline', len(baseline_acc)),
                         np.repeat('Calibrated', len(calibrated_acc))])
df = pd.DataFrame({'TED': data, 'Method': groups})

# Create custom color palette
colors = ['#FFC857', '#3CAEA3']

# Create violin plot using Seaborn
fig, ax = plt.subplots(figsize=(5, 5))
sns.violinplot(data=df, x='Method', y='TED', palette=colors, ax=ax)

# Add scatter points
sns.stripplot(data=df, x='Method', y='TED', color='black', size=4, alpha=0.7)

# Set y-axis label and title
ax.set_ylabel('TED', fontsize=22)
# ax.set_title('Model Comparison on CIFAR-100 Dataset')

# Set scientific notation for y-axis
ax.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
ax.yaxis.offsetText.set_fontsize(10)
ax.tick_params(axis='y', labelsize=10)

# Customize x-axis labels
ax.set_xticklabels(['Baseline', 'Calibrated'], fontsize=22)

# Set plot limits
# ax.set_ylim([70, 90])

# Add statistical annotations
sns.despine(trim=True)
for i, group_name in enumerate(['Baseline', 'Calibrated']):
    group_data = df[df['Method'] == group_name]['TED']
    mean_value = np.mean(group_data)
    std_value = np.std(group_data)
    ax.annotate(f'Mean: {mean_value:.2f}', xy=(i, mean_value), xytext=(40, -15),
                textcoords='offset points', ha='center', fontsize=14, color='black')


plt.tight_layout()
# plt.show()
plt.savefig('tools/plot/cifar100_ted_calibrated.pdf')


