gomoku / DI-engine /ding /torch_utils /nn_test_helper.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
1.84 kB
from typing import Union, List
import torch
def is_differentiable(
loss: torch.Tensor, model: Union[torch.nn.Module, List[torch.nn.Module]], print_instead: bool = False
) -> None:
"""
Overview:
Judge whether the model/models are differentiable. First check whether module's grad is None,
then do loss's back propagation, finally check whether module's grad are torch.Tensor.
Arguments:
- loss (:obj:`torch.Tensor`): loss tensor of the model
- model (:obj:`Union[torch.nn.Module, List[torch.nn.Module]]`): model or models to be checked
- print_instead (:obj:`bool`): Whether to print module's final grad result, \
instead of asserting. Default set to ``False``.
"""
assert isinstance(loss, torch.Tensor)
if isinstance(model, list):
for m in model:
assert isinstance(m, torch.nn.Module)
for k, p in m.named_parameters():
assert p.grad is None, k
elif isinstance(model, torch.nn.Module):
for k, p in model.named_parameters():
assert p.grad is None, k
else:
raise TypeError('model must be list or nn.Module')
loss.backward()
if isinstance(model, list):
for m in model:
for k, p in m.named_parameters():
if print_instead:
if not isinstance(p.grad, torch.Tensor):
print(k, "grad is:", p.grad)
else:
assert isinstance(p.grad, torch.Tensor), k
elif isinstance(model, torch.nn.Module):
for k, p in model.named_parameters():
if print_instead:
if not isinstance(p.grad, torch.Tensor):
print(k, "grad is:", p.grad)
else:
assert isinstance(p.grad, torch.Tensor), k