import time, os, torch
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
                           SequenceGroupMetadata, SequenceStatus)
from vllm.logger import init_logger
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.scheduler import SchedulerOutputs, Scheduler, ScheduledSequenceGroup
from filelock import FileLock
import shutil
import socket, pickle

logger = init_logger(__name__)
from vllm_inject.utils import *

def can_append_slots_status(seq_group, func, model_name):
    now_status = func(seq_group) #同步这个状态，以0为主
    HOST = socket.gethostname()
    PORT = 11455
    client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    client_socket.connect((HOST, PORT))
    now_status_data = pickle.dumps((model_name, now_status, 0))
    client_socket.send(now_status_data)
    result_data = client_socket.recv(4096)
    now_status_min = pickle.loads(result_data)
    client_socket.close()
    return now_status_min
def _schedule(self) -> SchedulerOutputs:
    # Blocks that need to be swapped or copied before model execution.
    blocks_to_swap_in: Dict[int, int] = {}
    blocks_to_swap_out: Dict[int, int] = {}
    blocks_to_copy: Dict[int, List[int]] = {}

    # Fix the current time.
    now = time.time()

    # Join waiting sequences if possible.
    if not self.swapped:
        ignored_seq_groups: List[SequenceGroup] = []
        scheduled: List[SequenceGroup] = []
        new_add_num = 0
        # The total number of sequences on the fly, including the
        # requests in the generation phase.
        num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
                            for seq_group in self.running)
        curr_loras = set(
            seq_group.lora_int_id
            for seq_group in self.running) if self.lora_enabled else None

        # Optimization: We do not sort the waiting queue since the preempted
        # sequence groups are added to the front and the new sequence groups
        # are added to the back.
        leftover_waiting_sequences = deque()
        num_batched_tokens = 0
        while self._passed_delay(now) and self.waiting:
            seq_group = self.waiting[0]
            waiting_seqs = seq_group.get_seqs(
                status=SequenceStatus.WAITING)
            assert len(waiting_seqs) == 1, (
                "Waiting sequence group should have only one prompt "
                "sequence.")
            # get_len includes output tokens if the request has been
            # preempted.
            num_prefill_tokens = waiting_seqs[0].get_len()
            if num_prefill_tokens > self.prompt_limit:
                logger.warning(
                    f"Input prompt ({num_prefill_tokens} tokens) is too "
                    f"long and exceeds limit of {self.prompt_limit}")
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
                ignored_seq_groups.append(seq_group)
                self.waiting.popleft()
                continue

            # If the sequence group cannot be allocated, stop.
            can_allocate = self.block_manager.can_allocate(seq_group)
            if can_allocate == AllocStatus.LATER:
                break
            elif can_allocate == AllocStatus.NEVER:
                logger.warning(
                    f"Input prompt ({num_prefill_tokens} tokens) is too "
                    f"long and exceeds the capacity of block_manager")
                for seq in waiting_seqs:
                    seq.status = SequenceStatus.FINISHED_IGNORED
                ignored_seq_groups.append(seq_group)
                self.waiting.popleft()
                continue

            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
                if (lora_int_id > 0 and lora_int_id not in curr_loras
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_waiting_sequences.appendleft(seq_group)
                    self.waiting.popleft()
                    continue

            # If the number of batched tokens exceeds the limit, stop.
            num_batched_tokens += num_prefill_tokens
            if (num_batched_tokens >
                    self.scheduler_config.max_num_batched_tokens):
                break

            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
            num_new_seqs = seq_group.get_max_num_running_seqs()
            if (num_curr_seqs + num_new_seqs >
                    self.scheduler_config.max_num_seqs):
                break

            if lora_int_id > 0:
                curr_loras.add(lora_int_id)
            self.waiting.popleft()
            # self._allocate(seq_group)
            # self.running.append(seq_group)
            # print(seq_group.get_seqs(status=SequenceStatus.WAITING))
            num_curr_seqs += num_new_seqs
            new_add_num += 1
            scheduled.append(
                ScheduledSequenceGroup(
                    seq_group=seq_group,
                    token_chunk_size=num_prefill_tokens))
        ### -----gather and write lock-----
        HOST = socket.gethostname()
        PORT = 11455
        client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        client_socket.connect((HOST, PORT))
        while 1:
            try:
                new_add_num_data = pickle.dumps((self.scheduler_config.model_name, new_add_num, len(self.running)))
                client_socket.send(new_add_num_data)
                result_data = client_socket.recv(4096)
                new_add_num_min = pickle.loads(result_data)
                client_socket.close()
                break
            except:
                print("try to reconnect")
                time.sleep(2)
                continue
        # file_path = "/xx/analysis/vllm_inject/batch_num_gsm"
            
        # now_name = self.scheduler_config.model_name.replace("/", "#")
        # path_now = os.path.join(file_path, now_name+".pkl_write")
        # if os.path.exists(path_now.replace("_write", "_read")):
        #     os.remove(path_now.replace("_write", "_read"))
        # wait_them(file_path, "read", True)  #删掉所有read
        
        # torch.save(num_curr_seqs, path_now)
        # wait_them(file_path, "write", False)  #构建所有write
        # # wait all write
        
        # shutil.copyfile(path_now, path_now.replace("_write", "_read"))
        
        # ### -----broadcast and read lock-----
        # wait_them(file_path, "read", False)  
        # path_13b = os.path.join(file_path, f"meta-llama#Llama-2-13b-hf.pkl_read")
        # path_7b = os.path.join(file_path, f"meta-llama#Llama-2-7b-hf.pkl_read")
        # path_7b_expert = os.path.join(file_path, f"meta-llama#Llama-2-7b-chat-hf.pkl_read")
        # path_list = [path_13b, path_7b, path_7b_expert]
        # num_curr_seqs_min = 999999999
        # for idx, path_i in enumerate(path_list):
        #     num_curr_seqs_now = torch.load(path_i)
        #     num_curr_seqs_min = min(num_curr_seqs_now, num_curr_seqs_min)
        # os.remove(path_now)
        # wait_them(file_path, "write", True)  #删掉所有write
        # os.remove(path_now.replace("_write", "_read"))
        # wait_them(file_path, "read", True)  #删掉所有read
        # # os.rename(path_now.replace("_unread", "_read"), path_now)
        
        
        #make them same
        for i in range(new_add_num_min, new_add_num):
            self.waiting.appendleft(scheduled[i].seq_group)
        num_curr_seqs = num_curr_seqs - (new_add_num  - new_add_num_min) #(new_add_num,new_add_num_min)
        scheduled = scheduled[:new_add_num_min]
        # import pdb;pdb.set_trace()
        for i in scheduled:
            # self.waiting.popleft()
            self._allocate(i.seq_group)
            self.running.append(i.seq_group)  
        self.waiting.extendleft(leftover_waiting_sequences)
        # import pdb;pdb.set_trace()
        if scheduled or ignored_seq_groups:
            self.prev_prompt = True
            scheduler_outputs = SchedulerOutputs(
                scheduled_seq_groups=scheduled,
                prompt_run=True,
                num_batched_tokens=num_batched_tokens,
                blocks_to_swap_in=blocks_to_swap_in,
                blocks_to_swap_out=blocks_to_swap_out,
                blocks_to_copy=blocks_to_copy,
                ignored_seq_groups=ignored_seq_groups,
                num_lookahead_slots=self._get_num_lookahead_slots(
                    is_prefill=True),
            )
            return scheduler_outputs

    # NOTE(woosuk): Preemption happens only when there is no available slot
    # to keep all the sequence groups in the RUNNING state.
    # In this case, the policy is responsible for deciding which sequence
    # groups to preempt.
    self.running = self.policy.sort_by_priority(now, self.running)
    # Reserve new token slots for the running sequence groups.
    running: Deque[SequenceGroup] = deque()
    preempted: List[SequenceGroup] = []
    while self.running:
        seq_group = self.running.popleft()
        now_status = can_append_slots_status(seq_group, self._can_append_slots, self.scheduler_config.model_name)
        while not now_status:
            if self.running:
                # Preempt the lowest-priority sequence groups.
                victim_seq_group = self.running.pop()
                self._preempt(victim_seq_group, blocks_to_swap_out)
                preempted.append(victim_seq_group)
                now_status = can_append_slots_status(seq_group, self._can_append_slots, self.scheduler_config.model_name)
            else:
                # No other sequence groups can be preempted.
                # Preempt the current sequence group.
                self._preempt(seq_group, blocks_to_swap_out)
                preempted.append(seq_group)
                break
        else:
            # Append new slots to the sequence group.
            self._append_slots(seq_group, blocks_to_copy)
            running.append(seq_group)
    self.running = running

    # Swap in the sequence groups in the SWAPPED state if possible.
    self.swapped = self.policy.sort_by_priority(now, self.swapped)
    if not preempted:
        num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
                            for seq_group in self.running)
        curr_loras = set(
            seq_group.lora_int_id
            for seq_group in self.running) if self.lora_enabled else None

        leftover_swapped = deque()

        while self.swapped:
            seq_group = self.swapped[0]
            lora_int_id = 0
            if self.lora_enabled:
                lora_int_id = seq_group.lora_int_id
                if (lora_int_id > 0 and lora_int_id not in curr_loras
                        and len(curr_loras) >= self.lora_config.max_loras):
                    # We don't have a space for another LoRA, so
                    # we ignore this request for now.
                    leftover_swapped.appendleft(seq_group)
                    self.swapped.popleft()
                    continue

            # If the sequence group cannot be swapped in, stop.
            if not self._can_swap_in(seq_group):
                break

            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
            num_new_seqs = seq_group.get_max_num_running_seqs()
            if (num_curr_seqs + num_new_seqs >
                    self.scheduler_config.max_num_seqs):
                break

            if lora_int_id > 0:
                curr_loras.add(lora_int_id)
            self.swapped.popleft()
            self._swap_in(seq_group, blocks_to_swap_in)
            self._append_slots(seq_group, blocks_to_copy)
            num_curr_seqs += num_new_seqs
            self.running.append(seq_group)

        self.swapped.extendleft(leftover_swapped)

    # Each sequence in the generation phase only takes one token slot.
    # Therefore, the number of batched tokens is equal to the number of
    # sequences in the RUNNING state.
    num_batched_tokens = sum(
        seq_group.num_seqs(status=SequenceStatus.RUNNING)
        for seq_group in self.running)

    scheduler_outputs = SchedulerOutputs(
        scheduled_seq_groups=[
            ScheduledSequenceGroup(seq_group=running_group,
                                    token_chunk_size=1)
            for running_group in self.running
        ],
        prompt_run=False,
        num_batched_tokens=num_batched_tokens,
        blocks_to_swap_in=blocks_to_swap_in,
        blocks_to_swap_out=blocks_to_swap_out,
        blocks_to_copy=blocks_to_copy,
        ignored_seq_groups=[],
        num_lookahead_slots=self._get_num_lookahead_slots(
            is_prefill=False),
    )
    return scheduler_outputs

setattr(Scheduler, "_schedule", _schedule)