# This script solves the large-scale QP problem for matching the user level causal exposure and
# obtain the obtain the optimal p_{ij}.
# For details please see: TODO: Add link to the main wiki / doc.

# The inputs
# dataPath : Path to the full data containing all the information
# groupSize : The max size of consumer set for each run of the QP.
# maxIter : The total number of iterations for convergence


getPq <- function(fullData) {
  # Function that returns the P matrix and q vector correponding to solving the optimization problem
  # Min 0.5x'Px + x'q s.t. l <= Ax <= u
  #
  # Args:
  #   fullData: Data frame containing the full information.
  #   (index, producerId, lowerBound, h, q, consumerId, pijBase, Aj, upperBound, pijOld, pijCurrent, htilde)
  #
  # Returns:
  #   The objective matrix P and the objective vector q. For the details of these objects please see the wiki.
  
  # Forcing an ordering to keep sync.
  fullData <- fullData[order(producerId, consumerId)]
  prodSetData <- fullData %>%
    group_by(producerId) %>%
    summarize(Aj = list(Aj), htildeFinal = first(htildeFinal), q = first(q))
  prodSetData$matAi <- lapply(prodSetData$Aj, function(x) { tcrossprod(matrix(x))})
  prodSetData$Pi <- mapply('*', prodSetData$matAi, 2 * prodSetData$q ^ 2, SIMPLIFY = FALSE)
  prodSetData$veci <- mapply('*', prodSetData$Aj, - 2 * prodSetData$q * prodSetData$htildeFinal, SIMPLIFY = FALSE)
  
  return(list(P = bdiag(prodSetData$Pi), q = unlist(prodSetData$veci)))
}

getAlu <- function(fullData) {
  # Function that returns the A matrix and l,u vectors correponding to solving the optimization problem
  # Min 0.5x'Px + x'q s.t. l <= Ax <= u
  #
  # Args:
  #   fullData: Data frame containing the full information.
  #   (index, producerId, lowerBound, h, q, consumerId, pijBase, Aj, upperBound, pijOld, pijCurrent, htilde)
  #
  # Returns:
  #   The constraint matrix A and the lower and upper bound vectors l, u respectively. For the details
  #   of these objects please see the wiki.
  
  # Forcing an ordering to keep sync.
  fullData <- fullData[order(producerId, consumerId)]
  neighborData <- fullData %>%
    group_by(consumerId) %>%
    summarize(colIds = list(tempIndex), uj = first(uj), lj = first(lj))
  neighborData$rowIds <- lapply(1 : nrow(neighborData),
                                function(x) { rep(x, length(neighborData$colIds[[x]]))})
  
  rowIds <- unlist(neighborData$rowIds)
  colIds <- unlist(neighborData$colIds)
  values <- rep(1, length(rowIds))
  
  secondMat <- sparseMatrix(i = rowIds, j = colIds, x = values)
  firstMat <- Diagonal(n = ncol(secondMat))
  Afull <- rBind(firstMat, secondMat)
  
  lvec <- c(fullData$lowerBound, neighborData$lj)
  uvec <- c(fullData$upperBound, neighborData$uj)
  
  return(list(A = Afull, l = lvec, u = uvec))
}

getHTildeFinal <- function(finalData, jSet) {
  # Function that adjusts the htilde accordingly corresponding to the chosen consumerSet
  #
  # Args:
  #   finalData: Data frame containing the full information-
  #   (index, producerId, lowerBound, h, q, consumerId, pijBase, Aj, upperBound, pijOld, pijCurrent, htilde)
  #   jSet: Set of consumers
  #
  # Returns:
  #   The producer set containing the correct htilde information to be used in the grouped
  # optimization problem.
  
  prodSet = finalData[which(finalData$consumerId %in% jSet),]
  prodSetFinal <- prodSet[, htildeFinal := first(htilde) + sum(Aj * pijCurrent * q), by = producerId]
  prodSetFinal <- prodSetFinal[order(producerId, consumerId)]
  prodSetFinal$tempIndex <- seq.int(nrow(prodSetFinal))
  
  return(prodSetFinal)
}

