from io import StringIO
import math
from fractions import Fraction
import random
import numpy as np

from io import StringIO
from pysmt.smtlib.parser import SmtLibParser, SmtLibScript, SmtLibCommand
from pysmt.exceptions import PysmtSyntaxError, PysmtTypeError
from pysmt.shortcuts import Symbol, Int, Real, NotEquals, Equals, Plus, Minus, Times, Div, ToReal, GT, to_smtlib
from pysmt.typing import INT, REAL, BV1

from MathGym.utils import gpt, simplify, solve, utils

DEBUG = False

class mutater:
    def __init__(self, tactic, rm_aux=True, real_aux=True):
        """ 
        tactic: the tactic to generate the solution, by mcmc or all smt
            the mcmc is more efficient, while the enum is more effective
        real_aux: whether allow real variables as aux variables
        rm_aux: indicate whether to remove or add the aux variables
        """
        self.id_counter = 0
        self.vars = dict()
        self.aux_vars = []
        self.mutate_script = SmtLibScript()
        
        self.rm_aux = rm_aux
        self.real_aux = real_aux
        self.tactic = tactic

    def mutate(self, statement):
        if self.tactic == 'mcmc':
            return self.mcmc_mutate(statement)
        elif self.tactic == 'enum':
            return self.enum_mutate(statement)

    def compile(self, statement):
        smt_parser = SmtLibParser()
        self.statement = statement
        for s in reversed(statement.split('\n')):
            s = s.strip()
            if "get-value" in s and not s.startswith(';'):
                self.get_value = s
        if hasattr(self, 'get_value') == False: # add this because get-value could be simplified in some case
            # raise PysmtSyntaxError("get-value is not in the last line")
            return False
        try:
            cmds = smt_parser.get_command_generator(StringIO(self.statement))
        except PysmtSyntaxError as e:
            return False
        for cmd in cmds:
            # Note: cmd.args[0] is fnode
            if cmd.name == "declare-fun":
                tmp_name, tmp_type = cmd.args[0].symbol_name(), cmd.args[0].symbol_type()
                if not tmp_type.is_function_type():
                    self.vars[tmp_name] = {'name': tmp_name, 'type': tmp_type}
                self.mutate_script.add_command(cmd) 
            elif cmd.name == "assert":
                fnode = cmd.args[0]
                if not fnode.is_equals(): # skip the non-equal node
                    self.mutate_script.add_command(cmd) 
                else:
                    lhs, rhs = fnode.arg(0), fnode.arg(1)
                    if lhs.is_constant() and rhs.is_constant(): cmd = cmd
                    elif lhs.is_constant(): 
                        lhs = self.constant_fuzzing(lhs)
                        cmd = SmtLibCommand(name="assert", args=(Equals(lhs, rhs),))
                    elif rhs.is_constant():
                        rhs = self.constant_fuzzing(rhs)
                        cmd = SmtLibCommand(name="assert", args=(Equals(lhs, rhs),))
                    else:
                        rhs = self.operator_fuzzing(rhs)
                        cmd = SmtLibCommand(name="assert", args=(Equals(lhs, rhs),))
                    self.mutate_script.add_command(cmd)
            elif cmd.name == "declare-const": # back compatible to smtlib 1.0
                self.mutate_script.add("declare-fun", cmd.args)
            elif cmd.name == "check-sat" or cmd.name == "get-value": # ignore check-sat
                continue
            else:
                self.mutate_script.add_command(cmd)
        return True

    def save_state(self, sols):
        """ save pos/neg state of each solution """
        for s in sols:  
            key, svalue = s.split(' := ')
            val = utils.execute(svalue)
            if key in self.vars.keys():
                tmp_var = self.vars[key]
                if not isinstance(val, bool): # ignore bool
                    if val > 0: 
                        tmp_var['state'] = "GT" 
                    elif val >= 0: 
                        tmp_var['state'] = "GE"
                    elif val < 0: 
                        tmp_var['state'] = "LT"
                if tmp_var['type'].is_int_type():
                    tmp_var['value'] = int(val)
                elif tmp_var['type'].is_real_type():
                    try:
                        tmp_var['value'] = Fraction(svalue)
                    except ValueError as e:
                        tmp_var['value'] = val
                elif tmp_var['type'].is_bool_type():
                    tmp_var['value'] = val
                self.vars.update({key: tmp_var})
            # else:
                # z3 may add more aux vars
                # raise ValueError("the variable is not in the vars")

        # set default value 0 for unassigned vars
        for key in self.vars.keys():
            if 'value' not in self.vars[key]:
                if self.vars[key]['type'].is_int_type():
                    self.vars[key]['value'] = 0
                elif self.vars[key]['type'].is_real_type():
                    self.vars[key]['value'] = 0.0
        return self.vars

    def check_state(self, sols):
        """ check whether the solution is valid """
        for s in sols:  
            key, svalue = s.split(' := ')  
            val = utils.execute(svalue)
            if key in self.vars.keys():
                if self.vars[key]['state'] == "GT" and val <= 0:
                    if DEBUG: print('incorrect state', self.vars[key], self.vars[key]['state'], val)
                    return False
                elif self.vars[key]['state'] == "GE" and val < 0:
                    if DEBUG: print('incorrect state', self.vars[key], self.vars[key]['state'], val)
                    return False
                elif self.vars[key]['state'] == "LT" and val >= 0:
                    if DEBUG: print('incorrect state', self.vars[key], self.vars[key]['state'], val)
                    return False
            # else:
                # z3 may add more aux vars
                # raise ValueError("the variable is not in the vars")
        return True

    def convert(self, value, ttype=None):
        if value == 0: return value
        # convert value to log or exp or any others if possible
        if ttype == "log":
            if value - int(value) < 1e-5:
                value = int(value)
                base = np.random.choice([2,3,5], 1)[0]
                return "(log %s %s)" %(base, base**value)
        elif ttype == "sqrt":
            if value - int(value) < 1e-5:
                value = int(value)
                return "(sqrt %s)" %(value**2)
        elif ttype == "power":
            for i in np.random.permutation([2,3,5,math.e]):
                log_value = math.log(value, i)
                if abs(log_value - int(log_value)) < 1e-5:
                    log_value = int(log_value)
                    return "(power %s %s)" %(i, log_value)
        elif ttype == "sin" or ttype == "cos":
            if value >= -1 and value <= 1:
                for f in np.random.permutation(["sin", "cos", "tan"]):
                    if f == "sin": 
                        func = math.asin
                    elif f == "cos":
                        func = math.acos
                    elif f == "tan":
                        func = math.atan
                    a_value = func(value)
                    for i in [1/6, 1/4, 1/3, 1/2, 2/3, 3/4, 5/6, \
                                    -1/6, -1/4, -1/3, -1/2, -2/3, -3/4, -5/6]:
                        if abs(a_value / math.pi - i) < 1e-5:
                            t = "(/ %s %s)" %(Fraction(i).numerator, Fraction(i).denominator)
                            return "(%s (* %s %s))" %(f, t, "pi")
        elif ttype == "gcd":
            if value - int(value) < 1e-5 and value > 0:
                value = int(value)
                x = utils.prime_gen([0,100])*value
                y = utils.prime_gen([0,100])*value
                if math.gcd(x, y) == value:
                    return "(gcd %s %s)" %(x, y)
        elif ttype == "lcm":
            if value - int(value) < 1e-5 and value > 0:
                value = int(value)
                x = utils.divisor_gen(value)
                y = value // x
                if math.lcm(x, y) == value:
                    return "(lcm %s %s)" %(x, y)
        return value

    def subst(self, statement):
        id = 0
        tmp_statement = []
        for s in statement.split('\n'):
            aux_var = f"aux{id}"
            if aux_var in s:
                if "declare-fun" in s: continue
                else:
                    value = self.vars[aux_var]['value']
                    types = ['log', 'pow', 'sin', 'cos', 'lcm', 'gcd']
                    random.shuffle(types)
                    ttype = None
                    for t in types:
                        if t in s:
                            ttype = t
                            break
                    if ttype: value = self.convert(value, ttype)
                    tmp_statement.append(s.replace(aux_var, str(value)))
                    id += 1
            elif "get-value" in s: # skip the temp get-value command
                continue
            else:
                tmp_statement.append(s)
        statement = '\n'.join(tmp_statement) + "(check-sat)\n" + self.get_value
        return statement
    
    def add(self, statement):
        id = 0
        tmp_statement = []
        for s in statement.split('\n'):
            aux_var = f"aux{id}"
            if aux_var in s:
                s = s.replace("aux%s" %id, "var%s" %(id))
                if "declare-fun" in s: 
                    tmp_statement.append(s)
                else:                     
                    value = self.vars[aux_var]['value']
                    tmp_statement.append(f"(assert (= {'var%s' %(id)} {value}))")
                    id += 1
                tmp_statement.append(s)
            elif "get-value" in s: # skip the temp get-value command
                continue
            else:
                tmp_statement.append(s)
        statement = '\n'.join(tmp_statement) + "(check-sat)\n" + self.get_value
        return statement

    def mcmc_mutate(self, statement, stop_steps=10):
        """ rm_aux indicates whether to remove or add the aux variables """
        ok = self.compile(statement)
        if ok == False: return False, statement # check & stop
        buf_out = StringIO()
        self.mutate_script.serialize(buf_out, daggify=False)
        base_statement = buf_out.getvalue()
        auxaug_statement = base_statement \
            + "\n".join([f"(assert (= {a} {self.vars[str(a)]['value']}))" for a in self.aux_vars]) \
                + "\n(check-sat)\n(get-model)" # aux augmented statement
        ok, sols = solve.solve(auxaug_statement) 
        if ok == False or (len(sols) == 1 and sols[0] == ""): # may no vars to be solved
            return False, statement # check & stop
        vars = self.save_state(sols)
        key_list = list(vars.keys())
        state_cons = [] # add state constraints
        for key in key_list:
            state = vars[key].get('state', "")
            if state == 'GT': state_cons.append("(assert (> %s 0))" %str(vars[key]['name']))
            elif state == 'GE': state_cons.append("(assert (>= %s 0))" %str(vars[key]['name']))
            elif state == 'LT': state_cons.append("(assert (< %s 0))" %str(vars[key]['name']))
            elif state == 'False': state_cons.append("(assert (= %s false))" %str(vars[key]['name']))
            elif state == 'True': state_cons.append("(assert (= %s true))" %str(vars[key]['name']))
        aug_statement = base_statement + "\n".join(state_cons)
        for i in range(stop_steps):
            s, *v = np.random.choice(key_list, int(len(key_list)), replace=False) # random select
            if vars[s]['type'].is_int_type():
                number = self.int_fuzzing(vars[s]['value'])
                if DEBUG: print(f"change {s} from {vars[s]['value']} to {number}")
            elif vars[s]['type'].is_real_type():
                number = self.real_fuzzing(vars[s]['value'])
                if DEBUG: print(f"change {s} from {vars[s]['value']} to {number}")
            elif vars[s]['type'].is_bool_type():
                number = 'false' if vars[s]['value'] else 'true'
                if DEBUG: print(f"change {s} from {vars[s]['value']} to {number}")
            origin_number = vars[s]['value']
            vars[s]['value'] = number
            auxaug_statement = aug_statement + '\n' \
                + "\n".join([f"(assert (= {str(vars[key]['name'])} {vars[key]['value']}))" for key in key_list if key not in v]) \
                 + "\n(check-sat)\n(get-model)" # aux augmented statement
            ok, sols = solve.solve(auxaug_statement) 
            if DEBUG: print(f"currect version {'='*100}\n{auxaug_statement}\n{sols}")
            # if ok == True: ok = self.check_state(sols) # DO not need this now
            if ok == True: # ignore the first step to avoid the same solution
                vars = self.save_state(sols)
                break
            else:
                vars[s]['value'] = origin_number
        if ok == True: 
            if self.rm_aux == True:
                statement = self.subst(base_statement)
            else:
                statement = self.add(base_statement)
        return ok, statement 

    def enum_mutate(self, statement, stop_steps=10):
        """ rm_aux indicates whether to remove or add the aux variables """
        ok = self.compile(statement)
        if ok == False: return False, statement # check & stop
        buf_out = StringIO()
        self.mutate_script.serialize(buf_out, daggify=False)
        base_statement = buf_out.getvalue()
        auxaug_statement = base_statement \
            + "\n".join([f"(assert (= {a} {self.vars[str(a)]['value']}))" for a in self.aux_vars]) \
                + "\n(check-sat)\n(get-model)"# aux augmented statement
        ok, sols = solve.solve(auxaug_statement, sympy=True) 
        if ok == False: return False, statement # check & stop
        vars = self.save_state(sols)
        key_list = list(vars.keys())
        state_cons = [] # add state constraints
        for key in key_list:
            if vars[key]['state'] == 'GT': state_cons.append("(assert (> %s 0))" %str(vars[key]['name']))
            elif vars[key]['state'] == 'GE': state_cons.append("(assert (>= %s 0))" %str(vars[key]['name']))
            elif vars[key]['state'] == 'LT': state_cons.append("(assert (< %s 0))" %str(vars[key]['name']))
        aug_statement = base_statement + "\n".join(state_cons)
        solutions = []
        for i in range(stop_steps):
            aug_statement = aug_statement + '\n' \
                + "\n".join([f"(assert (distinct {str(vars[key]['name'])} {vars[key]['value']}))" for key in key_list]) 
            tmp_statement = aug_statement + "\n(check-sat)\n(get-model)" # aux augmented statement
            ok, sols = solve.solve(tmp_statement, sympy=True) 
            if DEBUG: print(f"currect version {'='*100}\n{aug_statement}\n{sols}")
            if ok == True: 
                vars = self.save_state(sols)
                solutions.append(sols)
            if ok == False: 
                break
        if len(solutions) > 0: 
            sols = random.choice(solutions)
            vars = self.save_state(sols)
            if self.rm_aux == True:
                statement = self.subst(base_statement)
            else:
                statement = self.add(base_statement)
            return True, statement
        else:
            return False, statement 
          
    def constant_fuzzing(self, constant):
        """ fuzz the constant """
        # number = constant.constant_value()
        ops = [Plus, Minus]
        op = random.choice(ops)
        if constant.is_real_constant() and self.real_aux == True:
            aux_var = Symbol("aux%s" %self.id_counter, REAL)
            self.vars["aux%s" %self.id_counter] = {'name': "aux%s" %self.id_counter, 'type': REAL, 'value': 0}
            formula = op(constant, aux_var)
        elif constant.is_real_constant() and self.real_aux == False:
            aux_var = Symbol("aux%s" %self.id_counter, INT)
            self.vars["aux%s" %self.id_counter] = {'name': "aux%s" %self.id_counter, 'type': INT, 'value': 0}
            formula = op(constant, ToReal(aux_var))
        elif constant.is_int_constant():
            aux_var = Symbol("aux%s" %self.id_counter, INT)
            self.vars["aux%s" %self.id_counter] = {'name': "aux%s" %self.id_counter, 'type': INT, 'value': 0}
            formula = op(constant, aux_var)
        else:
            raise ValueError("the constant is not int or real")
        self.aux_vars.append(aux_var)
        cmd = SmtLibCommand(name="declare-fun", args=(aux_var,)) # add a new variable
        self.mutate_script.add_command(cmd)
        self.id_counter += 1
        return formula

    def operator_fuzzing(self, formula):
        """" fuzz the equation by add a new operator """
        ftype = formula.get_type()
        if ftype.is_real_type() and self.real_aux == True:
            ops = [Plus, Minus, Times, Div]
            op = random.choice(ops)
            aux_var = Symbol("aux%s" %self.id_counter, REAL)
            if op == Plus or op == Minus: 
                self.vars["aux%s" %self.id_counter] = {'name': "aux%s" %self.id_counter, 'type': REAL, 'value': 0.0}
            elif op == Times or op == Div: 
                self.vars["aux%s" %self.id_counter] = {'name': "aux%s" %self.id_counter, 'type': REAL, 'value': 1.0}
            formula = op(formula, aux_var)
        elif ftype.is_real_type() and self.real_aux == False:
            ops = [Plus, Minus, Times, Div]
            op = random.choice(ops)
            aux_var = Symbol("aux%s" %self.id_counter, INT)
            if op == Plus or op == Minus: 
                self.vars["aux%s" %self.id_counter] = {'name': "aux%s" %self.id_counter, 'type': INT, 'value': 0}
            elif op == Times or op == Div: 
                self.vars["aux%s" %self.id_counter] = {'name': "aux%s" %self.id_counter, 'type': INT, 'value': 1}
            formula = op(formula, ToReal(aux_var))
        elif ftype.is_int_type():
            ops = [Plus, Minus, Times] # div is not allowed for int_type
            op = random.choice(ops)
            aux_var = Symbol("aux%s" %self.id_counter, INT)
            if op == Plus or op == Minus: 
                self.vars["aux%s" %self.id_counter] = {'name': "aux%s" %self.id_counter, 'type': INT, 'value': 0}
            elif op == Times or op == Div: 
                self.vars["aux%s" %self.id_counter] = {'name': "aux%s" %self.id_counter, 'type': INT, 'value': 1}
            formula = op(formula, aux_var)
        self.aux_vars.append(aux_var)
        cmd = SmtLibCommand(name="declare-fun", args=(aux_var,)) # add a new variable
        self.mutate_script.add_command(cmd)
        self.id_counter += 1
        return formula

    def int_fuzzing(self, number):
        """ fuzz int by adding a noise with same scale """
        if abs(number) == 0:
            magnitude = 10
            perturbation = random.uniform(0, magnitude)
            sign = 0
        else:
            magnitude = 10 ** (int(math.log10(abs(number))))
            perturbation = random.uniform(0, magnitude)
            sign = random.choice([-1, 1])
        number = int(number + sign*perturbation)
        return number

    def real_fuzzing(self, number):
        """ fuzz real by adding a noise with same scale """
        if 0 < number < 1:
            number = round(random.uniform(0, 1), 2)
        elif -1 < number < 0:
            number = - round(random.uniform(0, 1), 2)
        else:
            tmp_num = self.int_fuzzing(int(eval(str(number))))
            number = tmp_num + round(random.uniform(0, 1), 2)
        return number
    

def mutate(statement, tactic='mcmc', allow_real=False, rm_aux=True, simplify=True, comment=True, rename=False):
    if rename:
        statement = simplify.rename(statement)
    ### using INT first
    m = mutater(rm_aux=rm_aux, real_aux=False, tactic=tactic)
    ok, fuzz_statement = m.mutate(statement)
    ### then, using REAL 
    if ok == False and allow_real == True: # try real type again
        m = mutater(rm_aux=rm_aux, real_aux=True, tactic=tactic)
        ok, fuzz_statement = m.mutate(statement)
    if ok == True:
        statement = fuzz_statement
    if simplify == True: 
        statement = simplify.simplify(statement, comment=comment)
    return statement


