import matplotlib.pyplot as plt
import numpy as np
import ot
from itertools import combinations
import operator as op
from functools import reduce
import random
import math
import time
from torch.autograd import grad
from torch.autograd import Variable
from von_mises_fisher import VonMisesFisher
import torch
from random import choices
from imageio import imread
from sklearn.neighbors import KernelDensity
from vonmiseskde import VonMisesKDE


def load_image(fname):
    img = imread(fname, as_gray=True)  # Grayscale
    img = (img[::-1, :]) / 255.
    return 1 - img
def draw_samples(fname, n, dtype=torch.FloatTensor):
    A = load_image(fname)
    xg, yg = np.meshgrid(np.linspace(0, 1, A.shape[0]), np.linspace(0, 1, A.shape[1]))

    grid = list(zip(xg.ravel(), yg.ravel()))
    dens = A.ravel() / A.sum()
    dots = np.array(choices(grid, dens, k=n))
    dots += (.5 / A.shape[0]) * np.random.standard_normal(dots.shape)

    return torch.from_numpy(dots).type(dtype)

def rand_projections(dim, num_projections=1000,device='cpu',require_grad=False):
    projections = torch.randn((num_projections, dim),device=device,requires_grad=require_grad)
    projections.data = projections.data / torch.sqrt(torch.sum(projections.data ** 2, dim=1, keepdim=True))
    return projections


def one_dimensional_Wasserstein_prod(X_prod,Y_prod,p):
    X_prod = X_prod.view(X_prod.shape[0], -1)
    Y_prod = Y_prod.view(Y_prod.shape[0], -1)
    wasserstein_distance = torch.abs(
        (
                torch.sort(X_prod, dim=0)[0]
                - torch.sort(Y_prod, dim=0)[0]
        )
    )
    wasserstein_distance = torch.sum(torch.pow(wasserstein_distance, p), dim=0)
    return wasserstein_distance

def SW(X, Y, L=30, p=2, device="cpu"):
    dim = X.size(1)
    theta = rand_projections(dim, L,device)
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta.transpose(0, 1))
    sw=one_dimensional_Wasserstein_prod(X_prod,Y_prod,p=p)
    return  sw,torch.pow(sw.mean(),1./p)
np.random.seed(1)
torch.manual_seed(1)
mu1_s = np.array([0, 0])
cov1_s = np.array([[1, 0], [0, 1]])

mu2_s = np.array([0, 1])
cov2_s = np.array([[1, -0.8], [-0.8, 1]])

mu1_t = np.array([0, 0])
cov1_t = np.array([[1, -0.8], [-0.8, 1]])

mu2_t = np.array([0, 0])
cov2_t = np.array([[1, 0], [0, 1]])

x = np.concatenate([ot.datasets.make_2D_samples_gauss(50, mu1_s, cov1_s),ot.datasets.make_2D_samples_gauss(50, mu2_s, cov2_s)],axis=0)
y = np.concatenate([ot.datasets.make_2D_samples_gauss(50, mu1_t, cov1_t),ot.datasets.make_2D_samples_gauss(50, mu2_t, cov2_t)],axis=0)

N, M =  (100, 100)

x = draw_samples("density_a.png", N, torch.FloatTensor).numpy()
y = draw_samples("density_b.png", M, torch.FloatTensor).numpy()
n_row = 1
n_col = 3
fig, axs = plt.subplots(n_row, n_col, figsize=(18,5))
axs[0].scatter(x[:, 0], x[:, 1], c="tab:blue", label="$\mu$")
axs[0].scatter(y[:, 0], y[:, 1], c="tab:red",marker='8', label=r"$\nu$")
axs[0].legend(fontsize = 14)
axs[0].set_xlabel(r'$x_1$')
axs[0].set_ylabel(r'$x_2$')
# axs[0].set_title(r'TDistributions')
X = torch.from_numpy(x).float()
Y = torch.from_numpy(y).float()
d=X.shape[1]
epsilon=0





