Cross-Encoder for Natural Language Inference
This model was trained using SentenceTransformers Cross-Encoder class.
Training Data
The model was trained on the SNLI and MultiNLI datasets. For a given sentence pair, it will output three scores corresponding to the labels: contradiction, entailment, neutral.
Performance
For evaluation results, see SBERT.net - Pretrained Cross-Encoder.
Usage
Pre-trained models can be used like this:
from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/nli-roberta-base')
scores = model.predict([('A man is eating pizza', 'A man eats something'), ('A black race car starts up in front of a crowd of people.', 'A man is driving down a lonely road.')])
#Convert scores to labels
label_mapping = ['contradiction', 'entailment', 'neutral']
labels = [label_mapping[score_max] for score_max in scores.argmax(axis=1)]
Usage with Transformers AutoModel
You can use the model also directly with Transformers library (without SentenceTransformers library):
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/nli-roberta-base')
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/nli-roberta-base')
features = tokenizer(['A man is eating pizza', 'A black race car starts up in front of a crowd of people.'], ['A man eats something', 'A man is driving down a lonely road.'], padding=True, truncation=True, return_tensors="pt")
model.eval()
with torch.no_grad():
scores = model(**features).logits
label_mapping = ['contradiction', 'entailment', 'neutral']
labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1)]
print(labels)
Zero-Shot Classification
This model can also be used for zero-shot-classification:
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model='cross-encoder/nli-roberta-base')
sent = "Apple just announced the newest iPhone X"
candidate_labels = ["technology", "sports", "politics"]
res = classifier(sent, candidate_labels)
print(res)
- Downloads last month
- 10,933
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.