import torch
import yaml
from collections import OrderedDict
from labeler.model.chexgpt_labeler import Model

HEAD_CONFIG_PATH = "./configs/head/evaluation_chexgpt.yaml"
PRETRAIN_PATH = "./checkpoint/model_mixed.ckpt"

# read head configuration
with open(HEAD_CONFIG_PATH) as f:
    head_cfg = yaml.load(f, Loader=yaml.FullLoader)
label_map = head_cfg["label_map"]

# load model
model = Model(label_map).eval()
ckpt = torch.load(PRETRAIN_PATH, map_location="cpu")
new_state_dict = OrderedDict()
for k, v in ckpt['state_dict'].items():
    name = k[6:]  # remove `model.`
    new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=True)
tokenizer = model.get_tokenizer()

# inference
inputs = ["Focal consolidation is identified.",
          "There is likely some atelectasis at the right base.",
          "Displaced left seventh posterior rib fracture.",
          "The lungs are hyperinflated compatible with underlying COPD.",
          "The heart is enlarged.",
          ]

# inputs = ["No pleural effusion or pneumothorax is seen.",
#                 "There is no pneumothorax or pleural effusion.",
#                 "The extent of the pleural effusion is reduced.",
#                 "The extent of the pleural effusion remains constant.",
#                 "Interval enlargement of pleural effusion."]

# inputs = [
#     "Widespread subcutaneous emphysema and pneumomediastinum are new compared to the prior radiograph. The presence of this extent of subcutaneous emphysema reduces the sensitivity for detecting pneumothorax. At the time of this dictation, the patient has undergone a chest CT, which more clearly delineates the abnormal gas collections. Please see separately dictated CT under clip ___. Right pigtail pleural catheter has changed in position compared to prior radiograph, and is more fully evaluated on the subsequent CT as well. Moderate-to-large right pleural effusion has increased in size since the prior radiograph. Worsening opacity in the right mid and lower lung may reflect atelectasis or aspiration."]

tokenized = tokenizer(
    inputs,
    padding=True,
    truncation=True,
    add_special_tokens=True,
    return_tensors="pt",
    pad_to_multiple_of=8)

with torch.inference_mode():
    outputs = model(tokenized.input_ids, tokenized.attention_mask)

# display outputs
for head_name, attrs in label_map.items():
    for attr_name, attr_info in attrs.items():
        text = outputs[head_name][attr_name]["prediction_text"]
        print(f"{head_name} | {attr_name} | {text}")


# atelectasis | status | ['not_exist', 'exist', 'not_exist', 'not_exist', 'not_exist']
# consolidation | status | ['exist', 'not_exist', 'not_exist', 'not_exist', 'not_exist']
# effusion | status | ['not_exist', 'not_exist', 'not_exist', 'not_exist', 'not_exist']
# fracture | status | ['not_exist', 'not_exist', 'exist', 'not_exist', 'not_exist']
# hyperinflation | status | ['not_exist', 'not_exist', 'not_exist', 'exist', 'not_exist']
# lung opacity | status | ['not_exist', 'not_exist', 'not_exist', 'not_exist', 'not_exist']
# nodule | status | ['not_exist', 'not_exist', 'not_exist', 'not_exist', 'not_exist']
# pleural lesion | status | ['not_exist', 'not_exist', 'not_exist', 'not_exist', 'not_exist']
# pneumothorax | status | ['not_exist', 'not_exist', 'not_exist', 'not_exist', 'not_exist']
# pulmonary edema | status | ['not_exist', 'not_exist', 'not_exist', 'not_exist', 'not_exist']
# subcutaneous emphysema | status | ['not_exist', 'not_exist', 'not_exist', 'not_exist', 'not_exist']
# subdiaphragmatic gas | status | ['not_exist', 'not_exist', 'not_exist', 'not_exist', 'not_exist']
# widened mediastinal silhouette | status | ['not_exist', 'not_exist', 'not_exist', 'not_exist', 'exist']
