import numpy as np
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def find_rectangles(grid):
    rectangles = []
    for i in range(grid.shape[0]):
        for j in range(grid.shape[1]):
            if grid[i,j] != black:
                for k in range(i, grid.shape[0]):
                    if grid[k,j] == black:
                        break
                    for l in range(j, grid.shape[1]):
                        if grid[k,l] == black:
                            break
                        if (k-i+1)*(l-j+1) == np.count_nonzero(grid[i:k+1,j:l+1] == grid[i,j]):
                            rectangles.append((i,j,k,l,grid[i,j]))
    return rectangles

# Define a function to keep only the max rectangle and let other rectangles be black
def main(grid):
    rectangles = find_rectangles(grid)
    max_rectangle = max(rectangles, key=lambda r: (r[2]-r[0]+1)*(r[3]-r[1]+1))
    output_grid = np.full(grid.shape, black)
    output_grid[max_rectangle[0]:max_rectangle[2]+1,max_rectangle[1]:max_rectangle[3]+1] = max_rectangle[4]
    return output_grid