np.random.seed(1)
torch.manual_seed(1)
angles = torch.linspace(-np.pi,np.pi,10000)
tanangles = torch.tan(angles)
theta = torch.stack([torch.sin(tanangles),torch.cos(tanangles)],dim=1)
X_prod = torch.matmul(X, theta.transpose(0, 1))
Y_prod = torch.matmul(Y, theta.transpose(0, 1))
distances =one_dimensional_Wasserstein_prod(X_prod,Y_prod,p=2)+epsilon
distances = distances/torch.sum(distances)
angles = torch.atan2(theta[:,0],theta[:,1]).numpy()

# axs[0,1].bar(angles.numpy(), distances.numpy(),width=0.1, color = "green")
maxs = [distances[np.argmax(distances)]]
axs[1].plot(angles[np.argsort(angles)], distances.numpy()[np.argsort(angles)],label=r'$\sigma_{\mu,\nu}(\theta;f_1)$',linestyle='solid')
# axs[1].bar(angles[np.argmax(distances)], 1,width=0.02,label="True Max",color='black')
axs[1].set_xlim([-np.pi,np.pi])

np.random.seed(1)
torch.manual_seed(1)
start = time.time()
theta = rand_projections(d, 20000,'cpu')
X_prod = torch.matmul(X, theta.transpose(0, 1))
Y_prod = torch.matmul(Y, theta.transpose(0, 1))
distances =one_dimensional_Wasserstein_prod(X_prod,Y_prod,p=2)+epsilon
weights = distances/torch.sum(distances)
inds = torch.multinomial(weights, 10000, replacement=True)
angles = torch.atan2(theta[:,0],theta[:,1])
angles = angles[inds]
print("SIR: {}".format(time.time()-start))
values, counts = np.unique(angles.numpy(), return_counts=True)

# axs[0,2].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[0,2].set_xlim([-np.pi,np.pi])

kde_sk = VonMisesKDE(angles.reshape(-1,),np.full(angles.reshape(-1).shape,1)/angles.reshape(-1).shape[0],kappa=20)
eval_points = torch.linspace(-np.pi,np.pi,10000)
y_sk = kde_sk.evaluate(eval_points)
y_sk = y_sk / np.sum(y_sk)
maxs.append(np.max(y_sk))
# values, counts = np.unique(angles.numpy(), return_counts=True)
# axs[1,0].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[1,0].set_xlim([-np.pi,np.pi])
axs[1].plot(eval_points, y_sk,label=r"$\hat{\sigma}_{\mu,\nu}^{SIR}(\theta;f_1)$")


np.random.seed(1)
torch.manual_seed(1)
L=20000
start=time.time()
theta=rand_projections(d, 1, 'cpu')
thetas=[theta]
for l in range(L):
    theta_prime = rand_projections(d, 1, 'cpu')
    X_prod = torch.matmul(X, theta_prime.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta_prime.transpose(0, 1))
    distance_new = one_dimensional_Wasserstein_prod(X_prod, Y_prod, p=2)+epsilon
    theta = thetas[-1]
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta.transpose(0, 1))
    distance_old = one_dimensional_Wasserstein_prod(X_prod, Y_prod, p=2)+epsilon
    acceptance_rate = np.min([1,distance_new/distance_old])
    u = torch.rand(1)
    if(u<=acceptance_rate):
        thetas.append(theta_prime)
    else:
        thetas.append(theta)
theta = torch.cat(thetas,dim=0)[-1000:]
print("IMH: {}".format(time.time()-start))
angles = torch.atan2(theta[:,0],theta[:,1])

kde_sk = VonMisesKDE(angles.reshape(-1,),np.full(angles.reshape(-1).shape,1)/angles.reshape(-1).shape[0],kappa=20)
eval_points = torch.linspace(-np.pi,np.pi,10000)
y_sk = kde_sk.evaluate(eval_points)
y_sk = y_sk / np.sum(y_sk)
maxs.append(np.max(y_sk))
# values, counts = np.unique(angles.numpy(), return_counts=True)
# axs[1,0].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[1,0].set_xlim([-np.pi,np.pi])
axs[1].plot(eval_points, y_sk, label=r"$\hat{\sigma}_{\mu,\nu}^{IMH}(\theta;f_1)$")


