from torch.profiler import profile, record_function, ProfilerActivity
import torch
from fixation_prediction.tasks import Ecoset100PreTrainedNoisyRetinaBlurS2500WRandomScalesWClickmeFixationPredictionXResNet2x18
import sys
from importlib import import_module

def get_task_class_from_str(s):
    split = s.split('.')
    modstr = '.'.join(split[:-1])
    cls_name =  split[-1]
    mod = import_module(modstr)
    task_cls = getattr(mod, cls_name)
    return task_cls

inputs = torch.randn(5, 3, 224, 224).cuda()
task = get_task_class_from_str(sys.argv[1])
p = task().get_model_params(); m=p.cls(p).cuda();

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
  m(inputs)
#print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=5))
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
