import torch
from pytorch3d.transforms import so3_exponential_map as rodrigues
import numpy as np
# from utils.evaluate import PCK_3d
import torch.nn.functional as F
import torch.nn as nn


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count if self.count != 0 else 0

def p3d_no_scale(p3d):
    p3d = p3d.reshape(-1, 3, 16)
    hey = p3d * 1
    bone_inx = [6, 0, 1, 6, 3, 4, -1, 6, 7, 8, 7, 10, 11, 7, 13, 14]
    # p3d -= p3d[:, :, 6][:,:,None]
    bone_lenth = torch.zeros((p3d.shape[0], 15), device=p3d.device)
    n = 0
    for i, j in enumerate(bone_inx):
        if j == -1:
            pass
        else:
            bone_lenth[:, n] = ((hey[:, :, j] - hey[:, :, i]) ** 2 + 1e-9).sum(-1) ** 0.5
            n += 1
   
    scale_p3d = bone_lenth.mean(1).unsqueeze(1).unsqueeze(1)


    p3d_scaled = p3d / scale_p3d
    # loss = ((p2d_scaled - p3d_scaled).abs().reshape(-1, 2, 16).sum(axis=1) * confs).sum() / (p2d_scaled.shape[0] * p2d_scaled.shape[1])

    return p3d_scaled

def p3d_no_scale_np(p3d):
    p3d = p3d.reshape(-1, 3, 16)
    hey = p3d * 1
    bone_inx = [6, 0, 1, 6, 3, 4, -1, 6, 7, 8, 7, 10, 11, 7, 13, 14]
    # p3d -= p3d[:, :, 6][:,:,None]
    bone_lenth = np.zeros((p3d.shape[0], 15))
    n = 0
    for i, j in enumerate(bone_inx):
        if j == -1:
            pass
        else:
            bone_lenth[:, n] = ((hey[:, :, j] - hey[:, :, i]) ** 2).sum(-1) ** 0.5
            n += 1
   
    scale_p3d = bone_lenth[:,0][:,None,None]


    p3d_scaled = p3d / scale_p3d
    # loss = ((p2d_scaled - p3d_scaled).abs().reshape(-1, 2, 16).sum(axis=1) * confs).sum() / (p2d_scaled.shape[0] * p2d_scaled.shape[1])

    return p3d_scaled

def get_bone_rate(p3d):
    p3d = p3d.reshape(-1, 3, 16)
    hey = p3d * 1
    bone_inx = [6, 0, 1, 6, 3, 4, -1, 6, 7, 8, 7, 10, 11, 7, 13, 14]

    bone_lenth = torch.zeros((p3d.shape[0], 15)).cuda()
    n = 0
    for i, j in enumerate(bone_inx):
        if j == -1:
            pass
        else:
            bone_lenth[:, n] = ((hey[:, :, j] - hey[:, :, i]) ** 2).sum(-1) ** 0.5
            n += 1
    # loss = ((p2d_scaled - p3d_scaled).abs().reshape(-1, 2, 16).sum(axis=1) * confs).sum() / (p2d_scaled.shape[0] * p2d_scaled.shape[1])

    return bone_lenth

def get_pose_features(p3d):

    p3d = p3d.reshape(-1, 3, 17)
    hey = p3d * 1
    bone_inx = [-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15]

    bone_lenth = torch.zeros((p3d.shape[0], 16)).cuda()
    depth_sign = torch.zeros((p3d.shape[0], 16)).cuda()

    n = 0
    for i, j in enumerate(bone_inx):
        if j == -1:
            pass
        else:
            bone_lenth[:, n] = ((hey[:, :, j] - hey[:, :, i]) ** 2 + 1e-9).sum(-1) ** 0.5
            xf = p3d[:, 0, i] / p3d[:, 2, i]
            yf = p3d[:, 1, i] / p3d[:, 2, i]
            a = xf ** 2 + yf ** 2 + 1
            b = (xf * hey[:,0,j].clone()+ yf * hey[:,1,j].clone() + hey[:,2,j].clone())
            mid = (b / a)

            depth_sign[:, n] = torch.sign(hey[:, 2, i] - mid)

            n += 1
            
    return bone_lenth, depth_sign


def get_pose_from_features(pose_lenth, pose_sign, p3d, p2d):
    
    p3d_updated = torch.zeros_like(p3d)
    p3d_updated[:, :, 6] = p3d[:, :, 6]
    bone_inx = [6, 0, 1, 6, 3, 4, -1, 6, 7, 8, 7, 10, 11, 7, 13, 14]
    n = 0
    f = p2d[:, 0, 6] * p3d[:, 2, 6] / p3d[:, 0, 6]
    # f = 3.839
    # f = (p2d[:, 0] * p3d[:, 2] / p3d[:, 0]).mean(-1)
    
    for i, j in enumerate(bone_inx):
        if j == -1:
            pass
        else:
            xf = (p2d[:,0,i] / f)
            yf = (p2d[:,1,i] / f) 
            # xf = p3d[:, 0, i] / p3d[:, 2, i]
            # yf = p3d[:, 1, i] / p3d[:, 2, i]
            D = pose_lenth[:, n]
            sign = pose_sign[:, n]
        
            a = xf ** 2 + yf ** 2 + 1
            b = (xf * p3d_updated[:,0,j].clone()+ yf * p3d_updated[:,1,j].clone() + p3d_updated[:,2,j].clone())
            c = p3d_updated[:,0,j].clone() ** 2 + p3d_updated[:,1,j].clone() ** 2 + p3d_updated[:,2,j].clone() ** 2 - D ** 2
            t = (b + sign * torch.sqrt(nn.ReLU()(b ** 2 - a * c) + 1e-9)) / a
            

            p3d_updated[:, :, i] = torch.stack([xf * t, yf * t, t], dim=1) * 1
            # if i == 7:
            #     y1 = t[193].item()
            #     y2 = ((b - sign * torch.sqrt(nn.ReLU()(b ** 2 - a * c) + 1e-9)) / a)[193].item()
            #     y0 = (y1 + y2)/2
            #     x0 = p3d[193][2,6].item()
            #     print("첫번째 :",round(y1, 2),", 두번쨰 :",round(y2, 2),", 가운데 :",round(y0, 2),", 루트 :" ,round(x0, 2))
            #     exit()
    


            n += 1

    return p3d_updated
    



def pred_to_gt_scale(p3d, gt):

    p3d = p3d.reshape(-1, 3, 16)
    
    hey = p3d * 1
    bone_inx = [6, 0, 1, 6, 3, 4, -1, 8, 9, -1, 7, 10, 11, 7, 13, 14]

    bone_lenth = torch.zeros((p3d.shape[0], 15)).cuda()
    n = 0
    for i, j in enumerate(bone_inx):
        if j == -1:
            pass
        else:
            bone_lenth[:, n] = ((hey[:, :, j] - hey[:, :, i]) ** 2).sum(-1) ** 0.5
            n += 1
    scale_p3d = bone_lenth.mean(-1).unsqueeze(1).unsqueeze(1)
    p3d_scaled = p3d / scale_p3d
    
    p3d = gt.reshape(-1, 3, 16)
    
    hey = p3d * 1
    bone_inx = [6, 0, 1, 6, 3, 4, -1, 8, 9, -1, 7, 10, 11, 7, 13, 14]

    bone_lenth = torch.zeros((p3d.shape[0], 14)).cuda()
    n = 0
    for i, j in enumerate(bone_inx):
        if j == -1:
            pass
        else:
            bone_lenth[:, n] = ((hey[:, :, j] - hey[:, :, i]) ** 2).sum(-1) ** 0.5
            n += 1
    scale_p3d = bone_lenth.mean(-1).unsqueeze(1).unsqueeze(1)
    
    p3d_scaled *= scale_p3d 

    return p3d_scaled

