import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Get the dimensions of the input grid
    rows, cols = input_grid.shape

    # Find all not black points in the first row
    not_black_points = []
    for col in range(cols):
        if input_grid[0][col] != black:
            not_black_points.append((0, col))

    # Initialize the output grid with the copy of the first two rows
    output_grid = np.full((rows, cols), black)
    output_grid[0] = input_grid[0]
    for col in range(cols):
        if input_grid[0][col] != black:
            if col > 0:
                output_grid[1][col-1] = input_grid[0][col]
            if col < cols-1:
                output_grid[1][col+1] = input_grid[0][col]
    output_grid[2::2] = output_grid[0]
    output_grid[3::2] = output_grid[1]

    # Color the remaining rows
    for row in range(2, rows):
        # Color the not black points in the first row
        for prev_row, prev_col in not_black_points:
            if prev_col > 0:
                output_grid[row][prev_col-1] = output_grid[row-1][prev_col]
            if prev_col < cols-1:
                output_grid[row][prev_col+1] = output_grid[row-1][prev_col]

    return output_grid


