from lpmm.functional import vectorwise_dequant, vectorwise_quant, _max_reduce_except_dim, \
    atom_quant, atom_dequant, prepare_quant_boundary, quant_scaling, create_dynamic_map, nonlinear_quant
from lpmm.utils import symmetric_atom_dequantize, symmetric_atom_quantize, sqnr, test_quant, test_dequant, relerr, abserr
import torch
from copy import deepcopy



shape = (1024, 4096)
x = torch.randn(shape)
signed = True
sign = random_tensor = torch.randint(low=0, high=2, size=shape) * 2 - 1
if signed:
    x = x * sign
x = x.to('cuda')
# x = x.to(memory_format=torch.channels_last)
# print(x)

quant_types = ['log-tensor', 'log-dim0', 'log-dim1', 'log-sm3', 'sm3', 'dim10']
quantizers = [test_quant, test_quant, test_quant, test_quant, vectorwise_quant, vectorwise_quant]
dequantizers = [test_dequant, test_dequant, test_dequant, test_dequant, vectorwise_dequant, vectorwise_dequant]

# for quant_type, quant, dequant in zip(quant_types, quantizers, dequantizers):
#     config = {
#         'b': 3 if quant is test_quant and signed else 4,
#         'quant_type': quant_type,
#         'round_type': 'nearest',
#         'gp_sz': None,
#         'shape': x.shape,

#         'transform': None,
#         'truncated_mode': None,
#         'truncated_factor': 0.95,
#         'truncated_global_factor': 0.9,
#         'signed': signed,

#         'sm3_history': None,
#         'res_scale': 1
#     }
#     qx, md = quant(x, **config)
#     # print(qx, md)
#     x_hat = dequant(qx, **md)
#     # print(x_hat)
#     print(f"relative error: {relerr(x, x_hat)}")
#     print(f"SQNR: {sqnr(x, x_hat)}")

def test_fn(qx):
    qmap = {
        (4, signed): create_dynamic_map(True, 3),
    }
    config = dict(
        gp_sz=2048,
        scale_type=None,
        quant_type=None,
        round_type=None,
        b=4,
        signed=signed,
        shape=x.shape,
        qmap=qmap,
    )
    print(qmap)
    best_relerr = 1e9
    best_cfg = None
    for scale_type in ['tensor', 'dim01', 'dim10', 'sm3']:
        for quant_type in ['linear', 'nonlinear']:
            for round_type in ['sr', 'down', 'up', 'nearest', 'sr1']:
                cfg = (scale_type, quant_type, round_type)
                config.update(dict(
                    scale_type=scale_type,
                    quant_type=quant_type,
                    round_type=round_type,
                ))
                qx, md = vectorwise_quant(x, **config)
                x_hat = vectorwise_dequant(qx, **md)
                rel = relerr(x, x_hat)
                print(cfg, rel)

                if rel < best_relerr:
                    best_relerr = rel
                    best_cfg = cfg

    print(f"best: {best_cfg, best_relerr}")

    # print(qx)
    # for scale_type in ['tensor', 'group', 'dim0', 'dim1', 'dim01', 'dim10', 'sm3']:
    #     sqx, md = quant_scaling(qx, scale_type, **config)
    #     print(sqx, md['max1'])

def test_real_fn(x):
    print(f"test_real_fn run.")
    b = 4
    config = dict(
        gp_sz=256,
        scale_type='rank1-group',
        quant_type='nonlinear',
        round_type='real-nearest',
        b=4,
        signed=signed,
        shape=x.shape,
        qmap={},
    )
    # config['qmap'][(b, signed)] = {} 
    config['qmap'] = create_dynamic_map(signed, b-1).to('cuda')
    config['b'] = b

    for scale in ['rank1-group', 'sm3', 'group2048', 'group128']:
        config['scale'] = scale
        if len(scale) >= 5 and scale[:5] == 'group':
            config['gp_sz'] = int(scale[5:])
            config['scale_type'] = 'group'

        qconfig = deepcopy(config)
        qx, md = vectorwise_quant(x.clone(), **qconfig)
        qconfig.update(md)
        x_hat = vectorwise_dequant(qx, **qconfig)
        # print(qx, x_hat)
        rel = relerr(x, x_hat)
        print(f"b: {b}, rel: {rel}")

def test_memory_format_fn(x):
    print(f"test_memory_format_fn run.")
    b = 8
    quant_type = 'nonlinear'
    config = dict(
        gp_sz=2048,
        scale_type='group',
        quant_type=quant_type,
        round_type='real-nearest',
        b=b,
        signed=signed,
        qmap={},
        fp16_scale=False,
    )
    config['qmap'][(b, signed)] = {} 
    config['qmap'][(b, signed)][quant_type] = create_dynamic_map(signed, b-1).to('cuda')
        
    config['shape'] = x.shape
    print(x.stride(), x.shape, x.layout)
    qx, md = vectorwise_quant(x, **config)
    x_hat = vectorwise_dequant(qx, **md)
    rel = relerr(x, x_hat)
    abser = abserr(x, x_hat)
    print(f"rel: {rel}, abs: {abser}")
    print(x_hat.stride(), x_hat.shape, x_hat.layout)

if __name__ == '__main__':
    # test_fn(x)
    # test_real_fn(x)
    test_real_fn(x)