from typing import Optional, Tuple, List
import subprocess

import numpy as np
from tensorflow.keras import Model
from sacred import Ingredient

import rinokeras as rk
from rinokeras.models.transformer import TransformerInputEmbedding, TransformerEncoder


transformer_hparams = Ingredient('transformer')


@transformer_hparams.config
def configure_transformer():
    n_layers = 12  # noqa: F841
    n_heads = 8  # noqa: F841
    d_model = 512  # noqa: F841
    d_filter = 4 * d_model  # noqa: F841
    dropout = 0.1  # noqa: F841
    layer_dropout = 0.  # noqa: F841
    kernel_regularizer = None  # noqa: F841


class Transformer(Model):

    @transformer_hparams.capture
    def __init__(self,
                 n_symbols: int,
                 n_layers: int = 12,
                 n_heads: int = 8,
                 d_model: int = 512,
                 d_filter: int = 2048,
                 dropout: Optional[float] = 0.1,
                 layer_dropout: Optional[float] = None,
                 kernel_regularizer: Optional[str] = None) -> None:
        print("Creating Transformer with {} layers, {} hidden size, {} filter size".format(n_layers, d_model, d_filter))
        super().__init__()
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_filter = d_filter
        self.kernel_regularizer = kernel_regularizer

        input_embedding = TransformerInputEmbedding(
            d_model, discrete=True, n_symbols=n_symbols, dropout=dropout,
            concat_position_encoding=True, reproject_position_encoding=True)

        self.encoder = TransformerEncoder(
            input_embedding, n_layers, n_heads, d_model, d_filter, dropout, layer_dropout)

    def call(self, inputs):
        """
        Args:
            sequence: tf.Tensor[int32] - Amino acid sequence,
                a padded tensor with shape [batch_size, MAX_PROTEIN_LENGTH]

            protein_length: tf.Tensor[int32] - Length of each protein in the sequence, a tensor with shape [batch_size]

        Output:
            encoder_output: tf.Tensor[float32] - embedding of each amino acid
                a tensor with shape [batch_size, MAX_PROTEIN_LENGTH, d_model]
        """

        sequence = inputs['sequence']
        protein_length = inputs['protein_length']

        attention_mask = rk.utils.convert_to_attention_mask(sequence, protein_length)

        encoder_output = self.encoder(sequence, mask=attention_mask)
        inputs['encoder_output'] = encoder_output
        return inputs

    @property
    def boundaries(self) -> Tuple[List[int], List[int]]:
        nvidia_smi = subprocess.check_output('nvidia-smi')
        memsize = list(filter(lambda word: 'MiB' in word, nvidia_smi.decode().split()))[1]
        memsize = int(memsize[:-3]) // 1000  # number of gigabytes on gpu
        boundaries = [
            (100, 4),
            (200, 3),
            (300, 2),
            (400, 1.5),
            (500, 1),
            (600, 0.9),
            (700, 0.9),
            (800, 0.8),
            (900, 0.65),
            (1000, 0.6),
            (1100, 0.5),
            (1200, 0.5),
            (1300, 0.4),
            (1400, 0.3),
            (1500, 0.3),
            (1600, 0.2),
            (1700, 0.2),
            (1800, 0.1),
            (2000, 0.1)]

        bounds = [b[0] for b in boundaries]
        sizes = [b[1] for b in boundaries]
        sizes.append(0)

        bounds_array = np.array(bounds)
        sizes_array = np.array(sizes)

        sizes_array *= (memsize * 12 / self.n_layers) * 0.8

        sizes_array = np.asarray(sizes_array, np.int32)
        sizes_array[sizes_array <= 0] = 1

        print(sizes_array)

        return bounds_array, sizes_array
