Impossible_llm / train /train_accelerate.py
Yaning1001's picture
Add files using upload-large-folder tool
f20d980 verified
raw
history blame
3.93 kB
import sys
import torch
sys.path.append("..")
import os
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from utils_llama import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
import argparse
# import wandb
# Setup for Weights & Biases
# wandb.init(project="kallini", group="babylm-perturbation-experiments", name=run_id)
if __name__ == "__main__":
# === CONFIGURATION SETTINGS ===
parser = argparse.ArgumentParser(description="Training configuration.")
parser.add_argument('--perturbation', type=str, default='hop_tokens4', help='Type of perturbation to use.')
parser.add_argument('--train_set', type=str, default='10M', help='Dataset size for training.')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training.')
parser.add_argument('--epoch', type=int, default=20, help='train epoch')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
args = parser.parse_args()
# no_pos_encodings_underscore = "" # Ex: "_nopos" if needed
ckpt_path = "./checkpoints"
# effective_bsz = 512
model_name = "meta-llama/Llama-3.2-3B"
model_save_name = "Llama-3.2-3B"
# === FILE PATHS BASED ON CONFIGURATION ===
run_id = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
cache_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "artifacts")
run_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "runs")
os.makedirs(cache_dir, exist_ok=True)
os.makedirs(run_dir, exist_ok=True)
# === DATASET LOADING ===
dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
dataset = load_dataset('babylm_dataset_llama.py', name=dataset_name, trust_remote_code=True)
train_dataset = dataset['train']
# === TOKENIZER & MODEL LOADING ===
# model_name = f"gpt2{'' if no_pos_encodings_underscore == '' else '-no-pos'}-small-{perturbation}-{paren_model}"
# tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
tokenizer = PERTURBATIONS[args.perturbation]['llama_tokenizer']
model = AutoModelForCausalLM.from_pretrained(model_name,
device_map="auto",
cache_dir=cache_dir)
# print("model:", model)
# === TOKENIZATION ===
def tokenize_function(examples):
return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024)
tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# === DATA COLLATOR ===
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# === TRAINING ARGUMENTS ===
training_args = TrainingArguments(
output_dir=run_dir,
# evaluation_strategy="steps",
evaluation_strategy="no",
# per_device_train_batch_size=int(effective_bsz / 1), # Assuming 1 GPU for this example
per_device_train_batch_size=args.batch_size, # Assuming 1 GPU for this example
logging_dir='./logs',
logging_steps=1000,
save_steps=1000,
# save_total_limit=5,
learning_rate=2e-5,
num_train_epochs=args.epoch,
seed=args.seed,
# load_best_model_at_end=True,
gradient_accumulation_steps=1, # help reduce gpu memory
fp16 = True, # Enable mixed precision training
report_to="none",
)
# === TRAINER ===
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
tokenizer=tokenizer,
data_collator=data_collator
)
# === TRAIN MODEL ===
trainer.train()
# End logging
# wandb.finish()