Spaces:
Sleeping
Sleeping
File size: 8,317 Bytes
d8a9bd8 827f4e7 d8a9bd8 827f4e7 d8a9bd8 827f4e7 d8a9bd8 1331fa9 d8a9bd8 1331fa9 d8a9bd8 b7a625a d8a9bd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# To run: funix main.py
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import typing
from funix import funix
from funix.hint import HTML
low_memory = True # Set to True to run on mobile devices
import os
hf_token = os.environ.get("HF_TOKEN")
ku_gpt_tokenizer = AutoTokenizer.from_pretrained("ku-nlp/gpt2-medium-japanese-char")
chj_gpt_tokenizer = AutoTokenizer.from_pretrained("TURX/chj-gpt2", token=hf_token)
wakagpt_tokenizer = AutoTokenizer.from_pretrained("TURX/wakagpt", token=hf_token)
ku_gpt_model = AutoModelForCausalLM.from_pretrained("ku-nlp/gpt2-medium-japanese-char")
chj_gpt_model = AutoModelForCausalLM.from_pretrained("TURX/chj-gpt2", token=hf_token)
wakagpt_model = AutoModelForCausalLM.from_pretrained("TURX/wakagpt", token=hf_token)
print("Models loaded successfully.")
model_name_map = {
"Kyoto University GPT-2 (Modern)": "ku-gpt2",
"CHJ GPT-2 (Classical)": "chj-gpt2",
"Waka GPT": "wakagpt",
}
waka_type_map = {
"kana": "[ไปฎๅ]",
"original": "[ๅๆ]",
"aligned": "[ๆดๅฝข]",
}
@funix(
title=" Home",
description="""
<h1>Japanese Language Models</h1><hr>
Final Project, STAT 453 Spring 2024, University of Wisconsin-Madison<br>
Author: Ruixuan Tu (ruixuan@cs.wisc.edu, https://turx.tokyo)<hr>
Navigate the apps using the left sidebar.
"""
)
def home():
return
@funix(disable=True)
def __generate(tokenizer: AutoTokenizer, model: AutoModelForCausalLM, prompt: str,
do_sample: bool, num_beams: int, num_beam_groups: int, max_new_tokens: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float, num_return_sequences: int
) -> str:
global low_memory
inputs = tokenizer(prompt, return_tensors="pt").input_ids
outputs = model.generate(inputs, low_memory=low_memory, do_sample=do_sample, num_beams=num_beams, num_beam_groups=num_beam_groups, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=num_return_sequences)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
@funix(
title="Custom Prompt Japanese GPT-2",
description="""
<h1>Japanese GPT-2</h1><hr>
Let a GPT-2 model to complete a Japanese sentence for you.
""",
argument_labels={
"prompt": "Prompt in Japanese",
"model_type": "Model Type",
"max_new_tokens": "Max New Tokens to Generate",
"do_sample": "Do Sample",
"num_beams": "Number of Beams",
"num_beam_groups": "Number of Beam Groups",
"max_new_tokens": "Max New Tokens",
"temperature": "Temperature",
"top_k": "Top K",
"top_p": "Top P",
"repetition_penalty": "Repetition Penalty",
"num_return_sequences": "Number of Sequences to Return",
},
widgets={
"num_beams": "slider[1,10,1]",
"num_beam_groups": "slider[1,5,1]",
"max_new_tokens": "slider[1,512,1]",
"temperature": "slider[0.0,1.0,0.01]",
"top_k": "slider[1,100,0.1]",
"top_p": "slider[0.0,1.0,0.01]",
"repetition_penalty": "slider[1.0,2.0,0.01]",
"num_return_sequences": "slider[1,5,1]",
}
)
def prompt(prompt: str = "ใใใซใกใฏใ", model_type: typing.Literal["Kyoto University GPT-2 (Modern)", "CHJ GPT-2 (Classical)", "Waka GPT"] = "Kyoto University GPT-2 (Modern)",
do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, max_new_tokens: int = 100, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1
) -> HTML:
model_name = model_name_map[model_type]
if model_name == "ku-gpt2":
tokenizer = ku_gpt_tokenizer
model = ku_gpt_model
elif model_name == "chj-gpt2":
tokenizer = chj_gpt_tokenizer
model = chj_gpt_model
elif model_name == "wakagpt":
tokenizer = wakagpt_tokenizer
model = wakagpt_model
else:
raise NotImplementedError(f"Unsupported model: {model_name}")
generated = __generate(tokenizer, model, prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences)
return HTML("".join([f"<p>{i}</p>" for i in generated]))
@funix(
title="WakaGPT Poem Composer",
description="""
<h1>WakaGPT Poem Composer</h1><hr>
Generate a Japanese waka poem in 5-7-5-7-7 form using WakaGPT. A sample poem (Kokinshu 169) is provided below:<br>
Preface: ็ง็ซใคๆฅใใใ<br>
Author: ๆ่ก ่คๅๆ่กๆ่ฃ (018)<br>
Kana (Kana only with Separator): ใใใใฌใจโใใซใฏใใใใซโใฟใใญใจใโใใใฎใใจใซใโใใจใใใใฌใ<br>
Original (Kana + Kanji without Separator): ใใใใฌใจใใซใฏใใใใซ่ฆใใญใจใ้ขจใฎใใจใซใใใจใใใใฌใ<br>
Aligned (Kana + Kanji with Separator): ใใใใฌใจโใใซใฏใใใใซโ่ฆใใญใจใโ้ขจใฎใใจใซใโใใจใใใใฌใ
""",
argument_labels={
"preface": "Preface (Kotobagaki) in Japanese (optional)",
"author": "Author Name in Japanese (optional)",
"first_line": "First Line of Poem in Japanese (optional)",
"type": "Waka Type",
"remaining_lines": "Remaining Lines of Poem",
"do_sample": "Do Sample",
"num_beams": "Number of Beams",
"num_beam_groups": "Number of Beam Groups",
"temperature": "Temperature",
"top_k": "Top K",
"top_p": "Top P",
"repetition_penalty": "Repetition Penalty",
"num_return_sequences": "Number of Sequences to Return (at Maximum)",
},
widgets={
"remaining_lines": "slider[1,5,1]",
"num_beams": "slider[1,10,1]",
"num_beam_groups": "slider[1,5,1]",
"temperature": "slider[0.0,1.0,0.01]",
"top_k": "slider[1,100,0.1]",
"top_p": "slider[0.0,1.0,0.01]",
"repetition_penalty": "slider[1.0,2.0,0.01]",
"num_return_sequences": "slider[1,5,1]",
}
)
def waka(preface: str = "", author: str = "", first_line: str = "ใใใใฌใจโใใซใฏใใใใซโใฟใใญใจใ", type: typing.Literal["Kana", "Original", "Aligned"] = "Kana", remaining_lines: int = 2,
do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1
) -> HTML:
waka_prompt = ""
if preface:
waka_prompt += "[่ฉๆธ] " + preface + "\n"
if author:
waka_prompt += "[ไฝ่
] " + author + "\n"
token_counts = [5, 7, 5, 7, 7]
max_new_tokens = sum(token_counts[-remaining_lines:])
first_line = first_line.strip()
# add separators
if type.lower() in ["kana", "aligned"]:
if first_line == "":
max_new_tokens += 4
else:
first_line += "โ" if first_line[-1] != "โ" else first_line
max_new_tokens += remaining_lines - 1 # remaining separators
waka_prompt += waka_type_map[type.lower()] + " " + first_line
info = f"""
Prompt: {waka_prompt}<br>
Max New Tokens: {max_new_tokens}<br>
"""
generated = __generate(wakagpt_tokenizer, wakagpt_model, waka_prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences)
removed = 0
checked_generated = []
if type.lower() == "kana":
def check(seq):
poem = first_line + seq[len(waka_prompt) - 1:]
parts = poem.split("โ")
if len(parts) == 5 and all(len(part) == token_counts[i] for i, part in enumerate(parts)):
checked_generated.append(poem)
else:
nonlocal removed
removed += 1
for i in generated:
check(i)
else:
checked_generated = [first_line + i[len(waka_prompt) - 1:] for i in generated]
generated = [f"<p>{i}</p>" for i in checked_generated]
return info + f"Removed Malformed: {removed}<br>Results:<br>{''.join(generated)}"
if __name__ == "__main__":
print(prompt("ใใใซใกใฏ", "Kyoto University GPT-2 (Modern)", num_beams=5, num_return_sequences=5))
|