Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from huggingface_hub import model_info | |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
def inference(prompt, model, n_images, seed, n_inference_steps): | |
# Load the model | |
info = model_info(model) | |
model_base = info.cardData["base_model"] | |
pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float32) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.unet.load_attn_procs(model) | |
# Load the UI components for progress bar and image grid | |
progress_bar_ui = st.empty() | |
with progress_bar_ui.container(): | |
progress_bar = st.progress(0, text=f"Performing inference on {n_images} images...") | |
image_grid_ui = st.empty() | |
# Run inference | |
result_images = [] | |
generators = [torch.Generator().manual_seed(i) for i in range(seed, n_images+seed)] | |
print(f"Inferencing '{prompt}' for {n_images} images.") | |
for i in range(n_images): | |
result = pipe(prompt, generator=generators[i], num_inference_steps=n_inference_steps).images[0] | |
result_images.append(result) | |
# Start with empty UI elements | |
progress_bar_ui.empty() | |
image_grid_ui.empty() | |
# Update the progress bar | |
with progress_bar_ui.container(): | |
value = ((i+1)/n_images) | |
progress_bar.progress(value, text=f"{i+1} out of {n_images} images processed.") | |
# Update the image grid | |
with image_grid_ui.container(): | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
for i in range(0, len(result_images), 3): | |
st.image(result_images[i], caption=f"Image - {i+1}") | |
with col2: | |
for i in range(1, len(result_images), 3): | |
st.image(result_images[i], caption=f"Image - {i+1}") | |
with col3: | |
for i in range(2, len(result_images), 3): | |
st.image(result_images[i], caption=f"Image - {i+1}") | |
if __name__ == "__main__": | |
# --- START UI --- | |
st.title("Finetune LoRA inference") | |
with st.form(key='form_parameters'): | |
model_options = [ | |
"asrimanth/person-thumbs-up-plain-lora : Tom Cruise thumbs up", | |
"asrimanth/srimanth-thumbs-up-lora-plain : srimanth thumbs up", | |
"asrimanth/person-thumbs-up-lora : <tom_cruise> #thumbsup", | |
"asrimanth/person-thumbs-up-lora-no-cap : <tom_cruise> #thumbsup", | |
] | |
current_model = st.selectbox("Choose a model", options=model_options) | |
model, _ = current_model.split(" : ") | |
prompt = st.text_input("Enter the prompt: (sample prompts in dropdown)") | |
current_model = current_model.split(" : ")[0] | |
col1_inp, col2_inp, col_3_inp = st.columns(3) | |
with col1_inp: | |
n_images = int(st.number_input("Enter the number of images", value=3, min_value=0, max_value=50)) | |
with col2_inp: | |
n_inference_steps = int(st.number_input("Enter the number of inference steps", value=5, min_value=0)) | |
with col_3_inp: | |
seed_input = int(st.number_input("Enter the seed (default=25)", value=25, min_value=0)) | |
submitted = st.form_submit_button("Predict") | |
if submitted: # The form is submitted | |
inference(prompt, model, n_images, seed_input, n_inference_steps) | |