ClaudiaIoana550 commited on
Commit
bc11d4b
1 Parent(s): a47fb80

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +46 -0
handler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Dict, List
3
+ from langchain.llms import HuggingFacePipeline
4
+
5
+ import torch
6
+ import transformers
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
10
+
11
+ class EndpointHandler:
12
+ def __init__(self, model_path=""):
13
+ tokenizer=AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_path,
16
+ return_dict=True,
17
+ device_map="auto",
18
+ torch_dtype = dtype,
19
+ trust_remote_code=True
20
+ )
21
+
22
+ generation_config = model.generation_config
23
+ generation_config.max_new_tokens = 1700
24
+ generation_config.min_length = 20
25
+ generation_config.temperature = 1
26
+ generation_config.top_p = 0.7
27
+ generation_config.num_return_sequences = 1
28
+ generation_config.pad_token_id = tokenizer.eos_token_id
29
+ generation_config.eos_token_id = tokenizer.eos_token_id
30
+ generation_config.repetition_penalty = 1.1
31
+
32
+ gpipeline = transformers.pipeline(
33
+ model=model,
34
+ tokenizer=tokenizer,
35
+ return_full_text=True,
36
+ task="text-generation",
37
+ stopping_criteria=stopping_criteria,
38
+ generation_config=generation_config
39
+ )
40
+
41
+ self.llm = HuggingFacePipeline(pipeline=gpipeline)
42
+
43
+ def __call__(self, data:Dict[str, Any]) -> Dict[str, Any]:
44
+ prompt = data.pop("inputs", data)
45
+ result = self.llm(prompt)
46
+ return result