# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

import numpy as np
import pandas as pd

def get_Edesign(X, iterations=500, threshold=0, warm_start=None):
    #Get the XY design using the Franke-Wolfe Algorithm 
    if warm_start is None:
        lambda_vec = np.ones(len(X))/len(X)
    else:
        lambda_vec = warm_start.copy() 
        
    outers = np.matmul(X[:,:,np.newaxis], X[:,np.newaxis, :]) 

    for k in range(1,iterations):
                               
        A_lambda = np.sum(outers*lambda_vec[:,np.newaxis, np.newaxis], axis=0) 
        B = np.linalg.pinv(A_lambda)
        lambda_derivative = -np.linalg.norm(B, -2)*np.diagonal(B@X@X.T@B) #calculate derivative\
        
        #Frank-Wolfe update
        alpha = 2/(k+2) #step size
        min_lambda_derivative_index = np.argmin(lambda_derivative)
        lambda_vec -= alpha*lambda_vec
        lambda_vec[min_lambda_derivative_index] +=  alpha
        
#     print(A_lambda)
    val = np.linalg.norm(A_lambda, -2)
    return lambda_vec, val


def get_XYdesign(X, Y, iterations=5000, threshold=0, warm_start=None):
    #Get the XY design using the Franke-Wolfe Algorithm 
    if warm_start is None:
        lambda_vec = np.ones(len(X))/len(X)
    else:
        lambda_vec = warm_start.copy() 
        
    outers = np.matmul(X[:,:,np.newaxis], X[:,np.newaxis, :]) 
    old_y_max_val = 1 

    for k in range(1,iterations):
                               
        A_lambda = np.sum(outers*lambda_vec[:,np.newaxis, np.newaxis], axis=0) 
        if np.linalg.det(A_lambda) == 0: 
            cov_A = np.linalg.pinv(A_lambda)
        else:
            cov_A = np.linalg.inv(A_lambda)
        Y_A = Y @ cov_A @ Y.T
        pred_vars = np.diag(Y_A)
        y_max_val = np.max(pred_vars)
        max_y = Y[np.argmax(pred_vars)]
        lambda_derivative = -(max_y.T @ cov_A @ X.T)**2 #calculate derivative\
        
        #Frank-Wolfe update
        alpha = 2/(k+2) #step size
        min_lambda_derivative_index = np.argmin(lambda_derivative)
        lambda_vec -= alpha*lambda_vec
        lambda_vec[min_lambda_derivative_index] +=  alpha
        
        if y_max_val == 0 or abs((old_y_max_val - y_max_val)/old_y_max_val) < threshold: 
            break
        old_y_max_val = y_max_val 

    return cov_A, y_max_val, lambda_vec


def get_value_given_design(X, Y, lambda_vec):
  
    outers = np.matmul(X[:,:,np.newaxis], X[:,np.newaxis, :])                
    A_lambda = np.sum(outers*lambda_vec[:,np.newaxis, np.newaxis], axis=0) 
    if np.linalg.det(A_lambda) == 0: 
        cov_A = np.linalg.pinv(A_lambda)
    else:
        cov_A = np.linalg.inv(A_lambda)
    Y_A = Y @ cov_A @ Y.T
    pred_vars = np.diag(Y_A)
    y_max_val = np.max(pred_vars)

    return y_max_val


def get_uniform_subset_design(X, Y, subset):

    lambda_vec = np.zeros(len(X))
    lambda_vec[subset] = 1.
    lambda_vec = lambda_vec/len(subset)    
    outers = np.matmul(X[:,:,np.newaxis], X[:,np.newaxis, :])                
    A_lambda = np.sum(outers*lambda_vec[:,np.newaxis, np.newaxis], axis=0) 
    cov_A = np.linalg.pinv(A_lambda)
    Y_A = Y @ cov_A @ Y.T
    pred_vars = np.diag(Y_A)
    y_max_val = np.max(pred_vars)

    return cov_A, y_max_val, lambda_vec


