Update image_captioner.py

#2
Files changed (2) hide show
  1. app.py +6 -6
  2. image_captioner.py +8 -16
app.py CHANGED
@@ -24,28 +24,28 @@ prompt_template = """
24
  You will receive a descriptive text of a photo.
25
  Try to generate a nice Instagram caption with a phrase rhyming with the text. Include emojis in the caption.
26
 
27
- Descriptive text: {{captions[0]}};
28
  Instagram Caption:
29
  """
30
 
31
  hf_api_key = os.environ["HF_API_KEY"]
32
 
33
- def generate_caption(image_file_paths, model_name):
34
  image_to_text = ImageCaptioner(
35
  model_name="nlpconnect/vit-gpt2-image-captioning",
36
  )
37
  prompt_builder = PromptBuilder(template=prompt_template)
38
- generator = HuggingFaceTGIGenerator(model=model_name, token=Secret.from_token(hf_api_key))
39
  captioning_pipeline = Pipeline()
40
  captioning_pipeline.add_component("image_to_text", image_to_text)
41
  captioning_pipeline.add_component("prompt_builder", prompt_builder)
42
  captioning_pipeline.add_component("generator", generator)
43
 
44
- captioning_pipeline.connect("image_to_text.captions", "prompt_builder.captions")
45
  captioning_pipeline.connect("prompt_builder", "generator")
46
 
47
- result = captioning_pipeline.run({"image_to_text":{"image_file_paths":image_file_paths}})
48
- return result["generator"][0]
49
 
50
  with gr.Blocks(theme="soft") as demo:
51
  gr.Markdown(value=description)
 
24
  You will receive a descriptive text of a photo.
25
  Try to generate a nice Instagram caption with a phrase rhyming with the text. Include emojis in the caption.
26
 
27
+ Descriptive text: {{caption}};
28
  Instagram Caption:
29
  """
30
 
31
  hf_api_key = os.environ["HF_API_KEY"]
32
 
33
+ def generate_caption(image_file_path, model_name):
34
  image_to_text = ImageCaptioner(
35
  model_name="nlpconnect/vit-gpt2-image-captioning",
36
  )
37
  prompt_builder = PromptBuilder(template=prompt_template)
38
+ generator = HuggingFaceTGIGenerator(model=model_name, token=Secret.from_token(hf_api_key), generation_kwargs={"max_new_tokens":50})
39
  captioning_pipeline = Pipeline()
40
  captioning_pipeline.add_component("image_to_text", image_to_text)
41
  captioning_pipeline.add_component("prompt_builder", prompt_builder)
42
  captioning_pipeline.add_component("generator", generator)
43
 
44
+ captioning_pipeline.connect("image_to_text.caption", "prompt_builder.caption")
45
  captioning_pipeline.connect("prompt_builder", "generator")
46
 
47
+ result = captioning_pipeline.run({"image_to_text":{"image_file_path":image_file_path}})
48
+ return result["generator"]["replies"][0]
49
 
50
  with gr.Blocks(theme="soft") as demo:
51
  gr.Markdown(value=description)
image_captioner.py CHANGED
@@ -37,20 +37,16 @@ class ImageCaptioner:
37
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  self.model.to(self.device)
39
 
40
- @component.output_types(captions=List[str])
41
- def run(self, image_file_paths: List[str]) -> List[Document]:
42
 
43
- images = []
44
- for image_path in image_file_paths:
45
- i_image = Image.open(image_path)
46
- if i_image.mode != "RGB":
47
- i_image = i_image.convert(mode="RGB")
48
-
49
- images.append(i_image)
50
 
51
  preds = []
52
  if self.model_name == "nlpconnect/vit-gpt2-image-captioning":
53
- pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
54
  pixel_values = pixel_values.to(self.device)
55
 
56
  output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
@@ -59,7 +55,7 @@ class ImageCaptioner:
59
  preds = [pred.strip() for pred in preds]
60
  else:
61
 
62
- inputs = self.processor(images, return_tensors="pt")
63
  output_ids = self.model.generate(**inputs)
64
  preds = self.processor.batch_decode(output_ids, skip_special_tokens=True)
65
  preds = [pred.strip() for pred in preds]
@@ -68,8 +64,4 @@ class ImageCaptioner:
68
  # for caption, image_file_path in zip(preds, image_file_paths):
69
  # document = Document(content=caption, meta={"image_path": image_file_path})
70
  # captions.append(document)
71
- return {"captions": preds}
72
-
73
- # captioner = ImageCaptioner(model_name="Salesforce/blip-image-captioning-base")
74
- # result = captioner.run(image_file_paths=["selfie.png"])
75
- # print(result)
 
37
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  self.model.to(self.device)
39
 
40
+ @component.output_types(caption=str)
41
+ def run(self, image_file_path: str) -> List[Document]:
42
 
43
+ i_image = Image.open(image_file_path)
44
+ if i_image.mode != "RGB":
45
+ i_image = i_image.convert(mode="RGB")
 
 
 
 
46
 
47
  preds = []
48
  if self.model_name == "nlpconnect/vit-gpt2-image-captioning":
49
+ pixel_values = self.feature_extractor(images=[i_image], return_tensors="pt").pixel_values
50
  pixel_values = pixel_values.to(self.device)
51
 
52
  output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
 
55
  preds = [pred.strip() for pred in preds]
56
  else:
57
 
58
+ inputs = self.processor([i_image], return_tensors="pt")
59
  output_ids = self.model.generate(**inputs)
60
  preds = self.processor.batch_decode(output_ids, skip_special_tokens=True)
61
  preds = [pred.strip() for pred in preds]
 
64
  # for caption, image_file_path in zip(preds, image_file_paths):
65
  # document = Document(content=caption, meta={"image_path": image_file_path})
66
  # captions.append(document)
67
+ return {"caption": preds[0]}