Spaces:
Sleeping
Sleeping
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, visit | |
# https://nvlabs.github.io/stylegan2/license.html | |
"""Helpers for managing the run/training loop.""" | |
import datetime | |
import json | |
import os | |
import pprint | |
import time | |
import types | |
from typing import Any | |
from . import submit | |
# Singleton RunContext | |
_run_context = None | |
class RunContext(object): | |
"""Helper class for managing the run/training loop. | |
The context will hide the implementation details of a basic run/training loop. | |
It will set things up properly, tell if run should be stopped, and then cleans up. | |
User should call update periodically and use should_stop to determine if run should be stopped. | |
Args: | |
submit_config: The SubmitConfig that is used for the current run. | |
config_module: (deprecated) The whole config module that is used for the current run. | |
""" | |
def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None): | |
global _run_context | |
# Only a single RunContext can be alive | |
assert _run_context is None | |
_run_context = self | |
self.submit_config = submit_config | |
self.should_stop_flag = False | |
self.has_closed = False | |
self.start_time = time.time() | |
self.last_update_time = time.time() | |
self.last_update_interval = 0.0 | |
self.progress_monitor_file_path = None | |
# vestigial config_module support just prints a warning | |
if config_module is not None: | |
print("RunContext.config_module parameter support has been removed.") | |
# write out details about the run to a text file | |
self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} | |
with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: | |
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) | |
def __enter__(self) -> "RunContext": | |
return self | |
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: | |
self.close() | |
def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: | |
"""Do general housekeeping and keep the state of the context up-to-date. | |
Should be called often enough but not in a tight loop.""" | |
assert not self.has_closed | |
self.last_update_interval = time.time() - self.last_update_time | |
self.last_update_time = time.time() | |
if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): | |
self.should_stop_flag = True | |
def should_stop(self) -> bool: | |
"""Tell whether a stopping condition has been triggered one way or another.""" | |
return self.should_stop_flag | |
def get_time_since_start(self) -> float: | |
"""How much time has passed since the creation of the context.""" | |
return time.time() - self.start_time | |
def get_time_since_last_update(self) -> float: | |
"""How much time has passed since the last call to update.""" | |
return time.time() - self.last_update_time | |
def get_last_update_interval(self) -> float: | |
"""How much time passed between the previous two calls to update.""" | |
return self.last_update_interval | |
def close(self) -> None: | |
"""Close the context and clean up. | |
Should only be called once.""" | |
if not self.has_closed: | |
# update the run.txt with stopping time | |
self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") | |
with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: | |
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) | |
self.has_closed = True | |
# detach the global singleton | |
global _run_context | |
if _run_context is self: | |
_run_context = None | |
def get(): | |
import dnnlib | |
if _run_context is not None: | |
return _run_context | |
return RunContext(dnnlib.submit_config) | |