from collections import defaultdict
from itertools import combinations
import numpy as np
from dataset.workloads import downward_closure

def all_k_way(domain, k):
    return list(combinations(domain, k))

class ResidualFrontier:
    def __init__(self, attributes, max_depth, initial_measurements=list(), remeasure=False):
        self.attributes = frozenset(attributes)
        self.max_depth = max_depth
        self.counts = defaultdict(int)
        self.measurements = list()
        self.candidates = [frozenset()]
        self.remeasure = remeasure
        
        initial_measurements.sort(key=lambda x : len(x))
        for m in initial_measurements:
            self.measure(m)
        
    def measure(self, m):
        '''
        Expand the frontier based on having measured the residual query m, whether or not the
        measured residual is removed depends on the attribute self.remeasure
        '''
        if len(m) > self.max_depth:
            raise ValueError(f'the residual {m} is deeper than the max depth {self.max_depth}')
        if m not in self.candidates:
            raise ValueError(f'the residual {m} is not in the candidate set')
        
        if m in self.measurements:
            if not self.remeasure:
                raise ValueError(f'the residual {m} has already been measured, but remeasure is set false')
        else:
            self.measurements.append(m)
            if not self.remeasure:
                self.candidates = [c for c in self.candidates if c != m]
            if len(m) < self.max_depth:
                for att in self.attributes - m:
                    potential_add = m.union(frozenset([att]))
                    self.counts[potential_add] += 1
                    if self.counts[potential_add] == len(potential_add):
                        self.candidates.append(potential_add)
        
    def get_candidates(self):
        return self.candidates

    
def attrMulti(candidate, domain):
    return np.prod([(domain[col] - 1)/domain[col] for col in candidate])

def attrQuot(candidate, domain):
    return np.prod([domain[col] ** -2 for col in candidate])

def domainSize(candidate, domain):
    return np.prod([domain[col] for col in candidate])

def attrSubMQ(candidate, sub, domain):
    return np.prod([attrMulti((col), domain) if col in sub else attrQuot((col), domain) for col in candidate])

def varSum(tau, workloads, domain):
    return np.sum([domainSize(wkload, domain) * attrSubMQ(wkload, tau, domain) 
                   for wkload in workloads if set(tau).issubset(wkload)])

def sigma(candidate, domain, rho):
    return ((1/(2*rho)) * attrMulti(candidate, domain)) ** 0.5

def getOptimalSigmasCF(marginals, rho, domain):
    c = 2 * rho
    
    dc = downward_closure(marginals)
    # calc p
    p = {wkload : attrMulti(wkload, domain) for wkload in dc}
    # calc v
    v = {tau : varSum(tau, marginals, domain) for tau in dc}
    # calc T
    T = (np.sum([(v[wkload] * p[wkload]) ** 0.5  for wkload in dc]) ** 2) / c
    # calc opt sigmas 
    optSigmas = {wkload : ((T * p[wkload])/(c * v[wkload])) ** (0.5/2) for wkload in dc}
    
    return optSigmas