from abc import ABC
import math
import yaml

import torch
import torch.distributions as dist
import numpy as np

from .graph_priors import ErdosRenyi, PermuteGenerate, ScaleFree
from .sergio_vectorized_pytorch import (
    sim_sergio_pytorch,
    outlier_effect,
    lib_size_effect,
    dropout_indicator,
)
from src import Uniform, RandInt, Beta
import igraph as ig


EPS = 2**-22


class CausalExperimentModel(ABC):
    """
    Basic interface for probabilistic models
    """

    def __init__(self):
        self.epsilon = torch.tensor(EPS)

    def sanity_check(self):
        assert self.var_dim > 0
        assert len(self.var_names) > 0

    def reset(self, n_parallel):
        raise NotImplementedError

    def run_experiment(self, design, theta):
        """
        Execute an experiment with given design.
        """
        # create model from sampled params
        n_samples = design.shape[-2]
        y = self.rsample(design, theta, n_samples)
        if isinstance(y, torch.Tensor):
            y = y.detach().clone()
        return y

    def get_likelihoods(self, y, design, thetas):
        lik = self.log_prob(y, design, thetas)
        return lik

    def sample_theta(self, num_theta, zero_bias):
        thetas = self.sample_prior(num_theta, zero_bias=zero_bias)
        return thetas


