import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find non-black blocks in input grid
    non_black_blocks = np.where(input_grid != black)
    
    # Fill blocks below each non-black block with its color
    for row, col in zip(non_black_blocks[0], non_black_blocks[1]):
        color = input_grid[row, col]
        for i in range(row+1, 3):
            input_grid[i, col] = color
    
    return input_grid
    