Spaces:
Sleeping
Sleeping
import transformers | |
import torch | |
class InferencePipeline: | |
def __init__(self, conf, api_key): | |
self.conf = conf | |
self.token = api_key | |
self.pipeline = self.get_model() | |
def get_model(self): | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model=self.conf["model"]["model_name"], | |
model_kwargs={"torch_dtype": torch.bfloat16}, | |
device_map=self.conf["model"]["device_map"], | |
token=self.token | |
) | |
return pipeline | |
def infer(self, prompt): | |
outputs = self.pipeline( | |
prompt, | |
max_new_tokens=self.conf["model"]["max_new_tokens"], | |
) | |
outputs = outputs[0]["generated_text"][-1] | |
outputs = outputs['content'] | |
return outputs | |