import torch 
import wandb 
import matplotlib.pyplot as plt 

import argparse 
import sys 
sys.path.append('.') 

# gflownets and general utils 
from streaming_gfn.gflownet import GFlowNet, SBGFlowNet  
from streaming_gfn.utils import compute_marginal_dist 

def get_argument_parser():
    parser = argparse.ArgumentParser(description="GFlowNet Training Script")

    # GFlowNet architecture and training parameters
    parser.add_argument("--hidden_dim", type=int, default=256, help="Hidden dimension of the policy network")
    parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the policy network")
    parser.add_argument("--epochs_eval", type=int, default=100, help="Number of epochs for evaluation")
    parser.add_argument('--epochs_per_step', type=int, default=int(9e3), help='number of epochs per step')
    parser.add_argument('--num_steps', type=int, default=25, help='number of steps')  
    parser.add_argument("--use_scheduler", action="store_true", help="Use learning rate scheduler")
    
    parser.add_argument("--criterion", type=str, default="tb", help="Loss function for training", choices=['iwae', 'kl', 'tb'])
    parser.add_argument('--n_traj', type=int, default=2, help='number of backward trajectories for iwae') 
    parser.add_argument("--device", type=str, default="cpu", help="Device to use (cpu or cuda)")
    
    parser.add_argument('--env', type=str, default='sets', help='Target domain', 
                        choices=['sets', 'phylogenetics']) 

    # Environment parameters

    # Generic 
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training")

    # Sets 
    parser.add_argument("--set_size", type=int, default=18, help="Number of elements in the set")
    parser.add_argument("--src_size", type=int, default=24, help="Number of source vectors")

    # Phylogenetic inference 
    parser.add_argument('--num_leaves', type=int, default=7, help='number of biological species') 
    parser.add_argument('--num_nb', type=int, default=4, help='number of nucleotides (hypothetical)') 
    parser.add_argument('--num_sites', type=int, default=25, help='number of observed sites') 
    parser.add_argument('--temperature', type=float, default=1., help='temperature of the target') 


    # Reward and seed
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reward generation")
    parser.add_argument('--num_stm_upd', type=int, default=2, help='number of streaming updates') 
    parser.add_argument('--alpha', type=float, default=1., help='temperature of the target distribution') 

    # Visualization parameters
    parser.add_argument("--num_back_traj", type=int, default=8, help="Number of back-trajectories for evaluation")
    parser.add_argument('--use_progress_bar', action='store_true', help='use progress bar') 

    return parser

def create_gfn(config, is_streaming=False): 
    match config.env: 
        case 'sets': 
            from streaming_gfn.policies.sets import ForwardPolicy, BackwardPolicy 
            pf = ForwardPolicy(config.src_size, config.hidden_dim, config.num_layers, device=config.device) 
            pb = BackwardPolicy(config.device)  
        case 'phylogenetics': 
            from streaming_gfn.policies.trees import ForwardPolicyMLP, BackwardPolicy 
            pf = ForwardPolicyMLP(config.hidden_dim, config.num_leaves, device=config.device) 
            pb = BackwardPolicy(config.device) 
        case _: 
            raise Exception(f'env: {config.env}')  
    
    if not is_streaming: 
        return GFlowNet(pf, pb, n_traj=config.n_traj, criterion=config.criterion, device=config.device)  
    else: 
        return SBGFlowNet(pf, pb, n_traj=config.n_traj, criterion=config.criterion, device=config.device) 

def create_env(config, log_reward=None): 
    match config.env: 
        case 'sets': 
            from streaming_gfn.gym.sets import Set 
            return Set(config.src_size, config.set_size, config.batch_size, log_reward, device=config.device) 
        case 'phylogenetics': 
            from streaming_gfn.gym.trees import Trees 
            return Trees(config.num_leaves, config.batch_size, log_reward, device=config.device) 
        
def create_log_reward(config, gflownet): 
    match config.env: 
        case 'sets': 
            from streaming_gfn.rewards.sets import LogReward 
            from streaming_gfn.gym.sets import Set 
            log_reward = LogReward(config.src_size, config.seed, device=config.device, alpha=config.alpha) 
            sets = Set(config.src_size, config.set_size, config.batch_size, log_reward, device=config.device) 
            sets = gflownet.sample(sets) 
            log_reward.shift = sets.log_reward().max()  
            return log_reward 
        case 'phylogenetics': 
            from streaming_gfn.rewards.trees import LogReward 
            from streaming_gfn.gym.trees import Trees 
            tree = Trees(config.num_leaves, batch_size=1, log_reward=None, device=config.device) 
            with gflownet.off_policy(): 
                tree = gflownet.sample(tree, seed=42) 
            # Simulate JC69 
            Q = 5e-1 * torch.ones((config.num_nb, config.num_nb), device=config.device) 
            Q[torch.arange(config.num_nb), torch.arange(config.num_nb)] -= Q.sum(dim=-1) 
            pi = torch.ones((config.num_nb,), device=config.device) / config.num_nb   
            sites = Trees.sample_from_phylogeny(tree, Q, config.num_sites, pi, device=config.device)
            sites = sites[:, :config.num_leaves] 
            # Tree's likelihood using Felsenstein's algorithm 
            log_reward = LogReward(pi, sites, Q, config.temperature) 
            # Shift the reward for enhanced numerical stability 
            env = Trees(config.num_leaves, batch_size=config.batch_size, log_reward=log_reward, device=config.device) 
            values = gflownet.sample(env).log_reward() 
            log_reward.shift = values.max() 
            return log_reward 

def create_opt(gfn, config): 
    optimizer = torch.optim.Adam([
        {'params': gfn.pf.parameters(), 'lr': 1e-3}, 
        {'params': gfn.log_z, 'lr': 1e-1} 
    ])
    scheduler = None 
    if config.use_scheduler: 
        total_iters = config.epochs_per_step  
        scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_iters, power=1.)
    return optimizer, scheduler 

def eval_step(config, gfn, create_env_func, plot, return_dist=False): 
    with gfn.off_policy(): 
        learned_dist, target_dist = compute_marginal_dist(gfn, create_env_func, 
                                num_batches=config.epochs_eval, 
                                num_back_traj=config.num_back_traj, 
                                use_progress_bar=config.use_progress_bar) 

        wandb.log({'l1': (learned_dist - target_dist).abs().sum()}) 
        return_value = (learned_dist - target_dist).abs().sum()  
    
    if return_dist: 
        return return_value, learned_dist, target_dist 
    return return_value 
