sgugger brydon commited on
Commit
b82e3a1
1 Parent(s): 4b49e24

Update PyTorch example in README.md (#3)

Browse files

- Update PyTorch example in README.md (90fe6fb8e83feb4fb44837d76ae7bc8f6e726017)


Co-authored-by: Brydon Eastman <brydon@users.noreply.huggingface.co>

Files changed (1) hide show
  1. README.md +7 -3
README.md CHANGED
@@ -64,10 +64,10 @@ Answer: 'SQuAD dataset', score: 0.4704, start: 147, end: 160
64
  Here is how to use this model in PyTorch:
65
 
66
  ```python
67
- from transformers import DistilBertTokenizer, DistilBertModel
68
  import torch
69
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad')
70
- model = DistilBertModel.from_pretrained('distilbert-base-uncased-distilled-squad')
71
 
72
  question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
73
 
@@ -75,7 +75,11 @@ inputs = tokenizer(question, text, return_tensors="pt")
75
  with torch.no_grad():
76
  outputs = model(**inputs)
77
 
78
- print(outputs)
 
 
 
 
79
  ```
80
 
81
  And in TensorFlow:
 
64
  Here is how to use this model in PyTorch:
65
 
66
  ```python
67
+ from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
68
  import torch
69
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad')
70
+ model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad')
71
 
72
  question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
73
 
 
75
  with torch.no_grad():
76
  outputs = model(**inputs)
77
 
78
+ answer_start_index = torch.argmax(outputs.start_logits)
79
+ answer_end_index = torch.argmax(outputs.end_logits)
80
+
81
+ predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
82
+ tokenizer.decode(predict_answer_tokens)
83
  ```
84
 
85
  And in TensorFlow: