import json import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import streamlit as st import accelerate import bitsandbytes import re def formatting_func(document): instruction = "You are a model designed to rephrase medical summaries for a general audience. Please summarize the following article in such a way a normal person could understand it, while also ensuring the same factual accuracy. Replace any technical terms with their equivalents in ordinary language, and be concise (< 100 words) and approachable.\n---------\n" text = f"### {instruction} \n### Conclusion: {document} \n### Summary: " return text def generate(text,max_new_token): ft_model = AutoModelForCausalLM.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer-old", device_map="cuda", load_in_4bit=True) eval_tokenizer = AutoTokenizer.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer-old", add_bos_token=True, device_map="cuda") ft_model.eval() with torch.no_grad(): eval_prompt = formatting_func(text) model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda") response = eval_tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=max_new_token)[0], skip_special_tokens=True) if(eval_prompt in response): response = response.replace(eval_prompt, '') response = re.sub(r'#+', '', response) return response def main(): st.title('Medical Document Summarization') col1, col2 = st.columns(2) with col1: user_input = st.text_area("Enter your text here:", height=300) max_new_token = st.number_input('Max new tokens:', value=200) submit_button = st.button("Summarize") with col2: if submit_button: st.write("Model Response:") output = generate(user_input,max_new_token) print(output) st.markdown(output) st.set_page_config(layout="wide") if __name__=="__main__": main()