# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os from enum import Enum from pathlib import Path from typing import Any, Dict, Optional class ClusterType(Enum): AWS = "aws" FAIR = "fair" RSC = "rsc" def _guess_cluster_type() -> ClusterType: uname = os.uname() if uname.sysname == "Linux": if uname.release.endswith("-aws"): # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" return ClusterType.AWS elif uname.nodename.startswith("rsc"): # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" return ClusterType.RSC return ClusterType.FAIR def get_cluster_type( cluster_type: Optional[ClusterType] = None, ) -> Optional[ClusterType]: if cluster_type is None: return _guess_cluster_type() return cluster_type def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: cluster_type = get_cluster_type(cluster_type) if cluster_type is None: return None CHECKPOINT_DIRNAMES = { ClusterType.AWS: "checkpoints", ClusterType.FAIR: "checkpoint", ClusterType.RSC: "checkpoint/dino", } return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] def get_user_checkpoint_path( cluster_type: Optional[ClusterType] = None, ) -> Optional[Path]: checkpoint_path = get_checkpoint_path(cluster_type) if checkpoint_path is None: return None username = os.environ.get("USER") assert username is not None return checkpoint_path / username def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: cluster_type = get_cluster_type(cluster_type) if cluster_type is None: return None SLURM_PARTITIONS = { ClusterType.AWS: "learnlab", ClusterType.FAIR: "learnlab", ClusterType.RSC: "learn", } return SLURM_PARTITIONS[cluster_type] def get_slurm_executor_parameters( nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs, ) -> Dict[str, Any]: # create default parameters params = { "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html "gpus_per_node": num_gpus_per_node, "tasks_per_node": num_gpus_per_node, # one task per GPU "cpus_per_task": 10, "nodes": nodes, "slurm_partition": get_slurm_partition(cluster_type), } # apply cluster-specific adjustments cluster_type = get_cluster_type(cluster_type) if cluster_type == ClusterType.AWS: params["cpus_per_task"] = 12 del params["mem_gb"] elif cluster_type == ClusterType.RSC: params["cpus_per_task"] = 12 # set additional parameters / apply overrides params.update(kwargs) return params