# pip install html2image import base64 import random from io import BytesIO from html2image import Html2Image import os import pathlib import re import gradio as gr import requests from PIL import Image from gradio_client import Client import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, Pipeline HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise Exception("HF_TOKEN environment variable is required to call remote API.") API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta" headers = {"Authorization": f"Bearer {HF_TOKEN}"} client = Client("https://latent-consistency-super-fast-lcm-lora-sd1-5.hf.space") def init_speech_to_text_model() -> Pipeline: device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = "distil-whisper/distil-medium.en" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) return pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, max_new_tokens=128, torch_dtype=torch_dtype, device=device, ) whisper_pipe = init_speech_to_text_model() def query(payload: dict): response = requests.post(API_URL, headers=headers, json=payload) return response.json() def generate_text(card_text: str, user_request: str) -> (str, str, str): # Prompt must apply the correct chat template for the model see: # https://huggingface.co/docs/transformers/main/en/chat_templating prompt = f"""<|system|> You create Magic the Gathering cards based on the user's request. # RULES - In your response always generate a new card. - Only generate one card, no other dialogue. - Surround card info in triple backticks (```). - Format the card text using headers like in the example below: ``` Name: Band of Brothers ManaCost: {{3}}{{W}}{{W}} Type: Creature — Phyrexian Human Soldier Rarity: rare Text: Vigilance {{W}}, {{T}}: Attach target creature you control to target creature. (Any number of attacking creatures with total power 5 or less can attack in a band. A band deals damage to that creature.) FlavorText: "This time we will be stronger." —Elder brotherhood blessing Power: 2 Toughness: 2 Color: ['W'] ``` <|user|> {user_request} <|assistant|> """ if card_text and card_text != starting_text: prompt = f"""<|system|> You edit Magic the Gathering cards based on the user's request. # RULES - In your response always generate a new card. - Only generate one card, no other dialogue. - Surround card info in triple backticks (```). - Format the card text using headers like in the example below: ``` Name: Band of Brothers ManaCost: {{3}}{{W}}{{W}} Type: Creature — Phyrexian Human Soldier Rarity: rare Text: Vigilance {{W}}, {{T}}: Attach target creature you control to target creature. (Any number of attacking creatures with total power 5 or less can attack in a band. A band deals damage to that creature.) FlavorText: "This time we will be stronger." —Elder brotherhood blessing Power: 2 Toughness: 2 Color: ['W'] ``` <|user|> # CARD TO EDIT ``` {card_text} ``` # EDIT REQUEST {user_request} <|assistant|> """ print(f"Calling API with prompt:\n{prompt}") params = {"max_new_tokens": 512} output = query({"inputs": prompt, "parameters": params}) if 'error' in output: print(f'Language model call failed: {output["error"]}') raise gr.Warning(f'Language model call failed: {output["error"]}') print(f'API RESPONSE SIZE: {len(output[0]["generated_text"])}') assistant_reply = output[0]["generated_text"].split('<|assistant|>')[1] print(f'ASSISTANT REPLY:\n{assistant_reply}') new_card_text = assistant_reply.split('```') if len(new_card_text) > 1: new_card_text = new_card_text[1].strip() + '\n' else: new_card_text = assistant_reply.split('\n\n') if len(new_card_text) < 2: return assistant_reply, card_text, None new_card_text = new_card_text[1].strip() + '\n' return assistant_reply, new_card_text, None def format_html(text, image_data): template = pathlib.Path("./card_template.html").read_text(encoding='utf-8') if "['U']" in text: template = template.replace("{card_color}", 'style="background-color:#5a73ab"') elif "['W']" in text: template = template.replace("{card_color}", 'style="background-color:#f0e3d0"') elif "['G']" in text: template = template.replace("{card_color}", 'style="background-color:#325433"') elif "['B']" in text: template = template.replace("{card_color}", 'style="background-color:#1a1b1e"') elif "['R']" in text: template = template.replace("{card_color}", 'style="background-color:#c2401c"') elif "Type: Land" in text: template = template.replace("{card_color}", 'style="background-color:#aa8c71"') elif "Type: Artifact" in text: template = template.replace("{card_color}", 'style="background-color:#9ba7bc"') else: template = template.replace("{card_color}", 'style="background-color:#edd99d"') pattern = re.compile('Name: (.*)') name = pattern.findall(text)[0] template = template.replace("{name}", name) pattern = re.compile('Mana.?Cost: (.*)') mana_cost = pattern.findall(text)[0] if mana_cost == "None": template = template.replace("{mana_cost}", '') else: symbols = [] for c in mana_cost: if c in {"{", "}"}: continue else: symbols.append(c.lower()) formatted_symbols = [] for s in symbols: formatted_symbols.append(f'') template = template.replace("{mana_cost}", "\n".join(formatted_symbols[::-1])) if not isinstance(image_data, (bytes, bytearray)): template = template.replace('{image_data}', f'{image_data}') else: template = template.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') pattern = re.compile('Type: (.*)') card_type = pattern.findall(text)[0] template = template.replace("{card_type}", card_type) if len(card_type) > 30: template = template.replace("{type_size}", "16") else: template = template.replace("{type_size}", "18") pattern = re.compile('Rarity: (.*)') rarity = pattern.findall(text)[0] template = template.replace("{rarity}", f"ss-{rarity}") pattern = re.compile(r'^Text: (.*)\n\bFlavor.?Text|Power|Color\b', re.MULTILINE | re.DOTALL) card_text = pattern.findall(text)[0] text_lines = [] for line in card_text.splitlines(): line = line.replace('{T}', '') line = line.replace('{UT}', '') line = line.replace('{E}', '') line = re.sub(r"{(.*?)}", r''.lower(), line) line = re.sub(r"ms-(.)/(.)", r''.lower(), line) line = line.replace('(', '(').replace(')', ')') text_lines.append(f"
{line}
") template = template.replace("{card_text}", "\n".join(text_lines)) pattern = re.compile(r'Flavor.?Text: (.*?)\n^.*$', re.MULTILINE | re.DOTALL) flavor_text = pattern.findall(text) if flavor_text: flavor_text = flavor_text[0] flavor_text_lines = [] for line in flavor_text.splitlines(): flavor_text_lines.append(f"{line}
") template = template.replace("{flavor_text}", "" + "\n".join(flavor_text_lines) + "") else: template = template.replace("{flavor_text}", "") if len(card_text) + len(flavor_text or '') > 170 or len(text_lines) > 3: template = template.replace("{text_size}", '16') template = template.replace( 'ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;">', 'ms-cost" style="top:0px;float:none;height: 16px;width: 16px;font-size: 11px;">') else: template = template.replace("{text_size}", '18') pattern = re.compile('Power: (.*)') power = pattern.findall(text) if power: power = power[0] if not power: template = template.replace("{power_toughness}", "") pattern = re.compile('Toughness: (.*)') toughness = pattern.findall(text)[0] template = template.replace("{power_toughness}", f'