|
import json |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
import streamlit as st |
|
import accelerate |
|
import re |
|
|
|
def formatting_func(document): |
|
instruction = "You are a model designed to rephrase medical summaries for a general audience. Please summarize the following article in such a way a normal person could understand it, while also ensuring the same factual accuracy. Replace any technical terms with their equivalents in ordinary language, and be concise (< 100 words) and approachable.\n---------\n" |
|
text = f"### {instruction} \n### Conclusion: {document} \n### Summary: " |
|
return text |
|
|
|
def generate(text,max_new_token): |
|
ft_model = AutoModelForCausalLM.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer-old", |
|
device_map="cuda", load_in_4bit=True) |
|
eval_tokenizer = AutoTokenizer.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer-old", add_bos_token=True, |
|
device_map="cuda") |
|
ft_model.eval() |
|
with torch.no_grad(): |
|
eval_prompt = formatting_func(text) |
|
model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda") |
|
response = eval_tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=max_new_token)[0], skip_special_tokens=True) |
|
|
|
if(eval_prompt in response): |
|
response = response.replace(eval_prompt, '') |
|
response = re.sub(r'#+', '', response) |
|
return response |
|
|
|
def main(): |
|
st.title('Medical Document Summarization') |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
user_input = st.text_area("Enter your text here:", height=300) |
|
max_new_token = st.number_input('Max new tokens:', value=200) |
|
submit_button = st.button("Summarize") |
|
|
|
with col2: |
|
if submit_button: |
|
st.write("Model Response:") |
|
output = generate(user_input,max_new_token) |
|
print(output) |
|
st.markdown(output) |
|
|
|
st.set_page_config(layout="wide") |
|
if __name__=="__main__": |
|
main() |
|
|
|
|
|
|