|
class Obj: |
|
def __init__(self, model, tokenizer, device = "cpu"): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.device = device |
|
self.model = self.model.to(self.device) |
|
|
|
def predict( |
|
self, |
|
source_text: str, |
|
max_length: int = 512, |
|
num_return_sequences: int = 1, |
|
num_beams: int = 2, |
|
top_k: int = 50, |
|
top_p: float = 0.95, |
|
do_sample: bool = True, |
|
repetition_penalty: float = 2.5, |
|
length_penalty: float = 1.0, |
|
early_stopping: bool = True, |
|
skip_special_tokens: bool = True, |
|
clean_up_tokenization_spaces: bool = True, |
|
): |
|
input_ids = self.tokenizer.encode( |
|
source_text, return_tensors="pt", add_special_tokens=True |
|
) |
|
input_ids = input_ids.to(self.device) |
|
generated_ids = self.model.generate( |
|
input_ids=input_ids, |
|
num_beams=num_beams, |
|
max_length=max_length, |
|
repetition_penalty=repetition_penalty, |
|
length_penalty=length_penalty, |
|
early_stopping=early_stopping, |
|
top_p=top_p, |
|
top_k=top_k, |
|
num_return_sequences=num_return_sequences, |
|
do_sample = do_sample |
|
) |
|
preds = [ |
|
self.tokenizer.decode( |
|
g, |
|
skip_special_tokens=skip_special_tokens, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
) |
|
for g in generated_ids |
|
] |
|
return preds |
|
|