from abc import ABC, abstractmethod
from typing import Callable

import pandas as pd

from ..events import RequestEvent
from ..state import WorldState
from .offer import Offer
from ..routing import get_route
from ..experiments import ExperimentPolicy


class PricingPolicy(ABC):
    def __init__(self):
        self.log = []

    def log_df(self):
        return pd.DataFrame(self.log)

    @abstractmethod
    def make_offer(self, state: WorldState, event: RequestEvent) -> Offer:
        """
        Offers a price and an ETD guarantee to a customer.
        """
        raise NotImplementedError()


class ConstantFactorPricingPolicy(PricingPolicy):
    def __init__(self, price_factor: float,
                 etd_factor: float,
                 cost_fn: Callable,
                 dispatcher=None,
                 cost_basis="solo",
                 pickup_eta=300):
        super(ConstantFactorPricingPolicy, self).__init__()
        self.price_factor = price_factor
        self.etd_factor = etd_factor
        self.cost_fn = cost_fn
        self.dispatcher = dispatcher
        self.cost_basis = cost_basis
        self.pickup_eta = pickup_eta

    def make_offer(self, state: WorldState, event: RequestEvent) -> Offer:
        """
        Computes the ETD and cost to fulfill the request, and returns
        an offer with ETD and cost multiplied by a constant factor.

        `self.cost_basis` determines the type of (hypothetical) dispatch
        used to compute the cost of fulfilling the request:

        - "solo" dispatches the rider directly to their destination.
        - "greedy" finds the minimum cost dispatch given current supply.
        """
        nn = state.get_nearest_drivers(event.rider.src, n=1)
        if len(nn) == 0:
            etd = float("Inf")
            cost = float("Inf")
        else:
            nearest_driver = nn[0]
            if self.cost_basis == "solo":
                route = get_route(state.ts,
                                  [nearest_driver.latlng(state.ts),
                                   event.rider.src,
                                   event.rider.dest])
                etd = route.total_secs
                cost = self.cost_fn(route)
            elif self.cost_basis == "greedy":
                hypothetical_offer = Offer(float("Inf"), 0., dict())
                hypothetical_dispatches = self.dispatcher.dispatch(
                    state, hypothetical_offer, event.rider)
                if len(hypothetical_dispatches) == 0:
                    etd = float("Inf")
                    cost = float("Inf")
                else:
                    dispatch = hypothetical_dispatches[0]
                    # print("1. Insertion cost: ",
                    #       dispatch.insertion_cost,
                    #       "(Solo)" if state.drivers[dispatch.driver_id].is_idle
                    #       else "(Pool)")
                    # print("2. Solo cost: ", cost)
                    etd = dispatch.route.etd(event.rider.id)
                    cost = dispatch.insertion_cost
            else:
                raise NotImplementedError(
                    "Invalid cost_basis %s" % self.cost_basis)

        self.log.append(dict(
            ts=state.ts,
            rider_id=event.rider_id,
            cost_basis=cost,
            etd=etd,
            price_factor=self.price_factor,
            etd_factor=self.etd_factor))

        return Offer(etd * self.etd_factor, cost * self.price_factor)


class PricingExperimentPolicy(ExperimentPolicy, PricingPolicy):
    def make_offer(self, state: WorldState, event: RequestEvent) -> Offer:
        return self.get_policy(event).make_offer(state, event)
