syf2023 commited on
Commit
49d2066
1 Parent(s): ddbc45a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +30 -0
handler.py CHANGED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ import holidays
4
+
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ self.pipeline = pipeline("text-classification", model=path)
9
+ self.holidays = holidays.US()
10
+
11
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
12
+ """
13
+ data args:
14
+ inputs (:obj: `str`)
15
+ date (:obj: `str`)
16
+ Return:
17
+ A :obj:`list` | `dict`: will be serialized and returned
18
+ """
19
+ # get inputs
20
+ inputs = data.pop("inputs", data)
21
+ # get additional date field
22
+ date = data.pop("date", None)
23
+
24
+ # check if date exists and if it is a holiday
25
+ if date is not None and date in self.holidays:
26
+ return [{"label": "happy", "score": 1}]
27
+
28
+ # run normal prediction
29
+ prediction = self.pipeline(inputs)
30
+ return prediction