import random import torch import numpy as np from tqdm import tqdm from functools import partialmethod import gradio as gr from gradio.mix import Series from transformers import pipeline, FSMTForConditionalGeneration, FSMTTokenizer from rudalle.pipelines import generate_images from rudalle import get_rudalle_model, get_tokenizer, get_vae # disable tqdm logging from the rudalle pipeline tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") translation_model = FSMTForConditionalGeneration.from_pretrained("facebook/wmt19-en-ru", torch_dtype=torch.float16).half().to(device) translation_tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-en-ru") dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device) tokenizer = get_tokenizer() vae = get_vae().to(device) def translation_wrapper(text: str): input_ids = translation_tokenizer.encode(text, return_tensors="pt") outputs = translation_model.generate(input_ids.to(device)) decoded = translation_tokenizer.decode(outputs[0].float(), skip_special_tokens=True) return decoded def dalle_wrapper(prompt: str): top_k, top_p = random.choice([ (1024, 0.98), (512, 0.97), (384, 0.96), ]) images , _ = generate_images( prompt, tokenizer, dalle, vae, top_k=top_k, images_num=1, top_p=top_p ) title = f"{prompt}" return title, images[0] translator = gr.Interface(fn=translation_wrapper, inputs=[gr.inputs.Textbox(label='What would you like to see?')], outputs="text") outputs = [ gr.outputs.HTML(label=""), gr.outputs.Image(label=""), ] generator = gr.Interface(fn=dalle_wrapper, inputs="text", outputs=outputs) description = ( "ruDALL-E is a 1.3B params text-to-image model by SberAI (links at the bottom). " "This demo uses an English-Russian translation model to adapt the prompts. " "Try pressing [Submit] multiple times to generate new images!" ) article = ( "
" "GitHub | " "Article (in Russian)" "
" ) examples = [["A still life of grapes and a bottle of wine"], ["Город в стиле киберпанк"], ["A colorful photo of a coral reef"], ["A white cat sitting in a cardboard box"]] series = Series(translator, generator, title='Kinda-English ruDALL-E', description=description, article=article, layout='horizontal', theme='huggingface', examples=examples, allow_flagging=False, live=False, enable_queue=True, ) series.launch()