Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
model_map = torch.hub.load('nateraw/image-generation:main', 'model_map') | |
class InferenceWrapper: | |
def __init__(self, model): | |
self.model = model | |
self.pipe = torch.hub.load('nateraw/image-generation:main', 'styleganv3', pretrained=self.model, videos=True) | |
def __call__(self, seed1, seed2, seed3, w_frames, model): | |
if model != self.model: | |
print(f"Loading model: {model}") | |
self.model = model | |
self.pipe = torch.hub.load('nateraw/image-generation:main', 'styleganv3', pretrained=self.model, videos=True) | |
else: | |
print(f"Model '{model}' already loaded, reusing it.") | |
return self.pipe([seed1, seed2, seed3], w_frames=w_frames) | |
wrapper = InferenceWrapper('stylegan3-t-afhqv2-512x512.pkl') | |
def fn(s1, s2, s3, w_frames, model): | |
return wrapper(s1, s2, s3, w_frames, model) | |
gr.Interface( | |
fn, | |
inputs=[ | |
gr.inputs.Slider(minimum=0, maximum=999999999, step=1, default=0, label='Random Seed For Image 1'), | |
gr.inputs.Slider(minimum=0, maximum=999999999, step=1, default=0, label='Random Seed For Image 2'), | |
gr.inputs.Slider(minimum=0, maximum=999999999, step=1, default=0, label='Random Seed For Image 3'), | |
gr.inputs.Radio([60, 120, 240], type="value", default=60, label='Frames'), | |
gr.inputs.Radio(list(model_map), type="value", default='stylegan3-t-afhqv2-512x512.pkl', label='Pretrained Model') | |
], | |
outputs='video', | |
examples=[[0, 1, 2, 60, 'landscapes-256']], | |
enable_queue=True | |
).launch() |