def rand_position(poses, rott=None, trans=None):
    poses = poses.detach().reshape(-1, 3, 17)

    poses = poses - poses[:, :, :1]
    if rott == None:

        rand_size = torch.rand((poses.shape[0], 1)).cuda() * torch.pi * 2
        rand_vec = (torch.ones((poses.shape[0], 3))).cuda()
        rand_theta = torch.rand((poses.shape[0])).cuda() * torch.pi
        rand_pie = torch.rand((poses.shape[0])).cuda() * torch.pi * 2

        rand_vec[:, 0] *= 0
        rand_vec[:, 2] *= 0

        # rand_vec[:, 0] *= torch.sin(rand_theta) * torch.cos(rand_pie)
        # rand_vec[:, 1] *= torch.sin(rand_theta) * torch.sin(rand_pie)
        # rand_vec[:, 2] *= torch.cos(rand_theta)
        rand_vec = rand_vec * rand_size
        rot_y = rodrigues(rand_vec)

        rand_size = torch.randn((poses.shape[0], 1)).cuda() * torch.pi / 2
        rand_vec = (torch.ones((poses.shape[0], 3))).cuda()
        rand_vec[:, 1] *= 0
        rand_vec[:, 2] *= 0
        rand_vec = rand_vec * torch.pi 
        rot_p =  rodrigues(rand_vec)

        rand_size = torch.randn((poses.shape[0], 1)).cuda() * torch.pi / 2
        rand_vec = (torch.ones((poses.shape[0], 3))).cuda()
        rand_vec[:, 1] *= 0
        rand_vec[:, 0] *= 0
        rand_vec = rand_vec * torch.pi 
        rot_w =  rodrigues(rand_vec)

        rott = rot_y @ rot_p @ rot_w



    poses = (rott @ poses)

    if trans == None:
        rand_num = torch.rand((poses.shape[0], 3)).cuda()
        trans = rand_num[:, :, None] * 1

        trans[:, 0] = torch.randn(trans[:, 0].shape) * 1000
        trans[:, 1] = torch.randn(trans[:, 0].shape) * 1000
        trans[:, 2] = trans[:, 2] * 10000 + 1000
    poses += trans
    min = torch.where(poses[:, 2].min() < 1000, 1000 - poses[:, 2].min(), 0)
    poses[:, 2] += min
    f = torch.rand((poses.shape[0], 1, 1)).cuda() + 1
    p3d = poses
    p2d = poses[:, :2] / poses[:, 2][:, None] * f


    return p3d, p2d, rott, trans

def human_model(pred):
    lenth_rate = 10

    pred = pred.unsqueeze(-1)

    len = np.array([0.49108774, 1.80588788, 0.43735805, 0.4342996 , 0.55776042,
       1.03019458, 0.92972383, 1.63597411, 1.6777138 ]) * 270
    p = torch.zeros((pred.shape[0], 17, 3)).cuda()
    x = torch.zeros((pred.shape[0], 3)).cuda() + torch.tensor([1, 0, 0]).cuda()
    y = torch.zeros((pred.shape[0], 3)).cuda() + torch.tensor([0, 1, 0]).cuda()
    z = torch.zeros((pred.shape[0], 3)).cuda() + torch.tensor([0, 0, 1]).cuda()

    p[:, 1] = p[:, 1] - x * len[0] * (1 + pred[:, 0] / lenth_rate)
    p[:, 4] = p[:, 4] + x * len[0] * (1 + pred[:, 0] / lenth_rate)

    p[:, 8] = p[:, 8] + z * len[1] * (1 + pred[:, 1] / lenth_rate)

    head_dot = len[2] * (1 + pred[:, 32] / lenth_rate)
    head_top = len[3] * (1 + pred[:, 33] / lenth_rate)
    nose = (head_dot / 2 + head_top / 2) * 7 / 12 * (0.8 + abs(pred[:, 2]) / 5)
    p[:, 9] = p[:, 8] + y * nose + z * (head_dot ** 2 - nose ** 2) ** 0.5
    p[:, 10] = p[:, 9] - y * nose + z * (head_top ** 2 - nose ** 2) ** 0.5

    head_rot = rodrigues(z * ((pred[:, 8]) * 45) * torch.pi / 180)
    head_front = rodrigues(-x * (pred[:, 9] * 60 + 10) * torch.pi / 180)
    head_side = rodrigues(y * ((pred[:, 10]) * 35) * torch.pi / 180)
    p[:, 9:11] = torch.transpose(
        head_front @ head_side @ head_rot @ torch.transpose(p[:, 9:11] - p[:, 8][:, None, :], 1, 2), 1, 2) + p[:, 8][:,
                                                                                                             None, :]

    shod_l_rot = rodrigues(y * torch.arcsin((pred[:, 11] + 1) * 0.8 / 2))
    shod_len = (x * len[4] * (1 + pred[:, 3] / 3))[:, :, None]
    shod_l = shod_l_rot.matmul(shod_len).squeeze(-1)
    p[:, 11] = p[:, 8] + shod_l
    shod_r_rot = rodrigues(-y * torch.arcsin((pred[:, 12] + 1) * 0.8 / 2))
    shod_r = shod_r_rot.matmul(shod_len).squeeze(-1)
    p[:, 14] = p[:, 8] - shod_r

    p[:, 12] = p[:, 11] - z * len[5] * (1 + pred[:, 4] / lenth_rate)
    p[:, 15] = p[:, 14] - z * len[5] * (1 + pred[:, 4] / lenth_rate)

    elbow_l_rot = rodrigues(x * (pred[:, 13] + 1) / 2 * 3.14 * 145 / 180)

    elbow_l_rot2 = rodrigues(z * ((pred[:, 14] + 1) / 2 * 150 - 40) * 3.14 / 180)

    elbow_l = (elbow_l_rot2 @ elbow_l_rot).matmul(z[:, :, None]).squeeze(-1)
    p[:, 13] = p[:, 12] - elbow_l * 251.7 * (1 + pred[:, 5] / 3)

    # side_angel = ((pred[:, 15] + 1) / 2 * 140 - 90)
    side_angel = ((pred[:, 15] + 1) / 2 * 170 - 120)
    front_angel = ((pred[:, 16] + 1) / 2 * 225 - 45)
    # side_angel = torch.where(front_angel < 0, (side_angel + 130) * 9 / 19, side_angel)
    # side_angel = torch.where(front_angel >= 165, side_angel * 0, side_angel)
    # side_angel = torch.where(abs(front_angel) <= 15, side_angel * 0, side_angel)

    shod_l_front = rodrigues(x * front_angel * 3.14 / 180)

    shod_vec = (shod_l_front @ y[:, :, None])[:, :, 0]

    shod_l_side = rodrigues(shod_vec * side_angel * 3.14 / 180)

    p[:, 12:14] = torch.transpose(
        shod_l_side @ shod_l_front @ torch.transpose(p[:, 12:14] - p[:, 11][:, None, :], 1, 2), 1, 2) + p[:, 11][:,
                                                                                                        None, :]

    elbow_r_rot = rodrigues(x * (pred[:, 17] + 1) / 2 * 3.14 * 145 / 180)

    elbow_r_rot2 = rodrigues(-z * ((pred[:, 18] + 1) / 2 * 150 - 40) * 3.14 / 180)

    elbow_r = (elbow_r_rot2 @ elbow_r_rot).matmul(z[:, :, None]).squeeze(-1)
    p[:, 16] = p[:, 15] - elbow_r * len[6] * (1 + pred[:, 5] / lenth_rate)

    # side_angel = ((pred[:, 19] + 1) / 2 * 140 - 90)
    side_angel = ((pred[:, 19] + 1) / 2 * 170 - 120)
    front_angel = ((pred[:, 20] + 1) / 2 * 225 - 45)
    # side_angel = torch.where(front_angel < 0, (side_angel + 130) * 9 / 19, side_angel)
    # side_angel = torch.where(front_angel >= 165, side_angel * 0, side_angel)
    # side_angel = torch.where(abs(front_angel) <= 15, side_angel * 0, side_angel)

    shod_r_front = rodrigues(x * front_angel * 3.14 / 180)

    shod_vec = (shod_r_front @ -y[:, :, None])[:, :, 0]

    shod_r_side = rodrigues(shod_vec * side_angel * 3.14 / 180)

    p[:, 15:17] = torch.transpose(
        shod_r_side @ shod_r_front @ torch.transpose(p[:, 15:17] - p[:, 14][:, None, :], 1, 2), 1, 2) + p[:, 14][:,
                                                                                                        None, :]
    p[:, 2] = p[:, 1] - z * len[7] * (1 + pred[:, 6] / lenth_rate)
    p[:, 5] = p[:, 4] - z * len[7] * (1 + pred[:, 6] / lenth_rate)
    knee_r_rot = rodrigues(-x * (pred[:, 21] + 1) / 2 * torch.pi * 135 / 180)
    knee_r_rot2 = rodrigues(-z * (pred[:, 22] * 45) * torch.pi / 180)
    knee_r = knee_r_rot2.matmul(knee_r_rot.matmul(z[:, :, None])).squeeze(-1)
    p[:, 3] = p[:, 2] - knee_r * len[8] * (1 + pred[:, 7] / lenth_rate)
    knee_r_side = rodrigues(y * ((pred[:, 23] + 1) / 2 * 70 - 25) * torch.pi / 180)

    knee_r_front = rodrigues(x * ((pred[:, 24] + 1) / 2 * 140 - 30) * torch.pi / 180)
    p[:, 1:4] = torch.transpose(knee_r_front @ knee_r_side @ torch.transpose(p[:, 1:4] - p[:, 1][:, None, :], 1, 2), 1,
                                2) + p[:, 1][:, None, :]

    knee_l_rot = rodrigues(-x * (pred[:, 25] + 1) / 2 * torch.pi * 135 / 180)
    knee_l_rot2 = rodrigues(z * (pred[:, 26] * 45) * torch.pi / 180)
    knee_l = knee_l_rot2.matmul(knee_l_rot.matmul(z[:, :, None])).squeeze(-1)
    p[:, 6] = p[:, 5] - knee_l * len[8] * (1 + pred[:, 7] / lenth_rate)
    knee_l_side = rodrigues(-y * ((pred[:, 27] + 1) / 2 * 70 - 25) * torch.pi / 180)
    # knee_l_front = rodrigues(x * ((pred[:, 28] + 1) / 2 * 135 - 45) * torch.pi / 180)
    knee_l_front = rodrigues(x * ((pred[:, 28] + 1) / 2 * 140 - 30) * torch.pi / 180)
    p[:, 5:7] = torch.transpose(knee_l_front @ knee_l_side @ torch.transpose(p[:, 5:7] - p[:, 4][:, None, :], 1, 2), 1,
                                2) + p[:, 4][:, None, :]

    spine_rot = rodrigues(z * ((pred[:, 29]) * 30) * torch.pi / 180)
    spine_front = rodrigues(-x * ((pred[:, 30] + 1) / 2 * 105 - 30) * torch.pi / 180)
    spine_side = rodrigues(y * ((pred[:, 31]) * 35) * torch.pi / 180)
    p[:, 7:] = torch.transpose(spine_front @ spine_side @ spine_rot @ torch.transpose(p[:, 7:] * 1, 1, 2), 1, 2)
    # p[:, :7] = torch.transpose(spine_front @ torch.transpose(p[:, :7] * 1, 1, 2), 1, 2)

    pp = torch.transpose(p, 1, 2)
    p3d = torch.zeros_like(pp)
    p3d[:, 0] = pp[:, 0]
    p3d[:, 1] = -pp[:, 2]
    p3d[:, 2] = -pp[:, 1]
    return p3d
