""" Defines an adversarial attack on PLNN's where we walk and project until
    we run into the boundary
"""
from _polytope_ import Polytope, Face
import utilities as utils
import torch
import copy
import numpy as np
import heapq
import matplotlib.pyplot as plt
import time
import plnn

class NaiveGeoCrawl(object):
    def __init__(self, net, verbose=True):
        assert isinstance(net, plnn.PLNN)
        self.net = net # Better only have 2 outputs
        self.verbose = verbose

    def _verbose_print(self, *args):
        if self.verbose:
            print(*args)


    def attack(self, example, label):
        """ Given an example and label, does the following 'attack':
            Let (x, y) be given and L(x') = f(x')[y'] -f(x')[y]
            where y' is the label that is not y. Notice that for PLNN's , the
            gradients of this loss are piecewise CONSTANT

        Then this attack does...
        Until hits decision_boundary:
            (i)   Compute gradient direction at working point
            (ii)  Rayshoot in that direction until hit polytope boundary
            (iii) Find new polytope to 'live in' and repeat
        """

        ##################################################################
        #   Setup things needed for loop                                 #
        ##################################################################

        original_logit = self.net(example)
        assert original_logit.numel() == 2 # Better have 2

        # Loss function is a linear operator on the network output
        loss_operator = torch.Tensor([1, -1]) * (2 * label.item() - 1)
        loss_fxn = lambda x: self.net(x).squeeze().dot(loss_operator)
        running_example = example.clone()
        running_config = self.net.relu_config(running_example, False)

        ##################################################################
        #   Start loop                                                   #
        ##################################################################
        seen_configs = set([utils.flatten_config(running_config)])
        iter_no = 0
        while True:
            if loss_fxn(running_example).item() >= 0:
                break
            prev_running_example = running_example
            running_example, running_config = self._next_running_example(
                                                                running_example,
                                                                running_config,
                                                                loss_operator)
            if utils.flatten_config(running_config) in seen_configs:
                pass
                # raise Exception("Backtracked!")
            seen_configs.add(utils.flatten_config(running_config))

            if self.verbose:
                print('-' * 50)
                print("ITERATION %02d" % iter_no)
                # print("RUNNING POINT:", utils.as_numpy(running_example))
                print("LOSS:", loss_fxn(running_example).item())
                print('-' * 50)

            bad_thing_happens = False
            if bad_thing_happens:
               break
            iter_no += 1
        ##################################################################
        #   Terminate loop and interpolate within this last polytope     #
        ##################################################################


        final_pt = self._interpolate_final_polytope(prev_running_example,
                                                    running_example,
                                                    loss_operator)

        return final_pt


    def _next_running_example(self, x, current_config, loss_operator):
        """ Meat of the algorithm: takes in a point x and a loss function
            (which needs to be a linear operator)
        Does the following:
            (i) computes the direction to MAXIMIZE loss_operator
            (ii) rayshoots from current polytope to boundary and returns the
                 (a) rayshot point
                 (b) new polytope
        """

        ######################################################################
        #   Compute direction to maximize loss_operator                      #
        ######################################################################
        # If the loss is < loss_operator, f(x)> (where f is vector valued)
        # Then the gradient is (loss_operator)Grad(f(x))
        # And if f is linear at x, then the loss operator is just the matrix



        matrix = self.net.compute_matrix(current_config) #
        grad_dir = torch.matmul(loss_operator, matrix)

        ######################################################################
        #   Now rayshoot to the next polytope boundary                       #
        ######################################################################
        new_point, new_config = self._rayshoot(x, grad_dir, current_config)
        return new_point, new_config



    def _rayshoot(self, x, direction, config):
        """ First compute the polytope from the config
                (also assert x lives in this polytope!)
            And then walk in the specified direction until we hit the polytope
            boundary.
        Return the new point, and the new config
        """
        toler = 1e-6

        ######################################################################
        #   Compute the polytope and rayshoot inside it                      #
        ######################################################################

        current_polytope = self.net.compute_polytope_config(config)
        # Don't need a polytope class
        poly_a = current_polytope['poly_a']
        poly_b = current_polytope['poly_b']
        dir_np = direction.data.cpu().numpy().squeeze()
        x_np = x.data.cpu().numpy().squeeze()
        if not all(np.matmul(poly_a, x_np) <= poly_b):
            print(min(poly_b - np.matmul(poly_a, x_np)))


        # First check that rayshoot does SOMETHING
        degen_case = False
        initial_tight = (poly_b - np.matmul(poly_a, x_np)) < toler
        for i in range(len(initial_tight)):
            if not initial_tight[i]:
                continue
            assert abs(np.dot(poly_a[i], x_np) - poly_b[i]) < toler
            if np.dot(poly_a[i], direction.data.numpy()) > 0:
                degen_case = True

        """ Quick'n'dirty rayshoot logic:
        Want to
            maximize t
            s.t. (Ax + tv) <= b (for v being the direction)
        This is equiv to
            maximize t
            s.t. t <= (b - Ax)_i / (Av)_i (for all i)
        Optimum is attained at the minimum of these upper bounds
        """

        # Double precision for everything
        poly_b_d = poly_b.astype(np.float64)
        poly_a_d = poly_a.astype(np.float64)
        x_np_d = x_np.astype(np.float64)
        dir_np_d = dir_np.astype(np.float64)

        numerator = (poly_b_d - np.matmul(poly_a_d, x_np_d)).squeeze()
        denominator = np.matmul(poly_a_d, dir_np_d).squeeze()

        min_idx, min_val = None, float('inf')
        for idx in range(len(numerator)):
            if denominator[idx] <= toler:
                # skip things that are in the wrong direction or really small
                continue
            else:
                val = numerator[idx] / denominator[idx]
                if val < min_val:
                    min_val = float(val)
                    min_idx = idx
        if min_idx is None:
            raise Exception("UNBOUNDED RAY SHOOT")

        rayshoot_point = x + min_val * direction


        ######################################################################
        #   Compute the new config                                           #
        ######################################################################

        flat_config = utils.flatten_config(config)

        # Flip the NEW TIGHT constraint
        initial_loose = (poly_b_d - np.matmul(poly_a_d, x_np_d)) > toler
        rayshoot_np = utils.as_numpy(rayshoot_point).squeeze()
        rayshoot_np_d = rayshoot_np.astype(np.float64)
        final_tight = (poly_b_d - np.matmul(poly_a_d, rayshoot_np_d)) <= toler
        new_tight = np.logical_and(initial_loose, final_tight)

        # new_tight[min_idx] = True # floating arithmetic be damned
        if sum(new_tight) > 1:
            raise Warning("Measure zero event occurred!")
        elif sum(new_tight) < 1:
            print("MINLEVELS",
                  poly_b[min_idx] - np.matmul(poly_a[min_idx], x_np))
            print("NEWTIGHT",
                  poly_b[min_idx] - np.matmul(poly_a[min_idx],
                                              utils.as_numpy(rayshoot_point).squeeze()))
            print("INFINALTIGHT", final_tight[min_idx])
            raise Exception("Backtrack????")


        coords_to_flip = [utils.index_to_config_coord(config, idx)
                          for idx in np.nonzero(new_tight)[0]]
        new_config = copy.deepcopy(config)
        for flip_i, flip_j in coords_to_flip:
            new_config[flip_i][flip_j] = int(1 - new_config[flip_i][flip_j])

        hamming_dist = utils.config_hamming_distance(config, new_config)
        if hamming_dist != 1:
            print("MIN VAL IS ", min_val)

            raise Warning("Hamming distance: ", hamming_dist)


        return rayshoot_point, new_config


    def _interpolate_final_polytope(self, prev, final, loss_operator):

        med = (prev + final) / 2.0
        med_config = self.net.relu_config(med, False)
        poly = self.net.compute_polytope_config(med_config)
        total_a = poly['total_a']
        total_b = poly['total_b']
        cta = torch.matmul(loss_operator, total_a)

        numer = torch.matmul(-total_b, loss_operator) -\
                torch.matmul(cta, prev.squeeze())
        denom = torch.matmul(cta, (final - prev).squeeze())
        return prev + numer / denom * (final - prev)