class GRNSergioModel(CausalExperimentModel):
    def __init__(
        self,
        d=2,
        n_parallel=1,
        graph_prior="erdos_renyi",
        graph_args={},
        intervention_type="kout",
        cell_types=5,
        b=dist.uniform.Uniform(1.0, 3.0),
        k_param=dist.uniform.Uniform(1.0, 5.0),
        k_sign_p=dist.beta.Beta(0.5, 0.5),
        hill=1.0,
        decays=0.8,
        noise_params=1.0,
        add_outlier_effect=True,
        add_lib_size_effect=True,
        add_dropout_effect=True,
        return_count_data=True,
        noise_config_type="10x-chromium-mini",
        noise_config_file="pyro/models/noise_config.yaml",
    ):
        super().__init__()
        graph_priors = {
            "erdos_renyi": ErdosRenyi,
            "scale_free": ScaleFree,
            "permute_generate": PermuteGenerate,
        }
        self.graph_prior_init = graph_priors[graph_prior]
        self.d = d
        self.var_dim = d
        self.n_parallel = n_parallel
        self.graph_args = graph_args
        self.var_names = [
            "graph",
            "k",
            "basal_rates",
            "hill",
            "decays",
            "noise_params",
            "outlier_prob",
            "outlier_mean",
            "outlier_scale",
            "lib_size_mean",
            "lib_size_scale",
            "dropout_shape",
            "dropout_percentile",
        ]
        self.intervention_type = intervention_type
        self.cell_types = cell_types
        self.b = b
        self.k_param = k_param
        self.k_sign_p = k_sign_p
        self.hill = hill
        self.decays = decays
        self.noise_params = noise_params
        self.add_outlier_effect = add_outlier_effect
        self.add_lib_size_effect = add_lib_size_effect
        self.add_dropout_effect = add_dropout_effect
        self.return_count_data = return_count_data
        self.noise_config_type = noise_config_type
        self.noise_config_file = noise_config_file
        self.reset(n_parallel)
        self.sanity_check()

    def sample_prior(self, num_theta, n_parallel=None, zero_bias=False):
        if n_parallel:
            self.reset(n_parallel)
        full_graph = self.graph_prior(num_theta).squeeze((-3))
        k = torch.abs(self.k_param.sample((num_theta, self.n_parallel, self.d, self.d)))
        effect_sgn = (
            dist.binomial.Binomial(
                1,
                self.k_sign_p.sample((num_theta, self.n_parallel, self.d, 1)),
            ).sample()
            * 2.0
            - 1.0
        )

        k = k * effect_sgn.to(torch.float32)
        basal_rates = self.b.sample(
            (num_theta, self.n_parallel, self.d, self.cell_types)
        )  # assuming 1 cell type is simulated
        # Load noise config
        with open(self.noise_config_file, "r") as file:
            config = yaml.safe_load(file)
        assert self.noise_config_type in config.keys(), (
            f"tech_noise_config `{self.noise_config_type}` "
            f"not in config keys: `{list(config.keys())}`"
        )
        outlier_prob_ = torch.tensor(config[self.noise_config_type]["outlier_prob"])
        #  Randomly choose outlier probability from the list with replacement
        outlier_prob = outlier_prob_[
            torch.randint(
                0, len(outlier_prob_), (num_theta, self.n_parallel), dtype=torch.int64
            )
        ]

        outlier_mean_ = torch.tensor(config[self.noise_config_type]["outlier_mean"])
        outlier_mean = outlier_mean_[
            torch.randint(
                0, len(outlier_mean_), (num_theta, self.n_parallel), dtype=torch.int64
            )
        ]
        outlier_scale_ = torch.tensor(config[self.noise_config_type]["outlier_scale"])
        outlier_scale = outlier_scale_[
            torch.randint(
                0, len(outlier_scale_), (num_theta, self.n_parallel), dtype=torch.int64
            )
        ]
        lib_size_mean_ = torch.tensor(config[self.noise_config_type]["lib_size_mean"])
        lib_size_mean = lib_size_mean_[
            torch.randint(
                0, len(lib_size_mean_), (num_theta, self.n_parallel), dtype=torch.int64
            )
        ]
        lib_size_scale_ = torch.tensor(config[self.noise_config_type]["lib_size_scale"])
        lib_size_scale = lib_size_scale_[
            torch.randint(
                0, len(lib_size_scale_), (num_theta, self.n_parallel), dtype=torch.int64
            )
        ]
        dropout_shape_ = torch.tensor(config[self.noise_config_type]["dropout_shape"])
        dropout_shape = dropout_shape_[
            torch.randint(
                0, len(dropout_shape_), (num_theta, self.n_parallel), dtype=torch.int64
            )
        ]
        dropout_percentile_ = torch.tensor(
            config[self.noise_config_type]["dropout_percentile"]
        ).to(torch.int32)
        dropout_percentile = dropout_percentile_[
            torch.randint(
                0,
                len(dropout_percentile_),
                (num_theta, self.n_parallel),
                dtype=torch.int64,
            )
        ]
        return {
            "graph": full_graph,
            "k": k,
            "basal_rates": basal_rates,
            "hill": self.hill
            * torch.ones((num_theta, self.n_parallel, self.d, self.d)),
            "decays": self.decays * torch.ones((num_theta, self.n_parallel, self.d)),
            "noise_params": self.noise_params
            * torch.ones((num_theta, self.n_parallel, self.d)),
            "outlier_prob": outlier_prob,
            "outlier_mean": outlier_mean,
            "outlier_scale": outlier_scale,
            "lib_size_mean": lib_size_mean,
            "lib_size_scale": lib_size_scale,
            "dropout_shape": dropout_shape,
            "dropout_percentile": dropout_percentile,
        }

    def rsample(self, design, theta, n_samples=1):
        graph = theta["graph"]
        k = theta["k"]
        basal_rates = theta["basal_rates"]
        hill = theta["hill"]
        decays = theta["decays"]
        noise_params = theta["noise_params"]
        outlier_prob = theta["outlier_prob"]
        outlier_mean = theta["outlier_mean"]
        outlier_scale = theta["outlier_scale"]
        lib_size_mean = theta["lib_size_mean"]
        lib_size_scale = theta["lib_size_scale"]
        dropout_shape = theta["dropout_shape"]
        dropout_percentile = theta["dropout_percentile"]
        return SergioGene(
            graph,
            k,
            hill,
            decays,
            noise_params,
            self.cell_types,
            basal_rates,
            design,
            add_outlier_effect=self.add_outlier_effect,
            add_lib_size_effect=self.add_lib_size_effect,
            add_dropout_effect=self.add_dropout_effect,
            outlier_prob=outlier_prob,
            outlier_mean=outlier_mean,
            outlier_scale=outlier_scale,
            lib_size_mean=lib_size_mean,
            lib_size_scale=lib_size_scale,
            dropout_shape=dropout_shape,
            dropout_percentile=dropout_percentile,
            interv_type=self.intervention_type,
        ).rsample((n_samples,))

    def log_prob(self, y, design, theta):
        pass

    def reset(self, n_parallel):
        self.n_parallel = n_parallel
        self.graph_prior = self.graph_prior_init(
            **{**self.graph_args, "n_parallel": self.n_parallel, "num_nodes": self.d}
        )