def human_model3(pred):
    ## pred : [-1 ~ 1]

    lenth_rate = 10

    pred = pred.unsqueeze(-1)
    
    len = np.array([0.49108774, 1.80588788, 0.43735805, 0.4342996 , 0.55776042,
       1.03019458, 0.92972383, 1.63597411, 1.6777138 ]) * 270 #[골반, 척추, 목, 머리, 어꺠, 상완, 하완, 허벅지, 종아리] 원래 버전
    
    len = np.array([0.5128, 1.7349, 0.4170, 0.4124 , 0.5871,
       1.0064, 0.8158, 1.6585, 1.6545 ]) * 282 #[골반, 척추, 목, 머리, 어꺠, 상완, 하완, 허벅지, 종아리] 헷갈리면 이거로
    # len = np.array([0.50608774, 1.70588788, 0.43735805, 0.4342996 , 0.55776042,
    #    1.03019458, 0.92972383, 1.63597411, 1.6777138 ]) * 270 #[골반, 척추, 목, 머리, 어꺠, 상완, 하완, 허벅지, 종아리]
    p = torch.zeros((pred.shape[0], 17, 3)).cuda()
    x = torch.zeros((pred.shape[0], 3)).cuda() + torch.tensor([1, 0, 0]).cuda()
    y = torch.zeros((pred.shape[0], 3)).cuda() + torch.tensor([0, 1, 0]).cuda()
    z = torch.zeros((pred.shape[0], 3)).cuda() + torch.tensor([0, 0, 1]).cuda()

    p[:, 1] = p[:, 1] - x * len[0] * (1 + pred[:, 0] / lenth_rate)
    p[:, 4] = p[:, 4] + x * len[0] * (1 + pred[:, 0] / lenth_rate)

    p[:, 8] = p[:, 8] + z * len[1] * (1 + pred[:, 1] / lenth_rate)

    head_dot = len[2] * (1 + pred[:, 32] / lenth_rate)
    head_top = len[3] * (1 + pred[:, 33] / lenth_rate)
    nose = (head_dot / 2 + head_top / 2) * 7 / 12 * (0.8 + abs(pred[:, 2]) / 5)
    p[:, 9] = p[:, 8] + y * nose + z * (head_dot ** 2 - nose ** 2) ** 0.5
    p[:, 10] = p[:, 9] - y * nose + z * (head_top ** 2 - nose ** 2) ** 0.5

    head_rot = rodrigues(z * ((pred[:, 8]) * 45) * torch.pi / 180)
    head_front = rodrigues(-x * (pred[:, 9] * 60 + 10) * torch.pi / 180)
    head_side = rodrigues(y * ((pred[:, 10]) * 35) * torch.pi / 180)
    p[:, 9:11] = torch.transpose(
        head_front @ head_side @ head_rot @ torch.transpose(p[:, 9:11] - p[:, 8][:, None, :], 1, 2), 1, 2) + p[:, 8][:,
                                                                                                             None, :]

    shod_l_rot = rodrigues(y * torch.arcsin((pred[:, 11] + 1) * 0.8 / 2))
    shod_len = (x * len[4] * (1 + pred[:, 3] / lenth_rate))[:, :, None]
    shod_l = shod_l_rot.matmul(shod_len).squeeze(-1)
    p[:, 11] = p[:, 8] + shod_l
    shod_r_rot = rodrigues(-y * torch.arcsin((pred[:, 12] + 1) * 0.8 / 2))
    shod_r = shod_r_rot.matmul(shod_len).squeeze(-1)
    p[:, 14] = p[:, 8] - shod_r

    p[:, 12] = p[:, 11] - z * len[5] * (1 + pred[:, 4] / lenth_rate)
    p[:, 15] = p[:, 14] - z * len[5] * (1 + pred[:, 4] / lenth_rate)

    elbow_l_rot = rodrigues(x * (pred[:, 13] + 1) / 2 * 3.14 * 145 / 180)

    elbow_l_rot2 = rodrigues(z * ((pred[:, 14] + 1) / 2 * 150 - 40) * 3.14 / 180)

    elbow_l = (elbow_l_rot2 @ elbow_l_rot).matmul(z[:, :, None]).squeeze(-1)
    p[:, 13] = p[:, 12] - elbow_l * 251.7 * (1 + pred[:, 5] / 3)

    # side_angel = ((pred[:, 15] + 1) / 2 * 140 - 90)
    side_angel = ((pred[:, 15] + 1)/ 2 * 190 - 170)
    front_angel = ((pred[:, 16] + 1) / 2 * 225 - 45)
    # side_angel = torch.where(abs(front_angel) < 30, -abs(side_angel), side_angel)
    # side_angel = torch.where(front_angel < 0, (side_angel + 130) * 9 / 19, side_angel)
    # side_angel = torch.where(front_angel >= 165, side_angel * 0, side_angel)
    # side_angel = torch.where(abs(front_angel) <= 15, side_angel * 0, side_angel)

    shod_l_side = rodrigues(y * side_angel * 3.14 / 180)


    shod_vec = (shod_l_side @ x[:, :, None])[:, :, 0]

    shod_l_front = rodrigues(shod_vec * front_angel * 3.14 / 180)

    p[:, 12:14] = torch.transpose(
        shod_l_front @ shod_l_side @ torch.transpose(p[:, 12:14] - p[:, 11][:, None, :], 1, 2), 1, 2) + p[:, 11][:,
                                                                                                        None, :]


    elbow_r_rot = rodrigues(x * (pred[:, 17] + 1) / 2 * 3.14 * 145 / 180)

    elbow_r_rot2 = rodrigues(-z * ((pred[:, 18] + 1) / 2 * 150 - 40) * 3.14 / 180)

    elbow_r = (elbow_r_rot2 @ elbow_r_rot).matmul(z[:, :, None]).squeeze(-1)
    p[:, 16] = p[:, 15] - elbow_r * len[6] * (1 + pred[:, 5] / lenth_rate)

    # side_angel = ((pred[:, 19] + 1) / 2 * 140 - 90)
    side_angel = ((pred[:, 19] + 1) / 2 * 190 - 170)
    # side_angel = nonlinear(pred[:, 19], 20, -170)

    front_angel = ((pred[:, 20] + 1) / 2 * 225 - 45)
    # front_angel = nonlinear(pred[:, 20], 180, -45 )
    # side_angel = torch.where(front_angel < 40, -abs(side_angel), side_angel)
    # front_angel = torch.where(side_angel < 0, torch.where(abs(front_angel)<40, front_angel + 20 * front_angel/abs(front_angel), front_angel), front_angel)
    # exit()
    # side_angel = torch.where(front_angel < 0, (side_angel + 130) * 9 / 19, side_angel)
    # side_angel = torch.where(front_angel >= 165, side_angel * 0, side_angel)
    # side_angel = torch.where(abs(front_angel) <= 15, side_angel * 0, side_angel)

    # shod_r_front = rodrigues(x * front_angel * 3.14 / 180)
    #
    # shod_vec = (shod_r_front @ -y[:, :, None])[:, :, 0]
    #
    # shod_r_side = rodrigues(shod_vec* side_angel * 3.14 / 180)
    # p[:, 15:17] = torch.transpose(
    #   shod_r_side @ shod_r_front @ torch.transpose(p[:, 15:17] - p[:, 14][:, None, :], 1, 2), 1, 2) + p[:, 14][:,
    #                                                                                                     None, :]
    shod_r_side = rodrigues(-y * side_angel * 3.14 / 180)
    shod_vec = (shod_r_side @ x[:, :, None])[:, :, 0]
    shod_r_front = rodrigues(shod_vec * front_angel * 3.14 / 180)
    shod_r = rodrigues( -y * side_angel * 3.14 / 180 + x * front_angel * 3.14 / 180)

    # p[:, 15:17] = torch.transpose(
    #     shod_r @ torch.transpose(p[:, 15:17] - p[:, 14][:, None, :], 1, 2), 1, 2) + p[:, 14][:,
    #                                                                                                     None, :]
    p[:, 15:17] = torch.transpose(
        shod_r_front @ shod_r_side @ torch.transpose(p[:, 15:17] - p[:, 14][:, None, :], 1, 2), 1, 2) + p[:, 14][:,
                                                                                                        None, :]
    p[:, 2] = p[:, 1] - z * len[7] * (1 + pred[:, 6] / lenth_rate)
    p[:, 5] = p[:, 4] - z * len[7] * (1 + pred[:, 6] / lenth_rate)
    knee_r_rot = rodrigues(-x * (pred[:, 21] + 1) / 2 * torch.pi * 135 / 180)
    knee_r_rot2 = rodrigues(-z * (pred[:, 22] * 45) * torch.pi / 180)
    knee_r = knee_r_rot2.matmul(knee_r_rot.matmul(z[:, :, None])).squeeze(-1)
    p[:, 3] = p[:, 2] - knee_r * len[8] * (1 + pred[:, 7] / lenth_rate)
    knee_r_side = rodrigues(y * ((pred[:, 23] + 1) / 2 * 70 - 25) * torch.pi / 180)

    knee_r_front = rodrigues(x * ((pred[:, 24] + 1) / 2 * 140 - 30) * torch.pi / 180)
    p[:, 1:4] = torch.transpose(knee_r_front @ knee_r_side @ torch.transpose(p[:, 1:4] - p[:, 1][:, None, :], 1, 2), 1,
                                2) + p[:, 1][:, None, :]

    knee_l_rot = rodrigues(-x * (pred[:, 25] + 1) / 2 * torch.pi * 135 / 180)
    knee_l_rot2 = rodrigues(z * (pred[:, 26] * 45) * torch.pi / 180)
    knee_l = knee_l_rot2.matmul(knee_l_rot.matmul(z[:, :, None])).squeeze(-1)
    p[:, 6] = p[:, 5] - knee_l * len[8] * (1 + pred[:, 7] / lenth_rate)
    knee_l_side = rodrigues(-y * ((pred[:, 27] + 1) / 2 * 70 - 25) * torch.pi / 180)
    # knee_l_front = rodrigues(x * ((pred[:, 28] + 1) / 2 * 135 - 45) * torch.pi / 180)
    knee_l_front = rodrigues(x * ((pred[:, 28] + 1) / 2 * 140 - 30) * torch.pi / 180)
    p[:, 5:7] = torch.transpose(knee_l_front @ knee_l_side @ torch.transpose(p[:, 5:7] - p[:, 4][:, None, :], 1, 2), 1,
                                2) + p[:, 4][:, None, :]

    spine_rot = rodrigues(z * ((pred[:, 29]) * 30) * torch.pi / 180)
    spine_front = rodrigues(-x * ((pred[:, 30] + 1) / 2 * 105 - 30) * torch.pi / 180)
    spine_side = rodrigues(y * ((pred[:, 31]) * 35) * torch.pi / 180)
    p[:, 7:] = torch.transpose(spine_front @ spine_side @ spine_rot @ torch.transpose(p[:, 7:] * 1, 1, 2), 1, 2)
    # p[:, :7] = torch.transpose(spine_front @ torch.transpose(p[:, :7] * 1, 1, 2), 1, 2)
    index = [1, 2, 3, 4, 5, 6, 0, 8, 9, 10, 14, 15, 16, 11, 12, 13]
 
    pp = torch.transpose(p, 1, 2)
    p3d = torch.zeros_like(pp)
    p3d[:, 0] = pp[:, 0]
    p3d[:, 1] = -pp[:, 2]
    p3d[:, 2] = -pp[:, 1]
    return p3d