np.random.seed(1)
torch.manual_seed(1)
start=time.time()
L=20000
theta=rand_projections(d, 1, 'cpu')
thetas=[theta]
for l in range(L):
    theta = thetas[-1]
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta.transpose(0, 1))
    distance_old = one_dimensional_Wasserstein_prod(X_prod, Y_prod, p=2) + epsilon
    vmf = VonMisesFisher(theta.data, torch.full((1, 1), 1, device='cpu'))
    theta_prime= vmf.rsample(1).view(1, -1)
    X_prod = torch.matmul(X, theta_prime.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta_prime.transpose(0, 1))
    distance_new = one_dimensional_Wasserstein_prod(X_prod, Y_prod, p=2)+epsilon
    acceptance_rate = np.min([1,distance_new/distance_old])
    u = torch.rand(1)
    if(u<=acceptance_rate):
        thetas.append(theta_prime)
    else:
        thetas.append(theta)
theta = torch.cat(thetas,dim=0)[-10000:]
print("RMH: {}".format(time.time()-start))
angles = torch.atan2(theta[:,0],theta[:,1])
values, counts = np.unique(angles.numpy(), return_counts=True)
# axs[1,1].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[1,1].set_xlim([-np.pi,np.pi])

kde_sk = VonMisesKDE(angles.reshape(-1,),np.full(angles.reshape(-1).shape,1)/angles.reshape(-1).shape[0],kappa=20)
eval_points = torch.linspace(-np.pi,np.pi,10000)
y_sk = kde_sk.evaluate(eval_points)
y_sk = y_sk / np.sum(y_sk)
maxs.append(np.max(y_sk))
# values, counts = np.unique(angles.numpy(), return_counts=True)
# axs[1,0].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[1,0].set_xlim([-np.pi,np.pi])
axs[1].plot(eval_points, y_sk,label=r"$\hat{\sigma}_{\mu,\nu}^{RMH}(\theta;f_1)$")



##


np.random.seed(1)
torch.manual_seed(1)
L=10000
kappa=1.
p=2
start=time.time()
epsilon = torch.randn((1, d), device='cpu', requires_grad=True)
epsilon.data = epsilon.data / torch.sqrt(torch.sum(epsilon.data ** 2, dim=1,keepdim=True))
optimizer = torch.optim.SGD([epsilon], lr=0.1)
X_detach = X.detach()
Y_detach = Y.detach()
for _ in range(100):
    vmf = VonMisesFisher(epsilon, torch.full((1, 1), kappa, device='cpu'))
    theta = vmf.rsample(L).view(L, -1)
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta.transpose(0, 1))
    negative_sw = -torch.pow(one_dimensional_Wasserstein_prod(X_prod, Y_prod, p=2).mean(),1./p)
    optimizer.zero_grad()
    negative_sw.backward()
    optimizer.step()
    epsilon.data = epsilon.data / torch.sqrt(torch.sum(epsilon.data ** 2, dim=1,keepdim=True))
# vmf = VonMisesFisher(epsilon.clone(), torch.full((1, 1), kappa, device='cpu'))
from torch.distributions.von_mises import VonMises
vmf =VonMises(torch.atan2(epsilon[:,0],epsilon[:,1]).detach().view(-1,), torch.tensor([kappa]))
# angles = torch.atan2(epsilon[:,0],epsilon[:,1]).detach().numpy()
# axs[1].bar(angles[0], max,width=0.02,label="mean",color='yellow')
# theta = vmf.rsample(L).view(L, -1)
# sw = one_dimensional_Wasserstein_prod(X, Y,theta, p=p).mean()
# theta = torch.cat(thetas,dim=0)[-5000:]
print("vDSW: {}".format(time.time()-start))
# angles = torch.atan2(theta[:,0],theta[:,1]).detach().numpy()
# values, counts = np.unique(angles, return_counts=True)
# axs[1,2].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[1,2].set_xlim([-np.pi,np.pi])
# kde_sk = VonMisesKDE(angles.reshape(-1,),np.full(angles.reshape(-1).shape,1)/angles.reshape(-1).shape[0],kappa=1)
angles = torch.linspace(-np.pi,np.pi,10000)
# tanangles = torch.tan(angles)
# theta = torch.stack([torch.sin(tanangles),torch.cos(tanangles)],dim=1)
probs = torch.exp(vmf.log_prob(angles)).detach().numpy()
probs = probs/np.sum(probs)
maxs.append(np.max(probs))
from scipy.stats import vonmises

