from transformers import Trainer, TrainingArguments | |
training_args = TrainingArguments( | |
output_dir="./results", # Куда сохранять модель | |
evaluation_strategy="epoch", # Как часто проверять на валидации | |
learning_rate=5e-5, | |
per_device_train_batch_size=4, | |
num_train_epochs=3, | |
save_steps=10_000, | |
save_total_limit=2, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset['train'], | |
eval_dataset=dataset['validation'], | |
) | |
trainer.train() |