|
import os |
|
import time |
|
|
|
import pytest |
|
import torch |
|
import torch.nn as nn |
|
import uuid |
|
|
|
from ding.torch_utils.checkpoint_helper import auto_checkpoint, build_checkpoint_helper, CountVar |
|
from ding.utils import read_file, save_file |
|
|
|
|
|
class DstModel(nn.Module): |
|
|
|
def __init__(self): |
|
super(DstModel, self).__init__() |
|
self.fc1 = nn.Linear(3, 3) |
|
self.fc2 = nn.Linear(3, 8) |
|
self.fc_dst = nn.Linear(3, 6) |
|
|
|
|
|
class SrcModel(nn.Module): |
|
|
|
def __init__(self): |
|
super(SrcModel, self).__init__() |
|
self.fc1 = nn.Linear(3, 3) |
|
self.fc2 = nn.Linear(3, 8) |
|
self.fc_src = nn.Linear(3, 7) |
|
|
|
|
|
class HasStateDict(object): |
|
|
|
def __init__(self, name): |
|
self._name = name |
|
self._state_dict = name + str(uuid.uuid4()) |
|
|
|
def state_dict(self): |
|
old = self._state_dict |
|
self._state_dict = self._name + str(uuid.uuid4()) |
|
return old |
|
|
|
def load_state_dict(self, state_dict): |
|
self._state_dict = state_dict |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestCkptHelper: |
|
|
|
def test_load_model(self): |
|
path = 'model.pt' |
|
os.popen('rm -rf ' + path) |
|
time.sleep(1) |
|
|
|
dst_model = DstModel() |
|
src_model = SrcModel() |
|
ckpt_state_dict = {'model': src_model.state_dict()} |
|
torch.save(ckpt_state_dict, path) |
|
|
|
ckpt_helper = build_checkpoint_helper({}) |
|
with pytest.raises(RuntimeError): |
|
ckpt_helper.load(path, dst_model, strict=True) |
|
|
|
ckpt_helper.load(path, dst_model, strict=False) |
|
assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() < 1e-6 |
|
assert torch.abs(dst_model.fc1.bias - src_model.fc1.bias).max() < 1e-6 |
|
|
|
dst_model = DstModel() |
|
src_model = SrcModel() |
|
assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() > 1e-6 |
|
src_optimizer = HasStateDict('src_optimizer') |
|
dst_optimizer = HasStateDict('dst_optimizer') |
|
src_last_epoch = CountVar(11) |
|
dst_last_epoch = CountVar(5) |
|
src_last_iter = CountVar(110) |
|
dst_last_iter = CountVar(50) |
|
src_dataset = HasStateDict('src_dataset') |
|
dst_dataset = HasStateDict('dst_dataset') |
|
src_collector_info = HasStateDict('src_collect_info') |
|
dst_collector_info = HasStateDict('dst_collect_info') |
|
ckpt_helper.save( |
|
path, |
|
src_model, |
|
optimizer=src_optimizer, |
|
dataset=src_dataset, |
|
collector_info=src_collector_info, |
|
last_iter=src_last_iter, |
|
last_epoch=src_last_epoch, |
|
prefix_op='remove', |
|
prefix="f" |
|
) |
|
ckpt_helper.load( |
|
path, |
|
dst_model, |
|
dataset=dst_dataset, |
|
optimizer=dst_optimizer, |
|
last_iter=dst_last_iter, |
|
last_epoch=dst_last_epoch, |
|
collector_info=dst_collector_info, |
|
strict=False, |
|
state_dict_mask=['fc1'], |
|
prefix_op='add', |
|
prefix="f" |
|
) |
|
assert dst_dataset.state_dict().startswith('src') |
|
assert dst_optimizer.state_dict().startswith('src') |
|
assert dst_collector_info.state_dict().startswith('src') |
|
assert dst_last_iter.val == 110 |
|
for k, v in dst_model.named_parameters(): |
|
assert k.startswith('fc') |
|
print('==dst', dst_model.fc2.weight) |
|
print('==src', src_model.fc2.weight) |
|
assert torch.abs(dst_model.fc2.weight - src_model.fc2.weight).max() < 1e-6 |
|
assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() > 1e-6 |
|
|
|
checkpoint = read_file(path) |
|
checkpoint.pop('dataset') |
|
checkpoint.pop('optimizer') |
|
checkpoint.pop('last_iter') |
|
save_file(path, checkpoint) |
|
ckpt_helper.load( |
|
path, |
|
dst_model, |
|
dataset=dst_dataset, |
|
optimizer=dst_optimizer, |
|
last_iter=dst_last_iter, |
|
last_epoch=dst_last_epoch, |
|
collector_info=dst_collector_info, |
|
strict=True, |
|
state_dict_mask=['fc1'], |
|
prefix_op='add', |
|
prefix="f" |
|
) |
|
with pytest.raises(NotImplementedError): |
|
ckpt_helper.load( |
|
path, |
|
dst_model, |
|
strict=False, |
|
lr_schduler='lr_scheduler', |
|
last_iter=dst_last_iter, |
|
) |
|
|
|
with pytest.raises(KeyError): |
|
ckpt_helper.save(path, src_model, prefix_op='key_error', prefix="f") |
|
ckpt_helper.load(path, dst_model, strict=False, prefix_op='key_error', prefix="f") |
|
|
|
os.popen('rm -rf ' + path + '*') |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_count_var(): |
|
var = CountVar(0) |
|
var.add(5) |
|
assert var.val == 5 |
|
var.update(3) |
|
assert var.val == 3 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_auto_checkpoint(): |
|
|
|
class AutoCkptCls: |
|
|
|
def __init__(self): |
|
pass |
|
|
|
@auto_checkpoint |
|
def start(self): |
|
for i in range(10): |
|
if i < 5: |
|
time.sleep(0.2) |
|
else: |
|
raise Exception("There is an exception") |
|
break |
|
|
|
def save_checkpoint(self, ckpt_path): |
|
print('Checkpoint is saved successfully in {}!'.format(ckpt_path)) |
|
|
|
auto_ckpt = AutoCkptCls() |
|
auto_ckpt.start() |
|
|
|
|
|
if __name__ == '__main__': |
|
test = TestCkptHelper() |
|
test.test_load_model() |
|
|