from datasets import load_dataset from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling # Load the tokenizer and set the padding token tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') # Load the GPT-2 tokenizer if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Set a default pad token if not defined # Tokenize function with padding and truncation def tokenize_function(examples): return tokenizer( examples['Question'], # Use the correct column name padding='max_length', # Ensure consistent padding truncation=True, # Enable truncation max_length=128 # Define a suitable max length ) # Load the dataset dataset = load_dataset('InnerI/synCAI_144kda') # Load your specific dataset # Tokenize the dataset with batched processing tokenized_datasets = dataset.map(tokenize_function, batched=True) # Load the model model = GPT2LMHeadModel.from_pretrained('gpt2-medium') # Load GPT-2 model # Define the data collator for language modeling data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False # Set to False for standard language modeling (non-masked) ) # Define training arguments with output directory and other settings training_args = TrainingArguments( output_dir=r"InnerI/synCAI-144k-gpt2.5", # Use raw string for Windows path overwrite_output_dir=True, num_train_epochs=1, # Number of epochs for training per_device_train_batch_size=4, # Batch size for each training device save_steps=10_000, # Save model checkpoint every 10,000 steps save_total_limit=2, # Limit to 2 checkpoints prediction_loss_only=True, # Record only loss during training ) # Initialize the Trainer with model, arguments, and collator trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=tokenized_datasets['train'], # Use the tokenized train dataset ) # Start training the model trainer.train() # Save the fine-tuned model to the specified output directory trainer.save_model(r"CAI-gpt2.5") # Use raw string for Windows path