Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import ( | |
PreTrainedTokenizerBase, | |
PreTrainedTokenizerFast, | |
AutoModelForCausalLM, | |
) | |
model_dict = { | |
"NanoTranslator-XS": "Mxode/NanoTranslator-XS", | |
"NanoTranslator-S": "Mxode/NanoTranslator-S", | |
"NanoTranslator-M": "Mxode/NanoTranslator-M", | |
"NanoTranslator-M2": "Mxode/NanoTranslator-M2", | |
"NanoTranslator-L": "Mxode/NanoTranslator-L", | |
"NanoTranslator-XL": "Mxode/NanoTranslator-XL", | |
"NanoTranslator-XXL": "Mxode/NanoTranslator-XXL", | |
"NanoTranslator-XXL2": "Mxode/NanoTranslator-XXL2", | |
} | |
# initialize model | |
def load_model(model_path: str): | |
model = AutoModelForCausalLM.from_pretrained(model_path) | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path) | |
return model, tokenizer | |
def translate(text: str, model, tokenizer: PreTrainedTokenizerBase, **kwargs): | |
generation_args = dict( | |
max_new_tokens=kwargs.pop("max_new_tokens", 64), | |
do_sample=kwargs.pop("do_sample", True), | |
temperature=kwargs.pop("temperature", 0.55), | |
top_p=kwargs.pop("top_p", 0.8), | |
top_k=kwargs.pop("top_k", 40), | |
**kwargs | |
) | |
prompt = "<|im_start|>" + text + "<|endoftext|>" | |
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device) | |
generated_ids = model.generate(model_inputs.input_ids, **generation_args) | |
generated_ids = [ | |
output_ids[len(input_ids) :] | |
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return response | |
st.title("NanoTranslator Demo") | |
st.sidebar.title("Options") | |
model_choice = st.sidebar.selectbox("Model", list(model_dict.keys()), index=list(model_dict.keys()).index("NanoTranslator-XXL2")) | |
do_sample = st.sidebar.checkbox("do_sample", value=True) | |
max_new_tokens = st.sidebar.slider( | |
"max_new_tokens", min_value=1, max_value=256, value=64 | |
) | |
temperature = st.sidebar.slider( | |
"temperature", min_value=0.01, max_value=1.5, value=0.55, step=0.01 | |
) | |
top_p = st.sidebar.slider("top_p", min_value=0.01, max_value=1.0, value=0.8, step=0.01) | |
top_k = st.sidebar.slider("top_k", min_value=1, max_value=100, value=40, step=1) | |
# 根据选择的模型加载 | |
model_path = model_dict[model_choice] | |
model, tokenizer = load_model(model_path) | |
input_text = st.text_area( | |
"Please input the text to be translated (Currently supports only English to Chinese):", | |
"Each step of the cell cycle is monitored by internal.", | |
) | |
if st.button("translate"): | |
if input_text.strip(): | |
with st.spinner("Translating..."): | |
translation = translate( | |
input_text, | |
model, | |
tokenizer, | |
max_new_tokens=max_new_tokens, | |
do_sample=do_sample, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
) | |
st.success("Translated successfully!") | |
st.write(translation) | |
else: | |
st.warning("Please input text before translation!") | |