# coding=utf-8
# Copyright 2019 The Hal Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Config for network architecture.

Not really used other than for the high-level.
"""

direct_obs = True
obs_type = 'direct'
action_type = 'perfect'
res = 64
input_dim = [res, res, 3] if not direct_obs else [10]
if obs_type == 'order_invariant':
  input_dim = [None, 8]
ac_dim = [800] if action_type == 'discrete' else [40]
conv_layer_config = [(48, 8, 2), (128, 5, 2), (192, 3, 1), (64, 3, 1)]
dense_layer_config = [256, 512, 1024]
encoder_n_unit = 32
vocab_size = 36
embedding_size = 8
MAX_LENGTH = 21


class Config(object):
  """Config object."""
  def __init__(self):
    self.direct_obs = direct_obs
    self.obs_type = obs_type
    self.action_type = action_type
    self.input_dim = input_dim
    self.ac_dim = ac_dim
    self.conv_layer_config = conv_layer_config
    self.dense_layer_config = dense_layer_config
    self.encoder_n_unit = encoder_n_unit
    self.vocab_size = vocab_size
    self.embedding_size = embedding_size
    self.max_len = MAX_LENGTH

# =======================================================================
max_num_shape = 1
max_num_size = 0
max_num_texture = 0
max_num_color = 5
single_obj_feature_length = (2 + max_num_shape + max_num_size
                             + max_num_texture + max_num_color)
descriptor_length = 64
inner_product_length =32


class VariableInputConfig(Config):
  """Variable input config object"""
  def __init__(self):
    super(VariableInputConfig, self).__init__()
    self.input_dim = [None, single_obj_feature_length]
    self.ac_dim = 8
    self.obs_type = 'order_invariant'
    self.des_len = descriptor_length
    self.inner_len = inner_product_length


def get_config(order_invariant=False):
  if order_invariant:
    return VariableInputConfig()
  else:
    return Config()
