|
import os |
|
import time |
|
import math |
|
import pickle |
|
import inspect |
|
import json |
|
from contextlib import nullcontext |
|
from dataclasses import dataclass |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
import argparse |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" |
|
|
|
def __init__(self, ndim, bias): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(ndim)) |
|
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
|
|
|
def forward(self, input): |
|
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) |
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
assert config.n_embd % config.n_head == 0 |
|
|
|
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) |
|
|
|
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
|
|
|
self.attn_dropout = nn.Dropout(config.dropout) |
|
self.resid_dropout = nn.Dropout(config.dropout) |
|
self.n_head = config.n_head |
|
self.n_embd = config.n_embd |
|
self.dropout = config.dropout |
|
|
|
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") |
|
if not self.flash: |
|
print( |
|
"WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0" |
|
) |
|
|
|
self.register_buffer( |
|
"bias", |
|
torch.tril(torch.ones(config.block_size, config.block_size)).view( |
|
1, 1, config.block_size, config.block_size |
|
), |
|
) |
|
|
|
def forward(self, x): |
|
B, T, C = ( |
|
x.size() |
|
) |
|
|
|
|
|
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
|
k = k.view(B, T, self.n_head, C // self.n_head).transpose( |
|
1, 2 |
|
) |
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose( |
|
1, 2 |
|
) |
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose( |
|
1, 2 |
|
) |
|
|
|
|
|
if self.flash: |
|
|
|
y = torch.nn.functional.scaled_dot_product_attention( |
|
q, |
|
k, |
|
v, |
|
attn_mask=None, |
|
dropout_p=self.dropout if self.training else 0, |
|
is_causal=True, |
|
) |
|
else: |
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
|
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) |
|
att = F.softmax(att, dim=-1) |
|
att = self.attn_dropout(att) |
|
y = att @ v |
|
y = ( |
|
y.transpose(1, 2).contiguous().view(B, T, C) |
|
) |
|
|
|
|
|
y = self.resid_dropout(self.c_proj(y)) |
|
return y |
|
|
|
|
|
class MLP(nn.Module): |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) |
|
self.gelu = nn.GELU() |
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) |
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
def forward(self, x): |
|
x = self.c_fc(x) |
|
x = self.gelu(x) |
|
x = self.c_proj(x) |
|
x = self.dropout(x) |
|
return x |
|
|
|
|
|
class StyleAdapter(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.linear = nn.Linear(config.n_embd, config.n_embd) |
|
|
|
def forward(self, x, style_emb): |
|
return x * self.linear(style_emb).unsqueeze(1) |
|
|
|
class Block(nn.Module): |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) |
|
self.attn = CausalSelfAttention(config) |
|
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) |
|
self.mlp = MLP(config) |
|
|
|
def forward(self, x): |
|
x = x + self.attn(self.ln_1(x)) |
|
x = x + self.mlp(self.ln_2(x)) |
|
return x |
|
|
|
|
|
@dataclass |
|
class GPTConfig: |
|
block_size: int = 1024 |
|
vocab_size: int = ( |
|
50304 |
|
) |
|
n_layer: int = 12 |
|
n_head: int = 12 |
|
n_embd: int = 768 |
|
dropout: float = 0.0 |
|
bias: bool = ( |
|
True |
|
) |
|
n_styles: int = 4 |
|
style_embd_dim: int = 64 |
|
|
|
|
|
class GPT(nn.Module): |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
assert config.vocab_size is not None |
|
assert config.block_size is not None |
|
self.config = config |
|
|
|
self.transformer = nn.ModuleDict( |
|
dict( |
|
wte=nn.Embedding(config.vocab_size, config.n_embd), |
|
wpe=nn.Embedding(config.block_size, config.n_embd), |
|
drop=nn.Dropout(config.dropout), |
|
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
|
ln_f=LayerNorm(config.n_embd, bias=config.bias), |
|
) |
|
) |
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
|
self.transformer.wte.weight = ( |
|
self.lm_head.weight |
|
) |
|
|
|
|
|
self.style_embeddings = nn.Parameter(torch.randn(config.n_styles, config.style_embd_dim)) |
|
self.style_proj = nn.Linear(config.style_embd_dim, config.n_embd) |
|
self.style_classifier = nn.Sequential( |
|
nn.Linear(config.n_embd, config.n_embd), |
|
nn.ReLU(), |
|
nn.Linear(config.n_embd, config.n_styles) |
|
) |
|
self.style_adapters = nn.ModuleList([StyleAdapter(config) for _ in range(config.n_layer // 2)]) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
for pn, p in self.named_parameters(): |
|
if pn.endswith("c_proj.weight"): |
|
torch.nn.init.normal_( |
|
p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) |
|
) |
|
|
|
|
|
print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,)) |
|
|
|
def get_num_params(self, non_embedding=True): |
|
""" |
|
Return the number of parameters in the model. |
|
For non-embedding count (default), the position embeddings get subtracted. |
|
The token embeddings would too, except due to the parameter sharing these |
|
params are actually used as weights in the final layer, so we include them. |
|
""" |
|
n_params = sum(p.numel() for p in self.parameters()) |
|
if non_embedding: |
|
n_params -= self.transformer.wpe.weight.numel() |
|
return n_params |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
if module.bias is not None: |
|
torch.nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
def forward(self, idx, targets=None): |
|
device = idx.device |
|
b, t = idx.size() |
|
assert ( |
|
t <= self.config.block_size |
|
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
pos = torch.arange(0, t, dtype=torch.long, device=device) |
|
|
|
|
|
tok_emb = self.transformer.wte(idx) |
|
pos_emb = self.transformer.wpe(pos) |
|
x = self.transformer.drop(tok_emb + pos_emb) |
|
|
|
style_logits = None |
|
for i, block in enumerate(self.transformer.h): |
|
x = block(x) |
|
if i % 2 == 1 and i < len(self.transformer.h) - 1: |
|
style_logits = self.style_classifier(x[:, -1, :]) |
|
style_probs = F.softmax(style_logits, dim=-1) |
|
style_emb = (style_probs @ self.style_embeddings) |
|
style_emb = self.style_proj(style_emb) |
|
x = self.style_adapters[i // 2](x, style_emb) |
|
|
|
x = self.transformer.ln_f(x) |
|
|
|
if targets is not None: |
|
|
|
logits = self.lm_head(x) |
|
loss = F.cross_entropy( |
|
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 |
|
) |
|
else: |
|
|
|
logits = self.lm_head( |
|
x[:, [-1], :] |
|
) |
|
loss = None |
|
|
|
return logits, loss, style_logits |
|
|
|
def crop_block_size(self, block_size): |
|
|
|
|
|
|
|
assert block_size <= self.config.block_size |
|
self.config.block_size = block_size |
|
self.transformer.wpe.weight = nn.Parameter( |
|
self.transformer.wpe.weight[:block_size] |
|
) |
|
for block in self.transformer.h: |
|
if hasattr(block.attn, "bias"): |
|
block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] |
|
|
|
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): |
|
|
|
param_dict = {pn: p for pn, p in self.named_parameters()} |
|
|
|
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} |
|
|
|
|
|
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] |
|
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] |
|
optim_groups = [ |
|
{"params": decay_params, "weight_decay": weight_decay}, |
|
{"params": nodecay_params, "weight_decay": 0.0}, |
|
] |
|
num_decay_params = sum(p.numel() for p in decay_params) |
|
num_nodecay_params = sum(p.numel() for p in nodecay_params) |
|
print( |
|
f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" |
|
) |
|
print( |
|
f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" |
|
) |
|
|
|
fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters |
|
use_fused = fused_available and device_type == "cuda" |
|
extra_args = dict(fused=True) if use_fused else dict() |
|
optimizer = torch.optim.AdamW( |
|
optim_groups, lr=learning_rate, betas=betas, **extra_args |
|
) |
|
print(f"using fused AdamW: {use_fused}") |
|
|
|
return optimizer |
|
|
|
@torch.no_grad() |
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
|
""" |
|
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete |
|
the sequence max_new_tokens times, feeding the predictions back into the model each time. |
|
Most likely you'll want to make sure to be in model.eval() mode of operation for this. |
|
""" |
|
for _ in range(max_new_tokens): |
|
|
|
idx_cond = ( |
|
idx |
|
if idx.size(1) <= self.config.block_size |
|
else idx[:, -self.config.block_size :] |
|
) |
|
|
|
logits, _, _ = self(idx_cond) |
|
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
if top_k is not None: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
logits[logits < v[:, [-1]]] = -float("Inf") |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
return idx |
|
|
|
|
|
|
|
def train(dataset="shakespeare_char", out_dir="run_0", seed_offset=0): |
|
|
|
|
|
|
|
gradient_accumulation_steps = 1 |
|
batch_size = 64 if dataset == "shakespeare_char" else 32 |
|
block_size = 256 |
|
|
|
eval_interval = 250 if dataset == "shakespeare_char" else 1000 |
|
log_interval = 10 if dataset == "shakespeare_char" else 100 |
|
eval_iters = 200 |
|
eval_only = False |
|
always_save_checkpoint = ( |
|
False |
|
) |
|
never_save_checkpoint = True |
|
|
|
n_layer = 6 |
|
n_head = 6 |
|
n_embd = 384 |
|
dropout = 0.2 |
|
bias = False |
|
n_styles = 4 |
|
style_embd_dim = 64 |
|
|
|
learning_rate = ( |
|
1e-3 if dataset == "shakespeare_char" else 5e-4 |
|
) |
|
max_iters = 5000 if dataset == "shakespeare_char" else 100000 |
|
weight_decay = 1e-1 |
|
beta1 = 0.9 |
|
beta2 = 0.99 |
|
grad_clip = 1.0 |
|
|
|
decay_lr = True |
|
warmup_iters = 100 if dataset == "shakespeare_char" else 200 |
|
lr_decay_iters = max_iters |
|
min_lr = 1e-4 if dataset == "shakespeare_char" else 5e-5 |
|
|
|
backend = "nccl" |
|
|
|
device = "cuda" |
|
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' |
|
compile = True |
|
|
|
|
|
|
|
|
|
master_process = True |
|
tokens_per_iter = gradient_accumulation_steps * batch_size * block_size |
|
print(f"tokens per iteration will be: {tokens_per_iter:,}") |
|
|
|
if master_process: |
|
os.makedirs(out_dir, exist_ok=True) |
|
torch.manual_seed(1337 + seed_offset) |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
device_type = "cuda" if "cuda" in device else "cpu" |
|
|
|
ptdtype = { |
|
"float32": torch.float32, |
|
"bfloat16": torch.bfloat16, |
|
"float16": torch.float16, |
|
}[dtype] |
|
ctx = ( |
|
nullcontext() |
|
if device_type == "cpu" |
|
else torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
) |
|
|
|
|
|
data_dir = os.path.join("../../../data", dataset) |
|
|
|
|
|
def get_batch(split): |
|
|
|
|
|
if split == "train": |
|
data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode="r") |
|
else: |
|
data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode="r") |
|
ix = torch.randint(len(data) - block_size, (batch_size,)) |
|
x = torch.stack( |
|
[torch.from_numpy((data[i : i + block_size]).astype(np.int64)) for i in ix] |
|
) |
|
y = torch.stack( |
|
[ |
|
torch.from_numpy((data[i + 1 : i + 1 + block_size]).astype(np.int64)) |
|
for i in ix |
|
] |
|
) |
|
if device_type == "cuda": |
|
|
|
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to( |
|
device, non_blocking=True |
|
) |
|
else: |
|
x, y = x.to(device), y.to(device) |
|
return x, y |
|
|
|
iter_num = 0 |
|
best_val_loss = 1e9 |
|
|
|
|
|
meta_path = os.path.join(data_dir, "meta.pkl") |
|
meta_vocab_size = None |
|
if os.path.exists(meta_path): |
|
with open(meta_path, "rb") as f: |
|
meta = pickle.load(f) |
|
meta_vocab_size = meta["vocab_size"] |
|
print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") |
|
|
|
|
|
model_args = dict( |
|
n_layer=n_layer, |
|
n_head=n_head, |
|
n_embd=n_embd, |
|
block_size=block_size, |
|
bias=bias, |
|
vocab_size=None, |
|
dropout=dropout, |
|
n_styles=n_styles, |
|
style_embd_dim=style_embd_dim, |
|
) |
|
|
|
print("Initializing a new model from scratch") |
|
|
|
if meta_vocab_size is None: |
|
print( |
|
"defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)" |
|
) |
|
model_args["vocab_size"] = meta_vocab_size if meta_vocab_size is not None else 50304 |
|
gptconf = GPTConfig(**model_args) |
|
model = GPT(gptconf) |
|
|
|
if block_size < model.config.block_size: |
|
model.crop_block_size(block_size) |
|
model_args["block_size"] = ( |
|
block_size |
|
) |
|
model.to(device) |
|
|
|
|
|
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) |
|
|
|
|
|
optimizer = model.configure_optimizers( |
|
weight_decay, learning_rate, (beta1, beta2), device_type |
|
) |
|
checkpoint = None |
|
|
|
|
|
if compile: |
|
print("compiling the model... (takes a ~minute)") |
|
unoptimized_model = model |
|
model = torch.compile(model) |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def estimate_loss(): |
|
out = {} |
|
model.eval() |
|
for split in ["train", "val"]: |
|
losses = torch.zeros(eval_iters) |
|
for k in range(eval_iters): |
|
X, Y = get_batch(split) |
|
with ctx: |
|
logits, loss, _ = model(X, Y) |
|
losses[k] = loss.item() |
|
out[split] = losses.mean() |
|
model.train() |
|
return out |
|
|
|
|
|
|
|
def get_lr(it): |
|
|
|
if it < warmup_iters: |
|
return learning_rate * it / warmup_iters |
|
|
|
if it > lr_decay_iters: |
|
return min_lr |
|
|
|
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) |
|
assert 0 <= decay_ratio <= 1 |
|
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
|
return min_lr + coeff * (learning_rate - min_lr) |
|
|
|
|
|
|
|
val_log_info = [] |
|
train_log_info = [] |
|
|
|
|
|
X, Y = get_batch("train") |
|
og_t0 = time.time() |
|
t0 = time.time() |
|
local_iter_num = 0 |
|
raw_model = model |
|
while True: |
|
|
|
|
|
lr = get_lr(iter_num) if decay_lr else learning_rate |
|
for param_group in optimizer.param_groups: |
|
param_group["lr"] = lr |
|
|
|
|
|
if iter_num % eval_interval == 0 and master_process: |
|
losses = estimate_loss() |
|
print( |
|
f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" |
|
) |
|
val_log_info.append( |
|
{ |
|
"iter": iter_num, |
|
"train/loss": losses["train"].item(), |
|
"val/loss": losses["val"].item(), |
|
"lr": lr, |
|
} |
|
) |
|
if losses["val"] < best_val_loss or always_save_checkpoint: |
|
best_val_loss = losses["val"] |
|
if iter_num > 0 and not never_save_checkpoint: |
|
checkpoint = { |
|
"model": raw_model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"model_args": model_args, |
|
"iter_num": iter_num, |
|
"best_val_loss": best_val_loss, |
|
} |
|
print(f"saving checkpoint to {out_dir}") |
|
torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt")) |
|
if iter_num == 0 and eval_only: |
|
break |
|
|
|
|
|
|
|
for micro_step in range(gradient_accumulation_steps): |
|
with ctx: |
|
logits, loss, style_logits = model(X, Y) |
|
|
|
style_loss = F.cross_entropy(style_logits, torch.randint(0, n_styles, (X.size(0),), device=device)) |
|
style_loss_weight = 0.1 |
|
total_loss = loss + style_loss_weight * style_loss |
|
total_loss = total_loss / gradient_accumulation_steps |
|
|
|
X, Y = get_batch("train") |
|
|
|
scaler.scale(total_loss).backward() |
|
|
|
if grad_clip != 0.0: |
|
scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) |
|
|
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
t1 = time.time() |
|
dt = t1 - t0 |
|
t0 = t1 |
|
if iter_num % log_interval == 0 and master_process: |
|
|
|
|
|
lossf = total_loss.item() * gradient_accumulation_steps |
|
print( |
|
f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms" |
|
) |
|
train_log_info.append( |
|
{ |
|
"iter": iter_num, |
|
"loss": lossf, |
|
"time": dt*1000, |
|
} |
|
) |
|
iter_num += 1 |
|
local_iter_num += 1 |
|
|
|
|
|
if iter_num > max_iters: |
|
break |
|
|
|
print("training done") |
|
print(f"Best validation loss: {best_val_loss}") |
|
print(f"Total train time: {(time.time() - og_t0) / 60:.2f} mins") |
|
|
|
final_info = { |
|
"final_train_loss": lossf, |
|
"best_val_loss": best_val_loss.item(), |
|
"total_train_time": time.time() - og_t0, |
|
} |
|
|
|
|
|
|
|
|
|
start = " " |
|
num_samples = 10 |
|
max_new_tokens = 500 |
|
temperature = 0.8 |
|
top_k = 200 |
|
|
|
|
|
assert os.path.exists(meta_path), "meta.pkl not found, please run training script first" |
|
print(f"Loading meta from {meta_path}...") |
|
with open(meta_path, 'rb') as f: |
|
meta = pickle.load(f) |
|
stoi, itos = meta['stoi'], meta['itos'] |
|
encode = lambda s: [stoi[c] for c in s] |
|
decode = lambda l: ''.join([itos[i] for i in l]) |
|
|
|
|
|
if start.startswith('FILE:'): |
|
with open(start[5:], 'r', encoding='utf-8') as f: |
|
start = f.read() |
|
start_ids = encode(start) |
|
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) |
|
|
|
|
|
model.eval() |
|
results = [] |
|
with torch.no_grad(): |
|
with ctx: |
|
for k in range(num_samples): |
|
start_time = time.time() |
|
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) |
|
end_time = time.time() |
|
|
|
generated_text = decode(y[0].tolist()) |
|
inference_time = end_time - start_time |
|
tokens_per_second = max_new_tokens / inference_time |
|
|
|
print(f"Sample {k+1}:") |
|
print(generated_text) |
|
print(f"Inference time: {inference_time:.2f} seconds") |
|
print(f"Tokens per second: {tokens_per_second:.2f}") |
|
print('---------------') |
|
|
|
results.append({ |
|
"sample_id": k+1, |
|
"generated_text": generated_text, |
|
"inference_time": inference_time, |
|
"tokens_per_second": tokens_per_second |
|
}) |
|
|
|
|
|
avg_tokens_per_second = sum(r['tokens_per_second'] for r in results) / len(results) |
|
print(f"Average tokens per second: {avg_tokens_per_second:.2f}") |
|
|
|
final_info["avg_inference_tokens_per_second"] = avg_tokens_per_second |
|
|
|
with open(os.path.join(out_dir, f"final_info_{dataset}_{seed_offset}.json"), "w") as f: |
|
json.dump(final_info, f) |
|
return final_info, train_log_info, val_log_info |
|
|
|
parser = argparse.ArgumentParser(description='Run experiment') |
|
parser.add_argument('--out_dir', type=str, default='run_0', help='Output directory') |
|
args = parser.parse_args() |
|
|
|
if __name__ == "__main__": |
|
num_seeds = { |
|
"shakespeare_char": 3, |
|
"enwik8": 1, |
|
"text8": 1, |
|
} |
|
|
|
out_dir = args.out_dir |
|
all_results = {} |
|
final_infos = {} |
|
for dataset in ["shakespeare_char", "enwik8", "text8"]: |
|
final_info_list = [] |
|
for seed_offset in range(num_seeds[dataset]): |
|
final_info, train_info, val_info = train(dataset, out_dir, seed_offset) |
|
all_results[f"{dataset}_{seed_offset}_final_info"] = final_info |
|
all_results[f"{dataset}_{seed_offset}_train_info"] = train_info |
|
all_results[f"{dataset}_{seed_offset}_val_info"] = val_info |
|
final_info_list.append(final_info) |
|
final_info_dict = {k: [d[k] for d in final_info_list] for k in final_info_list[0].keys()} |
|
means = {f"{k}_mean": np.mean(v) for k, v in final_info_dict.items()} |
|
stderrs = {f"{k}_stderr": np.std(v) / len(v) for k, v in final_info_dict.items()} |
|
final_infos[dataset] = { |
|
"means": means, |
|
"stderrs": stderrs, |
|
"final_info_dict": final_info_dict, |
|
} |
|
|
|
with open(os.path.join(out_dir, "final_info.json"), "w") as f: |
|
json.dump(final_infos, f) |
|
|
|
with open(os.path.join(out_dir, "all_results.npy"), "wb") as f: |
|
np.save(f, all_results) |
|
|