|
import time
|
|
from collections import deque
|
|
from contextlib import nullcontext
|
|
from typing import Any, Callable, Deque, Dict, Optional
|
|
|
|
import torch
|
|
from lightning import Callback, Fabric, LightningModule, Trainer
|
|
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
|
|
from lightning.fabric.plugins import (
|
|
BitsandbytesPrecision,
|
|
DoublePrecision,
|
|
FSDPPrecision,
|
|
HalfPrecision,
|
|
MixedPrecision,
|
|
Precision,
|
|
TransformerEnginePrecision,
|
|
XLAPrecision,
|
|
)
|
|
from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only
|
|
from lightning.pytorch.plugins import (
|
|
DoublePrecisionPlugin,
|
|
FSDPPrecisionPlugin,
|
|
HalfPrecisionPlugin,
|
|
MixedPrecisionPlugin,
|
|
XLAPrecisionPlugin,
|
|
)
|
|
from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only
|
|
from torch.utils.flop_counter import FlopCounterMode
|
|
|
|
from tsai_gpt import GPT
|
|
from tsai_gpt.utils import num_parameters
|
|
|
|
GPU_AVAILABLE_FLOPS = {
|
|
|
|
|
|
"h100-sxm": {
|
|
torch.float64: 67e12,
|
|
torch.float32: 67e12,
|
|
torch.bfloat16: 1.979e15 / 2,
|
|
torch.float16: 1.979e15 / 2,
|
|
torch.int8: 3.958e15 / 2,
|
|
},
|
|
"h100-pcie": {
|
|
torch.float64: 51e12,
|
|
torch.float32: 51e12,
|
|
torch.bfloat16: 1.513e15 / 2,
|
|
torch.float16: 1.513e15 / 2,
|
|
torch.int8: 3.026e15 / 2,
|
|
},
|
|
|
|
|
|
"a100": {torch.float64: 19.5e12, torch.float32: 19.5e12, torch.bfloat16: 312e12, torch.float16: 312e12},
|
|
|
|
"a10g": {torch.float32: 31.2e12, torch.bfloat16: 125e12, torch.float16: 125e12},
|
|
|
|
"v100-sxm": {torch.float64: 7.8e12, torch.float32: 15.7e12, torch.float16: 125e12},
|
|
"v100-pcie": {torch.float64: 7e12, torch.float32: 14e12, torch.float16: 112e12},
|
|
"v100s-pcie": {torch.float64: 8.2e12, torch.float32: 16.4e12, torch.float16: 130e12},
|
|
|
|
|
|
"t4": {torch.float32: 8.1e12, torch.float16: 65e12, torch.int8: 130e12},
|
|
|
|
"quadro rtx 5000": {torch.float32: 11.2e12, torch.float16: 89.2e12},
|
|
}
|
|
|
|
TPU_AVAILABLE_FLOPS = {
|
|
|
|
|
|
|
|
|
|
"v2": 45e12,
|
|
|
|
"v3": 123e12,
|
|
|
|
"v4": 275e12,
|
|
|
|
"v5litepod": 197e12,
|
|
}
|
|
|
|
|
|
def get_flops_available(device: torch.device, dtype: torch.dtype) -> Optional[float]:
|
|
if device.type == "cuda":
|
|
device_name = torch.cuda.get_device_name(device).lower()
|
|
if "h100" in device_name and "hbm3" in device_name:
|
|
device_name = "h100-sxm"
|
|
elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name):
|
|
device_name = "h100-pcie"
|
|
elif "a100" in device_name:
|
|
device_name = "a100"
|
|
elif "a10g" in device_name:
|
|
device_name = "a10g"
|
|
elif "v100-sxm" in device_name:
|
|
device_name = "v100-sxm"
|
|
elif "v100-pcie" in device_name:
|
|
device_name = "v100-pcie"
|
|
elif "t4" in device_name:
|
|
device_name = "t4"
|
|
elif "quadro rtx 5000" in device_name:
|
|
device_name = "quadro rtx 5000"
|
|
else:
|
|
device_name = None
|
|
|
|
if device_name is not None:
|
|
try:
|
|
return int(GPU_AVAILABLE_FLOPS[device_name][dtype])
|
|
except KeyError:
|
|
raise KeyError(
|
|
f"flop count not found for {device_name} with dtype: {dtype}; "
|
|
"MFU cannot be calculated and reported."
|
|
)
|
|
elif device.type == "xla":
|
|
if _XLA_GREATER_EQUAL_2_1:
|
|
from torch_xla._internal import tpu
|
|
else:
|
|
from torch_xla.experimental import tpu
|
|
|
|
device_name = tpu.get_tpu_env()["TYPE"].lower()
|
|
try:
|
|
return int(TPU_AVAILABLE_FLOPS[device_name])
|
|
except KeyError:
|
|
raise KeyError(
|
|
f"flop count not found for {device_name} with dtype: {dtype}; MFU cannot be calculated and reported."
|
|
)
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
class SpeedMonitorBase:
|
|
"""Logs the training throughput and utilization.
|
|
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
| Key | Logged data |
|
|
+=====================================+===========================================================+
|
|
| | Rolling average (over `window_size` most recent |
|
|
| `throughput/batches_per_sec` | batches) of the number of batches processed per second |
|
|
| | |
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
| | Rolling average (over `window_size` most recent |
|
|
| `throughput/samples_per_sec` | batches) of the number of samples processed per second |
|
|
| | |
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
| | Rolling average (over `window_size` most recent |
|
|
| `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. |
|
|
| | This may include padding depending on dataset |
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
| | Estimates flops by `flops_per_batch * batches_per_sec` |
|
|
| `throughput/flops_per_sec` | |
|
|
| | |
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
| `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size |
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
| `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size |
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
| | `throughput/tokens_per_sec` divided by world size. This |
|
|
| `throughput/device/tokens_per_sec` | may include pad tokens depending on dataset |
|
|
| | |
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
| | `throughput/flops_per_sec` divided by world size. Only |
|
|
| `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` |
|
|
| | |
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
| | `throughput/device/flops_per_sec` divided by world size. |
|
|
| `throughput/device/mfu` | |
|
|
| | |
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
| `time/train` | Total elapsed training time |
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
| `time/val` | Total elapsed validation time |
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
| `time/total` | Total elapsed time (time/train + time/val) |
|
|
+-------------------------------------+-----------------------------------------------------------+
|
|
|
|
Notes:
|
|
- The implementation assumes that devices are homogeneous as it normalizes by the world size.
|
|
- Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or
|
|
batches/sec to measure throughput under this circumstance.
|
|
- Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``.
|
|
There is no widespread, realistic, and reliable implementation to compute them.
|
|
We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which
|
|
will almost always be an overestimate when compared to the true value.
|
|
|
|
Args:
|
|
window_size (int, optional): Number of batches to use for a rolling average of throughput.
|
|
Defaults to 100.
|
|
time_unit (str, optional): Time unit to use for `time` logging. Can be one of
|
|
'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
flops_available: float,
|
|
log_dict: Callable[[Dict, int], None],
|
|
window_size: int = 100,
|
|
time_unit: str = "hours",
|
|
):
|
|
self.flops_available = flops_available
|
|
self.log_dict = log_dict
|
|
|
|
|
|
self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
|
|
self.history_wct: Deque[float] = deque(maxlen=window_size + 1)
|
|
self.history_lengths: Deque[int] = deque(maxlen=window_size + 1)
|
|
self.history_flops: Deque[int] = deque(maxlen=window_size + 1)
|
|
|
|
self.divider = 1
|
|
if time_unit == "seconds":
|
|
self.divider = 1
|
|
elif time_unit == "minutes":
|
|
self.divider = 60
|
|
elif time_unit == "hours":
|
|
self.divider = 60 * 60
|
|
elif time_unit == "days":
|
|
self.divider = 60 * 60 * 24
|
|
else:
|
|
raise ValueError(
|
|
f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".'
|
|
)
|
|
|
|
|
|
self.total_eval_wct = 0.0
|
|
self.step = -1
|
|
|
|
def on_train_batch_end(
|
|
self,
|
|
samples: int,
|
|
train_elapsed: float,
|
|
world_size: int,
|
|
flops_per_batch: Optional[int] = None,
|
|
lengths: Optional[int] = None,
|
|
) -> None:
|
|
self.step += 1
|
|
step = self.step
|
|
metrics = {}
|
|
|
|
self.history_samples.append(samples)
|
|
if lengths is not None:
|
|
self.history_lengths.append(lengths)
|
|
|
|
assert len(self.history_samples) == len(self.history_lengths)
|
|
self.history_wct.append(train_elapsed)
|
|
if len(self.history_wct) == self.history_wct.maxlen:
|
|
elapsed_batches = len(self.history_samples) - 1
|
|
elapsed_samples = self.history_samples[-1] - self.history_samples[0]
|
|
elapsed_wct = self.history_wct[-1] - self.history_wct[0]
|
|
samples_per_sec = elapsed_samples * world_size / elapsed_wct
|
|
dev_samples_per_sec = elapsed_samples / elapsed_wct
|
|
metrics.update(
|
|
{
|
|
"throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct,
|
|
"throughput/samples_per_sec": samples_per_sec,
|
|
"throughput/device/batches_per_sec": elapsed_batches / elapsed_wct,
|
|
"throughput/device/samples_per_sec": dev_samples_per_sec,
|
|
}
|
|
)
|
|
if lengths is not None:
|
|
elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0])
|
|
avg_length = elapsed_lengths / elapsed_batches
|
|
metrics.update(
|
|
{
|
|
"throughput/tokens_per_sec": samples_per_sec * avg_length,
|
|
"throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length,
|
|
}
|
|
)
|
|
|
|
if flops_per_batch is not None:
|
|
|
|
self.history_flops.append(flops_per_batch * world_size)
|
|
if len(self.history_flops) == self.history_flops.maxlen:
|
|
elapsed_flops = sum(self.history_flops) - self.history_flops[0]
|
|
elapsed_wct = self.history_wct[-1] - self.history_wct[0]
|
|
flops_per_sec = elapsed_flops / elapsed_wct
|
|
device_flops_per_sec = flops_per_sec / world_size
|
|
metrics.update(
|
|
{"throughput/flops_per_sec": flops_per_sec, "throughput/device/flops_per_sec": device_flops_per_sec}
|
|
)
|
|
if self.flops_available:
|
|
metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available
|
|
|
|
metrics.update(
|
|
{
|
|
"time/train": train_elapsed / self.divider,
|
|
"time/val": self.total_eval_wct / self.divider,
|
|
"time/total": (train_elapsed + self.total_eval_wct) / self.divider,
|
|
"samples": samples,
|
|
}
|
|
)
|
|
|
|
self.log_dict(metrics, step)
|
|
|
|
def eval_end(self, eval_elapsed: float) -> None:
|
|
self.total_eval_wct += eval_elapsed
|
|
|
|
|
|
def plugin_to_compute_dtype(plugin: Precision) -> torch.dtype:
|
|
if isinstance(plugin, BitsandbytesPrecision):
|
|
return plugin.dtype
|
|
if isinstance(plugin, (HalfPrecision, MixedPrecision, HalfPrecisionPlugin)):
|
|
return plugin._desired_input_dtype
|
|
if isinstance(plugin, MixedPrecisionPlugin):
|
|
return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half
|
|
if isinstance(plugin, (DoublePrecision, DoublePrecisionPlugin)):
|
|
return torch.double
|
|
if isinstance(plugin, (XLAPrecision, XLAPrecisionPlugin)):
|
|
return plugin._desired_dtype
|
|
if isinstance(plugin, TransformerEnginePrecision):
|
|
return torch.int8
|
|
if isinstance(plugin, (FSDPPrecision, FSDPPrecisionPlugin)):
|
|
return plugin.mixed_precision_config.reduce_dtype
|
|
if isinstance(plugin, Precision):
|
|
return torch.float32
|
|
raise NotImplementedError(plugin)
|
|
|
|
|
|
class SpeedMonitorFabric(SpeedMonitorBase):
|
|
def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None:
|
|
dtype = plugin_to_compute_dtype(fabric.strategy.precision)
|
|
flops_available = get_flops_available(fabric.device, dtype)
|
|
super().__init__(flops_available, fabric.log_dict, *args, **kwargs)
|
|
|
|
@fabric_rank_zero_only
|
|
def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
|
|
super().on_train_batch_end(*args, **kwargs)
|
|
|
|
|
|
class SpeedMonitorCallback(Callback):
|
|
def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None:
|
|
super().__init__()
|
|
self.speed_monitor: Optional[SpeedMonitorBase] = None
|
|
self.speed_monitor_kwargs = kwargs
|
|
self.length_fn = length_fn
|
|
self.batch_size = batch_size
|
|
self.eval_t0: int = 0
|
|
self.train_t0: int = 0
|
|
self.total_lengths: int = 0
|
|
|
|
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
|
if self.speed_monitor is not None:
|
|
return
|
|
dtype = plugin_to_compute_dtype(trainer.precision_plugin)
|
|
flops_available = get_flops_available(trainer.strategy.root_device, dtype)
|
|
self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs)
|
|
|
|
@trainer_rank_zero_only
|
|
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
if trainer.fit_loop._should_accumulate():
|
|
return
|
|
|
|
self.train_t0 = time.perf_counter()
|
|
|
|
@trainer_rank_zero_only
|
|
def on_train_batch_end(
|
|
self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int
|
|
) -> None:
|
|
self.total_lengths += self.length_fn(batch)
|
|
if trainer.fit_loop._should_accumulate():
|
|
return
|
|
train_elapsed = time.perf_counter() - self.train_t0
|
|
assert self.speed_monitor is not None
|
|
iter_num = trainer.fit_loop.total_batch_idx
|
|
assert (measured_flops := pl_module.measured_flops) is not None
|
|
self.speed_monitor.on_train_batch_end(
|
|
(iter_num + 1) * self.batch_size,
|
|
train_elapsed,
|
|
|
|
trainer.world_size,
|
|
flops_per_batch=measured_flops,
|
|
lengths=self.total_lengths,
|
|
)
|
|
|
|
@trainer_rank_zero_only
|
|
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
self.eval_t0 = time.perf_counter()
|
|
|
|
@trainer_rank_zero_only
|
|
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
eval_elapsed = time.perf_counter() - self.eval_t0
|
|
assert self.speed_monitor is not None
|
|
self.speed_monitor.eval_end(eval_elapsed)
|
|
|
|
|
|
def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
|
|
flops_per_token = 2 * n_params
|
|
|
|
|
|
flops_per_seq = flops_per_token * max_seq_length
|
|
attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
|
|
return flops_per_seq + attn_flops_per_seq
|
|
|
|
|
|
def estimate_flops(model: GPT) -> int:
|
|
"""Measures estimated FLOPs for MFU.
|
|
|
|
Refs:
|
|
* https://ar5iv.labs.arxiv.org/html/2205.05198#A1
|
|
* https://ar5iv.labs.arxiv.org/html/2204.02311#A2
|
|
"""
|
|
|
|
|
|
|
|
|
|
n_trainable_params = num_parameters(model, requires_grad=True)
|
|
trainable_flops = flops_per_param(
|
|
model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params
|
|
)
|
|
|
|
ops_per_step = 3 if model.training else 1
|
|
n_frozen_params = num_parameters(model, requires_grad=False)
|
|
frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params)
|
|
|
|
frozen_ops_per_step = 2 if model.training else 1
|
|
return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
|
|
|
|
|
|
def measure_flops(model: GPT, x: torch.Tensor) -> int:
|
|
"""Measures real FLOPs for HFU"""
|
|
flop_counter = FlopCounterMode(model, display=False)
|
|
ctx = nullcontext() if model.training else torch.no_grad()
|
|
with ctx, flop_counter:
|
|
y = model(x)
|
|
if model.training:
|
|
y.sum().backward()
|
|
return flop_counter.get_total_flops()
|
|
|