from torch import nn

from .embedding import Embedding
from .combiner import Combiner


class TextEncoder(nn.Module):
    def __init__(self, config):
        super(TextEncoder, self).__init__()

        self.embedding = Embedding(config.embedding)
        self.combiner = Combiner(config.combiner)

    def forward(self, input_ids, word_mask, use_dropout=True):
        embeddings = self.embedding(input_ids, use_dropout)
        output = self.combiner(embeddings, word_mask)
        return output
