# in_channels =256
# out_channels = 128

# def output_size(H_in,stride,padding,kernel_size):
#     return (H_in-1)*stride - 2*padding + 1*(kernel_size-1) + 0 + 1
# H_in = 6 
# stride = 2
# padding = 0
# kernel_size = 4

# H_out1 = output_size(12,2,0,4)
# print(H_out1)
# H_out2 = output_size(H_out1,2,0,4)
# print(H_out2)
# H_out3 = output_size(H_out2,2,0,5)
# print(H_out3)
# H_out4 = output_size(H_out3,2,0,4)
# print(H_out4)
import torch
import torch.nn as nn

class VaeCNNEncoder(nn.Module):
    def __init__(self, latent_size = 512, input_channel = 3):
        super(VaeCNNEncoder, self).__init__()
        self.latent_size = latent_size
        self.main = nn.Sequential(
            nn.Conv2d(input_channel, 32, 4, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2), nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2), nn.ReLU()
        )
        self.linear_mu = nn.Linear(36864, latent_size)

    def forward(self, x):
        x = self.main(x/255.0)
        x = x.view(x.size(0), -1)
        output = self.linear_mu(x)
        return output

# example 
latent_size = 32
encoder_path = "/home/andykim0723/LUSR/checkpoints/encoder_main_ithor_cnn.pt"
model = VaeCNNEncoder(latent_size=latent_size)

weights = torch.load(encoder_path, map_location=torch.device('cpu'))
for k in list(weights.keys()):
    if k not in model.state_dict().keys():
        del weights[k]
model.load_state_dict(weights)


rand = torch.rand((1,3,224,224))
output = model.forward(rand)
print(output.shape)