#### Packages and project directory ####

#also going to need CDO command line tools
library(foreach)
library(doParallel)
library(patchwork)
library(fdasrvf)
library(tidyverse)
library(maps)
project_dir = '/mnt/r/wasserstein/'
data_dir = '/mnt/r/historical/'
setwd(project_dir)

#### Slicing helper functions ####

#kernel function
wendland_kernel = function(d,r){
  (d<r)*(1/3)*((1-d/r)^6)*(35*(d/r)^2 + 18*(d/r) + 3)
}

#chordal distance
distChordal = function(x,y,r=6371){
  #angular formula
  r*acos(sin(x[,1])*sin(y[,1])+cos(x[,1])*cos(y[,1])*cos(x[,2]-y[,2]))
  #r*sqrt(2-2*(sin(x[,2])*sin(y[,2])+cos(x[,2])*cos(y[,2])*cos(x[,1]-y[,1])))
}

#function to create regular lat/long grids
grid_lat_long = function(n_lats = 48, n_longs = 96){
  lats = seq(90,-90,length.out = n_lats+2)[2:(n_lats+1)]
  longs = seq(0,360,length.out = n_longs+1)[1:n_longs]
  round(expand.grid(dlat = lats, dlong = longs),3)
}

#helper to find one nearest neighbors on the poles
nearest_outside_extent = function(m,n_lat_up=320){
  #find na rows on top and bottom (outside extent of data)
  na_rows = which(is.na(m[,1]))
  top_na_rows = na_rows[na_rows<n_lat_up/2]
  bot_na_rows = na_rows[na_rows>n_lat_up/2]
  
  #count na rows
  lt = length(top_na_rows)
  lb = length(bot_na_rows)
  
  #pad rows with the index of the nearest point in the data
  if(lt>0){m[top_na_rows,] = m[rep(max(top_na_rows)+1,lt),]}
  if(lb>0){m[bot_na_rows,] = m[rep(min(bot_na_rows)-1,lb),]}
  
  #tiebreak cols
  cols = 1:dim(m)[2]
  na_cols = which(is.na(m[1,]))
  if(length(na_cols)>0){
    cols[na_cols] = NA
    nearest_full_col = sapply(na_cols,function(na_col){which.min(abs(na_col-cols))})
    m[,na_cols] = m[,nearest_full_col]
  }
  
  m
}

## Weight by Latitude Helper
w_latitude = function(n_lats){
  w = cos(seq(pi/2,-pi/2,length.out = n_lats+2)[2:(n_lats+1)])
  #w_sum = sum(w*n_longs)
  w
}


#### Data processing helpers ###

#for parsing CMIP model output files
ymd_range = function(s,e,calendar='standard'){
  s_year = str_sub(s,1,4)
  s_month = str_sub(s,5,6)
  s_day = str_sub(s,7,8)
  e_year = str_sub(e,1,4)
  e_month = str_sub(e,5,6)
  e_day = str_sub(e,7,8)
  if((calendar == 'noleap') || (calendar=="365_day")){
    dates = seq(as.Date(paste(s_year,s_month,s_day,sep='-')),as.Date(paste(e_year,e_month,e_day,sep='-')),by='1 day')
    dates_num = as.numeric(gsub('-','',dates))
    date_range = dates_num
    leap_dates = which(substr(dates_num,5,8) == '0229')
    if(length(leap_dates)>0){
      date_range = dates_num[-leap_dates]
    }
  }else if(calendar == '360_day'){
    dates_full = do.call(paste0, expand.grid(as.character(s_year:e_year), str_pad(paste(1:12),2,pad="0"), str_pad(paste(1:30),2,pad="0")))
    dates_full = sort(as.numeric(dates_full))
    date_range = dates_full[which(dates_full==s):which(dates_full==e)]
  }else{
    dates = seq(as.Date(paste(s_year,s_month,s_day,sep='-')),as.Date(paste(e_year,e_month,e_day,sep='-')),by='1 day')
    date_range = as.numeric(gsub('-','',dates))
  }
  date_range
}


#### SCWD and GMWD helper functions ####

# Helper function (l2 norm)
sqrt_mean_sq = function(x){sqrt(mean(x^2))}

# Helper function (squared mean)
wasserstein_1d = function(x,y,q){
  sqrt_mean_sq(quantile(x,q)-quantile(y,q))
}

# Quantiles for each slice
slice_quantiles = function(x_sliced, q){
  apply(x_sliced, 2:3, quantile, probs = q, na.rm = T)
}

# Sperical Convolutional Wasserstein Distance algo
# Provide the quantiles of interest for each slice in the first two arguments
# Provide spatial weights in the second two arguments
scwd = function(xq, yq, weights){
  wd_per_slice = apply(xq - yq, 2:3, sqrt_mean_sq)
  list(scwd    = sqrt(weighted.mean(wd_per_slice^2,weights)),
       wd_vals = wd_per_slice)
}

