"""
Library for reading adult.names or adult.test and extracting numerical features.
"""

import numpy

# Instances of Categorical, Continuous and Ignored are "column descriptions",
# which describe how to handle a single column of the CSV file. They can act as
# feature extractors using the "to_features" method, which takes the field's
# value (as a string) and turns it into 0 or more feature values encoded as
# strings. Some columns are ignored, based on previous work [0] [1].
#
# [0] Incognito: Efficient Full-Domain K-Anonymity; https://www.cse.iitb.ac.in/infolab/Data/Courses/CS632/2017/2016/2014/2013/Papers/incognito.pdf on 2022-05-31T15:57:20Z.
# [1] l-Diversity: Privacy Beyond k-Anonymity; https://www.cs.uml.edu/~ge/pdf/papers_685-2-1/ldiversity-icde06.pdf on 2022-05-31T15:58:53Z.

class ColumnDescription:
    """
    A description of a column of the Adult dataset and how to extract a feature
    vector or a single number from it.
    """
    def __init__(self, name):
        """
        name: A short string describing the column.
        """
        self._name = name

    def ignore(self):
        """
        Used by read_as_columns to decide which columns to include. To be
        consistent, it's better if this returns True iff to_features returns an
        empty sequence.
        """
        return False

    def name(self):
        return self._name

    def to_features(self, field_value):
        """
        Given a string field_value (taken directly from the dataset file),
        returns a sequence of feature values. The length of the sequence musn't
        depend on field_value, i.e. a given column always produces the same
        number of features.
        """
        raise NotImplementedError

    def to_number(self, field_value):
        """
        Given a string field_value (taken directly from the dataset file),
        returns an integer representing its value.
        """
        raise NotImplementedError

class Categorical(ColumnDescription):
    """
    A column with categorical values. to_features produces a one-hot feature
    vector; to_number produces a integer from 0 to #categories-1.
    """
    def __init__(self, class_names, **kwargs):
        super().__init__(**kwargs)
        self._class_names = class_names
        self._num_classes = len(class_names)
        self._class_to_index = {
            class_name: i for i, class_name in enumerate(class_names)}

    def num_classes(self):
        return self._num_classes

    def class_names(self):
        return self._class_names

    def to_features(self, field_value):
        """
        Turn it into a one-hot feature vector.
        """
        result = [0] * self._num_classes
        result[self.to_number(field_value)] = 1
        return result

    def to_number(self, field_value):
        return self._class_to_index[field_value]

class Continuous(ColumnDescription):
    """
    A column with continuous values, extracted directly as a single feature.
    """
    def to_features(self, field_value):
        return (self.to_number(field_value),)

    def to_number(self, field_value):
        return float(field_value)

class Ignored(ColumnDescription):
    """
    A column that's ignored: no features are extracted. to_number is not
    implemented.
    """
    def ignore(self):
        return True

    def to_features(self, _field_value):
        return ()

def read_as_columns(lines, columns_override = None):
    """
    Use like this:
        adult_by_column = read_as_columns(open("adult.data"))

    lines should be an iterator returning lines of text from either adult.data
    or adult.test (like a file object). Returns a dict mapping column names to
    numpy arrays of the same length. Numerical columns are returned as-is, and
    categorical values are converted to integers, in both cases using
    ColumnDescription.to_number. Skips columns selected to be ignored
    (ColumnDescription.ignore).

    To get the column description corresponding to an entry in the returned
    dict, look up the entry with the same key in
    adult_column_descriptions_by_name.

    For testing, columns_override can be set to a list of ColumnDescription
    objects.

    >>> read_as_columns(
    ...     (("30, 12, Foo, USA"), ("31, 12, Bar, Canada")),
    ...     columns_override =
    ...         (Continuous(name = "age"),
    ...          Continuous(name = "years-education"),
    ...          Ignored(name = "something"),
    ...          Categorical(("Canada", "USA"), name = "country")))
    {'age': array([30., 31.]), 'years-education': array([12., 12.]), 'country': array([1, 0])}
    """
    column_descriptions = columns_override or adult_column_descriptions
    columns = tuple(zip(*_read_dataset(lines)))
    assert len(columns) == len(column_descriptions)
    result = {}
    for column, column_description in zip(columns, column_descriptions):
        if column_description.ignore():
            continue
        result[column_description.name()] = numpy.array(
            tuple(column_description.to_number(x) for x in column))
    return result

