import math
import os
from argparse import ArgumentParser
from pathlib import Path

import torch
import torchvision.models as models
import yaml
from torch.optim import AdamW

from models import model_from_kwargs
from functools import partial

def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--mode", required=True)
    parser.add_argument("--device", type=int, required=True)
    parser.add_argument("--precision", type=str, default="float32", choices=["fp32", "fp16", "float16", "float32"])
    parser.add_argument("--start", type=int, default=500)
    return vars(parser.parse_args())


def clampm1to1(*args, **kwargs):
    return torch.randn(*args, **kwargs).clamp(-1, 1)


def batch_idx(num_points, device, batch_size):
    assert num_points % batch_size == 0
    return torch.arange(batch_size, device=device).repeat_interleave(num_points // batch_size)

def unbatch_idx(num_points, device, batch_size):
    assert num_points % batch_size == 0
    return torch.arange(batch_size, device=device).repeat_interleave(num_points // batch_size)

def unbatch_select(*_, device, batch_size):
    return torch.arange(batch_size, device=device)


def meshtogrid32768to3586(size, two, device, batch_size):
    gridmax = 32768 * batch_size
    querymax = 3586 * batch_size
    return torch.stack(
        [
            torch.randint(gridmax, size=(size,), device=device),
            torch.randint(querymax, size=(size,), device=device),
        ],
        dim=1,
    )

def meshtogrid110592to3586(size, two, device, batch_size):
    gridmax = 110592 * batch_size
    querymax = 3586 * batch_size
    return torch.stack(
        [
            torch.randint(gridmax, size=(size,), device=device),
            torch.randint(querymax, size=(size,), device=device),
        ],
        dim=1,
    )

def meshtogrid262144to3586(size, two, device, batch_size):
    gridmax = 262144 * batch_size
    querymax = 3586 * batch_size
    return torch.stack(
        [
            torch.randint(gridmax, size=(size,), device=device),
            torch.randint(querymax, size=(size,), device=device),
        ],
        dim=1,
    )

def gridtoquery3586to32768(size, two, device, batch_size):
    querymax = 3586 * batch_size
    gridmax = 32768 * batch_size
    return torch.stack(
        [
            torch.randint(querymax, size=(size,), device=device),
            torch.randint(gridmax, size=(size,), device=device),
        ],
        dim=1,
    )

def gridtoquery3586to110592(size, two, device, batch_size):
    querymax = 3586 * batch_size
    gridmax = 110592 * batch_size
    return torch.stack(
        [
            torch.randint(querymax, size=(size,), device=device),
            torch.randint(gridmax, size=(size,), device=device),
        ],
        dim=1,
    )

def gridtoquery3586to262144(size, two, device, batch_size):
    querymax = 3586 * batch_size
    gridmax = 262144 * batch_size
    return torch.stack(
        [
            torch.randint(querymax, size=(size,), device=device),
            torch.randint(gridmax, size=(size,), device=device),
        ],
        dim=1,
    )


# noinspection PyUnusedLocal
def main(mode, device, precision, start):
    # init device
    device = torch.device(f"cuda:{device}")
    print(f"device: {device}")

    # init precision
    if precision in ["float16", "fp16"]:
        precision = torch.float16
    elif precision in ["float32", "fp32"]:
        precision = torch.float32
    else:
        raise NotImplementedError
    print(f"precision: {precision}")

    # init model and data
    if mode == "resnet18":
        model = models.resnet18()
        input_configs = dict(x=dict(shape=(3, 224, 224)))
        output_key = None
    else:
        config_uri = Path(f"yamls/profile_memory/{mode}.yaml")
        if config_uri.exists():
            with open(config_uri) as f:
                config = yaml.safe_load(f)
            input_shape = config["input_shape"]
            model = model_from_kwargs(
                input_shape=input_shape,
                output_shape=config["output_shape"],
                **config["model"],
            )
            input_configs = config["inputs"]
            output_key = config["output_key"]
        else:
            raise NotImplementedError
    model = model.to(device)
    optim = AdamW(model.parameters())

    # binary search
    lower_bound = 0
    upper_bound = start * 2
    largest_batchsize = 0
    is_warmup = True
    while True:
        # dev run
        if os.name == "nt" and not is_warmup:
            print("dev run -> exit after warmup")
            return
        # warmup for optim state initialization
        if is_warmup:
            print(f"warmup")
            best_guess = 1
        else:
            best_guess = lower_bound + math.ceil((upper_bound - lower_bound) / 2)
            print(f"search space: [{lower_bound} - {upper_bound}] -> best guess: {best_guess}")
        try:
            torch.cuda.synchronize()
            torch.cuda.empty_cache()
            inputs = {}
            for input_name, input_config in input_configs.items():
                shape = input_config["shape"]
                constraint = input_config.get("constraint", None)
                if constraint is None:
                    gen_fn = torch.randn
                elif constraint == "clampm1to1":
                    gen_fn = clampm1to1
                elif constraint == "batch_idx":
                    gen_fn = partial(batch_idx, batch_size=best_guess)
                elif constraint == "unbatch_idx":
                    gen_fn = partial(unbatch_idx, batch_size=best_guess)
                elif constraint == "unbatch_select":
                    gen_fn = partial(unbatch_select, batch_size=best_guess)
                elif constraint == "meshtogrid32768to3586":
                    gen_fn = partial(meshtogrid32768to3586, batch_size=best_guess)
                elif constraint == "meshtogrid110592to3586":
                    gen_fn = partial(meshtogrid110592to3586, batch_size=best_guess)
                elif constraint == "meshtogrid262144to3586":
                    gen_fn = partial(meshtogrid262144to3586, batch_size=best_guess)
                elif constraint == "gridtoquery3586to32768":
                    gen_fn = partial(gridtoquery3586to32768, batch_size=best_guess)
                elif constraint == "gridtoquery3586to110592":
                    gen_fn = partial(gridtoquery3586to110592, batch_size=best_guess)
                elif constraint == "gridtoquery3586to262144":
                    gen_fn = partial(gridtoquery3586to262144, batch_size=best_guess)
                else:
                    raise NotImplementedError
                if input_config.get("is_sparse", False):
                    inputs[input_name] = gen_fn(best_guess * shape[0], *shape[1:], device=device)
                else:
                    inputs[input_name] = gen_fn(best_guess, *shape, device=device)
            optim.zero_grad()
            with torch.autocast(device_type="cuda", dtype=precision):
                outputs = model(**inputs)
                if output_key is None:
                    output = outputs
                else:
                    output = outputs[output_key]
                loss = output.mean()
            loss.backward()
            optim.step()
            # success -> increase lower bound
            if not is_warmup:
                lower_bound = best_guess
                print(f"{best_guess} is possible")
                largest_batchsize = best_guess
        except:
            # error -> decrease upper bound
            if not is_warmup:
                upper_bound = best_guess - 1
                print(f"{best_guess} is impossible")
        inputs = None
        output = None
        outputs = None
        loss = None
        if lower_bound == upper_bound:
            break
        is_warmup = False
    print(f"largest batchsize: {largest_batchsize}")
    total_memory = torch.cuda.get_device_properties(device).total_memory
    print(f"total_memory: {total_memory}")
    print(f"memory_per_sample: {total_memory // largest_batchsize / 1024 / 1024 / 1024:.2f} GB")
    print(f"num_parameters: {sum(p.numel() for p in model.parameters()) / 1000 / 1000:.2f}M")


if __name__ == "__main__":
    main(**parse_args())
