import time
from math import prod
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from learning.fns import get_einsum_parser
from learning.fns import get_einsum_exp
from learning.fns import gen_cores
from learning.fns import gen_core_info
from learning.fns import find_locs_info
from learning.fns import gen_vec
from ops.operators import EinOpVec


def generate_data(seed=21):
    torch.manual_seed(seed=seed)
    B = 100

    core_info = gen_core_info(total_cores=4)
    vec = [1 for _ in range(len(core_info))]
    targets = [
        ((1, 1, 0, 0), 2),
        ((1, 0, 1, 0), 3),
        ((0, 1, 0, 1), 3),
        ((0, 0, 1, 1), 2),
        ((0, 1, 1, 0), 2),
    ]
    for idx, val in find_locs_info(targets, core_info):
        vec[idx] = val
    empty_cores, active_cores = gen_cores(vec, core_info)
    cores = [torch.randn(core.shape) for core in empty_cores]
    ein_exp = get_einsum_exp(active_cores, core_info)
    shapes = (cores[0].shape, cores[-1].shape)
    E = EinOpVec(cores[1:-1], ein_exp, shapes)
    x = torch.randn((B, prod(cores[0].shape)))
    y = x @ E
    return x, y


def main(args):
    t0 = time.time()

    if args.wandb:
        config = args.__dict__
        wandb.init(
            project=args.wandb_project,
            name=f"test_run_{args.idx}",
            config=config,
        )

    x, y = generate_data(seed=21)
    core_info = gen_core_info(total_cores=4)
    # shapes = (x[0, :].shape, y[0, :].shape)
    shapes = [(3, 2), (3, 2)]
    vec = gen_vec(core_info, shapes)
    empty_cores, active_cores = gen_cores(vec, core_info)
    ein_exp = get_einsum_exp(active_cores, core_info)
    Ms = [nn.Parameter(torch.randn(M.shape)) for M in empty_cores[1:-1]]
    A_hat = EinOpVec(Ms, ein_exp, shapes)

    # opt = optim.SGD(Ms, lr=args.lr)
    opt = optim.Adam(Ms, lr=args.lr)
    for idx in range(args.max_iters):
        y_hat = x @ A_hat
        error = torch.mean(torch.norm(y - y_hat, dim=-1))
        error.backward()
        opt.step()
        opt.zero_grad()
        print(f"Iter: {idx + 1} | Error: {error:1.5e}")

    if args.wandb:
        wandb.log({"error": error})
        wandb.log({"expr": ein_exp})
    t1 = time.time()
    print(f"Took: {t1 - t0:1.5e} sec")


if __name__ == "__main__":
    parser = get_einsum_parser()
    args = parser.parse_args()
    main(args)
