File size: 373 Bytes
9ff79dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""
Utility functions for interpretability.
"""

import torch


def get_torch_device() -> str:
    """
    Returns the device and dtype to be used for torch tensors.
    """
    if torch.cuda.is_available():
        device = "cuda:0"
    elif torch.backends.mps.is_available():  # for Apple Silicon
        device = "mps"
    else:
        device = "cpu"
    return device