HUANG-Stephanie's picture
Upload 88 files
9ff79dc verified
raw
history blame
373 Bytes
"""
Utility functions for interpretability.
"""
import torch
def get_torch_device() -> str:
"""
Returns the device and dtype to be used for torch tensors.
"""
if torch.cuda.is_available():
device = "cuda:0"
elif torch.backends.mps.is_available(): # for Apple Silicon
device = "mps"
else:
device = "cpu"
return device