anicolson commited on
Commit
b5aa720
·
1 Parent(s): 90834dc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +67 -0
README.md CHANGED
@@ -1,3 +1,70 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ ```python
6
+ import os, torch, transformers
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+
10
+ os.environ['TRANSFORMERS_CACHE'] = '/scratch1/nic261/hf_cache'
11
+ os.environ['HUGGINGFACE_HUB_CACHE'] = '/scratch1/nic261/hf_cache'
12
+
13
+ ckpt_name = 'aehrc/mimic-cxr-report-gen-single'
14
+
15
+ encoder_decoder = transformers.AutoModel.from_pretrained(ckpt_name, trust_remote_code=True)
16
+ tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(ckpt_name)
17
+ image_processor = transformers.AutoFeatureExtractor.from_pretrained(ckpt_name)
18
+
19
+ test_transforms = transforms.Compose(
20
+ [
21
+ transforms.Resize(size=image_processor.size['shortest_edge']),
22
+ transforms.CenterCrop(size=[
23
+ image_processor.size['shortest_edge'],
24
+ image_processor.size['shortest_edge'],
25
+ ]
26
+ ),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize(
29
+ mean=image_processor.image_mean,
30
+ std=image_processor.image_std,
31
+ ),
32
+ ]
33
+ )
34
+
35
+ url = 'https://www.stritch.luc.edu/lumen/meded/radio/curriculum/IPM/PCM/86a_labelled.jpg'
36
+ response = requests.get(url)
37
+ image_a = Image.open(BytesIO(response.content))
38
+ image_a = image_a.convert('RGB')
39
+ image_a = test_transforms(image_a)
40
+
41
+ url = 'https://prod-images-static.radiopaedia.org/images/566180/d527ff6fc1482161c9225345c4ab42_big_gallery.jpg'
42
+ response = requests.get(url)
43
+ image_b = Image.open(BytesIO(response.content))
44
+ image_b = image_b.convert('RGB')
45
+ image_b = test_transforms(image_b)
46
+
47
+ images = torch.stack([image_a, image_b], dim=0)
48
+ images.shape
49
+
50
+ outputs = encoder_decoder.generate(
51
+ pixel_values=images,
52
+ special_token_ids=[tokenizer.sep_token_id],
53
+ bos_token_id=tokenizer.bos_token_id,
54
+ eos_token_id=tokenizer.eos_token_id,
55
+ pad_token_id=tokenizer.pad_token_id,
56
+ return_dict_in_generate=True,
57
+ use_cache=True,
58
+ max_length=256,
59
+ num_beams=4,
60
+ )
61
+
62
+ findings, impression = encoder_decoder.split_and_decode_sections(
63
+ outputs.sequences,
64
+ [tokenizer.sep_token_id, tokenizer.eos_token_id],
65
+ tokenizer,
66
+ )
67
+
68
+ for i, j in zip(findings, impression):
69
+ print(f'Findings: {i}\nImpression: {j}\n')
70
+ ```