from dataclasses import dataclass
from typing import Callable, Mapping

import elegy as eg
import jax
import numpy as np
import pyvista as pv
from tensorboardX import SummaryWriter

from .config import Config
from .data import extract_features, gen_example_from_mesh, simplify
from .model import Mesh


@dataclass
class MeshValData:
    mesh: pv.UnstructuredGrid
    u_true: np.ndarray
    max_u: float
    mesh_features: Mesh
    x: np.ndarray
    val_inputs: tuple


def process_mesh(mesh: pv.UnstructuredGrid, mult: float = 1.0) -> MeshValData:
    mesh = mesh.ctp()
    vols = mesh.compute_cell_sizes()["Volume"]
    mesh = mesh.extract_cells(vols != 0).triangulate()
    surf = mesh.extract_surface()
    surf = simplify(surf, 1024)
    max_u = np.max(np.abs(mesh["u"]))
    return MeshValData(
        mesh=mesh,
        u_true=mesh["u"],
        max_u=max_u,
        mesh_features=extract_features(surf),
        x=mesh.points,
        val_inputs=gen_example_from_mesh(
            surf, 1024, random_augmentation=False, mult=mult
        ),
    )


def load_meshes(pat: str) -> Mapping[str, pv.UnstructuredGrid]:
    from glob import glob
    from pathlib import Path

    files = glob(pat)
    return {Path(f).stem: pv.read(f) for f in files}


@jax.jit
def predict(u, x, mesh):
    return u(x, mesh)


def predict_and_save_mesh(u, name, mdata):
    u = predict(u, mdata.x, mdata.mesh_features)
    u = np.asarray(u)
    err = (
        np.square(u.reshape(mdata.u_true.shape) - mdata.u_true)
        / np.square(mdata.u_true).mean()
    )
    mdata.mesh.point_data["u_hat"] = u
    mdata.mesh.point_data["err"] = err
    mdata.mesh.save(f"{name}.vtu")
    return err


class MeshValidator(eg.callbacks.Callback):
    def __init__(
        self,
        meshes: Mapping[str, pv.UnstructuredGrid],
        val_writer_factory: Callable[[], SummaryWriter],
        cfg: Config,
        plot_meshes=False,
    ):
        super().__init__()
        self.meshes: Mapping[str, MeshValData] = {
            name: process_mesh(m, cfg.data.source_multiplier)
            for name, m in meshes.items()
        }
        self.val_writer_factory = val_writer_factory
        self._val_writer = None
        self.plot_meshes = plot_meshes
        self.best_err: float = np.inf

    @property
    def val_writer(self) -> SummaryWriter:
        if self._val_writer is None:
            self._val_writer = self.val_writer_factory()
        return self._val_writer

    def on_epoch_end(self, epoch, logs):
        from os import makedirs

        from colorcet import m_blues as cm

        def plot_contour(name: str, scalars: np.ndarray, mdata: MeshValData):
            scalars = np.clip(np.asarray(scalars), 0, 1)
            contours = mdata.mesh.contour(scalars=scalars)
            if contours.n_points >= 512:
                contours = contours.decimate(1 - 512 / contours.n_points)
            faces = contours.faces.reshape(1, -1, 4)[..., 1:]
            self.val_writer.add_mesh(
                tag=name,
                vertices=contours.points.reshape(1, -1, 3),
                faces=np.concatenate((faces, faces[..., ::-1]), axis=1),
                colors=cm(scalars, bytes=True).reshape(1, -1, 4)[..., :3],
                global_step=epoch,
            )

        makedirs("meshes", exist_ok=True)
        hparams = self.model.optimizer.opt_state.hyperparams
        for k, v in hparams.items():
            self.val_writer.add_scalar(f"optimizer/{k}", v, global_step=epoch)

        mean_err = 0.0
        for name, mdata in self.meshes.items():
            if epoch == 0 and self.plot_meshes:
                plot_contour(
                    f"{name}-u_true(x)",
                    np.square(mdata.u_true / mdata.max_u),
                    mdata,
                )
            with jax.default_device(jax.devices("cpu")[0]):
                err = predict_and_save_mesh(
                    self.model.module.u,
                    f"meshes/epoch-{epoch}-{name}.vtu",
                    mdata,
                )
                test_out = self.model.test_on_batch(*mdata.val_inputs)
                test_out["error"] = np.mean(err)
                mean_err += test_out["error"]

            for k, v in test_out.items():
                self.val_writer.add_scalar(
                    f"{k}/{name}",
                    v,
                    global_step=epoch,
                )

            if epoch % 100 == 0 and self.plot_meshes:
                plot_contour(f"{name}-u_hat(x)", np.square(u / mdata.max_u), mdata)
                plot_contour(f"{name}-err", err, mdata)

        mean_err /= len(self.meshes)
        if mean_err < self.best_err:
            self.best_err = mean_err
            self.model.save("best")
        self.val_writer.add_scalar("best_err", self.best_err, epoch)
