import parameterfree
import torch
import torch.nn as nn
import torch.nn.functional as F

class MnistNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        self.optim = parameterfree.COCOB(self.parameters())

    def clone(self):
        other = MnistNet()
        other.load_state_dict(self.state_dict())
        other.optim = parameterfree.COCOB(other.parameters())
        other.optim.load_state_dict(self.optim.state_dict())
        return other

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        # softmax, then take log
        output = F.log_softmax(x, dim=1)
        return output

    def predict(self, x):
        self.eval()
        return torch.exp(self(x))

    def bandit_learn(self, x, a, r):
        self.train()
        self.optim.zero_grad()
        logprobs = self(x)
        indexed_logprobs = logprobs[range(logprobs.shape[0]), a]
        indexed_probs = torch.exp(torch.clamp(indexed_logprobs, min=-18, max=-1e-3))
        loss = -( torch.dot(indexed_logprobs, r) + torch.dot(torch.log1p(-indexed_probs), 1 - r) ) / logprobs.shape[0]
        loss.backward()
        self.optim.step()
        return loss.item()

    def masked_bandit_learn(self, x, observed, r):
        if observed.any():
            self.train()
            self.optim.zero_grad()
            logprobs = self(x)
            observed_logprobs = logprobs[observed > 0]
            observed_probs = torch.exp(torch.clamp(observed_logprobs, min=-18, max=-1e-3))
            observedr = r[observed > 0]
            loss = - ( torch.dot(observed_logprobs.flatten(), observedr.flatten()) + torch.dot(torch.log1p(-observed_probs.flatten()), 1 - observedr.flatten()) ) / logprobs.shape[0]
            loss.backward()
            self.optim.step()
            return loss.item()
        else:
            return None

    def learn(self, x, y):
        self.train()
        self.optim.zero_grad()
        output = self(x)
        loss = F.nll_loss(output, y)
        loss.backward()
        self.optim.step()
        return loss.item()

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.optim = parameterfree.COCOB(self.parameters())

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

    def predict(self, x):
        self.eval()
        return torch.exp(self(x))

    def bandit_learn(self, x, a, r):
        self.train()
        self.optim.zero_grad()
        logprobs = self(x)
        indexed_logprobs = logprobs[range(logprobs.shape[0]), a]
        # [bs]
        indexed_probs = torch.exp(torch.clamp(indexed_logprobs, min=-18, max=-1e-3))
        loss = -( torch.dot(indexed_logprobs, r) + torch.dot(torch.log1p(-indexed_probs), 1 - r) ) / logprobs.shape[0]
        loss.backward()
        self.optim.step()
        return loss.item()

class Offline_Cls(nn.Module):
    def __init__(self):
        super(Offline_Cls, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
       
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        output = nn.Softmax(dim=1)(x)
        return output

class Predictor(nn.Module):
    def __init__(self):
        super(Predictor, self).__init__()
        self.conv1 = nn.Conv2d(2, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)  # Concatenate X and Y along the channel dimension
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        x = nn.Softmax(dim=1)(x)
        return x
