import transformers import torch class InferencePipeline: def __init__(self, conf, api_key): self.conf = conf self.token = api_key self.pipeline = self.get_model() def get_model(self): pipeline = transformers.pipeline( "text-generation", model=self.conf["model"]["model_name"], model_kwargs={"torch_dtype": torch.bfloat16}, device_map=self.conf["model"]["device_map"], token=self.token ) return pipeline def infer(self, prompt): outputs = self.pipeline( prompt, max_new_tokens=self.conf["model"]["max_new_tokens"], ) outputs = outputs[0]["generated_text"][-1] outputs = outputs['content'] return outputs