jessicayjm
commited on
Commit
•
713cf03
1
Parent(s):
f798a31
update example
Browse files
README.md
CHANGED
@@ -63,8 +63,8 @@ del state_dict['model.embeddings.position_ids']
|
|
63 |
model.load_state_dict(state_dict)
|
64 |
|
65 |
# use the model
|
66 |
-
target = ['
|
67 |
-
observer = ['
|
68 |
|
69 |
target_encodings = tokenizer(target, padding=True, truncation=True)
|
70 |
target_input_ids = torch.LongTensor(target_encodings['input_ids']).to('cuda')
|
@@ -75,5 +75,5 @@ observer_attention_mask = torch.LongTensor(observer_encodings['attention_mask'])
|
|
75 |
|
76 |
model.eval()
|
77 |
output = model(target_input_ids, target_attention_mask, observer_input_ids, observer_attention_mask)
|
78 |
-
print(output) # [0.
|
79 |
```
|
|
|
63 |
model.load_state_dict(state_dict)
|
64 |
|
65 |
# use the model
|
66 |
+
target = ["I'm so sad that my cat died yesterday."]
|
67 |
+
observer = ["It's ok to feel sad."]
|
68 |
|
69 |
target_encodings = tokenizer(target, padding=True, truncation=True)
|
70 |
target_input_ids = torch.LongTensor(target_encodings['input_ids']).to('cuda')
|
|
|
75 |
|
76 |
model.eval()
|
77 |
output = model(target_input_ids, target_attention_mask, observer_input_ids, observer_attention_mask)
|
78 |
+
print(output) # [0.5755]
|
79 |
```
|