from __future__ import print_function

import torch
import torch.nn as nn
import math
import numpy as np


model_output_dim_dict = {
    'resnet18': 512,
    'resnet34': 512,
    'resnet50': 2048,
    'wide_resnet50_2': 2048,
    'resnet101': 2048
}


class AttrLinearClassifier(nn.Module):
    """Linear classifier"""
    def __init__(self, name='resnet18',  head='mlp', num_classes=1):
        super(AttrLinearClassifier, self).__init__()
        print(model_output_dim_dict,name)
        dim_in = model_output_dim_dict[name]
        if head == 'linear':
            self.fc = nn.Linear(dim_in, dim_in)
        elif head == 'mlp':
            self.fc = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, num_classes)
            )

    def forward(self, features):
        # print('Before : ', features.size())
        features = torch.flatten(features, 1)
        # print('After : ', features.size())
        return self.fc(features)