def human_model_for_h36m(pred):
    lenth_rate = 5

    pred = pred.unsqueeze(-1)

    # len = np.array([0.49108774, 1.80588788, 0.43735805, 0.4342996 , 0.55776042,
    #    1.03019458, 0.92972383, 1.63597411, 1.6777138 ]) * 270 #[골반, 척추 아래, 목, 머리, 어꺠, 상완, 하완, 허벅지, 종아리, 척추위]
    
    len1 = torch.tensor([132.9485889,  233.47550133, 121.1349384,  115.00222509, 151.03422411,
        278.88276868, 251.73344834, 442.89460875, 454.20644352, 257.07767603]) # S1
    len2 = torch.tensor([119.31363074, 224.31715423, 117.11775943, 114.99853789, 143.09625478,
        264.58479855, 248.62032941, 428.28616939, 442.44419415, 254.05558418]) # S5
    len3 = torch.tensor([142.61378145, 262.2150067,  119.39894189, 115.00002887, 149.37477976,
        301.00832809, 257.91450733, 486.5709849,  461.49364563, 260.00926677]) # S6
    len4 = torch.tensor([135.87891883, 226.24669822, 107.14753056, 115.00316109, 139.71766197,
        275.56358157, 247.29862765, 448.61460293, 438.01023965, 255.4075579 ]) # S7
    len5 = torch.tensor([146.5378281,  261.21454096, 120.45857404, 115.00028777, 169.31633827,
        289.90848878, 244.17701696, 452.14748513, 438.63330443, 251.02422458]) # S8
    len6 = torch.tensor([124.0746,256.2813228709519, 120.5530, 114.9971, 183.2926, 296.3204, 249.0150, 472.1931, 468.6243, 250.3569269232421]) # S9
    len7 = torch.tensor([138.184, 254.377569435607,  111.0849, 114.9990, 166.3306, 283.1711, 248.3148, 461.9405, 460.2236, 249.9474227831257]) # S11
    len_set = torch.stack([len1, len2, len3, len4, len5, len6, len7])
    len_index = torch.randint(0, 7, (pred.shape[0],)).cuda()
    len = len_set[len_index].unsqueeze(-1).cuda()
    scale = len.mean(dim=1, keepdim=True)
    # len /= scale
    # print(len.shape)

    # indices = torch.randint(0, 2, (pred.shape[0],))

    # len = torch.where(indices[:, None] == 0, len1, len2).unsqueeze(-1).cuda()

    # len = torch.stack([len1 if torch.rand(1) < 0.5 else len2 for _ in range(pred.shape[0])])
    # print(len.shape)
    # exit()

    p = torch.zeros((pred.shape[0], 17, 3)).cuda()
    x = torch.zeros((pred.shape[0], 3)).cuda() + torch.tensor([1, 0, 0]).cuda()
    y = torch.zeros((pred.shape[0], 3)).cuda() + torch.tensor([0, 1, 0]).cuda()
    z = torch.zeros((pred.shape[0], 3)).cuda() + torch.tensor([0, 0, 1]).cuda()

    p[:, 1] = p[:, 1] - x * len[:,0] * (1 + pred[:, 0] / lenth_rate)
    p[:, 4] = p[:, 4] + x * len[:,0] * (1 + pred[:, 0] / lenth_rate)

    p[:, 7] = p[:, 7] + z * len[:,1] * (1 + pred[:, 1] / lenth_rate)
    p[:, 8] = p[:, 7] + z * len[:,-1] * (1 + pred[:, 1] / lenth_rate)

    head_dot = len[:,2] * (1 + pred[:, 32] / lenth_rate)
    head_top = len[:,3] * (1 + pred[:, 33] / lenth_rate)
    nose = (head_dot / 2 + head_top / 2) * 7 / 12 * (0.8 + abs(pred[:, 2]) / 5)
    p[:, 9] = p[:, 8] + y * nose + z * (head_dot ** 2 - nose ** 2) ** 0.5
    p[:, 10] = p[:, 9] - y * nose + z * (head_top ** 2 - nose ** 2) ** 0.5

    head_rot = rodrigues(z * ((pred[:, 8]) * 45) * torch.pi / 180)
    head_front = rodrigues(-x * (pred[:, 9] * 60 + 10) * torch.pi / 180)
    head_side = rodrigues(y * ((pred[:, 10]) * 35) * torch.pi / 180)
    p[:, 9:11] = torch.transpose(
        head_front @ head_side @ head_rot @ torch.transpose(p[:, 9:11] - p[:, 8][:, None, :], 1, 2), 1, 2) + p[:, 8][:,
                                                                                                             None, :]

    shod_l_rot = rodrigues(y * torch.arcsin((pred[:, 11] + 1) * 0.8 / 2))
    shod_len = (x * len[:,4] * (1 + pred[:, 3] / 3))[:, :, None]
    shod_l = shod_l_rot.matmul(shod_len).squeeze(-1)
    p[:, 11] = p[:, 8] + shod_l
    shod_r_rot = rodrigues(-y * torch.arcsin((pred[:, 12] + 1) * 0.8 / 2))
    shod_r = shod_r_rot.matmul(shod_len).squeeze(-1)
    p[:, 14] = p[:, 8] - shod_r

    p[:, 12] = p[:, 11] - z * len[:,5] * (1 + pred[:, 4] / lenth_rate)
    p[:, 15] = p[:, 14] - z * len[:,5] * (1 + pred[:, 4] / lenth_rate)

    elbow_l_rot = rodrigues(x * (pred[:, 13] + 1) / 2 * 3.14 * 145 / 180)

    elbow_l_rot2 = rodrigues(z * ((pred[:, 14] + 1) / 2 * 150 - 40) * 3.14 / 180)

    elbow_l = (elbow_l_rot2 @ elbow_l_rot).matmul(z[:, :, None]).squeeze(-1)
    p[:, 13] = p[:, 12] - elbow_l * 251.7 * (1 + pred[:, 5] / 3)

    # side_angel = ((pred[:, 15] + 1) / 2 * 140 - 90)
    side_angel = ((pred[:, 15] + 1)/ 2 * 190 - 170)
    front_angel = ((pred[:, 16] + 1) / 2 * 225 - 45)
    # side_angel = torch.where(abs(front_angel) < 30, -abs(side_angel), side_angel)
    # side_angel = torch.where(front_angel < 0, (side_angel + 130) * 9 / 19, side_angel)
    # side_angel = torch.where(front_angel >= 165, side_angel * 0, side_angel)
    # side_angel = torch.where(abs(front_angel) <= 15, side_angel * 0, side_angel)

    shod_l_side = rodrigues(y * side_angel * 3.14 / 180)


    shod_vec = (shod_l_side @ x[:, :, None])[:, :, 0]

    shod_l_front = rodrigues(shod_vec * front_angel * 3.14 / 180)

    p[:, 12:14] = torch.transpose(
        shod_l_front @ shod_l_side @ torch.transpose(p[:, 12:14] - p[:, 11][:, None, :], 1, 2), 1, 2) + p[:, 11][:,
                                                                                                        None, :]


    elbow_r_rot = rodrigues(x * (pred[:, 17] + 1) / 2 * 3.14 * 145 / 180)

    elbow_r_rot2 = rodrigues(-z * ((pred[:, 18] + 1) / 2 * 150 - 40) * 3.14 / 180)

    elbow_r = (elbow_r_rot2 @ elbow_r_rot).matmul(z[:, :, None]).squeeze(-1)
    p[:, 16] = p[:, 15] - elbow_r * len[:,6] * (1 + pred[:, 5] / lenth_rate)

    # side_angel = ((pred[:, 19] + 1) / 2 * 140 - 90)
    side_angel = ((pred[:, 19] + 1) / 2 * 190 - 170)
    # side_angel = nonlinear(pred[:, 19], 20, -170)

    front_angel = ((pred[:, 20] + 1) / 2 * 225 - 45)
    # front_angel = nonlinear(pred[:, 20], 180, -45 )
    # side_angel = torch.where(front_angel < 40, -abs(side_angel), side_angel)
    # front_angel = torch.where(side_angel < 0, torch.where(abs(front_angel)<40, front_angel + 20 * front_angel/abs(front_angel), front_angel), front_angel)
    # exit()
    # side_angel = torch.where(front_angel < 0, (side_angel + 130) * 9 / 19, side_angel)
    # side_angel = torch.where(front_angel >= 165, side_angel * 0, side_angel)
    # side_angel = torch.where(abs(front_angel) <= 15, side_angel * 0, side_angel)

    # shod_r_front = rodrigues(x * front_angel * 3.14 / 180)
    #
    # shod_vec = (shod_r_front @ -y[:, :, None])[:, :, 0]
    #
    # shod_r_side = rodrigues(shod_vec* side_angel * 3.14 / 180)
    # p[:, 15:17] = torch.transpose(
    #   shod_r_side @ shod_r_front @ torch.transpose(p[:, 15:17] - p[:, 14][:, None, :], 1, 2), 1, 2) + p[:, 14][:,
    #                                                                                                     None, :]
    shod_r_side = rodrigues(-y * side_angel * 3.14 / 180)
    shod_vec = (shod_r_side @ x[:, :, None])[:, :, 0]
    shod_r_front = rodrigues(shod_vec * front_angel * 3.14 / 180)
    shod_r = rodrigues( -y * side_angel * 3.14 / 180 + x * front_angel * 3.14 / 180)

    # p[:, 15:17] = torch.transpose(
    #     shod_r @ torch.transpose(p[:, 15:17] - p[:, 14][:, None, :], 1, 2), 1, 2) + p[:, 14][:,
    #                                                                                                     None, :]
    p[:, 15:17] = torch.transpose(
        shod_r_front @ shod_r_side @ torch.transpose(p[:, 15:17] - p[:, 14][:, None, :], 1, 2), 1, 2) + p[:, 14][:,
                                                                                                        None, :]
    p[:, 2] = p[:, 1] - z * len[:,7] * (1 + pred[:, 6] / lenth_rate)
    p[:, 5] = p[:, 4] - z * len[:,7] * (1 + pred[:, 6] / lenth_rate)
    knee_r_rot = rodrigues(-x * (pred[:, 21] + 1) / 2 * torch.pi * 135 / 180)
    knee_r_rot2 = rodrigues(-z * (pred[:, 22] * 45) * torch.pi / 180)
    knee_r = knee_r_rot2.matmul(knee_r_rot.matmul(z[:, :, None])).squeeze(-1)
    p[:, 3] = p[:, 2] - knee_r * len[:,8] * (1 + pred[:, 7] / lenth_rate)
    knee_r_side = rodrigues(y * ((pred[:, 23] + 1) / 2 * 70 - 25) * torch.pi / 180)

    knee_r_front = rodrigues(x * ((pred[:, 24] + 1) / 2 * 140 - 30) * torch.pi / 180)
    p[:, 1:4] = torch.transpose(knee_r_front @ knee_r_side @ torch.transpose(p[:, 1:4] - p[:, 1][:, None, :], 1, 2), 1,
                                2) + p[:, 1][:, None, :]

    knee_l_rot = rodrigues(-x * (pred[:, 25] + 1) / 2 * torch.pi * 135 / 180)
    knee_l_rot2 = rodrigues(z * (pred[:, 26] * 45) * torch.pi / 180)
    knee_l = knee_l_rot2.matmul(knee_l_rot.matmul(z[:, :, None])).squeeze(-1)
    p[:, 6] = p[:, 5] - knee_l * len[:,8] * (1 + pred[:, 7] / lenth_rate)
    knee_l_side = rodrigues(-y * ((pred[:, 27] + 1) / 2 * 70 - 25) * torch.pi / 180)
    # knee_l_front = rodrigues(x * ((pred[:, 28] + 1) / 2 * 135 - 45) * torch.pi / 180)
    knee_l_front = rodrigues(x * ((pred[:, 28] + 1) / 2 * 140 - 30) * torch.pi / 180)
    p[:, 5:7] = torch.transpose(knee_l_front @ knee_l_side @ torch.transpose(p[:, 5:7] - p[:, 4][:, None, :], 1, 2), 1,
                                2) + p[:, 4][:, None, :]

    spine_rot = rodrigues(z * ((pred[:, 29]) * 30) * torch.pi / 180)
    spine_front = rodrigues(-x * ((pred[:, 30] + 1) / 2 * 105 - 30) * torch.pi / 180)
    spine_side = rodrigues(y * ((pred[:, 31]) * 35) * torch.pi / 180)
    p[:, 8:] = torch.transpose(spine_front @ torch.transpose(p[:, 8:] - p[:, 7][:, None, :], 1, 2), 1, 2) + p[:, 7][:, None, :]
    p[:, 7:] = torch.transpose(spine_side @ spine_rot @ torch.transpose(p[:, 7:] * 1, 1, 2), 1, 2)
    # p[:, :7] = torch.transpose(spine_front @ torch.transpose(p[:, :7] * 1, 1, 2), 1, 2)
    # index = [1, 2, 3, 4, 5, 6, 0, 8, 9, 10, 14, 15, 16, 11, 12, 13]
    pp = torch.zeros((pred.shape[0], 17, 3)).cuda()
    # for i, j in enumerate(index):
    #     pp[:, i] = p[:, j]
    pp = torch.transpose(p, 1, 2).reshape(-1, 3, 17)
    p3d = torch.zeros_like(pp)
    p3d[:, 0] = pp[:, 0]
    p3d[:, 1] = -pp[:, 2]
    p3d[:, 2] = -pp[:, 1]
    bone = get_bone(p3d.permute(0,2,1)).mean(dim=1, keepdim=True).unsqueeze(-1)    
    
    return p3d

