import streamlit as st import torch from huggingface_hub import model_info from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler def inference(prompt, model, n_images, seed, n_inference_steps): # Load the model info = model_info(model) model_base = info.cardData["base_model"] pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float32) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe.unet.load_attn_procs(model) # Load the UI components for progress bar and image grid progress_bar_ui = st.empty() with progress_bar_ui.container(): progress_bar = st.progress(0, text=f"Performing inference on {n_images} images...") image_grid_ui = st.empty() # Run inference result_images = [] generators = [torch.Generator().manual_seed(i) for i in range(seed, n_images+seed)] print(f"Inferencing '{prompt}' for {n_images} images.") for i in range(n_images): result = pipe(prompt, generator=generators[i], num_inference_steps=n_inference_steps).images[0] result_images.append(result) # Start with empty UI elements progress_bar_ui.empty() image_grid_ui.empty() # Update the progress bar with progress_bar_ui.container(): value = ((i+1)/n_images) progress_bar.progress(value, text=f"{i+1} out of {n_images} images processed.") # Update the image grid with image_grid_ui.container(): col1, col2, col3 = st.columns(3) with col1: for i in range(0, len(result_images), 3): st.image(result_images[i], caption=f"Image - {i+1}") with col2: for i in range(1, len(result_images), 3): st.image(result_images[i], caption=f"Image - {i+1}") with col3: for i in range(2, len(result_images), 3): st.image(result_images[i], caption=f"Image - {i+1}") if __name__ == "__main__": # --- START UI --- st.title("Finetune LoRA inference") with st.form(key='form_parameters'): model_options = [ "asrimanth/person-thumbs-up-plain-lora : Tom Cruise thumbs up", "asrimanth/srimanth-thumbs-up-lora-plain : srimanth thumbs up", "asrimanth/person-thumbs-up-lora : #thumbsup", "asrimanth/person-thumbs-up-lora-no-cap : #thumbsup", ] current_model = st.selectbox("Choose a model", options=model_options) model, _ = current_model.split(" : ") prompt = st.text_input("Enter the prompt: (sample prompts in dropdown)") current_model = current_model.split(" : ")[0] col1_inp, col2_inp, col_3_inp = st.columns(3) with col1_inp: n_images = int(st.number_input("Enter the number of images", value=3, min_value=0, max_value=50)) with col2_inp: n_inference_steps = int(st.number_input("Enter the number of inference steps", value=5, min_value=0)) with col_3_inp: seed_input = int(st.number_input("Enter the seed (default=25)", value=25, min_value=0)) submitted = st.form_submit_button("Predict") if submitted: # The form is submitted inference(prompt, model, n_images, seed_input, n_inference_steps)