# y_sk = kde_sk.evaluate(eval_points)
# y_sk = y_sk / np.sum(y_sk)
# values, counts = np.unique(angles.numpy(), return_counts=True)
# axs[1,0].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[1,0].set_xlim([-np.pi,np.pi])
axs[1].plot(angles, probs, label = r"$vMF(\epsilon^\star,\kappa)$")
axs[1].set_ylim([0,np.max(maxs)+1e-6])







np.random.seed(1)
torch.manual_seed(1)
L=10000
kappa=15
p=2
start=time.time()
theta = torch.randn((1, d), device='cpu', requires_grad=True)
theta.data = epsilon.data / torch.sqrt(torch.sum(theta.data ** 2, dim=1,keepdim=True))
optimizer = torch.optim.SGD([theta], lr=0.1)
X_detach = X.detach()
Y_detach = Y.detach()
for _ in range(100):
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta.transpose(0, 1))
    negative_sw = -torch.pow(one_dimensional_Wasserstein_prod(X_prod, Y_prod, p=2).mean(),1./p)
    optimizer.zero_grad()
    negative_sw.backward()
    optimizer.step()
    theta.data = theta.data / torch.sqrt(torch.sum(theta.data ** 2, dim=1,keepdim=True))
angles = torch.atan2(theta[:,0],theta[:,1]).detach().numpy()
axs[1].bar(angles[0], 1,width=0.03,label=r"$\delta_{\theta^\star}$",color='black')
axs[1].legend(fontsize = 14)
axs[1].set_xlabel(r'$\theta$',fontsize = 14)
axs[1].set_ylabel(r'Density',fontsize = 14)
axs[1].set_title(r'$f_1(x) = x$',fontsize = 14)


################### exp
x = draw_samples("density_a.png", N, torch.FloatTensor).numpy()
y = draw_samples("density_b.png", M, torch.FloatTensor).numpy()
X = torch.from_numpy(x).float()
Y = torch.from_numpy(y).float()

np.random.seed(1)
torch.manual_seed(1)
angles = torch.linspace(-np.pi,np.pi,10000)
tanangles = torch.tan(angles)
theta = torch.stack([torch.sin(tanangles),torch.cos(tanangles)],dim=1)
X_prod = torch.matmul(X, theta.transpose(0, 1))
Y_prod = torch.matmul(Y, theta.transpose(0, 1))
distances =one_dimensional_Wasserstein_prod(X_prod,Y_prod,p=2)
distances = torch.softmax(distances,dim=-1)
angles = torch.atan2(theta[:,0],theta[:,1]).numpy()

# axs[0,1].bar(angles.numpy(), distances.numpy(),width=0.1, color = "green")
maxs = [distances[np.argmax(distances)]]
axs[2].plot(angles[np.argsort(angles)], distances.numpy()[np.argsort(angles)],label=r'$\sigma_{\mu,\nu}(\theta;f_e)$',linestyle='solid')
# axs[2].bar(angles[np.argmax(distances)], 1,width=0.02,label="True Max",color='black')
axs[2].set_xlim([-np.pi,np.pi])

np.random.seed(1)
torch.manual_seed(1)
start = time.time()
theta = rand_projections(d, 20000,'cpu')
X_prod = torch.matmul(X, theta.transpose(0, 1))
Y_prod = torch.matmul(Y, theta.transpose(0, 1))

distances =one_dimensional_Wasserstein_prod(X_prod,Y_prod,p=2)

weights = torch.softmax(distances,dim=-1)
inds = torch.multinomial(weights, 10000, replacement=True)
angles = torch.atan2(theta[:,0],theta[:,1])
angles = angles[inds]
print("SIR: {}".format(time.time()-start))
values, counts = np.unique(angles.numpy(), return_counts=True)

# axs[0,2].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[0,2].set_xlim([-np.pi,np.pi])

