import torch
import copy
import time
from auto_LiRPA.utils import stop_criterion_batch_any, multi_spec_keep_func_all

class AutomaticInferenceCut:
    '''
    Automatic inference cutting planes during Branch and Bound procedure.
    To use it:
    Please make sure using the neuron influence score heuristic and make the config:
    cut:
        enabled: Ture
        auto_inference_cuts: True
        drop_ratio: 0.5
        bab_cut: True
        cplex_cuts: Ture/False # both True and False are acceptable
    solver:
        min_batch_size_ratio: 0 # only by this way the score can be calculated correctly
    bab:
        pruning_in_iteration: False # only by this way the score can be calculated correctly
        interm_transfer: True # otherwise the mapping doesn't work
        sort_domain_interval: 1 # not nessisary before the now data structure finished
    '''
    def __init__(self, ret, max_cut_number, max_iter=20, drop_ratio=0.5, cplex_cuts=False):
        self.decision_threshold = 0
        self.max_cut_number = max_cut_number
        self.max_iter = max_iter
        self.cplex_cuts = cplex_cuts
        self.aicp_cuts = []
        self.drop_ratio = drop_ratio
        self.pure_influence_score = True

        self.key_mapping_lA = {key: index for index, key in enumerate(ret['lA'].keys())}
        self.key_mapping_lb = {key: index for index, key in enumerate(ret['lower_bounds'].keys())} 

        self.lA_init = {self.key_mapping_lA[key]: value for key, value in ret['lA'].items()}
        self.lb_init = {self.key_mapping_lb[key]: value for key, value in ret['lower_bounds'].items()}
        self.ub_init = {self.key_mapping_lb[key]: value for key, value in ret['upper_bounds'].items()} 

    def update_cut(self, d, net, ret, enforce_usage, heuristic=None, iter_idx=None):
        print('======================Cut inference begins======================')
        start_cut_time = time.time() # record the start time
        inference_time = None
        _, lbs_final = list(ret['lower_bounds'].items())[-1]
        v_idx = torch.where(lbs_final > self.decision_threshold)[0]
        print(f"Number of Verified Splits: {len(v_idx)} of {len(lbs_final)}")

        if iter_idx <= self.max_iter or enforce_usage:

            if heuristic == 'neuron_influence_score':
                self.neuron_influence_score_cal(d['history'], d['lower_bounds'], lbs_final)
                if iter_idx == 1:
                    print("Neuron influence score heuristic used.")
                    self.bound_diff_score_cal()

            elif heuristic == 'random_drop':
                print("Warning: Random drop heuristic used, performance may be bad.")

            elif heuristic == 'sparse_opt':
                NotImplementedError("Sparse Optimization Heuristic is not implemented yet.")

            else:
                print("Warning: No heuristic is used, performance may be bad.")

            if self.inference_condition(lbs_final) or enforce_usage:
                if len(self.aicp_cuts) < self.max_cut_number:
                    inference_time = time.time() # record the preprocessing time

                    if enforce_usage:
                        tmp_cuts = self.inference_original_cut(d, v_idx)
                    else:
                        tmp_cuts = self.reinforce_cuts(d, net, ret, v_idx, heuristic)

                    add_cuts_time = time.time() # record the inference time
                    unique_number_of_cuts = 0
                    redundant_number_of_cuts = 0
                    for cut in tmp_cuts:
                        if cut not in self.aicp_cuts:
                            self.aicp_cuts.append(cut)
                            if not self.cplex_cuts:
                                net.cutter.cuts.append(cut)
                            unique_number_of_cuts += 1
                        else:
                            redundant_number_of_cuts += 1

                    print(f"{len(tmp_cuts)} cuts inferenced, {unique_number_of_cuts} unique cuts added, {redundant_number_of_cuts} redundant.")

                    if self.cplex_cuts:
                        if net.cutter.cuts is not None:
                            net.cutter.cuts += self.aicp_cuts
                        else:
                            net.cutter.cuts = self.aicp_cuts

                    if unique_number_of_cuts > 0:
                        print("Cuts are synchronized to the solver.")
                        cut_module = net.cutter.construct_cut_module()
                        net.net.cut_module = cut_module
                        for m in net.net.relus:
                            m.cut_module = cut_module

                    cut_analysis_time = time.time() # record the analysis time
                else:
                    print("Stop inferencing: Too much cuts.")
            else:
                print("No cut inferenced: All or none verified.")
        else:
            print("Stop inferencing: Max iteration reached.")

        if net.cutter.cuts:
            self.cut_analysis(net.cutter.cuts)
        else:
            self.cut_analysis(self.aicp_cuts)

        stop_cut_time = time.time()
        print('All time:', stop_cut_time - start_cut_time)
        if inference_time:
            print('Preprocessing time:', inference_time - start_cut_time,
                  'Inference time:', add_cuts_time - inference_time,
                  'Add cuts time:', cut_analysis_time - add_cuts_time,
                  'Cut analysis time:', stop_cut_time - cut_analysis_time)
        print('======================Cut inference ends========================')

    def cut_analysis(self, cuts, max_cut_length=12):
        less_than_ = 0
        equal_to_ = 0
        more_than_ = 0
        total_length = 0
        l = int(max_cut_length/4)
        e = int(max_cut_length/2)
        for cut in cuts:
            length = len(cut['arelu_coeffs']) + len(cut['pre_coeffs']) + len(cut['x_coeffs']) + len(cut['relu_coeffs'])
            total_length += length
            if length <= l:
                less_than_ += 1
            elif length <= e:
                equal_to_ += 1
            else:
                more_than_ += 1
        print(f"Total number of valid cuts: {len(cuts)}.")
        print(f"#cuts len<{l}: {less_than_}, #{l}<cuts len<{e}: {equal_to_}, #{e}<cuts len<{max_cut_length}: {more_than_}")

    def inference_condition(self, lbs_final):
        return not (lbs_final > self.decision_threshold).all() and (lbs_final > self.decision_threshold).any()

    def pick_d(self, v_idx, d):
        d_new = {}
        for key, value in d.items():
            if key in ['history', 'betas', 'intermediate_betas', 'split_history', 'depths']:
                d_new[key] = [value[i] for i in v_idx if i < len(value)]
            if key in ['lAs', 'lower_bounds', 'upper_bounds']:
                d_new[key] = {k: v[v_idx] for k, v in value.items() if v.size(0) >= max(v_idx) + 1}
            if key == 'alphas':
                d_new[key] = {}
                for sub_key, sub_nested_dict in value.items():
                    d_new[key][sub_key] = {}
                    for tensor_key, tensor in sub_nested_dict.items():
                        if tensor.size(2) >= max(v_idx) + 1:
                            d_new[key][sub_key][tensor_key] = tensor[:, :, v_idx, :]
            if key in ['cs', 'thresholds']:
                d_new[key] = value[v_idx]
        return copy.deepcopy(d_new)

    def build_mask(self, net, d_revise):
        def _to(x, non_blocking=True):
            return x.to(device=net.x.device, non_blocking=non_blocking)
        new_masks = {}
        for k in d_revise['lower_bounds']:
            if k not in net.net.split_activations:
                continue
            mask = None
            for activation, index in net.net.split_activations[k]:
                mask_ = _to(
                    activation.get_split_mask(
                        d_revise['lower_bounds'][k], d_revise['upper_bounds'][k], index
                    ).flatten(1).float())
                mask = mask_ if mask is None else torch.logical_or(mask, mask_)
            if mask is None:
                mask = torch.ones_like(d_revise['lower_bounds'][k]).flatten(1)
            new_masks[k] = mask

    def reinforce_cuts(self, d, net, ret, v_idx, heuristic):
        d_revise, tmp_cuts, original_length = self.inference_cut(d, ret, v_idx, heuristic)
        d_revise['mask'] = self.build_mask(net, d_revise)

        ret_revise = net.update_bounds(d_revise, fix_interm_bounds=True,
                                        stop_criterion_func=stop_criterion_batch_any(d_revise['thresholds']),
                                        multi_spec_keep_func=multi_spec_keep_func_all,
                                        beta_bias=None)

        _, lbs_new = list(ret_revise['lower_bounds'].items())[-1]
        v_idx_new = torch.where(lbs_new > self.decision_threshold)[0]
        print(f"Number of Verified Cuts: {len(v_idx_new)} of {len(v_idx)}")

        cuts = []
        for i in range(len(v_idx_new)):
            if len(tmp_cuts[v_idx_new[i]]['arelu_coeffs']) < original_length[i]:
                cuts.append(tmp_cuts[v_idx_new[i]])

        if self.inference_condition(lbs_new):
            print('\nReinforce cuts')
            cuts += self.reinforce_cuts(d_revise, net, ret_revise, v_idx_new, None)
        return cuts

    def inference_original_cut(self, d, v_idx):
        print("Original cuts are inferenced.")
        original_cuts = []
        for j in range(len(v_idx)):
            arelu_decision = []
            arelu_coeffs = []
            bias = 0
            for key, (relu_idx, relu_status, _, _, _) in d['history'][v_idx[j]].items():
                key_int = self.key_mapping_lb[key]
                for i in range(len(relu_idx)):
                    arelu_decision.append([key_int, relu_idx[i].item()])
                    arelu_coeffs.append(relu_status[i].item())
                    bias += relu_status[i].clamp(min=0).item()
            original_cut = self.generate_cut(arelu_decision=arelu_decision, arelu_coeffs=arelu_coeffs, b=bias-1)
            original_cuts.append(original_cut)
        return original_cuts

    def _convert_history_from_list(self, history):
        """Convert the history variables into tensors if they are lists.

        It is because some legacy code creates history as lists.
        """
        if isinstance(history[0], torch.Tensor):
            return history

        return (torch.tensor(history[0], dtype=torch.long),
                torch.tensor(history[1]),
                torch.tensor(history[2]),
                torch.tensor(history[3]),
                torch.tensor(history[4]))

    def inference_cut(self, d, ret, v_idx, heuristic):
        d_revise = self.pick_d(v_idx, d) # deep copy a new d only contains the verified cases
        original_length = []
        tmp_cuts = []

        for j in range(len(v_idx)):
            arelu_decision = []
            arelu_coeffs = []
            bias = 0
            cut_length = 0

            if heuristic == 'neuron_influence_score':
                criterion = self.influence_criterian_get(d['history'][v_idx[j]])

            for key in d['history'][v_idx[j]].keys():
                d['history'][v_idx[j]][key] = self._convert_history_from_list(d['history'][v_idx[j]][key])
            for key, (relu_idx, relu_status, relu_bias, relu_score, depths) in d['history'][v_idx[j]].items():
                key_int = self.key_mapping_lb[key]
                cut_length += len(relu_idx)
                hist_index = []
                hist_split = []
                hist_bias = []
                hist_score = []
                hist_depths = []

                for i in range(len(relu_idx)):

                    if heuristic == 'random_drop':
                        condition = self.random_drop()
                    elif heuristic == 'neuron_influence_score':
                        condition = self.neuron_influence_score(key_int, relu_idx[i], relu_score[i], criterion)
                    elif heuristic == 'sparse_opt':
                        NotImplementedError("Sparse Optimization Heuristic is not implemented yet.")
                    else:
                        condition = False
                    
                    #if ret['betas'][v_idx[j]][key][i] > 0 or condition:
                    if (ret['betas'][v_idx[j]][key][i] > 0 and self.lA_init[key_int][0][0].flatten()[relu_idx[i]] <= 0) or condition:
                        # if (beta > 0 and upper bound used) or condition = True, add the neuron to the cut, else drop the neuron
                        arelu_decision.append([key_int, relu_idx[i].item()])
                        arelu_coeffs.append(relu_status[i].item())
                        bias += relu_status[i].clamp(min=0).item()
                        # record the neuron and split status
                        hist_index.append(relu_idx[i])
                        hist_split.append(relu_status[i])
                        hist_bias.append(relu_bias[i])
                        hist_score.append(relu_score[i])
                        hist_depths.append(depths[i])
                    else:
                        # drop the neuron, recover the bounds. The coresponding histories and betas removed.
                        d_revise['lower_bounds'][key][j].flatten()[relu_idx[i]] = self.lb_init[key_int][0].flatten()[relu_idx[i]]
                        d_revise['upper_bounds'][key][j].flatten()[relu_idx[i]] = self.ub_init[key_int][0].flatten()[relu_idx[i]]

                if hist_split:
                    d_revise['history'][j][key] = (torch.tensor(hist_index), torch.tensor(hist_split), torch.tensor(hist_bias), torch.tensor(hist_score), torch.tensor(hist_depths))
                if d_revise['betas'][j] != None:
                    d_revise['betas'][j][key] = torch.zeros_like(torch.tensor(hist_split))

            original_length.append(cut_length)
            tmp_cut = self.generate_cut(arelu_decision=arelu_decision, arelu_coeffs=arelu_coeffs, b=bias-1)
            tmp_cuts.append(tmp_cut)

        return d_revise, tmp_cuts, original_length

    def random_drop(self):
        from random import choice
        return choice([True, False])

    def neuron_influence_score(self, key_int, idx, relu_score_i, criterian):
        crtn, lmbda = criterian
        if self.bounds_score_value:
            bound_score = self.bounds_score_value[key_int][0].flatten()[idx]
        else:
            bound_score = 0
        score = lmbda * relu_score_i + (1 - lmbda) * bound_score
        return score >= crtn

    def influence_criterian_get(self, hist, lmbda=1):
        bounds_score_all = []
        relu_score_all = []
        for key, (relu_idx, _, _, relu_score, _) in hist.items():
            relu_score_all += relu_score
            if self.bounds_score_value:
                key_int = self.key_mapping_lb[key]
                for i in range(len(relu_idx)):
                    bounds_score_all.append(self.bounds_score_value[key_int][0].flatten()[relu_idx[i]])
        b_score = torch.tensor(bounds_score_all).flatten() if self.bounds_score_value else 0
        n_score = torch.tensor(relu_score_all).flatten()
        score = lmbda * n_score + (1 - lmbda) * b_score
        return score.quantile(self.drop_ratio, interpolation='midpoint'), lmbda

    def bound_diff_score_cal(self):
        if self.pure_influence_score:
            self.bounds_score_value = None
        else:
            self.bounds_score_value = copy.deepcopy(self.lA_init)
            for key in list(self.lA_init.keys())[:-1]:
                for i in range(len(self.lA_init[key].flatten())):
                    if self.lb_init[key].flatten()[i] < 0 and self.ub_init[key].flatten()[i] > 0:
                        scr = (self.lA_init[key][0][0].flatten()[i] *
                            ((self.ub_init[key].flatten()[i] * -self.lb_init[key].flatten()[i]) /
                                (self.ub_init[key].flatten()[i] - self.lb_init[key].flatten()[i]))).abs()
                        self.bounds_score_value[key].flatten()[i] = scr
                    else:
                        self.bounds_score_value[key].flatten()[i] = 0

    def neuron_influence_score_cal(self, d_hist, d_lbs, lbs_final):
        _, d_lbs_final = list(d_lbs.items())[-1]
        d_lbs_final = d_lbs_final.to('cpu')
        lbs_score = lbs_final - d_lbs_final

        for j in range(len(d_lbs_final)):
            for key, (relu_idx, relu_status, relu_bias, relu_score, depths) in d_hist[j].items():
                hist_score = []
                for i in range(len(relu_idx)):
                    if isinstance(relu_score, torch.Tensor) and relu_score[i] != 0:
                        hist_score.append(relu_score[i])
                    else:
                        if lbs_final[j] > 0:
                            hist_score.append(lbs_score[j] + 1e-1)
                        else:
                            hist_score.append(lbs_score[j] + 1e-8)
                if hist_score:
                    d_hist[j][key] = (relu_idx, relu_status, relu_bias, torch.tensor(hist_score).flatten(), depths)
                else:
                    d_hist[j][key] = (relu_idx, relu_status, relu_bias, relu_score, depths)

    def generate_cut(self, x_decision=[], x_coeffs=[], relu_decision=[], relu_coeffs=[], arelu_decision=[], arelu_coeffs=[], pre_decision=[], pre_coeffs=[], b=0, c=-1):
        return {
            'x_decision': x_decision,
            'x_coeffs': x_coeffs,
            'relu_decision': relu_decision,
            'relu_coeffs': relu_coeffs,
            'arelu_decision': arelu_decision,
            'arelu_coeffs': arelu_coeffs,
            'pre_decision': pre_decision,
            'pre_coeffs': pre_coeffs,
            'bias': b,
            'c': c
        }
