|
--- |
|
license: mit |
|
--- |
|
|
|
**NLI-Mixer** is an attempt to tackle the Natural Language Inference (NLI) task by mixing multiple datasets together. |
|
|
|
The approach is simple: |
|
|
|
1. Combine all available NLI data without any domain-dependent re-balancing or re-weighting. |
|
2. Finetune several SOTA transformers of different sizes (20m parameters to 300m parameters) on the combined data. |
|
3. Evaluate on challenging NLI datasets. |
|
|
|
This model was trained using [SentenceTransformers](https://sbert.net) [Cross-Encoder](https://www.sbert.net/examples/applications/cross-encoder/README.html) class. It is based on [microsoft/deberta-v3-base](https://huggingface.co/microsoft/deberta-v3-base). |
|
|
|
### Data |
|
20+ NLI datasets were combined to train a binary classification model. The `contradiction` and `neutral` labels were combined to form a `non-entailment` class. |
|
|
|
### Usage |
|
|
|
In Transformers |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
from torch.nn.functional import softmax, sigmoid |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model_name="ragarwal/deberta-v3-base-nli-mixer-binary" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
sentence = "During its monthly call, the National Oceanic and Atmospheric Administration warned of \ |
|
increased temperatures and low precipitation" |
|
labels = ["Computer", "Climate Change", "Tablet", "Football", "Artificial Intelligence", "Global Warming"] |
|
|
|
features = tokenizer([[sentence, l] for l in labels], padding=True, truncation=True, return_tensors="pt") |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
scores = model(**features).logits |
|
print("Multi-Label:", sigmoid(scores)) #Multi-Label Classification |
|
print("Single-Label:", softmax(scores, dim=0)) #Single-Label Classification |
|
|
|
#Multi-Label: tensor([[0.0412],[0.2436],[0.0394],[0.0020],[0.0050],[0.1424]]) |
|
#Single-Label: tensor([[0.0742],[0.5561],[0.0709],[0.0035],[0.0087],[0.2867]]) |
|
``` |
|
|
|
|
|
In Sentence-Transformers |
|
|
|
```python |
|
from sentence_transformers import CrossEncoder |
|
|
|
model_name="ragarwal/deberta-v3-base-nli-mixer-binary" |
|
model = CrossEncoder(model_name, max_length=256) |
|
|
|
sentence = "During its monthly call, the National Oceanic and Atmospheric Administration warned of \ |
|
increased temperatures and low precipitation" |
|
labels = ["Computer", "Climate Change", "Tablet", "Football", "Artificial Intelligence", "Global Warming"] |
|
|
|
scores = model.predict([[sentence, l] for l in labels]) |
|
print(scores) |
|
#array([0.04118565, 0.2435827 , 0.03941465, 0.00203637, 0.00501176, 0.1423797], dtype=float32) |
|
|
|
``` |