|
import streamlit as st |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MarianMTModel, MarianTokenizer |
|
|
|
|
|
@st.cache_resource |
|
def load_healthscribe_model(): |
|
model_name = "har1/HealthScribe-Clinical_Note_Generator" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
return model, tokenizer |
|
|
|
@st.cache_resource |
|
def load_translation_model(model_name): |
|
model = MarianMTModel.from_pretrained(model_name) |
|
tokenizer = MarianTokenizer.from_pretrained(model_name) |
|
return model, tokenizer |
|
|
|
|
|
healthscribe_model, healthscribe_tokenizer = load_healthscribe_model() |
|
|
|
|
|
language_options = { |
|
"English to French": ("en", "fr"), |
|
"French to English": ("fr", "en"), |
|
"English to Spanish": ("en", "es"), |
|
"Spanish to English": ("es", "en"), |
|
"English to German": ("en", "de"), |
|
"German to English": ("de", "en"), |
|
"English to Italian": ("en", "it"), |
|
"Italian to English": ("it", "en"), |
|
} |
|
|
|
|
|
st.title("Multifunctional Text Processing App") |
|
st.write("This app can generate clinical notes or translate text between languages.") |
|
|
|
|
|
task = st.selectbox("Select a task:", ["Generate Clinical Note", "Translate Text"]) |
|
|
|
if task == "Generate Clinical Note": |
|
st.subheader("Clinical Note Generator") |
|
input_text = st.text_area("Enter patient information or medical notes:", height=200) |
|
|
|
if st.button("Generate Clinical Note"): |
|
if input_text.strip(): |
|
|
|
inputs = healthscribe_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) |
|
outputs = healthscribe_model.generate(inputs["input_ids"], max_length=512, num_beams=5, early_stopping=True) |
|
generated_note = healthscribe_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
st.subheader("Generated Clinical Note") |
|
st.write(generated_note) |
|
else: |
|
st.warning("Please enter some text to generate a clinical note.") |
|
|
|
elif task == "Translate Text": |
|
st.subheader("Translation Tool") |
|
language_pair = st.selectbox("Select language pair", list(language_options.keys())) |
|
src_lang, tgt_lang = language_options[language_pair] |
|
|
|
|
|
model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}" |
|
translation_model, translation_tokenizer = load_translation_model(model_name) |
|
|
|
|
|
text = st.text_area("Enter text to translate:") |
|
|
|
if st.button("Translate"): |
|
if text.strip(): |
|
|
|
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
|
|
|
|
|
translation = translation_model.generate(**inputs) |
|
|
|
|
|
translated_text = translation_tokenizer.decode(translation[0], skip_special_tokens=True) |
|
|
|
|
|
st.write("**Original Text**:", text) |
|
st.write("**Translated Text**:", translated_text) |
|
else: |
|
st.warning("Please enter some text to translate.") |
|
|