yano0 commited on
Commit
eef3369
1 Parent(s): 9ae5537

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +14 -5
README.md CHANGED
@@ -59,21 +59,30 @@ pip install -U sentence-transformers
59
  Then you can load this model and run inference.
60
  ```python
61
  from sentence_transformers import SentenceTransformer
 
62
 
63
  # Download from the 🤗 Hub
64
- model = SentenceTransformer("pkshatech/RoSEtta-base")
 
65
 
66
  # Don't forget to add the prefix "query: " for query-side or "passage: " for passage-side texts.
67
  sentences = [
68
- 'query: PKSHAはどんな会社ですか?'
69
- 'passage: 研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。'
 
 
70
  ]
71
- embeddings = model.encode(sentences)
72
  print(embeddings.shape)
73
- # [2, 768]
74
 
75
  # Get the similarity scores for the embeddings
76
  similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
 
 
 
 
 
77
  ```
78
 
79
  <!--
 
59
  Then you can load this model and run inference.
60
  ```python
61
  from sentence_transformers import SentenceTransformer
62
+ import torch.nn.functional as F
63
 
64
  # Download from the 🤗 Hub
65
+ # The argument "trust_remote_code=True" is required to load the model
66
+ model = SentenceTransformer("pkshatech/RoSEtta-base-ja",trust_remote_code=True)
67
 
68
  # Don't forget to add the prefix "query: " for query-side or "passage: " for passage-side texts.
69
  sentences = [
70
+ 'query: PKSHAはどんな会社ですか?',
71
+ 'passage: 研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。',
72
+ 'query: 日本で一番高い山は?',
73
+ 'passage: 富士山(ふじさん)は、標高3776.12 m、日本最高峰(剣ヶ峰)の独立峰で、その優美な風貌は日本国外でも日本の象徴として広く知られている。',
74
  ]
75
+ embeddings = model.encode(sentences,convert_to_tensor=True)
76
  print(embeddings.shape)
77
+ # [4, 768]
78
 
79
  # Get the similarity scores for the embeddings
80
  similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
81
+ print(similarities)
82
+ # tensor([[1.0000, 0.5910, 0.4332, 0.5421],
83
+ # [0.5910, 1.0000, 0.4977, 0.6969],
84
+ # [0.4332, 0.4977, 1.0000, 0.7475],
85
+ # [0.5421, 0.6969, 0.7475, 1.0000]])
86
  ```
87
 
88
  <!--