def get_distance(pose):
    
    torso = torch.linalg.norm((pose[:,:,0] - pose[:,:,4]), dim=1) * 0.9
    up_arm =  torch.linalg.norm((pose[:,:,14] - pose[:,:,15]), dim=1) / 16
    down_arm = torch.linalg.norm((pose[:,:,15] - pose[:,:,16]), dim=1) / 16
    up_leg = torch.linalg.norm((pose[:, :, 1] - pose[:, :, 2]), dim=1) / 16
    down_leg = torch.linalg.norm((pose[:, :, 2] - pose[:, :, 3]), dim=1) / 16
    a = torch.cross((pose[:, :, 8] - pose[:, :, 9]),(pose[:, :, 8] - pose[:, :, 10]))
    head = torch.linalg.norm(a, dim=1) / torch.linalg.norm((pose[:, :, 8] - pose[:, :, 10]), dim=1) * 2
    shoulder = up_arm * 1
    pelvis = up_leg * 1

    # limit = torch.tensor(
    #     [25, 25, 25, 25, 25, 25, 25, 25, 100, 25, 5, 5, 15, 10, 15, 10, 100, 25, 100, 100, 25, 5, 0, 15, 10, 15, 10, 100,
    #      25, 100, 100, 25, 25, 25, 25, 25, 25, 100, 15, 10, 15, 10, 100, 25, 100, 100, 15, 10, 15, 10, 100, 25, 100, 100, 15,
    #      15, 100, 100, 100, 15, 0, 100, 25, 100, 100, 100, 100, 100, 100, 25, 100, 100, 100, 100, 100])
    # limit = torch.tensor([25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 5, 5, 15, 10, 15, 10, 50, 25, 25, 25, 25, 5, 0, 15, 10, 15, 10, 50,
    #          25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 15, 10, 15, 10, 50, 25, 25, 25, 15, 10, 15, 10, 50, 25, 25, 25, 15,
    #          15, 50, 25, 25, 15, 0, 50, 25, 25, 25, 50, 25, 25, 50, 25, 25, 25, 50, 25, 25])
    length = []
    name = []
    act = torch.nn.Hardtanh(min_val= 0, max_val=1)
    Bone = {'R_shoulder' : [8,14],'R_up_arm': [14, 15], 'R_down_arm': [15, 16], 'L_shoulder' : [8, 11],'L_up_arm': [11, 12], 'L_down_arm': [12, 13],
            'R_up_leg': [1, 2], 'R_down_leg': [2, 3], 'L_up_leg': [4, 5], 'L_down_leg': [5, 6], 'torso': [0, 8],
            'pelvis': [1, 4], 'head': [8, 10]}

    Bone2 = {'R_shoulder' : [8,14],'R_up_arm': [14, 15], 'R_down_arm': [15, 16], 'L_shoulder' : [8, 11],'L_up_arm': [11, 12], 'L_down_arm': [12, 13],
            'R_up_leg': [1, 2], 'R_down_leg': [2, 3], 'L_up_leg': [4, 5], 'L_down_leg': [5, 6], 'torso': [0, 8],
            'pelvis': [1, 4], 'head': [8, 10]}

    for j in Bone:
        Bone2.pop(j)
        for k in Bone2:
            segment1 = Bone[j]
            segment2 = Bone[k]
            if segment1[0] == segment2[0] or segment1[0] == segment2[1] or segment1[1] == segment2[0] or segment1[1] == segment2[1]:
                pass
            elif j == 'torso' and k == 'pelvis':
                pass
            elif j == 'R_down_arm' and k == 'L_down_arm':
                pass
            else:

                p1 = pose[:, :, segment1[0]]
                p2 = pose[:, :, segment1[1]]
                q1 = pose[:, :, segment2[0]]
                q2 = pose[:, :, segment2[1]]
                u = p2 - p1
                v = q2 - q1
                w0 = p1 - q1
                a = torch.bmm(u.reshape(-1,1,3), u.reshape(-1,3,1)).reshape(-1,1)
                b = torch.bmm(u.reshape(-1,1,3), v.reshape(-1,3,1)).reshape(-1,1)
                c = torch.bmm(v.reshape(-1,1,3), v.reshape(-1,3,1)).reshape(-1,1)
                d = torch.bmm(u.reshape(-1,1,3), w0.reshape(-1,3,1)).reshape(-1,1)
                e = torch.bmm(v.reshape(-1,1,3), w0.reshape(-1,3,1)).reshape(-1,1)
                f = torch.bmm(w0.reshape(-1, 1, 3), w0.reshape(-1, 3, 1)).reshape(-1, 1)


                length.append(torch.where((a*c - b*b).squeeze() == 0 , ((f - (e/c**0.5)**2) **0.5).squeeze() , torch.norm(((q1 + act((a*e - d*b) / (a*c - b*b)) * v)-(p1 + act((b*e - c*d) / (a*c - b*b)) * u)),p=2,dim=1)))
                name.append([j,k])
    length = torch.stack(length)
    # length = torch.where(length<30, 0, 1).reshape(-1,1)
    for a in range(length.shape[0]):

        if name[a][0][0] == "R" or name[a][0][0] == "L":
            b = eval(name[a][0][2:])
        else:
            b = eval(name[a][0])
        if name[a][1][0] == "R" or name[a][1][0] == "L":
            c = eval(name[a][1][2:])
        else:
            c = eval(name[a][1])
        limit = torch.stack([b,c]).max(0)[0]

        # print(name[a], limit.min(), length[a].max())
        length[a] = torch.where(length[a] <= limit, 0, 1)
    length = length.mean(0)


    length = torch.where(length ==1, 1.,0.).reshape(-1,1)

    # print(length)

    # print(length)
    # exit()
    
    return length

