import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the coordinates of all gray pixels
    gray_pixels = np.argwhere(input_grid == grey)
    
    # Output all 3*3 grids centered on gray pixels
    grids = []
    for pixel in gray_pixels:
        x, y = pixel
        grid = input_grid[x-1:x+2, y-1:y+2]
        grids.append(grid)
    
    # Add all the 3*3 grids and make the center grey
    output_grid = np.sum(grids, axis=0)
    output_grid[1,1] = grey
    
    return output_grid