chendl's picture
update cap
e770d90
raw
history blame
5.61 kB
from typing import Dict, Sequence, Tuple
import re
import numpy as np
import torch
def postprocess_classification_generation(predictions) -> str:
return re.split("Prompt|Completion", predictions, 1)[0]
def compute_classification_accuracy(predictions: Sequence[Dict[str, str]]) -> float:
"""Compute the accuracy of a sequence of predictions."""
def _preprocess_fn(s):
"""Function to preprocess both targets and predictions."""
return s.lower()
is_correct = [
_preprocess_fn(x["prediction"]) == _preprocess_fn(x["class_label"])
for x in predictions
]
return np.mean(is_correct).item()
def compute_shifted_logits_and_labels(
logits: torch.Tensor, encodings, tokenizer, eoc_token_id
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Helper function to compute shifted logits and labels.
This allows for straightforward computation of the loss on shift_logits
and shift_labels such that the nth element of logits computes the n-1th
element of the original labels (in the outputs, the nth element of logits
corresponds to the nth element of the labels).
Elements in shift_labels that correspond to inputs are masked with values
of -100 (by default in hf, loss is only computed on token IDs >= 0).
Returns: tuple containing two elements:
shift_logits: a float Tensor of shape [batch_size, seq_len - 1].
shift_labels: an integer Tensor of shape [batch_size, seq_len - 1]
"""
labels = encodings["input_ids"].clone()
# convert padding and EOC tokens to -100 so they are ignored in loss
labels[labels == tokenizer.pad_token_id] = -100
labels[labels == eoc_token_id] = -100
# Convert all tokens in prefix until separator to -100 so they are
# ignored in loss
for idx in range(len(labels)):
# Find the location of the last token of prefix *from right*,
# since the first non-padding token of the sequence will also be
# eos_token (because bos_token and eos_token are the same for
# the tokenizer).
end_of_prefix = -labels[idx].tolist()[::-1].index(tokenizer.eos_token_id) - 1
labels[idx, : end_of_prefix + 1] = -100
# Shift so that tokens < n predict n. The shifted tensors both have
# shape [batch_size, seq_len - 1].
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return shift_logits, shift_labels
def compute_per_sample_probs(
encodings, tokenizer, logits: torch.Tensor, eoc_token_id
) -> torch.Tensor:
"""Helper function to compute per-sample probability of the input sequence.
Assumes <eos token> is used to separate inputs from targets in the
prompt text
"""
shift_logits, shift_labels = compute_shifted_logits_and_labels(
logits, encodings, tokenizer, eoc_token_id
)
# Tuple of tensors for unmasked label tokens. The first element of the
# tuple contains the batch indices; the second element contains the
# sequence indices.
unmasked_indices = torch.nonzero(shift_labels != -100, as_tuple=True)
# Tensor where the i^th element is the token_id corresponding to the i^th
# element of unmasked_indices
unmasked_token_ids = shift_labels[unmasked_indices]
# 3d tensor of [batch_idx, sequence_position, token_id] for unmasked tokens.
target_idxs = torch.column_stack([*unmasked_indices, unmasked_token_ids])
target_idxs = target_idxs.to(shift_logits.device)
# Sanity check that every element in batch has at least one unmasked
# target token
assert torch.all(
torch.bincount(target_idxs[:, 0]) != 0
), "At least one element in batch has no unmasked target tokens."
# Renormalize over tokens to make sure they are proper probabilities via
# softmax over the token dimension.
shift_probs = torch.nn.functional.softmax(shift_logits, 2)
# Compute the probability of the target sequence (as the product of the
# probability of the individual tokens in the sequence).
target_probs = torch.ones(len(shift_labels), device=shift_logits.device)
for i, j, k in target_idxs:
target_probs[i] *= shift_probs[i, j, k]
return target_probs
def compute_per_sample_loss(encodings, tokenizer, logits, eoc_token_id) -> torch.Tensor:
"""Helper function to compute per-sample classification loss.
Assumes <eos token> is used to separate inputs from targets in the
prompt text
"""
shift_logits, shift_labels = compute_shifted_logits_and_labels(
logits, encodings, tokenizer, eoc_token_id
)
device = shift_logits.device
# Loss is computed token-wise, on Tensors of shape
# [batch_size * (seq_len - 1), vocab_size]
# and returns a loss tensor of shape
# [batch_size * (seq_len - 1)]. Most of the tokens will be masked
# in this computation.
loss = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1).to(device),
reduction="none",
)
# Reshape to [batch_size, seq_len - 1]
loss = loss.view(shift_logits.size(0), shift_logits.size(1)).cpu()
# loss_mask is 1 for tokens we want included in the loss, and 0 for tokens
# that should be ignored in the loss.
loss_mask = (shift_labels != -100).int().cpu()
loss *= loss_mask
# Compute per-element loss : sum loss over all (unmasked) tokens and
# divide by number of variable tokens to obtain tensor of
# shape [batch_size,]
loss = loss.sum(dim=1) / (shift_labels != -100).sum(dim=1).float()
return loss