|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import pytorch_lightning as pl |
|
from qa_mdt.audioldm_train.utilities.model_util import ( |
|
exists, |
|
default, |
|
mean_flat, |
|
count_params, |
|
instantiate_from_config, |
|
) |
|
from torch.optim import * |
|
|
|
from transformers import GPT2Config, GPT2Model, GPTJConfig, GPTJModel |
|
import torch.optim.lr_scheduler as lr_scheduler |
|
|
|
|
|
class Sequence2AudioMAE(pl.LightningModule): |
|
def __init__( |
|
self, |
|
base_learning_rate, |
|
sequence_gen_length, |
|
sequence_input_key, |
|
sequence_input_embed_dim, |
|
cond_stage_config, |
|
optimizer_type="AdamW", |
|
use_warmup=True, |
|
use_ar_gen_loss=False, |
|
use_audiomae_linear=False, |
|
target_tokens_mask_ratio=0.0, |
|
random_mask_ratio=False, |
|
**kwargs |
|
): |
|
|
|
super().__init__() |
|
assert use_audiomae_linear == False |
|
self.random_mask_ratio = random_mask_ratio |
|
self.learning_rate = base_learning_rate |
|
self.cond_stage_config = cond_stage_config |
|
self.use_audiomae_linear = use_audiomae_linear |
|
self.optimizer_type = optimizer_type |
|
self.use_warmup = use_warmup |
|
self.use_ar_gen_loss = use_ar_gen_loss |
|
|
|
|
|
|
|
self.mae_token_num = sequence_gen_length |
|
self.sequence_input_key = sequence_input_key |
|
self.sequence_input_embed_dim = sequence_input_embed_dim |
|
self.target_tokens_mask_ratio = target_tokens_mask_ratio |
|
|
|
self.start_of_sequence_tokens = nn.Embedding(32, 768) |
|
self.end_of_sequence_tokens = nn.Embedding(32, 768) |
|
|
|
self.input_sequence_embed_linear = nn.ModuleList([]) |
|
self.initial_learning_rate = None |
|
|
|
for dim in self.sequence_input_embed_dim: |
|
self.input_sequence_embed_linear.append(nn.Linear(dim, 768)) |
|
|
|
self.cond_stage_models = nn.ModuleList([]) |
|
self.instantiate_cond_stage(cond_stage_config) |
|
self.initialize_param_check_toolkit() |
|
|
|
self.private_training_step = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model = GPT2Model.from_pretrained("gpt2") |
|
|
|
|
|
|
|
|
|
self.loss_fn = nn.L1Loss() |
|
|
|
self.logger_save_dir = None |
|
self.logger_exp_name = None |
|
self.logger_exp_group_name = None |
|
self.logger_version = None |
|
|
|
def set_log_dir(self, save_dir, exp_group_name, exp_name): |
|
self.logger_save_dir = save_dir |
|
self.logger_exp_group_name = exp_group_name |
|
self.logger_exp_name = exp_name |
|
|
|
def cfg_uncond(self, batch_size): |
|
unconditional_conditioning = {} |
|
for key in self.cond_stage_model_metadata: |
|
model_idx = self.cond_stage_model_metadata[key]["model_idx"] |
|
unconditional_conditioning[key] = self.cond_stage_models[ |
|
model_idx |
|
].get_unconditional_condition(batch_size) |
|
assert ( |
|
"crossattn_audiomae_pooled" in unconditional_conditioning.keys() |
|
), "The module is not initialized with AudioMAE" |
|
unconditional_conditioning[ |
|
"crossattn_clap_to_audiomae_feature" |
|
] = unconditional_conditioning["crossattn_audiomae_pooled"] |
|
return unconditional_conditioning |
|
|
|
def configure_optimizers(self): |
|
lr = float(self.learning_rate) |
|
|
|
params = list(self.parameters()) |
|
|
|
|
|
opt = eval(self.optimizer_type)(params, lr=lr) |
|
scheduler = lr_scheduler.StepLR(opt, step_size=10, gamma=0.8) |
|
return [opt], [scheduler] |
|
|
|
def add_sos_eos_tokens(self, _id, sequence, attn_mask): |
|
batchsize = sequence.size(0) |
|
|
|
new_attn_mask_step = torch.ones((batchsize, 1)).to(sequence.device) |
|
key_id = torch.tensor([_id]).to(sequence.device) |
|
|
|
|
|
new_attn_mask = torch.cat( |
|
[new_attn_mask_step, attn_mask, new_attn_mask_step], dim=1 |
|
) |
|
|
|
|
|
sos_token = self.start_of_sequence_tokens(key_id).expand(batchsize, 1, -1) |
|
eos_token = self.end_of_sequence_tokens(key_id).expand(batchsize, 1, -1) |
|
new_sequence = torch.cat([sos_token, sequence, eos_token], dim=1) |
|
return new_sequence, new_attn_mask |
|
|
|
def truncate_sequence_and_mask(self, sequence, mask, max_len=512): |
|
if sequence.size(1) > max_len: |
|
print( |
|
"The input sequence length to GPT-2 model is too long:", |
|
sequence.size(1), |
|
) |
|
return sequence[:, :max_len], mask[:, :max_len] |
|
else: |
|
return sequence, mask |
|
|
|
def get_input_sequence_and_mask(self, cond_dict): |
|
input_embeds = None |
|
input_embeds_attn_mask = None |
|
for _id, sequence_key in enumerate(self.sequence_input_key): |
|
assert sequence_key in cond_dict.keys(), ( |
|
"Invalid sequence key %s" % sequence_key |
|
) |
|
cond_embed = cond_dict[sequence_key] |
|
if isinstance(cond_embed, list): |
|
assert ( |
|
len(cond_embed) == 2 |
|
), "The crossattn returned list should have length 2, including embed and attn_mask" |
|
item_input_embeds, item_attn_mask = cond_embed |
|
|
|
item_input_embeds = self.input_sequence_embed_linear[_id]( |
|
item_input_embeds |
|
) |
|
|
|
item_input_embeds, item_attn_mask = self.add_sos_eos_tokens( |
|
_id, item_input_embeds, item_attn_mask |
|
) |
|
|
|
if input_embeds is None and input_embeds_attn_mask is None: |
|
input_embeds, input_embeds_attn_mask = ( |
|
item_input_embeds, |
|
item_attn_mask, |
|
) |
|
else: |
|
input_embeds = torch.cat( |
|
[input_embeds, item_input_embeds], dim=1 |
|
) |
|
input_embeds_attn_mask = torch.cat( |
|
[input_embeds_attn_mask, item_attn_mask], dim=1 |
|
) |
|
else: |
|
assert isinstance(cond_embed, torch.Tensor) |
|
cond_embed = self.input_sequence_embed_linear[_id](cond_embed) |
|
attn_mask = torch.ones((cond_embed.size(0), cond_embed.size(1))).to( |
|
cond_embed.device |
|
) |
|
|
|
item_input_embeds, item_attn_mask = self.add_sos_eos_tokens( |
|
_id, cond_embed, attn_mask |
|
) |
|
|
|
if input_embeds is None and input_embeds_attn_mask is None: |
|
input_embeds, input_embeds_attn_mask = ( |
|
item_input_embeds, |
|
item_attn_mask, |
|
) |
|
else: |
|
input_embeds, input_embeds_attn_mask = torch.cat( |
|
[input_embeds, item_input_embeds], dim=1 |
|
), torch.cat([input_embeds_attn_mask, item_attn_mask], dim=1) |
|
|
|
assert input_embeds is not None and input_embeds_attn_mask is not None |
|
|
|
input_embeds, input_embeds_attn_mask = self.truncate_sequence_and_mask( |
|
input_embeds, input_embeds_attn_mask, int(1024 - self.mae_token_num) |
|
) |
|
cond_sequence_end_time_idx = input_embeds.size( |
|
1 |
|
) |
|
|
|
return input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx |
|
|
|
def warmup_step(self): |
|
if self.initial_learning_rate is None: |
|
self.initial_learning_rate = float(self.learning_rate) |
|
|
|
|
|
if self.global_step <= 1000: |
|
if self.global_step == 0: |
|
print( |
|
"Warming up learning rate start with %s" |
|
% self.initial_learning_rate |
|
) |
|
self.trainer.optimizers[0].param_groups[0]["lr"] = ( |
|
self.global_step / 1000 |
|
) * self.initial_learning_rate |
|
else: |
|
|
|
self.trainer.optimizers[0].param_groups[0][ |
|
"lr" |
|
] = self.initial_learning_rate |
|
|
|
def mask_target_sequence(self, target_embeds, target_embeds_attn_mask): |
|
time_seq_mask = None |
|
if self.target_tokens_mask_ratio > 1e-4: |
|
batchsize, time_seq_len, embed_dim = target_embeds.size() |
|
_, time_seq_len = target_embeds_attn_mask.size() |
|
|
|
if self.random_mask_ratio: |
|
mask_ratio = torch.rand(1).item() * self.target_tokens_mask_ratio |
|
else: |
|
mask_ratio = self.target_tokens_mask_ratio |
|
|
|
time_seq_mask = (torch.rand((batchsize, time_seq_len)) > mask_ratio).to( |
|
target_embeds.device |
|
) |
|
|
|
target_embeds = target_embeds * time_seq_mask.unsqueeze(-1) |
|
target_embeds_attn_mask = target_embeds_attn_mask * time_seq_mask |
|
return target_embeds, target_embeds_attn_mask, time_seq_mask |
|
|
|
def training_step(self, batch, batch_idx=None, cond_dict=None, return_output=False): |
|
|
|
|
|
|
|
if self.use_warmup: |
|
self.warmup_step() |
|
|
|
if cond_dict is None: |
|
cond_dict = self.get_input(batch) |
|
|
|
|
|
target_embeds, target_embeds_attn_mask = ( |
|
cond_dict["crossattn_audiomae_pooled"][0], |
|
cond_dict["crossattn_audiomae_pooled"][1], |
|
) |
|
|
|
( |
|
input_embeds, |
|
input_embeds_attn_mask, |
|
cond_sequence_end_time_idx, |
|
) = self.get_input_sequence_and_mask(cond_dict) |
|
|
|
|
|
if "crossattn_audiomae_pooled_44" in cond_dict.keys(): |
|
target_embeds = cond_dict["crossattn_audiomae_pooled_44"][0] |
|
|
|
|
|
|
|
final_input_embeds = torch.cat([input_embeds, target_embeds], dim=1) |
|
final_input_embeds_attn_mask = torch.cat( |
|
[input_embeds_attn_mask, target_embeds_attn_mask], dim=1 |
|
) |
|
|
|
|
|
output_embeds = self.model( |
|
inputs_embeds=final_input_embeds, |
|
attention_mask=final_input_embeds_attn_mask, |
|
)["last_hidden_state"] |
|
|
|
|
|
|
|
|
|
|
|
target = target_embeds |
|
output = output_embeds[:, cond_sequence_end_time_idx - 1 : -1] |
|
|
|
|
|
|
|
assert target.size(1) == self.mae_token_num |
|
|
|
|
|
|
|
loss = self.loss_fn(output, target) |
|
|
|
if self.use_ar_gen_loss: |
|
ar_gen_loss = self.calculate_ahead_k_step_loss(batch, batch_idx, cond_dict) |
|
else: |
|
ar_gen_loss = loss |
|
|
|
if self.private_training_step % 500 == 0: |
|
print( |
|
"AudioMAE prediction module:", "loss", loss, "ar_gen_loss", ar_gen_loss |
|
) |
|
|
|
try: |
|
learning_rate = self.trainer.optimizers[0].param_groups[0]["lr"] |
|
|
|
self.log( |
|
"train/lr_audiomae_pred", |
|
learning_rate, |
|
prog_bar=True, |
|
logger=True, |
|
on_step=True, |
|
on_epoch=False, |
|
sync_dist=True, |
|
) |
|
except: |
|
pass |
|
|
|
self.log( |
|
"train/loss_clap_2_audiomae", |
|
loss, |
|
prog_bar=True, |
|
logger=True, |
|
on_step=True, |
|
on_epoch=False, |
|
sync_dist=True, |
|
) |
|
|
|
self.log( |
|
"train/loss_ar_gen_loss", |
|
ar_gen_loss, |
|
prog_bar=True, |
|
logger=True, |
|
on_step=True, |
|
on_epoch=False, |
|
sync_dist=True, |
|
) |
|
|
|
self.log( |
|
"global_step_audiomae", |
|
float(self.global_step), |
|
prog_bar=True, |
|
logger=True, |
|
on_step=True, |
|
on_epoch=False, |
|
sync_dist=True, |
|
) |
|
self.private_training_step += 1 |
|
if return_output: |
|
return loss + ar_gen_loss, output |
|
else: |
|
return loss + ar_gen_loss |
|
|
|
def calculate_ahead_k_step_loss(self, batch, batch_idx=None, cond_dict=None): |
|
if cond_dict is None: |
|
cond_dict = self.get_input(batch) |
|
|
|
target_embeds, target_embeds_attn_mask = ( |
|
cond_dict["crossattn_audiomae_pooled"][0], |
|
cond_dict["crossattn_audiomae_pooled"][1], |
|
) |
|
|
|
assert ( |
|
torch.sum(target_embeds_attn_mask < 0.1) < 1 |
|
), "This function only works for AudioMAE prediction, which should have all one atten_mask" |
|
|
|
( |
|
input_embeds, |
|
input_embeds_attn_mask, |
|
cond_sequence_end_time_idx, |
|
) = self.get_input_sequence_and_mask(cond_dict) |
|
|
|
target_total_time_steps = target_embeds.size(1) |
|
|
|
steps = min(round(torch.rand(1).item() * 8), target_total_time_steps) |
|
|
|
if steps < 2: |
|
steps = 2 |
|
|
|
start_idx = max( |
|
0, round(torch.rand(1).item() * (target_total_time_steps - steps)) - 1 |
|
) |
|
|
|
model_input = input_embeds |
|
model_input_mask = input_embeds_attn_mask |
|
target_embeds_ar_gen = target_embeds[:, start_idx : start_idx + steps, :] |
|
generation = [] |
|
|
|
if start_idx > 0: |
|
model_input = torch.cat( |
|
[input_embeds, target_embeds[:, :start_idx, :]], dim=1 |
|
) |
|
attention_mask_known_steps = torch.ones( |
|
(model_input_mask.size(0), start_idx) |
|
).to(model_input.device) |
|
model_input_mask = torch.cat( |
|
[input_embeds_attn_mask, attention_mask_known_steps], dim=1 |
|
) |
|
|
|
for _ in range(steps): |
|
output = self.model( |
|
inputs_embeds=model_input, attention_mask=model_input_mask |
|
)["last_hidden_state"] |
|
|
|
generation.append(output[:, -1:, :]) |
|
model_input = torch.cat([model_input, output[:, -1:, :]], dim=1) |
|
|
|
attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to( |
|
model_input.device |
|
) |
|
model_input_mask = torch.cat( |
|
[model_input_mask, attention_mask_new_step], dim=1 |
|
) |
|
|
|
generation = torch.cat(generation, dim=1) |
|
|
|
return self.loss_fn(generation, target_embeds_ar_gen) |
|
|
|
def generate_partial(self, batch, cond_dict=None, no_grad=False): |
|
if cond_dict is None: |
|
cond_dict = self.get_input(batch) |
|
|
|
print("Generate partially prompted audio with in-context learning") |
|
|
|
|
|
|
|
target_embeds, target_embeds_attn_mask = ( |
|
cond_dict["crossattn_audiomae_pooled"][0], |
|
cond_dict["crossattn_audiomae_pooled"][1], |
|
) |
|
|
|
target_time_steps = target_embeds.size(1) |
|
|
|
( |
|
input_embeds, |
|
input_embeds_attn_mask, |
|
cond_sequence_end_time_idx, |
|
) = self.get_input_sequence_and_mask(cond_dict) |
|
|
|
model_input = torch.cat( |
|
[input_embeds, target_embeds[:, : target_time_steps // 4, :]], dim=1 |
|
) |
|
model_input_mask = torch.cat( |
|
[ |
|
input_embeds_attn_mask, |
|
target_embeds_attn_mask[:, : target_time_steps // 4], |
|
], |
|
dim=1, |
|
) |
|
|
|
steps = self.mae_token_num |
|
|
|
for _ in range(3 * steps // 4): |
|
output = self.model( |
|
inputs_embeds=model_input, attention_mask=model_input_mask |
|
)["last_hidden_state"] |
|
|
|
model_input = torch.cat([model_input, output[:, -1:, :]], dim=1) |
|
|
|
attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to( |
|
model_input.device |
|
) |
|
model_input_mask = torch.cat( |
|
[model_input_mask, attention_mask_new_step], dim=1 |
|
) |
|
|
|
output = model_input[:, cond_sequence_end_time_idx:] |
|
|
|
return output, cond_dict |
|
|
|
def generate(self, batch, cond_dict=None, no_grad=False): |
|
if cond_dict is None: |
|
cond_dict = self.get_input(batch) |
|
|
|
|
|
|
|
|
|
( |
|
input_embeds, |
|
input_embeds_attn_mask, |
|
cond_sequence_end_time_idx, |
|
) = self.get_input_sequence_and_mask(cond_dict) |
|
model_input = input_embeds |
|
model_input_mask = input_embeds_attn_mask |
|
|
|
steps = self.mae_token_num |
|
|
|
for _ in range(steps): |
|
output = self.model( |
|
inputs_embeds=model_input, attention_mask=model_input_mask |
|
)["last_hidden_state"] |
|
|
|
model_input = torch.cat([model_input, output[:, -1:, :]], dim=1) |
|
|
|
attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to( |
|
model_input.device |
|
) |
|
model_input_mask = torch.cat( |
|
[model_input_mask, attention_mask_new_step], dim=1 |
|
) |
|
|
|
return model_input[:, cond_sequence_end_time_idx:], cond_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
cond_dict = self.get_input(batch) |
|
|
|
|
|
|
|
target_embeds, target_embeds_attn_mask = ( |
|
cond_dict["crossattn_audiomae_pooled"][0], |
|
cond_dict["crossattn_audiomae_pooled"][1], |
|
) |
|
|
|
( |
|
input_embeds, |
|
input_embeds_attn_mask, |
|
cond_sequence_end_time_idx, |
|
) = self.get_input_sequence_and_mask(cond_dict) |
|
|
|
|
|
if "crossattn_audiomae_pooled_44" in cond_dict.keys(): |
|
target_embeds = cond_dict["crossattn_audiomae_pooled_44"][0] |
|
|
|
final_input_embeds = torch.cat([input_embeds, target_embeds], dim=1) |
|
final_input_embeds_attn_mask = torch.cat( |
|
[input_embeds_attn_mask, target_embeds_attn_mask], dim=1 |
|
) |
|
|
|
output_embeds = self.model( |
|
inputs_embeds=final_input_embeds, |
|
attention_mask=final_input_embeds_attn_mask, |
|
)["last_hidden_state"] |
|
|
|
target = target_embeds |
|
output = output_embeds[:, cond_sequence_end_time_idx - 1 : -1] |
|
|
|
loss = self.loss_fn(output, target) |
|
|
|
self.log( |
|
"val/loss", |
|
loss, |
|
prog_bar=True, |
|
logger=True, |
|
on_step=True, |
|
sync_dist=True, |
|
on_epoch=True, |
|
) |
|
|
|
generation_output, _ = self.generate(batch) |
|
ar_gen_loss = self.loss_fn(generation_output, target) |
|
|
|
self.log( |
|
"val/ar_gen_loss", |
|
ar_gen_loss, |
|
prog_bar=True, |
|
logger=True, |
|
on_step=True, |
|
sync_dist=True, |
|
on_epoch=True, |
|
) |
|
|
|
return {"loss": loss, "ar_gen_loss": ar_gen_loss} |
|
|
|
def get_input_item(self, batch, k): |
|
fname, text, label_indices, waveform, stft, fbank = ( |
|
batch["fname"], |
|
batch["text"], |
|
batch["label_vector"], |
|
batch["waveform"], |
|
batch["stft"], |
|
batch["log_mel_spec"], |
|
) |
|
ret = {} |
|
|
|
ret["fbank"] = ( |
|
fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() |
|
) |
|
ret["stft"] = stft.to(memory_format=torch.contiguous_format).float() |
|
|
|
ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() |
|
ret["text"] = list(text) |
|
ret["fname"] = fname |
|
|
|
for key in batch.keys(): |
|
if key not in ret.keys(): |
|
ret[key] = batch[key] |
|
|
|
return ret[k] |
|
|
|
def get_input(self, batch): |
|
cond_dict = {} |
|
if len(self.cond_stage_model_metadata.keys()) > 0: |
|
unconditional_cfg = False |
|
|
|
for cond_model_key in self.cond_stage_model_metadata.keys(): |
|
cond_stage_key = self.cond_stage_model_metadata[cond_model_key][ |
|
"cond_stage_key" |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
xc = self.get_input_item(batch, cond_stage_key) |
|
if type(xc) == torch.Tensor: |
|
xc = xc.to(self.device) |
|
|
|
c = self.get_learned_conditioning( |
|
xc, key=cond_model_key, unconditional_cfg=unconditional_cfg |
|
) |
|
cond_dict[cond_model_key] = c |
|
|
|
return cond_dict |
|
|
|
def instantiate_cond_stage(self, config): |
|
self.cond_stage_model_metadata = {} |
|
|
|
for i, cond_model_key in enumerate(config.keys()): |
|
model = instantiate_from_config(config[cond_model_key]) |
|
self.cond_stage_models.append(model) |
|
self.cond_stage_model_metadata[cond_model_key] = { |
|
"model_idx": i, |
|
"cond_stage_key": config[cond_model_key]["cond_stage_key"], |
|
"conditioning_key": config[cond_model_key]["conditioning_key"], |
|
} |
|
|
|
def get_learned_conditioning(self, c, key, unconditional_cfg): |
|
assert key in self.cond_stage_model_metadata.keys() |
|
|
|
|
|
if not unconditional_cfg: |
|
c = self.cond_stage_models[ |
|
self.cond_stage_model_metadata[key]["model_idx"] |
|
](c) |
|
else: |
|
if isinstance(c, torch.Tensor): |
|
batchsize = c.size(0) |
|
elif isinstance(c, list): |
|
batchsize = len(c) |
|
else: |
|
raise NotImplementedError() |
|
c = self.cond_stage_models[ |
|
self.cond_stage_model_metadata[key]["model_idx"] |
|
].get_unconditional_condition(batchsize) |
|
|
|
return c |
|
|
|
def initialize_param_check_toolkit(self): |
|
self.tracked_steps = 0 |
|
self.param_dict = {} |
|
|
|
def statistic_require_grad_tensor_number(self, module, name=None): |
|
requires_grad_num = 0 |
|
total_num = 0 |
|
require_grad_tensor = None |
|
for p in module.parameters(): |
|
if p.requires_grad: |
|
requires_grad_num += 1 |
|
if require_grad_tensor is None: |
|
require_grad_tensor = p |
|
total_num += 1 |
|
print( |
|
"Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)" |
|
% (name, requires_grad_num, total_num, requires_grad_num / total_num) |
|
) |
|
return require_grad_tensor |
|
|
|
def check_module_param_update(self): |
|
|
|
if self.tracked_steps == 0: |
|
print("Sequence2AudioMAE") |
|
for name, module in self.named_children(): |
|
try: |
|
require_grad_tensor = self.statistic_require_grad_tensor_number( |
|
module, name=name |
|
) |
|
if require_grad_tensor is not None: |
|
self.param_dict[name] = require_grad_tensor.clone() |
|
else: |
|
print("==> %s does not requires grad" % name) |
|
except Exception as e: |
|
print("%s does not have trainable parameters: %s" % (name, e)) |
|
continue |
|
|
|
if self.tracked_steps % 5000 == 0: |
|
print("Sequence2AudioMAE") |
|
for name, module in self.named_children(): |
|
try: |
|
require_grad_tensor = self.statistic_require_grad_tensor_number( |
|
module, name=name |
|
) |
|
|
|
if require_grad_tensor is not None: |
|
print( |
|
"===> Param diff %s: %s; Size: %s" |
|
% ( |
|
name, |
|
torch.sum( |
|
torch.abs( |
|
self.param_dict[name] - require_grad_tensor |
|
) |
|
), |
|
require_grad_tensor.size(), |
|
) |
|
) |
|
else: |
|
print("%s does not requires grad" % name) |
|
except Exception as e: |
|
print("%s does not have trainable parameters: %s" % (name, e)) |
|
continue |
|
|
|
self.tracked_steps += 1 |
|
|