Biswajit Padhi commited on
Commit
c59ccb6
1 Parent(s): 97c9b98

Update app.py

Browse files
Files changed (2) hide show
  1. app.py +31 -10
  2. requirements.txt +8 -2
app.py CHANGED
@@ -1,29 +1,50 @@
1
  import json
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
 
 
 
4
 
5
  def formatting_func(document):
6
  instruction = "You are a model designed to rephrase medical summaries for a general audience. Please summarize the following article in such a way a normal person could understand it, while also ensuring the same factual accuracy. Replace any technical terms with their equivalents in ordinary language, and be concise (< 100 words) and approachable.\n---------\n"
7
  text = f"### {instruction} \n### Conclusion: {document} \n### Summary: "
8
  return text
9
 
10
- def genenrate(text):
11
- ft_model = AutoModelForCausalLM.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer",
12
  device_map="cuda", load_in_4bit=True)
13
- eval_tokenizer = AutoTokenizer.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer", add_bos_token=True,
14
  device_map="cuda")
15
  ft_model.eval()
16
  with torch.no_grad():
17
  eval_prompt = formatting_func(text)
18
  model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda")
19
- response = eval_tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=200)[0], skip_special_tokens=True)
20
  if(eval_prompt in response):
21
  response = response.replace(eval_prompt, '')
 
22
  return response
23
 
24
- input = st.text_input(label= "Input Text")
25
- if input is not None:
26
- col =st.columns(1)
27
- output = generate(input)
28
- col.header("Summary")
29
- col.write(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+ import streamlit as st
5
+ import accelerate
6
+ import bitsandbytes
7
+ import re
8
 
9
  def formatting_func(document):
10
  instruction = "You are a model designed to rephrase medical summaries for a general audience. Please summarize the following article in such a way a normal person could understand it, while also ensuring the same factual accuracy. Replace any technical terms with their equivalents in ordinary language, and be concise (< 100 words) and approachable.\n---------\n"
11
  text = f"### {instruction} \n### Conclusion: {document} \n### Summary: "
12
  return text
13
 
14
+ def generate(text,max_new_token):
15
+ ft_model = AutoModelForCausalLM.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer-old",
16
  device_map="cuda", load_in_4bit=True)
17
+ eval_tokenizer = AutoTokenizer.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer-old", add_bos_token=True,
18
  device_map="cuda")
19
  ft_model.eval()
20
  with torch.no_grad():
21
  eval_prompt = formatting_func(text)
22
  model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda")
23
+ response = eval_tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=max_new_token)[0], skip_special_tokens=True)
24
  if(eval_prompt in response):
25
  response = response.replace(eval_prompt, '')
26
+ response = re.sub(r'#+', '', response)
27
  return response
28
 
29
+ def main():
30
+ st.title('Medical Document Summarization')
31
+
32
+ col1, col2 = st.columns(2)
33
+
34
+ with col1:
35
+ user_input = st.text_area("Enter your text here:", height=300)
36
+ max_new_token = st.number_input('Max new tokens:', value=200)
37
+ submit_button = st.button("Summarize")
38
+
39
+ with col2:
40
+ if submit_button:
41
+ st.write("Model Response:")
42
+ output = generate(user_input,max_new_token)
43
+ print(output)
44
+ st.markdown(output)
45
+
46
+ st.set_page_config(layout="wide")
47
+ if __name__=="__main__":
48
+ main()
49
+
50
+
requirements.txt CHANGED
@@ -1,2 +1,8 @@
1
- transformers
2
- torch
 
 
 
 
 
 
 
1
+ torch
2
+ git+https://github.com/huggingface/transformers.git
3
+ git+https://github.com/huggingface/peft.git
4
+ git+https://github.com/huggingface/accelerate.git
5
+ bitsandbytes-cuda111
6
+ datasets
7
+ scipy
8
+ ipywidgets