|
For example, we could turn off amp temporarily if it's |
|
enabled, after moving the original forward into a helper wrapper, like so: |
|
thon |
|
def _forward(self, hidden_states): |
|
hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) |
|
hidden_linear = self.wi_1(hidden_states) |
|
hidden_states = hidden_gelu * hidden_linear |
|
hidden_states = self.dropout(hidden_states) |
|
hidden_states = self.wo(hidden_states) |
|
return hidden_states |
|
import torch |
|
def forward(self, hidden_states): |
|
if torch.is_autocast_enabled(): |
|
with torch.cuda.amp.autocast(enabled=False): |
|
return self._forward(hidden_states) |
|
else: |
|
return self._forward(hidden_states) |
|
|
|
Since the automatic detector only reports on inputs and outputs of full frames, once you know where to look, you may |
|
want to analyse the intermediary stages of any specific forward function as well. |