import numpy as np
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)
def main(input_grid: np.ndarray) -> np.ndarray:
    # Divide the input grid into four sub-grids
    sub_grid_list = [
        input_grid[:2, :2],
        input_grid[:2, 3:],
        input_grid[3:, :2],
        input_grid[3:, 3:],
    ]
    # Check each sub-grid for the number of non-black blocks
    non_black_blocks = []
    for i in range(len(sub_grid_list)):
        other_sub_grids = sub_grid_list[:i] + sub_grid_list[i + 1 :]
        non_black_blocks.append(np.count_nonzero(sub_grid_list[i] != black))
        if non_black_blocks[i] != min(
            [np.count_nonzero(sub_grid != black) for sub_grid in other_sub_grids]
        ) and non_black_blocks[i] != max(
            [np.count_nonzero(sub_grid != black) for sub_grid in other_sub_grids]
        ):
            special_sub_grid = sub_grid_list[i]
            break
    return special_sub_grid.tolist()
