# Copyright (c) 2017-present, XXX, Inc.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

FIND_PACKAGE(CUDA)
IF(CUDA_FOUND)
  ADD_DEFINITIONS(-DHAVE_CUDA)
  SET(USE_NCCL ON)
  SET(USE_CUDA ON)
ENDIF(CUDA_FOUND)

IF(MSVC)
  # Extra setup for MSVC
  INCLUDE_DIRECTORIES("${GLOG_ROOT_DIR}/include")
  INCLUDE_DIRECTORIES("${GLOG_ROOT_DIR}/src/windows")
  INCLUDE_DIRECTORIES("${GFLAGS_INCLUDE_DIR}")
  ADD_DEFINITIONS(-D_CRT_SECURE_NO_WARNINGS)
  # Warning level 1 is enough
  IF(CMAKE_CXX_FLAGS MATCHES "/W[0-4]")
    STRING(REGEX REPLACE "/W[0-4]" "/W1" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
  ELSE()
    SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W1")
  ENDIF()
  LINK_DIRECTORIES("${PROJECT_SOURCE_DIR}/deps")
ENDIF(MSVC)

FILE(GLOB SRCS *.cpp)
FILE(GLOB AGSRCS autograd/*.cpp)
LIST(APPEND SRCS ${AGSRCS})

ADD_LIBRARY(common "${SRCS}")
TARGET_COMPILE_OPTIONS(common PRIVATE "${CHERRYPI_WARNINGS}")
SET_PROPERTY(TARGET common PROPERTY POSITION_INDEPENDENT_CODE ON)

IF(CUDA_FOUND)
  TARGET_INCLUDE_DIRECTORIES(common SYSTEM PUBLIC ${CUDA_TOOLKIT_INCLUDE})
ENDIF(CUDA_FOUND)

TARGET_LINK_LIBRARIES(common Torch)

TARGET_LINK_LIBRARIES(common glog fmt ${ZSTD_LIBRARY})
TARGET_INCLUDE_DIRECTORIES(common PUBLIC
  "${PROJECT_SOURCE_DIR}/3rdparty/include"
  "${PROJECT_SOURCE_DIR}/3rdparty/range-v3/include"
  "${PROJECT_SOURCE_DIR}/3rdparty/cereal/include"
  "${PROJECT_SOURCE_DIR}/3rdparty/fmt/include"
)

IF(MSVC)
  INCLUDE_DIRECTORIES("${CMAKE_CURRENT_SOURCE_DIR}")
  INCLUDE_DIRECTORIES("${PROJECT_SOURCE_DIR}")
  INCLUDE_DIRECTORIES("${PROJECT_SOURCE_DIR}/3rdparty/bwapilib")
  INCLUDE_DIRECTORIES("${PROJECT_SOURCE_DIR}/3rdparty/include")
  INCLUDE_DIRECTORIES("${CMAKE_BINARY_DIR}/tc/include")
ENDIF(MSVC)
