import torch
import torch.nn as nn
import numpy as np
import scipy
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt
import geotorch
import geomloss
import ot
from matplotlib import collections  as mc
from matplotlib.legend_handler import HandlerTuple
import argparse
from PIL import Image
from sklearn.decomposition import PCA
import os
import wandb

from scipy.stats import ortho_group
import itertools

from tqdm import tqdm
import sys
sys.path.append("..")

from typing import Callable, Tuple, Union


from src import distributions
from src.utils import Distrib2Sampler, Config

# models
from src.models2D import FullyConnectedMLP

# langevin sampling
from src.eot import sample_langevin_batch
from src.eot_utils import computePotGrad, evaluating
from src.utils import *
from src import bar_benchmark

################# CONFIG ###############
########################################
parser = argparse.ArgumentParser(prefix_chars='--')
parser.add_argument('--DIM', type=int, default=2,
                     help='2,4,8,16,64 or 128')
ARGS = parser.parse_args()

DIM = ARGS.DIM
assert DIM > 1

CLIP_GRADS_NORM = False
HREG = 0.1

LANGEVIN_THRESH = None
LANGEVIN_SAMPLING_NOISE = 0.1
ENERGY_SAMPLING_ITERATIONS = 700 # 100
LANGEVIN_DECAY = 1.0
LANGEVIN_SCORE_COEFFICIENT = 1.0
LANGEVIN_COST_COEFFICIENT = 1.0

# learning parameters
MAX_STEPS = 100000
PLOT_FREQ = 300
SCORE_FREQ = 300
BATCH_SIZE = 1000
BASIC_NOISE_VAR = 1.0

NUM = 3
ALPHAS = np.array([0.25, 0.25, 0.5])

CASE = {
    'type' : 'EigWarp', 
    'sampler' : 'Gaussians', #'SwissRoll',# , #
    'params' : {'num' : NUM, 'alphas' : ALPHAS, 'min_eig' : .5, 'max_eig' : 2}
}

SEED = 0xB00BA
np.random.seed(SEED)
torch.manual_seed(SEED)
# torch.cuda.manual_seed_all(seed)
# torch.cuda.deterministic = True

EXP_NAME = f'EgNOTbary_{DIM}_{HREG}_{NUM}'
OUTPUT_PATH = '../checkpoints/EgNOTbary_{}_{}_{}/'.format(DIM, HREG, NUM)

if OUTPUT_PATH is not None:
    if not os.path.exists(OUTPUT_PATH):
        os.makedirs(OUTPUT_PATH)

config = dict(
    CASE=CASE['sampler'],
    SCORE_FREQ=SCORE_FREQ,
    BATCH_SIZE=BATCH_SIZE
)

wandb.init(name=EXP_NAME, project='egbarycenters', entity='gunsandroses', config=config)

assert NUM == len(ALPHAS)   
assert torch.cuda.is_available()
DEVICE = 'cuda'
DEVICE_IDS = [i for i in range(torch.cuda.device_count())]

############### Initializing distributions ##########
#######################################################

if CASE['type'] == 'EigWarp':
    if CASE['sampler'] == 'Gaussians':
        sampler = distributions.StandardNormalSampler(dim=DIM)
    elif CASE['sampler'] == 'SwissRoll':
        assert DIM == 2
        sampler = distributions.SwissRollSampler()
    elif CASE['sampler'] == 'Rectangles':
        sampler = distributions.CubeUniformSampler(dim=DIM, normalized=True, centered=True)
    
    benchmark = bar_benchmark.EigenWarpBenchmark(sampler, **CASE['params'])
    
############## PLOTTERS #########################
################################################

def plot_distributions(ax1, ax2):
    cols = plt.get_cmap("Dark2").colors
    Xs = []
    for i, distr in enumerate(benchmark.samplers):
        X = distr.sample(512,).detach().cpu().numpy()
        Xs.append(X)
        ax1.scatter(
            X[:, 0], X[:, 1],
            label=f"$x_{{{i + 1}}} \\sim \\mathbb{{P}}_{{{i + 1}}}$", 
            edgecolors=alpha_color((0, 0, 0)), color=alpha_color(cols[i]), linewidth=.5,
        )
    
    Xgt = benchmark.gauss_bar_sampler.sample(512).detach().cpu().numpy()
    ax2.scatter(
        Xgt[:, 0], Xgt[:, 1],
        label=r"$x \sim \mathbb{Q}_*$", 
        edgecolors='black', color=cols[NUM + 1], linewidth=.5,
    )
    ax1.legend(ncol=2, loc="upper left", prop={"size": 12})
    ax2.legend(ncol=2, loc="upper left", prop={"size": 12})

def alpha_color(color_rgb, alpha=0.5):
    color_rgb = np.asanyarray(color_rgb)
    alpha_color_rgb = 1. - (1. - color_rgb) * alpha
    return alpha_color_rgb

