|
import contextlib |
|
import sys |
|
import time |
|
|
|
import torch |
|
|
|
if sys.version_info >= (3, 7): |
|
|
|
@contextlib.contextmanager |
|
def profile_time(trace_name, |
|
name, |
|
enabled=True, |
|
stream=None, |
|
end_stream=None): |
|
"""Print time spent by CPU and GPU. |
|
|
|
Useful as a temporary context manager to find sweet spots of code |
|
suitable for async implementation. |
|
""" |
|
if (not enabled) or not torch.cuda.is_available(): |
|
yield |
|
return |
|
stream = stream if stream else torch.cuda.current_stream() |
|
end_stream = end_stream if end_stream else stream |
|
start = torch.cuda.Event(enable_timing=True) |
|
end = torch.cuda.Event(enable_timing=True) |
|
stream.record_event(start) |
|
try: |
|
cpu_start = time.monotonic() |
|
yield |
|
finally: |
|
cpu_end = time.monotonic() |
|
end_stream.record_event(end) |
|
end.synchronize() |
|
cpu_time = (cpu_end - cpu_start) * 1000 |
|
gpu_time = start.elapsed_time(end) |
|
msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms ' |
|
msg += f'gpu_time {gpu_time:.2f} ms stream {stream}' |
|
print(msg, end_stream) |
|
|