|
import torch |
|
from transformers import Trainer, TrainingArguments |
|
from my_custom_model import MyCustomModel |
|
from my_dataset import MyDataset |
|
|
|
|
|
tokenizer = ... |
|
|
|
|
|
train_dataset = MyDataset(...) |
|
val_dataset = MyDataset(...) |
|
|
|
|
|
model = MyCustomModel(...) |
|
training_args = TrainingArguments( |
|
output_dir='./results', |
|
evaluation_strategy='epoch', |
|
learning_rate=2e-4, |
|
per_device_train_batch_size=16, |
|
per_device_eval_batch_size=16, |
|
num_train_epochs=1, |
|
weight_decay=0.01, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=val_dataset, |
|
) |
|
trainer.train() |
|
|
|
|
|
model_path = './trained_model' |
|
model.save_pretrained(model_path) |
|
|
|
|
|
model = MyCustomModel.from_pretrained(model_path) |
|
|
|
|
|
def answer_question(input_text): |
|
|
|
input_ids = tokenizer.encode(input_text, return_tensors='pt') |
|
|
|
|
|
answer_ids = model.generate(input_ids) |
|
answer = tokenizer.decode(answer_ids[0], skip_special_tokens=True) |
|
|
|
return answer |
|
|
|
|
|
input_text = "Your input text here" |
|
answer = answer_question(input_text) |
|
print(answer) |
|
|