def plot_bary_i(f_pot, sampler, ax, i, n_samples=512, n_maps=0, n_arrows_per_map=1):
    global p1
    n_arrows = n_maps * n_arrows_per_map
    X = benchmark.samplers[i].sample(n_samples,).to(DEVICE)
    if n_maps > 0:
        Xm = benchmark.samplers[i].sample(n_maps,).to(DEVICE)
        Xm = torch.tile(Xm, (n_arrows_per_map, 1))
        X = torch.concatenate((X, Xm), dim=0)
        
        
    Y_init = init_noise_sampler.sample(n_samples + n_arrows).to(DEVICE)
    Y = sample_langevin_mu_f(f_pot, X, Y_init).to(DEVICE)
    X_np = X.detach().cpu().numpy()
    Y_np = Y.detach().cpu().numpy()
    
    def darker(c): return tuple(x * 0.85 for x in c)
    
    cols = plt.get_cmap("Dark2").colors
    col_bary = plt.get_cmap("tab10").colors[NUM]
    p4 = ax.scatter(
        X_np[:n_samples, 0], X_np[:n_samples, 1],
#         label=f"$x_{i} \\sim \mathbb{{P}}_{i}$",
        edgecolors=alpha_color((0, 0, 0)), color=alpha_color(cols[i]), zorder=0, linewidth=.5,
    )
    p1 = ax.scatter(
        Y_np[:n_samples, 0], Y_np[:n_samples, 1],
        edgecolors=(0, 0, 0), color=col_bary, zorder=0, linewidth=.5,
    )
    p3 = ax.scatter(
            X_np[-n_arrows:, 0], X_np[-n_arrows:, 1],
            linewidth=.5, edgecolors='black', color=cols[i], zorder=2,
        )
    p2 = ax.scatter(
        Y_np[-n_arrows:, 0], Y_np[-n_arrows:, 1],
        linewidth=.5, edgecolors='black', color=cols[NUM + 2], zorder=2,
    )
    if n_arrows > 0:
        ax.quiver(
            X_np[-n_arrows:, 0], X_np[-n_arrows:, 1],
            Y_np[-n_arrows:, 0] - X_np[-n_arrows:, 0], Y_np[-n_arrows:, 1] - X_np[-n_arrows:, 1],
            angles='xy', scale_units='xy', scale=0.95, width=.005, zorder=1, headwidth=0.0, headlength=0.0,
        )
        
    ax.legend(
        [
            (p1, p2),
            (p3, p4),
        ], [
            f"$x_{i + 1} \\sim \\pi^*_{i + 1}(\\cdot \\mid x_{i + 1})$",
            f"$x_{i + 1} \\sim \mathbb{{P}}_{i + 1}$",
        ],
        handler_map={tuple: HandlerTuple(ndivide=None)},
        loc="upper left",
        prop={"size": 12},
    )

def plot_bary(potential_fns, benchmark, arrows=True):
    N_SAMPLES = 512
    N_MAPS = 5
    N_ARROWS_PER_MAP = 3
    
    n_maps = N_MAPS if arrows else 0
    
    fig, axs = plt.subplots(
        ncols=NUM + 2,
        figsize=(18.75, 3.75),
        sharex=True, sharey=True,
        dpi=200,
    )
        
    plot_distributions(axs[0], axs[1])
    axs[0].set_xlim(-7, 7)
    axs[0].set_ylim(-7, 7)
    axs[1].set_xlim(-7, 7)
    axs[1].set_ylim(-7, 7)
    
    for i, (f_pot, sampler, ax) in enumerate(zip(potential_fns, benchmark.samplers, axs[2:])):
        plot_bary_i(f_pot, sampler, ax, i, N_SAMPLES, n_maps, N_ARROWS_PER_MAP)
    
    fig.tight_layout()
    return fig

################### PCA ########################
#################################################
pca = PCA(n_components=2)

class Identity:
    pass

if benchmark.bar_sampler is not None:
    pca.fit(benchmark.bar_sampler.sample(100000).cpu().detach().numpy())
elif benchmark.gauss_bar_sampler is not None:
    pca.fit(benchmark.gauss_bar_sampler.sample(100000).cpu().detach().numpy())
else:
    pca = Identity()
    pca.transform = lambda x: x
    
# No PCA for dim=2
if DIM == 2:
    pca = Identity()
    pca.transform = lambda x: x
    
####################### Potentials setup ################
###########################################################

def make_f_pot(idx, nets):
    def f_pot(x):
        res = 0.0
        for i, (net, alpha) in enumerate(zip(nets, ALPHAS)):
            if i == idx:
                res += net(x)
            else:
                res -= alpha * net(x) / (NUM - 1) / ALPHAS[idx]
        return res
    return f_pot

def l2_grad_y(y, x):
    '''
    returns \nabla_y c(x, y)
    '''
    return y - x

grad_fn = l2_grad_y

