import argparse import abc import random from itertools import permutations from typing import Set import os import json import numpy as np from einops import rearrange, repeat import torch from torch.utils.data import IterableDataset from torch import nn, Tensor import math class AbstractDataset(abc.ABC): def __init__(self, group_elements1: Set, group_elements2: Set, frac_train: float): self.frac_train = frac_train self.group_elements1 = group_elements1 self.group_elements2 = group_elements2 self.ordered_group_elements1 = list(self.group_elements1) self.ordered_group_elements2 = list(self.group_elements2) self.idx2vocab = ["o", "="] + list(group_elements1.union(group_elements2)) self.vocab2idx = {vocab: idx for idx, vocab in enumerate(self.idx2vocab)} self.n_vocab = len(self.idx2vocab) self.n_out = len(group_elements1.union(group_elements2)) idxs = list(range(len(self.group_elements1) * len(self.group_elements2))) random.shuffle(idxs) self.train_pairs, self.val_pairs = ( idxs[: int(len(idxs) * frac_train)], idxs[int(len(idxs) * frac_train) :], ) @abc.abstractmethod def fetch_output(self, a, b): pass def encode(self, sequence): return [self.vocab2idx[item] for item in sequence] def decode(self, sequence): return [self.idx2vocab[item] for item in sequence] def form_equation(self, a, b, c): return [a, "o", b, "=", c] def fetch_example(self, idx): a = self.ordered_group_elements1[idx // len(self.group_elements2)] b = self.ordered_group_elements2[idx % len(self.group_elements2)] c = self.fetch_output(a, b) equation = self.form_equation(a, b, c) return self.encode(equation[:-1]), (self.vocab2idx[c] - 2), equation def fetch_train_example(self): idx = random.choice(self.train_pairs) return self.fetch_example(idx) def fetch_val_example(self): idx = random.choice(self.val_pairs) return self.fetch_example(idx) class ModSumDataset(AbstractDataset): def __init__(self, p, frac_train): super(ModSumDataset, self).__init__(set(range(p)), set(range(p)), frac_train) self.p = p def fetch_output(self, a, b): return (a + b) % self.p class ModSubtractDataset(AbstractDataset): def __init__(self, p, frac_train): super(ModSubtractDataset, self).__init__( set(range(p)), set(range(p)), frac_train ) self.p = p def fetch_output(self, a, b): return (a - b) % self.p class ModDivisonDataset(AbstractDataset): def __init__(self, p, frac_train): super(ModDivisonDataset, self).__init__( set(range(p)), set(range(1, p)), frac_train ) self.p = p def fetch_output(self, a, b): return (a * pow(b, self.p - 2, self.p)) % self.p class PermutationGroup(AbstractDataset): def __init__(self, k, frac_train): perms = set(map(tuple, permutations(list(range(k))))) super(PermutationGroup, self).__init__(perms, perms, frac_train) self.k = k def fetch_output(self, a, b): return tuple([a[b[i]] for i in range(len(b))]) class GroupDataset(IterableDataset): def __init__(self, dataset: AbstractDataset, split: str): super(GroupDataset, self).__init__() assert split in {"train", "val"} self.dataset = dataset self.split = split self.fetch_f = None if self.split == "train": self.fetch_f = self.dataset.fetch_train_example elif self.split == "val": self.fetch_f = self.dataset.fetch_val_example else: raise NotImplementedError def __iter__(self): return self def __next__(self): x, y, _ = self.fetch_f() return torch.tensor(x), torch.tensor(y) def operation_mod_p_data(operation: str, p: int, frac_train: float): """ x◦y (mod p) for 0 <= x < p, 1 <= y < p if operation in DIVISION_MODULO_OPERATIONS x◦y (mod p) for 0 <= x, y < p otherwise """ if operation == "x_plus_y": data = ModSumDataset(p=p, frac_train=frac_train) elif operation == "x_minus_y": data = ModSubtractDataset(p=p, frac_train=frac_train) elif operation == "x_div_y": data = ModDivisonDataset(p=p, frac_train=frac_train) elif operation == "permutation": data = PermutationGroup(k=5, frac_train=frac_train) return data def get_data(operation: str, prime: int, training_fraction: float, batch_size: int): dataset = operation_mod_p_data(operation, prime, training_fraction) train_dataset = GroupDataset(dataset, "train") val_dataset = GroupDataset(dataset, "val") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size) return ( train_loader, val_loader, train_dataset.dataset.n_vocab, train_dataset.dataset.n_out, ) class DecoderBlock(torch.nn.Module): def __init__(self, dim_model: int, n_heads: int): super().__init__() self.self_attn = nn.MultiheadAttention(dim_model, n_heads) self.self_attn_norm = nn.LayerNorm(dim_model) self.ffn = nn.Sequential( nn.Linear(dim_model, dim_model * 4), nn.GELU(), nn.Linear(dim_model * 4, dim_model), ) self.ffn_norm = nn.LayerNorm(dim_model) def forward(self, x: Tensor): attn_mask = torch.full( (len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype ) attn_mask = torch.triu(attn_mask, diagonal=1) a1, _ = self.self_attn(x, x, x, attn_mask=attn_mask) a1 = self.self_attn_norm(x + a1) a2 = self.ffn(a1) a2 = self.ffn_norm(a1 + a2) return a2 class Transformer(torch.nn.Module): def __init__( self, num_layers: int, dim_model: int, num_heads: int, vocab_size: int, output_size: int, seq_len: int, ): super().__init__() self.token_embeddings = nn.Embedding(vocab_size, dim_model) self.position_embeddings = nn.Embedding(seq_len, dim_model) self.model = nn.Sequential( *[DecoderBlock(dim_model, num_heads) for _ in range(num_layers)], nn.LayerNorm(dim_model), nn.Linear(dim_model, output_size), ) def forward(self, inputs: Tensor): batch_size, context_len = inputs.shape token_embedding = self.token_embeddings(inputs) positions = repeat( torch.arange(context_len, device=inputs.device), "p -> b p", b=batch_size ) position_embedding = self.position_embeddings(positions) embedding = token_embedding + position_embedding embedding = rearrange(embedding, "b s d -> s b d") return self.model(embedding) def train(model, train_loader, optimizer, scheduler, device, num_train_batches): # Set model to training mode model.train() criterion = torch.nn.CrossEntropyLoss() loss_total, correct = 0.0, 0.0 total = 0 # Loop over each batch from the training set count = 0 for batch in train_loader: count += 1 # Copy data to device if needed batch = tuple(t.to(device) for t in batch) # Unpack the batch from the loader inputs, labels = batch # Zero gradient buffers optimizer.zero_grad() # Forward pass output = model(inputs)[-1, :, :] loss = criterion(output, labels) correct += (torch.argmax(output, dim=1) == labels).sum() loss_total += loss * len(labels) total += len(labels) # Backward pass loss.backward() # Update weights optimizer.step() scheduler.step() if count >= num_train_batches: break acc = correct / total loss = loss_total / total metrics = { "train_accuracy": float(acc), "train_loss": float(loss), } return metrics def evaluate(model, val_loader, device, num_eval_batches): # Set model to evaluation mode model.eval() criterion = torch.nn.CrossEntropyLoss() correct = 0 loss = 0.0 total = 0 count = 0 # Loop over each batch from the validation set for batch in val_loader: # Copy data to device if needed batch = tuple(t.to(device) for t in batch) # Unpack the batch from the loader inputs, labels = batch # Forward pass with torch.no_grad(): output = model(inputs)[-1, :, :] correct += (torch.argmax(output, dim=1) == labels).sum() loss += criterion(output, labels) * len(labels) total += labels.shape[0] count += 1 if count >= num_eval_batches: break acc = correct / total loss = loss / total metrics = {"val_accuracy": float(acc), "val_loss": float(loss)} return metrics def estimate_mdl(model, threshold=1e-2): total_params = 0 non_zero_params = 0 for param in model.parameters(): total_params += param.numel() non_zero_params += torch.sum(torch.abs(param) > threshold).item() return non_zero_params def run(out_dir, dataset, seed_offset): os.makedirs(out_dir, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.manual_seed(1337 + seed_offset) train_loader, val_loader, n_vocab, n_output = get_data( operation=dataset, prime=97, training_fraction=0.5, batch_size=512, ) model = Transformer( num_layers=2, dim_model=128, num_heads=4, vocab_size=n_vocab, output_size=n_output, seq_len=5, ).to(device) optimizer = torch.optim.AdamW( model.parameters(), lr=1e-3, betas=(0.9, 0.98), weight_decay=0.5, ) num_train_batches = 10 num_eval_batches = 8 num_total_updates = 7500 warmup_steps = 50 scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda s: min(s / warmup_steps, 1) ) final_info, train_log_info, val_log_info, mdl_log_info = [], [], [], [] step_val_acc_99 = num_total_updates step_val_acc_95 = num_total_updates for ep in range(num_total_updates // num_train_batches): train_metrics = train( model, train_loader, optimizer, scheduler, device, num_train_batches, ) val_metrics = evaluate( model, val_loader, device, num_eval_batches, ) train_metrics["step"] = (ep + 1) * num_train_batches val_metrics["step"] = (ep + 1) * num_train_batches if step_val_acc_99 == num_total_updates and val_metrics["val_accuracy"] > 0.99: step_val_acc_99 = val_metrics["step"] if step_val_acc_95 == num_total_updates and val_metrics["val_accuracy"] > 0.95: step_val_acc_95 = val_metrics["step"] train_log_info.append(train_metrics) val_log_info.append(val_metrics) if (ep + 1) * num_train_batches % 500 == 0: mdl = estimate_mdl(model) mdl_log_info.append({"step": (ep + 1) * num_train_batches, "mdl": mdl}) final_info = { "final_train_loss": train_metrics["train_loss"], "final_val_loss": val_metrics["val_loss"], "final_train_acc": train_metrics["train_accuracy"], "final_val_acc": val_metrics["val_accuracy"], "step_val_acc_99": step_val_acc_99, "step_val_acc_95": step_val_acc_95, } print(final_info) with open( os.path.join(out_dir, f"final_info_{dataset}_{seed_offset}.json"), "w" ) as f: json.dump(final_info, f) return final_info, train_log_info, val_log_info, mdl_log_info parser = argparse.ArgumentParser(description="Run experiment") parser.add_argument("--out_dir", type=str, default="run_0", help="Output directory") args = parser.parse_args() if __name__ == "__main__": num_seeds = { "x_div_y": 3, "x_plus_y": 3, "x_minus_y": 3, "permutation": 3, } out_dir = args.out_dir all_results = {} final_infos = {} for dataset in ["x_div_y", "x_minus_y", "x_plus_y", "permutation"]: final_info_list = [] for seed_offset in range(num_seeds[dataset]): print(f"Running {dataset} with seed offset {seed_offset}") final_info, train_info, val_info, mdl_info = run(args.out_dir, dataset, seed_offset) all_results[f"{dataset}_{seed_offset}_final_info"] = final_info all_results[f"{dataset}_{seed_offset}_train_info"] = train_info all_results[f"{dataset}_{seed_offset}_val_info"] = val_info all_results[f"{dataset}_{seed_offset}_mdl_info"] = mdl_info final_info_list.append(final_info) final_info_dict = { k: [d[k] for d in final_info_list] for k in final_info_list[0].keys() } means = {f"{k}_mean": np.mean(v) for k, v in final_info_dict.items()} stderrs = { f"{k}_stderr": np.std(v) / len(v) for k, v in final_info_dict.items() } final_infos[dataset] = { "means": means, "stderrs": stderrs, "final_info_dict": final_info_dict, } with open(os.path.join(out_dir, "final_info.json"), "w") as f: json.dump(final_infos, f) with open(os.path.join(out_dir, "all_results.npy"), "wb") as f: np.save(f, all_results)