import os
import torch
import numpy as np
from torch import nn
from collections import OrderedDict
from args import get_args 
from data import get_dataset
from sparse_gradient_reconstruction import *
from PIL import Image
from utils import report_metrics

args = get_args()
torch.manual_seed(args.random_seed)
np.random.seed(args.random_seed)

train_loader, test_loader, img_size, num_classes, inv_transform = get_dataset(args)

treshold = get_tau( args.pFN, args.W )
print(treshold, treshold*args.W)
def test(neptune, model, device, test_loader, pset):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\n{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        pset, test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    if neptune:
        neptune[f'net_loss_{pset}'].log( test_loss )
        neptune[f'net_acc_{pset}'].log( 100. * correct / len(test_loader.dataset) )

class Net(nn.Module):
    def __init__(self, num_layers, size, img_size, num_classes, bias=True):
        super(Net, self).__init__()
        img_size = np.prod(img_size)
        layers = [ ('fc1', nn.Linear(img_size,size)), ('relu1', nn.ReLU()) ]
        for i in range(1,num_layers-1):
            layers.append(  ( f'fc{i+1}', nn.Linear(size,size, bias=bias) ) )
            layers.append( (f'relu{i+1}', nn.ReLU()) )
        layers.append( (f'fc{num_layers}', nn.Linear(size, num_classes, bias=bias)) )
        self.model = nn.Sequential( OrderedDict( layers ) )

    def forward(self, x):
        b = x.shape[0] 
        x = x.reshape(b,-1)
        x = self.model(x)
        return x

net = Net(args.L, args.W, img_size, num_classes).cuda()
optim = torch.optim.SGD(net.parameters(), lr=0.03)
step = 0
for i in range(100):
    net.train()
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.cuda(), targets.cuda()
        optim.zero_grad()
        out = net(data)
        loss = torch.nn.functional.cross_entropy(out, targets)
        loss.backward()
        optim.step()
        step += 1
        if step > args.steps:
            break
    if step > args.steps:
        break

print(f"Step {step}:")
test(args.neptune, net, 'cuda', test_loader, 'Test')
test(args.neptune, net, 'cuda', train_loader, 'Train')

optim.zero_grad()
net = net.cpu()

if args.ds_type == 'train':
    loader = train_loader
else:
    loader = test_loader

for batch_idx, (example_data, example_targets) in enumerate(loader):
    if batch_idx < args.st:
        continue
    if batch_idx >= args.en:
        break
    print(f'\n\n\nImage {batch_idx}\n\n\n')
    if args.neptune:
        args.neptune['step'].log(batch_idx)
    random_seed = batch_idx
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)

    l = torch.nn.functional.cross_entropy(net(example_data), example_targets)
    l.backward()

    B_true = example_targets.shape[0]
    if args.true_B:
        B_est, params, grad_params, LR, LR_inv = get_layer_decomp(args.neptune, net, 'model.fc1', B=args.B, device='cuda')
    else:
        B_est, params, grad_params, LR, LR_inv = get_layer_decomp(args.neptune, net, 'model.fc1', B=None, device='cuda')
    print( f'B_est: {B_est} vs B_true: {B_true}' )

    Q_opt = torch.linalg.lstsq( example_data.reshape(B_true,-1).cuda().T, LR[1].T, driver='gels').solution.T.detach()
    Q_opt_error = ( Q_opt @ example_data.reshape(B_true,-1).cuda() - LR[1] ).abs().max().item()
    dZ_opt = LR[0] @ Q_opt
    dZ_true_sparsity = ( dZ_opt.abs() < args.sparsity_tol ).sum(0)
    min_sparsity = min(dZ_true_sparsity).item()
    min_allowed_sparsity = LR[0].shape[0] * treshold
    print( f'Q_opt num error: {Q_opt_error}, Min sparsity: {min_sparsity}/{min_allowed_sparsity}' )
    if args.neptune:
        args.neptune['result/Q_opt_error'].log( Q_opt_error )
        args.neptune['result/min_sparsity'].log( min_sparsity )
        args.neptune['result/min_allowed_sparsity'].log( min_allowed_sparsity )
        args.neptune['parameters/treshold'].log( treshold )
    
    try:
        Q_rec, Q_rec_inv = getQ( args.neptune, params, grad_params, LR, LR_inv, Q_opt, device='cuda', N=args.N, par_SVD=args.par_SVD, 
            treshold=treshold, cond=args.cond, sigma_tol=args.sigma_tol, sigma_treshold=args.sigma_treshold, sparsity_tol=args.sparsity_tol, count_hack=args.count_hack )
    except:
        Q_rec, Q_rec_inv = None,None
        if args.neptune:
            args.neptune['result/failed'].log( batch_idx )

    if Q_rec is None:
        X_rec = torch.zeros( *example_data.shape )
        X_rec = X_rec.reshape(args.B, -1)
    else:
        X_rec = (Q_rec_inv[0] @ LR[1]).cpu()
    
    vision_metrics = report_metrics( net, X_rec, example_data, example_targets, 
                f'./result/{batch_idx}', f'result/rec/batch {batch_idx}', f'result/gt/batch {batch_idx}', 
                img_size, inv_transform, args.neptune )

    for p in net.parameters():
        p.grad = None
