#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import random
import json
import numpy as np
import os.path as osp
import torch
import pickle
from utils.system_utils import searchForMaxIteration
from scene.gaussian_model import GaussianModel
from arguments import ModelParams
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON

from scene.dataset_readers import readCTSceneInfo, fetchPly


class Scene:
    gaussians: GaussianModel

    def __init__(
        self,
        args: ModelParams,
        gaussians: GaussianModel,
        load_iteration=None,
        shuffle=True,
        init_from="pcd",
    ):
        """b
        :param path: Path to colmap scene main folder.
        """
        self.model_path = args.model_path
        self.loaded_iter = None
        self.gaussians = gaussians

        if load_iteration:
            if load_iteration == -1:
                self.loaded_iter = searchForMaxIteration(
                    os.path.join(self.model_path, "point_cloud")
                )
            else:
                self.loaded_iter = load_iteration
            print("Loading trained model at iteration {}".format(self.loaded_iter))

        self.train_cameras = {}
        self.test_cameras = {}

        if os.path.exists(os.path.join(args.source_path, "meta_data.json")):
            scene_info = readCTSceneInfo(
                args.source_path, args.eval, osp.join(args.source_path, args.init_path)
            )
        else:
            assert False, "Could not recognize scene type!"

        if not self.loaded_iter:
            with open(scene_info.ply_path, "rb") as src_file, open(
                os.path.join(self.model_path, "input.ply"), "wb"
            ) as dest_file:
                dest_file.write(src_file.read())
            json_cams = []
            camlist = []
            if scene_info.test_cameras:
                camlist.extend(scene_info.test_cameras)
            if scene_info.train_cameras:
                camlist.extend(scene_info.train_cameras)
            for id, cam in enumerate(camlist):
                json_cams.append(camera_to_JSON(id, cam))
            with open(os.path.join(self.model_path, "cameras.json"), "w") as file:
                json.dump(json_cams, file)

        if shuffle:
            random.shuffle(
                scene_info.train_cameras
            )  # Multi-res consistent random shuffling
            random.shuffle(
                scene_info.test_cameras
            )  # Multi-res consistent random shuffling

        print("Loading Training Cameras")
        self.train_cameras = cameraList_from_camInfos(scene_info.train_cameras, args)
        print("Loading Test Cameras")
        self.test_cameras = cameraList_from_camInfos(scene_info.test_cameras, args)

        self.meta_data = scene_info.meta_data

        self.vol_gt = torch.from_numpy(
            np.load(osp.join(args.source_path, self.meta_data["vol"]))
        ).cuda()

        if self.loaded_iter:
            # ! Need to change
            self.gaussians.load_ply(
                os.path.join(
                    self.model_path,
                    "point_cloud",
                    "iteration_" + str(self.loaded_iter),
                    "point_cloud.ply",
                )
            )
        else:
            if init_from[:6] == "random":
                n_point = int(init_from[7:])
                assert n_point > 0, "Specify valid number of random points"
                bbox = torch.tensor(self.meta_data["bbox"])
                pcd = bbox[0] + (bbox[1] - bbox[0]) * torch.rand([n_point, 3])
                density = torch.rand([n_point, 1]) * 0.1
                point_cloud = torch.concat([pcd, density], dim=-1)
                init_from = "random"
            elif init_from == "pcd":
                point_cloud = fetchPly(scene_info.ply_path)
                print(f"Initialize gaussians with pcd {scene_info.ply_path}")
            elif init_from == "pickle":
                with open(scene_info.ply_path[:-3] + "pickle", "rb") as handle:
                    point_cloud = pickle.load(handle)
            else:
                point_cloud = None

            self.gaussians.create_from_pcd(point_cloud, 1.0, init_from)

    def save(self, iteration, queryfunc, pipe):
        point_cloud_path = osp.join(
            self.model_path, "point_cloud/iteration_{}".format(iteration)
        )
        self.gaussians.save_ply(osp.join(point_cloud_path, "point_cloud.ply"))

        # Save volume
        scanner_cfg = self.meta_data["scanner"]
        query_pkg = queryfunc(
            self.gaussians,
            scanner_cfg["offOrigin"],
            scanner_cfg["nVoxel"],
            scanner_cfg["sVoxel"],
            pipe,
        )
        vol_pred = query_pkg["vol"].clip(0.0, 1.0)
        vol_gt = self.vol_gt.clip(0.0, 1.0)

        np.save(osp.join(point_cloud_path, "vol_gt.npy"), vol_gt.detach().cpu().numpy())
        np.save(
            osp.join(point_cloud_path, "vol_pred.npy"), vol_pred.detach().cpu().numpy()
        )

    def getTrainCameras(self):
        return self.train_cameras

    def getTestCameras(self):
        return self.test_cameras
