from __future__ import print_function
import parser
import torch
import torch.nn.functional as F
import torch.optim as optim
from utils import dataset_loader
from clients_attackers import *
from server import Server
import random 
import numpy as np 

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
def main(args):
    print(f'Aggregation Rule:\t{args.AR}\nData distribution:\t{args.loader_type}\nAttacks:\t\t{args.attacks} ')
    print("#" * 64)
    
    device = args.device
    attacks = args.attacks
    
    trainData, testData, vocab = dataset_loader.get_dataloader(args)
    args.vocab = vocab
    Model = dataset_loader.Model
    criterion = F.cross_entropy
    attacker_list_name = '-'.join([str(atk) for atk in args.attacker_list]) if len(args.attacker_list) > 0 else 'clean'
    args.attacker_list_name = attacker_list_name
    attacker_list = args.attacker_list
    text_backdoor_utils = Text_Backdoor_Utils(args, vocab)
        
    # create server instance
    model0 = Model(args)
    server = Server(args, model0, testData, text_backdoor_utils, criterion, device)
    server.set_AR(args.AR)
    server.path_to_aggNet = f'./aaa/{args.dataset}_dirichlet_word_attention.pt'
    if args.save_model_weights:
        server.isSaveChanges = True
        server.savePath = f'./log/{args.dataset}_{args.attacks}_{args.loader_type}_{args.AR}_{attacker_list_name}'
        from pathlib import Path
        Path(server.savePath).mkdir(parents=True, exist_ok=True)
        '''
        honest clients are labeled as 1, malicious clients are labeled as 0
        '''
        label = torch.ones(args.num_clients)
        for i in args.attacker_list:
            label[i] = 0
        torch.save(label, f'{server.savePath}/label.pt')
    # create clients instance

    for i in range(args.num_clients):
        model = Model(args)
        if args.optimizer == 'SGD':
            optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
        elif args.optimizer == 'Adam':
            optimizer = optim.Adam(model.parameters(), lr=args.lr)
        if i in attacker_list:
            client_i = Attacker_Text(args, i, model, trainData[i], optimizer, text_backdoor_utils, criterion, device, args.inner_epochs)
        else:
            client_i = Client(args, i, model, trainData[i], optimizer, criterion, device, args.inner_epochs)
        server.attach(client_i)

    steps = 0
    loss, accuracy = server.test(steps)
    loss, accuracy = server.test_backdoor(steps)
    
    for j in range(args.epochs):
        steps = j + 1

        print('\n\n########EPOCH %d ########' % j)
        print('###Model distribution###\n')
        server.step = steps
        server.epochs = args.epochs
        server.distribute()
        group = range(args.num_clients)
        server.train(group)

        loss, acc = server.test(steps)
        loss, asr = server.test_backdoor(steps)
        if False:
        ###if args.attacks == 'syntactic':
            torch.save(text_backdoor_utils.scpn_dict, f'data/{args.dataset}_scpn_dict.pt')
            ###torch.save(text_backdoor_utils.scpn_dict, f'data/{args.dataset}_scpn_dict_{args.attacker_list_name}.pt')
        
        print('%s,%s,%s,%s,%s,%s,%s,%s,%d,%.5f,%.5f' % (args.dataset, args.loader_type, args.model, args.optimizer, args.defense, args.AR, args.attacks, attacker_list_name, steps, acc, asr), file=open('log.csv', 'a'))
           
if __name__ == "__main__":
    args = parser.parse_args()
    set_seed(args.seed)
    print("#" * 64)
    for i in vars(args):
        print(f"#{i:>30}: {str(getattr(args, i)):<30}#")
    print("#" * 64)
    main(args)
