from typing import Dict, Optional, Union, Tuple

import torch
from mmcv.cnn import build_conv_layer

from mmrazor.registry import MODELS
from .base_connector import BaseConnector


# implement a simple 1*1 connector for knowledge distillation
@MODELS.register_module()
class SingleConvConnector(BaseConnector):

    def __init__(self, conv_cfg : Dict, in_channels: int, out_channels: int, kernel_size: int = 1, init_cfg: Optional[Dict] = None):
        super(SingleConvConnector, self).__init__(init_cfg)
        self.conv = build_conv_layer(conv_cfg, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size)

    def forward_train(self, feature: Union[torch.Tensor, Tuple[torch.Tensor]]) -> Union[torch.Tensor, Tuple[torch.Tensor]]:

        if isinstance(feature, tuple):
            feature = (self.conv(feature[0]), ) + feature[1:]
        elif isinstance(feature, list):
            feature[0] = self.conv(feature[0])
            feature = tuple(feature)
        else:
            feature = (self.conv(feature), )

        return feature
