Update README.md
Browse files
README.md
CHANGED
@@ -107,10 +107,7 @@ contexts = [
|
|
107 |
|
108 |
## convert query into a format as follows:
|
109 |
## user: {user}\nagent: {agent}\nuser: {user}
|
110 |
-
formatted_query = ""
|
111 |
-
for turn in query:
|
112 |
-
formatted_query += turn['role'] + ": " + turn['content'] + "\n"
|
113 |
-
formatted_query = formatted_query.strip()
|
114 |
|
115 |
## get query and context embeddings
|
116 |
query_input = tokenizer(formatted_query, return_tensors='pt')
|
@@ -118,9 +115,11 @@ ctx_input = tokenizer(contexts, padding=True, return_tensors='pt')
|
|
118 |
query_emb = query_encoder(**query_input).last_hidden_state[:, 0, :]
|
119 |
ctx_emb = context_encoder(**ctx_input).last_hidden_state[:, 0, :]
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
124 |
```
|
125 |
|
126 |
## License
|
|
|
107 |
|
108 |
## convert query into a format as follows:
|
109 |
## user: {user}\nagent: {agent}\nuser: {user}
|
110 |
+
formatted_query = '\n'.join([turn['role'] + ": " + turn['content'] for turn in messages]).strip()
|
|
|
|
|
|
|
111 |
|
112 |
## get query and context embeddings
|
113 |
query_input = tokenizer(formatted_query, return_tensors='pt')
|
|
|
115 |
query_emb = query_encoder(**query_input).last_hidden_state[:, 0, :]
|
116 |
ctx_emb = context_encoder(**ctx_input).last_hidden_state[:, 0, :]
|
117 |
|
118 |
+
## Compute similarity scores using dot product
|
119 |
+
similarities = query_emb.matmul(ctx_emb.transpose(0, 1)) # (1, num_ctx)
|
120 |
+
|
121 |
+
## rank the similarity (from highest to lowest)
|
122 |
+
ranked_results = torch.argsort(similarities, dim=-1, descending=True) # (1, num_ctx)
|
123 |
```
|
124 |
|
125 |
## License
|