bgaspra commited on
Commit
650aead
1 Parent(s): 26d55ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -41,12 +41,13 @@ def get_image_embedding(image):
41
  padding=True
42
  ).to(device, torch_dtype)
43
 
44
- # Generate decoder_input_ids
45
  decoder_input_ids = model.generate(
46
  **inputs,
47
- max_length=1,
48
  min_length=1,
49
  num_beams=1,
 
50
  pad_token_id=processor.tokenizer.pad_token_id,
51
  return_dict_in_generate=True,
52
  ).sequences
@@ -55,7 +56,6 @@ def get_image_embedding(image):
55
 
56
  with torch.no_grad():
57
  outputs = model(**inputs)
58
- # Use the mean of the last hidden state as the embedding
59
  image_embeddings = outputs.last_hidden_state.mean(dim=1)
60
  return image_embeddings.cpu().numpy()
61
  except Exception as e:
@@ -75,12 +75,13 @@ def get_text_embedding(text):
75
  padding=True
76
  ).to(device, torch_dtype)
77
 
78
- # Generate decoder_input_ids
79
  decoder_input_ids = model.generate(
80
  **inputs,
81
- max_length=1,
82
  min_length=1,
83
  num_beams=1,
 
84
  pad_token_id=processor.tokenizer.pad_token_id,
85
  return_dict_in_generate=True,
86
  ).sequences
 
41
  padding=True
42
  ).to(device, torch_dtype)
43
 
44
+ # Generate decoder_input_ids with adjusted parameters
45
  decoder_input_ids = model.generate(
46
  **inputs,
47
+ max_new_tokens=20, # Increased from max_length
48
  min_length=1,
49
  num_beams=1,
50
+ do_sample=False,
51
  pad_token_id=processor.tokenizer.pad_token_id,
52
  return_dict_in_generate=True,
53
  ).sequences
 
56
 
57
  with torch.no_grad():
58
  outputs = model(**inputs)
 
59
  image_embeddings = outputs.last_hidden_state.mean(dim=1)
60
  return image_embeddings.cpu().numpy()
61
  except Exception as e:
 
75
  padding=True
76
  ).to(device, torch_dtype)
77
 
78
+ # Generate decoder_input_ids with adjusted parameters
79
  decoder_input_ids = model.generate(
80
  **inputs,
81
+ max_new_tokens=20, # Using max_new_tokens instead of max_length
82
  min_length=1,
83
  num_beams=1,
84
+ do_sample=False,
85
  pad_token_id=processor.tokenizer.pad_token_id,
86
  return_dict_in_generate=True,
87
  ).sequences