from collections import defaultdict from typing import Dict import datasets from datasets import Dataset from sentence_transformers import ( SentenceTransformer, SentenceTransformerTrainer, losses, evaluation, SentenceTransformerTrainingArguments ) from sentence_transformers.models import Transformer, Pooling, Dense, Normalize def to_triplets(dataset): premises = defaultdict(dict) for sample in dataset: premises[sample["premise"]][sample["label"]] = sample["hypothesis"] queries = [] positives = [] negatives = [] for premise, sentences in premises.items(): if 0 in sentences and 2 in sentences: queries.append(premise) positives.append(sentences[0]) # <- entailment negatives.append(sentences[2]) # <- contradiction return Dataset.from_dict({ "anchor": queries, "positive": positives, "negative": negatives, }) snli_ds = datasets.load_dataset("snli") snli_ds = datasets.DatasetDict({ "train": to_triplets(snli_ds["train"]), "validation": to_triplets(snli_ds["validation"]), "test": to_triplets(snli_ds["test"]), }) multi_nli_ds = datasets.load_dataset("multi_nli") multi_nli_ds = datasets.DatasetDict({ "train": to_triplets(multi_nli_ds["train"]), "validation_matched": to_triplets(multi_nli_ds["validation_matched"]), }) all_nli_ds = datasets.DatasetDict({ "train": datasets.concatenate_datasets([snli_ds["train"], multi_nli_ds["train"]]),#.select(range(10000)), "validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]),#.select(range(1000)), "test": snli_ds["test"] }) stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation") stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test") training_args = SentenceTransformerTrainingArguments( output_dir="checkpoints", num_train_epochs=1, seed=42, per_device_train_batch_size=256, per_device_eval_batch_size=256, learning_rate=2e-5, warmup_ratio=0.1, bf16=True, logging_steps=100, eval_strategy="steps", eval_steps=100, save_steps=100, save_total_limit=2, metric_for_best_model="sts-dev_spearman_cosine", greater_is_better=True, ) transformer = Transformer("prajjwal1/bert-tiny", max_seq_length=384) pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean") dense = Dense(128, 256) normalize = Normalize() model = SentenceTransformer(modules=[transformer, pooling, dense, normalize]) # Ensure all tensors in the model are contiguous for param in model.parameters(): param.data = param.data.contiguous() loss = losses.MultipleNegativesRankingLoss(model) # loss = losses.MatryoshkaLoss(model, loss, [256, 128, 64, 32, 16, 8]) dev_evaluator = evaluation.EmbeddingSimilarityEvaluator( stsb_dev["sentence1"], stsb_dev["sentence2"], [score / 5 for score in stsb_dev["score"]], main_similarity=evaluation.SimilarityFunction.COSINE, name="sts-dev", ) trainer = SentenceTransformerTrainer( model=model, evaluator=dev_evaluator, args=training_args, train_dataset=all_nli_ds["train"], eval_dataset=all_nli_ds["validation"], loss=loss, ) trainer.train() test_evaluator = evaluation.EmbeddingSimilarityEvaluator( stsb_test["sentence1"], stsb_test["sentence2"], [score / 5 for score in stsb_test["score"]], main_similarity=evaluation.SimilarityFunction.COSINE, name="sts-test", ) results = test_evaluator(model) breakpoint() model.push_to_hub("sentence-transformers-testing/all-nli-bert-tiny-dense", private=True)