from models import vision_transformer as vits
import torch.nn as nn

class StudentModel(nn.Module):
    def __init__(self, backbone, projector):
        super(StudentModel, self).__init__()
        self.backbone = backbone
        self.projector = projector

    def forward(self, images, store_qkv=None, energy_align=False, align_strength=0.01, verbose=False, logger=None):
        # Pass energy_align to the backbone's forward method
        x = self.backbone(images, store_qkv=store_qkv, energy_align=energy_align, align_strength=align_strength, verbose=verbose, logger=logger)
        x = self.projector(x)
        return x
    
def create_model(device='cuda', model_path=None, store_qkv_layers=[]):
    backbone = vits.__dict__['vit_base'](store_qkv_layers=store_qkv_layers)
    if model_path:
        backbone.load_param_finetune(model_path)
    
    return backbone.to(device)