class GRNSergioModelNoisyIntervention(GRNSergioModel):
    def __init__(
        self,
        d=2,
        n_parallel=1,
        graph_prior="erdos_renyi",
        graph_args={},
        intervention_type="kout",
        cell_types=5,
        b=dist.uniform.Uniform(1.0, 3.0),
        k_param=dist.uniform.Uniform(1.0, 5.0),
        k_sign_p=dist.beta.Beta(0.5, 0.5),
        hill=1.0,
        decays=0.8,
        noise_params=1.0,
        add_outlier_effect=True,
        add_lib_size_effect=True,
        add_dropout_effect=True,
        return_count_data=True,
        noise_config_type="10x-chromium-mini",
        noise_config_file="pyro/models/noise_config.yaml",
        intervention_noise=0.1,
    ):
        super().__init__(
            d=d,
            n_parallel=n_parallel,
            graph_prior=graph_prior,
            graph_args=graph_args,
            intervention_type=intervention_type,
            cell_types=cell_types,
            b=b,
            k_param=k_param,
            k_sign_p=k_sign_p,
            hill=hill,
            decays=decays,
            noise_params=noise_params,
            add_outlier_effect=add_outlier_effect,
            add_lib_size_effect=add_lib_size_effect,
            add_dropout_effect=add_dropout_effect,
            return_count_data=return_count_data,
            noise_config_type=noise_config_type,
            noise_config_file=noise_config_file,
        )
        self.intervention_noise = intervention_noise

    def rsample(self, design, theta, n_samples=1):
        graph = theta["graph"]
        k = theta["k"]
        basal_rates = theta["basal_rates"]
        hill = theta["hill"]
        decays = theta["decays"]
        noise_params = theta["noise_params"]
        outlier_prob = theta["outlier_prob"]
        outlier_mean = theta["outlier_mean"]
        outlier_scale = theta["outlier_scale"]
        lib_size_mean = theta["lib_size_mean"]
        lib_size_scale = theta["lib_size_scale"]
        dropout_shape = theta["dropout_shape"]
        dropout_percentile = theta["dropout_percentile"]
        design = design.clone()
        flip_indices = (
            dist.bernoulli.Bernoulli(self.intervention_noise)
            .sample(design.shape)
            .bool()
        )
        return SergioGene(
            graph,
            k,
            hill,
            decays,
            noise_params,
            self.cell_types,
            basal_rates,
            design,
            add_outlier_effect=self.add_outlier_effect,
            add_lib_size_effect=self.add_lib_size_effect,
            add_dropout_effect=self.add_dropout_effect,
            outlier_prob=outlier_prob,
            outlier_mean=outlier_mean,
            outlier_scale=outlier_scale,
            lib_size_mean=lib_size_mean,
            lib_size_scale=lib_size_scale,
            dropout_shape=dropout_shape,
            dropout_percentile=dropout_percentile,
            interv_type=self.intervention_type,
            flip_indices=flip_indices,
        ).rsample((n_samples,))


