import random
import json
import jsonlines
import ray
import mmap
import time
import sys
from ray.util.queue import Queue
sys.path.append(".")
from utils.misc import execute
from utils.ray_tools import ProgressBar
from tqdm import tqdm
import pathlib
import glob
import subprocess
import os
import traceback
from Bio.PDB import *
import numpy as np
from rdkit import Chem

TMscore_threshold=0.4
Match_rate_threshold=0.4
N_CPU_PER_THREAD = 1
n_thread=200
PDBBind_dir='/path/to/dir'
MSA_dir="/path/to/dir"
AF2DB_dir="/path/to/dir"


def get_MSA_ids(PDBBind_instance_dir):
    file_list=glob.glob(PDBBind_instance_dir + '/rotation_matrix/*_TMscore.txt')
    MSA_ids=[]
    for file in file_list:
        MSA_ids.append(file.split("/")[-1].split("_")[0])
    return MSA_ids

def calc_match_rate(pocket_position,Aligned_seq):
    total_cnt=0
    match_cnt=0
    for i in range(len(Aligned_seq)):
        if pocket_position[i]!="-":
            total_cnt+=1
            if pocket_position[i]==Aligned_seq[i]:
                match_cnt+=1
    return match_cnt/total_cnt

def get_rotate_matrix(rotate_matrix_file):
    with open(rotate_matrix_file,"r") as f:
        data=f.readlines()
    u=[]
    t=[]
    for i in range(2,5):
        line=data[i].split(" ")
        line_float=[float(x) for x in line if x!=""]
        t.append(line_float[1])
        u.append(line_float[2:])
    u=np.array(u)
    t=np.array(t)
    return u,t

@ray.remote(num_cpus=N_CPU_PER_THREAD)
def process_jobs(id,jobs_queue,actor):
    print("start process",id)
    while not jobs_queue.empty():
        job = jobs_queue.get()
        try:
            execute_one_job(job)
            
        except:
            print(f"failed: {job}")
            traceback.print_exception(*sys.exc_info())
        try:
            actor.update.remote(1)
        except:
            pass
    return 1


