Spaces:
Sleeping
Sleeping
File size: 373 Bytes
9ff79dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
"""
Utility functions for interpretability.
"""
import torch
def get_torch_device() -> str:
"""
Returns the device and dtype to be used for torch tensors.
"""
if torch.cuda.is_available():
device = "cuda:0"
elif torch.backends.mps.is_available(): # for Apple Silicon
device = "mps"
else:
device = "cpu"
return device
|