kde_sk = VonMisesKDE(angles.reshape(-1,),np.full(angles.reshape(-1).shape,1)/angles.reshape(-1).shape[0],kappa=200)
eval_points = torch.linspace(-np.pi,np.pi,10000)
y_sk = kde_sk.evaluate(eval_points)
y_sk = y_sk / np.sum(y_sk)
maxs.append(np.max(y_sk))
# values, counts = np.unique(angles.numpy(), return_counts=True)
# axs[1,0].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[1,0].set_xlim([-np.pi,np.pi])
axs[2].plot(eval_points, y_sk,label=r"$\hat{\sigma}_{\mu,\nu}^{SIR}(\theta;f_e)$")
#
#
np.random.seed(1)
torch.manual_seed(1)
L=20000
start=time.time()
theta=rand_projections(d, 1, 'cpu')
thetas=[theta]
for l in range(L):
    theta_prime = rand_projections(d, 1, 'cpu')
    X_prod = torch.matmul(X, theta_prime.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta_prime.transpose(0, 1))
    distance_new = one_dimensional_Wasserstein_prod(X_prod, Y_prod, p=2)
    theta = thetas[-1]
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta.transpose(0, 1))
    distance_old = one_dimensional_Wasserstein_prod(X_prod, Y_prod, p=2)
    acceptance_rate = np.min([1,torch.exp(distance_new-distance_old)])
    u = torch.rand(1)
    if(u<=acceptance_rate):
        thetas.append(theta_prime)
    else:
        thetas.append(theta)
theta = torch.cat(thetas,dim=0)[-1000:]
print("IMH: {}".format(time.time()-start))
angles = torch.atan2(theta[:,0],theta[:,1])

kde_sk = VonMisesKDE(angles.reshape(-1,),np.full(angles.reshape(-1).shape,1)/angles.reshape(-1).shape[0],kappa=200)
eval_points = torch.linspace(-np.pi,np.pi,10000)
y_sk = kde_sk.evaluate(eval_points)
y_sk = y_sk / np.sum(y_sk)
maxs.append(np.max(y_sk))
# values, counts = np.unique(angles.numpy(), return_counts=True)
# axs[1,0].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[1,0].set_xlim([-np.pi,np.pi])
axs[2].plot(eval_points, y_sk,label=r"$\hat{\sigma}_{\mu,\nu}^{IMH}(\theta;f_e)$")


np.random.seed(1)
torch.manual_seed(1)
start=time.time()
L=20000
theta=rand_projections(d, 1, 'cpu')
thetas=[theta]
for l in range(L):
    theta = thetas[-1]
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta.transpose(0, 1))
    distance_old = one_dimensional_Wasserstein_prod(X_prod, Y_prod, p=2)
    vmf = VonMisesFisher(theta.data, torch.full((1, 1), 1, device='cpu'))
    theta_prime= vmf.rsample(1).view(1, -1)
    X_prod = torch.matmul(X, theta_prime.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta_prime.transpose(0, 1))
    distance_new = one_dimensional_Wasserstein_prod(X_prod, Y_prod, p=2)
    acceptance_rate = np.min([1,torch.exp(distance_new-distance_old)])
    u = torch.rand(1)
    if(u<=acceptance_rate):
        thetas.append(theta_prime)
    else:
        thetas.append(theta)
theta = torch.cat(thetas,dim=0)[-10000:]
print("RMH: {}".format(time.time()-start))
angles = torch.atan2(theta[:,0],theta[:,1])
values, counts = np.unique(angles.numpy(), return_counts=True)

# axs[1,1].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[1,1].set_xlim([-np.pi,np.pi])

kde_sk = VonMisesKDE(angles.reshape(-1,),np.full(angles.reshape(-1).shape,1)/angles.reshape(-1).shape[0],kappa=200)
eval_points = torch.linspace(-np.pi,np.pi,10000)
y_sk = kde_sk.evaluate(eval_points)
y_sk = y_sk / np.sum(y_sk)
maxs.append(np.max(y_sk))
# values, counts = np.unique(angles.numpy(), return_counts=True)
# axs[1,0].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[1,0].set_xlim([-np.pi,np.pi])
axs[2].plot(eval_points, y_sk,label=r"$\hat{\sigma}_{\mu,\nu}^{RMH}(\theta;f_e)$")



####



