import logging
from typing import Dict

import numpy as np
import torch

TensorDict = Dict[str, torch.Tensor]


def to_one_hot(indices: torch.Tensor,
               num_classes: int,
               device=None) -> torch.Tensor:
    """
    Generates one-hot encoding with <num_classes> classes from <indices>
    :param indices: (N x 1) tensor
    :param num_classes: number of classes
    :param device: torch device
    :return: (N x num_classes) tensor
    """
    shape = indices.shape[:-1] + (num_classes, )
    oh = torch.zeros(shape, device=device).view(shape)

    # scatter_ is the in-place version of scatter
    oh.scatter_(dim=-1, index=indices, value=1)

    return oh.view(*shape)


def count_parameters(module: torch.nn.Module) -> int:
    return int(sum(np.prod(p.shape) for p in module.parameters()))


def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict:
    return {k: v.to(device) for k, v in td.items()}


def set_seeds(seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)


def to_numpy(t: torch.Tensor) -> np.ndarray:
    return t.cpu().detach().numpy()


def init_device(device_str: str) -> torch.device:
    if device_str == 'cuda':
        assert (torch.cuda.is_available()), 'No CUDA device available!'
        logging.info(
            f'CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}'
        )
        torch.cuda.init()
        return torch.device('cuda')

    logging.info('Using CPU')
    return torch.device('cpu')


dtype_dict = {'float16': torch.float16,
              'float32': torch.float32,
              'float64': torch.float64}


def set_default_dtype(dtype: str) -> None:
    torch.set_default_dtype(dtype_dict[dtype])


def get_complex_default_dtype() -> str:
    dtype = torch.get_default_dtype()
    if dtype == torch.float16:
        return torch.complex32
    elif dtype == torch.float32:
        return torch.complex64
    elif dtype == torch.float64:
        return torch.complex128
    else:
        raise ValueError(f'Unknown dtype {dtype}')