from dataclasses import dataclass | |
import sys | |
class XftConfig: | |
max_seq_len: int = 4096 | |
beam_width: int = 1 | |
eos_token_id: int = -1 | |
pad_token_id: int = -1 | |
num_return_sequences: int = 1 | |
is_encoder_decoder: bool = False | |
padding: bool = True | |
early_stopping: bool = False | |
data_type: str = "bf16_fp16" | |
class XftModel: | |
def __init__(self, xft_model, xft_config): | |
self.model = xft_model | |
self.config = xft_config | |
def load_xft_model(model_path, xft_config: XftConfig): | |
try: | |
import xfastertransformer | |
from transformers import AutoTokenizer | |
except ImportError as e: | |
print(f"Error: Failed to load xFasterTransformer. {e}") | |
sys.exit(-1) | |
if xft_config.data_type is None or xft_config.data_type == "": | |
data_type = "bf16_fp16" | |
else: | |
data_type = xft_config.data_type | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, use_fast=False, padding_side="left", trust_remote_code=True | |
) | |
xft_model = xfastertransformer.AutoModel.from_pretrained( | |
model_path, dtype=data_type | |
) | |
model = XftModel(xft_model=xft_model, xft_config=xft_config) | |
if model.model.rank > 0: | |
while True: | |
model.model.generate() | |
return model, tokenizer | |