Spaces:
Sleeping
Sleeping
# To run: funix main.py | |
from transformers import AutoTokenizer | |
from transformers import AutoModelForCausalLM | |
import typing | |
from funix import funix | |
from funix.hint import HTML | |
low_memory = True # Set to True to run on mobile devices | |
import os | |
hf_token = os.environ.get("HF_TOKEN") | |
ku_gpt_tokenizer = AutoTokenizer.from_pretrained("ku-nlp/gpt2-medium-japanese-char") | |
chj_gpt_tokenizer = AutoTokenizer.from_pretrained("TURX/chj-gpt2", token=hf_token) | |
wakagpt_tokenizer = AutoTokenizer.from_pretrained("TURX/wakagpt", token=hf_token) | |
ku_gpt_model = AutoModelForCausalLM.from_pretrained("ku-nlp/gpt2-medium-japanese-char") | |
chj_gpt_model = AutoModelForCausalLM.from_pretrained("TURX/chj-gpt2", token=hf_token) | |
wakagpt_model = AutoModelForCausalLM.from_pretrained("TURX/wakagpt", token=hf_token) | |
print("Models loaded successfully.") | |
model_name_map = { | |
"Kyoto University GPT-2 (Modern)": "ku-gpt2", | |
"CHJ GPT-2 (Classical)": "chj-gpt2", | |
"Waka GPT": "wakagpt", | |
} | |
waka_type_map = { | |
"kana": "[ไปฎๅ]", | |
"original": "[ๅๆ]", | |
"aligned": "[ๆดๅฝข]", | |
} | |
def home(): | |
return | |
def __generate(tokenizer: AutoTokenizer, model: AutoModelForCausalLM, prompt: str, | |
do_sample: bool, num_beams: int, num_beam_groups: int, max_new_tokens: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float, num_return_sequences: int | |
) -> str: | |
global low_memory | |
inputs = tokenizer(prompt, return_tensors="pt").input_ids | |
outputs = model.generate(inputs, low_memory=low_memory, do_sample=do_sample, num_beams=num_beams, num_beam_groups=num_beam_groups, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=num_return_sequences) | |
return tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
def prompt(prompt: str = "ใใใซใกใฏใ", model_type: typing.Literal["Kyoto University GPT-2 (Modern)", "CHJ GPT-2 (Classical)", "Waka GPT"] = "Kyoto University GPT-2 (Modern)", | |
do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, max_new_tokens: int = 100, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1 | |
) -> HTML: | |
model_name = model_name_map[model_type] | |
if model_name == "ku-gpt2": | |
tokenizer = ku_gpt_tokenizer | |
model = ku_gpt_model | |
elif model_name == "chj-gpt2": | |
tokenizer = chj_gpt_tokenizer | |
model = chj_gpt_model | |
elif model_name == "wakagpt": | |
tokenizer = wakagpt_tokenizer | |
model = wakagpt_model | |
else: | |
raise NotImplementedError(f"Unsupported model: {model_name}") | |
generated = __generate(tokenizer, model, prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences) | |
return HTML("".join([f"<p>{i}</p>" for i in generated])) | |
def waka(preface: str = "", author: str = "", first_line: str = "ใใใใฌใจโใใซใฏใใใใซโใฟใใญใจใ", type: typing.Literal["Kana", "Original", "Aligned"] = "Kana", remaining_lines: int = 2, | |
do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1 | |
) -> HTML: | |
waka_prompt = "" | |
if preface: | |
waka_prompt += "[่ฉๆธ] " + preface + "\n" | |
if author: | |
waka_prompt += "[ไฝ่ ] " + author + "\n" | |
token_counts = [5, 7, 5, 7, 7] | |
max_new_tokens = sum(token_counts[-remaining_lines:]) | |
first_line = first_line.strip() | |
# add separators | |
if type.lower() in ["kana", "aligned"]: | |
if first_line == "": | |
max_new_tokens += 4 | |
else: | |
first_line += "โ" if first_line[-1] != "โ" else first_line | |
max_new_tokens += remaining_lines - 1 # remaining separators | |
waka_prompt += waka_type_map[type.lower()] + " " + first_line | |
info = f""" | |
Prompt: {waka_prompt}<br> | |
Max New Tokens: {max_new_tokens}<br> | |
""" | |
generated = __generate(wakagpt_tokenizer, wakagpt_model, waka_prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences) | |
removed = 0 | |
checked_generated = [] | |
if type.lower() == "kana": | |
def check(seq): | |
poem = first_line + seq[len(waka_prompt) - 1:] | |
parts = poem.split("โ") | |
if len(parts) == 5 and all(len(part) == token_counts[i] for i, part in enumerate(parts)): | |
checked_generated.append(poem) | |
else: | |
nonlocal removed | |
removed += 1 | |
for i in generated: | |
check(i) | |
else: | |
checked_generated = [first_line + i[len(waka_prompt) - 1:] for i in generated] | |
generated = [f"<p>{i}</p>" for i in checked_generated] | |
return info + f"Removed Malformed: {removed}<br>Results:<br>{''.join(generated)}" | |
if __name__ == "__main__": | |
print(prompt("ใใใซใกใฏ", "Kyoto University GPT-2 (Modern)", num_beams=5, num_return_sequences=5)) | |