Edit model card

RaDialog

๐Ÿ“ Paper โ€ข ๐Ÿ–ฅ๏ธ Github โ€ข ๐Ÿ—‚๏ธDataset โ€ข ๐ŸŒ๏ธProject Page

Get Started

Clone repository:

git clone https://huggingface.co/ChantalPellegrini/RaDialog-interactive-radiology-report-generation

Install requirements:

conda create -n llava_hf python=3.10
conda activate llava_hf
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install -r requirements.txt

Run RaDialog inference:

from pathlib import Path

import io

import requests
import torch
from PIL import Image
import numpy as np
from huggingface_hub import snapshot_download

from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8
from LLAVA_Biovil.llava.model.builder import load_pretrained_model
from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1

from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX
from utils import create_chest_xray_transform_for_inference, init_chexpert_predictor


def load_model_from_huggingface(repo_id):
    # Download model files
    model_path = snapshot_download(repo_id=repo_id, revision="main")
    model_path = Path(model_path)

    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
                                                                           model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)


    return tokenizer, model, image_processor, context_len



if __name__ == '__main__':
    sample_img_path = "https://openi.nlm.nih.gov/imgs/512/294/3502/CXR3502_IM-1707-1001.png?keywords=Surgical%20Instruments,Cardiomegaly,Pulmonary%20Congestion,Diaphragm"

    response = requests.get(sample_img_path)
    image = Image.open(io.BytesIO(response.content))
    image = remap_to_uint8(np.array(image))
    image = Image.fromarray(image).convert("L")

    tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="Chantal/RaDialog-interactive-radiology-report-generation")
    cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()

    model.config.tokenizer_padding_side = "left"

    cp_image = cp_transforms(image)
    logits = cp_model(cp_image[None].half().cuda())
    preds_probs = torch.sigmoid(logits)
    preds = preds_probs > 0.5
    pred = preds[0].cpu().numpy()
    findings = cp_class_names[pred].tolist()
    findings = ', '.join(findings).lower().strip()

    conv = conv_vicuna_v1.copy()
    REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
    print("USER: ", REPORT_GEN_PROMPT)
    conv.append_message("USER", REPORT_GEN_PROMPT)
    conv.append_message("ASSISTANT", None)
    text_input = conv.get_prompt()

    # get the image
    vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448)
    image_tensor = vis_transforms_biovil(image).unsqueeze(0)

    image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16)
    input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)

    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)

    # generate a report
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=False,
            use_cache=True,
            max_new_tokens=300,
            stopping_criteria=[stopping_criteria],
            pad_token_id=tokenizer.pad_token_id
        )

    pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
    print("ASSISTANT: ", pred)

    # add prediction to conversation
    conv.messages.pop()
    conv.append_message("ASSISTANT", pred)
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)

    # generate a report
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=False,
            use_cache=True,
            max_new_tokens=300,
            stopping_criteria=[stopping_criteria],
            pad_token_id=tokenizer.pad_token_id
        )

    pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
    print("ASSISTANT: ", pred)

    # add prediction to conversation
    conv.messages.pop()
    conv.append_message("ASSISTANT", pred)
    conv.append_message("USER", "Translate this report to easy language for a patient to understand.")
    conv.append_message("ASSISTANT", None)
    text_input = conv.get_prompt()
    print("USER: ", "Translate this report to easy language for a patient to understand.")

    # generate easy language report
    input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=False,
            use_cache=True,
            max_new_tokens=300,
            stopping_criteria=[stopping_criteria],
            pad_token_id=tokenizer.pad_token_id
        )

    pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
    print("ASSISTANT: ", pred)

โœ๏ธ Citation

@article{pellegrini2023radialog,
  title={RaDialog: A Large Vision-Language Model for Radiology Report Generation and Conversational Assistance},
  author={Pellegrini, Chantal and {\"O}zsoy, Ege and Busam, Benjamin and Navab, Nassir and Keicher, Matthias},
  journal={arXiv preprint arXiv:2311.18681},
  year={2023}
}
Downloads last month
94
Inference Examples
Inference API (serverless) has been turned off for this model.