zetavg
fix
ae567a3 unverified
raw
history blame
3.26 kB
import torch
import transformers
from .get_device import get_device
from .streaming_generation_utils import Iteratorize, Stream
def generate(
# model
model,
tokenizer,
# input
prompt,
generation_config,
max_new_tokens,
stopping_criteria=[],
# output options
stream_output=False
):
device = get_device()
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generate_params = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"max_new_tokens": max_new_tokens,
"stopping_criteria": transformers.StoppingCriteriaList() + stopping_criteria
}
skip_special_tokens = True
if '/dolly' in tokenizer.name_or_path:
# dolly has additional_special_tokens as ['### End', '### Instruction:', '### Response:'], skipping them will break the prompter's reply extraction.
skip_special_tokens = False
# Ensure generation stops once it generates "### End"
end_key_token_id = tokenizer.encode("### End")
end_key_token_id = end_key_token_id[0] # 50277
if isinstance(generate_params['generation_config'].eos_token_id, str):
generate_params['generation_config'].eos_token_id = [generate_params['generation_config'].eos_token_id]
elif not generate_params['generation_config'].eos_token_id:
generate_params['generation_config'].eos_token_id = []
generate_params['generation_config'].eos_token_id.append(end_key_token_id)
if stream_output:
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator,
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
generation_output = None
def generate_with_callback(callback=None, **kwargs):
nonlocal generation_output
kwargs["stopping_criteria"].insert(
0,
Stream(callback_func=callback)
)
with torch.no_grad():
generation_output = model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(
generate_with_callback, kwargs, callback=None
)
with generate_with_streaming(**generate_params) as generator:
for output in generator:
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
yield decoded_output, output, False
if generation_output:
output = generation_output.sequences[0]
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
yield decoded_output, output, True
return # early return for stream_output
# Without streaming
with torch.no_grad():
generation_output = model.generate(**generate_params)
output = generation_output.sequences[0]
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
yield decoded_output, output, True
return