Unigram-Watermark / gptwm.py
Xuandong's picture
init
6a20eb3
import hashlib
from typing import List
import numpy as np
from scipy.stats import norm
import torch
from transformers import LogitsWarper
class GPTWatermarkBase:
"""
Base class for watermarking distributions with fixed-group green-listed tokens.
Args:
fraction: The fraction of the distribution to be green-listed.
strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.
vocab_size: The size of the vocabulary.
watermark_key: The random seed for the green-listing.
"""
def __init__(self, fraction: float = 0.5, strength: float = 2.0, vocab_size: int = 50257, watermark_key: int = 0):
rng = np.random.default_rng(self._hash_fn(watermark_key))
mask = np.array([True] * int(fraction * vocab_size) + [False] * (vocab_size - int(fraction * vocab_size)))
rng.shuffle(mask)
self.green_list_mask = torch.tensor(mask, dtype=torch.float32)
self.strength = strength
self.fraction = fraction
@staticmethod
def _hash_fn(x: int) -> int:
"""solution from https://stackoverflow.com/questions/67219691/python-hash-function-that-returns-32-or-64-bits"""
x = np.int64(x)
return int.from_bytes(hashlib.sha256(x).digest()[:4], 'little')
class GPTWatermarkLogitsWarper(GPTWatermarkBase, LogitsWarper):
"""
LogitsWarper for watermarking distributions with fixed-group green-listed tokens.
Args:
fraction: The fraction of the distribution to be green-listed.
strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.
vocab_size: The size of the vocabulary.
watermark_key: The random seed for the green-listing.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
"""Add the watermark to the logits and return new logits."""
watermark = self.strength * self.green_list_mask
new_logits = scores + watermark.to(scores.device)
return new_logits
class GPTWatermarkDetector(GPTWatermarkBase):
"""
Class for detecting watermarks in a sequence of tokens.
Args:
fraction: The fraction of the distribution to be green-listed.
strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.
vocab_size: The size of the vocabulary.
watermark_key: The random seed for the green-listing.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
def _z_score(num_green: int, total: int, fraction: float) -> float:
"""Calculate and return the z-score of the number of green tokens in a sequence."""
return (num_green - fraction * total) / np.sqrt(fraction * (1 - fraction) * total)
@staticmethod
def _compute_tau(m: int, N: int, alpha: float) -> float:
"""
Compute the threshold tau for the dynamic thresholding.
Args:
m: The number of unique tokens in the sequence.
N: Vocabulary size.
alpha: The false positive rate to control.
Returns:
The threshold tau.
"""
factor = np.sqrt(1 - (m - 1) / (N - 1))
tau = factor * norm.ppf(1 - alpha)
return tau
def detect(self, sequence: List[int]) -> float:
"""Detect the watermark in a sequence of tokens and return the z value."""
green_tokens = int(sum(self.green_list_mask[i] for i in sequence))
green_tokens_mask = []
for i in sequence:
if self.green_list_mask[i]:
green_tokens_mask.append(True)
else:
green_tokens_mask.append(False)
# self.green_tokens_mask = green_tokens_mask
return self._z_score(green_tokens, len(sequence), self.fraction), green_tokens_mask,green_tokens,len(sequence)
def unidetect(self, sequence: List[int]) -> float:
"""Detect the watermark in a sequence of tokens and return the z value. Just for unique tokens."""
sequence = list(set(sequence))
green_tokens = int(sum(self.green_list_mask[i] for i in sequence))
return self._z_score(green_tokens, len(sequence), self.fraction)
def dynamic_threshold(self, sequence: List[int], alpha: float, vocab_size: int) -> (bool, float):
"""Dynamic thresholding for watermark detection. True if the sequence is watermarked, False otherwise."""
z_score = self.unidetect(sequence)
tau = self._compute_tau(len(list(set(sequence))), vocab_size, alpha)
return z_score > tau, z_score