def cycle(input):
    while abs(input).max() > 1:
        input = torch.where(input > 1, 2 - input, input)
        input = torch.where(input < -1, -2 - input, input)

    return input

def freepose_(batch_size, frame):
    n = 0
    p3d_set = []
    p2d_set = []
    
    rand_num = (torch.rand((batch_size, 34), device='cuda') * 2 - 1)
    ramdom_p3d = human_model3(rand_num)
    
    ramdom_p3d, ramdom_p2d, rot, trans = rand_position(ramdom_p3d)
    rand_tem = ((torch.randn((batch_size, 24), device='cuda')) / 10)

    tem_para = [rand_num * 1]
    tem_3d = [ramdom_p3d * 1]
    tem_2d = [ramdom_p2d * 1]

    for i in range(frame -1):
        rand_num[:, 8:32] += rand_tem
        rand_num[:, 8:32] = cycle(rand_num[:, 8:32])
        tem_para.append(rand_num * 1)
        ramdom_p3d = human_model3(rand_num)
        ramdom_p3d, ramdom_p2d, _, _ = rand_position(ramdom_p3d, rot, trans)
        tem_3d.append(ramdom_p3d)
        tem_2d.append(ramdom_p2d)
        
    tem_2d = torch.stack(tem_2d, dim=1).reshape(-1, frame * 32)
    p3d = tem_3d[int(frame/2)]
    distance = get_distance(p3d) 

    return p3d, tem_2d, distance


