import time import torch from transformers import (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, XLNetLMHeadModel, XLNetTokenizer, TransfoXLLMHeadModel, TransfoXLTokenizer, CTRLLMHeadModel, CTRLTokenizer) model_metadata = { "gpt2/small": { "tokenizer": GPT2Tokenizer, "model": GPT2LMHeadModel, "size": 550, "checkpoint": "gpt2", "identifier": "gpt2/small" }, "gpt": { "tokenizer": OpenAIGPTTokenizer, "model": OpenAIGPTLMHeadModel, "size": 550, "checkpoint": "openai-community/openai-gpt", "identifier": "gpt" }, "xlnet": { "tokenizer": XLNetTokenizer, "model": XLNetLMHeadModel, "size": 550, "checkpoint": "xlnet-base-cased", "identifier": "xlnet" }, "gpt2/arxiv-nlp": { "tokenizer": GPT2Tokenizer, "model": GPT2LMHeadModel, "size": 550, "checkpoint": "arxiv-nlp-v1", "identifier": "gpt2/arxiv-nlp" }, "gpt2/medium": { "tokenizer": GPT2Tokenizer, "model": GPT2LMHeadModel, "size": 1500, "checkpoint": "openai-community/gpt2-medium", "identifier": "gpt2/medium" }, "gpt2/large": { "tokenizer": GPT2Tokenizer, "model": GPT2LMHeadModel, "size": 3300, "checkpoint": "openai-community/gpt2-large", "identifier": "gpt2/large" }, "distilgpt2/small": { "tokenizer": GPT2Tokenizer, "model": GPT2LMHeadModel, "size": 350, "checkpoint": "distilgpt2", "identifier": "distilgpt2/small" }, "ctrl": { "tokenizer": CTRLTokenizer, "model": CTRLLMHeadModel, "size": 6300, "checkpoint": "Salesforce/ctrl", "identifier": "ctrl" }, "pplm": { "tokenizer": GPT2Tokenizer, "model": GPT2LMHeadModel, "size": 3000, "checkpoint": "openai-community/gpt2-large", "identifier": "pplm" }, "gpt2/xl": { "tokenizer": GPT2Tokenizer, "model": GPT2LMHeadModel, "size": 7000, "checkpoint": "openai-community/gpt2-xl", "identifier": "gpt2/xl" }, "pplm": { "tokenizer": GPT2Tokenizer, "model": GPT2LMHeadModel, "size": 4000, "checkpoint": "openai-community/gpt2-medium", "identifier": "pplm", "configuration_options": { "config": GPT2Config, "options": { "output_hidden_states": True } } } } memory_overhead = 500 class GPU: def __init__(self, id): self.id = id self.models = [] self.total_memory = torch.cuda.get_device_properties( "cuda:{}".format(id)).total_memory / 1_000_000 - 1_000 print("INIT GPU WITH DEVICE", "cuda:{}".format(id)) def register_model(self, model, cached_path=None): if self.total_memory_used() + model["size"] < self.total_memory: model["device"] = "cuda:{}".format(self.id) if cached_path: model["cached_path"] = cached_path self.models.append(model) return True else: return False def total_memory_used(self): return sum([model["size"] for model in self.models]) + memory_overhead def __repr__(self): return str( [(model["checkpoint"], model["size"]) for model in self.models] + [str(round(100 * (self.total_memory_used() / self.total_memory))) + "%"] + ["cuda:{}".format(self.id)] ) class GPUHandler: def __init__(self, ids, model_list, gpu_ids, cached_models=None): if cached_models is None: cached_models = {} self.gpus = [GPU(id) for id in gpu_ids] print("GPU handler initiated with {} gpus.".format(len(self.gpus))) self.sanity_check([model_metadata[model] for model in model_list]) for model in model_list: self.register_model(model_metadata[model], cached_models.get(model)) def register_model(self, model, cached_path=None): for index, gpu in enumerate(self.gpus): if gpu.register_model(model, cached_path): print("Registered model", model, "in GPU", gpu) break if index >= len(self.gpus): raise ValueError("Could not load model", model["checkpoint"]) def sanity_check(self, model_list): temp_gpus = [GPU(id) for id in range(len(self.gpus))] for model in model_list: current_gpu_index = 0 while current_gpu_index < len(temp_gpus): if not temp_gpus[current_gpu_index].register_model(model): current_gpu_index += 1 else: break if current_gpu_index >= len(temp_gpus): raise RuntimeError("SANITY CHECK FAILED") print("Current layout", temp_gpus) def __repr__(self): return f"NO. GPUS: {len(self.gpus)}.\n{self.gpus}"