zaidmehdi commited on
Commit
79dc319
1 Parent(s): 4b51583

debug extract_hidden_state()

Browse files
Files changed (1) hide show
  1. src/utils.py +3 -3
src/utils.py CHANGED
@@ -6,10 +6,10 @@ import torch
6
 
7
 
8
  def extract_hidden_state(input_text, tokenizer, language_model):
9
- tokens = tokenizer(input_text, padding=True)
10
  with torch.no_grad():
11
- outputs = language_model(tokens)
12
-
13
  return outputs.last_hidden_state
14
 
15
 
 
6
 
7
 
8
  def extract_hidden_state(input_text, tokenizer, language_model):
9
+ tokens = tokenizer(input_text, padding=True, return_tensors="pt")
10
  with torch.no_grad():
11
+ outputs = language_model(**tokens)
12
+
13
  return outputs.last_hidden_state
14
 
15