getObjValue <- function(data, isFull = TRUE) {
  # Function that generates the objective value
  #
  # Args:
  #   data: Data frame containing the full information-
  #     (index, producerId, lowerBound, h, q, consumerId, pijBase, Aj, upperBound, pijOld, pijCurrent, htilde)
  #   isFull: A boolean indicator that differentiates if we are using the full data or the prodSet data.
  #
  # Returns:
  #   The objective value:
  #       sum_{producerIds} (h_i - q_i \sum_{j in Ci} p_{ij} Aj) if isFull = TRUE
  #       sum_{producerIds} (htilde_i - q_i \sum_{j in Ci} p_{ij} Aj) if isFull = FALSE
  
  if (isFull) {
    prodData <- data %>%
      group_by(producerId) %>%
      summarize(total = first(h) - sum(Aj * pijCurrent * q))
  } else {
    prodData <- data %>%
      group_by(producerId) %>%
      summarize(total = first(htildeFinal) - sum(Aj * pijCurrent * q))
  }
  
  objValue <- mean((prodData$total) ^ 2)
  return(objValue)
}

QPSolver <- function(finalData, groupSize = 100, maxIter = 10, convergenceLimit = 1e-5){
  finalData <- finalData %>% select(producerId = producer, consumerId = consumer, h = exposureDiff,pijBase = baseLevelExposure, 
                                    q = q, Aj = Aj, lj = lj, uj = uj, lowerBound = lowerBound, upperBound = upperBound) %>% 
    mutate(pijOld = pijBase, pijCurrent = pijBase) %>% group_by(producerId) %>% 
    mutate(htilde = first(h) - sum(Aj * pijCurrent * q))
  finalData$index <- 1:nrow(finalData)
  finalData <- as.data.table(finalData)
  
  
  error <- 1
  objVal <- 1
  count <- 0
  
  ##################### Starting the Execution ###############################
  
  while ((count < maxIter) && (error > convergenceLimit) && (objVal > convergenceLimit)) {
    
    # Setting Consumer Chunks
    consumerSet <- sample(unique(finalData$consumerId), replace = FALSE)
    consumerSetChunks <- split(consumerSet, ceiling(seq_along(consumerSet) / groupSize))
    chunkLen <- length(consumerSetChunks)
    
    setup.time <- 0
    qp.time <- 0
    update.time <- 0
    
   # print(paste("Objective Value BEFORE iteration is :", getObjValue(finalData)))
    
    for (ct in 1 : length(consumerSetChunks)) {
      
      jSet <- consumerSetChunks[[ct]]
      
      startTime <- proc.time()
      prodSet <- getHTildeFinal(finalData, jSet)
      
    #  print(paste("Objective Value ProdSet BEFORE iteration", ct, "is :", getObjValue(prodSet, FALSE)))
      
      Pq <- getPq(prodSet)
      Alu <- getAlu(prodSet)
      endTime <- proc.time() - startTime
      setup.time <- setup.time + endTime[3]
      
      startTime <- proc.time()
      results <- solve_osqp(Pq$P, Pq$q, Alu$A, Alu$l, Alu$u,
                            osqpSettings(eps_abs = 1e-6, eps_rel = 1e-6, verbose = FALSE))
      endTime <- proc.time() - startTime
      qp.time <- qp.time + endTime[3]
      
      startTime <- proc.time()
      prodSet$pijCurrent <- results$x
      
     # print(paste("Objective Value ProdSet AFTER iteration", ct, "is :", getObjValue(prodSet, FALSE)))
      
      finalData$pijCurrent[prodSet$index] <- results$x
      finalData <- finalData[, htilde := first(h) - sum(Aj * pijCurrent * q), by = producerId]
      endTime <- proc.time() - startTime
      update.time <- update.time + endTime[3]
    }
    
    error = norm(finalData$pijOld - finalData$pijCurrent, type = "2") / norm(finalData$pijOld, type = "2")
    finalData$pijOld <- finalData$pijCurrent
    
   # print(paste("Error: ", error))
    objVal <- getObjValue(finalData)
    print(paste("Iteration :", count, ", ChunkSize:", chunkLen))
    print(paste("Objective Value after full iteration: ", objVal))
    print(paste("Total Setup Time:", setup.time, ", Total QP time:", qp.time, ", Total Update Time:", update.time))
    #print(paste("Total Time: ", setup.time + qp.time + update.time))
    
    count <- count + 1
  }
  
  output <- subset(finalData, select = c("producerId", "consumerId", "pijBase", "pijCurrent"))
  return(output)
}

