'''
Utilities functions
'''
import torch
import numpy
import numpy as np
import os
import os.path as osp
import pickle
import argparse
from scipy.stats import ortho_group
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import pdb

#parse the configs from config file

def read_config():
    with open('config', 'r') as file:
        lines = file.readlines()
    
    name2config = {}
    for line in lines:
        
        if line[0] == '#' or '=' not in line:
            continue
        line_l = line.split('=')
        name2config[line_l[0].strip()] = line_l[1].strip()
    m = name2config
    if 'kahip_dir' not in m or 'data_dir' not in m or 'glove_dir' not in m or 'sift_dir' not in m:
        raise Exception('Config must have kahip_dir, data_dir, glove_dir, and sift_dir')
    return name2config

name2config = read_config()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
kahip_dir = name2config['kahip_dir'] 
graph_file = 'knn.graph'
data_dir = name2config['data_dir'] 

parts_path = osp.join(data_dir, 'partition', '')
dsnode_path = osp.join(data_dir, 'train_dsnode')

glove_dir = name2config['glove_dir'] 
sift_dir = name2config['sift_dir'] 


'''
Memory-compatible. 
Ranks of closest points not self.
Uses l2 dist. But uses cosine dist if data normalized. 
Input: 
-data: tensors
-specify k if only interested in the top k results.
-largest: whether pick largest when ranking. 
-include_self: include the point itself in the final ranking.
'''
def dist_rank(data_x, k, data_y=None, largest=False, opt=None, include_self=False):

    if isinstance(data_x, np.ndarray):
        data_x = torch.from_numpy(data_x)

    if data_y is None:
        data_y = data_x
    else:
        if isinstance(data_y, np.ndarray):
            data_y = torch.from_numpy(data_y)
    k0 = k
    device_o = data_x.device
    data_x = data_x.to(device)
    data_y = data_y.to(device)
    
    (data_x_len, dim) = data_x.size()
    data_y_len = data_y.size(0)
    #break into chunks. 5e6  is total for MNIST point size
    #chunk_sz = int(5e6 // data_y_len)
    chunk_sz = 16384
    chunk_sz = 500 #700 mem error. 1 mil points
    if data_y_len > 990000:
        chunk_sz = 600 #1000 if over 1.1 mil
        #chunk_sz = 500 #1000 if over 1.1 mil 
    else:
        chunk_sz = 3000    

    if k+1 > len(data_y):
        k = len(data_y) - 1
    #if opt is not None and opt.sift:
    
    if device == 'cuda':
        dist_mx = torch.cuda.LongTensor(data_x_len, k+1)
        act_dist = torch.cuda.FloatTensor(data_x_len, k+1)
    else:
        dist_mx = torch.LongTensor(data_x_len, k+1)
        act_dist = torch.cuda.FloatTensor(data_x_len, k+1)
    data_normalized = True if opt is not None and opt.normalize_data else False
    largest = True if largest else (True if data_normalized else False)
    
    #compute l2 dist <--be memory efficient by blocking
    total_chunks = int((data_x_len-1) // chunk_sz) + 1
    y_t = data_y.t()
    if not data_normalized:
        y_norm = (data_y**2).sum(-1).view(1, -1)
    
    for i in range(total_chunks):
        base = i*chunk_sz
        upto = min((i+1)*chunk_sz, data_x_len)
        cur_len = upto-base
        x = data_x[base : upto]
        
        if not data_normalized:
            x_norm = (x**2).sum(-1).view(-1, 1)        
            #plus op broadcasts
            dist = x_norm + y_norm        
            dist -= 2*torch.mm(x, y_t)
        else:
            dist = -torch.mm(x, y_t)
            
        topk_d, topk = torch.topk(dist, k=k+1, dim=1, largest=largest)
                
        dist_mx[base:upto, :k+1] = topk #torch.topk(dist, k=k+1, dim=1, largest=largest)[1][:, 1:]
        act_dist[base:upto, :k+1] = topk_d #torch.topk(dist, k=k+1, dim=1, largest=largest)[1][:, 1:]
        
    topk = dist_mx
    if k > 3 and opt is not None and opt.sift:
        #topk = dist_mx
        #sift contains duplicate points, don't run this in general.
        identity_ranks = torch.LongTensor(range(len(topk))).to(topk.device)
        topk_0 = topk[:, 0]
        topk_1 = topk[:, 1]
        topk_2 = topk[:, 2]
        topk_3 = topk[:, 3]

        id_idx1 = topk_1 == identity_ranks
        id_idx2 = topk_2 == identity_ranks
        id_idx3 = topk_3 == identity_ranks

        if torch.sum(id_idx1).item() > 0:
            topk[id_idx1, 1] = topk_0[id_idx1]

        if torch.sum(id_idx2).item() > 0:
            topk[id_idx2, 2] = topk_0[id_idx2]

        if torch.sum(id_idx3).item() > 0:
            topk[id_idx3, 3] = topk_0[id_idx3]           

    
    if not include_self:
        topk = topk[:, 1:]
        act_dist = act_dist[:, 1:]
    elif topk.size(-1) > k0:
        topk = topk[:, :-1]
    topk = topk.to(device_o)
    return act_dist, topk

'''
Memory-compatible. 
Input: 
-data: tensors
-data_y: if None take dist from data_x to itself
'''
def l2_dist(data_x, data_y=None):

    if data_y is not None:
        return _l2_dist2(data_x, data_y)
    else:
        return _l2_dist1(data_x)
   
'''
Memory-compatible, when insufficient GPU mem. To be combined with _l2_dist2 later.
Input: 
-data: tensor
'''
def _l2_dist1(data):

    if isinstance(data, numpy.ndarray):
        data = torch.from_numpy(data)
    (data_len, dim) = data.size()
    #break into chunks. 5e6  is total for MNIST point size
    chunk_sz = int(5e6 // data_len)    
    dist_mx = torch.FloatTensor(data_len, data_len)
    
    #compute l2 dist <--be memory efficient by blocking
    total_chunks = int((data_len-1) // chunk_sz) + 1
    y_t = data.t()
    y_norm = (data**2).sum(-1).view(1, -1)
    
    for i in range(total_chunks):
        base = i*chunk_sz
        upto = min((i+1)*chunk_sz, data_len)
        cur_len = upto-base
        x = data[base : upto]
        x_norm = (x**2).sum(-1).view(-1, 1)
        #plus op broadcasts
        dist_mx[base:upto] = x_norm + y_norm - 2*torch.mm(x, y_t)
        

    return dist_mx

'''
Memory-compatible.
Input: 
-data: tensor
'''
def _l2_dist2(data_x, data_y):

    (data_x_len, dim) = data_x.size()
    data_y_len = data_y.size(0)
    #break into chunks. 5e6  is total for MNIST point size
    chunk_sz = int(5e6 // data_y_len)
    dist_mx = torch.FloatTensor(data_x_len, data_y_len)
    
    #compute l2 dist <--be memory efficient by blocking
    total_chunks = int((data_x_len-1) // chunk_sz) + 1
    y_t = data_y.t()
    y_norm = (data_y**2).sum(-1).view(1, -1)
    
    for i in range(total_chunks):
        base = i*chunk_sz
        upto = min((i+1)*chunk_sz, data_x_len)
        cur_len = upto-base
        x = data_x[base : upto]
        x_norm = (x**2).sum(-1).view(-1, 1)
        #plus op broadcasts
        dist_mx[base:upto] = x_norm + y_norm - 2*torch.mm(x, y_t)
        
        #data_x = data[base : upto].unsqueeze(cur_len, data_len, dime(1).expand(cur_len, data_len, dim)
        #                                    )
    return dist_mx

 
'''
convert numpy array or list to markdown table
Input:
-numpy array (or two-nested list)
-s

'''
def mx2md(mx, row_label, col_label):
    #height, width = mx.shape
    height, width = len(mx), len(mx[0])
    
    if height != len(row_label) or width != len(col_label):
        raise Exception('mx2md: height != len(row_label) or width != len(col_label)')

    l = ['-']
    l.extend([str(i) for i in col_label])
    rows = [l]
    rows.append(['---' for i in range(width+1)])
    
    for i, row in enumerate(mx):
        l = [str(row_label[i])]
        l.extend([str(j) for j in mx[i]])
        rows.append(l)
        
    md = '\n'.join(['|'.join(row) for row in rows])
    #md0 = ['\n'.join(row) for row in rows]
    return md


def load_lines(path):
    with open(path, 'r') as file:
        lines = file.read().splitlines()
    return lines

'''                            
Input: lines is list of objects, not newline-terminated yet.                                                                        
'''
def write_lines(lines, path):
    lines1 = []
    for line in lines:
        lines1.append(str(line) + os.linesep)
    with open(path, 'w') as file:
        file.writelines(lines1)

def pickle_dump(obj, path):
    with open(path, 'wb') as file:
        pickle.dump(obj, file)

def pickle_load(path):
    with open(path, 'rb') as file:
        return pickle.load(file)

    
if __name__ == '__main__':
    mx1 = np.zeros((2,2))
    mx2 = np.ones((2,2))
    
    row = ['1','2']
    col = ['3','4']
    
    print(mxs2md([mx1,mx2], row, col))

