Spaces:
Runtime error
Runtime error
File size: 5,952 Bytes
2815e7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
#!/usr/bin/env python
from __future__ import annotations
import enum
import gradio as gr
from huggingface_hub import HfApi
from inference import InferencePipeline
from utils import find_exp_dirs
SAMPLE_MODEL_IDS = [
'koala2/dreambooth-dog-v2',
'lambdalabs/dreambooth-avatar',
]
class ModelSource(enum.Enum):
SAMPLE = 'Sample'
HUB_LIB = 'Hub (dreambooth-library)'
LOCAL = 'Local'
class InferenceUtil:
def __init__(self, hf_token: str | None):
self.hf_token = hf_token
@staticmethod
def load_sample_model_list():
return gr.update(choices=SAMPLE_MODEL_IDS, value=SAMPLE_MODEL_IDS[0])
def load_hub_model_list(self) -> dict:
api = HfApi(token=self.hf_token)
choices = [
info.modelId for info in api.list_models(author='dreambooth-library')
]
return gr.update(choices=choices,
value=choices[0] if choices else None)
@staticmethod
def load_local_model_list() -> dict:
choices = find_exp_dirs()
return gr.update(choices=choices,
value=choices[0] if choices else None)
def reload_model_list(self, model_source: str) -> dict:
if model_source == ModelSource.SAMPLE.value:
return self.load_sample_model_list()
elif model_source == ModelSource.HUB_LIB.value:
return self.load_hub_model_list()
elif model_source == ModelSource.LOCAL.value:
return self.load_local_model_list()
else:
raise ValueError
def load_model_info(self, model_id: str) -> tuple[str, str]:
try:
card = InferencePipeline.get_model_card(model_id, self.hf_token)
except Exception:
return ''
instance_prompt = getattr(card.data, 'instance_prompt', '')
return instance_prompt
def reload_model_list_and_update_model_info(
self, model_source: str
) -> tuple[dict, str, str]:
model_list_update = self.reload_model_list(model_source)
model_list = model_list_update['choices']
model_info = self.load_model_info(model_list[0] if model_list else '')
return model_list_update, *model_info
def create_inference_demo(pipe: InferencePipeline,
hf_token: str | None = None) -> gr.Blocks:
app = InferenceUtil(hf_token)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
with gr.Box():
model_source = gr.Radio(
label='Model Source',
choices=[_.value for _ in ModelSource],
value=ModelSource.SAMPLE.value)
reload_button = gr.Button('Reload Model List')
model_id = gr.Dropdown(label='Model ID',
choices=SAMPLE_MODEL_IDS,
value=SAMPLE_MODEL_IDS[0])
with gr.Accordion(
label=
'Model info (Base model and instance prompt used for training)',
open=False):
with gr.Row():
instance_prompt_used_for_training = gr.Text(
label='Instance prompt', interactive=False)
prompt = gr.Textbox(
label='Prompt',
max_lines=1,
placeholder='Example: "A picture of a {}dog in a bucket"'
)
seed = gr.Slider(label='Seed',
minimum=0,
maximum=100000,
step=1,
value=0)
with gr.Accordion('Other Parameters', open=False):
num_steps = gr.Slider(label='Number of Steps',
minimum=0,
maximum=100,
step=1,
value=25)
guidance_scale = gr.Slider(label='CFG Scale',
minimum=0,
maximum=50,
step=0.1,
value=7.5)
run_button = gr.Button('Generate')
gr.Markdown('''
- After training, you can press "Reload Model List" button to load your trained model names.
''')
with gr.Column():
result = gr.Image(label='Result')
model_source.change(
fn=app.reload_model_list_and_update_model_info,
inputs=model_source,
outputs=[
model_id,
instance_prompt_used_for_training,
])
reload_button.click(
fn=app.reload_model_list_and_update_model_info,
inputs=model_source,
outputs=[
model_id,
instance_prompt_used_for_training,
])
model_id.change(fn=app.load_model_info,
inputs=model_id,
outputs=[
instance_prompt_used_for_training,
])
inputs = [
model_id,
prompt,
seed,
num_steps,
guidance_scale,
]
prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
return demo
if __name__ == '__main__':
import os
hf_token = os.getenv('HF_TOKEN')
pipe = InferencePipeline(hf_token)
demo = create_inference_demo(pipe, hf_token)
demo.queue(max_size=10).launch(share=False)
|