bilgeyucel commited on
Commit
4397b1e
1 Parent(s): e5db11f

Update image_captioner.py

Browse files
Files changed (1) hide show
  1. image_captioner.py +8 -16
image_captioner.py CHANGED
@@ -37,20 +37,16 @@ class ImageCaptioner:
37
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  self.model.to(self.device)
39
 
40
- @component.output_types(captions=List[str])
41
- def run(self, image_file_paths: List[str]) -> List[Document]:
42
 
43
- images = []
44
- for image_path in image_file_paths:
45
- i_image = Image.open(image_path)
46
- if i_image.mode != "RGB":
47
- i_image = i_image.convert(mode="RGB")
48
-
49
- images.append(i_image)
50
 
51
  preds = []
52
  if self.model_name == "nlpconnect/vit-gpt2-image-captioning":
53
- pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
54
  pixel_values = pixel_values.to(self.device)
55
 
56
  output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
@@ -59,7 +55,7 @@ class ImageCaptioner:
59
  preds = [pred.strip() for pred in preds]
60
  else:
61
 
62
- inputs = self.processor(images, return_tensors="pt")
63
  output_ids = self.model.generate(**inputs)
64
  preds = self.processor.batch_decode(output_ids, skip_special_tokens=True)
65
  preds = [pred.strip() for pred in preds]
@@ -68,8 +64,4 @@ class ImageCaptioner:
68
  # for caption, image_file_path in zip(preds, image_file_paths):
69
  # document = Document(content=caption, meta={"image_path": image_file_path})
70
  # captions.append(document)
71
- return {"captions": preds}
72
-
73
- # captioner = ImageCaptioner(model_name="Salesforce/blip-image-captioning-base")
74
- # result = captioner.run(image_file_paths=["selfie.png"])
75
- # print(result)
 
37
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  self.model.to(self.device)
39
 
40
+ @component.output_types(caption=str)
41
+ def run(self, image_file_path: str) -> List[Document]:
42
 
43
+ i_image = Image.open(image_file_path)
44
+ if i_image.mode != "RGB":
45
+ i_image = i_image.convert(mode="RGB")
 
 
 
 
46
 
47
  preds = []
48
  if self.model_name == "nlpconnect/vit-gpt2-image-captioning":
49
+ pixel_values = self.feature_extractor(images=[i_image], return_tensors="pt").pixel_values
50
  pixel_values = pixel_values.to(self.device)
51
 
52
  output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
 
55
  preds = [pred.strip() for pred in preds]
56
  else:
57
 
58
+ inputs = self.processor([i_image], return_tensors="pt")
59
  output_ids = self.model.generate(**inputs)
60
  preds = self.processor.batch_decode(output_ids, skip_special_tokens=True)
61
  preds = [pred.strip() for pred in preds]
 
64
  # for caption, image_file_path in zip(preds, image_file_paths):
65
  # document = Document(content=caption, meta={"image_path": image_file_path})
66
  # captions.append(document)
67
+ return {"caption": preds[0]}