import streamlit as st from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MarianMTModel, MarianTokenizer # Load models and tokenizers @st.cache_resource def load_healthscribe_model(): model_name = "har1/HealthScribe-Clinical_Note_Generator" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) return model, tokenizer @st.cache_resource def load_translation_model(model_name): model = MarianMTModel.from_pretrained(model_name) tokenizer = MarianTokenizer.from_pretrained(model_name) return model, tokenizer # Initialize models healthscribe_model, healthscribe_tokenizer = load_healthscribe_model() # Language selection options language_options = { "English to French": ("en", "fr"), "French to English": ("fr", "en"), "English to Spanish": ("en", "es"), "Spanish to English": ("es", "en"), "English to German": ("en", "de"), "German to English": ("de", "en"), "English to Italian": ("en", "it"), "Italian to English": ("it", "en"), } # Streamlit UI setup st.title("Multifunctional Text Processing App") st.write("This app can generate clinical notes or translate text between languages.") # Choose task task = st.selectbox("Select a task:", ["Generate Clinical Note", "Translate Text"]) if task == "Generate Clinical Note": st.subheader("Clinical Note Generator") input_text = st.text_area("Enter patient information or medical notes:", height=200) if st.button("Generate Clinical Note"): if input_text.strip(): # Tokenize and generate inputs = healthscribe_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) outputs = healthscribe_model.generate(inputs["input_ids"], max_length=512, num_beams=5, early_stopping=True) generated_note = healthscribe_tokenizer.decode(outputs[0], skip_special_tokens=True) # Display the result st.subheader("Generated Clinical Note") st.write(generated_note) else: st.warning("Please enter some text to generate a clinical note.") elif task == "Translate Text": st.subheader("Translation Tool") language_pair = st.selectbox("Select language pair", list(language_options.keys())) src_lang, tgt_lang = language_options[language_pair] # Load the corresponding translation model and tokenizer model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}" translation_model, translation_tokenizer = load_translation_model(model_name) # Input text to translate text = st.text_area("Enter text to translate:") if st.button("Translate"): if text.strip(): # Prepare the input for the model inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True) # Generate translation translation = translation_model.generate(**inputs) # Decode the output translated_text = translation_tokenizer.decode(translation[0], skip_special_tokens=True) # Display translation st.write("**Original Text**:", text) st.write("**Translated Text**:", translated_text) else: st.warning("Please enter some text to translate.")