Spaces:
Build error
Build error
Heiko Hotz
commited on
Commit
•
4b9c730
1
Parent(s):
833c58b
initial commit
Browse files- predict.py +3 -3
predict.py
CHANGED
@@ -13,8 +13,8 @@ from transformers import (
|
|
13 |
from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample
|
14 |
from transformers.data.metrics.squad_metrics import compute_predictions_logits
|
15 |
|
|
|
16 |
def run_prediction(question_texts, context_text, model_path, n_best_size=1):
|
17 |
-
### Setting hyperparameters
|
18 |
max_seq_length = 512
|
19 |
doc_stride = 256
|
20 |
n_best_size = n_best_size
|
@@ -102,7 +102,7 @@ def run_prediction(question_texts, context_text, model_path, n_best_size=1):
|
|
102 |
print(all_results)
|
103 |
|
104 |
output_nbest_file = None
|
105 |
-
if n_best_size > 1:
|
106 |
output_nbest_file = "nbest.json"
|
107 |
|
108 |
timer = time.time()
|
@@ -123,4 +123,4 @@ def run_prediction(question_texts, context_text, model_path, n_best_size=1):
|
|
123 |
)
|
124 |
print(f'Logits converted to predictions in {time.time()-timer} seconds')
|
125 |
|
126 |
-
return final_predictions
|
|
|
13 |
from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample
|
14 |
from transformers.data.metrics.squad_metrics import compute_predictions_logits
|
15 |
|
16 |
+
|
17 |
def run_prediction(question_texts, context_text, model_path, n_best_size=1):
|
|
|
18 |
max_seq_length = 512
|
19 |
doc_stride = 256
|
20 |
n_best_size = n_best_size
|
|
|
102 |
print(all_results)
|
103 |
|
104 |
output_nbest_file = None
|
105 |
+
if int(n_best_size) > 1:
|
106 |
output_nbest_file = "nbest.json"
|
107 |
|
108 |
timer = time.time()
|
|
|
123 |
)
|
124 |
print(f'Logits converted to predictions in {time.time()-timer} seconds')
|
125 |
|
126 |
+
return final_predictions
|