Biswajit Padhi
Resolve merge conflicts
95073a7
raw
history blame
2.15 kB
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import streamlit as st
import accelerate
import bitsandbytes
import re
def formatting_func(document):
instruction = "You are a model designed to rephrase medical summaries for a general audience. Please summarize the following article in such a way a normal person could understand it, while also ensuring the same factual accuracy. Replace any technical terms with their equivalents in ordinary language, and be concise (< 100 words) and approachable.\n---------\n"
text = f"### {instruction} \n### Conclusion: {document} \n### Summary: "
return text
def generate(text,max_new_token):
ft_model = AutoModelForCausalLM.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer-old",
device_map="cuda", load_in_4bit=True)
eval_tokenizer = AutoTokenizer.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer-old", add_bos_token=True,
device_map="cuda")
ft_model.eval()
with torch.no_grad():
eval_prompt = formatting_func(text)
model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda")
response = eval_tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=max_new_token)[0], skip_special_tokens=True)
if(eval_prompt in response):
response = response.replace(eval_prompt, '')
response = re.sub(r'#+', '', response)
return response
def main():
st.title('Medical Document Summarization')
col1, col2 = st.columns(2)
with col1:
user_input = st.text_area("Enter your text here:", height=300)
max_new_token = st.number_input('Max new tokens:', value=200)
submit_button = st.button("Summarize")
with col2:
if submit_button:
st.write("Model Response:")
output = generate(user_input,max_new_token)
print(output)
st.markdown(output)
st.set_page_config(layout="wide")
if __name__=="__main__":
main()