def cond_score(
        f : Callable[[torch.Tensor], torch.Tensor], 
        cost_grad_y_fn : Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 
        y : torch.Tensor, 
        x : torch.Tensor,
        ret_stats=False
    ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
    with torch.enable_grad():
        y.requires_grad_(True)
        proto_s = f(y)
        s = computePotGrad(y, proto_s)
        assert s.shape == y.shape
    cost_coeff = LANGEVIN_COST_COEFFICIENT * (LANGEVIN_SAMPLING_NOISE ** 2 / HREG)
    cost_part = cost_grad_y_fn(y, x) * cost_coeff
    score_part = s * LANGEVIN_SCORE_COEFFICIENT
    if not ret_stats:
        return score_part - cost_part
    return score_part - cost_part, cost_part, score_part

def sample_langevin_mu_f(
        f: Callable[[torch.Tensor], torch.Tensor], 
        x: torch.Tensor, 
        y_init: torch.Tensor
    ) -> torch.Tensor:
    
    def score(y, ret_stats=False):
        return cond_score(f, grad_fn, y, x, ret_stats=ret_stats)
    
    y, r_t, cost_r_t, score_r_t, noise_norm = sample_langevin_batch(
        score, 
        y_init,
        n_steps=ENERGY_SAMPLING_ITERATIONS, 
        decay=LANGEVIN_DECAY, 
        thresh=LANGEVIN_THRESH, 
        noise=LANGEVIN_SAMPLING_NOISE, 
        data_projector=lambda x: x, 
        compute_stats=True)
    
    return y

nets = [FullyConnectedMLP(DIM, [32, 32], 1).to(DEVICE) for _ in range(NUM)]
param_gens = [net.parameters() for net in nets]
opt = torch.optim.Adam(
    itertools.chain(*param_gens),
    lr=1e-3,
)

f_pots = [make_f_pot(i, nets) for i in range(NUM)]

init_noise_sampler = Distrib2Sampler(TD.Normal(
    torch.zeros(DIM).to(DEVICE), 
    torch.ones(DIM).to(DEVICE) * BASIC_NOISE_VAR))

################# Metrics  #########################
####################################################

def score_forward_maps(benchmark, f_pots, alphas, score_size=1024, hidden_size=1000):
    assert (benchmark.gauss_bar_maps is not None) and (benchmark.gauss_bar_sampler is not None)
    L2_UVP_arr = []
    Y_init = init_noise_sampler.sample(hidden_size*score_size).to(DEVICE)
    for n in tqdm(range(benchmark.num)):
        X = benchmark.samplers[n].sample(score_size)
        Xs = []
        for i in range(X.shape[0]):
            X_i = X[i].view((1, DIM)).repeat_interleave(hidden_size, axis=0).cuda()
            Xs.append(X_i)
        Xs = torch.cat(Xs)
        with torch.no_grad():
            X_push = sample_langevin_mu_f(f_pots[n], Xs, Y_init)
            X_push = X_push.view((score_size, hidden_size, DIM)).mean(dim=1)
        with torch.no_grad():
            X_push_true = benchmark.gauss_bar_maps[n](X)
            L2_UVP_arr.append(
                100 * (((X_push - X_push_true) ** 2).sum(dim=1).mean() / benchmark.gauss_bar_sampler.var).item()
            )
    weighted_L2_UVP = sum(alpha * L2_UVP for (alpha, L2_UVP) in zip(alphas, L2_UVP_arr))
    return weighted_L2_UVP

################# ALGORITHM ##########################
######################################################
    
for g in opt.param_groups:
    g['lr'] = 0.001
    
last_plot_it = -1
last_score_it = -1
best_L2_UVP = 1000

for it in tqdm(range(MAX_STEPS)):
    Xs = [s.sample(BATCH_SIZE).to(DEVICE) for s in benchmark.samplers]
    Ys_init = [init_noise_sampler.sample(BATCH_SIZE).to(DEVICE) for _ in range(NUM)]

    for net in nets: unfreeze(net)
    with torch.no_grad():
        Ys = [sample_langevin_mu_f(f, X.to(DEVICE), Y_init) for f, X, Y_init in zip(f_pots, Xs, Ys_init)]

    loss = sum(alpha * f(Y).mean() for alpha, f, Y in zip(ALPHAS, f_pots, Ys))
    wandb.log({f'Loss' : loss.item()}, step=it)
    opt.zero_grad()
    loss.backward()
    opt.step()

    if (it - last_plot_it >= PLOT_FREQ):
        clear_output(wait=True)
        last_plot_it = it  

        fig = plot_bary(f_pots, benchmark, arrows=False)
        fig.tight_layout();
        wandb.log({'Pca' : [wandb.Image(fig2img(fig))]}, step=it)
        plt.show()
        plt.close(fig)

    if (it - last_score_it >= SCORE_FREQ):
        last_score_it = it
        for net in nets: freeze(net)
        if benchmark.gauss_bar_sampler is not None:
            L2_UVP = score_forward_maps(benchmark, f_pots, ALPHAS, score_size=1024)
            wandb.log({f'L2_UVP' : L2_UVP}, step=it)

            if L2_UVP < best_L2_UVP:
                for k in range(benchmark.num):
                    freeze(nets[k])
                    torch.save(nets[k].state_dict(), OUTPUT_PATH + 'f_pots{}_best.pt'.format(k))
                    np.savez(OUTPUT_PATH + 'metrics.npz', L2_UVP=L2_UVP)