Spaces:
Sleeping
Sleeping
import re | |
from itertools import count, islice | |
from typing import Any, Iterable, Literal, Optional, TypedDict, TypeVar, Union, overload | |
from datasets import Features, Value, get_dataset_config_info | |
from datasets.features.features import FeatureType, _visit | |
from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult | |
Row = dict[str, Any] | |
T = TypeVar("T") | |
BATCH_SIZE = 1 | |
MAX_TEXT_LENGTH = 500 | |
analyzer = AnalyzerEngine() | |
batch_analyzer = BatchAnalyzerEngine(analyzer) | |
class PresidioEntity(TypedDict): | |
text: str | |
type: str | |
row_idx: int | |
column_name: str | |
def batched(it: Iterable[T], n: int) -> Iterable[list[T]]: | |
... | |
def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[list[T]]: | |
... | |
def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[tuple[list[int], list[T]]]: | |
... | |
def batched( | |
it: Iterable[T], n: int, with_indices: bool = False | |
) -> Union[Iterable[list[T]], Iterable[tuple[list[int], list[T]]]]: | |
it, indices = iter(it), count() | |
while batch := list(islice(it, n)): | |
yield (list(islice(indices, len(batch))), batch) if with_indices else batch | |
def mask(text: str) -> str: | |
return text # don't apply mask for demo | |
# return " ".join( | |
# word[: min(2, len(word) - 1)] + re.sub("[A-Za-z0-9]", "*", word[min(2, len(word) - 1) :]) | |
# for word in text.split(" ") | |
# ) | |
def get_strings(row_content: Any) -> str: | |
if isinstance(row_content, str): | |
return row_content | |
if isinstance(row_content, dict): | |
if "src" in row_content: | |
return "" # could be image or audio | |
row_content = list(row_content.values()) | |
if isinstance(row_content, list): | |
str_items = (get_strings(row_content_item) for row_content_item in row_content) | |
return "\n".join(str_item for str_item in str_items if str_item) | |
return "" | |
def _simple_analyze_iterator_cache( | |
batch_analyzer: BatchAnalyzerEngine, | |
texts: Iterable[str], | |
language: str, | |
score_threshold: float, | |
cache: dict[str, list[RecognizerResult]], | |
) -> list[list[RecognizerResult]]: | |
not_cached_results = iter( | |
batch_analyzer.analyze_iterator( | |
(text for text in texts if text not in cache), language=language, score_threshold=score_threshold | |
) | |
) | |
results = [cache[text] if text in cache else next(not_cached_results) for text in texts] | |
# cache the last results | |
cache.clear() | |
cache.update(dict(zip(texts, results))) | |
return results | |
def analyze( | |
batch_analyzer: BatchAnalyzerEngine, | |
batch: list[dict[str, str]], | |
indices: Iterable[int], | |
scanned_columns: list[str], | |
columns_descriptions: list[str], | |
cache: Optional[dict[str, list[RecognizerResult]]] = None, | |
) -> list[PresidioEntity]: | |
cache = {} if cache is None else cache | |
texts = [ | |
f"The following is {columns_description} data:\n\n{example[column_name] or ''}" | |
for example in batch | |
for column_name, columns_description in zip(scanned_columns, columns_descriptions) | |
] | |
return [ | |
PresidioEntity( | |
text=mask(texts[i * len(scanned_columns) + j][recognizer_result.start : recognizer_result.end]), | |
type=recognizer_result.entity_type, | |
row_idx=row_idx, | |
column_name=column_name, | |
) | |
for i, row_idx, recognizer_row_results in zip( | |
count(), | |
indices, | |
batched(_simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache), len(scanned_columns)), | |
) | |
for j, column_name, columns_description, recognizer_results in zip( | |
count(), scanned_columns, columns_descriptions, recognizer_row_results | |
) | |
for recognizer_result in recognizer_results | |
if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n") | |
] | |
def presidio_scan_entities( | |
rows: Iterable[Row], scanned_columns: list[str], columns_descriptions: list[str] | |
) -> Iterable[PresidioEntity]: | |
cache: dict[str, list[RecognizerResult]] = {} | |
rows_with_scanned_columns_only = ( | |
{column_name: get_strings(row[column_name])[:MAX_TEXT_LENGTH] for column_name in scanned_columns} for row in rows | |
) | |
for indices, batch in batched(rows_with_scanned_columns_only, BATCH_SIZE, with_indices=True): | |
yield from analyze( | |
batch_analyzer=batch_analyzer, | |
batch=batch, | |
indices=indices, | |
scanned_columns=scanned_columns, | |
columns_descriptions=columns_descriptions, | |
cache=cache, | |
) | |
def get_columns_with_strings(features: Features) -> list[str]: | |
columns_with_strings: list[str] = [] | |
for column, feature in features.items(): | |
str_column = str(column) | |
with_string = False | |
def classify(feature: FeatureType) -> None: | |
nonlocal with_string | |
if isinstance(feature, Value) and feature.dtype == "string": | |
with_string = True | |
_visit(feature, classify) | |
if with_string: | |
columns_with_strings.append(str_column) | |
return columns_with_strings | |
def get_column_description(column_name: str, feature: FeatureType) -> str: | |
nested_fields: list[str] = [] | |
def get_nested_field_names(feature: FeatureType) -> None: | |
nonlocal nested_fields | |
if isinstance(feature, dict): | |
nested_fields += list(feature) | |
_visit(feature, get_nested_field_names) | |
return f"{column_name} (with {', '.join(nested_fields)})" if nested_fields else column_name | |