ajaynagotha commited on
Commit
f211efc
·
verified ·
1 Parent(s): 881d53f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -17
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
3
  import torch
4
- from mlcroissant import Dataset
5
  import random
6
 
7
  # Load the DistilBERT model and tokenizer
@@ -10,27 +10,27 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = AutoModelForQuestionAnswering.from_pretrained(model_name)
11
 
12
  # Load the Bhagavad Gita dataset
13
- ds = Dataset(jsonld="https://huggingface.co/api/datasets/knowrohit07/gita_dataset/croissant")
14
- records = list(ds.records("default"))
15
 
16
  def get_relevant_context(question):
17
  # Randomly select 5 records to form the context
18
- selected_records = random.sample(records, 5)
19
- context = " ".join([record["Text"] for record in selected_records])
20
  return context
21
 
22
  def generate_response(question):
23
  context = get_relevant_context(question)
24
 
25
  # Encode the question and context
26
- inputs = tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="pt")
27
- input_ids = inputs["input_ids"].tolist()[0]
28
-
29
  # Get the answer
30
- outputs = model(**inputs)
 
 
31
  answer_start = torch.argmax(outputs.start_logits)
32
  answer_end = torch.argmax(outputs.end_logits) + 1
33
- answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
34
 
35
  # If the model couldn't find an answer, provide a default response
36
  if answer == "" or answer == "[CLS]" or answer == "[SEP]":
@@ -41,13 +41,9 @@ def generate_response(question):
41
 
42
  return answer + disclaimer
43
 
44
- # Define the predict function for the API
45
- def predict(question):
46
- return generate_response(question)
47
-
48
  # Create the Gradio interface
49
  iface = gr.Interface(
50
- fn=predict,
51
  inputs=gr.Textbox(lines=2, placeholder="Enter your question about the Bhagavad Gita here..."),
52
  outputs="text",
53
  title="Bhagavad Gita Q&A Assistant",
@@ -61,5 +57,5 @@ iface = gr.Interface(
61
  ]
62
  )
63
 
64
- # Launch the interface with sharing enabled
65
- iface.launch(share=True)
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
3
  import torch
4
+ from datasets import load_dataset
5
  import random
6
 
7
  # Load the DistilBERT model and tokenizer
 
10
  model = AutoModelForQuestionAnswering.from_pretrained(model_name)
11
 
12
  # Load the Bhagavad Gita dataset
13
+ ds = load_dataset("knowrohit07/gita_dataset")
 
14
 
15
  def get_relevant_context(question):
16
  # Randomly select 5 records to form the context
17
+ selected_records = random.sample(ds['train'], 5)
18
+ context = " ".join([record['Text'] for record in selected_records])
19
  return context
20
 
21
  def generate_response(question):
22
  context = get_relevant_context(question)
23
 
24
  # Encode the question and context
25
+ inputs = tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="pt", max_length=512, truncation=True)
26
+
 
27
  # Get the answer
28
+ with torch.no_grad():
29
+ outputs = model(**inputs)
30
+
31
  answer_start = torch.argmax(outputs.start_logits)
32
  answer_end = torch.argmax(outputs.end_logits) + 1
33
+ answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
34
 
35
  # If the model couldn't find an answer, provide a default response
36
  if answer == "" or answer == "[CLS]" or answer == "[SEP]":
 
41
 
42
  return answer + disclaimer
43
 
 
 
 
 
44
  # Create the Gradio interface
45
  iface = gr.Interface(
46
+ fn=generate_response,
47
  inputs=gr.Textbox(lines=2, placeholder="Enter your question about the Bhagavad Gita here..."),
48
  outputs="text",
49
  title="Bhagavad Gita Q&A Assistant",
 
57
  ]
58
  )
59
 
60
+ # Launch the interface
61
+ iface.launch()