--- license: apache-2.0 --- ```python import os, torch, transformers from PIL import Image from torchvision import transforms os.environ['TRANSFORMERS_CACHE'] = '/scratch1/nic261/hf_cache' os.environ['HUGGINGFACE_HUB_CACHE'] = '/scratch1/nic261/hf_cache' ckpt_name = 'aehrc/mimic-cxr-report-gen-single' encoder_decoder = transformers.AutoModel.from_pretrained(ckpt_name, trust_remote_code=True) tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(ckpt_name) image_processor = transformers.AutoFeatureExtractor.from_pretrained(ckpt_name) test_transforms = transforms.Compose( [ transforms.Resize(size=image_processor.size['shortest_edge']), transforms.CenterCrop(size=[ image_processor.size['shortest_edge'], image_processor.size['shortest_edge'], ] ), transforms.ToTensor(), transforms.Normalize( mean=image_processor.image_mean, std=image_processor.image_std, ), ] ) url = 'https://www.stritch.luc.edu/lumen/meded/radio/curriculum/IPM/PCM/86a_labelled.jpg' response = requests.get(url) image_a = Image.open(BytesIO(response.content)) image_a = image_a.convert('RGB') image_a = test_transforms(image_a) url = 'https://prod-images-static.radiopaedia.org/images/566180/d527ff6fc1482161c9225345c4ab42_big_gallery.jpg' response = requests.get(url) image_b = Image.open(BytesIO(response.content)) image_b = image_b.convert('RGB') image_b = test_transforms(image_b) images = torch.stack([image_a, image_b], dim=0) images.shape outputs = encoder_decoder.generate( pixel_values=images, special_token_ids=[tokenizer.sep_token_id], bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, return_dict_in_generate=True, use_cache=True, max_length=256, num_beams=4, ) findings, impression = encoder_decoder.split_and_decode_sections( outputs.sequences, [tokenizer.sep_token_id, tokenizer.eos_token_id], tokenizer, ) for i, j in zip(findings, impression): print(f'Findings: {i}\nImpression: {j}\n') ```