def execute_one_job(PDBBind_instance_dir):
    pdb_id=PDBBind_instance_dir.split("/")[-2]
    fasta_dir=glob.glob(PDBBind_instance_dir + '/*.fasta')
    chain_id=fasta_dir[0].split("/")[-1].split(".")[0][-1]
    pocket_position_file=PDBBind_instance_dir + pdb_id +chain_id+ '_pocket_position.txt'
    chain_pdb_file=PDBBind_instance_dir + pdb_id + '_pocket_chain.pdb'
    # print("chain_pdb_file:",chain_pdb_file)
    mol2_file=PDBBind_instance_dir + pdb_id + '_ligand.mol2'
    pocket_position_file=PDBBind_instance_dir + pdb_id +chain_id+ '_pocket_position.txt'
    with open(pocket_position_file,"r") as f:
        data=f.readlines()
    pocket_position=data[0].strip()

    extend_dir=PDBBind_instance_dir+"/extend/"
    if not os.path.exists(extend_dir):
        os.mkdir(extend_dir)

    # get_MSA_ids
    MSA_ids=get_MSA_ids(PDBBind_instance_dir)
    # print("MSA_ids:",MSA_ids)

    ligand = Chem.MolFromMol2File(mol2_file)
    conf = ligand.GetConformer()
    ligand_coords = conf.GetPositions()

    for MSA_id in MSA_ids:
        # get TMscore and Match_rate of the MSA
        TMscore_file=PDBBind_instance_dir+f"/rotation_matrix/{MSA_id}_TMscore.txt"
        with open(TMscore_file,"r") as f:
            data=f.readlines()
        TMscore=float(data[0].split(":")[-1])
        Aligned_seq=data[4].strip()
        Match_rate=calc_match_rate(pocket_position,Aligned_seq)

        # print("TMscore:",TMscore,"Match_rate:",Match_rate)

        if TMscore>=TMscore_threshold and Match_rate>=Match_rate_threshold:
            # print("###########################################################################")
            # print("MSA_id:",MSA_id)
            # create extend dir
            extend_instance_dir=extend_dir+MSA_id+"/"
            if not os.path.exists(extend_instance_dir):
                os.mkdir(extend_instance_dir)
            
            # read ori MSA pdb file
            MSA_pdb_file=AF2DB_dir+f"/{MSA_id}"+".pdb"
            parser = PDBParser()
            structure = parser.get_structure(MSA_id, MSA_pdb_file)
            model = structure[0]
            
            # get chain
            for chain in model:
                MSA_chain_id=chain.id
                break
            MSA_chain=model[MSA_chain_id]

            # get rotate_matrix
            rotate_matrix_file=PDBBind_instance_dir+f"/rotation_matrix/{MSA_id}.txt"
            rotation_matrix=get_rotate_matrix(rotate_matrix_file)
            # print("rotation_matrix:",rotation_matrix)

            for residue in MSA_chain:
                for atom in residue:
                    coord=atom.get_coord()
                    coord=np.array(coord)
                    new_coord=np.dot(rotation_matrix[0],coord)+rotation_matrix[1]
                    atom.set_coord(new_coord)
            
            # write new pdb file
            io = PDBIO()
            io.set_structure(structure)
            io.save(extend_instance_dir+f"{MSA_id}"+"_protein.pdb")

            # get pocket , which is in the 6A of ligand
            MSA_pocket_file=extend_instance_dir+f"{MSA_id}"+"_pocket.pdb"
            for residue in MSA_chain:
                # print("-------------------------------------------------------------")
                remove_atom_ids=[]
                for atom in residue:
                    # print("atom: ",atom.id)
                    coord=atom.get_coord()
                    f=0
                    for ligand_coord in ligand_coords:
                        dis=np.linalg.norm(coord-ligand_coord)
                        if np.linalg.norm(coord-ligand_coord)<=6:
                            f=1
                            break
                    if f==0:
                        remove_atom_ids.append(atom.id)
                for atom_id in remove_atom_ids:
                    residue.detach_child(atom_id)
            io = PDBIO()
            io.set_structure(structure)
            io.save(MSA_pocket_file)
    print("finish: pdb_id:",pdb_id)
    return 1


PDBBind_instance_dirs = glob.glob(PDBBind_dir + '*/')
print('Number of PDBBind instances: {}'.format(len(PDBBind_instance_dirs)))
uncompleted_jobs=[]

# remove jobs which do not have pocket position/chain_pdb_file/MSA_file
for PDBBind_instance_dir in PDBBind_instance_dirs:
    pdb_id=PDBBind_instance_dir.split("/")[-2]
    fasta_dir=glob.glob(PDBBind_instance_dir + '/*.fasta')
    if len(fasta_dir)==0:
        continue
    chain_id=fasta_dir[0].split("/")[-1].split(".")[0][-1]
    # pocket_position_file=PDBBind_instance_dir + pdb_id +chain_id+ '_pocket_position.txt'
    # if not os.path.exists(pocket_position_file):
    #     continue
    # chain_pdb_file=PDBBind_instance_dir + pdb_id + '_pocket_chain.pdb'
    # if not os.path.exists(chain_pdb_file):
    #     continue
    MSA_file=MSA_dir+f"/{pdb_id}"+f"{chain_id}"+".fasta"
    if not os.path.exists(MSA_file):
        continue
    uncompleted_jobs.append(PDBBind_instance_dir)
print("uncompleted jobs:",len(uncompleted_jobs))

uncompleted_jobs=uncompleted_jobs

job_queue = Queue()
for job in tqdm(uncompleted_jobs):
    job_queue.put(job)
print("job queue size:",job_queue.qsize())
pb = ProgressBar(len(uncompleted_jobs)) 
actor=pb.actor

jop_id_list=[]
for i in range(n_thread):
    jop_id_list.append(process_jobs.remote(i,job_queue,actor))
pb.print_until_done()
result=ray.get(jop_id_list)
print("Done!")