import os
import json
import time
import numpy as np
from collections import defaultdict

import torch
from tensorboardX import SummaryWriter

from utils.misc import set_random_seed
from utils.logger import write_to_record_file, print_progress, timeSince
from utils.distributed import init_distributed, is_default_gpu
from utils.distributed import all_gather, merge_dist_results

from utils.data import ImageFeaturesDB
from r2r.data_utils import construct_instrs
from r2r.env import R2RNavBatch
from r2r.parser import parse_args

from models.vlnbert_init import get_tokenizer
from r2r.agent import GMapNavAgent


def build_dataset(args, rank=0, is_test=False):
    tok = get_tokenizer(args)

    feat_db = ImageFeaturesDB(args.img_ft_file, args.image_feat_size, args.img_ft_file_sd, args.aug_prob)

    dataset_class = R2RNavBatch

    # because we don't use distributed sampler here
    # in order to make different processes deal with different training examples
    # we need to shuffle the data with different seed in each processes
    if args.aug is not None:
        aug_instr_data = construct_instrs(
            args.anno_dir, args.dataset, [args.aug],
            tokenizer=args.tokenizer, max_instr_len=args.max_instr_len,
            is_test=is_test
        )
        feat_db_aug = ImageFeaturesDB(args.img_ft_file, args.image_feat_size, args.img_ft_file_sd, args.aug_prob,
                                      aug_dataset=args.sepenv)
        aug_env = dataset_class(
            feat_db_aug, aug_instr_data, args.connectivity_dir,
            batch_size=args.batch_size, angle_feat_size=args.angle_feat_size,
            seed=args.seed + rank, sel_data_idxs=None, name='aug',
        )
    else:
        aug_env = None

    train_instr_data = construct_instrs(
        args.anno_dir, args.dataset, ['train'],
        tokenizer=args.tokenizer, max_instr_len=args.max_instr_len,
        is_test=is_test
    )
    train_env = dataset_class(
        feat_db, train_instr_data, args.connectivity_dir,
        batch_size=args.batch_size,
        angle_feat_size=args.angle_feat_size, seed=args.seed + rank,
        sel_data_idxs=None, name='train',
    )

    # val_env_names = ['val_train_seen']
    val_env_names = ['val_unseen']
    # if args.dataset == 'r4r' and (not args.test):
    #     val_env_names[-1] == 'val_unseen_sampled'

    if args.submit and args.dataset != 'r4r':
        val_env_names.append('test')

    val_envs = {}
    for split in val_env_names:
        val_instr_data = construct_instrs(
            args.anno_dir, args.dataset, [split],
            tokenizer=args.tokenizer, max_instr_len=args.max_instr_len,
            is_test=is_test
        )
        val_env = dataset_class(
            feat_db, val_instr_data, args.connectivity_dir, batch_size=args.batch_size,
            angle_feat_size=args.angle_feat_size, seed=args.seed + rank,
            sel_data_idxs=None if args.world_size < 2 else (rank, args.world_size), name=split,
        )  # evaluation using all objects
        val_envs[split] = val_env

    return train_env, val_envs, aug_env


def eval_all(args, train_env, val_envs, aug_env=None, rank=-1):
    default_gpu = is_default_gpu(args)

    if default_gpu:
        with open(os.path.join(args.log_dir, 'training_args.json'), 'w') as outf:
            json.dump(vars(args), outf, indent=4)
        writer = SummaryWriter(log_dir=args.log_dir)
        record_file = os.path.join(args.log_dir, 'train.txt')
        write_to_record_file(str(args) + '\n\n', record_file)

    # resume file
    srs = []
    spls = []
    ckpt_names = []

    ckpt_dir = args.bert_ckpt_file
    for file in os.listdir(ckpt_dir):
        if ".pt" in file:
            args.bert_ckpt_file = os.path.join(ckpt_dir, file)
            agent_class = GMapNavAgent
            listner = agent_class(args, train_env, rank=rank)

            # start_iter = listner.load(args.resume_file + "/" + file)
            # if default_gpu:
            #     write_to_record_file(
            #         "\nLOAD the model from {}, iteration ".format(args.resume_file + "/" + file, start_iter),
            #         record_file
            #     )

            loss_str = "validation before training"
            for env_name, env in val_envs.items():
                listner.env = env
                # Get validation distance from goal under test evaluation conditions
                listner.test(use_dropout=False, feedback='argmax', iters=None, future=args.future)
                preds = listner.get_results()
                # gather distributed results
                preds = merge_dist_results(all_gather(preds))

                score_summary, _ = env.eval_metrics(preds)
                loss_str += ", %s " % env_name
                for metric, val in score_summary.items():
                    loss_str += ', %s: %.2f' % (metric, val)

                srs.append(score_summary['sr'])
                spls.append(score_summary['spl'])
                ckpt_names.append(file)

            if default_gpu:
                write_to_record_file(loss_str, record_file)

    print("---------Summary----------")
    print(srs)
    print(spls)
    print(ckpt_names)

def main():
    args = parse_args()

    if args.world_size > 1:
        rank = init_distributed(args)
        torch.cuda.set_device(args.local_rank)
    else:
        rank = 0

    set_random_seed(args.seed + rank)
    train_env, val_envs, aug_env = build_dataset(args, rank=rank, is_test=args.test)

    eval_all(args, train_env, val_envs, aug_env=aug_env, rank=rank)


if __name__ == '__main__':
    main()
