tykiww commited on
Commit
38f32e9
1 Parent(s): aaea170

Update connections/model.py

Browse files
Files changed (1) hide show
  1. connections/model.py +3 -3
connections/model.py CHANGED
@@ -13,9 +13,9 @@ class InferencePipeline:
13
 
14
  pipeline = transformers.pipeline(
15
  "text-generation",
16
- model=conf["model"]["model_name"],
17
  model_kwargs={"torch_dtype": torch.bfloat16},
18
- device_map=conf["model"]["device_map"],
19
  token=self.token
20
  )
21
 
@@ -25,7 +25,7 @@ class InferencePipeline:
25
 
26
  outputs = pipeline(
27
  prompt,
28
- max_new_tokens=conf["model"]["max_new_tokens"],
29
  )
30
 
31
  return outputs[0]["generated_text"][-1]
 
13
 
14
  pipeline = transformers.pipeline(
15
  "text-generation",
16
+ model=self.conf["model"]["model_name"],
17
  model_kwargs={"torch_dtype": torch.bfloat16},
18
+ device_map=self.conf["model"]["device_map"],
19
  token=self.token
20
  )
21
 
 
25
 
26
  outputs = pipeline(
27
  prompt,
28
+ max_new_tokens=self.conf["model"]["max_new_tokens"],
29
  )
30
 
31
  return outputs[0]["generated_text"][-1]