florentgbelidji HF staff commited on
Commit
baa2ff5
1 Parent(s): 993e825

Updating docstring, loading local model weights and adding parameters

Browse files
Files changed (1) hide show
  1. pipeline.py +21 -11
pipeline.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import base64
6
  import os
7
  from io import BytesIO
8
- from blip import blip_decoder
9
  from torchvision import transforms
10
  from torchvision.transforms.functional import InterpolationMode
11
 
@@ -13,10 +13,15 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
  print(device)
14
 
15
  class PreTrainedPipeline():
16
- def __init__(self, path=""):
17
  # load the optimized model
18
- self.model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
19
- self.model = blip_decoder(pretrained=self.model_url, image_size=384, vit='large',med_config=os.path.join(path, 'configs/med_config.json'))
 
 
 
 
 
20
  self.model.eval()
21
  self.model = self.model.to(device)
22
 
@@ -29,23 +34,28 @@ class PreTrainedPipeline():
29
 
30
 
31
 
32
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
33
  """
34
  Args:
35
  data (:obj:):
36
  includes the input data and the parameters for the inference.
37
  Return:
38
- A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
39
- - "label": A string representing what the label/class is. There can be multiple labels.
40
- - "score": A score between 0 and 1 describing how confident the model is for this label/class.
41
  """
42
  inputs = data.pop("inputs", data)
43
- parameters = data.pop("parameters", None)
44
 
45
  # decode base64 image to PIL
46
  image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
47
  image = self.transform(image).unsqueeze(0).to(device)
48
  with torch.no_grad():
49
- caption = self.model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
 
 
 
 
 
 
50
  # postprocess the prediction
51
- return caption
 
5
  import base64
6
  import os
7
  from io import BytesIO
8
+ from models.blip_decoder import blip_decoder
9
  from torchvision import transforms
10
  from torchvision.transforms.functional import InterpolationMode
11
 
 
13
  print(device)
14
 
15
  class PreTrainedPipeline():
16
+ def __init__(self):
17
  # load the optimized model
18
+ self.model_path = 'model_base_capfilt_large.pth'
19
+ self.model = blip_decoder(
20
+ pretrained=self.model_path,
21
+ image_size=384,
22
+ vit='large',
23
+ med_config=os.path.join(path, 'configs/med_config.json')
24
+ )
25
  self.model.eval()
26
  self.model = self.model.to(device)
27
 
 
34
 
35
 
36
 
37
+ def __call__(self, data: Any) -> Dict[str]:
38
  """
39
  Args:
40
  data (:obj:):
41
  includes the input data and the parameters for the inference.
42
  Return:
43
+ A :obj:`dict`:. The object returned should be a dict of one list like [[{"label": 0.9939950108528137}]] containing :
44
+ - "caption": A string corresponding to the generated caption.
 
45
  """
46
  inputs = data.pop("inputs", data)
47
+ parameters = data.pop("parameters", {})
48
 
49
  # decode base64 image to PIL
50
  image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
51
  image = self.transform(image).unsqueeze(0).to(device)
52
  with torch.no_grad():
53
+ caption = self.model.generate(
54
+ image,
55
+ sample=parameters.get('sample',True),
56
+ top_p=parameters.get('top_p',0.9),
57
+ max_length=parameters.get('max_length',20),
58
+ min_length=parameters.get('min_length',5)
59
+ )
60
  # postprocess the prediction
61
+ return {"caption": caption}