class SergioGene(object):
    """
    SERGIO simulator for GRNs

    Args:
        b (Distribution): distribution for sampling basic reproduction rates. Example: `avici.synthetic.Uniform`
        k_param (Distribution): distribution for sampling (non-negative) interaction strenghts.
            Example: `avici.synthetic.Uniform`
        k_sign_p (Distribution): distribution of sampling probability for positive (vs. negative)
            interaction sign signs. Example: `avici.synthetic.Beta`
        hill (float): Hill function coefficient
        decays (float): decay rate
        noise_params (float): noise scale parameter
        cell_types (Distribution): distribution for sampling integer number of cell types.
            Example: `avici.synthetic.RandInt`
        noise_type (str): noise type in SERGIO simulator. Default: `dpd`
        sampling_state (int): configuration of SERGIO sampler. Default: 15
        dt (float): dt increment in stochastic process. Default: 0.01

        * Technical noise*

        tech_noise_config (str): specification of noise elvels.
            Select one of the keys in `avici/synthetic/sergio/noise_config.yaml`
        add_outlier_effect (bool): whether to simulate outlier effects based on `tech_noise_config`
        add_lib_size_effect (bool): whether to simulate library size effects based on `tech_noise_config`
        add_dropout_effect (bool): whether to simulate dropout effects based on `tech_noise_config`
        return_count_data (bool): whether to return Poisson count data of the float mean expression levels

        * Interventions *

        n_ko_genes (int): no. unique genes knocked out in all of data collected; -1 indicates all genes
    """

    has_rsample = False

    def __init__(
        self,
        graph,
        k,
        hill,
        decays,
        noise_params,
        n_cell_types,
        basal_rates,
        interv_mask,
        noise_type="dpd",
        sampling_state=15,
        dt=0.01,
        tech_noise_config=None,
        add_outlier_effect=True,
        add_lib_size_effect=True,
        add_dropout_effect=True,
        return_count_data=True,
        outlier_prob=0.01,
        outlier_mean=3.0,
        outlier_scale=1.0,
        lib_size_mean=6.0,
        lib_size_scale=0.3,
        dropout_shape=8,
        dropout_percentile=45,
        interv_type="kout",
        flip_indices=None,
    ):

        self.graph = graph
        self.k = k
        self.hill = hill
        self.decays = decays
        self.noise_params = noise_params
        self.n_cell_types = n_cell_types
        self.basal_rates = basal_rates
        self.noise_type = noise_type
        self.sampling_state = sampling_state
        self.dt = dt
        self.noise_config_type = tech_noise_config
        self.add_outlier_effect = add_outlier_effect
        self.add_lib_size_effect = add_lib_size_effect
        self.add_dropout_effect = add_dropout_effect
        self.return_count_data = return_count_data
        interv_mask = interv_mask[..., 0, :]
        self.interv_mask = (interv_mask > 0.5).cpu().squeeze()
        if self.interv_mask.sum() > 0:
            if flip_indices is not None:
                flip_indices = flip_indices[..., 0, :]
                self.interv_mask[flip_indices] = ~self.interv_mask[flip_indices]
        self.outlier_prob = outlier_prob
        self.outlier_mean = outlier_mean
        self.outlier_scale = outlier_scale
        self.lib_size_mean = lib_size_mean
        self.lib_size_scale = lib_size_scale
        self.dropout_shape = dropout_shape
        self.dropout_percentile = dropout_percentile
        self.interv_type = interv_type

    def rsample(self, sample_shape=torch.Size([])):
        # sample interaction terms K
        toporder = []
        for i in self.graph:
            g = ig.Graph.Adjacency(i.numpy().tolist())
            toporder.append(torch.tensor(g.topological_sorting(mode="out")))
        toporder = torch.stack(toporder, dim=1)

        if self.interv_mask.sum() == 0:
            number_sc = math.ceil(sample_shape[0] / self.n_cell_types)

        else:
            number_sc = 1
        # setup simulator
        expr = sim_sergio_pytorch(
            graph=self.graph,
            toporder=toporder,
            number_bins=self.n_cell_types,
            number_sc=number_sc,
            noise_params=self.noise_params,
            decays=self.decays,
            basal_rates=self.basal_rates,
            k=self.k,
            hill=self.hill,
            targets=self.interv_mask,
            interv_type=self.interv_type,
            sampling_state=self.sampling_state,
            dt=self.dt,
            safety_steps=2,
        )

        if self.add_outlier_effect:
            expr = outlier_effect(
                expr, self.outlier_prob, self.outlier_mean, self.outlier_scale
            )

        # 2) library size
        if self.add_lib_size_effect:
            expr = lib_size_effect(expr, self.lib_size_mean, self.lib_size_scale)

        # 3) dropout
        if self.add_dropout_effect:
            binary_ind = dropout_indicator(
                expr, self.dropout_shape, self.dropout_percentile
            )
            expr *= binary_ind

        # 4) mRNA count data
        if self.return_count_data:
            expr = torch.poisson(expr)

        # expr_agg = np.concatenate(expr, axis=1)
        x = expr.reshape(*expr.shape[:2], -1)

        x = x[..., torch.randperm(x.size(-1))]
        return x[..., : sample_shape[0]].transpose(-1, -2)

    def log_prob(self, y):
        pass


