import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find all grey pixels
    grey_pixels = np.where(input_grid == grey)

    # Initialize output grid
    output_grid = np.zeros_like(input_grid)

    # Initialize group ID counter
    group_id = 1

    # Loop through all grey pixels
    for i in range(len(grey_pixels[0])):
        x, y = grey_pixels[0][i], grey_pixels[1][i]

        # Check if pixel has already been assigned to a group
        if output_grid[x][y] != 0:
            continue

        # Initialize new group
        group = [(x, y)]

        # Check neighboring pixels to determine group shape
        if x+1 < input_grid.shape[0] and y+1 < input_grid.shape[1] and \
           input_grid[x+1][y] == grey and input_grid[x][y+1] == grey and input_grid[x+1][y+1] == grey:
            shape = (2, 2)
            color = teal
            group.append((x+1, y))
            group.append((x, y+1))
            group.append((x+1, y+1))
        elif x+2 < input_grid.shape[0] and input_grid[x+1][y] == grey and input_grid[x+2][y] == grey:
            shape = (3, 1)
            color = red
            group.append((x+1, y))
            group.append((x+2, y))
        elif y+2 < input_grid.shape[1] and input_grid[x][y+1] == grey and input_grid[x][y+2] == grey:
            shape = (1, 3)
            color = red
            group.append((x, y+1))
            group.append((x, y+2))
        else:
            # If group shape cannot be determined, skip this pixel
            continue

        # Assign group ID and color to all pixels in group
        for pixel in group:
            output_grid[pixel[0]][pixel[1]] = group_id
            output_grid[pixel[0]][pixel[1]] = color

        # Increment group ID counter
        group_id += 1

    return output_grid