gomoku / DI-engine /ding /torch_utils /tests /test_ckpt_helper.py
zjowowen's picture
init space
079c32c
raw
history blame
5.46 kB
import os
import time
import pytest
import torch
import torch.nn as nn
import uuid
from ding.torch_utils.checkpoint_helper import auto_checkpoint, build_checkpoint_helper, CountVar
from ding.utils import read_file, save_file
class DstModel(nn.Module):
def __init__(self):
super(DstModel, self).__init__()
self.fc1 = nn.Linear(3, 3)
self.fc2 = nn.Linear(3, 8)
self.fc_dst = nn.Linear(3, 6)
class SrcModel(nn.Module):
def __init__(self):
super(SrcModel, self).__init__()
self.fc1 = nn.Linear(3, 3)
self.fc2 = nn.Linear(3, 8)
self.fc_src = nn.Linear(3, 7)
class HasStateDict(object):
def __init__(self, name):
self._name = name
self._state_dict = name + str(uuid.uuid4())
def state_dict(self):
old = self._state_dict
self._state_dict = self._name + str(uuid.uuid4())
return old
def load_state_dict(self, state_dict):
self._state_dict = state_dict
@pytest.mark.unittest
class TestCkptHelper:
def test_load_model(self):
path = 'model.pt'
os.popen('rm -rf ' + path)
time.sleep(1)
dst_model = DstModel()
src_model = SrcModel()
ckpt_state_dict = {'model': src_model.state_dict()}
torch.save(ckpt_state_dict, path)
ckpt_helper = build_checkpoint_helper({})
with pytest.raises(RuntimeError):
ckpt_helper.load(path, dst_model, strict=True)
ckpt_helper.load(path, dst_model, strict=False)
assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() < 1e-6
assert torch.abs(dst_model.fc1.bias - src_model.fc1.bias).max() < 1e-6
dst_model = DstModel()
src_model = SrcModel()
assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() > 1e-6
src_optimizer = HasStateDict('src_optimizer')
dst_optimizer = HasStateDict('dst_optimizer')
src_last_epoch = CountVar(11)
dst_last_epoch = CountVar(5)
src_last_iter = CountVar(110)
dst_last_iter = CountVar(50)
src_dataset = HasStateDict('src_dataset')
dst_dataset = HasStateDict('dst_dataset')
src_collector_info = HasStateDict('src_collect_info')
dst_collector_info = HasStateDict('dst_collect_info')
ckpt_helper.save(
path,
src_model,
optimizer=src_optimizer,
dataset=src_dataset,
collector_info=src_collector_info,
last_iter=src_last_iter,
last_epoch=src_last_epoch,
prefix_op='remove',
prefix="f"
)
ckpt_helper.load(
path,
dst_model,
dataset=dst_dataset,
optimizer=dst_optimizer,
last_iter=dst_last_iter,
last_epoch=dst_last_epoch,
collector_info=dst_collector_info,
strict=False,
state_dict_mask=['fc1'],
prefix_op='add',
prefix="f"
)
assert dst_dataset.state_dict().startswith('src')
assert dst_optimizer.state_dict().startswith('src')
assert dst_collector_info.state_dict().startswith('src')
assert dst_last_iter.val == 110
for k, v in dst_model.named_parameters():
assert k.startswith('fc')
print('==dst', dst_model.fc2.weight)
print('==src', src_model.fc2.weight)
assert torch.abs(dst_model.fc2.weight - src_model.fc2.weight).max() < 1e-6
assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() > 1e-6
checkpoint = read_file(path)
checkpoint.pop('dataset')
checkpoint.pop('optimizer')
checkpoint.pop('last_iter')
save_file(path, checkpoint)
ckpt_helper.load(
path,
dst_model,
dataset=dst_dataset,
optimizer=dst_optimizer,
last_iter=dst_last_iter,
last_epoch=dst_last_epoch,
collector_info=dst_collector_info,
strict=True,
state_dict_mask=['fc1'],
prefix_op='add',
prefix="f"
)
with pytest.raises(NotImplementedError):
ckpt_helper.load(
path,
dst_model,
strict=False,
lr_schduler='lr_scheduler',
last_iter=dst_last_iter,
)
with pytest.raises(KeyError):
ckpt_helper.save(path, src_model, prefix_op='key_error', prefix="f")
ckpt_helper.load(path, dst_model, strict=False, prefix_op='key_error', prefix="f")
os.popen('rm -rf ' + path + '*')
@pytest.mark.unittest
def test_count_var():
var = CountVar(0)
var.add(5)
assert var.val == 5
var.update(3)
assert var.val == 3
@pytest.mark.unittest
def test_auto_checkpoint():
class AutoCkptCls:
def __init__(self):
pass
@auto_checkpoint
def start(self):
for i in range(10):
if i < 5:
time.sleep(0.2)
else:
raise Exception("There is an exception")
break
def save_checkpoint(self, ckpt_path):
print('Checkpoint is saved successfully in {}!'.format(ckpt_path))
auto_ckpt = AutoCkptCls()
auto_ckpt.start()
if __name__ == '__main__':
test = TestCkptHelper()
test.test_load_model()