import torch import torch.distributed.nn from torch import distributed as dist, nn as nn from torch.nn import functional as F import numpy as np from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score try: import horovod.torch as hvd except ImportError: hvd = None def gather_features( audio_features, text_features, audio_features_mlp=None, text_features_mlp=None, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False, mlp_loss=False, ): if use_horovod: assert hvd is not None, "Please install horovod" if gather_with_grad: all_audio_features = hvd.allgather(audio_features) all_text_features = hvd.allgather(text_features) if mlp_loss: all_audio_features_mlp = hvd.allgather(audio_features_mlp) all_text_features_mlp = hvd.allgather(text_features_mlp) else: with torch.no_grad(): all_audio_features = hvd.allgather(audio_features) all_text_features = hvd.allgather(text_features) if mlp_loss: all_audio_features_mlp = hvd.allgather(audio_features_mlp) all_text_features_mlp = hvd.allgather(text_features_mlp) if not local_loss: # ensure grads for local rank when all_* features don't have a gradient gathered_audio_features = list( all_audio_features.chunk(world_size, dim=0) ) gathered_text_features = list( all_text_features.chunk(world_size, dim=0) ) gathered_audio_features[rank] = audio_features gathered_text_features[rank] = text_features all_audio_features = torch.cat(gathered_audio_features, dim=0) all_text_features = torch.cat(gathered_text_features, dim=0) if mlp_loss: gathered_audio_features_mlp = list( all_audio_features_mlp.chunk(world_size, dim=0) ) gathered_text_features_mlp = list( all_text_features_mlp.chunk(world_size, dim=0) ) gathered_audio_features_mlp[rank] = audio_features_mlp gathered_text_features_mlp[rank] = text_features_mlp all_audio_features_mlp = torch.cat( gathered_audio_features_mlp, dim=0 ) all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) else: # We gather tensors from all gpus if gather_with_grad: all_audio_features = torch.cat( torch.distributed.nn.all_gather(audio_features), dim=0 ) all_text_features = torch.cat( torch.distributed.nn.all_gather(text_features), dim=0 ) if mlp_loss: all_audio_features_mlp = torch.cat( torch.distributed.nn.all_gather(audio_features_mlp), dim=0 ) all_text_features_mlp = torch.cat( torch.distributed.nn.all_gather(text_features_mlp), dim=0 ) else: gathered_audio_features = [ torch.zeros_like(audio_features) for _ in range(world_size) ] gathered_text_features = [ torch.zeros_like(text_features) for _ in range(world_size) ] dist.all_gather(gathered_audio_features, audio_features) dist.all_gather(gathered_text_features, text_features) if mlp_loss: gathered_audio_features_mlp = [ torch.zeros_like(audio_features_mlp) for _ in range(world_size) ] gathered_text_features_mlp = [ torch.zeros_like(text_features_mlp) for _ in range(world_size) ] dist.all_gather(gathered_audio_features_mlp, audio_features_mlp) dist.all_gather(gathered_text_features_mlp, text_features_mlp) if not local_loss: # ensure grads for local rank when all_* features don't have a gradient gathered_audio_features[rank] = audio_features gathered_text_features[rank] = text_features if mlp_loss: gathered_audio_features_mlp[rank] = audio_features_mlp gathered_text_features_mlp[rank] = text_features_mlp all_audio_features = torch.cat(gathered_audio_features, dim=0) all_text_features = torch.cat(gathered_text_features, dim=0) if mlp_loss: all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) if mlp_loss: return ( all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp, ) else: return all_audio_features, all_text_features class ClipLoss(nn.Module): def __init__( self, local_loss=False, gather_with_grad=False, cache_labels=False, rank=0, world_size=1, use_horovod=False, mlp_loss=False, weight_loss_kappa=0, ): super().__init__() self.local_loss = local_loss self.gather_with_grad = gather_with_grad self.cache_labels = cache_labels self.rank = rank self.world_size = world_size self.use_horovod = use_horovod self.mlp_loss = mlp_loss self.weighted_loss = bool(weight_loss_kappa != 0) self.weight_loss_kappa = weight_loss_kappa # cache state self.prev_num_logits = 0 self.labels = {} def forward( self, audio_features, text_features, logit_scale_a, logit_scale_t=None, audio_features_mlp=None, text_features_mlp=None, ): device = audio_features.device if self.mlp_loss: if self.world_size > 1: ( all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp, ) = gather_features( audio_features=audio_features, text_features=text_features, audio_features_mlp=audio_features_mlp, text_features_mlp=text_features_mlp, local_loss=self.local_loss, gather_with_grad=self.gather_with_grad, rank=self.rank, world_size=self.world_size, use_horovod=self.use_horovod, mlp_loss=self.mlp_loss, ) if self.local_loss: a_logits_per_audio = ( logit_scale_a * audio_features @ all_text_features_mlp.T ) a_logits_per_text = ( logit_scale_a * text_features_mlp @ all_audio_features.T ) t_logits_per_audio = ( logit_scale_t * audio_features_mlp @ all_text_features.T ) t_logits_per_text = ( logit_scale_t * text_features @ all_audio_features_mlp.T ) else: a_logits_per_audio = ( logit_scale_a * all_audio_features @ all_text_features_mlp.T ) a_logits_per_text = a_logits_per_audio.T t_logits_per_audio = ( logit_scale_t * all_audio_features_mlp @ all_text_features.T ) t_logits_per_text = t_logits_per_audio.T else: a_logits_per_audio = ( logit_scale_a * audio_features @ text_features_mlp.T ) a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T t_logits_per_audio = ( logit_scale_t * audio_features_mlp @ text_features.T ) t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T # calculated ground-truth and cache if enabled num_logits = a_logits_per_audio.shape[0] if self.prev_num_logits != num_logits or device not in self.labels: labels = torch.arange(num_logits, device=device, dtype=torch.long) if self.world_size > 1 and self.local_loss: labels = labels + num_logits * self.rank if self.cache_labels: self.labels[device] = labels self.prev_num_logits = num_logits else: labels = self.labels[device] if not self.weighted_loss: total_loss = ( F.cross_entropy(a_logits_per_audio, labels) + F.cross_entropy(a_logits_per_text, labels) + F.cross_entropy(t_logits_per_audio, labels) + F.cross_entropy(t_logits_per_text, labels) ) / 4 else: audio_weight = (audio_features @ audio_features.T).detach() audio_weight = ( torch.exp( torch.sum(audio_weight, axis=1) / (self.weight_loss_kappa * len(audio_weight)) ) ).detach() text_weight = (text_features @ text_features.T).detach() text_weight = ( torch.exp( torch.sum(text_weight, axis=1) / (self.weight_loss_kappa * len(text_features)) ) ).detach() total_loss = ( F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) + F.cross_entropy(t_logits_per_text, labels, weight=text_weight) ) / 4 else: if self.world_size > 1: all_audio_features, all_text_features = gather_features( audio_features=audio_features, text_features=text_features, local_loss=self.local_loss, gather_with_grad=self.gather_with_grad, rank=self.rank, world_size=self.world_size, use_horovod=self.use_horovod, mlp_loss=self.mlp_loss, ) if self.local_loss: logits_per_audio = ( logit_scale_a * audio_features @ all_text_features.T ) logits_per_text = ( logit_scale_a * text_features @ all_audio_features.T ) else: logits_per_audio = ( logit_scale_a * all_audio_features @ all_text_features.T ) logits_per_text = logits_per_audio.T else: logits_per_audio = logit_scale_a * audio_features @ text_features.T logits_per_text = logit_scale_a * text_features @ audio_features.T # calculated ground-truth and cache if enabled num_logits = logits_per_audio.shape[0] if self.prev_num_logits != num_logits or device not in self.labels: labels = torch.arange(num_logits, device=device, dtype=torch.long) if self.world_size > 1 and self.local_loss: labels = labels + num_logits * self.rank if self.cache_labels: self.labels[device] = labels self.prev_num_logits = num_logits else: labels = self.labels[device] if not self.weighted_loss: total_loss = ( F.cross_entropy(logits_per_audio, labels) + F.cross_entropy(logits_per_text, labels) ) / 2 else: audio_weight = (all_audio_features @ all_audio_features.T).detach() audio_weight = ( torch.exp( torch.sum(audio_weight, axis=1) / (self.weight_loss_kappa * len(all_audio_features)) ) ).detach() text_weight = (all_text_features @ all_text_features.T).detach() text_weight = ( torch.exp( torch.sum(text_weight, axis=1) / (self.weight_loss_kappa * len(all_text_features)) ) ).detach() total_loss = ( F.cross_entropy(logits_per_audio, labels, weight=text_weight) + F.cross_entropy(logits_per_text, labels, weight=audio_weight) ) / 2 return total_loss def lp_gather_features(pred, target, world_size=1, use_horovod=False): if use_horovod: assert hvd is not None, "Please install horovod" with torch.no_grad(): all_preds = hvd.allgather(pred) all_targets = hvd.allgath(target) else: gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] dist.all_gather(gathered_preds, pred) dist.all_gather(gathered_targets, target) all_preds = torch.cat(gathered_preds, dim=0) all_targets = torch.cat(gathered_targets, dim=0) return all_preds, all_targets def get_map(pred, target): pred = torch.sigmoid(pred).numpy() target = target.numpy() return np.mean(average_precision_score(target, pred, average=None)) def get_acc(pred, target): pred = torch.argmax(pred, 1).numpy() target = torch.argmax(target, 1).numpy() return accuracy_score(target, pred) def get_mauc(pred, target): pred = torch.sigmoid(pred).numpy() target = target.numpy() return np.mean(roc_auc_score(target, pred, average=None)) class LPMetrics(object): def __init__(self, metric_names=["map", "acc", "mauc"]): self.metrics = [] for name in metric_names: self.metrics.append(self.get_metric(name)) self.metric_names = metric_names def get_metric(self, name): if name == "map": return get_map elif name == "acc": return get_acc elif name == "mauc": return get_mauc else: raise ValueError(f"the metric should be at least one of [map, acc, mauc]") def evaluate_mertics(self, pred, target): metric_dict = {} for i in range(len(self.metric_names)): metric_dict[self.metric_names[i]] = self.metrics[i](pred, target) return metric_dict def calc_celoss(pred, target): target = torch.argmax(target, 1).long() return nn.CrossEntropyLoss()(pred, target) class LPLoss(nn.Module): def __init__(self, loss_name): super().__init__() if loss_name == "bce": self.loss_func = nn.BCEWithLogitsLoss() elif loss_name == "ce": self.loss_func = calc_celoss elif loss_name == "mse": self.loss_func = nn.MSELoss() else: raise ValueError(f"the loss func should be at least one of [bce, ce, mse]") def forward(self, pred, target): loss = self.loss_func(pred, target) return loss