import copy
import os
import random

import numpy as np
import torch


def make_save_path(args):
    root_path = os.path.dirname(__file__)
    root_path = os.path.join(os.path.dirname(os.path.dirname(root_path)), "results")
    os.makedirs(root_path, exist_ok=True)
    return root_path


def get_server_and_client(args, client_args, server_args):
    if args.method.lower() == "foogd":
        from src.algorithms.foogd import FOOGDClient, FOOGDServer
        for cid in range(len(client_args)):
            client_args[cid]["score_model"] = copy.deepcopy(args.score_model)

        server = FOOGDServer
        client = FOOGDClient
    else:
        raise NotImplementedError("method not support")

    return server, client, client_args, server_args


def set_seed(seed: int = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
