import numpy as np

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the number of unique colors in the input grid (excluding black)
    unique_colors = np.unique(input_grid)
    n = len(unique_colors[unique_colors != 0])

    # Create an empty output grid
    output_grid = np.zeros((3*n, 3*n), dtype=int)

    # Map each pixel in the input grid to a nxn grid with the same color in the output grid
    for i in range(3):
        for j in range(3):
            color = input_grid[i][j]
            if color != 0:
                output_grid[i*n:(i+1)*n, j*n:(j+1)*n] = color

    return output_grid