def get_uniform_design(X, Y):

    lambda_vec = np.ones(len(X))/len(X)    
    outers = np.matmul(X[:,:,np.newaxis], X[:,np.newaxis, :])                
    A_lambda = np.sum(outers*lambda_vec[:,np.newaxis, np.newaxis], axis=0) 
    if np.linalg.det(A_lambda) == 0: 
        cov_A = np.linalg.pinv(A_lambda)
    else:
        cov_A = np.linalg.inv(A_lambda)
    Y_A = Y @ cov_A @ Y.T
    pred_vars = np.diag(Y_A)
    y_max_val = np.max(pred_vars)

    return cov_A, y_max_val, lambda_vec

def get_diff(W):
    
    indices = np.triu_indices(W.shape[0], 1)

    diff_matrix = W[indices[0]] - W[indices[1]]
    unique_diffs = np.unique(diff_matrix, axis=0)

    result = np.delete(unique_diffs, np.where(np.all(unique_diffs == 0, axis=1)), axis=0)
            
    return np.vstack(result)


def get_diff_opt(W, theta):

    if len(theta.shape)==2:
        theta = theta.flatten()
    # Choose the row with respect to which differences will be calculated
    values = W@theta
    reference_row = np.argmax(values)
    gaps = np.max(values)-values
    gaps = np.delete(gaps, np.where(gaps==0)[0])

    diff_matrix = (W - W[reference_row])
    diff_matrix = np.delete(diff_matrix, np.where(np.all(diff_matrix == 0, axis=1)), axis=0)
    diff_matrix = diff_matrix/gaps.reshape(-1, 1)

    result = np.unique(diff_matrix, axis=0)


    return np.vstack(result)
    


class GridProblem:
    def __init__(self, K, A, theta, sigmau, sigmaepsilon):
        self.A = A
        self.K = K
        self.theta = theta
        self.sigmau = sigmau
        self.sigmaepsilon = sigmaepsilon
        self.grid = [i/K for i in range(K)]

    def get_sample(self, zidx):
        u = self.sigmau*np.random.randn()
        xidx = np.abs((self.A[zidx] + u) - self.grid).argmin()
        #print(self.A[zidx], zidx, np.abs((self.A[zidx] + u)%1 - self.grid), xidx)
        r = self.theta[xidx] +  u + self.sigmaepsilon*np.random.normal()
        #print(xidx, zidx)
        return xidx, r
    
    def get_sample_batch(self, zidx):
        n = len(zidx)
        u = self.sigmau*np.random.randn(n)
        xidx = np.abs((self.A[zidx] + u).reshape(-1, 1) - np.tile(np.array(self.grid).reshape(-1, 1), n).T).argmin(axis=1)
        r = np.array(self.theta)[xidx] +  u + self.sigmaepsilon*np.random.normal(size=n)
        return xidx, r

def df_to_Gamma(df):
    # 1. Group by 'z', then 'x' and count occurrences
    counts = df.groupby(['z', 'x']).size().reset_index(name='counts')
    # 2. Group by 'z' and get total count for each 'z'
    total_counts = df.groupby('z').size().reset_index(name='total_counts')
    # Merge the counts with total counts
    merged_counts = pd.merge(counts, total_counts, on='z')
    # 3. Compute the proportion
    merged_counts['proportion'] = merged_counts['counts'] / merged_counts['total_counts']
    # 4. Pivot to get the matrix
    matrix = merged_counts.pivot(index='x', columns='z', values='proportion').fillna(0)
    # Optional: If you want to ensure the matrix is KxK (given K is the maximum possible value + 1)
    K = max(df['x'].max(), df['z'].max()) + 1
    for i in range(K):
        if i not in matrix.index:
            matrix.loc[i] = [0] * K
        if i not in matrix.columns:
            matrix[i] = [0] * K
    matrix = matrix.sort_index().sort_index(axis=1)
    return matrix.to_numpy()


def get_Gamma(K, A, horizon, theta, sigmau, sigmaepsilon):
    xs = []
    zs = []
    rs = []

    sampler = GridProblem(K, A, theta,sigmau,sigmaepsilon)
    for t in range(horizon):
        zidx = np.random.choice(K)
        xidx,r = sampler.get_sample(zidx)    
        zs.append(zidx)
        xs.append(xidx)
        rs.append(r)
    df = pd.DataFrame({'z':zs,'x':xs,'r':rs})
    Gamma = df_to_Gamma(df)

    return Gamma

