Srimanth Agastyaraju
Update app.py prompt UI
2c38794
import streamlit as st
import torch
from huggingface_hub import model_info
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
def inference(prompt, model, n_images, seed, n_inference_steps):
# Load the model
info = model_info(model)
model_base = info.cardData["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float32)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.unet.load_attn_procs(model)
# Load the UI components for progress bar and image grid
progress_bar_ui = st.empty()
with progress_bar_ui.container():
progress_bar = st.progress(0, text=f"Performing inference on {n_images} images...")
image_grid_ui = st.empty()
# Run inference
result_images = []
generators = [torch.Generator().manual_seed(i) for i in range(seed, n_images+seed)]
print(f"Inferencing '{prompt}' for {n_images} images.")
for i in range(n_images):
result = pipe(prompt, generator=generators[i], num_inference_steps=n_inference_steps).images[0]
result_images.append(result)
# Start with empty UI elements
progress_bar_ui.empty()
image_grid_ui.empty()
# Update the progress bar
with progress_bar_ui.container():
value = ((i+1)/n_images)
progress_bar.progress(value, text=f"{i+1} out of {n_images} images processed.")
# Update the image grid
with image_grid_ui.container():
col1, col2, col3 = st.columns(3)
with col1:
for i in range(0, len(result_images), 3):
st.image(result_images[i], caption=f"Image - {i+1}")
with col2:
for i in range(1, len(result_images), 3):
st.image(result_images[i], caption=f"Image - {i+1}")
with col3:
for i in range(2, len(result_images), 3):
st.image(result_images[i], caption=f"Image - {i+1}")
if __name__ == "__main__":
# --- START UI ---
st.title("Finetune LoRA inference")
with st.form(key='form_parameters'):
model_options = [
"asrimanth/person-thumbs-up-plain-lora : Tom Cruise thumbs up",
"asrimanth/srimanth-thumbs-up-lora-plain : srimanth thumbs up",
"asrimanth/person-thumbs-up-lora : <tom_cruise> #thumbsup",
"asrimanth/person-thumbs-up-lora-no-cap : <tom_cruise> #thumbsup",
]
current_model = st.selectbox("Choose a model", options=model_options)
model, _ = current_model.split(" : ")
prompt = st.text_input("Enter the prompt: (sample prompts in dropdown)")
current_model = current_model.split(" : ")[0]
col1_inp, col2_inp, col_3_inp = st.columns(3)
with col1_inp:
n_images = int(st.number_input("Enter the number of images", value=3, min_value=0, max_value=50))
with col2_inp:
n_inference_steps = int(st.number_input("Enter the number of inference steps", value=5, min_value=0))
with col_3_inp:
seed_input = int(st.number_input("Enter the seed (default=25)", value=25, min_value=0))
submitted = st.form_submit_button("Predict")
if submitted: # The form is submitted
inference(prompt, model, n_images, seed_input, n_inference_steps)