import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find columns with non-black pixels
    non_black_cols = []
    for col in range(input_grid.shape[1]):
        if not np.all(input_grid[:, col] == black):
            non_black_cols.append(col)

    # Get number of non-black pixels and color for each column
    col_info = []
    for col in non_black_cols:
        num_non_black = np.count_nonzero(input_grid[:, col] != black)
        col_color = input_grid[np.nonzero(input_grid[:, col])[0][0], col]
        col_info.append((col, num_non_black, col_color))

    # Create output grid
    output_grid = np.zeros_like(input_grid)
    output_grid[:-3, :] = black
    for col, num_non_black, col_color in col_info:
        output_grid[-num_non_black:, col] = col_color

    return output_grid