import os
import pickle
import shlex
from copy import deepcopy
from pathlib import Path
from subprocess import Popen
from tempfile import NamedTemporaryFile

import numpy as np
import torch
from pytorch_lightning.cli import LightningArgumentParser
from trimesh import Trimesh
from trimesh.sample import sample_surface
import pyvista as pv

import utils.BVH as BVH
from data_loaders.mret import (MRetDataModule, all_pose_to_body_pose,
                               all_static_to_body_static,
                               body_pose_to_all_pose, get_bone_ranges, indices2relative_coords)
from model.retnet import RetNet
from run.train_retnet import RetNetCLI
from utils.Animation import Animation
from utils.body_armatures import MixamoBodyArmature
from utils.lbs import Rotation2MixamoVerts, SkinnableSensor
from utils.Quaternions_old import Quaternions
from utils.rotation_conversions import matrix_to_quaternion, rotation_6d_to_matrix


class GenCLI(RetNetCLI):
    def add_arguments_to_parser(self, parser: LightningArgumentParser):
        super().add_arguments_to_parser(parser)
        parser.add_argument('--ckpt_path', type=str, required=True, help='Path to the checkpoint')
        parser.add_argument('--num_samples', type=int, default=10, help='Number of samples to generate')
        parser.add_argument('--output_dir', type=str, default='output', help='Output directory')

    def before_instantiate_classes(self):
        self.config.data['batch_size'] = self.config.num_samples