def read_as_features_labels(lines, columns_override = None):
    """
    Use like this:
        training_set = read_as_features_labels(open("adult.data"))
        test_set     = read_as_features_labels(open("adult.test"))

    lines should be an iterator returning lines of text from either adult.data
    or adult.test (like a file object). Returns a sequence of (features, label)
    pairs, where features is a sequence of numbers and label is 0 or 1.

    For testing, columns_override can be set to a list of ColumnDescription
    objects. If it's None (the default), a set of column descriptions which uses
    some of the Adult dataset's columns and ignores others will be used --- see
    adult_column_descriptions.

    >>> read_as_features_labels(
    ...     (("30, 12, Foo, USA"), ("31, 12, Bar, Canada")),
    ...     columns_override =
    ...         (Continuous(name = "age"),
    ...          Continuous(name = "years-education"),
    ...          Ignored(name = "something"),
    ...          Categorical(("Canada", "USA"), name = "country")))
    [([30.0, 12.0], 1), ([31.0, 12.0], 0)]
    """
    column_descriptions = columns_override or adult_column_descriptions
    feature_descriptions = column_descriptions[:-1]
    label_description = column_descriptions[-1]
    result = []
    for fields in _read_dataset(lines):
        assert len(fields) == len(column_descriptions), \
            f"Wrong number of fields: {fields}"
        features = []
        for field_value, column_description in zip(fields[:-1],
                                                   feature_descriptions):
            features.extend(column_description.to_features(
                field_value))
        label = label_description.to_number(fields[-1])
        result.append((features, label))
    return result

#### Descriptions for columns of the Adult dataset

adult_column_descriptions = (
   Continuous(name = "age"),
   Categorical(
        name = "workclass",
        class_names = (
            "Private", "Self-emp-not-inc", "Self-emp-inc", "Federal-gov",
            "Local-gov", "State-gov", "Without-pay", "Never-worked", "?",
        ),
    ),
   Ignored(name = "fnlwght"),
   Categorical(
        name = "education",
        class_names = (
            "Bachelors", "Some-college", "11th", "HS-grad", "Prof-school",
            "Assoc-acdm", "Assoc-voc", "9th", "7th-8th", "12th", "Masters",
            "1st-4th", "10th", "Doctorate", "5th-6th", "Preschool",
        ),
    ),
   Ignored(name = "education-num"),
   Categorical(
        name = "marital-status",
        class_names = (
            "Married-civ-spouse", "Divorced", "Never-married", "Separated",
            "Widowed", "Married-spouse-absent", "Married-AF-spouse",
        ),
    ),
   Categorical(
        name = "occupation",
        class_names = (
            "Tech-support", "Craft-repair", "Other-service", "Sales",
            "Exec-managerial", "Prof-specialty", "Handlers-cleaners",
            "Machine-op-inspct", "Adm-clerical", "Farming-fishing",
            "Transport-moving", "Priv-house-serv", "Protective-serv",
            "Armed-Forces", "?",
        ),
    ),
   Ignored(name = "relationship"),
   Categorical(
        name = "race",
        class_names = (
            "White", "Asian-Pac-Islander", "Amer-Indian-Eskimo", "Other",
            "Black",
        ),
    ),
   Categorical(
        name = "sex",
        class_names = ("Female", "Male"),
    ),
   Ignored(name = "capital-gain"),
   Ignored(name = "capital-loss"),
   Ignored(name = "hours-per-week"),
   Categorical(
        name = "native-country",
        class_names = (
            "United-States", "Cambodia", "England", "Puerto-Rico", "Canada",
            "Germany", "Outlying-US(Guam-USVI-etc)", "India", "Japan",
            "Greece", "South", "China", "Cuba", "Iran", "Honduras",
            "Philippines", "Italy", "Poland", "Jamaica", "Vietnam", "Mexico",
            "Portugal", "Ireland", "France", "Dominican-Republic", "Laos",
            "Ecuador", "Taiwan", "Haiti", "Columbia", "Hungary", "Guatemala",
            "Nicaragua", "Scotland", "Thailand", "Yugoslavia", "El-Salvador",
            "Trinadad&Tobago", "Peru", "Hong", "Holand-Netherlands", "?",
        ),
    ),
   Categorical(
        name = "income",
        class_names = ("<=50K", ">50K"),
    )
)

adult_column_descriptions_by_name = {
    column_description.name(): column_description
    for column_description in adult_column_descriptions}

#### Helpers

def _read_dataset(lines):
    """
    lines should be an iterator over the lines of adult.names or adult.test.
    E.g. open("path/to/adult.data"), or just sys.stdin if the file is on
    standard input.

    Returns an iterable over the rows of the dataset. Each element of the
    iterator is a sequence of column values, as strings. Includes some special
    processing to deal with quirks in the Adult dataset formatting.
    """
    for line in lines:
        stripped_line = line.strip().rstrip(".")
        if stripped_line in ("", "|1x3 Cross validator"):
            continue
        # Within a row, the fields are comma-separated. Split them apart and
        # remove whitespace.
        yield tuple(x.strip() for x in stripped_line.split(","))

#### doctest main

if __name__ == "__main__":
    import doctest
    doctest.testmod()
