from typing import * import torch from .span import Span def _tensor2span_batch( span_boundary: torch.Tensor, span_labels: torch.Tensor, parent_indices: torch.Tensor, num_spans: torch.Tensor, label_confidence: torch.Tensor, idx2label: Dict[int, str], label_ignore: List[int], ) -> Span: spans = list() for (start_idx, end_idx), parent_idx, label, label_conf in \ list(zip(span_boundary, parent_indices, span_labels, label_confidence))[:int(num_spans)]: if label not in label_ignore: span = Span(int(start_idx), int(end_idx), idx2label[int(label)], True, confidence=float(label_conf)) if int(parent_idx) < len(spans): spans[int(parent_idx)].add_child(span) spans.append(span) return spans[0] def tensor2span( span_boundary: torch.Tensor, span_labels: torch.Tensor, parent_indices: torch.Tensor, num_spans: torch.Tensor, label_confidence: torch.Tensor, idx2label: Dict[int, str], label_ignore: Optional[List[int]] = None, ) -> List[Span]: """ Generate spans in dict from vectors. Refer to the model part for the meaning of these variables. If idx_ignore is provided, some labels will be ignored. :return: """ label_ignore = label_ignore or [] if span_boundary.device.type != 'cpu': span_boundary = span_boundary.to(device='cpu') parent_indices = parent_indices.to(device='cpu') span_labels = span_labels.to(device='cpu') num_spans = num_spans.to(device='cpu') label_confidence = label_confidence.to(device='cpu') ret = list() for args in zip( span_boundary.unbind(0), span_labels.unbind(0), parent_indices.unbind(0), num_spans.unbind(0), label_confidence.unbind(0), ): ret.append(_tensor2span_batch(*args, label_ignore=label_ignore, idx2label=idx2label)) return ret