ajaynagotha commited on
Commit
738d0f3
1 Parent(s): 5be946a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -25,6 +25,11 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  model.to(device)
26
  logger.info(f"Model and tokenizer loaded successfully. Using device: {device}")
27
 
 
 
 
 
 
28
  def clean_answer(answer):
29
  special_tokens = set(tokenizer.all_special_tokens)
30
  cleaned_answer = ' '.join(token for token in answer.split() if token not in special_tokens)
@@ -33,14 +38,11 @@ def clean_answer(answer):
33
  def answer_question(question):
34
  logger.info(f"Received question: {question}")
35
  try:
36
- logger.info("Combining text from dataset")
37
- context = " ".join([item.get('Text', '') for item in ds['train']])
38
- logger.info(f"Combined context length: {len(context)} characters")
39
-
40
  # Implement sliding window approach
41
  max_length = 1024
42
  stride = 512
43
  answers = []
 
44
  for i in range(0, len(context), stride):
45
  chunk = context[i:i+max_length]
46
 
@@ -55,14 +57,9 @@ def answer_question(question):
55
 
56
  inputs = {k: v.to(device) for k, v in inputs.items()}
57
 
58
- logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
59
-
60
- logger.info("Getting model output")
61
  with torch.no_grad():
62
  outputs = model(**inputs)
63
 
64
- logger.info(f"Output logits shapes: start={outputs.start_logits.shape}, end={outputs.end_logits.shape}")
65
-
66
  answer_start = torch.argmax(outputs.start_logits)
67
  answer_end = torch.argmax(outputs.end_logits) + 1
68
 
@@ -73,6 +70,10 @@ def answer_question(question):
73
  score = torch.max(outputs.start_logits) + torch.max(outputs.end_logits)
74
  answers.append((ans, score.item()))
75
 
 
 
 
 
76
  # Select best answer
77
  best_answer = max(answers, key=lambda x: x[1])[0]
78
 
 
25
  model.to(device)
26
  logger.info(f"Model and tokenizer loaded successfully. Using device: {device}")
27
 
28
+ # Preprocess the dataset
29
+ logger.info("Preprocessing the dataset")
30
+ context = " ".join([item.get('Text', '') for item in ds['train']])
31
+ logger.info(f"Combined context length: {len(context)} characters")
32
+
33
  def clean_answer(answer):
34
  special_tokens = set(tokenizer.all_special_tokens)
35
  cleaned_answer = ' '.join(token for token in answer.split() if token not in special_tokens)
 
38
  def answer_question(question):
39
  logger.info(f"Received question: {question}")
40
  try:
 
 
 
 
41
  # Implement sliding window approach
42
  max_length = 1024
43
  stride = 512
44
  answers = []
45
+
46
  for i in range(0, len(context), stride):
47
  chunk = context[i:i+max_length]
48
 
 
57
 
58
  inputs = {k: v.to(device) for k, v in inputs.items()}
59
 
 
 
 
60
  with torch.no_grad():
61
  outputs = model(**inputs)
62
 
 
 
63
  answer_start = torch.argmax(outputs.start_logits)
64
  answer_end = torch.argmax(outputs.end_logits) + 1
65
 
 
70
  score = torch.max(outputs.start_logits) + torch.max(outputs.end_logits)
71
  answers.append((ans, score.item()))
72
 
73
+ # Break if we have a good answer
74
+ if score > 10: # Adjust this threshold as needed
75
+ break
76
+
77
  # Select best answer
78
  best_answer = max(answers, key=lambda x: x[1])[0]
79