highdeff1 / Training.py
highdeff's picture
Upload 16 files
2c07569
import torch
from transformers import Trainer, TrainingArguments
from my_custom_model import MyCustomModel # import your custom model class here
from my_dataset import MyDataset # import your dataset class here
# Instantiate the tokenizer
tokenizer = ... # define your tokenizer here
# Load the dataset and preprocess it
train_dataset = MyDataset(...) # define your training dataset here
val_dataset = MyDataset(...) # define your validation dataset here
# Define your custom model and the training arguments
model = MyCustomModel(...) # define your custom model here
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy='epoch',
learning_rate=2e-4,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=1,
weight_decay=0.01,
)
# Define the trainer and train the model
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)
trainer.train()
# Save the trained model
model_path = './trained_model'
model.save_pretrained(model_path)
# Load the trained model
model = MyCustomModel.from_pretrained(model_path)
# Define your inference function
def answer_question(input_text):
# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors='pt')
# Generate the answer
answer_ids = model.generate(input_ids)
answer = tokenizer.decode(answer_ids[0], skip_special_tokens=True)
return answer
# Test the model with an example input
input_text = "Your input text here"
answer = answer_question(input_text)
print(answer)