import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the indices of the red boundary
    red_indices = np.where(input_grid == red)
    top, bottom, left, right = np.min(red_indices[0]), np.max(red_indices[0]), np.min(red_indices[1]), np.max(red_indices[1])
    
    # Extract the matrix within the red boundary
    m = input_grid[top:bottom+1, left:right+1]
    
    # Extract the 3x3 matrix containing non-red and non-black pixels
    non_red_black_indices = np.where((input_grid != red) & (input_grid != black))
    non_red_black_rows, non_red_black_cols = np.unique(non_red_black_indices[0]), np.unique(non_red_black_indices[1])
    n = input_grid[non_red_black_rows[0]:non_red_black_rows[-1]+1, non_red_black_cols[0]:non_red_black_cols[-1]+1]
    
    # Replace each matrix within the non-boundary parts of m with the corresponding pixel from n
    m_non_boundary = m[1:-1, 1:-1]
    non_boundary_shape = (len(m_non_boundary)//3, len(m_non_boundary[0])//3)
    for i in range(3):
        for j in range(3):
            matrix = m[1+i*non_boundary_shape[0]:1+(i+1)*non_boundary_shape[0], 1+j*non_boundary_shape[1]:1+(j+1)*non_boundary_shape[1]]
            pixel = n[i, j]
            matrix[matrix != red] = pixel
    
    return m
    