Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,737 Bytes
8866a87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Optional, Union
import torch
Device = Union[str, torch.device]
def make_device(device: Device) -> torch.device:
"""
Makes an actual torch.device object from the device specified as
either a string or torch.device object. If the device is `cuda` without
a specific index, the index of the current device is assigned.
Args:
device: Device (as str or torch.device)
Returns:
A matching torch.device object
"""
device = torch.device(device) if isinstance(device, str) else device
if device.type == "cuda" and device.index is None:
# If cuda but with no index, then the current cuda device is indicated.
# In that case, we fix to that device
device = torch.device(f"cuda:{torch.cuda.current_device()}")
return device
def get_device(x, device: Optional[Device] = None) -> torch.device:
"""
Gets the device of the specified variable x if it is a tensor, or
falls back to a default CPU device otherwise. Allows overriding by
providing an explicit device.
Args:
x: a torch.Tensor to get the device from or another type
device: Device (as str or torch.device) to fall back to
Returns:
A matching torch.device object
"""
# User overrides device
if device is not None:
return make_device(device)
# Set device based on input tensor
if torch.is_tensor(x):
return x.device
# Default device is cpu
return torch.device("cpu")
|