PoTaTo721's picture
Update to V1.4
28c720a
raw
history blame
287 Bytes
from contextlib import nullcontext
import torch
def autocast_exclude_mps(
device_type: str, dtype: torch.dtype
) -> nullcontext | torch.autocast:
return (
nullcontext()
if torch.backends.mps.is_available()
else torch.autocast(device_type, dtype)
)