# sample unit type and resource scale for rule base ai
from collections import defaultdict
import random
import torch
import numpy as np


class RuleAISampler:
    def __init__(self,
                 adversarial,
                 min_resource_scale,
                 max_resource_scale,
                 win_rate_decay,
                 unit_types=['SPEARMAN', 'SWORDMAN', 'CAVALRY', 'ARCHER', 'DRAGON']):
        self.adversarial = adversarial
        self.min_resource_scale = min_resource_scale
        self.max_resource_scale = max_resource_scale
        self.win_rate_decay = win_rate_decay
        self.unit_types = unit_types
        self.utype_idx = range(len(unit_types))

        self.win_rate_stats = [1.0 for _ in self.utype_idx]
        # for ut in self.utype_idx:
        #     self.win_rate_stats[ut] = 1.0
        self.recent_utype_count = defaultdict(int)
        self.recent_utype_scale = defaultdict(list)

    def get_resource_scale(self, win_rate):
        scale = (self.max_resource_scale
                 + (self.min_resource_scale - self.max_resource_scale) * win_rate)
        return scale

    def feed(self, batch):
        utype = batch['utype'].squeeze(1).numpy()
        reward = batch['r'].squeeze(1).numpy()

        new_utype = []
        resource_scale = []

        for ut, r in zip(utype, reward):
            if r < 0:
                r = 0
            self.win_rate_stats[ut] *= self.win_rate_decay
            self.win_rate_stats[ut] += (1 - self.win_rate_decay) * r

            if self.adversarial:
                new_ut = random.choices(
                    self.utype_idx, weights=self.win_rate_stats)[0]
            else:
                new_ut = random.choices(self.utype_idx)[0]

            new_utype.append(new_ut)
            self.recent_utype_count[new_ut] += 1

            scale = self.get_resource_scale(self.win_rate_stats[new_ut])
            resource_scale.append(scale)
            self.recent_utype_scale[new_ut].append(scale)

        reply = {
            'utype': torch.tensor(new_utype).unsqueeze(1),
            'resource_scale': torch.tensor(resource_scale).unsqueeze(1),
        }
        return reply

    def log(self):
        print('unit type: win rate, scale, recent percent')
        count_sum = sum(self.recent_utype_count.values())
        for ut in self.utype_idx:
            name = self.unit_types[ut]
            win = self.win_rate_stats[ut]
            scale = np.mean(self.recent_utype_scale[ut])
            if count_sum == 0:
                percent = 0
            else:
                percent = self.recent_utype_count[ut] / count_sum
            print('%s: %.2f, %.2f, %.2f' % (name, win, scale, percent))

    def reset(self):
        self.recent_utype_count = defaultdict(int)
