File size: 2,150 Bytes
97c9b98
 
 
c59ccb6
 
 
 
97c9b98
 
 
 
 
 
c59ccb6
 
97c9b98
c59ccb6
97c9b98
 
 
 
 
c59ccb6
95073a7
97c9b98
 
c59ccb6
97c9b98
 
c59ccb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import streamlit as st
import accelerate
import bitsandbytes
import re

def formatting_func(document):
    instruction = "You are a model designed to rephrase medical summaries for a general audience. Please summarize the following article in such a way a normal person could understand it, while also ensuring the same factual accuracy. Replace any technical terms with their equivalents in ordinary language, and be concise (< 100 words) and approachable.\n---------\n" 
    text = f"### {instruction} \n### Conclusion: {document} \n### Summary: "
    return text

def generate(text,max_new_token):
    ft_model = AutoModelForCausalLM.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer-old", 
                                                device_map="cuda", load_in_4bit=True)
    eval_tokenizer = AutoTokenizer.from_pretrained("BiswajitPadhi99/mistral-7b-finetuned-medical-summarizer-old", add_bos_token=True,
                                               device_map="cuda")
    ft_model.eval()
    with torch.no_grad():
        eval_prompt = formatting_func(text)
        model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda")
        response = eval_tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=max_new_token)[0], skip_special_tokens=True)

        if(eval_prompt in response):
            response = response.replace(eval_prompt, '')
        response = re.sub(r'#+', '', response)
        return response

def main():
    st.title('Medical Document Summarization')

    col1, col2 = st.columns(2)

    with col1:
        user_input = st.text_area("Enter your text here:", height=300)
        max_new_token = st.number_input('Max new tokens:', value=200)
        submit_button = st.button("Summarize")

    with col2:
        if submit_button:
            st.write("Model Response:")
            output = generate(user_input,max_new_token)
            print(output)
            st.markdown(output)

st.set_page_config(layout="wide")
if __name__=="__main__":
    main()