def main():
    cli = GenCLI(RetNet, MRetDataModule, run=False)
    cli.model = RetNet.load_from_checkpoint(cli.config.ckpt_path, map_location=cli.model.device)
    cli.model.freeze()
    cli.datamodule.setup('test')
    dataloader = cli.datamodule.test_dataloader()

    output_dir = Path(cli.config.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    batch = next(iter(dataloader))

    with torch.inference_mode():
        x, y = batch['x'], batch['y']
        if cli.config.model.only_body:
            x_body = all_pose_to_body_pose(x)
            y_body = deepcopy(y)
            y_body['src_static'] = all_static_to_body_static(y_body['src_static'])
            y_body['tgt_static'] = all_static_to_body_static(y_body['tgt_static'])
            num_clips = cli.config.data.seq_len // 30
            x_hat = []
            for i in range(num_clips):
                x_hat.append(cli.model(x_body[:, i*30:i*30+30], y_body))
            x_hat = torch.cat(x_hat, dim=1)
            for i in range(1, num_clips):
                for k in range(3):
                    m_1 = x_hat[:, i*30-k:i*30-k+2].mean(dim=1)
                    m_2 = x_hat[:, i*30+k:i*30+k+2].mean(dim=1)
                    x_hat[:, i*30-k] = m_1
                    x_hat[:, i*30+k+1] = m_2
            x_hat = body_pose_to_all_pose(x, x_hat)
        else:
            x_hat = cli.model(x, y)

    rot2verts = Rotation2MixamoVerts(x_hat.device)
    seq_mask = y['mask'].reshape(x.shape[0], x.shape[1])
    tgt_static = y['tgt_static']
    tgt_rest_verts = tgt_static['verts']
    tgt_faces = tgt_static['faces']
    tgt_joint_loc = tgt_static['joint_locations']
    parents = tgt_static['parents']
    tgt_lbs_weights = tgt_static['lbs_weights']
    src_static = y['src_static']
    src_rest_verts = src_static['verts']
    src_faces = src_static['faces']
    src_joint_loc = src_static['joint_locations']
    src_lbs_weights = src_static['lbs_weights']
    root_translation = batch['root_translation']
    joint_names = MixamoBodyArmature._standard_joint_names
    joint_names = ['mixamorig:' + j for j in joint_names]

    contact_precision(cli.model, x, x_hat, y['src_static'], y['tgt_static'])
    contact_precision(cli.model, x, x, y['src_static'], y['tgt_static'])

    # for clip_idx, sample in enumerate(x_hat):
    #     point_pairs = get_point_pairs(sample, src_static['body_sensor_locations'][clip_idx], src_static['body_sensor_mask'][clip_idx], src_static['body_sensor_tns'][clip_idx], src_static['body_sensor_weights'][clip_idx], src_joint_loc[clip_idx], parents)
    #     posed_verts = rot2verts(sample, seq_mask[clip_idx], cli.model.data_rep, tgt_rest_verts[clip_idx], tgt_joint_loc[clip_idx], parents, tgt_lbs_weights[clip_idx])
    #     draw_img(sample, tgt_static['body_sensor_locations'][clip_idx], tgt_static['body_sensor_mask'][clip_idx], tgt_static['body_sensor_tns'][clip_idx], tgt_static['body_sensor_weights'][clip_idx], tgt_joint_loc[clip_idx], parents, posed_verts, tgt_faces[clip_idx], point_pairs)
    #     posed_verts = rot2verts(x[clip_idx], seq_mask[clip_idx], cli.model.data_rep, src_rest_verts[clip_idx], src_joint_loc[clip_idx], parents, src_lbs_weights[clip_idx])
    #     draw_img(x[clip_idx], src_static['body_sensor_locations'][clip_idx], src_static['body_sensor_mask'][clip_idx], src_static['body_sensor_tns'][clip_idx], src_static['body_sensor_weights'][clip_idx], src_joint_loc[clip_idx], parents, posed_verts, src_faces[clip_idx], point_pairs)

    sample_pr, copy_pr = [], []
    for clip_idx, sample in enumerate(x_hat):
        armature = MixamoBodyArmature(MixamoBodyArmature._standard_joint_names, parents, tgt_rest_verts[clip_idx].detach().cpu().numpy(), tgt_faces[clip_idx].detach().cpu().numpy(), tgt_lbs_weights[clip_idx].detach().cpu().numpy(), tgt_joint_loc[clip_idx].detach().cpu().numpy())
        armature.joint_rotations = sample.unsqueeze(0)
        armature.root_locations = root_translation[clip_idx].reshape(1, -1, 3)
        sample_pr.append(armature.penetration_ratio())
        armature.joint_rotations = x[clip_idx].unsqueeze(0)
        copy_pr.append(armature.penetration_ratio())

    print('Sample PR:', np.mean(sample_pr), 'Copy PR:', np.mean(copy_pr))

    for clip_idx, sample in enumerate(x_hat):
        posed_verts = rot2verts(sample, seq_mask[clip_idx], cli.model.data_rep, tgt_rest_verts[clip_idx], tgt_joint_loc[clip_idx], parents, tgt_lbs_weights[clip_idx])
        posed_verts += root_translation[clip_idx]
        export_video(posed_verts, tgt_faces[clip_idx], output_dir / f'sample_{clip_idx}.mp4')
        export_bvh(joint_names, sample, tgt_joint_loc[clip_idx], root_translation[clip_idx], parents, output_dir / f'sample_{clip_idx}.bvh')
        posed_verts = rot2verts(x[clip_idx], seq_mask[clip_idx], cli.model.data_rep, src_rest_verts[clip_idx], src_joint_loc[clip_idx], parents, src_lbs_weights[clip_idx])
        posed_verts += root_translation[clip_idx]
        export_video(posed_verts, src_faces[clip_idx], output_dir / f'sample_{clip_idx}_src.mp4')
        export_bvh(joint_names, x[clip_idx], src_joint_loc[clip_idx], root_translation[clip_idx], parents, output_dir / f'sample_{clip_idx}_src.bvh')
        posed_verts = rot2verts(x[clip_idx], seq_mask[clip_idx], cli.model.data_rep, tgt_rest_verts[clip_idx], tgt_joint_loc[clip_idx], parents, tgt_lbs_weights[clip_idx])
        posed_verts += root_translation[clip_idx]
        export_video(posed_verts, tgt_faces[clip_idx], output_dir / f'sample_{clip_idx}_copy.mp4')
        export_bvh(joint_names, x[clip_idx], tgt_joint_loc[clip_idx], root_translation[clip_idx], parents, output_dir / f'sample_{clip_idx}_copy.bvh')


def export_video(posed_verts, faces, path):
    # mesh = Trimesh(vertices=posed_verts[15].detach().cpu().numpy(), faces=faces.detach().cpu().numpy())
    # points, *_ = sample_surface(mesh, 8192)
    # prev_z = points[:, 2].copy()
    # prev_y = points[:, 1].copy()
    # points[:, 2] = prev_y
    # points[:, 1] = prev_z
    # points[:, 0] = -points[:, 0]
    # np.save(path.with_suffix('.npy'), points)
    with NamedTemporaryFile('wb', delete=False) as temp_f:
        pickle.dump((posed_verts, faces), temp_f)
        temp_f.close()
        p = Popen(shlex.split(f'python -m run.play_mesh --input {temp_f.name} --output {path}'), stdout=None, stderr=None)
        p.wait()
        os.unlink(temp_f.name)


def export_bvh(joint_names, rotations, joint_loc, root_translations, parents, output):
    rotations = matrix_to_quaternion(rotation_6d_to_matrix(rotations))
    rotations = Quaternions(rotations.detach().cpu().numpy())
    root_positions = root_translations.detach().cpu().numpy()
    offsets = joint_loc.detach().cpu().numpy().copy()
    for idx, p in enumerate(parents):
        if p == -1:
            continue
        offsets[idx] = joint_loc[idx] - joint_loc[p]
    positions = np.concatenate([root_positions, np.repeat(offsets[np.newaxis], root_positions.shape[0], axis=0)[:, 1:]], axis=1)
    orients = Quaternions.id(offsets.shape[0])
    anim = Animation(rotations, positions, orients, offsets, parents)
    BVH.save(output, anim, joint_names, 1/30)


def masked_l2(a, b, mask):
    mse_loss = (a - b) ** 2
    mse_loss = (mse_loss * mask.float()).sum()
    non_zero_elements = mask.sum()
    mse_loss = mse_loss / (non_zero_elements + 1e-8)
    return mse_loss.mean()


def contact_precision(retnet, motion_A, motion_B, static_A, static_B):
    T = motion_A.shape[1]
    skin_sensor_A = SkinnableSensor(static_A['normalized_body_sensor_locations'].unsqueeze(1).expand(-1, T, -1, -1), static_A['body_sensor_tns'].unsqueeze(1).expand(-1, T, -1, -1, -1), static_A['normalized_joint_locations'].unsqueeze(1).expand(-1, T, -1, -1), static_A['parents'], static_A['body_sensor_weights'].unsqueeze(1).expand(-1, T, -1, -1))
    skin_sensor_B = SkinnableSensor(static_B['normalized_body_sensor_locations'].unsqueeze(1).expand(-1, T, -1, -1), static_B['body_sensor_tns'].unsqueeze(1).expand(-1, T, -1, -1, -1), static_B['normalized_joint_locations'].unsqueeze(1).expand(-1, T, -1, -1), static_B['parents'], static_B['body_sensor_weights'].unsqueeze(1).expand(-1, T, -1, -1))
    sensor_mask_A = static_A['body_sensor_mask']
    sensor_mask_B = static_B['body_sensor_mask']
    posed_sensor_locations_A, posed_sensor_tns_A = skin_sensor_A.skin(motion_A)
    posed_sensor_locations_B, posed_sensor_tns_B = skin_sensor_B.skin(motion_B)
    with torch.no_grad():
        _, src_relative_coords, src_sparse_indices, src_ef_vel, _, src_sparse_dist, src_ef_pose_global = retnet.get_cond_inp({'src_motion': motion_A, **static_A})
        pred_relative_coords, pred_sparse_mask = indices2relative_coords(retnet.group_pairs, posed_sensor_locations_B, posed_sensor_tns_B, sensor_mask_B, src_sparse_indices)
    contact_precision = []
    for gt_dist, pred_coords, cur_sparse_mask, cur_group_pair in zip(src_sparse_dist, pred_relative_coords, pred_sparse_mask, retnet.group_pairs):
        pred_dist = pred_coords.norm(dim=-1)
        cur_obs_limb_name = cur_group_pair[0][0][0].split('_')[0]
        src_limb_radius = static_A['normalized_limb_radius'][cur_obs_limb_name].reshape(-1, 1, 1, 1)
        tgt_limb_radius = static_B['normalized_limb_radius'][cur_obs_limb_name].reshape(-1, 1, 1, 1)
        close_mask = gt_dist < src_limb_radius * 0.5
        contact_precision.append(masked_l2(gt_dist/src_limb_radius, pred_dist/tgt_limb_radius, close_mask & cur_sparse_mask))
    print(contact_precision)


def draw_img(motion, sensor_locations, sensor_mask, sensor_tns, sensor_weights, joint_locations, parents, posed_verts, faces, point_pairs):
    skin_sensor = SkinnableSensor(sensor_locations, sensor_tns, joint_locations, parents, sensor_weights)
    sensor_mask = sensor_mask.detach().cpu().numpy()
    for i in [90]:
        plotter = pv.Plotter()
        body_mesh = pv.wrap(Trimesh(posed_verts[i].detach().cpu().numpy(), faces.detach().cpu().numpy()))
        plotter.add_mesh(body_mesh, color='lightgrey', opacity=1.0)
        posed_sensor_locations, posed_tns = skin_sensor.skin(motion[i])
        posed_sensor_locations = posed_sensor_locations.detach().cpu().numpy()
        posed_tns = posed_tns.detach().cpu().numpy()
        bone_ranges = get_bone_ranges(16)
        left_arm_range = bone_ranges['LeftArm']
        left_arm_sensor_locations = posed_sensor_locations[left_arm_range]
        left_arm_tns = posed_tns[left_arm_range]
        head_range = bone_ranges['Head']
        head_sensor_locations = posed_sensor_locations[head_range]
        plotter.add_arrows(left_arm_sensor_locations, left_arm_tns[..., 0], color='red', mag=2)
        plotter.add_arrows(left_arm_sensor_locations, left_arm_tns[..., 1], color='blue', mag=2)
        plotter.add_arrows(left_arm_sensor_locations, left_arm_tns[..., 2], color='green', mag=2)
        plotter.add_points(head_sensor_locations, color='black', point_size=5)
        for pos_idx, idx in point_pairs[i]:
            plotter.add_lines(np.array([left_arm_sensor_locations[pos_idx], head_sensor_locations[idx]]), color='red', width=2)
        plotter.camera_position = 'xy'
        plotter.image_scale = 4
        plotter.show()
        plotter.screenshot('test.png', transparent_background=True)


def get_point_pairs(motion, sensor_locations, sensor_mask, sensor_tns, sensor_weights, joint_locations, parents):
    skin_sensor = SkinnableSensor(sensor_locations, sensor_tns, joint_locations, parents, sensor_weights)
    sensor_mask = sensor_mask.detach().cpu().numpy()
    point_pairs = []
    for i in range(motion.shape[0]):
        cur_pairs = []
        posed_sensor_locations, posed_tns = skin_sensor.skin(motion[i])
        posed_sensor_locations = posed_sensor_locations.detach().cpu().numpy()
        posed_tns = posed_tns.detach().cpu().numpy()
        bone_ranges = get_bone_ranges(16)
        left_arm_range = bone_ranges['LeftArm']
        left_arm_sensor_locations = posed_sensor_locations[left_arm_range]
        left_arm_tns = posed_tns[left_arm_range]
        head_range = bone_ranges['Head']
        head_sensor_locations = posed_sensor_locations[head_range]
        k = 8
        for pos_idx, (pos, tns) in enumerate(zip(left_arm_sensor_locations, left_arm_tns)):
            close_head_indices = np.argpartition(np.linalg.norm(head_sensor_locations - pos, axis=1), k)[:k]
            for idx in close_head_indices:
                if not sensor_mask[head_range][idx] or not sensor_mask[left_arm_range][pos_idx]:
                    continue
                close_head_points = head_sensor_locations[idx]
                obs_dir = close_head_points - pos
                if np.sum(obs_dir * tns[..., 1]) < 0:
                    continue
                cur_pairs.append((pos_idx, idx))
        point_pairs.append(cur_pairs)
    return point_pairs


if __name__ == '__main__':
    main()