def freepose(batch_size, frame, aug = True):
    
    n = 0
    p3d_set = []
    p2d_set = []
    with torch.no_grad():  # Disable backpropagation
    
        while n < batch_size:
            rand_num = (torch.rand((batch_size, 34), device='cuda') * 2 - 1)
            ramdom_p3d = human_model_for_h36m(rand_num)
            ramdom_p3d, ramdom_p2d, rot, trans = rand_position(ramdom_p3d)
            rand_tem = ((torch.randn((batch_size, 24), device='cuda')) / 10)

            tem_para = [rand_num * 1]
            tem_3d = [ramdom_p3d * 1]
            tem_2d = [ramdom_p2d * 1]

            for i in range(frame -1):
                rand_num[:, 8:32] += rand_tem
                rand_num[:, 8:32] = cycle(rand_num[:, 8:32])
                tem_para.append(rand_num * 1)
                ramdom_p3d = human_model_for_h36m(rand_num)
                ramdom_p3d, ramdom_p2d, _, _ = rand_position(ramdom_p3d, rot, trans)
                tem_3d.append(ramdom_p3d)
                tem_2d.append(ramdom_p2d)
            
            tem_2d = torch.stack(tem_2d, dim=1).reshape(-1, frame * 34)
            p3d = tem_3d[int(frame/2)]
            distance = get_distance(p3d) 
            index = torch.where(distance == 1)[0]
            p3d_set.append(p3d[index])
            p2d_set.append(tem_2d[index])

            n += index.shape[0]
        p3d_set = torch.cat(p3d_set, dim=0)[:batch_size].detach()
        p2d_set = torch.cat(p2d_set, dim=0)[:batch_size].detach()

    return p3d_set, p2d_set, distance

def freepose_for_dis(batch_size):
    rand_num = (torch.rand((batch_size, 34), device='cuda') * 3 - 1.5)
    rand_num[:, :8] = cycle(rand_num[:, :8])
    rand_num[:, 32:] = cycle(rand_num[:, 32:])
    distance = torch.where((rand_num > 1) | (rand_num < -1), 1., 0.).max(1)[0][:, None]
    distance = (distance - 1) * -1
    ramdom_p3d = human_model_for_h36m(rand_num)
    ramdom_p3d, ramdom_p2d, rot, trans = rand_position(ramdom_p3d)
    
    distance *= get_distance(ramdom_p3d) 

    return ramdom_p3d, distance

