#!/usr/bin/env python3

from jax import numpy as jnp


def _augment_feat_array(feat_array):
    """
    feat_array is a feature array indexed by example then dimension
    Return an array that has a column of 1s first
    """
    s = feat_array.shape

    if not len(s) in (1, 2):
        raise ValueError("feature array must have 1 or 2 dimensions")

    if len(s) == 1:
        feat_array = feat_array.reshape(s[0], 1)

    return jnp.hstack((jnp.ones((s[0], 1)), feat_array))
