Spaces:
Sleeping
Sleeping
File size: 5,666 Bytes
ede461a 1d02824 b551628 e6ef189 11208e8 6c6da17 fe1b9ba e6ef189 da99fed e6ef189 ede461a e6ef189 0d3784b e6ef189 0d3784b e6ef189 0d3784b e6ef189 0d3784b e6ef189 0d3784b e6ef189 f12f776 e6ef189 fe1b9ba e6ef189 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import re
from itertools import count, islice
from typing import Any, Iterable, Literal, Optional, TypedDict, TypeVar, Union, overload
from datasets import Features, Value, get_dataset_config_info
from datasets.features.features import FeatureType, _visit
from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult
Row = dict[str, Any]
T = TypeVar("T")
BATCH_SIZE = 1
MAX_TEXT_LENGTH = 500
analyzer = AnalyzerEngine()
batch_analyzer = BatchAnalyzerEngine(analyzer)
class PresidioEntity(TypedDict):
text: str
type: str
row_idx: int
column_name: str
@overload
def batched(it: Iterable[T], n: int) -> Iterable[list[T]]:
...
@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[list[T]]:
...
@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[tuple[list[int], list[T]]]:
...
def batched(
it: Iterable[T], n: int, with_indices: bool = False
) -> Union[Iterable[list[T]], Iterable[tuple[list[int], list[T]]]]:
it, indices = iter(it), count()
while batch := list(islice(it, n)):
yield (list(islice(indices, len(batch))), batch) if with_indices else batch
def mask(text: str) -> str:
return text # don't apply mask for demo
# return " ".join(
# word[: min(2, len(word) - 1)] + re.sub("[A-Za-z0-9]", "*", word[min(2, len(word) - 1) :])
# for word in text.split(" ")
# )
def get_strings(row_content: Any) -> str:
if isinstance(row_content, str):
return row_content
if isinstance(row_content, dict):
if "src" in row_content:
return "" # could be image or audio
row_content = list(row_content.values())
if isinstance(row_content, list):
str_items = (get_strings(row_content_item) for row_content_item in row_content)
return "\n".join(str_item for str_item in str_items if str_item)
return ""
def _simple_analyze_iterator_cache(
batch_analyzer: BatchAnalyzerEngine,
texts: Iterable[str],
language: str,
score_threshold: float,
cache: dict[str, list[RecognizerResult]],
) -> list[list[RecognizerResult]]:
not_cached_results = iter(
batch_analyzer.analyze_iterator(
(text for text in texts if text not in cache), language=language, score_threshold=score_threshold
)
)
results = [cache[text] if text in cache else next(not_cached_results) for text in texts]
# cache the last results
cache.clear()
cache.update(dict(zip(texts, results)))
return results
def analyze(
batch_analyzer: BatchAnalyzerEngine,
batch: list[dict[str, str]],
indices: Iterable[int],
scanned_columns: list[str],
columns_descriptions: list[str],
cache: Optional[dict[str, list[RecognizerResult]]] = None,
) -> list[PresidioEntity]:
cache = {} if cache is None else cache
texts = [
f"The following is {columns_description} data:\n\n{example[column_name] or ''}"
for example in batch
for column_name, columns_description in zip(scanned_columns, columns_descriptions)
]
return [
PresidioEntity(
text=mask(texts[i * len(scanned_columns) + j][recognizer_result.start : recognizer_result.end]),
type=recognizer_result.entity_type,
row_idx=row_idx,
column_name=column_name,
)
for i, row_idx, recognizer_row_results in zip(
count(),
indices,
batched(_simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache), len(scanned_columns)),
)
for j, column_name, columns_description, recognizer_results in zip(
count(), scanned_columns, columns_descriptions, recognizer_row_results
)
for recognizer_result in recognizer_results
if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n")
]
def presidio_scan_entities(
rows: Iterable[Row], scanned_columns: list[str], columns_descriptions: list[str]
) -> Iterable[PresidioEntity]:
cache: dict[str, list[RecognizerResult]] = {}
rows_with_scanned_columns_only = (
{column_name: get_strings(row[column_name])[:MAX_TEXT_LENGTH] for column_name in scanned_columns} for row in rows
)
for indices, batch in batched(rows_with_scanned_columns_only, BATCH_SIZE, with_indices=True):
yield from analyze(
batch_analyzer=batch_analyzer,
batch=batch,
indices=indices,
scanned_columns=scanned_columns,
columns_descriptions=columns_descriptions,
cache=cache,
)
def get_columns_with_strings(features: Features) -> list[str]:
columns_with_strings: list[str] = []
for column, feature in features.items():
str_column = str(column)
with_string = False
def classify(feature: FeatureType) -> None:
nonlocal with_string
if isinstance(feature, Value) and feature.dtype == "string":
with_string = True
_visit(feature, classify)
if with_string:
columns_with_strings.append(str_column)
return columns_with_strings
def get_column_description(column_name: str, feature: FeatureType) -> str:
nested_fields: list[str] = []
def get_nested_field_names(feature: FeatureType) -> None:
nonlocal nested_fields
if isinstance(feature, dict):
nested_fields += list(feature)
_visit(feature, get_nested_field_names)
return f"{column_name} (with {', '.join(nested_fields)})" if nested_fields else column_name
|