fletch1300 commited on
Commit
be62e65
1 Parent(s): 9c3fdba

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +7 -26
handler.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  import transformers
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
-
8
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
9
 
10
  class EndpointHandler:
@@ -18,45 +17,27 @@ class EndpointHandler:
18
  torch_dtype=dtype,
19
  trust_remote_code=True,
20
  )
21
-
22
  generation_config = self.model.generation_config
23
  generation_config.max_new_tokens = 200
24
- generation_config.temperature = 0.8
25
  generation_config.top_p = 0.8
26
  generation_config.num_return_sequences = 1
27
  generation_config.pad_token_id = self.tokenizer.eos_token_id
28
  generation_config.eos_token_id = self.tokenizer.eos_token_id
29
- generation_config.early_stopping = True
30
  self.generate_config = generation_config
31
 
32
  self.pipeline = transformers.pipeline(
33
  "text-generation", model=self.model, tokenizer=self.tokenizer
34
  )
35
 
36
-
37
- def _ensure_token_limit(self, text):
38
- """Ensure text is within the model's token limit."""
39
- tokens = self.tokenizer.tokenize(text)
40
- if len(tokens) > 2048:
41
- # Remove tokens from the beginning until the text fits
42
- tokens = tokens[-2048:]
43
- return self.tokenizer.decode(tokens)
44
- return text
45
-
46
-
47
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
48
  user_prompt = data.pop("inputs", data)
49
 
50
- # Permanent context
51
- permanent_context = "<context>: You are a life coaching bot..."
52
- structured_prompt = f"{permanent_context}\<bot> response:"
53
-
54
- result = self.pipeline(structured_prompt, generation_config=self.generate_config)
55
 
56
- # Ensure _extract_response is defined and works as intended
57
- response_text = self._extract_response(result[0]['generated_text'])
58
 
59
- # Trimming response
60
- response_text = response_text.rsplit("[END", 1)[0].strip()
61
-
62
- return {"response": response_text}
 
4
  import transformers
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
 
7
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
8
 
9
  class EndpointHandler:
 
17
  torch_dtype=dtype,
18
  trust_remote_code=True,
19
  )
20
+
21
  generation_config = self.model.generation_config
22
  generation_config.max_new_tokens = 200
23
+ generation_config.temperature = 0.4
24
  generation_config.top_p = 0.8
25
  generation_config.num_return_sequences = 1
26
  generation_config.pad_token_id = self.tokenizer.eos_token_id
27
  generation_config.eos_token_id = self.tokenizer.eos_token_id
 
28
  self.generate_config = generation_config
29
 
30
  self.pipeline = transformers.pipeline(
31
  "text-generation", model=self.model, tokenizer=self.tokenizer
32
  )
33
 
 
 
 
 
 
 
 
 
 
 
 
34
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
35
  user_prompt = data.pop("inputs", data)
36
 
37
+ # Add the permanent context to the user's prompt
38
+ permanent_context = "<context>: You are a life coaching bot with the goal of improving understanding, reducing suffering and improving life. Learn about the user in order to provide guidance without making assumptions or adding information not provided by the user."
39
+ combined_prompt = f"{permanent_context}\n<human>: {user_prompt}"
 
 
40
 
41
+ result = self.pipeline(combined_prompt, generation_config=self.generate_config)
 
42
 
43
+ return result