LucasAguetai commited on
Commit
2974ccd
1 Parent(s): f234dc4

add score to squeezeBert

Browse files
Files changed (1) hide show
  1. modeles.py +25 -5
modeles.py CHANGED
@@ -7,8 +7,7 @@ def loadSqueeze():
7
  return tokenizer, model
8
 
9
  def squeezebert(context, question, model, tokenizer):
10
- # Define the specific model and tokenizer for SqueezeBERT
11
- # Tokenize the input question-context pair
12
  inputs = tokenizer.encode_plus(question, context, max_length=512, truncation=True, padding=True, return_tensors='pt')
13
 
14
  # Send inputs to the same device as your model
@@ -20,13 +19,34 @@ def squeezebert(context, question, model, tokenizer):
20
 
21
  # Extract the start and end positions of the answer in the tokens
22
  answer_start_scores, answer_end_scores = outputs.start_logits, outputs.end_logits
23
- answer_start_index = torch.argmax(answer_start_scores) # Most likely start of answer
24
- answer_end_index = torch.argmax(answer_end_scores) + 1 # Most likely end of answer; +1 for inclusive slicing
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # Convert token indices to the actual answer text
27
  answer_tokens = inputs['input_ids'][0, answer_start_index:answer_end_index]
28
  answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
29
- return {"answer": answer, "start": answer_start_index.item(), "end": answer_end_index.item()}
 
 
 
 
 
 
 
 
30
 
31
 
32
 
 
7
  return tokenizer, model
8
 
9
  def squeezebert(context, question, model, tokenizer):
10
+ # Tokenize the input question-context pair
 
11
  inputs = tokenizer.encode_plus(question, context, max_length=512, truncation=True, padding=True, return_tensors='pt')
12
 
13
  # Send inputs to the same device as your model
 
19
 
20
  # Extract the start and end positions of the answer in the tokens
21
  answer_start_scores, answer_end_scores = outputs.start_logits, outputs.end_logits
22
+
23
+ # Calculate probabilities from logits
24
+ answer_start_prob = torch.softmax(answer_start_scores, dim=-1)
25
+ answer_end_prob = torch.softmax(answer_end_scores, dim=-1)
26
+
27
+ # Find the most likely start and end positions
28
+ answer_start_index = torch.argmax(answer_start_prob) # Most likely start of answer
29
+ answer_end_index = torch.argmax(answer_end_prob) + 1 # Most likely end of answer; +1 for inclusive slicing
30
+
31
+ # Extract the highest probability scores
32
+ start_score = answer_start_prob.max().item() # Highest probability of start
33
+ end_score = answer_end_prob.max().item() # Highest probability of end
34
+
35
+ # Combine the scores into a singular score
36
+ combined_score = (start_score * end_score) ** 0.5 # Geometric mean of start and end scores
37
 
38
  # Convert token indices to the actual answer text
39
  answer_tokens = inputs['input_ids'][0, answer_start_index:answer_end_index]
40
  answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
41
+
42
+ # Return the answer, its positions, and the combined score
43
+ return {
44
+ "answer": answer,
45
+ "start": answer_start_index.item(),
46
+ "end": answer_end_index.item(),
47
+ "score": combined_score
48
+ }
49
+
50
 
51
 
52