Spaces:
Sleeping
Sleeping
ajaynagotha
commited on
Commit
•
738d0f3
1
Parent(s):
5be946a
Update app.py
Browse files
app.py
CHANGED
@@ -25,6 +25,11 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
25 |
model.to(device)
|
26 |
logger.info(f"Model and tokenizer loaded successfully. Using device: {device}")
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
def clean_answer(answer):
|
29 |
special_tokens = set(tokenizer.all_special_tokens)
|
30 |
cleaned_answer = ' '.join(token for token in answer.split() if token not in special_tokens)
|
@@ -33,14 +38,11 @@ def clean_answer(answer):
|
|
33 |
def answer_question(question):
|
34 |
logger.info(f"Received question: {question}")
|
35 |
try:
|
36 |
-
logger.info("Combining text from dataset")
|
37 |
-
context = " ".join([item.get('Text', '') for item in ds['train']])
|
38 |
-
logger.info(f"Combined context length: {len(context)} characters")
|
39 |
-
|
40 |
# Implement sliding window approach
|
41 |
max_length = 1024
|
42 |
stride = 512
|
43 |
answers = []
|
|
|
44 |
for i in range(0, len(context), stride):
|
45 |
chunk = context[i:i+max_length]
|
46 |
|
@@ -55,14 +57,9 @@ def answer_question(question):
|
|
55 |
|
56 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
57 |
|
58 |
-
logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
|
59 |
-
|
60 |
-
logger.info("Getting model output")
|
61 |
with torch.no_grad():
|
62 |
outputs = model(**inputs)
|
63 |
|
64 |
-
logger.info(f"Output logits shapes: start={outputs.start_logits.shape}, end={outputs.end_logits.shape}")
|
65 |
-
|
66 |
answer_start = torch.argmax(outputs.start_logits)
|
67 |
answer_end = torch.argmax(outputs.end_logits) + 1
|
68 |
|
@@ -73,6 +70,10 @@ def answer_question(question):
|
|
73 |
score = torch.max(outputs.start_logits) + torch.max(outputs.end_logits)
|
74 |
answers.append((ans, score.item()))
|
75 |
|
|
|
|
|
|
|
|
|
76 |
# Select best answer
|
77 |
best_answer = max(answers, key=lambda x: x[1])[0]
|
78 |
|
|
|
25 |
model.to(device)
|
26 |
logger.info(f"Model and tokenizer loaded successfully. Using device: {device}")
|
27 |
|
28 |
+
# Preprocess the dataset
|
29 |
+
logger.info("Preprocessing the dataset")
|
30 |
+
context = " ".join([item.get('Text', '') for item in ds['train']])
|
31 |
+
logger.info(f"Combined context length: {len(context)} characters")
|
32 |
+
|
33 |
def clean_answer(answer):
|
34 |
special_tokens = set(tokenizer.all_special_tokens)
|
35 |
cleaned_answer = ' '.join(token for token in answer.split() if token not in special_tokens)
|
|
|
38 |
def answer_question(question):
|
39 |
logger.info(f"Received question: {question}")
|
40 |
try:
|
|
|
|
|
|
|
|
|
41 |
# Implement sliding window approach
|
42 |
max_length = 1024
|
43 |
stride = 512
|
44 |
answers = []
|
45 |
+
|
46 |
for i in range(0, len(context), stride):
|
47 |
chunk = context[i:i+max_length]
|
48 |
|
|
|
57 |
|
58 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
59 |
|
|
|
|
|
|
|
60 |
with torch.no_grad():
|
61 |
outputs = model(**inputs)
|
62 |
|
|
|
|
|
63 |
answer_start = torch.argmax(outputs.start_logits)
|
64 |
answer_end = torch.argmax(outputs.end_logits) + 1
|
65 |
|
|
|
70 |
score = torch.max(outputs.start_logits) + torch.max(outputs.end_logits)
|
71 |
answers.append((ans, score.item()))
|
72 |
|
73 |
+
# Break if we have a good answer
|
74 |
+
if score > 10: # Adjust this threshold as needed
|
75 |
+
break
|
76 |
+
|
77 |
# Select best answer
|
78 |
best_answer = max(answers, key=lambda x: x[1])[0]
|
79 |
|