import io
import logging
import os
import tarfile
import urllib.request
from dataclasses import dataclass
from typing import Optional, List, Tuple, Dict, Sequence

import numpy as np
import pandas as pd

Vector = np.ndarray  # [4,]
Positions = np.ndarray  # [..., 4]
Forces = np.ndarray  # [..., 3]


@dataclass
class Configuration:
    positions: Positions  # energy
    signal: Optional[np.array]
    attributes: Optional[np.array]



Configurations = List[Configuration]


def random_train_valid_split(items: Sequence, valid_fraction: float, seed: int) -> Tuple[List, List]:
    assert 0.0 < valid_fraction < 1.0

    size = len(items)
    train_size = size - int(valid_fraction * size)

    indices = list(range(size))
    rng = np.random.default_rng(seed)
    rng.shuffle(indices)

    return [items[i] for i in indices[:train_size]], [items[i] for i in indices[train_size:]]


def download_url(url: str, save_path: str) -> None:
    with urllib.request.urlopen(url) as download_file:
        with open(save_path, 'wb') as out_file:
            out_file.write(download_file.read())


def fetch_archive(path: str, url: str, force_download=False) -> None:
    if not os.path.exists(path) and not force_download:
        logging.info(f'Downloading {url} to {path}')
        download_url(url=url, save_path=path)
    else:
        logging.info(f'File {path} exists')


def config_from_particles(particles: np.array, top_n_objs: int = 64) -> Configuration:
    padded_positions = particles[:-6]
    padded_positions = padded_positions[:top_n_objs*4]
    positions_particles = padded_positions[padded_positions!=0]
    num_particles = positions_particles.shape[0] // 4
    positions_particles = positions_particles.reshape(num_particles,4)
    signal = particles[-1]
    beam_positions = np.stack(([2,0,0,1],[2,0,0,-1]))
    positions = np.concatenate((positions_particles, beam_positions))
    attributes = np.concatenate((np.ones(num_particles),np.zeros(2)))
    return Configuration(positions=positions, signal=signal, attributes=attributes)

def unpack_configs_from_hdf5(path: str, dataset='table', top_n_objs: int= 64) -> Dict[str, Configurations]:
    extracted_data: Dict[str, Configurations]
    hdf_file = pd.read_hdf(path,dataset)
    content = hdf_file.to_numpy()
    configs = [config_from_particles(config, top_n_objs) for config in content]
    extracted_data = configs
    return extracted_data

def minkowski_norm(vector: np.array, #[batch,4]
    ) -> np.array :
    vector_squared = np.power(vector,2)
    return 2 * vector_squared[..., 0] - vector_squared.sum(axis=-1)