File size: 2,150 Bytes
97c9b98 c59ccb6 97c9b98 c59ccb6 97c9b98 c59ccb6 97c9b98 c59ccb6 95073a7 97c9b98 c59ccb6 97c9b98 c59ccb6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import streamlit as st
import accelerate
import bitsandbytes
import re
def formatting_func(document):
instruction = "You are a model designed to rephrase medical summaries for a general audience. Please summarize the following article in such a way a normal person could understand it, while also ensuring the same factual accuracy. Replace any technical terms with their equivalents in ordinary language, and be concise (< 100 words) and approachable.\n---------\n"
text = f"### {instruction} \n### Conclusion: {document} \n### Summary: "
return text
def generate(text,max_new_token):
ft_model = AutoModelForCausalLM.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer-old",
device_map="cuda", load_in_4bit=True)
eval_tokenizer = AutoTokenizer.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer-old", add_bos_token=True,
device_map="cuda")
ft_model.eval()
with torch.no_grad():
eval_prompt = formatting_func(text)
model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda")
response = eval_tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=max_new_token)[0], skip_special_tokens=True)
if(eval_prompt in response):
response = response.replace(eval_prompt, '')
response = re.sub(r'#+', '', response)
return response
def main():
st.title('Medical Document Summarization')
col1, col2 = st.columns(2)
with col1:
user_input = st.text_area("Enter your text here:", height=300)
max_new_token = st.number_input('Max new tokens:', value=200)
submit_button = st.button("Summarize")
with col2:
if submit_button:
st.write("Model Response:")
output = generate(user_input,max_new_token)
print(output)
st.markdown(output)
st.set_page_config(layout="wide")
if __name__=="__main__":
main()
|