M3D: Advancing 3D Medical Image Analysis with Multi-Modal Large Language Models
M3D is the pioneering and comprehensive series of work on the multi-modal large language model for 3D medical analysis, including:
- M3D-Data: the largest-scale open-source 3D medical dataset, consists of 120K image-text pairs and 662K instruction-response pairs;
- M3D-LaMed: the versatile multi-modal models with M3D-CLIP pretrained vision encoder, which are capable of tasks such as image-text retrieval, report generation, visual question answering, positioning and segmentation;
- M3D-Bench: the most comprehensive automatic evaluation benchmark covers 8 tasks.
Notifications
- We found that the previous GoodBaiBai88/M3D-LaMed-Llama-2-7B model had problems in the segmentation task. We have fixed this problem and will re-release the new model in the next few days.
Quickstart
Here, we can easily use our model based on Hugging Face.
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import simple_slice_viewer as ssv
import SimpleITK as sikt
device = torch.device('cuda') # 'cpu', 'cuda'
dtype = torch.bfloat16 # or bfloat16, float16, float32
model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Llama-2-7B'
proj_out_num = 256
# Prepare your 3D medical image:
# 1. The image shape needs to be processed as 1*32*256*256, consider resize and other methods.
# 2. The image needs to be normalized to 0-1, consider Min-Max Normalization.
# 3. The image format needs to be converted to .npy
# 4. Although we did not train on 2D images, in theory, the 2D image can be interpolated to the shape of 1*32*256*256 for input.
image_path = "./Data/data/examples/example_01.npy"
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=dtype,
device_map='auto',
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
model_max_length=512,
padding_side="right",
use_fast=False,
trust_remote_code=True
)
model = model.to(device=device)
# question = "Can you provide a caption consists of findings for this medical image?"
question = "What is liver in this image? Please output the segmentation mask."
# question = "What is liver in this image? Please output the box."
image_tokens = "<im_patch>" * proj_out_num
input_txt = image_tokens + question
input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
image_np = np.load(image_path)
image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
# generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
generation, seg_logit = model.generate(image_pt, input_id, seg_enable=True, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
seg_mask = (torch.sigmoid(seg_logit) > 0.5) * 1.0
print('question', question)
print('generated_texts', generated_texts[0])
image = sikt.GetImageFromArray(image_np)
ssv.display(image)
seg = sikt.GetImageFromArray(seg_mask.cpu().numpy()[0])
ssv.display(seg)
- Downloads last month
- 260
Inference API (serverless) does not yet support model repos that contain custom code.