Kpenciler's picture
Upload 53 files
88435ed verified
raw
history blame
2.13 kB
from typing import Any, TypeVar
from neollm.utils.utils import cprint
Immutable = tuple[Any, ...] | str | int | float | bool
_T = TypeVar("_T")
_TD = TypeVar("_TD")
def _to_immutable(x: Any) -> Immutable:
"""list, dictをtupleに変換して, setに格納できるようにする
Args:
x (Any): 要素
Returns:
Immutable: Immutableな要素(dict, listはtupleに変換)
"""
if isinstance(x, list):
return tuple(map(_to_immutable, x))
if isinstance(x, dict):
return tuple((key, _to_immutable(value)) for key, value in sorted(x.items()))
if isinstance(x, (set, frozenset)):
return tuple(sorted(map(_to_immutable, x)))
if isinstance(x, (str, int, float, bool)):
return x
cprint("_to_immutable: not supported: 無理やりstr(*)", color="yellow", background=True)
return str(x)
def _remove_duplicate(arr: list[_T | None]) -> list[_T]:
"""listの重複と初期値を削除する
Args:
arr (list[Any]): リスト
Returns:
list[Any]: 重複削除済みのlist
"""
seen_set: set[Immutable] = set()
unique_list: list[_T] = []
for x in arr:
if x is None or bool(x) is False:
continue
x_immutable = _to_immutable(x)
if x_immutable not in seen_set:
unique_list.append(x)
seen_set.add(x_immutable)
return unique_list
def get_entity(arr: list[_T | None], default: _TD, index: int | None = None) -> _T | _TD:
"""listから必要な1要素を取得する
Args:
arr (list[Any]): list
default (Any): 初期値
index (int | None, optional): 複数ある場合、指定のindex. Defaults to None.
Returns:
Any: 要素
"""
arr_cleaned = _remove_duplicate(arr)
if len(arr_cleaned) == 0:
return default
if len(arr_cleaned) == 1:
return arr_cleaned[0]
if index is not None:
return arr_cleaned[index]
cprint("get_entity: not unique", color="yellow", background=True)
cprint(arr_cleaned, color="yellow", background=True)
return arr_cleaned[0]