if __name__ == "__main__":
    import matplotlib.pyplot as plt

    # model = GRNSergioModel(d=3, n_parallel=1, graph_prior="erdos_renyi")
    # rng = np.random.default_rng()
    # thetas = model.sample_theta(1, zero_bias=True, rng=rng)
    # heta = {k: v[0] for k, v in thetas.items()}
    # import pickle

    # pickle.dump(theta, open("theta.pkl", "wb"))
    # rng = np.random.default_rng(np.random.SeedSequence(entropy=0))
    # print(torch.tensor([[0.0, 1.0, 0.0]]).repeat(50, 1).shape)
    # sample_ = model.run_experiment(
    #    torch.tensor([[0.0, 1.0, 0.0]]).repeat(1, 1), theta, rng=rng
    # )
    # print(sample_)
    n_parallel = 4
    model = NonLinGaussANMModel(d=3, n_parallel=n_parallel, graph_prior="erdos_renyi")
    thetas = model.sample_theta(1, zero_bias=True)
    thetas = {k: v[0] for k, v in thetas.items()}

    def test(y, omega, w, b, num_rff=100):
        phi = torch.cos(torch.einsum("...bd,...d->...b", omega.transpose(0, 1), y) + b)
        f = (
            np.sqrt(2.0)
            * torch.einsum("...e,...e->...", w, phi).squeeze(0)
            / np.sqrt(num_rff)
        )
        return f

    samples_ = torch.zeros(n_parallel, 1, 3)
    y_ = torch.randn(n_parallel, 1, 3)
    for i in range(n_parallel):
        for j in range(3):
            is_parent = (thetas["graph"][i, 0, :, j]).bool()
            omega = thetas["omega"][j, :, i][..., is_parent]
            w = thetas["w"][i, 0, j]
            b = thetas["b"][i, 0, j]
            samples_[i, 0, j] = test(y_[i, 0, is_parent], omega, w, b)
    design = torch.zeros(n_parallel, 1, 6)
    y_parents = y_.unsqueeze(-2).transpose(-1, -2) * thetas["graph"]
    phi = torch.cos(
        torch.einsum(
            "...bd,...d->...b",
            thetas["omega"].unsqueeze(0).permute(0, 3, 4, 1, 2, 5),
            y_parents.transpose(-1, -2),
        )
        + thetas["b"].unsqueeze(0)
    )
    f = (
        np.sqrt(2.0)
        * torch.einsum("...e,...e->...", thetas["w"].unsqueeze(0), phi).squeeze(0)
        / np.sqrt(100)
    )
    temp_ = NonLinearANSEM(
        thetas["graph"],
        thetas["w"],
        thetas["omega"],
        thetas["b"],
        dist.Normal(loc=thetas["bias"], scale=thetas["noise_scales"]),
        (design[..., :3] > 0).to(design.dtype),
        design[..., 3:],
        100,
    ).predict(y_)
    import pdb

    pdb.set_trace()

    sample = model.run_experiment(design, thetas)
    print(sample.shape)
