import gradio as gr import torch from diffusers import DiffusionPipeline from transformers import AutoTokenizer, CLIPTokenizerFast, T5TokenizerFast import pandas as pd def load_tokenizers(model_id: str) -> list[CLIPTokenizerFast | T5TokenizerFast | None]: config = DiffusionPipeline.load_config(model_id) num_tokenizers = sum("tokenizer" in key for key in config.keys()) if not 1 <= num_tokenizers <= 3: raise gr.Error(f"Invalid number of tokenizers: {num_tokenizers}") tokenizers = [ AutoTokenizer.from_pretrained( model_id, subfolder=f'tokenizer{"" if i == 0 else f"_{i + 1}"}' ) for i in range(num_tokenizers) ] # Pad the list with None if there are fewer than 3 tokenizers tokenizers.extend([None] * (3 - num_tokenizers)) return tokenizers @torch.no_grad() def inference(model_id: str, text: str): tokenizers = load_tokenizers(model_id) text_pairs_components = [] special_tokens_components = [] tokenizer_details_components = [] for i, tokenizer in enumerate(tokenizers): if tokenizer: label_text = f"Tokenizer {i + 1}: {tokenizer.__class__.__name__}" # テキストとトークンIDのペアを作成 input_ids = tokenizer( text=text, truncation=False, return_length=False, return_overflowing_tokens=False, ).input_ids decoded_tokens = [tokenizer.decode(id_) for id_ in input_ids] token_pairs = [ (str(token), str(id_)) for token, id_ in zip(decoded_tokens, input_ids) ] output_text_pair_component = gr.HighlightedText( label=label_text, value=token_pairs, visible=True, ) # スペシャルトークンを追加 special_tokens = [] for k, v in tokenizer.special_tokens_map.items(): if k == "additional_special_tokens": continue special_token_map = (str(k), str(v)) special_tokens.append(special_token_map) output_special_tokens_component = gr.HighlightedText( label=label_text, value=special_tokens, visible=True, ) # トークナイザーの詳細情報を追加 tokenizer_details = pd.DataFrame([ ("Type", tokenizer.__class__.__name__), ("Vocab Size", tokenizer.vocab_size), ("Model Max Length", tokenizer.model_max_length), ("Padding Side", tokenizer.padding_side), ("Truncation Side", tokenizer.truncation_side), ], columns=["Attribute", "Value"]) output_tokenizer_details = gr.Dataframe( headers=["Attribute", "Value"], value=tokenizer_details, label=label_text, visible=True, ) else: output_text_pair_component = gr.HighlightedText(visible=False) output_special_tokens_component = gr.HighlightedText(visible=False) output_tokenizer_details = gr.Dataframe(visible=False) text_pairs_components.append(output_text_pair_component) special_tokens_components.append(output_special_tokens_component) tokenizer_details_components.append(output_tokenizer_details) return text_pairs_components + special_tokens_components + tokenizer_details_components if __name__ == "__main__": theme = gr.themes.Soft( primary_hue=gr.themes.colors.emerald, secondary_hue=gr.themes.colors.emerald, ) with gr.Blocks(theme=theme) as demo: with gr.Column(): input_model_id = gr.Dropdown( label="Model ID", choices=[ "black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell", "stabilityai/stable-diffusion-3-medium-diffusers", "stabilityai/stable-diffusion-xl-base-1.0", "stable-diffusion-v1-5/stable-diffusion-v1-5", "stabilityai/japanese-stable-diffusion-xl", "rinna/japanese-stable-diffusion", ], value="black-forest-labs/FLUX.1-dev", ) input_text = gr.Textbox( label="Input Text", placeholder="Enter text here", ) with gr.Tab(label="Tokenization Outputs"): with gr.Column(): output_highlighted_text_1 = gr.HighlightedText() output_highlighted_text_2 = gr.HighlightedText() output_highlighted_text_3 = gr.HighlightedText() with gr.Tab(label="Special Tokens"): with gr.Column(): output_special_tokens_1 = gr.HighlightedText() output_special_tokens_2 = gr.HighlightedText() output_special_tokens_3 = gr.HighlightedText() with gr.Tab(label="Tokenizer Details"): with gr.Column(): output_tokenizer_details_1 = gr.Dataframe(headers=["Attribute", "Value"]) output_tokenizer_details_2 = gr.Dataframe(headers=["Attribute", "Value"]) output_tokenizer_details_3 = gr.Dataframe(headers=["Attribute", "Value"]) with gr.Row(): clear_button = gr.ClearButton(components=[input_text]) submit_button = gr.Button("Run", variant="primary") all_inputs = [input_model_id, input_text] all_output = [ output_highlighted_text_1, output_highlighted_text_2, output_highlighted_text_3, output_special_tokens_1, output_special_tokens_2, output_special_tokens_3, output_tokenizer_details_1, output_tokenizer_details_2, output_tokenizer_details_3, ] submit_button.click(fn=inference, inputs=all_inputs, outputs=all_output) examples = gr.Examples( fn=inference, inputs=all_inputs, outputs=all_output, examples=[ ["black-forest-labs/FLUX.1-dev", "a photo of cat"], [ "stabilityai/stable-diffusion-3-medium-diffusers", 'cat holding sign saying "I am a cat"', ], ["rinna/japanese-stable-diffusion", "空を飛んでいるネコの写真 油絵"], ], cache_examples=True, ) demo.queue().launch()