Update handler.py
Browse files- handler.py +3 -2
handler.py
CHANGED
@@ -62,14 +62,15 @@ class EndpointHandler:
|
|
62 |
|
63 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
64 |
embedding = data.pop("embedding", None)
|
|
|
65 |
max_length=200
|
66 |
with torch.no_grad():
|
67 |
-
outputs = self.model(ada_embedding=
|
68 |
decoded_tkns = outputs.logits.argmax(dim=-1)
|
69 |
|
70 |
for _ in range(max_length):
|
71 |
with torch.no_grad():
|
72 |
-
outputs = self.model(ada_embedding=
|
73 |
|
74 |
# Get the most likely next token, sampled from top k
|
75 |
logits = outputs.logits[:, -1]
|
|
|
62 |
|
63 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
64 |
embedding = data.pop("embedding", None)
|
65 |
+
ada_embedding = torch.tensor(embedding).unsqueeze(0)
|
66 |
max_length=200
|
67 |
with torch.no_grad():
|
68 |
+
outputs = self.model(ada_embedding=ada_embedding, decoded_tkns=None)
|
69 |
decoded_tkns = outputs.logits.argmax(dim=-1)
|
70 |
|
71 |
for _ in range(max_length):
|
72 |
with torch.no_grad():
|
73 |
+
outputs = self.model(ada_embedding=ada_embedding, decoded_tkns=decoded_tkns)
|
74 |
|
75 |
# Get the most likely next token, sampled from top k
|
76 |
logits = outputs.logits[:, -1]
|