File size: 1,836 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
from typing import Union, List
import torch
def is_differentiable(
loss: torch.Tensor, model: Union[torch.nn.Module, List[torch.nn.Module]], print_instead: bool = False
) -> None:
"""
Overview:
Judge whether the model/models are differentiable. First check whether module's grad is None,
then do loss's back propagation, finally check whether module's grad are torch.Tensor.
Arguments:
- loss (:obj:`torch.Tensor`): loss tensor of the model
- model (:obj:`Union[torch.nn.Module, List[torch.nn.Module]]`): model or models to be checked
- print_instead (:obj:`bool`): Whether to print module's final grad result, \
instead of asserting. Default set to ``False``.
"""
assert isinstance(loss, torch.Tensor)
if isinstance(model, list):
for m in model:
assert isinstance(m, torch.nn.Module)
for k, p in m.named_parameters():
assert p.grad is None, k
elif isinstance(model, torch.nn.Module):
for k, p in model.named_parameters():
assert p.grad is None, k
else:
raise TypeError('model must be list or nn.Module')
loss.backward()
if isinstance(model, list):
for m in model:
for k, p in m.named_parameters():
if print_instead:
if not isinstance(p.grad, torch.Tensor):
print(k, "grad is:", p.grad)
else:
assert isinstance(p.grad, torch.Tensor), k
elif isinstance(model, torch.nn.Module):
for k, p in model.named_parameters():
if print_instead:
if not isinstance(p.grad, torch.Tensor):
print(k, "grad is:", p.grad)
else:
assert isinstance(p.grad, torch.Tensor), k
|