Spaces:
Sleeping
Sleeping
File size: 805 Bytes
3857382 38f32e9 3857382 38f32e9 3857382 be0b5f6 3857382 38f32e9 3857382 590a47c 3857382 590a47c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import transformers
import torch
class InferencePipeline:
def __init__(self, conf, api_key):
self.conf = conf
self.token = api_key
self.pipeline = self.get_model()
def get_model(self):
pipeline = transformers.pipeline(
"text-generation",
model=self.conf["model"]["model_name"],
model_kwargs={"torch_dtype": torch.bfloat16},
device_map=self.conf["model"]["device_map"],
token=self.token
)
return pipeline
def infer(self, prompt):
outputs = self.pipeline(
prompt,
max_new_tokens=self.conf["model"]["max_new_tokens"],
)
outputs = outputs[0]["generated_text"][-1]
outputs = outputs['content']
return outputs
|