np.random.seed(1)
torch.manual_seed(1)
L=10000
kappa=20.
p=2
start=time.time()
epsilon = torch.randn((1, d), device='cpu', requires_grad=True)
epsilon.data = epsilon.data / torch.sqrt(torch.sum(epsilon.data ** 2, dim=1,keepdim=True))
optimizer = torch.optim.SGD([epsilon], lr=0.1)
X_detach = X.detach()
Y_detach = Y.detach()
for _ in range(100):
    vmf = VonMisesFisher(epsilon, torch.full((1, 1), kappa, device='cpu'))
    theta = vmf.rsample(L).view(L, -1)
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta.transpose(0, 1))
    negative_sw = -torch.pow(one_dimensional_Wasserstein_prod(X_prod, Y_prod, p=2).mean(),1./p)
    optimizer.zero_grad()
    negative_sw.backward()
    optimizer.step()
    epsilon.data = epsilon.data / torch.sqrt(torch.sum(epsilon.data ** 2, dim=1,keepdim=True))
# vmf = VonMisesFisher(epsilon.clone(), torch.full((1, 1), kappa, device='cpu'))
from torch.distributions.von_mises import VonMises
vmf =VonMises(torch.atan2(epsilon[:,0],epsilon[:,1]).detach().view(-1,), torch.tensor([kappa]))
# angles = torch.atan2(epsilon[:,0],epsilon[:,1]).detach().numpy()
# axs[1].bar(angles[0], max,width=0.02,label="mean",color='yellow')
# theta = vmf.rsample(L).view(L, -1)
# sw = one_dimensional_Wasserstein_prod(X, Y,theta, p=p).mean()
# theta = torch.cat(thetas,dim=0)[-5000:]
print("vDSW: {}".format(time.time()-start))
# angles = torch.atan2(theta[:,0],theta[:,1]).detach().numpy()
# values, counts = np.unique(angles, return_counts=True)
# axs[1,2].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[1,2].set_xlim([-np.pi,np.pi])
# kde_sk = VonMisesKDE(angles.reshape(-1,),np.full(angles.reshape(-1).shape,1)/angles.reshape(-1).shape[0],kappa=1)
angles = torch.linspace(-np.pi,np.pi,10000)
# tanangles = torch.tan(angles)
# theta = torch.stack([torch.sin(tanangles),torch.cos(tanangles)],dim=1)
probs = torch.exp(vmf.log_prob(angles)).detach().numpy()
probs = probs/np.sum(probs)
maxs.append(np.max(probs))
from scipy.stats import vonmises

# y_sk = kde_sk.evaluate(eval_points)
# y_sk = y_sk / np.sum(y_sk)
# values, counts = np.unique(angles.numpy(), return_counts=True)
# axs[1,0].bar(values, counts/np.sum(counts),width=0.1, color = "blue")
# axs[1,0].set_xlim([-np.pi,np.pi])
axs[2].plot(angles, probs, label = r"$vMF(\epsilon^\star,\kappa)$")


np.random.seed(1)
torch.manual_seed(1)
theta = torch.randn((1, d), device='cpu', requires_grad=True)
theta.data = epsilon.data / torch.sqrt(torch.sum(theta.data ** 2, dim=1,keepdim=True))
optimizer = torch.optim.SGD([theta], lr=0.1)
X_detach = X.detach()
Y_detach = Y.detach()
for _ in range(100):
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta.transpose(0, 1))
    negative_sw = -torch.pow(one_dimensional_Wasserstein_prod(X_prod, Y_prod, p=2).mean(),1./p)
    optimizer.zero_grad()
    negative_sw.backward()
    optimizer.step()
    theta.data = theta.data / torch.sqrt(torch.sum(theta.data ** 2, dim=1,keepdim=True))
angles = torch.atan2(theta[:,0],theta[:,1]).detach().numpy()
axs[2].bar(angles[0], 1,width=0.03,label=r"$\delta_{\theta^{\star}}$",color='black')

axs[2].set_ylim([0,np.max(maxs)+1e-6])

axs[2].legend(fontsize = 14)
axs[2].set_xlabel(r'$\theta$',fontsize = 14)
axs[2].set_ylabel(r'Density',fontsize = 14)
axs[2].set_title(r'$f_e(x) = e^x$',fontsize = 14)




plt.tight_layout()
plt.show()








