svjack's picture
Upload with huggingface_hub
f7a83c6
raw
history blame
1.55 kB
class Obj:
def __init__(self, model, tokenizer, device = "cpu"):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model = self.model.to(self.device)
def predict(
self,
source_text: str,
max_length: int = 512,
num_return_sequences: int = 1,
num_beams: int = 2,
top_k: int = 50,
top_p: float = 0.95,
do_sample: bool = True,
repetition_penalty: float = 2.5,
length_penalty: float = 1.0,
early_stopping: bool = True,
skip_special_tokens: bool = True,
clean_up_tokenization_spaces: bool = True,
):
input_ids = self.tokenizer.encode(
source_text, return_tensors="pt", add_special_tokens=True
)
input_ids = input_ids.to(self.device)
generated_ids = self.model.generate(
input_ids=input_ids,
num_beams=num_beams,
max_length=max_length,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
early_stopping=early_stopping,
top_p=top_p,
top_k=top_k,
num_return_sequences=num_return_sequences,
do_sample = do_sample
)
preds = [
self.tokenizer.decode(
g,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
for g in generated_ids
]
return preds