|
""" |
|
Nanotron Inference Script |
|
|
|
Usage: |
|
``` |
|
export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations |
|
torchrun --nproc_per_node=8 run_evals.py --checkpoint-config-path ./pretrained/Mistral-7B-v0.1/config.yaml \ |
|
--lighteval-override ./lighteval_eval_config.yaml |
|
``` |
|
""" |
|
|
|
import argparse |
|
import os |
|
import random |
|
import time |
|
from dataclasses import asdict |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
from huggingface_hub import HFSummaryWriter |
|
from lighteval.evaluator import evaluate, make_results_table |
|
from lighteval.logging.evaluation_tracker import EvaluationTracker |
|
from lighteval.logging.hierarchical_logger import hlog, htrack, htrack_block |
|
from lighteval.logging.info_loggers import ( |
|
DetailsLogger, |
|
) |
|
from lighteval.models.model_loader import ModelInfo |
|
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks |
|
from lighteval.tasks.registry import Registry, get_custom_tasks, taskinfo_selector |
|
from nanotron import distributed as dist |
|
from nanotron import logging |
|
from nanotron.config import get_config_from_file |
|
from nanotron.logging import get_logger, log_rank |
|
from nanotron.parallel.context import ParallelContext |
|
from nanotron.utils import local_ranks_zero_first |
|
|
|
from brrr.config import BrrrConfig |
|
from brrr.experiment_loggers import flatten_dict, obj_to_markdown |
|
from brrr.s3_checkpoints import fs_copy |
|
from brrr.utils import check_env |
|
|
|
from lighteval.models.brrr_models import BRRRModel |
|
|
|
from modeling_mistral import MistralForTraining |
|
from config_mistral import MistralConfig |
|
|
|
logger = get_logger(__name__) |
|
|
|
TOKEN = os.getenv("HF_TOKEN") |
|
CACHE_DIR = os.getenv("HF_HOME", "/scratch") |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--checkpoint-config-path", |
|
type=str, |
|
required=True, |
|
help="Path to the brr checkpoint YAML or python config file, potentially on S3", |
|
) |
|
parser.add_argument( |
|
"--lighteval-override", |
|
type=str, |
|
help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config", |
|
) |
|
parser.add_argument( |
|
"--tokenizer", |
|
type=str, |
|
help="Local or hub path of an optional tokenizer (if not indicated in the checkpoint)", |
|
) |
|
parser.add_argument( |
|
"--s5cmd-path", |
|
type=str, |
|
default="/admin/home/thomwolf/miniconda3/envs/b4r/bin/s5cmd", |
|
help="Path to s5cmd install", |
|
) |
|
parser.add_argument( |
|
"--s5cmd-numworkers", |
|
type=int, |
|
default=64, |
|
help="s5cmd num workers (optional)", |
|
) |
|
parser.add_argument( |
|
"--s5cmd-concurrency", |
|
type=int, |
|
default=10, |
|
help="s5cmd concurrency (optional)", |
|
) |
|
parser.add_argument( |
|
"--cache-dir", |
|
type=str, |
|
default="", |
|
help="Cache directory", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def push_results_to_wandb( |
|
config: BrrrConfig, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail] |
|
): |
|
|
|
lighteval_config = config.lighteval |
|
try: |
|
global_step = config.general.step |
|
except ValueError: |
|
global_step = 0 |
|
if config.lighteval.logging.tensorboard_metric_prefix is not None: |
|
prefix = config.lighteval.logging.tensorboard_metric_prefix |
|
else: |
|
prefix = "eval" |
|
output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix) |
|
output_dir_tb.mkdir(parents=True, exist_ok=True) |
|
|
|
os.environ["WANDB_DISABLE_SERVICE"] = "True" |
|
import wandb |
|
|
|
wandb.tensorboard.patch(root_logdir=config.lighteval.logging.local_output_path) |
|
hlog("Starting wandb with WANDB_DISABLE_SERVICE=True") |
|
wandb.init( |
|
project=config.lighteval.wandb.wandb_project, |
|
entity=config.lighteval.wandb.wandb_entity, |
|
name=config.lighteval.wandb.wandb_run_name, |
|
config=config.as_dict(), |
|
|
|
resume=True, |
|
) |
|
wb_dict = {} |
|
bench_averages = {} |
|
for name, values in results.items(): |
|
splited_name = name.split("|") |
|
if len(splited_name) == 3: |
|
_, task_name, _ = splited_name |
|
else: |
|
task_name = name |
|
bench_suite = None |
|
if ":" in task_name: |
|
bench_suite = task_name.split(":")[0] |
|
hlog(f"bench_suite {bench_suite} in {task_name}") |
|
for metric, value in values.items(): |
|
if "stderr" in metric: |
|
continue |
|
if bench_suite not in bench_averages: |
|
bench_averages[bench_suite] = {} |
|
bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)] |
|
hlog(f"Pushing {task_name} {values} to tensorboard") |
|
for metric, value in values.items(): |
|
if "stderr" in metric: |
|
wb_dict[f"stderr_{metric}/{task_name}"] = value |
|
elif bench_suite is not None: |
|
wb_dict[f"{bench_suite}-{metric}/{task_name}"] = value |
|
else: |
|
wb_dict[f"{metric}/{task_name}"] = value |
|
|
|
for name, values in bench_averages.items(): |
|
for metric, values in values.items(): |
|
hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard") |
|
wb_dict[f"{metric}/{name}"] = sum(values) / len(values) |
|
|
|
for task_name, task_details in details.items(): |
|
if len(task_details) <= 1: |
|
continue |
|
columns = list(flatten_dict(asdict(task_details[0])).keys()) |
|
table = wandb.Table(columns=columns) |
|
table.add_data(*[str(v) for v in flatten_dict(asdict(task_details[0])).values()]) |
|
table.add_data(*[str(v) for v in flatten_dict(asdict(task_details[1])).values()]) |
|
wandb.log({f"eval_details_{task_name}": table}, step=global_step, commit=False) |
|
|
|
wandb.log(dict(wb_dict.items()), step=global_step, commit=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hlog(f"Pushed to wandb" f" at {output_dir_tb} and global_step {global_step}") |
|
|
|
|
|
def push_results_to_tensorboard( |
|
config: BrrrConfig, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail] |
|
): |
|
|
|
lighteval_config = config.lighteval |
|
try: |
|
global_step = config.general.step |
|
except ValueError: |
|
global_step = 0 |
|
if config.lighteval.logging.tensorboard_metric_prefix is not None: |
|
prefix = config.lighteval.logging.tensorboard_metric_prefix |
|
else: |
|
prefix = "eval" |
|
output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix) |
|
output_dir_tb.mkdir(parents=True, exist_ok=True) |
|
tb_context = HFSummaryWriter( |
|
logdir=str(output_dir_tb), |
|
repo_id=lighteval_config.logging.hub_repo_tensorboard, |
|
repo_private=True, |
|
path_in_repo="tb", |
|
commit_every=6000, |
|
) |
|
bench_averages = {} |
|
for name, values in results.items(): |
|
splited_name = name.split("|") |
|
if len(splited_name) == 3: |
|
_, task_name, _ = splited_name |
|
else: |
|
task_name = name |
|
bench_suite = None |
|
if ":" in task_name: |
|
bench_suite = task_name.split(":")[0] |
|
hlog(f"bench_suite {bench_suite} in {task_name}") |
|
for metric, value in values.items(): |
|
if "stderr" in metric: |
|
continue |
|
if bench_suite not in bench_averages: |
|
bench_averages[bench_suite] = {} |
|
bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)] |
|
hlog(f"Pushing {task_name} {values} to tensorboard") |
|
for metric, value in values.items(): |
|
if "stderr" in metric: |
|
tb_context.add_scalar(f"stderr_{prefix}/{task_name}/{metric}", value, global_step=global_step) |
|
elif bench_suite is not None: |
|
tb_context.add_scalar(f"{prefix}_{bench_suite}/{task_name}/{metric}", value, global_step=global_step) |
|
else: |
|
tb_context.add_scalar(f"{prefix}/{task_name}/{metric}", value, global_step=global_step) |
|
|
|
for name, values in bench_averages.items(): |
|
for metric, values in values.items(): |
|
hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard") |
|
tb_context.add_scalar(f"{prefix}/{name}/{metric}", sum(values) / len(values), global_step=global_step) |
|
|
|
tb_context.add_text("eval_config", obj_to_markdown(results), global_step=global_step) |
|
|
|
|
|
for task_name, task_details in details.items(): |
|
tb_context.add_text( |
|
f"eval_details_{task_name}", |
|
obj_to_markdown({"0": task_details[0], "1": task_details[1] if len(task_details) > 1 else {}}), |
|
global_step=global_step, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
tb_context.close() |
|
time.sleep(5) |
|
files = os.listdir(output_dir_tb) |
|
for file in files: |
|
os.rename(os.path.join(output_dir_tb, file), os.path.join(output_dir_tb, f"{global_step:07d}_{file}")) |
|
|
|
|
|
tb_context.scheduler.trigger() |
|
hlog( |
|
f"Pushed to tensorboard at https://huggingface.co/tensorboard/{lighteval_config.logging.hub_repo_tensorboard}/" |
|
f" at {output_dir_tb} and global_step {global_step}" |
|
) |
|
|
|
|
|
@htrack() |
|
def main(args): |
|
cache_dir = args.cache_dir or CACHE_DIR |
|
check_env() |
|
|
|
dist.initialize_torch_distributed() |
|
|
|
with htrack_block("get config"): |
|
if not args.checkpoint_config_path.endswith(".yaml"): |
|
raise ValueError("The checkpoint path should point to a YAML file") |
|
local_config_path = args.checkpoint_config_path |
|
if args.checkpoint_config_path.startswith("s3:/"): |
|
local_config_path = args.checkpoint_config_path.replace("s3:/", cache_dir) |
|
with local_ranks_zero_first(): |
|
if os.environ.get("LOCAL_RANK", None) == "0": |
|
os.makedirs(os.path.dirname(local_config_path), exist_ok=True) |
|
fs_copy(args.checkpoint_config_path, local_config_path) |
|
|
|
brrr_config: BrrrConfig = get_config_from_file(local_config_path, config_class=BrrrConfig, model_config_class=MistralConfig) |
|
|
|
if args.lighteval_override: |
|
local_override_path = args.lighteval_override.replace("s3:/", cache_dir) |
|
if args.lighteval_override.startswith("s3:/"): |
|
local_override_path = args.lighteval_override.replace("s3:/", cache_dir) |
|
with local_ranks_zero_first(): |
|
if os.environ.get("LOCAL_RANK", None) == "0": |
|
os.makedirs(os.path.dirname(local_override_path), exist_ok=True) |
|
fs_copy(args.lighteval_override, local_override_path) |
|
lighteval_brrr_config: BrrrConfig = get_config_from_file(local_override_path, config_class=BrrrConfig) |
|
lighteval_config = lighteval_brrr_config.lighteval |
|
brrr_config.lighteval = lighteval_config |
|
else: |
|
local_override_path = "" |
|
lighteval_config = brrr_config.lighteval |
|
|
|
parallel_context = ParallelContext( |
|
tensor_parallel_size=lighteval_config.parallelism.tp, |
|
pipeline_parallel_size=lighteval_config.parallelism.pp, |
|
data_parallel_size=lighteval_config.parallelism.dp, |
|
) |
|
|
|
evaluation_tracker = EvaluationTracker(token=TOKEN) |
|
evaluation_tracker.general_config_logger.log_args_info( |
|
num_fewshot_seeds=1, |
|
override_batch_size=None, |
|
max_samples=lighteval_config.tasks.max_samples, |
|
job_id=os.environ.get("SLURM_JOB_ID", None), |
|
config=brrr_config.as_dict(), |
|
) |
|
|
|
with htrack_block("Test all gather"): |
|
hlog("Test gather tensor") |
|
|
|
log_rank( |
|
f"[TEST] Running NCCL sync for ranks {list(range(parallel_context.world_pg.size()))}", |
|
logger=logger, |
|
level=logging.WARNING, |
|
group=parallel_context.dp_pg, |
|
rank=0, |
|
) |
|
test_tensor = torch.tensor([dist.get_rank(parallel_context.world_pg)], device=torch.device("cuda")) |
|
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(parallel_context.world_pg.size())] |
|
dist.all_gather(test_tensor_list, test_tensor, group=parallel_context.world_pg, async_op=False) |
|
dist.barrier() |
|
log_rank( |
|
f"[TEST] NCCL sync for ranks {[t.item() for t in test_tensor_list]}", |
|
logger=logger, |
|
level=logging.WARNING, |
|
group=parallel_context.dp_pg, |
|
rank=0, |
|
) |
|
|
|
del test_tensor_list |
|
del test_tensor |
|
|
|
with htrack_block("Model loading"): |
|
|
|
model = BRRRModel( |
|
checkpoint_path=args.checkpoint_config_path.replace("config.yaml", ""), |
|
model_args=brrr_config.model, |
|
tokenizer=brrr_config.tokenizer, |
|
parallel_context=parallel_context, |
|
parallel_config=lighteval_config.parallelism, |
|
lighteval_config=lighteval_config, |
|
batch_size=lighteval_config.batch_size, |
|
cache_dir=os.environ.get("HF_HOME", "/scratch"), |
|
debug_one_layer_model=False, |
|
s5cmd_path=args.s5cmd_path, |
|
s5cmd_numworkers=args.s5cmd_numworkers, |
|
s5cmd_concurrency=args.s5cmd_concurrency, |
|
model_class=MistralForTraining |
|
) |
|
model_info = ModelInfo(model_name=f"{brrr_config.general.run}/{brrr_config.general.step}") |
|
evaluation_tracker.general_config_logger.log_model_info(model_info) |
|
|
|
with htrack_block("Tasks loading"): |
|
with local_ranks_zero_first(): |
|
tasks_selection = lighteval_config.tasks.tasks |
|
if lighteval_config.tasks.custom_tasks_file: |
|
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks_file) |
|
if tasks_groups_dict and lighteval_config.tasks.tasks in tasks_groups_dict: |
|
tasks_selection = tasks_groups_dict[lighteval_config.tasks.tasks] |
|
|
|
task_names_list, few_shots_dict = taskinfo_selector(tasks_selection) |
|
task_dict = Registry(cache_dir=cache_dir).get_task_dict( |
|
task_names_list, custom_tasks_file=lighteval_config.tasks.custom_tasks_file |
|
) |
|
|
|
LightevalTask.load_datasets(task_dict.values(), lighteval_config.tasks.dataset_loading_processes) |
|
|
|
evaluation_tracker.task_config_logger.log(task_dict) |
|
|
|
hlog("Loading documents, and requests") |
|
requests, docs = create_requests_from_tasks( |
|
task_dict=task_dict, |
|
fewshot_dict=few_shots_dict, |
|
num_fewshot_seeds=lighteval_config.tasks.num_fewshot_seeds or 1, |
|
lm=model, |
|
max_samples=lighteval_config.tasks.max_samples, |
|
evaluation_tracker=evaluation_tracker, |
|
use_chat_template=False |
|
) |
|
|
|
with htrack_block("Setting seeds and waiting for all processes"): |
|
hlog(f"setting seed to {1234} for random and numpy") |
|
random.seed(1234) |
|
np.random.seed(1234) |
|
dist.barrier() |
|
|
|
with htrack_block("Evaluation"): |
|
hlog(f"Evaluate on {len(task_names_list)} tasks.") |
|
evaluation_tracker = evaluate( |
|
lm=model, |
|
requests_dict=requests, |
|
docs=docs, |
|
task_dict=task_dict, |
|
override_bs=lighteval_config.batch_size, |
|
evaluation_tracker=evaluation_tracker, |
|
) |
|
|
|
if dist.get_rank(parallel_context.world_pg) == 0: |
|
with htrack_block("Compiling and saving results"): |
|
evaluation_tracker.general_config_logger.log_end_time() |
|
evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict, bootstrap_iters=1000) |
|
evaluation_tracker.details_logger.aggregate() |
|
|
|
if lighteval_config.logging.local_output_path: |
|
evaluation_tracker.save( |
|
output_dir=lighteval_config.logging.local_output_path, |
|
push_results_to_hub=lighteval_config.logging.push_results_to_hub, |
|
push_details_to_hub=lighteval_config.logging.push_details_to_hub, |
|
public=False, |
|
push_results_to_tensorboard=lighteval_config.logging.push_results_to_tensorboard, |
|
) |
|
|
|
if lighteval_config.logging.push_results_to_tensorboard: |
|
push_results_to_tensorboard( |
|
config=brrr_config, |
|
results=evaluation_tracker.metrics_logger.metric_aggregated, |
|
details=evaluation_tracker.details_logger.details, |
|
) |
|
if lighteval_config.wandb is not None: |
|
push_results_to_wandb( |
|
config=brrr_config, |
|
results=evaluation_tracker.metrics_logger.metric_aggregated, |
|
details=evaluation_tracker.details_logger.details, |
|
) |
|
|
|
final_dict = evaluation_tracker.generate_final_dict() |
|
|
|
hlog(make_results_table(final_dict)) |
|
|
|
return final_dict |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = get_parser() |
|
args, unknowns = parser.parse_known_args() |
|
main(args) |
|
|