def best_of_mloss(pred,gt,TOP_K = 5):
    
    # pred shape : [batch, hypo, 3 * 16]
    # gt shape   : [batch, 3 * 16]
    shape = pred.shape
    pred = pred.reshape(-1, 3, 16)
    gt = gt.reshape(-1, 3, 16)
    pred = (pred- pred.mean(-1)[:,:,None]).reshape(shape)
    gt = gt- gt.mean(-1)[:,:,None]
    batch_size = pred.shape[0]
    n_hypo = pred.shape[1]
    x_pred_hypos = pred.permute(1,0,2).view(n_hypo, batch_size, 3,16)

    errors_mpjpe = 1000 * torch.mean(torch.sqrt(torch.sum((gt[None] - x_pred_hypos) ** 2, dim=2)), dim=2)

    # compute mean of TOP_K best poses:
    x_pred_hypos = x_pred_hypos.view(n_hypo, batch_size, 3*16).transpose(0, 1)
    indices = torch.argsort(errors_mpjpe, dim=0, descending=False).transpose(0, 1)
    sorted_3d_preds = x_pred_hypos[torch.arange(batch_size)[:, None], indices]
    best_k_hypos = sorted_3d_preds[:, 0 : TOP_K].view(batch_size,TOP_K,3,16)

    loss_mb = (((abs(best_k_hypos - gt[:, None, :]) + 1e-9) ** 2).sum(2) ** 0.5).mean(-1)
    return loss_mb

def best_of_mloss_jointwise(pred, gt, TOP_K=5):
    """
    Compute the loss by selecting the best predictions for each joint independently.

    Args:
    pred (torch.Tensor): Predicted poses of shape [batch, hypo, 3 * 16]
    gt (torch.Tensor): Ground truth poses of shape [batch, 3 * 16]
    TOP_K (int): Number of top hypotheses to consider for each joint

    Returns:
    torch.Tensor: The computed loss
    """
    batch_size, n_hypo, _ = pred.shape
    pred = pred.reshape(batch_size, n_hypo, 3, 16)
    gt = gt.reshape(batch_size, 3, 16)

    # Normalize poses by subtracting the 7th joint (assuming it's the hip joint)
    pred = pred - pred.mean(-1)[:, :,:,None]
    gt = gt - gt.mean(-1)[:,:,None]

    # Calculate MPJPE (Mean Per Joint Position Error) for each hypothesis
    errors_mpjpe = 1000 * torch.sqrt(torch.sum((gt[:, None, :, :] - pred) ** 2, dim=2))

    # Select top K hypotheses for each joint independently
    top_k_indices = torch.argsort(errors_mpjpe, dim=1)[:, :TOP_K, :]
    top_k_indices_expanded = top_k_indices.unsqueeze(2).expand(-1, -1, 3, -1)
    top_k_predictions = torch.gather(pred, 1, top_k_indices_expanded)
    

    # Compute the loss based on the selected top K predictions for each joint
    loss_mb = (((abs(top_k_predictions - gt[:, None, :, :]) + 1e-9) ** 2).sum(2) ** 0.5).mean(-1).mean(1)
 
    return loss_mb

def best_of_hypo(pred,gt):
    TOP_K = 1
    # pred shape : [batch, hypo, 3 * 16]
    # gt shape   : [batch, 3 * 16]
    batch_size = pred.shape[0]
    n_hypo = pred.shape[1]

    x_pred_hypos = pred.permute(1,0,2).view(n_hypo, batch_size, 3,16).to(dtype=torch.float64)
    x_pred_hypos -= x_pred_hypos.mean(-1)[:, :,:,None]
    gt = gt.view(batch_size, 3, 16).type(torch.float64)
    gt -= gt.mean(-1)[:,:,None]
    x_pred_mean = x_pred_hypos.mean(0)

    pred_dot_pred = torch.einsum('nkc,nkc->n', x_pred_mean, x_pred_mean)
    pred_dot_gt = torch.einsum('nkc,nkc->n', x_pred_mean, gt)

    scale_factor = pred_dot_gt / pred_dot_pred

    x_pred_hypos *= scale_factor[None, :, None, None]

    errors_mpjpe = 1000 * torch.mean(torch.sqrt(torch.sum((gt
                                                            - x_pred_hypos) ** 2, dim=2)), dim=2)

    # compute mean of TOP_K best poses:
    x_pred_hypos = x_pred_hypos.view(n_hypo, batch_size, 3*16).transpose(0, 1)
    indices = torch.argsort(errors_mpjpe, dim=0, descending=False).transpose(0, 1)
    sorted_3d_preds = x_pred_hypos[torch.arange(batch_size)[:, None], indices]
    best_hypo = sorted_3d_preds[:, 0 : TOP_K].view(batch_size,3,16)
    # worst_hypo = sorted_3d_preds[:, -TOP_K :].view(batch_size,3,16)

    return best_hypo #, worst_hypo

def add_joint(p3d):
    shape = p3d.shape
    p3d = p3d.reshape(-1, 16)
    index = [6, 0, 1, 6, 3, 4,None , 6, 7, 8, 7, 10, 11, 7, 13, 14]
    p3d[:, 6] = 0 
    for i in range(16):
        if index[i] == None:
            pass
        elif i == 3:
            p3d[:,i] = p3d[:,index[i]] - p3d[:, 0]
        else:
            p3d[:,i] = p3d[:,index[i]] + p3d[:, i]

    # min = ((abs(p3d.min(-1)[0]) - p3d.min(-1)[0])/2 + 1)[:,None]
    # p3d = p3d + min
    # p3d = abs(p3d) + 1
    

    return p3d.reshape(shape)


def get_skeleton_sign(p3d):
    p3d = p3d.reshape(-1, 16)
    index = [6, 0, 1, 6, 3, 4,None , 6, 7, 8, 7, 10, 11, 7, 13, 14]
    p3d[:, 6] = 0 
    for i in range(16):
        if index[i] == None:
            pass
        elif i == 3:
            p3d[:,i] = p3d[:,index[i]] - p3d[:, 0]
        else:
            p3d[:,i] = p3d[:,index[i]] + p3d[:, i]

    p3d = p3d.reshape(-1, 3, 16)
    p3d = p3d - p3d[:, :, 6][:,:,None]
    p3d = p3d.reshape(-1, 48)
    p3d = p3d / p3d.max(-1)[0][:,None]
    return p3d


def cosine_similarity(x1, x2):
    """두 벡터 간의 코사인 유사도 계산"""
    return F.cosine_similarity(x1, x2, dim=1)

def diversity_promotion_loss(features, eps=1e-8):
    """
    효율적인 방법으로 GPU에서 다양성 증진 손실 계산
    :param features: 생성된 샘플들의 특징 벡터 (크기: N x feature_size)
    :param eps: 수치 안정성을 위한 작은 상수
    :return: 계산된 다양성 증진 손실
    """
    N = features.size(0)

    # 노름(normalization)을 통해 특징 벡터 정규화
    normalized_features = F.normalize(features, p=2, dim=1)

    # 정규화된 특징 벡터를 사용하여 코사인 유사도 행렬 계산
    similarity_matrix = torch.matmul(normalized_features, normalized_features.T)

    # 대각선 요소(자기 자신과의 유사도)는 0으로 설정
    similarity_matrix.fill_diagonal_(0)

    # 유사도의 반대 값 계산 (1 - 유사도)
    dissimilarity_matrix = 1 - similarity_matrix

    

    # 코사인 유사도의 합계를 계산하고, 평균 손실 계산
    loss = torch.sum(dissimilarity_matrix) / (N * (N - 1) + eps)
    return loss.mean()

def get_bone(p3d):
    bone = []
    index = [[0 ,1], [0, 7], [8, 9], [9, 10], [8, 11], [11, 12], [12, 13], [1, 2], [2, 3], [7, 8]]
    for i, j in index:
        bone.append(torch.linalg.norm(p3d[:,i] - p3d[:,j], axis=1))
    bone = torch.stack(bone, axis=1)
    return bone