# Demonstrate a single slice
convo_slice = function(x,y,center,range){
  n = dim(x)[1]
  temp = array(0,dim=c(n,320,640))
  for(i in 1:n){
    temp[i,,] = EBImage::resize(x[i,,],w=320,h=640,filter='none')
  }
  x = temp
  for(i in 1:n){
    temp[i,,] = EBImage::resize(y[i,,],w=320,h=640,filter='none')
  }
  y = temp
  
  d1 = dim(x)[2]
  d2 = dim(x)[3]
  d = d1*d2
  dim(x) = c(n,d)
  dim(y) = c(n,d)
  
  degree_latlong = grid_lat_long(d1,d2)
  lat_w = cos((degree_latlong$dlat)*pi/180)
  
  dist_all = distChordal((pi/180)*rbind(center[2:1]),(pi/180)*degree_latlong,r=6371)
  cov_w = wendland_kernel(dist_all,range)
  w = cov_w*lat_w
  w = w/sum(w)
  w = matrix(w,d,1)
  
  list(w = matrix(w,d1,d2),
       x_slice = as.numeric(x%*%w),
       y_slice = as.numeric(y%*%w))
}

#computational speedup for sparse matrices
#https://stackoverflow.com/questions/55407656/r-sweep-on-a-sparse-matrix
sweep_sparse <- function(x, margin, stats, fun = "*") {
  f <- match.fun(fun)
  if (margin == 1) {
    idx <- x@i + 1
  } else {
    idx <- x@j + 1
  }
  x@x <- f(x@x, stats[idx])
  return(x)
}

#plotting spatial fields
image.spatial = function(field,value='Temperature',zlims=FALSE,flip=FALSE,guide='colorbar'){
  cols = c("x", "y")
  if(flip==TRUE){
    cols = c('y','x')
  }
  m_df <- reshape2::melt(field, cols, value.name = value)
  if(length(zlims)!=2){
    min_val = min(field)
    max_val = max(field)
  }else{
    min_val = zlims[1]
    max_val = zlims[2]
  }
  if(value=='Temperature'){
    units = '°C'
  }else{
    units=value
  }
  ggplot(data=m_df,aes_string(x='x',y='y',fill=value))+
    geom_tile()+
    coord_flip()+
    scale_x_reverse()+
    viridis::scale_fill_viridis(option='turbo',limits = c(min_val,max_val),guide=guide)+
    theme_void()+
    theme(aspect.ratio = 1/1.65)+
    labs(fill=units)
}

#Continental outline for figures
mp1 <- fortify(rworldmap::coastsCoarse) #fortify(maps::map(fill=TRUE, plot=FALSE))
ggplot()+
  geom_path(data = mp1,aes(x=long,y=lat,group=group,fill=NULL))

mp1$group = as.numeric(mp1$group)
mp2 <- mp1
mp2$long <- mp2$long + 360
mp2$group <- mp2$group + max(mp2$group) + 1
mp <- rbind(mp1, mp2)

#image.spatial with continental outlines (rectangular projection)
isc = function(field, coords, value='Temperature',zlims=FALSE,flip=FALSE,guide='colorbar'){
  cols = c("x", "y")
  if(flip==TRUE){
    cols = c('y','x')
  }
  #m_df <- reshape2::melt(field, cols, value.name = value)
  m_df = data.frame(
    x = coords$dlong,
    y = coords$dlat
  )
  m_df[,value] = as.numeric(field)
  if(length(zlims)!=2){
    min_val = min(field)
    max_val = max(field)
  }else{
    min_val = zlims[1]
    max_val = zlims[2]
  }
  if(value=='Temperature'){
    units = '°C'
  }else{
    units=value
  }
  ggplot(data=m_df)+
    geom_tile(aes_string(x='x',y='y',fill=value))+
    geom_path(data=mp,mapping=aes(x=long,y=lat,group=group))+
    xlim(10,350)+
    viridis::scale_fill_viridis(option='turbo',limits = c(min_val,max_val),guide=guide)+
    theme_minimal()+
    theme(aspect.ratio = 1/1.65, axis.text = element_blank())+
    labs(fill=units,x=NULL,y=NULL)
}

#image.spatial with continental outlines (mollweide/oval projection)
isc2 = function(field, coords, value='Temperature',zlims=FALSE,flip=FALSE,guide='colorbar'){
  cols = c("x", "y")
  if(flip==TRUE){
    cols = c('y','x')
  }
  diff = mean(diff(sort(unique(coords$dlong))))
  #m_df <- reshape2::melt(field, cols, value.name = value)
  m_df = data.frame(
    x = coords$dlong+diff/2,
    y = coords$dlat
  )
  m_df[,value] = as.numeric(field)
  if(length(zlims)!=2){
    min_val = min(field)
    max_val = max(field)
  }else{
    min_val = zlims[1]
    max_val = zlims[2]
  }
  if(value=='Temperature'){
    units = '°C'
  }else{
    units=value
  }
  outline = mp
  outline$long = outline$long+diff/2
  ggplot(data=m_df)+
    geom_tile(aes_string(x='x',y='y',fill=value,color=value),linewidth=0.1)+
    geom_path(data=outline,mapping=aes(x=long,y=lat,group=group),linewidth=0.2)+
    xlim(0,360)+
    viridis::scale_fill_viridis(option='turbo',limits = c(min_val,max_val),guide=guide)+
    viridis::scale_color_viridis(option='turbo',limits = c(min_val,max_val),guide='none')+
    theme_minimal()+
    theme(aspect.ratio = 1/1.65, axis.text = element_blank())+
    labs(fill=units,x=NULL,y=NULL)+coord_map('moll',xlim=c(diff/2,360-diff/2))
}

