tykiww's picture
Update connections/model.py
590a47c verified
raw
history blame
805 Bytes
import transformers
import torch
class InferencePipeline:
def __init__(self, conf, api_key):
self.conf = conf
self.token = api_key
self.pipeline = self.get_model()
def get_model(self):
pipeline = transformers.pipeline(
"text-generation",
model=self.conf["model"]["model_name"],
model_kwargs={"torch_dtype": torch.bfloat16},
device_map=self.conf["model"]["device_map"],
token=self.token
)
return pipeline
def infer(self, prompt):
outputs = self.pipeline(
prompt,
max_new_tokens=self.conf["model"]["max_new_tokens"],
)
outputs = outputs[0]["generated_text"][-1]
outputs = outputs['content']
return outputs