sproos commited on
Commit
c83bd3f
1 Parent(s): 3129965

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -2
handler.py CHANGED
@@ -62,14 +62,15 @@ class EndpointHandler:
62
 
63
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
64
  embedding = data.pop("embedding", None)
 
65
  max_length=200
66
  with torch.no_grad():
67
- outputs = self.model(ada_embedding=embedding, decoded_tkns=None)
68
  decoded_tkns = outputs.logits.argmax(dim=-1)
69
 
70
  for _ in range(max_length):
71
  with torch.no_grad():
72
- outputs = self.model(ada_embedding=embedding, decoded_tkns=decoded_tkns)
73
 
74
  # Get the most likely next token, sampled from top k
75
  logits = outputs.logits[:, -1]
 
62
 
63
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
64
  embedding = data.pop("embedding", None)
65
+ ada_embedding = torch.tensor(embedding).unsqueeze(0)
66
  max_length=200
67
  with torch.no_grad():
68
+ outputs = self.model(ada_embedding=ada_embedding, decoded_tkns=None)
69
  decoded_tkns = outputs.logits.argmax(dim=-1)
70
 
71
  for _ in range(max_length):
72
  with torch.no_grad():
73
+ outputs = self.model(ada_embedding=ada_embedding, decoded_tkns=decoded_tkns)
74
 
75
  # Get the most likely next token, sampled from top k
76
  logits = outputs.logits[:, -1]