Spaces:
Runtime error
Runtime error
File size: 6,470 Bytes
bfc0ec6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
"""Find near-duplicates using minhash.
# Code forked from
# https://github.com/bigcode-project/bigcode-dataset/blob/main/near_deduplication/minhash_deduplication.py
# under the Apache 2.0 License.
"""
import gc
import hashlib
import re
import struct
from collections import defaultdict
from itertools import tee
from typing import Iterable, List
import numpy as np
from scipy.integrate import quad as integrate
from tqdm import tqdm
SEED = 42
WHITESPACE = re.compile(r'\s+')
RNG = np.random.RandomState(SEED)
MAX_HASH = np.uint64((1 << 32) - 1)
MERSENNE_PRIME = np.uint64((1 << 61) - 1)
def _ngrams(sequence: List[str], n: int, min_ngram_size: int) -> Iterable:
"""Directly taken from nltk package to avoid dependency.
Args:
sequence The sequence of items to be n-grammed.
n The order of the n-grams to be extracted.
min_ngram_size The minimum size of n-grams.
Returns
The n-grams generated from the sequence.
"""
if len(sequence) < min_ngram_size:
return []
ngram_size = min(n, len(sequence))
iterables = tee(sequence, ngram_size)
for i, sub_iterable in enumerate(iterables):
for _ in range(i):
next(sub_iterable, None)
return zip(*iterables)
def _sha1_hash32(data: bytes) -> int:
"""Directly taken from datasketch package to avoid dependency."""
return struct.unpack('<I', hashlib.sha1(data).digest()[:4])[0]
def _embed_func(
content: str,
num_perm: int,
ngram_size: int,
hashranges: list[tuple[int, int]],
permutations: np.ndarray,
min_ngram_size: int,
) -> list[bytes]:
"""Combined with some datasketch code to better parallelize computation.
Args:
content The content to be embedded.
idx The index of the content.
num_perm The number of permutations.
ngram_size The size of n-grams.
hashranges The ranges of hash values.
permutations The permutations for the minhash.
min_ngram_size The minimum size of n-grams.
Returns
The hash values in each range and the index.
"""
hashvalues = np.ones(num_perm, dtype=np.uint64) * MAX_HASH
tokens = {' '.join(t) for t in _ngrams(WHITESPACE.split(content), ngram_size, min_ngram_size)}
hv = np.array([_sha1_hash32(token.encode('utf-8')) for token in tokens],
dtype=np.uint64) # noqa: E501
a, b = permutations
phv = np.bitwise_and(((hv * np.tile(a, (len(hv), 1)).T).T + b) % MERSENNE_PRIME,
MAX_HASH) # noqa: E501
hashvalues = np.vstack([phv, hashvalues]).min(axis=0)
Hs: list[bytes] = [bytes(hashvalues[start:end].byteswap().data) for start, end in hashranges]
return Hs
def _optimal_param(threshold: float,
num_perm: int,
false_positive_weight: float = 0.5,
false_negative_weight: float = 0.5) -> tuple[int, int]:
"""Find optimal `MinHashLSH` parameter that minimizes the weighted sum of false pos and false neg.
Taken from datasketch.
Args
threshold The threshold for similarity.
num_perm The number of permutations.
false_positive_weight The weight of false positive.
false_negative_weight The weight of false negative.
Returns
The optimal `b` and `r` parameters.
The number of bands, and the number of rows per band respectively.
"""
def false_positive_probability(threshold: float, b: int, r: int) -> float:
"""Source: `datasketch.lsh`."""
def proba(s: float) -> float:
return 1 - (1 - s**float(r))**float(b)
a, _ = integrate(proba, 0.0, threshold)
return a
def false_negative_probability(threshold: float, b: int, r: int) -> float:
"""Source: `datasketch.lsh`."""
def proba(s: float) -> float:
return 1 - (1 - (1 - s**float(r))**float(b))
a, _ = integrate(proba, threshold, 1.0)
return a
min_error = float('inf')
opt = (0, 0)
for b in range(1, num_perm + 1):
max_r = int(num_perm / b)
for r in range(1, max_r + 1):
fp = false_positive_probability(threshold, b, r)
fn = false_negative_probability(threshold, b, r)
error = fp * false_positive_weight + fn * false_negative_weight
if error < min_error:
min_error = error
opt = (b, r)
return opt
class UnionFind:
"""Union find data structure."""
def __init__(self) -> None:
self.parent: dict[int, int] = {}
def find(self, x: int) -> int:
"""Find the parent of the node."""
if x not in self.parent:
self.parent[x] = x
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x: int, y: int) -> None:
"""Union two nodes."""
px = self.find(x)
py = self.find(y)
self.parent[px] = self.parent[py] = min(px, py)
def find_clusters(data: Iterable[str],
ngram_size: int = 5,
num_perm: int = 256,
threshold: float = 0.7,
min_ngram_size: int = 1) -> Iterable[int]:
"""Deduplicates documents and returns cluster ids."""
uf = UnionFind()
B, R = _optimal_param(threshold, num_perm)
HASH_RANGES: list[tuple[int, int]] = [(i * R, (i + 1) * R) for i in range(B)]
HASH_TABLES: list[dict[bytes, set[int]]] = [defaultdict(set) for _ in range(B)]
# Consume the data.
PERMUTATIONS = np.array(
[(
RNG.randint(1, MERSENNE_PRIME, dtype=np.uint64),
RNG.randint(0, MERSENNE_PRIME, dtype=np.uint64),
) for _ in range(num_perm)],
dtype=np.uint64,
).T
# Fingerprinting.
embedded: list[tuple[int, list[bytes]]] = []
for key, content in tqdm(enumerate(data), dynamic_ncols=True, desc='Fingerprinting...'):
hashes = _embed_func(
content,
num_perm=num_perm,
hashranges=HASH_RANGES,
ngram_size=ngram_size,
permutations=PERMUTATIONS,
min_ngram_size=min_ngram_size)
embedded.append((key, hashes))
batch_size: int = 10000
for i in tqdm(
range(0, len(embedded), batch_size), dynamic_ncols=True, desc='Computing hash collisions...'):
batch = embedded[i:i + batch_size]
for (key, Hs) in batch:
for H, hashtable in zip(Hs, HASH_TABLES):
hashtable[H].add(key)
for table in tqdm(HASH_TABLES, dynamic_ncols=True, desc='Clustering...'):
for cluster in table.values():
if len(cluster) <= 1:
continue
idx = min(cluster)
for x in cluster:
uf.union(x, idx)
gc.freeze()
gc.disable()
cluster_ids = [uf.find(i) for i in range(len(embedded))]
gc.enable()
gc.collect()
return cluster_ids
|