import torch
import unittest
from utils import *


class UtilsTest(unittest.TestCase):
    def test_get_ellipses(self):
        covs = torch.tensor(
            [[1, 2, 0, 0],
             [1, 2, torch.pi, 0], # ])
             [0.0023367926478385925, 1.8339891028062993e-07, 0.10819325596094131, -0.3256620168685913]])
        print(covs)
        print()
        for a, b, theta in get_ellipses(covs):
            print(torch.stack([a, b, theta], -1))

        n_samples = 256
        covs = torch.rand(n_samples, 4)
        covs[:, :2] = (covs[:, :2] * 40 - 20).exp() # a, b
        covs[:, 2:] = covs[:, 2:] * 2 * torch.pi - torch.pi
        norm = covs[:, :2].norm(dim=-1)

        for a, b, theta in get_ellipses(covs):
            self.assertEqual((a > norm).sum(), 0, torch.where(a > norm))
            self.assertEqual((b > norm).sum(), 0, torch.where(b > norm))

            self.assertEqual((a == 0).sum(), 0)
            self.assertEqual((b == 0).sum(), 0)

            self.assertEqual(a.isnan().sum(), 0)
            self.assertEqual(b.isnan().sum(), 0)
            self.assertEqual(theta.isnan().sum(), 0)

    def test_get_ellipse_info(self):
        # AX2 + BXY + CY2 + F = 0
        A = torch.tensor([ 1., 5/8,  5/8])
        B = torch.tensor([ 0., 3/4, -3/4])
        C = torch.tensor([ 1., 5/8,  5/8])
        F = torch.tensor([-1., -1.,  -1.])

        a = [1., 2., 2.]
        b = [1., 1., 1.]
        t = [0., torch.pi/4, 3*torch.pi/4]
        
        a_, b_, t_ = get_ellipse_info(A, B, C, F)
        print(a, a_)
        print(b, b_)
        print(t, t_)


if __name__ == '__main__':
    unittest.main()

