|
import unittest |
|
|
|
from detectron2.solver.build import _expand_param_groups, reduce_param_groups |
|
|
|
|
|
class TestOptimizer(unittest.TestCase): |
|
def testExpandParamsGroups(self): |
|
params = [ |
|
{ |
|
"params": ["p1", "p2", "p3", "p4"], |
|
"lr": 1.0, |
|
"weight_decay": 3.0, |
|
}, |
|
{ |
|
"params": ["p2", "p3", "p5"], |
|
"lr": 2.0, |
|
"momentum": 2.0, |
|
}, |
|
{ |
|
"params": ["p1"], |
|
"weight_decay": 4.0, |
|
}, |
|
] |
|
out = _expand_param_groups(params) |
|
gt = [ |
|
dict(params=["p1"], lr=1.0, weight_decay=4.0), |
|
dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), |
|
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), |
|
dict(params=["p4"], lr=1.0, weight_decay=3.0), |
|
dict(params=["p5"], lr=2.0, momentum=2.0), |
|
] |
|
self.assertEqual(out, gt) |
|
|
|
def testReduceParamGroups(self): |
|
params = [ |
|
dict(params=["p1"], lr=1.0, weight_decay=4.0), |
|
dict(params=["p2", "p6"], lr=2.0, weight_decay=3.0, momentum=2.0), |
|
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), |
|
dict(params=["p4"], lr=1.0, weight_decay=3.0), |
|
dict(params=["p5"], lr=2.0, momentum=2.0), |
|
] |
|
gt_groups = [ |
|
{ |
|
"lr": 1.0, |
|
"weight_decay": 4.0, |
|
"params": ["p1"], |
|
}, |
|
{ |
|
"lr": 2.0, |
|
"weight_decay": 3.0, |
|
"momentum": 2.0, |
|
"params": ["p2", "p6", "p3"], |
|
}, |
|
{ |
|
"lr": 1.0, |
|
"weight_decay": 3.0, |
|
"params": ["p4"], |
|
}, |
|
{ |
|
"lr": 2.0, |
|
"momentum": 2.0, |
|
"params": ["p5"], |
|
}, |
|
] |
|
out = reduce_param_groups(params) |
|
self.assertEqual(out, gt_groups) |
|
|