Spaces:
Runtime error
Runtime error
Push
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env +0 -6
- .gitattributes +0 -35
- .gitignore +1 -1
- Dockerfile +3 -0
- data/concept/lilac/profanity/concept.json +0 -0
- data/concept/lilac/profanity/sbert.pkl +0 -3
- data/concept/lilac/toxicity/cohere.pkl +0 -3
- data/concept/lilac/toxicity/concept.json +0 -0
- data/concept/lilac/toxicity/openai.pkl +0 -3
- data/concept/lilac/toxicity/sbert.pkl +0 -3
- data/concept/local/outerspace/cohere.pkl +0 -3
- data/concept/local/outerspace/concept.json +0 -188
- data/concept/local/outerspace/openai.pkl +0 -3
- data/concept/local/outerspace/sbert.pkl +0 -3
- data/datasets/local/spotify/data-00000-of-00001.parquet +0 -3
- data/datasets/local/spotify/manifest.json +0 -27
- data/datasets/local/spotify/text/.concepts/local/aliens/sbert-neg-100.pkl +0 -3
- data/datasets/local/spotify/text/.concepts/local/outer_space/sbert-neg-100.pkl +0 -3
- data/datasets/local/spotify/text/.concepts/local/outerspace/sbert-neg-100.pkl +0 -3
- data/datasets/local/spotify/text/.concepts/local/phone_addiction/sbert-neg-100.pkl +0 -3
- data/datasets/local/spotify/text/sbert/data-00000-of-00001.parquet +0 -3
- data/datasets/local/spotify/text/sbert/embedding/local/outerspace/v34/data-00000-of-00001.parquet +0 -3
- data/datasets/local/spotify/text/sbert/embedding/local/outerspace/v34/signal_manifest.json +0 -64
- data/datasets/local/spotify/text/sbert/embeddings-00000-of-00001.keys.pkl +0 -3
- data/datasets/local/spotify/text/sbert/embeddings-00000-of-00001.npy +0 -3
- data/datasets/local/spotify/text/sbert/signal_manifest.json +0 -37
- requirements.txt +1 -0
- src/concepts/concept.py +17 -8
- src/concepts/concept_test.py +0 -84
- src/concepts/db_concept_test.py +0 -606
- src/data/dataset_compute_signal_chain_test.py +0 -255
- src/data/dataset_compute_signal_test.py +0 -669
- src/data/dataset_duckdb.py +15 -13
- src/data/dataset_select_groups_test.py +0 -317
- src/data/dataset_select_rows_filter_test.py +0 -200
- src/data/dataset_select_rows_schema_test.py +0 -551
- src/data/dataset_select_rows_search_test.py +0 -393
- src/data/dataset_select_rows_sort_test.py +0 -904
- src/data/dataset_select_rows_udf_test.py +0 -404
- src/data/dataset_stats_test.py +0 -125
- src/data/dataset_test.py +0 -860
- src/data/dataset_utils.py +68 -34
- src/data/dataset_utils_test.py +0 -114
- src/data/sources/csv_source_test.py +0 -42
- src/data/sources/huggingface_source_test.py +0 -170
- src/data/sources/json_source_test.py +0 -74
- src/data/sources/pandas_source_test.py +0 -91
- src/data/sources/source_registry_test.py +0 -55
- src/data_loader_test.py +0 -74
- src/embeddings/embedding.py +18 -6
.env
CHANGED
@@ -26,9 +26,3 @@ DUCKDB_USE_VIEWS=0
|
|
26 |
# HF_USERNAME=
|
27 |
# The default repo to deploy to for a staging demo. Can be overridden by a command line flag.
|
28 |
# HF_STAGING_DEMO_REPO='HF_ORG/HF_REPO_NAME'
|
29 |
-
|
30 |
-
# HuggingFace demos: HuggingFace machine that runs the demo.
|
31 |
-
|
32 |
-
# To read private uploaded data from the server (running on HF spaces) for the demo.
|
33 |
-
# Get a token from https://huggingface.co/settings/tokens
|
34 |
-
# HF_ACCESS_TOKEN=
|
|
|
26 |
# HF_USERNAME=
|
27 |
# The default repo to deploy to for a staging demo. Can be overridden by a command line flag.
|
28 |
# HF_STAGING_DEMO_REPO='HF_ORG/HF_REPO_NAME'
|
|
|
|
|
|
|
|
|
|
|
|
.gitattributes
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
**/*.pyc
|
3 |
**/*.pyo
|
4 |
**/*.pyd
|
|
|
1 |
+
__pycache__/
|
2 |
**/*.pyc
|
3 |
**/*.pyo
|
4 |
**/*.pyd
|
Dockerfile
CHANGED
@@ -22,6 +22,9 @@ COPY /web/blueprint/build ./web/blueprint/build
|
|
22 |
# Copy python files.
|
23 |
COPY /src ./src/
|
24 |
|
|
|
|
|
|
|
25 |
CMD [ \
|
26 |
"gunicorn", "src.server:app", \
|
27 |
"--bind", "0.0.0.0:5432", \
|
|
|
22 |
# Copy python files.
|
23 |
COPY /src ./src/
|
24 |
|
25 |
+
# Copy the data files. We use glob so docker copy won't fail if the directory doesn't exist.
|
26 |
+
COPY /dat[a] ./data/
|
27 |
+
|
28 |
CMD [ \
|
29 |
"gunicorn", "src.server:app", \
|
30 |
"--bind", "0.0.0.0:5432", \
|
data/concept/lilac/profanity/concept.json
DELETED
The diff for this file is too large to render.
See raw diff
|
|
data/concept/lilac/profanity/sbert.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:647280d255e1a1fabff691683926fbb49dfaffe2f8151cf9913ec98816eef473
|
3 |
-
size 844427
|
|
|
|
|
|
|
|
data/concept/lilac/toxicity/cohere.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:670e81b8448ab0ee5161a42b523410b3af80c6ccce8003cae78edebb9d0981c4
|
3 |
-
size 9720631
|
|
|
|
|
|
|
|
data/concept/lilac/toxicity/concept.json
DELETED
The diff for this file is too large to render.
See raw diff
|
|
data/concept/lilac/toxicity/openai.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:e15e8235c2152b1412a8e2dee3dcb94b23e95f1fde6fb60f01b876a832e46404
|
3 |
-
size 3678199
|
|
|
|
|
|
|
|
data/concept/lilac/toxicity/sbert.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:8ac8b304760c88242eb6c567e1af87fd87731a192308df8cf43b253e24d2b0ec
|
3 |
-
size 959111
|
|
|
|
|
|
|
|
data/concept/local/outerspace/cohere.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:30afc472c4beb1aabb24d5b8e633df6039ec532fd704d8318755e083592221f3
|
3 |
-
size 331736
|
|
|
|
|
|
|
|
data/concept/local/outerspace/concept.json
DELETED
@@ -1,188 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"namespace": "local",
|
3 |
-
"concept_name": "outerspace",
|
4 |
-
"type": "text",
|
5 |
-
"data": {
|
6 |
-
"da77c67f82524ce1a276593471fd530f": {
|
7 |
-
"label": true,
|
8 |
-
"text": "Fly me to the moon and let me play among the stars.",
|
9 |
-
"id": "da77c67f82524ce1a276593471fd530f"
|
10 |
-
},
|
11 |
-
"f73feff4be50410ab1ac468752d0301a": {
|
12 |
-
"label": true,
|
13 |
-
"text": "Space may be the final frontier but it's made in a Hollywood basement.",
|
14 |
-
"id": "f73feff4be50410ab1ac468752d0301a"
|
15 |
-
},
|
16 |
-
"0f0815ed04604209842d9e7b1e3538f8": {
|
17 |
-
"label": true,
|
18 |
-
"text": "We're just a speck of dust within the galaxy.",
|
19 |
-
"id": "0f0815ed04604209842d9e7b1e3538f8"
|
20 |
-
},
|
21 |
-
"2e41f637061e4ecb8b0d4e35abab9b63": {
|
22 |
-
"label": true,
|
23 |
-
"text": "In the darkest night, the stars shine bright and guide me to the moonlight.",
|
24 |
-
"id": "2e41f637061e4ecb8b0d4e35abab9b63"
|
25 |
-
},
|
26 |
-
"fb65845f4dc84da1b276de30967592e3": {
|
27 |
-
"label": true,
|
28 |
-
"text": "We'll be shooting star through time and space\r\n\r\n",
|
29 |
-
"id": "fb65845f4dc84da1b276de30967592e3"
|
30 |
-
},
|
31 |
-
"075534e3095b421687039291439b5524": {
|
32 |
-
"label": true,
|
33 |
-
"text": "Dreaming of love while cruising at high altitude \r\nDreaming of making love with you the way we should \r\nCloser to heaven. We're thirty thousand feet, up in the sky \r\nHere among the stars, our spirits will fly \r\n \r\nLeave all your worries as we soar over the clouds \r\nJet lag that's making you appear far from the crowd \r\nWhile we're suspended, locked in each others, sweet embrace \r",
|
34 |
-
"id": "075534e3095b421687039291439b5524"
|
35 |
-
},
|
36 |
-
"4bb656032d0d4f449bac8aa5f23c3e48": {
|
37 |
-
"label": true,
|
38 |
-
"text": " \r\nI don't know where I don't know why \r\nBut somehow back in time again \r\nI'm on the edge that you can see \r\n \r\nI'm not particular at night \r\nA single party calling me \r\nYou won't be tracking me by sight \r\n \r\nShadows and the stars \r\nWe will not return \r\nHumanity won't save us \r\nAt the speed of light \r\n \r\nShadows and the stars \r\nWe will not return \r\nHumanity won't save us \r",
|
39 |
-
"id": "4bb656032d0d4f449bac8aa5f23c3e48"
|
40 |
-
},
|
41 |
-
"4a6dda9001ea487991a1264e6a6c021b": {
|
42 |
-
"label": true,
|
43 |
-
"text": "Load redeem me, am I pure? \r\nAs pure as pure as heaven \r\nSent you money sent you flowers \r\nCould worship you for hours \r\nIn whose hands are we anyway? \r\n \r\nGo waiting for the stars \r\nTo come showering down \r\nFrom Moscow to Mars \r\nUniverse falling down \r\n \r\nYou got to look real hard \r\nIs it in your heart? \r\nYeah it's in there somewhere \r\nThe power wrapped in your palm \r",
|
44 |
-
"id": "4a6dda9001ea487991a1264e6a6c021b"
|
45 |
-
},
|
46 |
-
"9aacce9311d24cb182aee783ca313c58": {
|
47 |
-
"label": true,
|
48 |
-
"text": "Growth is our future resource. \r\n \r\nJoin the state of the universe, \r\nUnited state of peace. \r\nJoin the state of the universe, \r\nUnited state of peace. \r\n \r\nStarpeace, I see you growing, \r\nStarpeace, I see you growing, \r\nStarpeace, I see you growing, \r\nStarpeace, I see you growing, \r\nStarpeace, I see you growing, \r\nStarpeace, I see you growing, \r\nStarpeace, I see you growing, \r",
|
49 |
-
"id": "9aacce9311d24cb182aee783ca313c58"
|
50 |
-
},
|
51 |
-
"313b8f9ce9164791b04ead82e6adb40f": {
|
52 |
-
"label": false,
|
53 |
-
"text": " \r\nEven I could see a light if it wasn't for the nights \r\n(Even I could see a light I think that I could make it) \r\nGuess my future would look bright if it wasn't for the nights\r\n\r\n",
|
54 |
-
"id": "313b8f9ce9164791b04ead82e6adb40f"
|
55 |
-
},
|
56 |
-
"b9c587b74f084ef4917e7a52cd5c5cbe": {
|
57 |
-
"label": true,
|
58 |
-
"text": "Yea I think I know \r\nI really hate it when it gets too slow \r\nI gotta try and keep myself amused \r\nI love the way my rocket purrs \r\nI like it best when I see blurs \r\nYou gotta love to watch me light my fuse \r\n \r\nNo more lookin' back to yesterday \r\nI got the thing to blow us both away \r\nAll I need is you to navigate \r\nSo come and ride my Rocket 88 \r\n \r",
|
59 |
-
"id": "b9c587b74f084ef4917e7a52cd5c5cbe"
|
60 |
-
},
|
61 |
-
"6f844600cc024117a22287557130a17b": {
|
62 |
-
"label": false,
|
63 |
-
"text": "They came flying from far away, now I'm under their spell \r\nI love hearing the stories that they tell \r\nThey've seen places beyond my land and they've found new horizons \r\nThey speak strangely but I understand \r\n \r\nAnd I dream I'm an eagle \r\nAnd I dream I can spread my wings \r\nFlying high, high, I'm a bird in the sky \r\nI'm an eagle that rides on the breeze \r",
|
64 |
-
"id": "6f844600cc024117a22287557130a17b"
|
65 |
-
},
|
66 |
-
"8cddcff34f894743872ecc02262c2375": {
|
67 |
-
"label": true,
|
68 |
-
"text": "Fire! I can see it burning so brightly \r\nFire! I can feel it calling out to me \r\nAnd as the sun goes down \r\nIt starts to paint a picture \r\n \r\nOf an ancient town \r\nSo far away, across the endless sea \r\nLead me to the light \r\nAnd take me to the edge of heaven \r\n \r\nI'm standing in the night \r\nLooking for the edge of heaven \r\nWe'll be touching the edge of heaven \r\nTime \r\n \r",
|
69 |
-
"id": "8cddcff34f894743872ecc02262c2375"
|
70 |
-
},
|
71 |
-
"3d044718f379452ab3c1e4d00c99f8f3": {
|
72 |
-
"label": false,
|
73 |
-
"text": "Fire! I can see it burning so brightly \r\nFire! I can feel it calling out to me \r\nAnd as the sun goes down \r\nIt starts to paint a picture \r\n \r\nOf an ancient town \r\nSo far away, across the endless sea \r\nLead me to the light \r\nAnd take me to the edge of heaven \r\n \r\nI'm standing in the night \r\nLooking for the edge of heaven \r\nWe'll be touching the edge of heaven \r\nTime \r\n \r",
|
74 |
-
"id": "3d044718f379452ab3c1e4d00c99f8f3"
|
75 |
-
},
|
76 |
-
"d233250a91d44f13aac58eb5fa43afe6": {
|
77 |
-
"label": true,
|
78 |
-
"text": "Star \r\nWe go waiting for the stars \r\nTo come showering down \r\nFrom Moscow to Mars \r\nUniverse falling down \r\n \r\nYou got to look real hard \r\nThere's a fiery star \r\nHidden out there somewhere \r\nNot the satellite of love \r\nBut a laser \r\nShooting out it's shiny tongue there \r\n \r\nGod is love, God is war \r\nTV-preacher tell me more \r\nLoad redeem me, am I pure? \r",
|
79 |
-
"id": "d233250a91d44f13aac58eb5fa43afe6"
|
80 |
-
},
|
81 |
-
"a30c9a5c63a2456f8f53a9177a522d7a": {
|
82 |
-
"label": false,
|
83 |
-
"text": "Tell me do you want to be free \r\n \r\nWell your love falls down you know \r\nAnd your heart might fall to pieces \r\nAnd I saw your soul get lost along the way \r\n \r\nAll these songs now they used to make you shine \r\nThey are just lullabies for your nightmares \r\nAnd Ill sing them softly now \r\n \r\nLately I've felt the warmth \r\nOf the one who tore down my walls \r\nBut then I look at you \r",
|
84 |
-
"id": "a30c9a5c63a2456f8f53a9177a522d7a"
|
85 |
-
},
|
86 |
-
"89ce6961ff064f719212e68058bb2013": {
|
87 |
-
"label": false,
|
88 |
-
"text": "I Left Them Niggas Needin'Path \r\nAnd Ya'll Probly Won't Live To See This Weekend, \r\nGotta Go, Gotta Go, FUckin Mash Out \r\nI Hit The Dro' A Lil More And Then I Pass Out \r\nCrashin' The H2, Bitches I Hate You \r\nNow you Keep Talkin Shit, I Kidnap And Ducktape You \r\nLet Them Faggots Rape You \r\nThen It's Back To Mississippi, If Ya Boys Want Revenge \r\nTell Them Bitches Come And Get Me \r",
|
89 |
-
"id": "89ce6961ff064f719212e68058bb2013"
|
90 |
-
},
|
91 |
-
"6de1b38adc9b4f48ac15609dad02faa0": {
|
92 |
-
"label": true,
|
93 |
-
"text": "In heaven's eye \r\n \r\nYes, this is our star. \r\nYes, this is our star. \r\nOur star our star.\r\n\r\n",
|
94 |
-
"id": "6de1b38adc9b4f48ac15609dad02faa0"
|
95 |
-
},
|
96 |
-
"52ccd98280b849f498d838b6230285a7": {
|
97 |
-
"label": false,
|
98 |
-
"text": "Tell Them Bitches Come And Get Me \r\n'cause I Was Born In This Bitch To Die \r\nI'm In Queens, In Your 'Lac, With Your Bitch, Gettin' High \r\n \r\nYoung Buck: \r\nGold Grills, Coupe' Devilles Sittin On 22's \r\nThe Dirty, Dirty Baby \r\nShow 'Em How The South Do \r\nWe Pop Pills, Shoot To Kill, You Know What We 'Bout \r\nAnd On Behalf Of G-Unit, Welcome To The South \r\n \r\nLil Flip: \r",
|
99 |
-
"id": "52ccd98280b849f498d838b6230285a7"
|
100 |
-
},
|
101 |
-
"866a61ec0ab04a54ade2532b7825c858": {
|
102 |
-
"label": false,
|
103 |
-
"text": "I Swear On The Soul's Of Our Dead Cousin's \r\nI Ain't Fuckin, Man I'm Commin Ak 40's Bustin', \r\n7's And Mack 11's \r\nI Told 'Em All I Ain't No Hoe \r\nBut Niggas Don't Listen Till You Kick A Nigga, \r\nSmack Him With That Callico \r\nI'm Tryin To Stay In Gods Plan \r\nBut I Hadta Show These Faggots That Your Fuckin With A Man, Ya Bitch! \r\nI Left Them Niggas Needin'Path \r",
|
104 |
-
"id": "866a61ec0ab04a54ade2532b7825c858"
|
105 |
-
},
|
106 |
-
"0a2dbf3ee6cd46ae9f71ecb65e02674e": {
|
107 |
-
"label": true,
|
108 |
-
"text": "And filling up the space \r\nMen and women boys and girls \r\nThere are so many people in the world \r\nThinkin' about the world \r\nAnd all the people in it \r\nAnd I'm staring at the stars \r\nAnd into the infinite \r\nIn a world without a world \r\nOn a planet that's \r\nDriftin' in a space \r\n \r\nSeconds into minutes and minutes \r\nInto hour and hours into days \r",
|
109 |
-
"id": "0a2dbf3ee6cd46ae9f71ecb65e02674e"
|
110 |
-
},
|
111 |
-
"fff7748b4c384cb49ae18f96df719aa8": {
|
112 |
-
"label": false,
|
113 |
-
"text": "And the way things ought to be \r\n \r\nWhat kind of difference \r\nCan on person make? \r\nCut to the chase\r\n\r\n",
|
114 |
-
"id": "fff7748b4c384cb49ae18f96df719aa8"
|
115 |
-
},
|
116 |
-
"54971cdd9be0444096cacd2637a50ce4": {
|
117 |
-
"label": false,
|
118 |
-
"text": "With bar lights and pretty girls \r\nBut most nights I stay straight and think about my mom \r\nOh god, I miss her so much \r\n \r\nAnd there are people on the street \r\nThey're coming up to me \r\nThey're telling me that they like what I do now \r\nAnd so I tried my best when I took the fall \r\nTo get right back up, back in your arms \r\nIf you're out here why do I miss you so much \r\n \r",
|
119 |
-
"id": "54971cdd9be0444096cacd2637a50ce4"
|
120 |
-
},
|
121 |
-
"048e4f04661d4f71a48d48f216b30975": {
|
122 |
-
"label": true,
|
123 |
-
"text": " \r\nShadows and the stars \r\nWe will not return \r\nHumanity won't save us \r\nAt the speed of light \r\n \r\nShadows and the stars \r\nWe will not return \r\nHumanity won't save us \r\nWe slip into the night \r\n \r\nI'll say a mass for you and wave \r\nShooting plasma from my grave \r\n \r\nEvent horizon lost in space \r\nRunning in a human race \r\n \r\nI don't know where I don't know why \r",
|
124 |
-
"id": "048e4f04661d4f71a48d48f216b30975"
|
125 |
-
},
|
126 |
-
"f4ee9e97357c4f2fa0ed627a6983e4de": {
|
127 |
-
"label": false,
|
128 |
-
"text": "I am here to tell you we can never meet again \r\nSimple really, isn't it, a word or two and then \r\nA lifetime of not knowing where or how or why or when \r\nYou think of me or speak of me or wonder what befell \r\nThe someone you once loved so long ago so well \r\n \r\nNever wonder what I'll feel as living shuffles by \r\nYou don't have to ask me and I need not reply \r",
|
129 |
-
"id": "f4ee9e97357c4f2fa0ed627a6983e4de"
|
130 |
-
},
|
131 |
-
"797514b7375f4ef8bfbd3320936b266a": {
|
132 |
-
"label": false,
|
133 |
-
"text": " \r\nThe last time that I saw him he was trying hard to get \r\nA woman's education but he's not a woman yet \r\nAnd the last time that I saw her she was living with some boy \r\nWho gives her soul an empty room and gives her body joy. \r\n \r\nSo the great affair is over but whoever would have guessed \r\nIt would leave us all so vacant and so deeply unimpressed \r",
|
134 |
-
"id": "797514b7375f4ef8bfbd3320936b266a"
|
135 |
-
},
|
136 |
-
"56663fdf792a4820b7ae2e4344542cfa": {
|
137 |
-
"label": true,
|
138 |
-
"text": "Yeah we'll find our star \r\nBut maybe that's another world \r\n \r\nFar away from where we are \r\nYeah we'll find our star \r\nBut maybe that's another world\r\n\r\n",
|
139 |
-
"id": "56663fdf792a4820b7ae2e4344542cfa"
|
140 |
-
},
|
141 |
-
"d522d97e7d44430e945e40720d54e98d": {
|
142 |
-
"label": false,
|
143 |
-
"text": "The silly people just like you and better too. \r\nHow can you keep turning when the overture is burning in the faces \r\nOf the people in the churches of the land. \r\n \r\nThat's all it seems, there is only one dream. \r\nThe day has come at last.\r\n\r\n",
|
144 |
-
"id": "d522d97e7d44430e945e40720d54e98d"
|
145 |
-
},
|
146 |
-
"761a17d5909d4c7c9cd0cd1ac8c2db76": {
|
147 |
-
"label": false,
|
148 |
-
"text": "Ah the man she wanted all her life was hanging by a thread \r\n\"I never even knew how much I wanted you,\" she said. \r\nHis muscles they were numbered and his style was obsolete. \r\n\"O baby, I have come too late.\" She knelt beside his feet. \r\n\"I'll never see a face like yours in years of men to come \r\nI'll never see such arms again in wrestling or in love.\" \r",
|
149 |
-
"id": "761a17d5909d4c7c9cd0cd1ac8c2db76"
|
150 |
-
},
|
151 |
-
"ffc68f626c7d41be8661babedf589778": {
|
152 |
-
"label": true,
|
153 |
-
"text": "let us make computations of the stars. \r\n \r\nOlder, wiser, sadder, blinder, watch us run: \r\nfaster, longer, harder, stronger, now it comes: \r\ncolour blisters, image splinters gravitate \r\ntowards the centre, in final splendour disintegrate, \r\nThe universe now beckons \r\nand Man, too, must take His place... \r\njust a few last fleeting seconds \r\nto wander in the waste, \r",
|
154 |
-
"id": "ffc68f626c7d41be8661babedf589778"
|
155 |
-
},
|
156 |
-
"8e8ffd440c2f48ebb5ae04810be5d090": {
|
157 |
-
"label": false,
|
158 |
-
"text": "And boy you'll see \r\nIt's an illusion shining down in front of me, \r\n \r\nAnd then you'll say \r\nEven in time we shall control the day, \r\nWhen what you'll see \r\nDeep inside base controlling you and me. \r\n \r\nAnd one peculiar point I see, \r\nAs one of many ones of me. \r\nAs truth is gathered, I rearrange, \r\nInside out, outside in, inside out, outside in, \r\nPerpetual change. \r\n \r",
|
159 |
-
"id": "8e8ffd440c2f48ebb5ae04810be5d090"
|
160 |
-
},
|
161 |
-
"c61414a653bb4a9482f341dbfbea4a47": {
|
162 |
-
"label": false,
|
163 |
-
"text": "While there's still time to choose \r\n \r\nEvery day of my life I discover \r\nSomeone murdering my sisters and brothers \r\nIn the name of some god or another \r\nWhat do you know \r\n \r\nFor the first precious few it's time to go \r\nWhat might have been we'll never know \r\nAll those bad ideas became the law \r\nOh yes, we've forgotten what we're looking for \r",
|
164 |
-
"id": "c61414a653bb4a9482f341dbfbea4a47"
|
165 |
-
},
|
166 |
-
"3a325e1d3789416584ad836e2d32df05": {
|
167 |
-
"label": true,
|
168 |
-
"text": "Earth is the third planet from the Sun and the only place known in the universe where life has originated and found habitability. This is enabled by Earth being a water world, the only one in the Solar System sustaining liquid surface water. Almost all of Earth's water is contained in its global ocean, spanning 70.8% of Earth's surface. The other 29.2% are spanned by land, consisting of continents",
|
169 |
-
"id": "3a325e1d3789416584ad836e2d32df05"
|
170 |
-
},
|
171 |
-
"44e9840483164b6b97e06f909e25b8dc": {
|
172 |
-
"label": false,
|
173 |
-
"text": "Human geography\nToggle Human geography subsection\nCultural and historical viewpoint\nSee also\nNotes\nReferences\nExternal links\nEarth",
|
174 |
-
"id": "44e9840483164b6b97e06f909e25b8dc"
|
175 |
-
},
|
176 |
-
"bcf625326bc64c6ca6d37fb59bffa5ba": {
|
177 |
-
"label": true,
|
178 |
-
"text": "When the ebbing tide retreats along the rocky shoreline\nIt leaves a trail of tide pools in a short-lived galaxy\nEach microcosmic planet, a complete society\nA simple kind of mirror to reflect upon our own\nAll the busy little creatures chasing out their destinies\nLiving in their pools, they soon forget about the sea\nWheel within wheels in a spiral array\nA pattern so grand and complex",
|
179 |
-
"id": "bcf625326bc64c6ca6d37fb59bffa5ba"
|
180 |
-
},
|
181 |
-
"7c2be4b17d8f49069f6179c5256acc5e": {
|
182 |
-
"label": true,
|
183 |
-
"text": "Beneath my dreams and wishes \nI long for thy caresses. \n \nA bridal bed awaits us both, \nAfter the landscape of death I cross. \nBefore my sorrows I must die, \nNightwish I send through the starlit sky. \n \n\"Passed away in silence \nThe flute from the realm unseen \nEmpties it's heart \nMaking love to me \nWith it's enchanting melody. \nLight of Orion, \nShadow of Andromeda, ",
|
184 |
-
"id": "7c2be4b17d8f49069f6179c5256acc5e"
|
185 |
-
}
|
186 |
-
},
|
187 |
-
"version": 34
|
188 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/concept/local/outerspace/openai.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:7ea2acd96a43d1c678273e7ec297b1758a3d09d1137f0325ac3058ca9a110112
|
3 |
-
size 126895
|
|
|
|
|
|
|
|
data/concept/local/outerspace/sbert.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:9916794dbe5526af5103019735188b637f9975a5326a21713380058034e13525
|
3 |
-
size 34935
|
|
|
|
|
|
|
|
data/datasets/local/spotify/data-00000-of-00001.parquet
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:32224657332b09187a737c73ab634f9d14c9ba9a240bd105f1b9819cde2afcef
|
3 |
-
size 37128682
|
|
|
|
|
|
|
|
data/datasets/local/spotify/manifest.json
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"files": [
|
3 |
-
"data-00000-of-00001.parquet"
|
4 |
-
],
|
5 |
-
"data_schema": {
|
6 |
-
"fields": {
|
7 |
-
"artist": {
|
8 |
-
"dtype": "string"
|
9 |
-
},
|
10 |
-
"song": {
|
11 |
-
"dtype": "string"
|
12 |
-
},
|
13 |
-
"link": {
|
14 |
-
"dtype": "string"
|
15 |
-
},
|
16 |
-
"text": {
|
17 |
-
"dtype": "string"
|
18 |
-
},
|
19 |
-
"__line_number__": {
|
20 |
-
"dtype": "int64"
|
21 |
-
},
|
22 |
-
"__rowid__": {
|
23 |
-
"dtype": "string"
|
24 |
-
}
|
25 |
-
}
|
26 |
-
}
|
27 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/datasets/local/spotify/text/.concepts/local/aliens/sbert-neg-100.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:93f390fafd0d0db4ae6ae80d30bfbf8eb0a80fa9332f77f30449d40a11df0936
|
3 |
-
size 183363
|
|
|
|
|
|
|
|
data/datasets/local/spotify/text/.concepts/local/outer_space/sbert-neg-100.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:3fc9ac4c9b8b8588e48ebabbe34598edb4431985d20e018225b84546b96ce2ea
|
3 |
-
size 166637
|
|
|
|
|
|
|
|
data/datasets/local/spotify/text/.concepts/local/outerspace/sbert-neg-100.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:f3432ea5dcfbe7f7a17c94a4cc0c09e3317b8a690456fdf3af3efa0dcaa6f4fc
|
3 |
-
size 188685
|
|
|
|
|
|
|
|
data/datasets/local/spotify/text/.concepts/local/phone_addiction/sbert-neg-100.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:f795fb8b5d52650bd9aa5c871ff5d480e95413cd0afb65822a634c02f6674825
|
3 |
-
size 163242
|
|
|
|
|
|
|
|
data/datasets/local/spotify/text/sbert/data-00000-of-00001.parquet
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:9796beb630cc3503f3c2ac9db8f71e4c1604570836d78bbf364e801cd427c39e
|
3 |
-
size 2709987
|
|
|
|
|
|
|
|
data/datasets/local/spotify/text/sbert/embedding/local/outerspace/v34/data-00000-of-00001.parquet
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:d1ba0fe68cc02849b0a20b7f72047c8e9cb8e5ef5b57b0cd642fa0b0be8a6e06
|
3 |
-
size 3340135
|
|
|
|
|
|
|
|
data/datasets/local/spotify/text/sbert/embedding/local/outerspace/v34/signal_manifest.json
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"files": [
|
3 |
-
"data-00000-of-00001.parquet"
|
4 |
-
],
|
5 |
-
"parquet_id": "local/outerspace/v34(text.sbert.*.embedding)",
|
6 |
-
"data_schema": {
|
7 |
-
"fields": {
|
8 |
-
"__rowid__": {
|
9 |
-
"dtype": "string"
|
10 |
-
},
|
11 |
-
"text": {
|
12 |
-
"fields": {
|
13 |
-
"sbert": {
|
14 |
-
"repeated_field": {
|
15 |
-
"fields": {
|
16 |
-
"embedding": {
|
17 |
-
"fields": {
|
18 |
-
"local/outerspace/v34": {
|
19 |
-
"dtype": "float32",
|
20 |
-
"signal": {
|
21 |
-
"signal_name": "concept_score",
|
22 |
-
"embedding": "sbert",
|
23 |
-
"namespace": "local",
|
24 |
-
"concept_name": "outerspace",
|
25 |
-
"draft": "main",
|
26 |
-
"num_negative_examples": 100
|
27 |
-
},
|
28 |
-
"bins": [
|
29 |
-
[
|
30 |
-
"Not in concept",
|
31 |
-
null,
|
32 |
-
0.5
|
33 |
-
],
|
34 |
-
[
|
35 |
-
"In concept",
|
36 |
-
0.5,
|
37 |
-
null
|
38 |
-
]
|
39 |
-
]
|
40 |
-
}
|
41 |
-
}
|
42 |
-
}
|
43 |
-
}
|
44 |
-
}
|
45 |
-
}
|
46 |
-
}
|
47 |
-
}
|
48 |
-
}
|
49 |
-
},
|
50 |
-
"signal": {
|
51 |
-
"signal_name": "concept_score",
|
52 |
-
"embedding": "sbert",
|
53 |
-
"namespace": "local",
|
54 |
-
"concept_name": "outerspace",
|
55 |
-
"draft": "main",
|
56 |
-
"num_negative_examples": 100
|
57 |
-
},
|
58 |
-
"enriched_path": [
|
59 |
-
"text",
|
60 |
-
"sbert",
|
61 |
-
"*",
|
62 |
-
"embedding"
|
63 |
-
]
|
64 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/datasets/local/spotify/text/sbert/embeddings-00000-of-00001.keys.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:d5df43291782b8c731d4ce56537946654c642a01dc9a4e37de394836362f6b45
|
3 |
-
size 3727400
|
|
|
|
|
|
|
|
data/datasets/local/spotify/text/sbert/embeddings-00000-of-00001.npy
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:94e10c23d7229541e1f60b791a659d13673b10a03649abf0ae092e0e18c5aee3
|
3 |
-
size 170446976
|
|
|
|
|
|
|
|
data/datasets/local/spotify/text/sbert/signal_manifest.json
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"files": [
|
3 |
-
"data-00000-of-00001.parquet"
|
4 |
-
],
|
5 |
-
"parquet_id": "sbert(text)",
|
6 |
-
"data_schema": {
|
7 |
-
"fields": {
|
8 |
-
"__rowid__": {
|
9 |
-
"dtype": "string"
|
10 |
-
},
|
11 |
-
"text": {
|
12 |
-
"fields": {
|
13 |
-
"sbert": {
|
14 |
-
"repeated_field": {
|
15 |
-
"fields": {
|
16 |
-
"embedding": {
|
17 |
-
"dtype": "embedding"
|
18 |
-
}
|
19 |
-
},
|
20 |
-
"dtype": "string_span"
|
21 |
-
},
|
22 |
-
"signal": {
|
23 |
-
"signal_name": "sbert"
|
24 |
-
}
|
25 |
-
}
|
26 |
-
}
|
27 |
-
}
|
28 |
-
}
|
29 |
-
},
|
30 |
-
"signal": {
|
31 |
-
"signal_name": "sbert"
|
32 |
-
},
|
33 |
-
"enriched_path": [
|
34 |
-
"text"
|
35 |
-
],
|
36 |
-
"embedding_filename_prefix": "embeddings-00000-of-00001"
|
37 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -18,6 +18,7 @@ cytoolz==0.12.1 ; python_version >= "3.9" and python_version < "3.10"
|
|
18 |
dask==2023.6.1 ; python_version >= "3.9" and python_version < "3.10"
|
19 |
datasets==2.13.1 ; python_version >= "3.9" and python_version < "3.10"
|
20 |
decorator==5.1.1 ; python_version >= "3.9" and python_version < "3.10"
|
|
|
21 |
dill==0.3.6 ; python_version >= "3.9" and python_version < "3.10"
|
22 |
distributed==2023.6.1 ; python_version >= "3.9" and python_version < "3.10"
|
23 |
duckdb==0.8.1 ; python_version >= "3.9" and python_version < "3.10"
|
|
|
18 |
dask==2023.6.1 ; python_version >= "3.9" and python_version < "3.10"
|
19 |
datasets==2.13.1 ; python_version >= "3.9" and python_version < "3.10"
|
20 |
decorator==5.1.1 ; python_version >= "3.9" and python_version < "3.10"
|
21 |
+
detect-secrets==1.4.0 ; python_version >= "3.9" and python_version < "3.10"
|
22 |
dill==0.3.6 ; python_version >= "3.9" and python_version < "3.10"
|
23 |
distributed==2023.6.1 ; python_version >= "3.9" and python_version < "3.10"
|
24 |
duckdb==0.8.1 ; python_version >= "3.9" and python_version < "3.10"
|
src/concepts/concept.py
CHANGED
@@ -162,7 +162,7 @@ class LogisticEmbeddingModel:
|
|
162 |
def __post_init__(self) -> None:
|
163 |
# See `notebooks/Toxicity.ipynb` for an example of training a concept model.
|
164 |
self._model = LogisticRegression(
|
165 |
-
class_weight=None, C=30, tol=1e-5, warm_start=True, max_iter=
|
166 |
|
167 |
def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
|
168 |
"""Get the scores for the provided embeddings."""
|
@@ -175,11 +175,12 @@ class LogisticEmbeddingModel:
|
|
175 |
return np.random.rand(len(embeddings))
|
176 |
|
177 |
def _setup_training(
|
178 |
-
self, X_train: np.ndarray,
|
179 |
implicit_negatives: Optional[np.ndarray]) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
180 |
-
num_pos_labels = len([y for y in
|
181 |
-
num_neg_labels = len([y for y in
|
182 |
-
sample_weights = [(1.0 / num_pos_labels if y else 1.0 / num_neg_labels) for y in
|
|
|
183 |
|
184 |
if implicit_negatives is not None:
|
185 |
num_implicit_labels = len(implicit_negatives)
|
@@ -191,7 +192,14 @@ class LogisticEmbeddingModel:
|
|
191 |
# Normalize sample weights to sum to the number of training examples.
|
192 |
weights = np.array(sample_weights)
|
193 |
weights *= (X_train.shape[0] / np.sum(weights))
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
def fit(self, embeddings: np.ndarray, labels: list[bool],
|
197 |
implicit_negatives: Optional[np.ndarray]) -> None:
|
@@ -337,11 +345,12 @@ class ConceptModel:
|
|
337 |
|
338 |
embedding_items = list(embedding.compute(examples))
|
339 |
result_items: list[Item] = []
|
|
|
340 |
for item in embedding_items:
|
341 |
if not isinstance(item, list):
|
342 |
raise ValueError('Item from embedding is not a list.')
|
343 |
-
embeddings = np.array([np.
|
344 |
-
scores =
|
345 |
|
346 |
item_result: list[Item] = []
|
347 |
for embedding_item, score in zip(item, scores):
|
|
|
162 |
def __post_init__(self) -> None:
|
163 |
# See `notebooks/Toxicity.ipynb` for an example of training a concept model.
|
164 |
self._model = LogisticRegression(
|
165 |
+
class_weight=None, C=30, tol=1e-5, warm_start=True, max_iter=5_000, n_jobs=-1)
|
166 |
|
167 |
def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
|
168 |
"""Get the scores for the provided embeddings."""
|
|
|
175 |
return np.random.rand(len(embeddings))
|
176 |
|
177 |
def _setup_training(
|
178 |
+
self, X_train: np.ndarray, labels: list[bool],
|
179 |
implicit_negatives: Optional[np.ndarray]) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
180 |
+
num_pos_labels = len([y for y in labels if y])
|
181 |
+
num_neg_labels = len([y for y in labels if not y])
|
182 |
+
sample_weights = [(1.0 / num_pos_labels if y else 1.0 / num_neg_labels) for y in labels]
|
183 |
+
y_train = np.array(labels)
|
184 |
|
185 |
if implicit_negatives is not None:
|
186 |
num_implicit_labels = len(implicit_negatives)
|
|
|
192 |
# Normalize sample weights to sum to the number of training examples.
|
193 |
weights = np.array(sample_weights)
|
194 |
weights *= (X_train.shape[0] / np.sum(weights))
|
195 |
+
|
196 |
+
# Shuffle the data in unison.
|
197 |
+
p = np.random.permutation(len(X_train))
|
198 |
+
X_train = X_train[p]
|
199 |
+
y_train = y_train[p]
|
200 |
+
weights = weights[p]
|
201 |
+
|
202 |
+
return X_train, y_train, weights
|
203 |
|
204 |
def fit(self, embeddings: np.ndarray, labels: list[bool],
|
205 |
implicit_negatives: Optional[np.ndarray]) -> None:
|
|
|
345 |
|
346 |
embedding_items = list(embedding.compute(examples))
|
347 |
result_items: list[Item] = []
|
348 |
+
logistic_model = self._get_logistic_model(draft)
|
349 |
for item in embedding_items:
|
350 |
if not isinstance(item, list):
|
351 |
raise ValueError('Item from embedding is not a list.')
|
352 |
+
embeddings = np.array([np.reshape(res[EMBEDDING_KEY], -1) for res in item])
|
353 |
+
scores = logistic_model.score_embeddings(embeddings).tolist()
|
354 |
|
355 |
item_result: list[Item] = []
|
356 |
for embedding_item, score in zip(item, scores):
|
src/concepts/concept_test.py
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
"""Tests for concept."""
|
2 |
-
|
3 |
-
from ..schema import SignalInputType
|
4 |
-
from .concept import DRAFT_MAIN, Concept, Example, draft_examples
|
5 |
-
|
6 |
-
|
7 |
-
def test_draft_examples_main() -> None:
|
8 |
-
concept = Concept(
|
9 |
-
namespace='test_namespace',
|
10 |
-
concept_name='test_name',
|
11 |
-
type=SignalInputType.TEXT,
|
12 |
-
data={
|
13 |
-
'0': Example(id='0', label=True, text='hello'),
|
14 |
-
'1': Example(id='1', label=False, text='world'),
|
15 |
-
},
|
16 |
-
version=0)
|
17 |
-
|
18 |
-
assert draft_examples(concept, DRAFT_MAIN) == {
|
19 |
-
'0': Example(id='0', label=True, text='hello'),
|
20 |
-
'1': Example(id='1', label=False, text='world'),
|
21 |
-
}
|
22 |
-
|
23 |
-
|
24 |
-
def test_draft_examples_simple_draft() -> None:
|
25 |
-
concept = Concept(
|
26 |
-
namespace='test_namespace',
|
27 |
-
concept_name='test_name',
|
28 |
-
type=SignalInputType.TEXT,
|
29 |
-
data={
|
30 |
-
'0': Example(id='0', label=True, text='hello'),
|
31 |
-
'1': Example(id='1', label=False, text='world'),
|
32 |
-
'2': Example(id='2', label=True, text='hello draft 1', draft='draft1'),
|
33 |
-
'3': Example(id='3', label=False, text='world draft 1', draft='draft1'),
|
34 |
-
'4': Example(id='4', label=True, text='hello draft 2', draft='draft2'),
|
35 |
-
'5': Example(id='5', label=False, text='world draft 2', draft='draft2'),
|
36 |
-
},
|
37 |
-
version=0)
|
38 |
-
|
39 |
-
assert draft_examples(concept, DRAFT_MAIN) == {
|
40 |
-
'0': Example(id='0', label=True, text='hello'),
|
41 |
-
'1': Example(id='1', label=False, text='world'),
|
42 |
-
}
|
43 |
-
|
44 |
-
assert draft_examples(concept, 'draft1') == {
|
45 |
-
'0': Example(id='0', label=True, text='hello'),
|
46 |
-
'1': Example(id='1', label=False, text='world'),
|
47 |
-
'2': Example(id='2', label=True, text='hello draft 1', draft='draft1'),
|
48 |
-
'3': Example(id='3', label=False, text='world draft 1', draft='draft1'),
|
49 |
-
}
|
50 |
-
|
51 |
-
assert draft_examples(concept, 'draft2') == {
|
52 |
-
'0': Example(id='0', label=True, text='hello'),
|
53 |
-
'1': Example(id='1', label=False, text='world'),
|
54 |
-
'4': Example(id='4', label=True, text='hello draft 2', draft='draft2'),
|
55 |
-
'5': Example(id='5', label=False, text='world draft 2', draft='draft2'),
|
56 |
-
}
|
57 |
-
|
58 |
-
|
59 |
-
def test_draft_examples_draft_dedupe() -> None:
|
60 |
-
concept = Concept(
|
61 |
-
namespace='test_namespace',
|
62 |
-
concept_name='test_name',
|
63 |
-
type=SignalInputType.TEXT,
|
64 |
-
data={
|
65 |
-
'0': Example(id='0', label=True, text='hello'),
|
66 |
-
'1': Example(id='1', label=False, text='world'),
|
67 |
-
# Duplicate text.
|
68 |
-
'2': Example(id='2', label=False, text='hello', draft='draft'),
|
69 |
-
'3': Example(id='3', label=False, text='world draft', draft='draft'),
|
70 |
-
},
|
71 |
-
version=0)
|
72 |
-
|
73 |
-
assert draft_examples(concept, DRAFT_MAIN) == {
|
74 |
-
'0': Example(id='0', label=True, text='hello'),
|
75 |
-
'1': Example(id='1', label=False, text='world'),
|
76 |
-
}
|
77 |
-
|
78 |
-
assert draft_examples(concept, 'draft') == {
|
79 |
-
# 0 is deduplicated with 2.
|
80 |
-
'1': Example(id='1', label=False, text='world'),
|
81 |
-
# 2 overrides 0's label.
|
82 |
-
'2': Example(id='2', label=False, text='hello', draft='draft'),
|
83 |
-
'3': Example(id='3', label=False, text='world draft', draft='draft'),
|
84 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/concepts/db_concept_test.py
DELETED
@@ -1,606 +0,0 @@
|
|
1 |
-
"""Tests for the the database concept."""
|
2 |
-
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Generator, Iterable, Optional, Type, cast
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import pytest
|
8 |
-
from pytest_mock import MockerFixture
|
9 |
-
from typing_extensions import override
|
10 |
-
|
11 |
-
from ..config import CONFIG
|
12 |
-
from ..data.dataset_duckdb import DatasetDuckDB
|
13 |
-
from ..data.dataset_utils import lilac_embedding
|
14 |
-
from ..db_manager import set_default_dataset_cls
|
15 |
-
from ..schema import Item, RichData, SignalInputType
|
16 |
-
from ..signals.signal import TextEmbeddingSignal, clear_signal_registry, register_signal
|
17 |
-
from .concept import (
|
18 |
-
DRAFT_MAIN,
|
19 |
-
Concept,
|
20 |
-
ConceptModel,
|
21 |
-
DraftId,
|
22 |
-
Example,
|
23 |
-
ExampleIn,
|
24 |
-
LogisticEmbeddingModel,
|
25 |
-
)
|
26 |
-
from .db_concept import (
|
27 |
-
ConceptDB,
|
28 |
-
ConceptInfo,
|
29 |
-
ConceptModelDB,
|
30 |
-
ConceptUpdate,
|
31 |
-
DiskConceptDB,
|
32 |
-
DiskConceptModelDB,
|
33 |
-
)
|
34 |
-
|
35 |
-
ALL_CONCEPT_DBS = [DiskConceptDB]
|
36 |
-
ALL_CONCEPT_MODEL_DBS = [DiskConceptModelDB]
|
37 |
-
|
38 |
-
|
39 |
-
@pytest.fixture(autouse=True)
|
40 |
-
def set_data_path(tmp_path: Path, mocker: MockerFixture) -> None:
|
41 |
-
mocker.patch.dict(CONFIG, {'LILAC_DATA_PATH': str(tmp_path)})
|
42 |
-
|
43 |
-
|
44 |
-
EMBEDDING_MAP: dict[str, list[float]] = {
|
45 |
-
'not in concept': [1.0, 0.0, 0.0],
|
46 |
-
'in concept': [0.9, 0.1, 0.0],
|
47 |
-
'a new data point': [0.1, 0.2, 0.3],
|
48 |
-
'a true draft point': [0.4, 0.5, 0.6],
|
49 |
-
'a false draft point': [0.7, 0.8, 0.9],
|
50 |
-
}
|
51 |
-
|
52 |
-
|
53 |
-
class TestEmbedding(TextEmbeddingSignal):
|
54 |
-
"""A test embed function."""
|
55 |
-
name = 'test_embedding'
|
56 |
-
|
57 |
-
@override
|
58 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
|
59 |
-
"""Embed the examples, use a hashmap to the vector for simplicity."""
|
60 |
-
for example in data:
|
61 |
-
if example not in EMBEDDING_MAP:
|
62 |
-
raise ValueError(f'Example "{str(example)}" not in embedding map')
|
63 |
-
yield [lilac_embedding(0, len(example), np.array(EMBEDDING_MAP[cast(str, example)]))]
|
64 |
-
|
65 |
-
|
66 |
-
@pytest.fixture(scope='module', autouse=True)
|
67 |
-
def setup_teardown() -> Generator:
|
68 |
-
set_default_dataset_cls(DatasetDuckDB)
|
69 |
-
register_signal(TestEmbedding)
|
70 |
-
|
71 |
-
# Unit test runs.
|
72 |
-
yield
|
73 |
-
|
74 |
-
# Teardown.
|
75 |
-
clear_signal_registry()
|
76 |
-
|
77 |
-
|
78 |
-
@pytest.mark.parametrize('db_cls', ALL_CONCEPT_DBS)
|
79 |
-
class ConceptDBSuite:
|
80 |
-
|
81 |
-
def test_create_concept(self, db_cls: Type[ConceptDB]) -> None:
|
82 |
-
db = db_cls()
|
83 |
-
db.create(namespace='test', name='test_concept', type=SignalInputType.TEXT)
|
84 |
-
|
85 |
-
assert db.list() == [
|
86 |
-
ConceptInfo(
|
87 |
-
namespace='test', name='test_concept', type=SignalInputType.TEXT, drafts=[DRAFT_MAIN])
|
88 |
-
]
|
89 |
-
|
90 |
-
# Make sure list with drafts relects the drafts.
|
91 |
-
train_data = [
|
92 |
-
ExampleIn(label=False, text='not in concept', draft='test_draft'),
|
93 |
-
ExampleIn(label=True, text='in concept', draft='test_draft')
|
94 |
-
]
|
95 |
-
db.edit('test', 'test_concept', ConceptUpdate(insert=train_data))
|
96 |
-
|
97 |
-
assert db.list() == [
|
98 |
-
ConceptInfo(
|
99 |
-
namespace='test',
|
100 |
-
name='test_concept',
|
101 |
-
type=SignalInputType.TEXT,
|
102 |
-
drafts=[DRAFT_MAIN, 'test_draft'])
|
103 |
-
]
|
104 |
-
|
105 |
-
def test_add_example(self, db_cls: Type[ConceptDB]) -> None:
|
106 |
-
db = db_cls()
|
107 |
-
namespace = 'test'
|
108 |
-
concept_name = 'test_concept'
|
109 |
-
train_data = [
|
110 |
-
ExampleIn(label=False, text='not in concept'),
|
111 |
-
ExampleIn(label=True, text='in concept')
|
112 |
-
]
|
113 |
-
db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
|
114 |
-
db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
|
115 |
-
|
116 |
-
concept = db.get(namespace, concept_name)
|
117 |
-
|
118 |
-
assert concept is not None
|
119 |
-
|
120 |
-
keys = list(concept.data.keys())
|
121 |
-
assert concept == Concept(
|
122 |
-
namespace=namespace,
|
123 |
-
concept_name=concept_name,
|
124 |
-
type=SignalInputType.TEXT,
|
125 |
-
data={
|
126 |
-
keys[0]: Example(id=keys[0], label=False, text='not in concept'),
|
127 |
-
keys[1]: Example(id=keys[1], label=True, text='in concept')
|
128 |
-
},
|
129 |
-
version=1)
|
130 |
-
|
131 |
-
# Add a draft labels.
|
132 |
-
db.edit(
|
133 |
-
namespace, concept_name,
|
134 |
-
ConceptUpdate(insert=[
|
135 |
-
ExampleIn(label=False, text='really not in concept', draft='test_draft'),
|
136 |
-
ExampleIn(label=True, text='really in concept', draft='test_draft')
|
137 |
-
]))
|
138 |
-
|
139 |
-
concept = db.get(namespace, concept_name)
|
140 |
-
assert concept is not None
|
141 |
-
|
142 |
-
keys = list(concept.data.keys())
|
143 |
-
assert concept == Concept(
|
144 |
-
namespace=namespace,
|
145 |
-
concept_name=concept_name,
|
146 |
-
type=SignalInputType.TEXT,
|
147 |
-
data={
|
148 |
-
keys[0]: Example(id=keys[0], label=False, text='not in concept'),
|
149 |
-
keys[1]: Example(id=keys[1], label=True, text='in concept'),
|
150 |
-
keys[2]: Example(id=keys[2], label=False, text='really not in concept', draft='test_draft'),
|
151 |
-
keys[3]: Example(id=keys[3], label=True, text='really in concept', draft='test_draft'),
|
152 |
-
},
|
153 |
-
version=2)
|
154 |
-
|
155 |
-
def test_update_concept(self, db_cls: Type[ConceptDB]) -> None:
|
156 |
-
db = db_cls()
|
157 |
-
namespace = 'test'
|
158 |
-
concept_name = 'test_concept'
|
159 |
-
train_data = [
|
160 |
-
ExampleIn(label=False, text='not in concept'),
|
161 |
-
ExampleIn(label=True, text='in concept'),
|
162 |
-
ExampleIn(label=False, text='really not in concept', draft='test_draft'),
|
163 |
-
ExampleIn(label=True, text='really in concept', draft='test_draft')
|
164 |
-
]
|
165 |
-
db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
|
166 |
-
db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
|
167 |
-
|
168 |
-
concept = db.get(namespace, concept_name)
|
169 |
-
assert concept is not None
|
170 |
-
|
171 |
-
keys = list(concept.data.keys())
|
172 |
-
# Edit the first example.
|
173 |
-
db.edit(
|
174 |
-
namespace, concept_name,
|
175 |
-
ConceptUpdate(update=[Example(id=keys[0], label=False, text='not in concept, updated')]))
|
176 |
-
concept = db.get(namespace, concept_name)
|
177 |
-
|
178 |
-
assert concept == Concept(
|
179 |
-
namespace=namespace,
|
180 |
-
concept_name=concept_name,
|
181 |
-
type=SignalInputType.TEXT,
|
182 |
-
data={
|
183 |
-
# The first example should be updated alone.
|
184 |
-
keys[0]: Example(id=keys[0], label=False, text='not in concept, updated'),
|
185 |
-
keys[1]: Example(id=keys[1], label=True, text='in concept'),
|
186 |
-
# Drafts are untouched.
|
187 |
-
keys[2]: Example(id=keys[2], label=False, text='really not in concept', draft='test_draft'),
|
188 |
-
keys[3]: Example(id=keys[3], label=True, text='really in concept', draft='test_draft'),
|
189 |
-
},
|
190 |
-
version=2)
|
191 |
-
|
192 |
-
# Edit the second example on the draft.
|
193 |
-
db.edit(
|
194 |
-
namespace, concept_name,
|
195 |
-
ConceptUpdate(update=[
|
196 |
-
Example(id=keys[3], label=True, text='really in concept, updated', draft='test_draft')
|
197 |
-
]))
|
198 |
-
concept = db.get(namespace, concept_name)
|
199 |
-
|
200 |
-
assert concept == Concept(
|
201 |
-
namespace=namespace,
|
202 |
-
concept_name=concept_name,
|
203 |
-
type=SignalInputType.TEXT,
|
204 |
-
data={
|
205 |
-
# Main remains the same.
|
206 |
-
keys[0]: Example(id=keys[0], label=False, text='not in concept, updated'),
|
207 |
-
keys[1]: Example(id=keys[1], label=True, text='in concept'),
|
208 |
-
keys[2]: Example(id=keys[2], label=False, text='really not in concept', draft='test_draft'),
|
209 |
-
keys[3]: Example(
|
210 |
-
id=keys[3], label=True, text='really in concept, updated', draft='test_draft'),
|
211 |
-
},
|
212 |
-
version=3)
|
213 |
-
|
214 |
-
def test_remove_concept(self, db_cls: Type[ConceptDB]) -> None:
|
215 |
-
db = db_cls()
|
216 |
-
namespace = 'test'
|
217 |
-
concept_name = 'test_concept'
|
218 |
-
db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
|
219 |
-
|
220 |
-
train_data = [
|
221 |
-
ExampleIn(label=False, text='not in concept'),
|
222 |
-
ExampleIn(label=True, text='in concept')
|
223 |
-
]
|
224 |
-
db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
|
225 |
-
concept = db.get(namespace, concept_name)
|
226 |
-
|
227 |
-
db.remove(namespace, concept_name)
|
228 |
-
|
229 |
-
concept = db.get(namespace, concept_name)
|
230 |
-
|
231 |
-
assert concept is None
|
232 |
-
|
233 |
-
def test_remove_concept_examples(self, db_cls: Type[ConceptDB]) -> None:
|
234 |
-
db = db_cls()
|
235 |
-
namespace = 'test'
|
236 |
-
concept_name = 'test_concept'
|
237 |
-
db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
|
238 |
-
|
239 |
-
train_data = [
|
240 |
-
ExampleIn(label=False, text='not in concept'),
|
241 |
-
ExampleIn(label=True, text='in concept')
|
242 |
-
]
|
243 |
-
db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
|
244 |
-
concept = db.get(namespace, concept_name)
|
245 |
-
assert concept is not None
|
246 |
-
|
247 |
-
keys = list(concept.data.keys())
|
248 |
-
|
249 |
-
db.edit(namespace, concept_name, ConceptUpdate(remove=[keys[0]]))
|
250 |
-
concept = db.get(namespace, concept_name)
|
251 |
-
|
252 |
-
assert concept == Concept(
|
253 |
-
namespace=namespace,
|
254 |
-
concept_name=concept_name,
|
255 |
-
type=SignalInputType.TEXT,
|
256 |
-
data={
|
257 |
-
# key_0 was removed.
|
258 |
-
keys[1]: Example(id=keys[1], label=True, text='in concept')
|
259 |
-
},
|
260 |
-
version=2)
|
261 |
-
|
262 |
-
def test_remove_concept_examples_draft(self, db_cls: Type[ConceptDB]) -> None:
|
263 |
-
db = db_cls()
|
264 |
-
namespace = 'test'
|
265 |
-
concept_name = 'test_concept'
|
266 |
-
train_data = [
|
267 |
-
ExampleIn(label=False, text='not in concept'),
|
268 |
-
ExampleIn(label=True, text='in concept'),
|
269 |
-
ExampleIn(label=False, text='really not in concept', draft='test_draft'),
|
270 |
-
ExampleIn(label=True, text='really in concept', draft='test_draft')
|
271 |
-
]
|
272 |
-
db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
|
273 |
-
db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
|
274 |
-
concept = db.get(namespace, concept_name)
|
275 |
-
assert concept is not None
|
276 |
-
|
277 |
-
keys = list(concept.data.keys())
|
278 |
-
|
279 |
-
db.edit(namespace, concept_name, ConceptUpdate(remove=[keys[2]]))
|
280 |
-
concept = db.get(namespace, concept_name)
|
281 |
-
|
282 |
-
assert concept == Concept(
|
283 |
-
namespace=namespace,
|
284 |
-
concept_name=concept_name,
|
285 |
-
type=SignalInputType.TEXT,
|
286 |
-
data={
|
287 |
-
keys[0]: Example(id=keys[0], label=False, text='not in concept'),
|
288 |
-
keys[1]: Example(id=keys[1], label=True, text='in concept'),
|
289 |
-
# The first draft example is removed.
|
290 |
-
keys[3]: Example(id=keys[3], label=True, text='really in concept', draft='test_draft'),
|
291 |
-
},
|
292 |
-
version=2)
|
293 |
-
|
294 |
-
def test_remove_invalid_id(self, db_cls: Type[ConceptDB]) -> None:
|
295 |
-
db = db_cls()
|
296 |
-
namespace = 'test'
|
297 |
-
concept_name = 'test_concept'
|
298 |
-
db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
|
299 |
-
|
300 |
-
train_data = [
|
301 |
-
ExampleIn(label=False, text='not in concept'),
|
302 |
-
ExampleIn(label=True, text='in concept'),
|
303 |
-
ExampleIn(label=False, text='really not in concept', draft='test_draft'),
|
304 |
-
ExampleIn(label=True, text='really in concept', draft='test_draft')
|
305 |
-
]
|
306 |
-
db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
|
307 |
-
|
308 |
-
with pytest.raises(ValueError, match='Example with id "invalid_id" does not exist'):
|
309 |
-
db.edit(namespace, concept_name, ConceptUpdate(remove=['invalid_id']))
|
310 |
-
|
311 |
-
def test_edit_before_creation(self, db_cls: Type[ConceptDB]) -> None:
|
312 |
-
db = db_cls()
|
313 |
-
namespace = 'test'
|
314 |
-
concept_name = 'test_concept'
|
315 |
-
|
316 |
-
with pytest.raises(
|
317 |
-
ValueError, match='Concept with namespace "test" and name "test_concept" does not exist'):
|
318 |
-
db.edit(namespace, concept_name,
|
319 |
-
ConceptUpdate(insert=[
|
320 |
-
ExampleIn(label=False, text='not in concept'),
|
321 |
-
]))
|
322 |
-
|
323 |
-
def test_edit_invalid_id(self, db_cls: Type[ConceptDB]) -> None:
|
324 |
-
db = db_cls()
|
325 |
-
namespace = 'test'
|
326 |
-
concept_name = 'test_concept'
|
327 |
-
db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
|
328 |
-
|
329 |
-
train_data = [
|
330 |
-
ExampleIn(label=False, text='not in concept'),
|
331 |
-
ExampleIn(label=True, text='in concept')
|
332 |
-
]
|
333 |
-
db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
|
334 |
-
|
335 |
-
with pytest.raises(ValueError, match='Example with id "invalid_id" does not exist'):
|
336 |
-
db.edit(namespace, concept_name,
|
337 |
-
ConceptUpdate(update=[Example(id='invalid_id', label=False, text='not in concept')]))
|
338 |
-
|
339 |
-
def test_merge_draft(self, db_cls: Type[ConceptDB]) -> None:
|
340 |
-
db = db_cls()
|
341 |
-
namespace = 'test'
|
342 |
-
concept_name = 'test_concept'
|
343 |
-
db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
|
344 |
-
|
345 |
-
train_data = [
|
346 |
-
ExampleIn(label=True, text='hello'),
|
347 |
-
ExampleIn(label=False, text='world'),
|
348 |
-
ExampleIn(label=True, text='hello draft 1', draft='draft1'),
|
349 |
-
ExampleIn(label=False, text='world draft 1', draft='draft1'),
|
350 |
-
# Duplicate of main.
|
351 |
-
ExampleIn(label=False, text='hello', draft='draft2'),
|
352 |
-
ExampleIn(label=True, text='world draft 2', draft='draft2'),
|
353 |
-
]
|
354 |
-
db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
|
355 |
-
|
356 |
-
db.merge_draft(namespace, concept_name, 'draft1')
|
357 |
-
|
358 |
-
concept = db.get(namespace, concept_name)
|
359 |
-
assert concept is not None
|
360 |
-
keys = list(concept.data.keys())
|
361 |
-
|
362 |
-
assert concept.dict() == Concept(
|
363 |
-
namespace='test',
|
364 |
-
concept_name='test_concept',
|
365 |
-
type=SignalInputType.TEXT,
|
366 |
-
data={
|
367 |
-
keys[0]: Example(id=keys[0], label=True, text='hello'),
|
368 |
-
keys[1]: Example(id=keys[1], label=False, text='world'),
|
369 |
-
# Draft examples are merged.
|
370 |
-
keys[2]: Example(id=keys[2], label=True, text='hello draft 1'),
|
371 |
-
keys[3]: Example(id=keys[3], label=False, text='world draft 1'),
|
372 |
-
# Draft 2 is untouched.
|
373 |
-
keys[4]: Example(id=keys[4], label=False, text='hello', draft='draft2'),
|
374 |
-
keys[5]: Example(id=keys[5], label=True, text='world draft 2', draft='draft2'),
|
375 |
-
},
|
376 |
-
version=2).dict()
|
377 |
-
|
378 |
-
db.merge_draft(namespace, concept_name, 'draft2')
|
379 |
-
|
380 |
-
concept = db.get(namespace, concept_name)
|
381 |
-
assert concept is not None
|
382 |
-
|
383 |
-
assert concept == Concept(
|
384 |
-
namespace='test',
|
385 |
-
concept_name='test_concept',
|
386 |
-
type=SignalInputType.TEXT,
|
387 |
-
data={
|
388 |
-
# The first example is a duplicate of the label from the draft, so it is removed.
|
389 |
-
keys[1]: Example(id=keys[1], label=False, text='world'),
|
390 |
-
# Draft examples are merged.
|
391 |
-
keys[2]: Example(id=keys[2], label=True, text='hello draft 1'),
|
392 |
-
keys[3]: Example(id=keys[3], label=False, text='world draft 1'),
|
393 |
-
# Draft examples are merged.
|
394 |
-
keys[4]: Example(id=keys[4], label=False, text='hello'),
|
395 |
-
keys[5]: Example(id=keys[5], label=True, text='world draft 2'),
|
396 |
-
},
|
397 |
-
version=3)
|
398 |
-
|
399 |
-
|
400 |
-
def _make_test_concept_model(
|
401 |
-
concept_db: ConceptDB,
|
402 |
-
logistic_models: dict[DraftId, LogisticEmbeddingModel] = {}) -> ConceptModel:
|
403 |
-
namespace = 'test'
|
404 |
-
concept_name = 'test_concept'
|
405 |
-
concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
|
406 |
-
|
407 |
-
train_data = [
|
408 |
-
ExampleIn(label=False, text='not in concept'),
|
409 |
-
ExampleIn(label=True, text='in concept')
|
410 |
-
]
|
411 |
-
concept_db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
|
412 |
-
model = ConceptModel(
|
413 |
-
namespace='test', concept_name='test_concept', embedding_name='test_embedding')
|
414 |
-
model._logistic_models = logistic_models
|
415 |
-
return model
|
416 |
-
|
417 |
-
|
418 |
-
class TestLogisticModel(LogisticEmbeddingModel):
|
419 |
-
|
420 |
-
@override
|
421 |
-
def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
|
422 |
-
"""Get the scores for the provided embeddings."""
|
423 |
-
return np.array([.1])
|
424 |
-
|
425 |
-
@override
|
426 |
-
def fit(self, embeddings: np.ndarray, labels: list[bool],
|
427 |
-
implicit_negatives: Optional[np.ndarray]) -> None:
|
428 |
-
pass
|
429 |
-
|
430 |
-
|
431 |
-
@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS)
|
432 |
-
@pytest.mark.parametrize('model_db_cls', ALL_CONCEPT_MODEL_DBS)
|
433 |
-
class ConceptModelDBSuite:
|
434 |
-
|
435 |
-
def test_save_and_get_model(self, concept_db_cls: Type[ConceptDB],
|
436 |
-
model_db_cls: Type[ConceptModelDB]) -> None:
|
437 |
-
concept_db = concept_db_cls()
|
438 |
-
model_db = model_db_cls(concept_db)
|
439 |
-
model = _make_test_concept_model(concept_db)
|
440 |
-
model_db.sync(model)
|
441 |
-
retrieved_model = model_db.get(
|
442 |
-
namespace='test', concept_name='test_concept', embedding_name='test_embedding')
|
443 |
-
if not retrieved_model:
|
444 |
-
retrieved_model = model_db.create(
|
445 |
-
namespace='test', concept_name='test_concept', embedding_name='test_embedding')
|
446 |
-
assert retrieved_model.namespace == model.namespace
|
447 |
-
assert retrieved_model.concept_name == model.concept_name
|
448 |
-
assert retrieved_model.embedding_name == model.embedding_name
|
449 |
-
assert retrieved_model.version == model.version
|
450 |
-
assert retrieved_model.column_info == model.column_info
|
451 |
-
|
452 |
-
def test_sync_model(self, concept_db_cls: Type[ConceptDB], model_db_cls: Type[ConceptModelDB],
|
453 |
-
mocker: MockerFixture) -> None:
|
454 |
-
|
455 |
-
concept_db = concept_db_cls()
|
456 |
-
model_db = model_db_cls(concept_db)
|
457 |
-
logistic_model = TestLogisticModel()
|
458 |
-
score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
|
459 |
-
fit_mock = mocker.spy(TestLogisticModel, 'fit')
|
460 |
-
|
461 |
-
model = _make_test_concept_model(concept_db, logistic_models={DRAFT_MAIN: logistic_model})
|
462 |
-
|
463 |
-
assert model_db.in_sync(model) is False
|
464 |
-
assert score_embeddings_mock.call_count == 0
|
465 |
-
assert fit_mock.call_count == 0
|
466 |
-
|
467 |
-
model_db.sync(model)
|
468 |
-
|
469 |
-
assert model_db.in_sync(model) is True
|
470 |
-
assert score_embeddings_mock.call_count == 0
|
471 |
-
assert fit_mock.call_count == 1
|
472 |
-
|
473 |
-
def test_out_of_sync_model(self, concept_db_cls: Type[ConceptDB],
|
474 |
-
model_db_cls: Type[ConceptModelDB], mocker: MockerFixture) -> None:
|
475 |
-
concept_db = concept_db_cls()
|
476 |
-
model_db = model_db_cls(concept_db)
|
477 |
-
score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
|
478 |
-
fit_mock = mocker.spy(TestLogisticModel, 'fit')
|
479 |
-
logistic_model = TestLogisticModel()
|
480 |
-
model = _make_test_concept_model(concept_db, logistic_models={DRAFT_MAIN: logistic_model})
|
481 |
-
model_db.sync(model)
|
482 |
-
assert model_db.in_sync(model) is True
|
483 |
-
assert score_embeddings_mock.call_count == 0
|
484 |
-
assert fit_mock.call_count == 1
|
485 |
-
|
486 |
-
(called_model, called_embeddings, called_labels,
|
487 |
-
called_implicit_negatives) = fit_mock.call_args_list[-1].args
|
488 |
-
assert called_model == logistic_model
|
489 |
-
np.testing.assert_array_equal(
|
490 |
-
called_embeddings, np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))
|
491 |
-
assert called_labels == [False, True]
|
492 |
-
assert called_implicit_negatives is None
|
493 |
-
|
494 |
-
# Edit the concept.
|
495 |
-
concept_db.edit('test', 'test_concept',
|
496 |
-
ConceptUpdate(insert=[ExampleIn(label=False, text='a new data point')]))
|
497 |
-
|
498 |
-
# Make sure the model is out of sync.
|
499 |
-
assert model_db.in_sync(model) is False
|
500 |
-
assert score_embeddings_mock.call_count == 0
|
501 |
-
assert fit_mock.call_count == 1
|
502 |
-
|
503 |
-
model_db.sync(model)
|
504 |
-
assert model_db.in_sync(model) is True
|
505 |
-
assert score_embeddings_mock.call_count == 0
|
506 |
-
assert fit_mock.call_count == 2
|
507 |
-
# Fit is called again with new points on main only.
|
508 |
-
(called_model, called_embeddings, called_labels,
|
509 |
-
called_implicit_negatives) = fit_mock.call_args_list[-1].args
|
510 |
-
assert called_model == logistic_model
|
511 |
-
np.testing.assert_array_equal(
|
512 |
-
called_embeddings,
|
513 |
-
np.array([
|
514 |
-
EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept'],
|
515 |
-
EMBEDDING_MAP['a new data point']
|
516 |
-
]))
|
517 |
-
assert called_labels == [False, True, False]
|
518 |
-
assert called_implicit_negatives is None
|
519 |
-
|
520 |
-
def test_out_of_sync_draft_model(self, concept_db_cls: Type[ConceptDB],
|
521 |
-
model_db_cls: Type[ConceptModelDB],
|
522 |
-
mocker: MockerFixture) -> None:
|
523 |
-
concept_db = concept_db_cls()
|
524 |
-
model_db = model_db_cls(concept_db)
|
525 |
-
score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
|
526 |
-
fit_mock = mocker.spy(TestLogisticModel, 'fit')
|
527 |
-
main_model = TestLogisticModel()
|
528 |
-
draft_model = TestLogisticModel()
|
529 |
-
model = _make_test_concept_model(
|
530 |
-
concept_db, logistic_models={
|
531 |
-
DRAFT_MAIN: main_model,
|
532 |
-
'test_draft': draft_model
|
533 |
-
})
|
534 |
-
model_db.sync(model)
|
535 |
-
assert model_db.in_sync(model) is True
|
536 |
-
assert score_embeddings_mock.call_count == 0
|
537 |
-
assert fit_mock.call_count == 1
|
538 |
-
|
539 |
-
# Make sure drafts cause the model to be out of sync.
|
540 |
-
concept_db.edit(
|
541 |
-
'test',
|
542 |
-
'test_concept',
|
543 |
-
ConceptUpdate(insert=[
|
544 |
-
ExampleIn(label=True, text='a true draft point', draft='test_draft'),
|
545 |
-
ExampleIn(label=False, text='a false draft point', draft='test_draft'),
|
546 |
-
# This point exists in main, but we switched the label.
|
547 |
-
ExampleIn(label=False, text='in concept', draft='test_draft'),
|
548 |
-
]))
|
549 |
-
|
550 |
-
# Make sure the model is out of sync.
|
551 |
-
assert model_db.in_sync(model) is False
|
552 |
-
assert score_embeddings_mock.call_count == 0
|
553 |
-
assert fit_mock.call_count == 1
|
554 |
-
|
555 |
-
model_db.sync(model)
|
556 |
-
assert model_db.in_sync(model) is True
|
557 |
-
assert score_embeddings_mock.call_count == 0
|
558 |
-
assert fit_mock.call_count == 3 # Fit is called on both the draft, and main.
|
559 |
-
|
560 |
-
# Fit is called again with the same points.
|
561 |
-
((called_model, called_embeddings, called_labels, called_implicit_negatives),
|
562 |
-
(called_draft_model, called_draft_embeddings, called_draft_labels,
|
563 |
-
called_draft_implicit_negatives)) = (
|
564 |
-
c.args for c in fit_mock.call_args_list[-2:])
|
565 |
-
|
566 |
-
# The draft model is called with the data from main, and the data from draft.
|
567 |
-
assert called_draft_model == draft_model
|
568 |
-
np.testing.assert_array_equal(
|
569 |
-
called_draft_embeddings,
|
570 |
-
np.array([
|
571 |
-
EMBEDDING_MAP['a true draft point'], EMBEDDING_MAP['a false draft point'],
|
572 |
-
EMBEDDING_MAP['in concept'], EMBEDDING_MAP['not in concept']
|
573 |
-
]))
|
574 |
-
assert called_draft_labels == [
|
575 |
-
True,
|
576 |
-
False,
|
577 |
-
# This was overriden by the draft.
|
578 |
-
False,
|
579 |
-
False
|
580 |
-
]
|
581 |
-
assert called_draft_implicit_negatives is None
|
582 |
-
|
583 |
-
# The main model was fit without the data from the draft.
|
584 |
-
assert called_model == main_model
|
585 |
-
np.testing.assert_array_equal(
|
586 |
-
called_embeddings, np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))
|
587 |
-
assert called_labels == [False, True]
|
588 |
-
assert called_implicit_negatives is None
|
589 |
-
|
590 |
-
def test_embedding_not_found_in_map(self, concept_db_cls: Type[ConceptDB],
|
591 |
-
model_db_cls: Type[ConceptModelDB]) -> None:
|
592 |
-
concept_db = concept_db_cls()
|
593 |
-
model_db = model_db_cls(concept_db)
|
594 |
-
model = _make_test_concept_model(concept_db)
|
595 |
-
model_db.sync(model)
|
596 |
-
|
597 |
-
# Edit the concept.
|
598 |
-
concept_db.edit('test', 'test_concept',
|
599 |
-
ConceptUpdate(insert=[ExampleIn(label=False, text='unknown text')]))
|
600 |
-
|
601 |
-
# Make sure the model is out of sync.
|
602 |
-
assert model_db.in_sync(model) is False
|
603 |
-
|
604 |
-
with pytest.raises(ValueError, match='Example "unknown text" not in embedding map'):
|
605 |
-
model_db.sync(model)
|
606 |
-
model_db.sync(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/dataset_compute_signal_chain_test.py
DELETED
@@ -1,255 +0,0 @@
|
|
1 |
-
"""Tests for dataset.compute_signal() when signals are chained."""
|
2 |
-
|
3 |
-
import re
|
4 |
-
from typing import Iterable, List, Optional, cast
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import pytest
|
8 |
-
from pytest_mock import MockerFixture
|
9 |
-
from typing_extensions import override
|
10 |
-
|
11 |
-
from ..embeddings.vector_store import VectorStore
|
12 |
-
from ..schema import UUID_COLUMN, Field, Item, RichData, VectorKey, field, schema
|
13 |
-
from ..signals.signal import (
|
14 |
-
TextEmbeddingModelSignal,
|
15 |
-
TextEmbeddingSignal,
|
16 |
-
TextSignal,
|
17 |
-
TextSplitterSignal,
|
18 |
-
clear_signal_registry,
|
19 |
-
register_signal,
|
20 |
-
)
|
21 |
-
from .dataset import DatasetManifest
|
22 |
-
from .dataset_test_utils import (
|
23 |
-
TEST_DATASET_NAME,
|
24 |
-
TEST_NAMESPACE,
|
25 |
-
TestDataMaker,
|
26 |
-
enriched_embedding_span,
|
27 |
-
enriched_embedding_span_field,
|
28 |
-
enriched_item,
|
29 |
-
)
|
30 |
-
from .dataset_utils import lilac_embedding, lilac_span
|
31 |
-
|
32 |
-
SIMPLE_ITEMS: list[Item] = [{
|
33 |
-
UUID_COLUMN: '1',
|
34 |
-
'str': 'a',
|
35 |
-
'int': 1,
|
36 |
-
'bool': False,
|
37 |
-
'float': 3.0
|
38 |
-
}, {
|
39 |
-
UUID_COLUMN: '2',
|
40 |
-
'str': 'b',
|
41 |
-
'int': 2,
|
42 |
-
'bool': True,
|
43 |
-
'float': 2.0
|
44 |
-
}, {
|
45 |
-
UUID_COLUMN: '3',
|
46 |
-
'str': 'b',
|
47 |
-
'int': 2,
|
48 |
-
'bool': True,
|
49 |
-
'float': 1.0
|
50 |
-
}]
|
51 |
-
|
52 |
-
EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]),
|
53 |
-
('hello2.', [1.0, 1.0, 0.0]),
|
54 |
-
('hello world.', [1.0, 1.0, 1.0]),
|
55 |
-
('hello world2.', [2.0, 1.0, 1.0])]
|
56 |
-
|
57 |
-
STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS}
|
58 |
-
|
59 |
-
|
60 |
-
class TestSplitter(TextSplitterSignal):
|
61 |
-
"""Split documents into sentence by splitting on period."""
|
62 |
-
name = 'test_splitter'
|
63 |
-
|
64 |
-
@override
|
65 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
|
66 |
-
for text in data:
|
67 |
-
if not isinstance(text, str):
|
68 |
-
raise ValueError(f'Expected text to be a string, got {type(text)} instead.')
|
69 |
-
sentences = [f'{sentence.strip()}.' for sentence in text.split('.') if sentence]
|
70 |
-
yield [
|
71 |
-
lilac_span(text.index(sentence),
|
72 |
-
text.index(sentence) + len(sentence)) for sentence in sentences
|
73 |
-
]
|
74 |
-
|
75 |
-
|
76 |
-
class TestEmbedding(TextEmbeddingSignal):
|
77 |
-
"""A test embed function."""
|
78 |
-
name = 'test_embedding'
|
79 |
-
|
80 |
-
@override
|
81 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
|
82 |
-
"""Call the embedding function."""
|
83 |
-
for example in data:
|
84 |
-
yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))]
|
85 |
-
|
86 |
-
|
87 |
-
class TestEmbeddingSumSignal(TextEmbeddingModelSignal):
|
88 |
-
"""Sums the embeddings to return a single floating point value."""
|
89 |
-
name = 'test_embedding_sum'
|
90 |
-
|
91 |
-
@override
|
92 |
-
def fields(self) -> Field:
|
93 |
-
return field('float32')
|
94 |
-
|
95 |
-
@override
|
96 |
-
def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Item]:
|
97 |
-
# The signal just sums the values of the embedding.
|
98 |
-
embedding_sums = vector_store.get(keys).sum(axis=1)
|
99 |
-
for embedding_sum in embedding_sums.tolist():
|
100 |
-
yield embedding_sum
|
101 |
-
|
102 |
-
|
103 |
-
@pytest.fixture(scope='module', autouse=True)
|
104 |
-
def setup_teardown() -> Iterable[None]:
|
105 |
-
# Setup.
|
106 |
-
register_signal(TestSplitter)
|
107 |
-
register_signal(TestEmbedding)
|
108 |
-
register_signal(TestEmbeddingSumSignal)
|
109 |
-
register_signal(NamedEntity)
|
110 |
-
# Unit test runs.
|
111 |
-
yield
|
112 |
-
# Teardown.
|
113 |
-
clear_signal_registry()
|
114 |
-
|
115 |
-
|
116 |
-
def test_manual_embedding_signal(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
|
117 |
-
dataset = make_test_data([{
|
118 |
-
UUID_COLUMN: '1',
|
119 |
-
'text': 'hello.',
|
120 |
-
}, {
|
121 |
-
UUID_COLUMN: '2',
|
122 |
-
'text': 'hello2.',
|
123 |
-
}])
|
124 |
-
|
125 |
-
embed_mock = mocker.spy(TestEmbedding, 'compute')
|
126 |
-
|
127 |
-
embedding_signal = TestEmbedding()
|
128 |
-
dataset.compute_signal(embedding_signal, 'text')
|
129 |
-
embedding_sum_signal = TestEmbeddingSumSignal(embedding=TestEmbedding.name)
|
130 |
-
dataset.compute_signal(embedding_sum_signal, 'text')
|
131 |
-
|
132 |
-
# Make sure the embedding signal is not called twice.
|
133 |
-
assert embed_mock.call_count == 1
|
134 |
-
|
135 |
-
assert dataset.manifest() == DatasetManifest(
|
136 |
-
namespace=TEST_NAMESPACE,
|
137 |
-
dataset_name=TEST_DATASET_NAME,
|
138 |
-
data_schema=schema({
|
139 |
-
UUID_COLUMN: 'string',
|
140 |
-
'text': field(
|
141 |
-
'string',
|
142 |
-
fields={
|
143 |
-
'test_embedding': field(
|
144 |
-
signal=embedding_signal.dict(),
|
145 |
-
fields=[
|
146 |
-
enriched_embedding_span_field(
|
147 |
-
{'test_embedding_sum': field('float32', embedding_sum_signal.dict())})
|
148 |
-
])
|
149 |
-
}),
|
150 |
-
}),
|
151 |
-
num_items=2)
|
152 |
-
|
153 |
-
result = dataset.select_rows()
|
154 |
-
expected_result = [{
|
155 |
-
UUID_COLUMN: '1',
|
156 |
-
'text': enriched_item(
|
157 |
-
'hello.', {'test_embedding': [enriched_embedding_span(0, 6, {'test_embedding_sum': 1.0})]})
|
158 |
-
}, {
|
159 |
-
UUID_COLUMN: '2',
|
160 |
-
'text': enriched_item(
|
161 |
-
'hello2.', {'test_embedding': [enriched_embedding_span(0, 7, {'test_embedding_sum': 2.0})]})
|
162 |
-
}]
|
163 |
-
assert list(result) == expected_result
|
164 |
-
|
165 |
-
|
166 |
-
def test_auto_embedding_signal(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
|
167 |
-
dataset = make_test_data([{
|
168 |
-
UUID_COLUMN: '1',
|
169 |
-
'text': 'hello.',
|
170 |
-
}, {
|
171 |
-
UUID_COLUMN: '2',
|
172 |
-
'text': 'hello2.',
|
173 |
-
}])
|
174 |
-
|
175 |
-
embed_mock = mocker.spy(TestEmbedding, 'compute')
|
176 |
-
|
177 |
-
# The embedding is automatically computed from the TestEmbeddingSumSignal.
|
178 |
-
embedding_sum_signal = TestEmbeddingSumSignal(embedding=TestEmbedding.name)
|
179 |
-
dataset.compute_signal(embedding_sum_signal, 'text')
|
180 |
-
|
181 |
-
# Make sure the embedding signal is not called twice.
|
182 |
-
assert embed_mock.call_count == 1
|
183 |
-
|
184 |
-
assert dataset.manifest() == DatasetManifest(
|
185 |
-
namespace=TEST_NAMESPACE,
|
186 |
-
dataset_name=TEST_DATASET_NAME,
|
187 |
-
data_schema=schema({
|
188 |
-
UUID_COLUMN: 'string',
|
189 |
-
'text': field(
|
190 |
-
'string',
|
191 |
-
fields={
|
192 |
-
'test_embedding': field(
|
193 |
-
signal=embedding_sum_signal._embedding_signal.dict(),
|
194 |
-
fields=[
|
195 |
-
enriched_embedding_span_field(
|
196 |
-
{'test_embedding_sum': field('float32', embedding_sum_signal.dict())})
|
197 |
-
])
|
198 |
-
}),
|
199 |
-
}),
|
200 |
-
num_items=2)
|
201 |
-
|
202 |
-
result = dataset.select_rows()
|
203 |
-
expected_result = [{
|
204 |
-
UUID_COLUMN: '1',
|
205 |
-
'text': enriched_item(
|
206 |
-
'hello.', {'test_embedding': [enriched_embedding_span(0, 6, {'test_embedding_sum': 1.0})]})
|
207 |
-
}, {
|
208 |
-
UUID_COLUMN: '2',
|
209 |
-
'text': enriched_item(
|
210 |
-
'hello2.', {'test_embedding': [enriched_embedding_span(0, 7, {'test_embedding_sum': 2.0})]})
|
211 |
-
}]
|
212 |
-
assert list(result) == expected_result
|
213 |
-
|
214 |
-
|
215 |
-
ENTITY_REGEX = r'[A-Za-z]+@[A-Za-z]+'
|
216 |
-
|
217 |
-
|
218 |
-
class NamedEntity(TextSignal):
|
219 |
-
"""Find special entities."""
|
220 |
-
name = 'entity'
|
221 |
-
|
222 |
-
@override
|
223 |
-
def fields(self) -> Field:
|
224 |
-
return field(fields=['string_span'])
|
225 |
-
|
226 |
-
@override
|
227 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[List[Item]]]:
|
228 |
-
for text in data:
|
229 |
-
if not isinstance(text, str):
|
230 |
-
yield None
|
231 |
-
continue
|
232 |
-
yield [lilac_span(m.start(0), m.end(0)) for m in re.finditer(ENTITY_REGEX, text)]
|
233 |
-
|
234 |
-
|
235 |
-
def test_entity_on_split_signal(make_test_data: TestDataMaker) -> None:
|
236 |
-
text = 'Hello nik@test. Here are some other entities like pii@gmail and all@lilac.'
|
237 |
-
dataset = make_test_data([{UUID_COLUMN: '1', 'text': text}])
|
238 |
-
entity = NamedEntity()
|
239 |
-
dataset.compute_signal(TestSplitter(), 'text')
|
240 |
-
dataset.compute_signal(entity, ('text', 'test_splitter', '*'))
|
241 |
-
|
242 |
-
result = dataset.select_rows(['text'])
|
243 |
-
assert list(result) == [{
|
244 |
-
UUID_COLUMN: '1',
|
245 |
-
'text': enriched_item(
|
246 |
-
text, {
|
247 |
-
'test_splitter': [
|
248 |
-
lilac_span(0, 15, {'entity': [lilac_span(6, 14)]}),
|
249 |
-
lilac_span(16, 74, {'entity': [
|
250 |
-
lilac_span(50, 59),
|
251 |
-
lilac_span(64, 73),
|
252 |
-
]}),
|
253 |
-
]
|
254 |
-
})
|
255 |
-
}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/dataset_compute_signal_test.py
DELETED
@@ -1,669 +0,0 @@
|
|
1 |
-
"""Tests for dataset.compute_signal()."""
|
2 |
-
|
3 |
-
from typing import Iterable, Optional, Union, cast
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import pytest
|
7 |
-
from typing_extensions import override
|
8 |
-
|
9 |
-
from ..concepts.concept import ExampleIn
|
10 |
-
from ..concepts.db_concept import ConceptUpdate, DiskConceptDB
|
11 |
-
from ..schema import UUID_COLUMN, VALUE_KEY, Field, Item, RichData, SignalInputType, field, schema
|
12 |
-
from ..signals.concept_scorer import ConceptScoreSignal
|
13 |
-
from ..signals.signal import (
|
14 |
-
TextEmbeddingSignal,
|
15 |
-
TextSignal,
|
16 |
-
TextSplitterSignal,
|
17 |
-
clear_signal_registry,
|
18 |
-
register_signal,
|
19 |
-
)
|
20 |
-
from .dataset import Column, DatasetManifest, GroupsSortBy, SortOrder, val
|
21 |
-
from .dataset_test_utils import (
|
22 |
-
TEST_DATASET_NAME,
|
23 |
-
TEST_NAMESPACE,
|
24 |
-
TestDataMaker,
|
25 |
-
enriched_embedding_span_field,
|
26 |
-
enriched_item,
|
27 |
-
)
|
28 |
-
from .dataset_utils import lilac_embedding, lilac_span
|
29 |
-
|
30 |
-
SIMPLE_ITEMS: list[Item] = [{
|
31 |
-
UUID_COLUMN: '1',
|
32 |
-
'str': 'a',
|
33 |
-
'int': 1,
|
34 |
-
'bool': False,
|
35 |
-
'float': 3.0
|
36 |
-
}, {
|
37 |
-
UUID_COLUMN: '2',
|
38 |
-
'str': 'b',
|
39 |
-
'int': 2,
|
40 |
-
'bool': True,
|
41 |
-
'float': 2.0
|
42 |
-
}, {
|
43 |
-
UUID_COLUMN: '3',
|
44 |
-
'str': 'b',
|
45 |
-
'int': 2,
|
46 |
-
'bool': True,
|
47 |
-
'float': 1.0
|
48 |
-
}]
|
49 |
-
|
50 |
-
|
51 |
-
class TestInvalidSignal(TextSignal):
|
52 |
-
name = 'test_invalid_signal'
|
53 |
-
|
54 |
-
@override
|
55 |
-
def fields(self) -> Field:
|
56 |
-
return field('int32')
|
57 |
-
|
58 |
-
@override
|
59 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
60 |
-
# Return an invalid output that doesn't match the input length.
|
61 |
-
return []
|
62 |
-
|
63 |
-
|
64 |
-
class TestSparseSignal(TextSignal):
|
65 |
-
name = 'test_sparse_signal'
|
66 |
-
|
67 |
-
@override
|
68 |
-
def fields(self) -> Field:
|
69 |
-
return field('int32')
|
70 |
-
|
71 |
-
@override
|
72 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
73 |
-
for text in data:
|
74 |
-
if text == 'hello':
|
75 |
-
# Skip this input.
|
76 |
-
yield None
|
77 |
-
else:
|
78 |
-
yield len(text)
|
79 |
-
|
80 |
-
|
81 |
-
class TestSparseRichSignal(TextSignal):
|
82 |
-
"""Find personally identifiable information (emails, phone numbers, etc)."""
|
83 |
-
name = 'test_sparse_rich_signal'
|
84 |
-
|
85 |
-
@override
|
86 |
-
def fields(self) -> Field:
|
87 |
-
return field(fields={'emails': ['string']})
|
88 |
-
|
89 |
-
@override
|
90 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
91 |
-
for text in data:
|
92 |
-
if text == 'hello':
|
93 |
-
# Skip this input.
|
94 |
-
yield None
|
95 |
-
else:
|
96 |
-
yield {'emails': ['test1@hello.com', 'test2@hello.com']}
|
97 |
-
|
98 |
-
|
99 |
-
class TestParamSignal(TextSignal):
|
100 |
-
name = 'param_signal'
|
101 |
-
param: str
|
102 |
-
|
103 |
-
def fields(self) -> Field:
|
104 |
-
return field('string')
|
105 |
-
|
106 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
107 |
-
for text_content in data:
|
108 |
-
yield f'{str(text_content)}_{self.param}'
|
109 |
-
|
110 |
-
|
111 |
-
class TestSignal(TextSignal):
|
112 |
-
name = 'test_signal'
|
113 |
-
|
114 |
-
@override
|
115 |
-
def fields(self) -> Field:
|
116 |
-
return field(fields={'len': 'int32', 'flen': 'float32'})
|
117 |
-
|
118 |
-
@override
|
119 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
120 |
-
return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data]
|
121 |
-
|
122 |
-
|
123 |
-
class TestSplitSignal(TextSplitterSignal):
|
124 |
-
"""Split documents into sentence by splitting on period, generating entities."""
|
125 |
-
name = 'test_split'
|
126 |
-
|
127 |
-
@override
|
128 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
|
129 |
-
for text in data:
|
130 |
-
if not isinstance(text, str):
|
131 |
-
raise ValueError(f'Expected text to be a string, got {type(text)} instead.')
|
132 |
-
sentences = [f'{sentence.strip()}.' for sentence in text.split('.') if sentence]
|
133 |
-
yield [
|
134 |
-
lilac_span(text.index(sentence),
|
135 |
-
text.index(sentence) + len(sentence)) for sentence in sentences
|
136 |
-
]
|
137 |
-
|
138 |
-
|
139 |
-
EMBEDDINGS: list[tuple[str, Union[list[float], list[list[float]]]]] = [
|
140 |
-
('hello.', [1.0, 0.0, 0.0]),
|
141 |
-
# This embedding has an outer dimension of 1.
|
142 |
-
('hello2.', [[1.0, 1.0, 0.0]]),
|
143 |
-
('hello3.', [[0, 0, 1.]])
|
144 |
-
]
|
145 |
-
|
146 |
-
STR_EMBEDDINGS: dict[str, Union[list[float], list[list[float]]]] = {
|
147 |
-
text: embedding for text, embedding in EMBEDDINGS
|
148 |
-
}
|
149 |
-
|
150 |
-
|
151 |
-
class TestEmbedding(TextEmbeddingSignal):
|
152 |
-
"""A test embed function."""
|
153 |
-
name = 'test_embedding'
|
154 |
-
|
155 |
-
@override
|
156 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
|
157 |
-
"""Call the embedding function."""
|
158 |
-
for example in data:
|
159 |
-
example = cast(str, example)
|
160 |
-
yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[example]))]
|
161 |
-
|
162 |
-
|
163 |
-
class ComputedKeySignal(TextSignal):
|
164 |
-
name = 'computed_key'
|
165 |
-
|
166 |
-
@override
|
167 |
-
def fields(self) -> Field:
|
168 |
-
return field('int64')
|
169 |
-
|
170 |
-
@override
|
171 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
172 |
-
for text in data:
|
173 |
-
yield 1
|
174 |
-
|
175 |
-
def key(self, is_computed_signal: Optional[bool] = False) -> str:
|
176 |
-
return f'key_{is_computed_signal}'
|
177 |
-
|
178 |
-
|
179 |
-
@pytest.fixture(scope='module', autouse=True)
|
180 |
-
def setup_teardown() -> Iterable[None]:
|
181 |
-
# Setup.
|
182 |
-
register_signal(TestSparseSignal)
|
183 |
-
register_signal(TestSparseRichSignal)
|
184 |
-
register_signal(TestParamSignal)
|
185 |
-
register_signal(TestSignal)
|
186 |
-
register_signal(TestSplitSignal)
|
187 |
-
register_signal(TestEmbedding)
|
188 |
-
register_signal(ComputedKeySignal)
|
189 |
-
register_signal(ConceptScoreSignal)
|
190 |
-
|
191 |
-
# Unit test runs.
|
192 |
-
yield
|
193 |
-
# Teardown.
|
194 |
-
clear_signal_registry()
|
195 |
-
|
196 |
-
|
197 |
-
def test_signal_output_validation(make_test_data: TestDataMaker) -> None:
|
198 |
-
signal = TestInvalidSignal()
|
199 |
-
|
200 |
-
dataset = make_test_data([{
|
201 |
-
UUID_COLUMN: '1',
|
202 |
-
'text': 'hello',
|
203 |
-
}, {
|
204 |
-
UUID_COLUMN: '2',
|
205 |
-
'text': 'hello world',
|
206 |
-
}])
|
207 |
-
|
208 |
-
with pytest.raises(
|
209 |
-
ValueError, match='The signal generated 0 values but the input data had 2 values.'):
|
210 |
-
dataset.compute_signal(signal, 'text')
|
211 |
-
|
212 |
-
|
213 |
-
def test_sparse_signal(make_test_data: TestDataMaker) -> None:
|
214 |
-
dataset = make_test_data([{
|
215 |
-
UUID_COLUMN: '1',
|
216 |
-
'text': 'hello',
|
217 |
-
}, {
|
218 |
-
UUID_COLUMN: '2',
|
219 |
-
'text': 'hello world',
|
220 |
-
}])
|
221 |
-
|
222 |
-
dataset.compute_signal(TestSparseSignal(), 'text')
|
223 |
-
|
224 |
-
result = dataset.select_rows(['text'])
|
225 |
-
assert list(result) == [{
|
226 |
-
UUID_COLUMN: '1',
|
227 |
-
'text': enriched_item('hello', {'test_sparse_signal': None})
|
228 |
-
}, {
|
229 |
-
UUID_COLUMN: '2',
|
230 |
-
'text': enriched_item('hello world', {'test_sparse_signal': 11})
|
231 |
-
}]
|
232 |
-
|
233 |
-
|
234 |
-
def test_sparse_rich_signal(make_test_data: TestDataMaker) -> None:
|
235 |
-
dataset = make_test_data([{
|
236 |
-
UUID_COLUMN: '1',
|
237 |
-
'text': 'hello',
|
238 |
-
}, {
|
239 |
-
UUID_COLUMN: '2',
|
240 |
-
'text': 'hello world',
|
241 |
-
}])
|
242 |
-
|
243 |
-
dataset.compute_signal(TestSparseRichSignal(), 'text')
|
244 |
-
|
245 |
-
result = dataset.select_rows(['text'])
|
246 |
-
assert list(result) == [{
|
247 |
-
UUID_COLUMN: '1',
|
248 |
-
'text': enriched_item('hello', {'test_sparse_rich_signal': None})
|
249 |
-
}, {
|
250 |
-
UUID_COLUMN: '2',
|
251 |
-
'text': enriched_item(
|
252 |
-
'hello world',
|
253 |
-
{'test_sparse_rich_signal': {
|
254 |
-
'emails': ['test1@hello.com', 'test2@hello.com']
|
255 |
-
}})
|
256 |
-
}]
|
257 |
-
|
258 |
-
|
259 |
-
def test_source_joined_with_signal(make_test_data: TestDataMaker) -> None:
|
260 |
-
dataset = make_test_data(SIMPLE_ITEMS)
|
261 |
-
assert dataset.manifest() == DatasetManifest(
|
262 |
-
namespace=TEST_NAMESPACE,
|
263 |
-
dataset_name=TEST_DATASET_NAME,
|
264 |
-
data_schema=schema({
|
265 |
-
UUID_COLUMN: 'string',
|
266 |
-
'str': 'string',
|
267 |
-
'int': 'int32',
|
268 |
-
'bool': 'boolean',
|
269 |
-
'float': 'float32',
|
270 |
-
}),
|
271 |
-
num_items=3)
|
272 |
-
|
273 |
-
test_signal = TestSignal()
|
274 |
-
dataset.compute_signal(test_signal, 'str')
|
275 |
-
|
276 |
-
# Check the enriched dataset manifest has 'text' enriched.
|
277 |
-
assert dataset.manifest() == DatasetManifest(
|
278 |
-
namespace=TEST_NAMESPACE,
|
279 |
-
dataset_name=TEST_DATASET_NAME,
|
280 |
-
data_schema=schema({
|
281 |
-
UUID_COLUMN: 'string',
|
282 |
-
'str': field(
|
283 |
-
'string',
|
284 |
-
fields={
|
285 |
-
'test_signal': field(
|
286 |
-
signal=test_signal.dict(), fields={
|
287 |
-
'len': 'int32',
|
288 |
-
'flen': 'float32'
|
289 |
-
}),
|
290 |
-
}),
|
291 |
-
'int': 'int32',
|
292 |
-
'bool': 'boolean',
|
293 |
-
'float': 'float32',
|
294 |
-
}),
|
295 |
-
num_items=3)
|
296 |
-
|
297 |
-
result = dataset.select_rows(['str'])
|
298 |
-
assert list(result) == [{
|
299 |
-
UUID_COLUMN: '1',
|
300 |
-
'str': enriched_item('a', {'test_signal': {
|
301 |
-
'len': 1,
|
302 |
-
'flen': 1.0
|
303 |
-
}}),
|
304 |
-
}, {
|
305 |
-
UUID_COLUMN: '2',
|
306 |
-
'str': enriched_item('b', {'test_signal': {
|
307 |
-
'len': 1,
|
308 |
-
'flen': 1.0
|
309 |
-
}}),
|
310 |
-
}, {
|
311 |
-
UUID_COLUMN: '3',
|
312 |
-
'str': enriched_item('b', {'test_signal': {
|
313 |
-
'len': 1,
|
314 |
-
'flen': 1.0
|
315 |
-
}}),
|
316 |
-
}]
|
317 |
-
|
318 |
-
# Select a specific signal leaf test_signal.flen with val('str').
|
319 |
-
result = dataset.select_rows([val('str'), ('str', 'test_signal', 'flen')])
|
320 |
-
|
321 |
-
assert list(result) == [{
|
322 |
-
UUID_COLUMN: '1',
|
323 |
-
f'str.{VALUE_KEY}': 'a',
|
324 |
-
'str.test_signal.flen': 1.0
|
325 |
-
}, {
|
326 |
-
UUID_COLUMN: '2',
|
327 |
-
f'str.{VALUE_KEY}': 'b',
|
328 |
-
'str.test_signal.flen': 1.0
|
329 |
-
}, {
|
330 |
-
UUID_COLUMN: '3',
|
331 |
-
f'str.{VALUE_KEY}': 'b',
|
332 |
-
'str.test_signal.flen': 1.0
|
333 |
-
}]
|
334 |
-
|
335 |
-
# Select a specific signal leaf test_signal.flen and the whole 'str' subtree.
|
336 |
-
result = dataset.select_rows(['str', ('str', 'test_signal', 'flen')])
|
337 |
-
|
338 |
-
assert list(result) == [{
|
339 |
-
UUID_COLUMN: '1',
|
340 |
-
'str': enriched_item('a', {'test_signal': {
|
341 |
-
'len': 1,
|
342 |
-
'flen': 1.0
|
343 |
-
}}),
|
344 |
-
'str.test_signal.flen': 1.0
|
345 |
-
}, {
|
346 |
-
UUID_COLUMN: '2',
|
347 |
-
'str': enriched_item('b', {'test_signal': {
|
348 |
-
'len': 1,
|
349 |
-
'flen': 1.0
|
350 |
-
}}),
|
351 |
-
'str.test_signal.flen': 1.0
|
352 |
-
}, {
|
353 |
-
UUID_COLUMN: '3',
|
354 |
-
'str': enriched_item('b', {'test_signal': {
|
355 |
-
'len': 1,
|
356 |
-
'flen': 1.0
|
357 |
-
}}),
|
358 |
-
'str.test_signal.flen': 1.0
|
359 |
-
}]
|
360 |
-
|
361 |
-
# Select multiple signal leafs with aliasing.
|
362 |
-
result = dataset.select_rows([
|
363 |
-
val('str'),
|
364 |
-
Column(('str', 'test_signal', 'flen'), alias='flen'),
|
365 |
-
Column(('str', 'test_signal', 'len'), alias='len')
|
366 |
-
])
|
367 |
-
|
368 |
-
assert list(result) == [{
|
369 |
-
UUID_COLUMN: '1',
|
370 |
-
f'str.{VALUE_KEY}': 'a',
|
371 |
-
'flen': 1.0,
|
372 |
-
'len': 1
|
373 |
-
}, {
|
374 |
-
UUID_COLUMN: '2',
|
375 |
-
f'str.{VALUE_KEY}': 'b',
|
376 |
-
'flen': 1.0,
|
377 |
-
'len': 1
|
378 |
-
}, {
|
379 |
-
UUID_COLUMN: '3',
|
380 |
-
f'str.{VALUE_KEY}': 'b',
|
381 |
-
'flen': 1.0,
|
382 |
-
'len': 1
|
383 |
-
}]
|
384 |
-
|
385 |
-
|
386 |
-
def test_parameterized_signal(make_test_data: TestDataMaker) -> None:
|
387 |
-
dataset = make_test_data([{
|
388 |
-
UUID_COLUMN: '1',
|
389 |
-
'text': 'hello'
|
390 |
-
}, {
|
391 |
-
UUID_COLUMN: '2',
|
392 |
-
'text': 'everybody'
|
393 |
-
}])
|
394 |
-
test_signal_a = TestParamSignal(param='a')
|
395 |
-
test_signal_b = TestParamSignal(param='b')
|
396 |
-
dataset.compute_signal(test_signal_a, 'text')
|
397 |
-
dataset.compute_signal(test_signal_b, 'text')
|
398 |
-
|
399 |
-
assert dataset.manifest() == DatasetManifest(
|
400 |
-
namespace=TEST_NAMESPACE,
|
401 |
-
dataset_name=TEST_DATASET_NAME,
|
402 |
-
data_schema=schema({
|
403 |
-
UUID_COLUMN: 'string',
|
404 |
-
'text': field(
|
405 |
-
'string',
|
406 |
-
fields={
|
407 |
-
'param_signal(param=a)': field('string', test_signal_a.dict()),
|
408 |
-
'param_signal(param=b)': field('string', test_signal_b.dict()),
|
409 |
-
}),
|
410 |
-
}),
|
411 |
-
num_items=2)
|
412 |
-
|
413 |
-
result = dataset.select_rows(['text'])
|
414 |
-
assert list(result) == [{
|
415 |
-
UUID_COLUMN: '1',
|
416 |
-
'text': enriched_item('hello', {
|
417 |
-
'param_signal(param=a)': 'hello_a',
|
418 |
-
'param_signal(param=b)': 'hello_b',
|
419 |
-
})
|
420 |
-
}, {
|
421 |
-
UUID_COLUMN: '2',
|
422 |
-
'text': enriched_item('everybody', {
|
423 |
-
'param_signal(param=a)': 'everybody_a',
|
424 |
-
'param_signal(param=b)': 'everybody_b',
|
425 |
-
})
|
426 |
-
}]
|
427 |
-
|
428 |
-
|
429 |
-
def test_split_signal(make_test_data: TestDataMaker) -> None:
|
430 |
-
dataset = make_test_data([{
|
431 |
-
UUID_COLUMN: '1',
|
432 |
-
'text': '[1, 1] first sentence. [1, 1] second sentence.',
|
433 |
-
}, {
|
434 |
-
UUID_COLUMN: '2',
|
435 |
-
'text': 'b2 [2, 1] first sentence. [2, 1] second sentence.',
|
436 |
-
}])
|
437 |
-
|
438 |
-
signal = TestSplitSignal()
|
439 |
-
dataset.compute_signal(signal, 'text')
|
440 |
-
|
441 |
-
assert dataset.manifest() == DatasetManifest(
|
442 |
-
namespace=TEST_NAMESPACE,
|
443 |
-
dataset_name=TEST_DATASET_NAME,
|
444 |
-
data_schema=schema({
|
445 |
-
UUID_COLUMN: 'string',
|
446 |
-
'text': field(
|
447 |
-
'string', fields={'test_split': field(signal=signal.dict(), fields=[field('string_span')])})
|
448 |
-
}),
|
449 |
-
num_items=2)
|
450 |
-
|
451 |
-
result = dataset.select_rows(['text'])
|
452 |
-
expected_result = [{
|
453 |
-
UUID_COLUMN: '1',
|
454 |
-
'text': enriched_item('[1, 1] first sentence. [1, 1] second sentence.',
|
455 |
-
{'test_split': [lilac_span(0, 22), lilac_span(23, 46)]})
|
456 |
-
}, {
|
457 |
-
UUID_COLUMN: '2',
|
458 |
-
'text': enriched_item('b2 [2, 1] first sentence. [2, 1] second sentence.',
|
459 |
-
{'test_split': [
|
460 |
-
lilac_span(0, 25),
|
461 |
-
lilac_span(26, 49),
|
462 |
-
]})
|
463 |
-
}]
|
464 |
-
assert list(result) == expected_result
|
465 |
-
|
466 |
-
|
467 |
-
def test_signal_on_repeated_field(make_test_data: TestDataMaker) -> None:
|
468 |
-
dataset = make_test_data([{
|
469 |
-
UUID_COLUMN: '1',
|
470 |
-
'text': ['hello', 'everybody'],
|
471 |
-
}, {
|
472 |
-
UUID_COLUMN: '2',
|
473 |
-
'text': ['hello2', 'everybody2'],
|
474 |
-
}])
|
475 |
-
test_signal = TestSignal()
|
476 |
-
# Run the signal on the repeated field.
|
477 |
-
dataset.compute_signal(test_signal, ('text', '*'))
|
478 |
-
|
479 |
-
# Check the enriched dataset manifest has 'text' enriched.
|
480 |
-
assert dataset.manifest() == DatasetManifest(
|
481 |
-
namespace=TEST_NAMESPACE,
|
482 |
-
dataset_name=TEST_DATASET_NAME,
|
483 |
-
data_schema=schema({
|
484 |
-
UUID_COLUMN: 'string',
|
485 |
-
'text': field(fields=[
|
486 |
-
field(
|
487 |
-
'string',
|
488 |
-
fields={
|
489 |
-
'test_signal': field(
|
490 |
-
signal=test_signal.dict(), fields={
|
491 |
-
'len': 'int32',
|
492 |
-
'flen': 'float32'
|
493 |
-
})
|
494 |
-
})
|
495 |
-
])
|
496 |
-
}),
|
497 |
-
num_items=2)
|
498 |
-
|
499 |
-
result = dataset.select_rows([('text', '*')])
|
500 |
-
|
501 |
-
assert list(result) == [{
|
502 |
-
UUID_COLUMN: '1',
|
503 |
-
'text.*': [
|
504 |
-
enriched_item('hello', {'test_signal': {
|
505 |
-
'len': 5,
|
506 |
-
'flen': 5.0
|
507 |
-
}}),
|
508 |
-
enriched_item('everybody', {'test_signal': {
|
509 |
-
'len': 9,
|
510 |
-
'flen': 9.0
|
511 |
-
}})
|
512 |
-
]
|
513 |
-
}, {
|
514 |
-
UUID_COLUMN: '2',
|
515 |
-
'text.*': [
|
516 |
-
enriched_item('hello2', {'test_signal': {
|
517 |
-
'len': 6,
|
518 |
-
'flen': 6.0
|
519 |
-
}}),
|
520 |
-
enriched_item('everybody2', {'test_signal': {
|
521 |
-
'len': 10,
|
522 |
-
'flen': 10.0
|
523 |
-
}})
|
524 |
-
]
|
525 |
-
}]
|
526 |
-
|
527 |
-
|
528 |
-
def test_text_splitter(make_test_data: TestDataMaker) -> None:
|
529 |
-
dataset = make_test_data([{
|
530 |
-
UUID_COLUMN: '1',
|
531 |
-
'text': '[1, 1] first sentence. [1, 1] second sentence.',
|
532 |
-
}, {
|
533 |
-
UUID_COLUMN: '2',
|
534 |
-
'text': 'b2 [2, 1] first sentence. [2, 1] second sentence.',
|
535 |
-
}])
|
536 |
-
|
537 |
-
dataset.compute_signal(TestSplitSignal(), 'text')
|
538 |
-
|
539 |
-
result = dataset.select_rows(['text'])
|
540 |
-
expected_result = [{
|
541 |
-
UUID_COLUMN: '1',
|
542 |
-
'text': enriched_item('[1, 1] first sentence. [1, 1] second sentence.',
|
543 |
-
{'test_split': [
|
544 |
-
lilac_span(0, 22),
|
545 |
-
lilac_span(23, 46),
|
546 |
-
]}),
|
547 |
-
}, {
|
548 |
-
UUID_COLUMN: '2',
|
549 |
-
'text': enriched_item('b2 [2, 1] first sentence. [2, 1] second sentence.',
|
550 |
-
{'test_split': [
|
551 |
-
lilac_span(0, 25),
|
552 |
-
lilac_span(26, 49),
|
553 |
-
]}),
|
554 |
-
}]
|
555 |
-
assert list(result) == expected_result
|
556 |
-
|
557 |
-
|
558 |
-
def test_embedding_signal(make_test_data: TestDataMaker) -> None:
|
559 |
-
dataset = make_test_data([{
|
560 |
-
UUID_COLUMN: '1',
|
561 |
-
'text': 'hello.',
|
562 |
-
}, {
|
563 |
-
UUID_COLUMN: '2',
|
564 |
-
'text': 'hello2.',
|
565 |
-
}])
|
566 |
-
|
567 |
-
embedding_signal = TestEmbedding()
|
568 |
-
dataset.compute_signal(embedding_signal, 'text')
|
569 |
-
|
570 |
-
assert dataset.manifest() == DatasetManifest(
|
571 |
-
namespace=TEST_NAMESPACE,
|
572 |
-
dataset_name=TEST_DATASET_NAME,
|
573 |
-
data_schema=schema({
|
574 |
-
UUID_COLUMN: 'string',
|
575 |
-
'text': field(
|
576 |
-
'string',
|
577 |
-
fields={
|
578 |
-
'test_embedding': field(
|
579 |
-
signal=embedding_signal.dict(), fields=[enriched_embedding_span_field()])
|
580 |
-
}),
|
581 |
-
}),
|
582 |
-
num_items=2)
|
583 |
-
|
584 |
-
result = dataset.select_rows()
|
585 |
-
|
586 |
-
# Embeddings are replaced with "None".
|
587 |
-
expected_result = [{
|
588 |
-
UUID_COLUMN: '1',
|
589 |
-
'text': enriched_item('hello.', {'test_embedding': [lilac_embedding(0, 6, None)]})
|
590 |
-
}, {
|
591 |
-
UUID_COLUMN: '2',
|
592 |
-
'text': enriched_item('hello2.', {'test_embedding': [lilac_embedding(0, 7, None)]})
|
593 |
-
}]
|
594 |
-
assert list(result) == expected_result
|
595 |
-
|
596 |
-
|
597 |
-
def test_is_computed_signal_key(make_test_data: TestDataMaker) -> None:
|
598 |
-
dataset = make_test_data([{
|
599 |
-
UUID_COLUMN: '1',
|
600 |
-
'text': 'hello.',
|
601 |
-
}, {
|
602 |
-
UUID_COLUMN: '2',
|
603 |
-
'text': 'hello2.',
|
604 |
-
}])
|
605 |
-
|
606 |
-
signal = ComputedKeySignal()
|
607 |
-
dataset.compute_signal(signal, 'text')
|
608 |
-
|
609 |
-
assert dataset.manifest() == DatasetManifest(
|
610 |
-
namespace=TEST_NAMESPACE,
|
611 |
-
dataset_name=TEST_DATASET_NAME,
|
612 |
-
data_schema=schema({
|
613 |
-
UUID_COLUMN: 'string',
|
614 |
-
'text': field('string', fields={'key_True': field('int64', signal=signal.dict())}),
|
615 |
-
}),
|
616 |
-
num_items=2)
|
617 |
-
|
618 |
-
result = dataset.select_rows()
|
619 |
-
|
620 |
-
# Embeddings are replaced with "None".
|
621 |
-
expected_result = [{
|
622 |
-
UUID_COLUMN: '1',
|
623 |
-
'text': enriched_item('hello.', {'key_True': 1})
|
624 |
-
}, {
|
625 |
-
UUID_COLUMN: '2',
|
626 |
-
'text': enriched_item('hello2.', {'key_True': 1})
|
627 |
-
}]
|
628 |
-
assert list(result) == expected_result
|
629 |
-
|
630 |
-
|
631 |
-
def test_concept_signal_with_select_groups(make_test_data: TestDataMaker) -> None:
|
632 |
-
dataset = make_test_data([{
|
633 |
-
UUID_COLUMN: '1',
|
634 |
-
'text': 'hello.',
|
635 |
-
}, {
|
636 |
-
UUID_COLUMN: '2',
|
637 |
-
'text': 'hello2.',
|
638 |
-
}, {
|
639 |
-
UUID_COLUMN: '3',
|
640 |
-
'text': 'hello3.',
|
641 |
-
}])
|
642 |
-
|
643 |
-
embedding_signal = TestEmbedding()
|
644 |
-
dataset.compute_signal(embedding_signal, 'text')
|
645 |
-
|
646 |
-
concept_db = DiskConceptDB()
|
647 |
-
concept_db.create(namespace='test_namespace', name='test_concept', type=SignalInputType.TEXT)
|
648 |
-
concept_db.edit(
|
649 |
-
'test_namespace', 'test_concept',
|
650 |
-
ConceptUpdate(insert=[
|
651 |
-
ExampleIn(label=False, text='hello.'),
|
652 |
-
ExampleIn(label=True, text='hello2.'),
|
653 |
-
ExampleIn(label=False, text='hello3.')
|
654 |
-
]))
|
655 |
-
|
656 |
-
concept_signal = ConceptScoreSignal(
|
657 |
-
namespace='test_namespace', concept_name='test_concept', embedding='test_embedding')
|
658 |
-
|
659 |
-
dataset.compute_signal(concept_signal, 'text')
|
660 |
-
|
661 |
-
concept_key = concept_signal.key(is_computed_signal=True)
|
662 |
-
result = dataset.select_groups(f'text.test_embedding.*.embedding.{concept_key}')
|
663 |
-
assert result.counts == [('Not in concept', 2), ('In concept', 1)]
|
664 |
-
|
665 |
-
result = dataset.select_groups(
|
666 |
-
f'text.test_embedding.*.embedding.{concept_key}',
|
667 |
-
sort_by=GroupsSortBy.COUNT,
|
668 |
-
sort_order=SortOrder.ASC)
|
669 |
-
assert result.counts == [('In concept', 1), ('Not in concept', 2)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/dataset_duckdb.py
CHANGED
@@ -6,7 +6,7 @@ import os
|
|
6 |
import re
|
7 |
import shutil
|
8 |
import threading
|
9 |
-
from typing import Any, Iterable, Optional, Sequence, Type, Union, cast
|
10 |
|
11 |
import duckdb
|
12 |
import numpy as np
|
@@ -93,6 +93,7 @@ from .dataset_utils import (
|
|
93 |
read_embedding_index,
|
94 |
replace_embeddings_with_none,
|
95 |
schema_contains_path,
|
|
|
96 |
unflatten,
|
97 |
wrap_in_dicts,
|
98 |
write_item_embeddings_to_disk,
|
@@ -686,7 +687,7 @@ class DatasetDuckDB(Dataset):
|
|
686 |
star_in_cols = any(col.path == ('*',) for col in cols)
|
687 |
if not cols or star_in_cols:
|
688 |
# Select all columns.
|
689 |
-
cols.extend([Column(name) for name in schema.fields.keys()])
|
690 |
if star_in_cols:
|
691 |
cols = [col for col in cols if col.path != ('*',)]
|
692 |
return cols
|
@@ -941,8 +942,9 @@ class DatasetDuckDB(Dataset):
|
|
941 |
# The input is an embedding.
|
942 |
embedding_signal = cast(TextEmbeddingModelSignal, signal)
|
943 |
vector_store = self.get_vector_store(embedding_signal.embedding, udf_col.path)
|
944 |
-
flat_keys = flatten_keys(df[UUID_COLUMN], input)
|
945 |
-
signal_out =
|
|
|
946 |
# Add progress.
|
947 |
if task_step_id is not None:
|
948 |
signal_out = progress(
|
@@ -953,8 +955,9 @@ class DatasetDuckDB(Dataset):
|
|
953 |
df[signal_column] = unflatten(signal_out, input)
|
954 |
else:
|
955 |
num_rich_data = count_primitives(input)
|
956 |
-
flat_input = cast(
|
957 |
-
signal_out =
|
|
|
958 |
# Add progress.
|
959 |
if task_step_id is not None:
|
960 |
signal_out = progress(
|
@@ -962,22 +965,21 @@ class DatasetDuckDB(Dataset):
|
|
962 |
task_step_id=task_step_id,
|
963 |
estimated_len=num_rich_data,
|
964 |
step_description=f'Computing {signal.key()}...')
|
965 |
-
|
966 |
-
|
967 |
if signal_column in temp_column_to_offset_column:
|
968 |
offset_column_name, field = temp_column_to_offset_column[signal_column]
|
969 |
-
nested_spans:
|
970 |
flat_spans = list(flatten(nested_spans))
|
971 |
-
for span, item in zip(flat_spans,
|
972 |
_offset_any_span(cast(int, span[VALUE_KEY][TEXT_SPAN_START_FEATURE]), item, field)
|
973 |
|
974 |
-
if len(
|
975 |
raise ValueError(
|
976 |
-
f'The signal generated {len(
|
977 |
f"{num_rich_data} values. This means the signal either didn't generate a "
|
978 |
'"None" for a sparse output, or generated too many items.')
|
979 |
|
980 |
-
df[signal_column] = unflatten(
|
981 |
|
982 |
signal.teardown()
|
983 |
|
|
|
6 |
import re
|
7 |
import shutil
|
8 |
import threading
|
9 |
+
from typing import Any, Iterable, Iterator, Optional, Sequence, Type, Union, cast
|
10 |
|
11 |
import duckdb
|
12 |
import numpy as np
|
|
|
93 |
read_embedding_index,
|
94 |
replace_embeddings_with_none,
|
95 |
schema_contains_path,
|
96 |
+
sparse_to_dense_compute,
|
97 |
unflatten,
|
98 |
wrap_in_dicts,
|
99 |
write_item_embeddings_to_disk,
|
|
|
687 |
star_in_cols = any(col.path == ('*',) for col in cols)
|
688 |
if not cols or star_in_cols:
|
689 |
# Select all columns.
|
690 |
+
cols.extend([Column((name,)) for name in schema.fields.keys()])
|
691 |
if star_in_cols:
|
692 |
cols = [col for col in cols if col.path != ('*',)]
|
693 |
return cols
|
|
|
942 |
# The input is an embedding.
|
943 |
embedding_signal = cast(TextEmbeddingModelSignal, signal)
|
944 |
vector_store = self.get_vector_store(embedding_signal.embedding, udf_col.path)
|
945 |
+
flat_keys = list(flatten_keys(df[UUID_COLUMN], input))
|
946 |
+
signal_out = sparse_to_dense_compute(
|
947 |
+
iter(flat_keys), lambda keys: signal.vector_compute(keys, vector_store))
|
948 |
# Add progress.
|
949 |
if task_step_id is not None:
|
950 |
signal_out = progress(
|
|
|
955 |
df[signal_column] = unflatten(signal_out, input)
|
956 |
else:
|
957 |
num_rich_data = count_primitives(input)
|
958 |
+
flat_input = cast(Iterator[Optional[RichData]], flatten(input))
|
959 |
+
signal_out = sparse_to_dense_compute(
|
960 |
+
flat_input, lambda x: signal.compute(cast(Iterable[RichData], x)))
|
961 |
# Add progress.
|
962 |
if task_step_id is not None:
|
963 |
signal_out = progress(
|
|
|
965 |
task_step_id=task_step_id,
|
966 |
estimated_len=num_rich_data,
|
967 |
step_description=f'Computing {signal.key()}...')
|
968 |
+
signal_out_list = list(signal_out)
|
|
|
969 |
if signal_column in temp_column_to_offset_column:
|
970 |
offset_column_name, field = temp_column_to_offset_column[signal_column]
|
971 |
+
nested_spans: Iterator[Item] = df[offset_column_name]
|
972 |
flat_spans = list(flatten(nested_spans))
|
973 |
+
for span, item in zip(flat_spans, signal_out_list):
|
974 |
_offset_any_span(cast(int, span[VALUE_KEY][TEXT_SPAN_START_FEATURE]), item, field)
|
975 |
|
976 |
+
if len(signal_out_list) != num_rich_data:
|
977 |
raise ValueError(
|
978 |
+
f'The signal generated {len(signal_out_list)} values but the input data had '
|
979 |
f"{num_rich_data} values. This means the signal either didn't generate a "
|
980 |
'"None" for a sparse output, or generated too many items.')
|
981 |
|
982 |
+
df[signal_column] = unflatten(signal_out_list, input)
|
983 |
|
984 |
signal.teardown()
|
985 |
|
src/data/dataset_select_groups_test.py
DELETED
@@ -1,317 +0,0 @@
|
|
1 |
-
"""Tests for dataset.select_groups()."""
|
2 |
-
|
3 |
-
import re
|
4 |
-
|
5 |
-
import pytest
|
6 |
-
from pytest_mock import MockerFixture
|
7 |
-
|
8 |
-
from ..schema import UUID_COLUMN, Item, field, schema
|
9 |
-
from . import dataset as dataset_module
|
10 |
-
from .dataset import BinaryOp
|
11 |
-
from .dataset_test_utils import TestDataMaker
|
12 |
-
|
13 |
-
|
14 |
-
def test_flat_data(make_test_data: TestDataMaker) -> None:
|
15 |
-
items: list[Item] = [
|
16 |
-
{
|
17 |
-
'name': 'Name1',
|
18 |
-
'age': 34,
|
19 |
-
'active': False
|
20 |
-
},
|
21 |
-
{
|
22 |
-
'name': 'Name2',
|
23 |
-
'age': 45,
|
24 |
-
'active': True
|
25 |
-
},
|
26 |
-
{
|
27 |
-
'age': 17,
|
28 |
-
'active': True
|
29 |
-
}, # Missing "name".
|
30 |
-
{
|
31 |
-
'name': 'Name3',
|
32 |
-
'active': True
|
33 |
-
}, # Missing "age".
|
34 |
-
{
|
35 |
-
'name': 'Name4',
|
36 |
-
'age': 55
|
37 |
-
} # Missing "active".
|
38 |
-
]
|
39 |
-
dataset = make_test_data(items)
|
40 |
-
|
41 |
-
result = dataset.select_groups(leaf_path='name')
|
42 |
-
assert result.counts == [('Name1', 1), ('Name2', 1), (None, 1), ('Name3', 1), ('Name4', 1)]
|
43 |
-
|
44 |
-
result = dataset.select_groups(leaf_path='age', bins=[20, 50, 60])
|
45 |
-
assert result.counts == [('1', 2), ('0', 1), (None, 1), ('2', 1)]
|
46 |
-
|
47 |
-
result = dataset.select_groups(leaf_path='active')
|
48 |
-
assert result.counts == [
|
49 |
-
(True, 3),
|
50 |
-
(False, 1),
|
51 |
-
(None, 1), # Missing "active".
|
52 |
-
]
|
53 |
-
|
54 |
-
|
55 |
-
def test_result_counts(make_test_data: TestDataMaker) -> None:
|
56 |
-
items: list[Item] = [
|
57 |
-
{
|
58 |
-
'active': False
|
59 |
-
},
|
60 |
-
{
|
61 |
-
'active': True
|
62 |
-
},
|
63 |
-
{
|
64 |
-
'active': True
|
65 |
-
},
|
66 |
-
{
|
67 |
-
'active': True
|
68 |
-
},
|
69 |
-
{} # Missing "active".
|
70 |
-
]
|
71 |
-
dataset = make_test_data(items, schema=schema({UUID_COLUMN: 'string', 'active': 'boolean'}))
|
72 |
-
|
73 |
-
result = dataset.select_groups(leaf_path='active')
|
74 |
-
assert result.counts == [(True, 3), (False, 1), (None, 1)]
|
75 |
-
|
76 |
-
|
77 |
-
def test_list_of_structs(make_test_data: TestDataMaker) -> None:
|
78 |
-
items: list[Item] = [{
|
79 |
-
'list_of_structs': [{
|
80 |
-
'name': 'a'
|
81 |
-
}, {
|
82 |
-
'name': 'b'
|
83 |
-
}]
|
84 |
-
}, {
|
85 |
-
'list_of_structs': [{
|
86 |
-
'name': 'c'
|
87 |
-
}, {
|
88 |
-
'name': 'a'
|
89 |
-
}, {
|
90 |
-
'name': 'd'
|
91 |
-
}]
|
92 |
-
}, {
|
93 |
-
'list_of_structs': [{
|
94 |
-
'name': 'd'
|
95 |
-
}]
|
96 |
-
}]
|
97 |
-
dataset = make_test_data(items)
|
98 |
-
|
99 |
-
result = dataset.select_groups(leaf_path='list_of_structs.*.name')
|
100 |
-
assert result.counts == [('a', 2), ('d', 2), ('b', 1), ('c', 1)]
|
101 |
-
|
102 |
-
|
103 |
-
def test_nested_lists(make_test_data: TestDataMaker) -> None:
|
104 |
-
items: list[Item] = [{
|
105 |
-
'nested_list': [[{
|
106 |
-
'name': 'a'
|
107 |
-
}], [{
|
108 |
-
'name': 'b'
|
109 |
-
}]]
|
110 |
-
}, {
|
111 |
-
'nested_list': [[{
|
112 |
-
'name': 'c'
|
113 |
-
}, {
|
114 |
-
'name': 'a'
|
115 |
-
}], [{
|
116 |
-
'name': 'd'
|
117 |
-
}]]
|
118 |
-
}, {
|
119 |
-
'nested_list': [[{
|
120 |
-
'name': 'd'
|
121 |
-
}]]
|
122 |
-
}]
|
123 |
-
dataset = make_test_data(items)
|
124 |
-
|
125 |
-
result = dataset.select_groups(leaf_path='nested_list.*.*.name')
|
126 |
-
assert result.counts == [('a', 2), ('d', 2), ('b', 1), ('c', 1)]
|
127 |
-
|
128 |
-
|
129 |
-
def test_nested_struct(make_test_data: TestDataMaker) -> None:
|
130 |
-
items: list[Item] = [
|
131 |
-
{
|
132 |
-
'nested_struct': {
|
133 |
-
'struct': {
|
134 |
-
'name': 'c'
|
135 |
-
}
|
136 |
-
}
|
137 |
-
},
|
138 |
-
{
|
139 |
-
'nested_struct': {
|
140 |
-
'struct': {
|
141 |
-
'name': 'b'
|
142 |
-
}
|
143 |
-
}
|
144 |
-
},
|
145 |
-
{
|
146 |
-
'nested_struct': {
|
147 |
-
'struct': {
|
148 |
-
'name': 'a'
|
149 |
-
}
|
150 |
-
}
|
151 |
-
},
|
152 |
-
]
|
153 |
-
dataset = make_test_data(items)
|
154 |
-
|
155 |
-
result = dataset.select_groups(leaf_path='nested_struct.struct.name')
|
156 |
-
assert result.counts == [('c', 1), ('b', 1), ('a', 1)]
|
157 |
-
|
158 |
-
|
159 |
-
def test_named_bins(make_test_data: TestDataMaker) -> None:
|
160 |
-
items: list[Item] = [{
|
161 |
-
'age': 34,
|
162 |
-
}, {
|
163 |
-
'age': 45,
|
164 |
-
}, {
|
165 |
-
'age': 17,
|
166 |
-
}, {
|
167 |
-
'age': 80
|
168 |
-
}, {
|
169 |
-
'age': 55
|
170 |
-
}, {
|
171 |
-
'age': float('nan')
|
172 |
-
}]
|
173 |
-
dataset = make_test_data(items)
|
174 |
-
|
175 |
-
result = dataset.select_groups(
|
176 |
-
leaf_path='age',
|
177 |
-
bins=[
|
178 |
-
('young', None, 20),
|
179 |
-
('adult', 20, 50),
|
180 |
-
('middle-aged', 50, 65),
|
181 |
-
('senior', 65, None),
|
182 |
-
])
|
183 |
-
assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1), (None, 1)]
|
184 |
-
|
185 |
-
|
186 |
-
def test_schema_with_bins(make_test_data: TestDataMaker) -> None:
|
187 |
-
items: list[Item] = [{
|
188 |
-
'age': 34,
|
189 |
-
}, {
|
190 |
-
'age': 45,
|
191 |
-
}, {
|
192 |
-
'age': 17,
|
193 |
-
}, {
|
194 |
-
'age': 80
|
195 |
-
}, {
|
196 |
-
'age': 55
|
197 |
-
}, {
|
198 |
-
'age': float('nan')
|
199 |
-
}]
|
200 |
-
data_schema = schema({
|
201 |
-
UUID_COLUMN: 'string',
|
202 |
-
'age': field(
|
203 |
-
'float32',
|
204 |
-
bins=[
|
205 |
-
('young', None, 20),
|
206 |
-
('adult', 20, 50),
|
207 |
-
('middle-aged', 50, 65),
|
208 |
-
('senior', 65, None),
|
209 |
-
])
|
210 |
-
})
|
211 |
-
dataset = make_test_data(items, data_schema)
|
212 |
-
|
213 |
-
result = dataset.select_groups(leaf_path='age')
|
214 |
-
assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1), (None, 1)]
|
215 |
-
|
216 |
-
|
217 |
-
def test_filters(make_test_data: TestDataMaker) -> None:
|
218 |
-
items: list[Item] = [
|
219 |
-
{
|
220 |
-
'name': 'Name1',
|
221 |
-
'age': 34,
|
222 |
-
'active': False
|
223 |
-
},
|
224 |
-
{
|
225 |
-
'name': 'Name2',
|
226 |
-
'age': 45,
|
227 |
-
'active': True
|
228 |
-
},
|
229 |
-
{
|
230 |
-
'age': 17,
|
231 |
-
'active': True
|
232 |
-
}, # Missing "name".
|
233 |
-
{
|
234 |
-
'name': 'Name3',
|
235 |
-
'active': True
|
236 |
-
}, # Missing "age".
|
237 |
-
{
|
238 |
-
'name': 'Name4',
|
239 |
-
'age': 55
|
240 |
-
} # Missing "active".
|
241 |
-
]
|
242 |
-
dataset = make_test_data(items)
|
243 |
-
|
244 |
-
# active = True.
|
245 |
-
result = dataset.select_groups(leaf_path='name', filters=[('active', BinaryOp.EQUALS, True)])
|
246 |
-
assert result.counts == [('Name2', 1), (None, 1), ('Name3', 1)]
|
247 |
-
|
248 |
-
# age < 35.
|
249 |
-
result = dataset.select_groups(leaf_path='name', filters=[('age', BinaryOp.LESS, 35)])
|
250 |
-
assert result.counts == [('Name1', 1), (None, 1)]
|
251 |
-
|
252 |
-
# age < 35 and active = True.
|
253 |
-
result = dataset.select_groups(
|
254 |
-
leaf_path='name', filters=[('age', BinaryOp.LESS, 35), ('active', BinaryOp.EQUALS, True)])
|
255 |
-
assert result.counts == [(None, 1)]
|
256 |
-
|
257 |
-
|
258 |
-
def test_invalid_leaf(make_test_data: TestDataMaker) -> None:
|
259 |
-
items: list[Item] = [
|
260 |
-
{
|
261 |
-
'nested_struct': {
|
262 |
-
'struct': {
|
263 |
-
'name': 'c'
|
264 |
-
}
|
265 |
-
}
|
266 |
-
},
|
267 |
-
{
|
268 |
-
'nested_struct': {
|
269 |
-
'struct': {
|
270 |
-
'name': 'b'
|
271 |
-
}
|
272 |
-
}
|
273 |
-
},
|
274 |
-
{
|
275 |
-
'nested_struct': {
|
276 |
-
'struct': {
|
277 |
-
'name': 'a'
|
278 |
-
}
|
279 |
-
}
|
280 |
-
},
|
281 |
-
]
|
282 |
-
dataset = make_test_data(items)
|
283 |
-
|
284 |
-
with pytest.raises(
|
285 |
-
ValueError, match=re.escape("Leaf \"('nested_struct',)\" not found in dataset")):
|
286 |
-
dataset.select_groups(leaf_path='nested_struct')
|
287 |
-
|
288 |
-
with pytest.raises(
|
289 |
-
ValueError, match=re.escape("Leaf \"('nested_struct', 'struct')\" not found in dataset")):
|
290 |
-
dataset.select_groups(leaf_path='nested_struct.struct')
|
291 |
-
|
292 |
-
with pytest.raises(
|
293 |
-
ValueError,
|
294 |
-
match=re.escape("Leaf \"('nested_struct', 'struct', 'wrong_name')\" not found in dataset")):
|
295 |
-
dataset.select_groups(leaf_path='nested_struct.struct.wrong_name')
|
296 |
-
|
297 |
-
|
298 |
-
def test_too_many_distinct(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
|
299 |
-
too_many_distinct = 5
|
300 |
-
mocker.patch(f'{dataset_module.__name__}.TOO_MANY_DISTINCT', too_many_distinct)
|
301 |
-
|
302 |
-
items: list[Item] = [{'feature': str(i)} for i in range(too_many_distinct + 10)]
|
303 |
-
dataset = make_test_data(items)
|
304 |
-
|
305 |
-
res = dataset.select_groups('feature')
|
306 |
-
assert res.too_many_distinct is True
|
307 |
-
assert res.counts == []
|
308 |
-
|
309 |
-
|
310 |
-
def test_auto_bins_for_float(make_test_data: TestDataMaker) -> None:
|
311 |
-
items: list[Item] = [{'feature': float(i)} for i in range(5)] + [{'feature': float('nan')}]
|
312 |
-
dataset = make_test_data(items)
|
313 |
-
|
314 |
-
res = dataset.select_groups('feature')
|
315 |
-
assert res.counts == [('0', 1), ('3', 1), ('7', 1), ('11', 1), ('14', 1), (None, 1)]
|
316 |
-
assert res.too_many_distinct is False
|
317 |
-
assert res.bins
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/dataset_select_rows_filter_test.py
DELETED
@@ -1,200 +0,0 @@
|
|
1 |
-
"""Tests for dataset.select_rows(filters=[...])."""
|
2 |
-
|
3 |
-
import pytest
|
4 |
-
|
5 |
-
from ..schema import UUID_COLUMN, Item, schema
|
6 |
-
from .dataset import BinaryFilterTuple, BinaryOp, ListFilterTuple, ListOp, UnaryOp
|
7 |
-
from .dataset_test_utils import TestDataMaker
|
8 |
-
|
9 |
-
TEST_DATA: list[Item] = [{
|
10 |
-
UUID_COLUMN: '1',
|
11 |
-
'str': 'a',
|
12 |
-
'int': 1,
|
13 |
-
'bool': False,
|
14 |
-
'float': 3.0
|
15 |
-
}, {
|
16 |
-
UUID_COLUMN: '2',
|
17 |
-
'str': 'b',
|
18 |
-
'int': 2,
|
19 |
-
'bool': True,
|
20 |
-
'float': 2.0
|
21 |
-
}, {
|
22 |
-
UUID_COLUMN: '3',
|
23 |
-
'str': 'b',
|
24 |
-
'int': 2,
|
25 |
-
'bool': True,
|
26 |
-
'float': 1.0
|
27 |
-
}, {
|
28 |
-
UUID_COLUMN: '4',
|
29 |
-
'float': float('nan')
|
30 |
-
}]
|
31 |
-
|
32 |
-
|
33 |
-
def test_filter_by_ids(make_test_data: TestDataMaker) -> None:
|
34 |
-
dataset = make_test_data(TEST_DATA)
|
35 |
-
|
36 |
-
id_filter: BinaryFilterTuple = (UUID_COLUMN, BinaryOp.EQUALS, '1')
|
37 |
-
result = dataset.select_rows(filters=[id_filter])
|
38 |
-
|
39 |
-
assert list(result) == [{UUID_COLUMN: '1', 'str': 'a', 'int': 1, 'bool': False, 'float': 3.0}]
|
40 |
-
|
41 |
-
id_filter = (UUID_COLUMN, BinaryOp.EQUALS, '2')
|
42 |
-
result = dataset.select_rows(filters=[id_filter])
|
43 |
-
|
44 |
-
assert list(result) == [{UUID_COLUMN: '2', 'str': 'b', 'int': 2, 'bool': True, 'float': 2.0}]
|
45 |
-
|
46 |
-
id_filter = (UUID_COLUMN, BinaryOp.EQUALS, b'f')
|
47 |
-
result = dataset.select_rows(filters=[id_filter])
|
48 |
-
|
49 |
-
assert list(result) == []
|
50 |
-
|
51 |
-
|
52 |
-
def test_filter_greater(make_test_data: TestDataMaker) -> None:
|
53 |
-
dataset = make_test_data(TEST_DATA)
|
54 |
-
|
55 |
-
id_filter: BinaryFilterTuple = ('float', BinaryOp.GREATER, 2.0)
|
56 |
-
result = dataset.select_rows(filters=[id_filter])
|
57 |
-
|
58 |
-
assert list(result) == [{UUID_COLUMN: '1', 'str': 'a', 'int': 1, 'bool': False, 'float': 3.0}]
|
59 |
-
|
60 |
-
|
61 |
-
def test_filter_greater_equal(make_test_data: TestDataMaker) -> None:
|
62 |
-
dataset = make_test_data(TEST_DATA)
|
63 |
-
|
64 |
-
id_filter: BinaryFilterTuple = ('float', BinaryOp.GREATER_EQUAL, 2.0)
|
65 |
-
result = dataset.select_rows(filters=[id_filter])
|
66 |
-
|
67 |
-
assert list(result) == [{
|
68 |
-
UUID_COLUMN: '1',
|
69 |
-
'str': 'a',
|
70 |
-
'int': 1,
|
71 |
-
'bool': False,
|
72 |
-
'float': 3.0
|
73 |
-
}, {
|
74 |
-
UUID_COLUMN: '2',
|
75 |
-
'str': 'b',
|
76 |
-
'int': 2,
|
77 |
-
'bool': True,
|
78 |
-
'float': 2.0
|
79 |
-
}]
|
80 |
-
|
81 |
-
|
82 |
-
def test_filter_less(make_test_data: TestDataMaker) -> None:
|
83 |
-
dataset = make_test_data(TEST_DATA)
|
84 |
-
|
85 |
-
id_filter: BinaryFilterTuple = ('float', BinaryOp.LESS, 2.0)
|
86 |
-
result = dataset.select_rows(filters=[id_filter])
|
87 |
-
|
88 |
-
assert list(result) == [{UUID_COLUMN: '3', 'str': 'b', 'int': 2, 'bool': True, 'float': 1.0}]
|
89 |
-
|
90 |
-
|
91 |
-
def test_filter_less_equal(make_test_data: TestDataMaker) -> None:
|
92 |
-
dataset = make_test_data(TEST_DATA)
|
93 |
-
|
94 |
-
id_filter: BinaryFilterTuple = ('float', BinaryOp.LESS_EQUAL, 2.0)
|
95 |
-
result = dataset.select_rows(filters=[id_filter])
|
96 |
-
|
97 |
-
assert list(result) == [{
|
98 |
-
UUID_COLUMN: '2',
|
99 |
-
'str': 'b',
|
100 |
-
'int': 2,
|
101 |
-
'bool': True,
|
102 |
-
'float': 2.0
|
103 |
-
}, {
|
104 |
-
UUID_COLUMN: '3',
|
105 |
-
'str': 'b',
|
106 |
-
'int': 2,
|
107 |
-
'bool': True,
|
108 |
-
'float': 1.0
|
109 |
-
}]
|
110 |
-
|
111 |
-
|
112 |
-
def test_filter_not_equal(make_test_data: TestDataMaker) -> None:
|
113 |
-
dataset = make_test_data(TEST_DATA)
|
114 |
-
|
115 |
-
id_filter: BinaryFilterTuple = ('float', BinaryOp.NOT_EQUAL, 2.0)
|
116 |
-
result = dataset.select_rows(filters=[id_filter])
|
117 |
-
|
118 |
-
assert list(result) == [
|
119 |
-
{
|
120 |
-
UUID_COLUMN: '1',
|
121 |
-
'str': 'a',
|
122 |
-
'int': 1,
|
123 |
-
'bool': False,
|
124 |
-
'float': 3.0
|
125 |
-
},
|
126 |
-
{
|
127 |
-
UUID_COLUMN: '3',
|
128 |
-
'str': 'b',
|
129 |
-
'int': 2,
|
130 |
-
'bool': True,
|
131 |
-
'float': 1.0
|
132 |
-
},
|
133 |
-
# NaNs are not counted when we are filtering a field.
|
134 |
-
]
|
135 |
-
|
136 |
-
|
137 |
-
def test_filter_by_list_of_ids(make_test_data: TestDataMaker) -> None:
|
138 |
-
dataset = make_test_data(TEST_DATA)
|
139 |
-
|
140 |
-
id_filter: ListFilterTuple = (UUID_COLUMN, ListOp.IN, ['1', '2'])
|
141 |
-
result = dataset.select_rows(filters=[id_filter])
|
142 |
-
|
143 |
-
assert list(result) == [{
|
144 |
-
UUID_COLUMN: '1',
|
145 |
-
'str': 'a',
|
146 |
-
'int': 1,
|
147 |
-
'bool': False,
|
148 |
-
'float': 3.0
|
149 |
-
}, {
|
150 |
-
UUID_COLUMN: '2',
|
151 |
-
'str': 'b',
|
152 |
-
'int': 2,
|
153 |
-
'bool': True,
|
154 |
-
'float': 2.0
|
155 |
-
}]
|
156 |
-
|
157 |
-
|
158 |
-
def test_filter_by_exists(make_test_data: TestDataMaker) -> None:
|
159 |
-
items: list[Item] = [{
|
160 |
-
UUID_COLUMN: '1',
|
161 |
-
'name': 'A',
|
162 |
-
'info': {
|
163 |
-
'lang': 'en'
|
164 |
-
},
|
165 |
-
'ages': []
|
166 |
-
}, {
|
167 |
-
UUID_COLUMN: '2',
|
168 |
-
'info': {
|
169 |
-
'lang': 'fr'
|
170 |
-
},
|
171 |
-
}, {
|
172 |
-
UUID_COLUMN: '3',
|
173 |
-
'name': 'C',
|
174 |
-
'ages': [[1, 2], [3, 4]]
|
175 |
-
}]
|
176 |
-
dataset = make_test_data(
|
177 |
-
items,
|
178 |
-
schema=schema({
|
179 |
-
UUID_COLUMN: 'string',
|
180 |
-
'name': 'string',
|
181 |
-
'info': {
|
182 |
-
'lang': 'string'
|
183 |
-
},
|
184 |
-
'ages': [['int32']]
|
185 |
-
}))
|
186 |
-
|
187 |
-
exists_filter = ('name', UnaryOp.EXISTS)
|
188 |
-
result = dataset.select_rows(['name'], filters=[exists_filter])
|
189 |
-
assert list(result) == [{UUID_COLUMN: '1', 'name': 'A'}, {UUID_COLUMN: '3', 'name': 'C'}]
|
190 |
-
|
191 |
-
exists_filter = ('info.lang', UnaryOp.EXISTS)
|
192 |
-
result = dataset.select_rows(['name'], filters=[exists_filter])
|
193 |
-
assert list(result) == [{UUID_COLUMN: '1', 'name': 'A'}, {UUID_COLUMN: '2', 'name': None}]
|
194 |
-
|
195 |
-
exists_filter = ('ages.*.*', UnaryOp.EXISTS)
|
196 |
-
result = dataset.select_rows(['name'], filters=[exists_filter])
|
197 |
-
assert list(result) == [{UUID_COLUMN: '3', 'name': 'C'}]
|
198 |
-
|
199 |
-
with pytest.raises(ValueError, match='Unable to filter on path'):
|
200 |
-
dataset.select_rows(['name'], filters=[('info', UnaryOp.EXISTS)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/dataset_select_rows_schema_test.py
DELETED
@@ -1,551 +0,0 @@
|
|
1 |
-
"""Tests for `db.select_rows_schema()`."""
|
2 |
-
|
3 |
-
from typing import Iterable, Optional, cast
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import pytest
|
7 |
-
from typing_extensions import override
|
8 |
-
|
9 |
-
from ..embeddings.vector_store import VectorStore
|
10 |
-
from ..schema import PATH_WILDCARD, UUID_COLUMN, Field, Item, RichData, VectorKey, field, schema
|
11 |
-
from ..signals.concept_labels import ConceptLabelsSignal
|
12 |
-
from ..signals.concept_scorer import ConceptScoreSignal
|
13 |
-
from ..signals.semantic_similarity import SemanticSimilaritySignal
|
14 |
-
from ..signals.signal import (
|
15 |
-
EMBEDDING_KEY,
|
16 |
-
TextEmbeddingModelSignal,
|
17 |
-
TextEmbeddingSignal,
|
18 |
-
TextSignal,
|
19 |
-
TextSplitterSignal,
|
20 |
-
clear_signal_registry,
|
21 |
-
register_signal,
|
22 |
-
)
|
23 |
-
from ..signals.substring_search import SubstringSignal
|
24 |
-
from .dataset import (
|
25 |
-
Column,
|
26 |
-
ConceptQuery,
|
27 |
-
KeywordQuery,
|
28 |
-
Search,
|
29 |
-
SearchResultInfo,
|
30 |
-
SelectRowsSchemaResult,
|
31 |
-
SelectRowsSchemaUDF,
|
32 |
-
SemanticQuery,
|
33 |
-
SortOrder,
|
34 |
-
SortResult,
|
35 |
-
)
|
36 |
-
from .dataset_test_utils import (
|
37 |
-
TEST_DATASET_NAME,
|
38 |
-
TEST_NAMESPACE,
|
39 |
-
TestDataMaker,
|
40 |
-
enriched_embedding_span_field,
|
41 |
-
)
|
42 |
-
from .dataset_utils import lilac_embedding, lilac_span
|
43 |
-
|
44 |
-
TEST_DATA: list[Item] = [{
|
45 |
-
UUID_COLUMN: '1',
|
46 |
-
'erased': False,
|
47 |
-
'people': [{
|
48 |
-
'name': 'A',
|
49 |
-
'zipcode': 0,
|
50 |
-
'locations': [{
|
51 |
-
'city': 'city1',
|
52 |
-
'state': 'state1'
|
53 |
-
}, {
|
54 |
-
'city': 'city2',
|
55 |
-
'state': 'state2'
|
56 |
-
}]
|
57 |
-
}]
|
58 |
-
}, {
|
59 |
-
UUID_COLUMN: '2',
|
60 |
-
'erased': True,
|
61 |
-
'people': [{
|
62 |
-
'name': 'B',
|
63 |
-
'zipcode': 1,
|
64 |
-
'locations': [{
|
65 |
-
'city': 'city3',
|
66 |
-
'state': 'state3'
|
67 |
-
}, {
|
68 |
-
'city': 'city4'
|
69 |
-
}, {
|
70 |
-
'city': 'city5'
|
71 |
-
}]
|
72 |
-
}, {
|
73 |
-
'name': 'C',
|
74 |
-
'zipcode': 2,
|
75 |
-
'locations': [{
|
76 |
-
'city': 'city1',
|
77 |
-
'state': 'state1'
|
78 |
-
}]
|
79 |
-
}]
|
80 |
-
}]
|
81 |
-
|
82 |
-
|
83 |
-
class TestSplitter(TextSplitterSignal):
|
84 |
-
"""Split documents into sentence by splitting on period."""
|
85 |
-
name = 'test_splitter'
|
86 |
-
|
87 |
-
@override
|
88 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
|
89 |
-
for text in data:
|
90 |
-
if not isinstance(text, str):
|
91 |
-
raise ValueError(f'Expected text to be a string, got {type(text)} instead.')
|
92 |
-
sentences = [f'{sentence.strip()}.' for sentence in text.split('.') if sentence]
|
93 |
-
yield [
|
94 |
-
lilac_span(text.index(sentence),
|
95 |
-
text.index(sentence) + len(sentence)) for sentence in sentences
|
96 |
-
]
|
97 |
-
|
98 |
-
|
99 |
-
EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]),
|
100 |
-
('hello2.', [1.0, 1.0, 0.0]),
|
101 |
-
('hello world.', [1.0, 1.0, 1.0]),
|
102 |
-
('hello world2.', [2.0, 1.0, 1.0])]
|
103 |
-
|
104 |
-
STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS}
|
105 |
-
|
106 |
-
|
107 |
-
class TestEmbedding(TextEmbeddingSignal):
|
108 |
-
"""A test embed function."""
|
109 |
-
name = 'test_embedding'
|
110 |
-
|
111 |
-
@override
|
112 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
|
113 |
-
"""Call the embedding function."""
|
114 |
-
for example in data:
|
115 |
-
yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))]
|
116 |
-
|
117 |
-
|
118 |
-
class TestEmbeddingSumSignal(TextEmbeddingModelSignal):
|
119 |
-
"""Sums the embeddings to return a single floating point value."""
|
120 |
-
name = 'test_embedding_sum'
|
121 |
-
|
122 |
-
@override
|
123 |
-
def fields(self) -> Field:
|
124 |
-
return field('float32')
|
125 |
-
|
126 |
-
@override
|
127 |
-
def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Item]:
|
128 |
-
# The signal just sums the values of the embedding.
|
129 |
-
embedding_sums = vector_store.get(keys).sum(axis=1)
|
130 |
-
for embedding_sum in embedding_sums.tolist():
|
131 |
-
yield embedding_sum
|
132 |
-
|
133 |
-
|
134 |
-
@pytest.fixture(scope='module', autouse=True)
|
135 |
-
def setup_teardown() -> Iterable[None]:
|
136 |
-
# Setup.
|
137 |
-
register_signal(LengthSignal)
|
138 |
-
register_signal(AddSpaceSignal)
|
139 |
-
register_signal(TestSplitter)
|
140 |
-
register_signal(TestEmbedding)
|
141 |
-
register_signal(TestEmbeddingSumSignal)
|
142 |
-
|
143 |
-
# Unit test runs.
|
144 |
-
yield
|
145 |
-
|
146 |
-
# Teardown.
|
147 |
-
clear_signal_registry()
|
148 |
-
|
149 |
-
|
150 |
-
class LengthSignal(TextSignal):
|
151 |
-
name = 'length_signal'
|
152 |
-
|
153 |
-
def fields(self) -> Field:
|
154 |
-
return field('int32')
|
155 |
-
|
156 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
157 |
-
for text_content in data:
|
158 |
-
yield len(text_content)
|
159 |
-
|
160 |
-
|
161 |
-
class AddSpaceSignal(TextSignal):
|
162 |
-
name = 'add_space_signal'
|
163 |
-
|
164 |
-
def fields(self) -> Field:
|
165 |
-
return field('string')
|
166 |
-
|
167 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
168 |
-
for text_content in data:
|
169 |
-
yield cast(str, text_content) + ' '
|
170 |
-
|
171 |
-
|
172 |
-
def test_simple_schema(make_test_data: TestDataMaker) -> None:
|
173 |
-
dataset = make_test_data(TEST_DATA)
|
174 |
-
result = dataset.select_rows_schema(combine_columns=True)
|
175 |
-
assert result == SelectRowsSchemaResult(
|
176 |
-
data_schema=schema({
|
177 |
-
UUID_COLUMN: 'string',
|
178 |
-
'erased': 'boolean',
|
179 |
-
'people': [{
|
180 |
-
'name': 'string',
|
181 |
-
'zipcode': 'int32',
|
182 |
-
'locations': [{
|
183 |
-
'city': 'string',
|
184 |
-
'state': 'string'
|
185 |
-
}]
|
186 |
-
}]
|
187 |
-
}))
|
188 |
-
|
189 |
-
|
190 |
-
def test_subselection_with_combine_cols(make_test_data: TestDataMaker) -> None:
|
191 |
-
dataset = make_test_data(TEST_DATA)
|
192 |
-
|
193 |
-
result = dataset.select_rows_schema([('people', '*', 'zipcode'),
|
194 |
-
('people', '*', 'locations', '*', 'city')],
|
195 |
-
combine_columns=True)
|
196 |
-
assert result == SelectRowsSchemaResult(
|
197 |
-
data_schema=schema({
|
198 |
-
UUID_COLUMN: 'string',
|
199 |
-
'people': [{
|
200 |
-
'zipcode': 'int32',
|
201 |
-
'locations': [{
|
202 |
-
'city': 'string'
|
203 |
-
}]
|
204 |
-
}]
|
205 |
-
}))
|
206 |
-
|
207 |
-
result = dataset.select_rows_schema([('people', '*', 'name'), ('people', '*', 'locations')],
|
208 |
-
combine_columns=True)
|
209 |
-
assert result == SelectRowsSchemaResult(
|
210 |
-
data_schema=schema({
|
211 |
-
UUID_COLUMN: 'string',
|
212 |
-
'people': [{
|
213 |
-
'name': 'string',
|
214 |
-
'locations': [{
|
215 |
-
'city': 'string',
|
216 |
-
'state': 'string'
|
217 |
-
}]
|
218 |
-
}]
|
219 |
-
}))
|
220 |
-
|
221 |
-
result = dataset.select_rows_schema([('people', '*')], combine_columns=True)
|
222 |
-
assert result == SelectRowsSchemaResult(
|
223 |
-
namespace=TEST_NAMESPACE,
|
224 |
-
dataset_name=TEST_DATASET_NAME,
|
225 |
-
data_schema=schema({
|
226 |
-
UUID_COLUMN: 'string',
|
227 |
-
'people': [{
|
228 |
-
'name': 'string',
|
229 |
-
'zipcode': 'int32',
|
230 |
-
'locations': [{
|
231 |
-
'city': 'string',
|
232 |
-
'state': 'string'
|
233 |
-
}]
|
234 |
-
}]
|
235 |
-
}))
|
236 |
-
|
237 |
-
|
238 |
-
def test_udf_with_combine_cols(make_test_data: TestDataMaker) -> None:
|
239 |
-
dataset = make_test_data(TEST_DATA)
|
240 |
-
|
241 |
-
length_signal = LengthSignal()
|
242 |
-
result = dataset.select_rows_schema([('people', '*', 'locations', '*', 'city'),
|
243 |
-
Column(('people', '*', 'name'), signal_udf=length_signal)],
|
244 |
-
combine_columns=True)
|
245 |
-
assert result == SelectRowsSchemaResult(
|
246 |
-
data_schema=schema({
|
247 |
-
UUID_COLUMN: 'string',
|
248 |
-
'people': [{
|
249 |
-
'name': {
|
250 |
-
'length_signal': field('int32', length_signal.dict())
|
251 |
-
},
|
252 |
-
'locations': [{
|
253 |
-
'city': 'string'
|
254 |
-
}]
|
255 |
-
}],
|
256 |
-
}),
|
257 |
-
udfs=[
|
258 |
-
SelectRowsSchemaUDF(path=('people', '*', 'name', length_signal.key())),
|
259 |
-
],
|
260 |
-
)
|
261 |
-
|
262 |
-
|
263 |
-
def test_embedding_udf_with_combine_cols(make_test_data: TestDataMaker) -> None:
|
264 |
-
dataset = make_test_data(TEST_DATA)
|
265 |
-
|
266 |
-
add_space_signal = AddSpaceSignal()
|
267 |
-
path = ('people', '*', 'name')
|
268 |
-
dataset.compute_signal(add_space_signal, path)
|
269 |
-
result = dataset.select_rows_schema([path, Column(path, signal_udf=add_space_signal)],
|
270 |
-
combine_columns=True)
|
271 |
-
assert result == SelectRowsSchemaResult(
|
272 |
-
data_schema=schema({
|
273 |
-
UUID_COLUMN: 'string',
|
274 |
-
'people': [{
|
275 |
-
'name': field(
|
276 |
-
'string', fields={'add_space_signal': field('string', signal=add_space_signal.dict())})
|
277 |
-
}],
|
278 |
-
}),
|
279 |
-
udfs=[
|
280 |
-
SelectRowsSchemaUDF(path=(*path, add_space_signal.key())),
|
281 |
-
],
|
282 |
-
)
|
283 |
-
|
284 |
-
|
285 |
-
def test_udf_chained_with_combine_cols(make_test_data: TestDataMaker) -> None:
|
286 |
-
dataset = make_test_data([{
|
287 |
-
UUID_COLUMN: '1',
|
288 |
-
'text': 'hello. hello2.',
|
289 |
-
}, {
|
290 |
-
UUID_COLUMN: '2',
|
291 |
-
'text': 'hello world. hello world2.',
|
292 |
-
}])
|
293 |
-
|
294 |
-
test_splitter = TestSplitter()
|
295 |
-
dataset.compute_signal(test_splitter, ('text'))
|
296 |
-
add_space_signal = AddSpaceSignal()
|
297 |
-
result = dataset.select_rows_schema(
|
298 |
-
[('text'), Column(('text'), signal_udf=add_space_signal)], combine_columns=True)
|
299 |
-
|
300 |
-
assert result == SelectRowsSchemaResult(
|
301 |
-
data_schema=schema({
|
302 |
-
UUID_COLUMN: 'string',
|
303 |
-
'text': field(
|
304 |
-
'string',
|
305 |
-
fields={
|
306 |
-
'add_space_signal': field('string', add_space_signal.dict()),
|
307 |
-
'test_splitter': field(signal=test_splitter.dict(), fields=['string_span'])
|
308 |
-
})
|
309 |
-
}),
|
310 |
-
udfs=[
|
311 |
-
SelectRowsSchemaUDF(path=('text', add_space_signal.key())),
|
312 |
-
],
|
313 |
-
)
|
314 |
-
|
315 |
-
|
316 |
-
def test_udf_embedding_chained_with_combine_cols(make_test_data: TestDataMaker) -> None:
|
317 |
-
dataset = make_test_data([{
|
318 |
-
UUID_COLUMN: '1',
|
319 |
-
'text': 'hello. hello2.',
|
320 |
-
}, {
|
321 |
-
UUID_COLUMN: '2',
|
322 |
-
'text': 'hello world. hello world2.',
|
323 |
-
}])
|
324 |
-
|
325 |
-
test_splitter = TestSplitter()
|
326 |
-
dataset.compute_signal(test_splitter, 'text')
|
327 |
-
test_embedding = TestEmbedding()
|
328 |
-
dataset.compute_signal(test_embedding, ('text', 'test_splitter', '*'))
|
329 |
-
|
330 |
-
embedding_sum_signal = TestEmbeddingSumSignal(embedding='test_embedding')
|
331 |
-
udf_col = Column(('text', 'test_splitter', '*'), signal_udf=embedding_sum_signal)
|
332 |
-
result = dataset.select_rows_schema([('text'), udf_col], combine_columns=True)
|
333 |
-
|
334 |
-
expected_schema = schema({
|
335 |
-
UUID_COLUMN: 'string',
|
336 |
-
'text': field(
|
337 |
-
'string',
|
338 |
-
fields={
|
339 |
-
'test_splitter': field(
|
340 |
-
signal=test_splitter.dict(),
|
341 |
-
fields=[
|
342 |
-
field(
|
343 |
-
'string_span',
|
344 |
-
fields={
|
345 |
-
'test_embedding': field(
|
346 |
-
signal=test_embedding.dict(),
|
347 |
-
fields=[
|
348 |
-
enriched_embedding_span_field(
|
349 |
-
{'test_embedding_sum': field('float32', embedding_sum_signal.dict())})
|
350 |
-
])
|
351 |
-
})
|
352 |
-
])
|
353 |
-
})
|
354 |
-
})
|
355 |
-
output_path = ('text', 'test_splitter', '*', 'test_embedding', '*', 'embedding',
|
356 |
-
'test_embedding_sum')
|
357 |
-
assert result == SelectRowsSchemaResult(
|
358 |
-
data_schema=expected_schema,
|
359 |
-
udfs=[SelectRowsSchemaUDF(path=output_path)],
|
360 |
-
)
|
361 |
-
|
362 |
-
# Alias the udf.
|
363 |
-
udf_col.alias = 'udf1'
|
364 |
-
result = dataset.select_rows_schema([('text'), udf_col], combine_columns=True)
|
365 |
-
assert result == SelectRowsSchemaResult(
|
366 |
-
data_schema=expected_schema,
|
367 |
-
udfs=[SelectRowsSchemaUDF(path=output_path, alias='udf1')],
|
368 |
-
)
|
369 |
-
|
370 |
-
|
371 |
-
def test_search_keyword_schema(make_test_data: TestDataMaker) -> None:
|
372 |
-
dataset = make_test_data([{
|
373 |
-
UUID_COLUMN: '1',
|
374 |
-
'text': 'hello world',
|
375 |
-
'text2': 'hello world2',
|
376 |
-
}])
|
377 |
-
query_world = 'world'
|
378 |
-
query_hello = 'hello'
|
379 |
-
|
380 |
-
result = dataset.select_rows_schema(
|
381 |
-
searches=[
|
382 |
-
Search(path='text', query=KeywordQuery(type='keyword', search=query_world)),
|
383 |
-
Search(path='text2', query=KeywordQuery(type='keyword', search=query_hello)),
|
384 |
-
],
|
385 |
-
combine_columns=True)
|
386 |
-
|
387 |
-
expected_world_signal = SubstringSignal(query=query_world)
|
388 |
-
expected_hello_signal = SubstringSignal(query=query_hello)
|
389 |
-
|
390 |
-
assert result == SelectRowsSchemaResult(
|
391 |
-
data_schema=schema({
|
392 |
-
UUID_COLUMN: 'string',
|
393 |
-
'text': field(
|
394 |
-
'string',
|
395 |
-
fields={
|
396 |
-
expected_world_signal.key(): field(
|
397 |
-
signal=expected_world_signal.dict(), fields=['string_span'])
|
398 |
-
}),
|
399 |
-
'text2': field(
|
400 |
-
'string',
|
401 |
-
fields={
|
402 |
-
expected_hello_signal.key(): field(
|
403 |
-
signal=expected_hello_signal.dict(), fields=['string_span'])
|
404 |
-
})
|
405 |
-
}),
|
406 |
-
search_results=[
|
407 |
-
SearchResultInfo(
|
408 |
-
search_path=('text',),
|
409 |
-
result_path=('text', expected_world_signal.key(), PATH_WILDCARD),
|
410 |
-
),
|
411 |
-
SearchResultInfo(
|
412 |
-
search_path=('text2',),
|
413 |
-
result_path=('text2', expected_hello_signal.key(), PATH_WILDCARD),
|
414 |
-
)
|
415 |
-
],
|
416 |
-
udfs=[
|
417 |
-
SelectRowsSchemaUDF(path=('text', expected_world_signal.key())),
|
418 |
-
SelectRowsSchemaUDF(path=('text2', expected_hello_signal.key())),
|
419 |
-
],
|
420 |
-
)
|
421 |
-
|
422 |
-
|
423 |
-
def test_search_semantic_schema(make_test_data: TestDataMaker) -> None:
|
424 |
-
dataset = make_test_data([{
|
425 |
-
UUID_COLUMN: '1',
|
426 |
-
'text': 'hello world.',
|
427 |
-
}])
|
428 |
-
query_world = 'world'
|
429 |
-
|
430 |
-
test_embedding = TestEmbedding()
|
431 |
-
dataset.compute_signal(test_embedding, ('text'))
|
432 |
-
|
433 |
-
result = dataset.select_rows_schema(
|
434 |
-
searches=[
|
435 |
-
Search(
|
436 |
-
path='text',
|
437 |
-
query=SemanticQuery(type='semantic', search=query_world, embedding='test_embedding')),
|
438 |
-
],
|
439 |
-
combine_columns=True)
|
440 |
-
|
441 |
-
test_embedding = TestEmbedding()
|
442 |
-
expected_world_signal = SemanticSimilaritySignal(query=query_world, embedding='test_embedding')
|
443 |
-
|
444 |
-
similarity_score_path = ('text', 'test_embedding', PATH_WILDCARD, EMBEDDING_KEY,
|
445 |
-
expected_world_signal.key())
|
446 |
-
assert result == SelectRowsSchemaResult(
|
447 |
-
data_schema=schema({
|
448 |
-
UUID_COLUMN: 'string',
|
449 |
-
'text': field(
|
450 |
-
'string',
|
451 |
-
fields={
|
452 |
-
'test_embedding': field(
|
453 |
-
signal=test_embedding.dict(),
|
454 |
-
fields=[
|
455 |
-
enriched_embedding_span_field(
|
456 |
-
{expected_world_signal.key(): field('float32', expected_world_signal.dict())})
|
457 |
-
])
|
458 |
-
})
|
459 |
-
}),
|
460 |
-
udfs=[SelectRowsSchemaUDF(path=similarity_score_path)],
|
461 |
-
search_results=[SearchResultInfo(search_path=('text',), result_path=similarity_score_path)],
|
462 |
-
sorts=[SortResult(path=similarity_score_path, order=SortOrder.DESC, search_index=0)])
|
463 |
-
|
464 |
-
|
465 |
-
def test_search_concept_schema(make_test_data: TestDataMaker) -> None:
|
466 |
-
dataset = make_test_data([{
|
467 |
-
UUID_COLUMN: '1',
|
468 |
-
'text': 'hello world.',
|
469 |
-
}])
|
470 |
-
|
471 |
-
test_embedding = TestEmbedding()
|
472 |
-
dataset.compute_signal(test_embedding, ('text'))
|
473 |
-
|
474 |
-
result = dataset.select_rows_schema(
|
475 |
-
searches=[
|
476 |
-
Search(
|
477 |
-
path='text',
|
478 |
-
query=ConceptQuery(
|
479 |
-
type='concept',
|
480 |
-
concept_namespace='test_namespace',
|
481 |
-
concept_name='test_concept',
|
482 |
-
embedding='test_embedding')),
|
483 |
-
],
|
484 |
-
combine_columns=True)
|
485 |
-
|
486 |
-
test_embedding = TestEmbedding()
|
487 |
-
expected_world_signal = ConceptScoreSignal(
|
488 |
-
namespace='test_namespace', concept_name='test_concept', embedding='test_embedding')
|
489 |
-
expected_labels_signal = ConceptLabelsSignal(
|
490 |
-
namespace='test_namespace', concept_name='test_concept')
|
491 |
-
|
492 |
-
concept_score_path = ('text', 'test_embedding', PATH_WILDCARD, EMBEDDING_KEY,
|
493 |
-
expected_world_signal.key())
|
494 |
-
concept_labels_path = ('text', expected_labels_signal.key())
|
495 |
-
assert result == SelectRowsSchemaResult(
|
496 |
-
data_schema=schema({
|
497 |
-
UUID_COLUMN: 'string',
|
498 |
-
'text': field(
|
499 |
-
'string',
|
500 |
-
fields={
|
501 |
-
'test_embedding': field(
|
502 |
-
signal=test_embedding.dict(),
|
503 |
-
fields=[
|
504 |
-
enriched_embedding_span_field({
|
505 |
-
expected_world_signal.key(): field(
|
506 |
-
'float32',
|
507 |
-
expected_world_signal.dict(),
|
508 |
-
bins=[('Not in concept', None, 0.5), ('In concept', 0.5, None)])
|
509 |
-
})
|
510 |
-
]),
|
511 |
-
'test_namespace/test_concept/labels': field(
|
512 |
-
fields=[field('string_span', fields={
|
513 |
-
'label': 'boolean',
|
514 |
-
'draft': 'string'
|
515 |
-
})],
|
516 |
-
signal=expected_labels_signal.dict())
|
517 |
-
})
|
518 |
-
}),
|
519 |
-
udfs=[
|
520 |
-
SelectRowsSchemaUDF(path=concept_labels_path),
|
521 |
-
SelectRowsSchemaUDF(path=concept_score_path)
|
522 |
-
],
|
523 |
-
search_results=[
|
524 |
-
SearchResultInfo(search_path=('text',), result_path=concept_labels_path),
|
525 |
-
SearchResultInfo(search_path=('text',), result_path=concept_score_path)
|
526 |
-
],
|
527 |
-
sorts=[SortResult(path=concept_score_path, order=SortOrder.DESC, search_index=0)])
|
528 |
-
|
529 |
-
|
530 |
-
def test_search_sort_override(make_test_data: TestDataMaker) -> None:
|
531 |
-
dataset = make_test_data([{
|
532 |
-
UUID_COLUMN: '1',
|
533 |
-
'text': 'hello world.',
|
534 |
-
}])
|
535 |
-
query_world = 'world'
|
536 |
-
|
537 |
-
test_embedding = TestEmbedding()
|
538 |
-
dataset.compute_signal(test_embedding, ('text'))
|
539 |
-
|
540 |
-
result = dataset.select_rows_schema(
|
541 |
-
searches=[
|
542 |
-
Search(
|
543 |
-
path='text',
|
544 |
-
query=SemanticQuery(type='semantic', search=query_world, embedding='test_embedding')),
|
545 |
-
],
|
546 |
-
# Explicit sort by overrides the semantic search.
|
547 |
-
sort_by=[('text',)],
|
548 |
-
sort_order=SortOrder.DESC,
|
549 |
-
combine_columns=True)
|
550 |
-
|
551 |
-
assert result.sorts == [SortResult(path=('text',), order=SortOrder.DESC)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/dataset_select_rows_search_test.py
DELETED
@@ -1,393 +0,0 @@
|
|
1 |
-
"""Tests for dataset.select_rows(searches=[...])."""
|
2 |
-
|
3 |
-
from typing import Iterable, cast
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import pytest
|
7 |
-
from pytest import approx
|
8 |
-
from pytest_mock import MockerFixture
|
9 |
-
from sklearn.preprocessing import normalize
|
10 |
-
from typing_extensions import override
|
11 |
-
|
12 |
-
from ..concepts.concept import ExampleIn, LogisticEmbeddingModel
|
13 |
-
from ..concepts.db_concept import ConceptUpdate, DiskConceptDB
|
14 |
-
from ..db_manager import set_default_dataset_cls
|
15 |
-
from ..schema import UUID_COLUMN, Item, RichData, SignalInputType
|
16 |
-
from ..signals.concept_scorer import ConceptScoreSignal
|
17 |
-
from ..signals.semantic_similarity import SemanticSimilaritySignal
|
18 |
-
from ..signals.signal import TextEmbeddingSignal, clear_signal_registry, register_signal
|
19 |
-
from ..signals.substring_search import SubstringSignal
|
20 |
-
from .dataset import ConceptQuery, KeywordQuery, ListOp, Search, SemanticQuery, SortOrder
|
21 |
-
from .dataset_duckdb import DatasetDuckDB
|
22 |
-
from .dataset_test_utils import TestDataMaker, enriched_embedding_span, enriched_item
|
23 |
-
from .dataset_utils import lilac_embedding, lilac_span
|
24 |
-
|
25 |
-
TEST_DATA: list[Item] = [{
|
26 |
-
UUID_COLUMN: '1',
|
27 |
-
'text': 'hello world',
|
28 |
-
'text2': 'again hello world',
|
29 |
-
}, {
|
30 |
-
UUID_COLUMN: '2',
|
31 |
-
'text': 'looking for world in text',
|
32 |
-
'text2': 'again looking for world in text',
|
33 |
-
}, {
|
34 |
-
UUID_COLUMN: '3',
|
35 |
-
'text': 'unrelated text',
|
36 |
-
'text2': 'again unrelated text'
|
37 |
-
}]
|
38 |
-
|
39 |
-
EMBEDDINGS: list[tuple[str, list[float]]] = [
|
40 |
-
('hello.', [1.0, 0.0, 0.0]),
|
41 |
-
('hello2.', [1.0, 1.0, 0.0]),
|
42 |
-
('hello world.', [1.0, 1.0, 1.0]),
|
43 |
-
('hello world2.', [2.0, 1.0, 1.0]),
|
44 |
-
('random negative 1', [0, 0, 0.3]),
|
45 |
-
('random negative 2', [0, 0, 0.4]),
|
46 |
-
('random negative 3', [0, 0.1, 0.5]),
|
47 |
-
('random negative 4', [0.1, 0, 0.4]),
|
48 |
-
]
|
49 |
-
|
50 |
-
STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS}
|
51 |
-
|
52 |
-
|
53 |
-
@pytest.fixture(scope='module', autouse=True)
|
54 |
-
def setup_teardown() -> Iterable[None]:
|
55 |
-
# Setup.
|
56 |
-
set_default_dataset_cls(DatasetDuckDB)
|
57 |
-
register_signal(TestEmbedding)
|
58 |
-
|
59 |
-
# Unit test runs.
|
60 |
-
yield
|
61 |
-
|
62 |
-
# Teardown.
|
63 |
-
clear_signal_registry()
|
64 |
-
|
65 |
-
|
66 |
-
def test_search_keyword(make_test_data: TestDataMaker) -> None:
|
67 |
-
dataset = make_test_data(TEST_DATA)
|
68 |
-
|
69 |
-
query = 'world'
|
70 |
-
result = dataset.select_rows(
|
71 |
-
searches=[Search(path='text', query=KeywordQuery(type='keyword', search=query))],
|
72 |
-
combine_columns=True)
|
73 |
-
|
74 |
-
expected_signal_udf = SubstringSignal(query=query)
|
75 |
-
assert list(result) == [{
|
76 |
-
UUID_COLUMN: '1',
|
77 |
-
'text': enriched_item('hello world', {expected_signal_udf.key(): [lilac_span(6, 11)]}),
|
78 |
-
'text2': 'again hello world'
|
79 |
-
}, {
|
80 |
-
UUID_COLUMN: '2',
|
81 |
-
'text': enriched_item('looking for world in text',
|
82 |
-
{expected_signal_udf.key(): [lilac_span(12, 17)]}),
|
83 |
-
'text2': 'again looking for world in text',
|
84 |
-
}]
|
85 |
-
|
86 |
-
|
87 |
-
def test_search_keyword_special_chars(make_test_data: TestDataMaker) -> None:
|
88 |
-
dataset = make_test_data([{
|
89 |
-
UUID_COLUMN: '1',
|
90 |
-
'text': 'This is 100%',
|
91 |
-
}, {
|
92 |
-
UUID_COLUMN: '2',
|
93 |
-
'text': 'This has _underscore_',
|
94 |
-
}])
|
95 |
-
|
96 |
-
query = '100%'
|
97 |
-
result = dataset.select_rows(
|
98 |
-
searches=[Search(path='text', query=KeywordQuery(type='keyword', search=query))],
|
99 |
-
combine_columns=True)
|
100 |
-
|
101 |
-
expected_signal_udf = SubstringSignal(query=query)
|
102 |
-
assert list(result) == [{
|
103 |
-
UUID_COLUMN: '1',
|
104 |
-
'text': enriched_item('This is 100%', {expected_signal_udf.key(): [lilac_span(8, 12)]}),
|
105 |
-
}]
|
106 |
-
|
107 |
-
query = '_underscore_'
|
108 |
-
result = dataset.select_rows(
|
109 |
-
searches=[Search(path='text', query=KeywordQuery(type='keyword', search=query))],
|
110 |
-
combine_columns=True)
|
111 |
-
|
112 |
-
expected_signal_udf = SubstringSignal(query=query)
|
113 |
-
assert list(result) == [{
|
114 |
-
UUID_COLUMN: '2',
|
115 |
-
'text': enriched_item('This has _underscore_',
|
116 |
-
{expected_signal_udf.key(): [lilac_span(9, 21)]}),
|
117 |
-
}]
|
118 |
-
|
119 |
-
|
120 |
-
def test_search_keyword_multiple(make_test_data: TestDataMaker) -> None:
|
121 |
-
dataset = make_test_data(TEST_DATA)
|
122 |
-
|
123 |
-
query_world = 'world'
|
124 |
-
query_looking_world = 'looking for world'
|
125 |
-
expected_world_udf = SubstringSignal(query=query_world)
|
126 |
-
expected_again_looking_udf = SubstringSignal(query=query_looking_world)
|
127 |
-
|
128 |
-
result = dataset.select_rows(
|
129 |
-
searches=[
|
130 |
-
Search(path='text', query=KeywordQuery(type='keyword', search=query_world)),
|
131 |
-
Search(path='text2', query=KeywordQuery(type='keyword', search=query_looking_world)),
|
132 |
-
],
|
133 |
-
combine_columns=True)
|
134 |
-
|
135 |
-
assert list(result) == [{
|
136 |
-
UUID_COLUMN: '2',
|
137 |
-
'text': enriched_item('looking for world in text', {
|
138 |
-
expected_world_udf.key(): [lilac_span(12, 17)],
|
139 |
-
}),
|
140 |
-
'text2': enriched_item('again looking for world in text',
|
141 |
-
{expected_again_looking_udf.key(): [lilac_span(6, 23)]})
|
142 |
-
}]
|
143 |
-
|
144 |
-
|
145 |
-
def test_search_keyword_with_filters(make_test_data: TestDataMaker) -> None:
|
146 |
-
dataset = make_test_data(TEST_DATA)
|
147 |
-
|
148 |
-
query = 'world'
|
149 |
-
result = dataset.select_rows(
|
150 |
-
filters=[(UUID_COLUMN, ListOp.IN, ['1', '3'])],
|
151 |
-
searches=[Search(path='text', query=KeywordQuery(type='keyword', search=query))],
|
152 |
-
combine_columns=True)
|
153 |
-
|
154 |
-
expected_signal_udf = SubstringSignal(query=query)
|
155 |
-
assert list(result) == [
|
156 |
-
{
|
157 |
-
UUID_COLUMN: '1',
|
158 |
-
'text': enriched_item('hello world', {expected_signal_udf.key(): [lilac_span(6, 11)]}),
|
159 |
-
'text2': 'again hello world'
|
160 |
-
},
|
161 |
-
# The second row doesn't match the UUID filter.
|
162 |
-
]
|
163 |
-
|
164 |
-
|
165 |
-
class TestEmbedding(TextEmbeddingSignal):
|
166 |
-
"""A test embed function."""
|
167 |
-
name = 'test_embedding'
|
168 |
-
|
169 |
-
@override
|
170 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
|
171 |
-
"""Call the embedding function."""
|
172 |
-
for example in data:
|
173 |
-
embedding = np.array(STR_EMBEDDINGS[cast(str, example)])
|
174 |
-
embedding = normalize([embedding])[0]
|
175 |
-
yield [lilac_embedding(0, len(example), embedding)]
|
176 |
-
|
177 |
-
|
178 |
-
def test_semantic_search(make_test_data: TestDataMaker) -> None:
|
179 |
-
dataset = make_test_data([{
|
180 |
-
UUID_COLUMN: '1',
|
181 |
-
'text': 'hello world.',
|
182 |
-
}, {
|
183 |
-
UUID_COLUMN: '2',
|
184 |
-
'text': 'hello world2.',
|
185 |
-
}])
|
186 |
-
|
187 |
-
test_embedding = TestEmbedding()
|
188 |
-
dataset.compute_signal(test_embedding, ('text'))
|
189 |
-
|
190 |
-
query = 'hello2.'
|
191 |
-
result = dataset.select_rows(
|
192 |
-
searches=[
|
193 |
-
Search(
|
194 |
-
path='text', query=SemanticQuery(type='semantic', search=query, embedding='test_embedding'))
|
195 |
-
],
|
196 |
-
combine_columns=True)
|
197 |
-
expected_signal_udf = SemanticSimilaritySignal(query=query, embedding='test_embedding')
|
198 |
-
assert list(result) == [
|
199 |
-
# Results are sorted by score desc.
|
200 |
-
{
|
201 |
-
UUID_COLUMN: '2',
|
202 |
-
'text': enriched_item(
|
203 |
-
'hello world2.', {
|
204 |
-
test_embedding.key():
|
205 |
-
[enriched_embedding_span(0, 13, {expected_signal_udf.key(): approx(0.916, 1e-3)})]
|
206 |
-
})
|
207 |
-
},
|
208 |
-
{
|
209 |
-
UUID_COLUMN: '1',
|
210 |
-
'text': enriched_item(
|
211 |
-
'hello world.', {
|
212 |
-
test_embedding.key():
|
213 |
-
[enriched_embedding_span(0, 12, {expected_signal_udf.key(): approx(0.885, 1e-3)})]
|
214 |
-
})
|
215 |
-
},
|
216 |
-
]
|
217 |
-
|
218 |
-
|
219 |
-
def test_concept_search(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
|
220 |
-
concept_model_mock = mocker.spy(LogisticEmbeddingModel, 'fit')
|
221 |
-
|
222 |
-
dataset = make_test_data([{
|
223 |
-
UUID_COLUMN: '1',
|
224 |
-
'text': 'hello world.',
|
225 |
-
}, {
|
226 |
-
UUID_COLUMN: '2',
|
227 |
-
'text': 'hello world2.',
|
228 |
-
}, {
|
229 |
-
UUID_COLUMN: '3',
|
230 |
-
'text': 'random negative 1',
|
231 |
-
}, {
|
232 |
-
UUID_COLUMN: '4',
|
233 |
-
'text': 'random negative 2',
|
234 |
-
}, {
|
235 |
-
UUID_COLUMN: '5',
|
236 |
-
'text': 'random negative 3',
|
237 |
-
}, {
|
238 |
-
UUID_COLUMN: '6',
|
239 |
-
'text': 'random negative 4',
|
240 |
-
}])
|
241 |
-
|
242 |
-
test_embedding = TestEmbedding()
|
243 |
-
dataset.compute_signal(test_embedding, ('text'))
|
244 |
-
|
245 |
-
concept_db = DiskConceptDB()
|
246 |
-
concept_db.create(namespace='test_namespace', name='test_concept', type=SignalInputType.TEXT)
|
247 |
-
concept_db.edit(
|
248 |
-
'test_namespace', 'test_concept',
|
249 |
-
ConceptUpdate(insert=[
|
250 |
-
ExampleIn(label=False, text='hello world.'),
|
251 |
-
ExampleIn(label=True, text='hello world2.')
|
252 |
-
]))
|
253 |
-
|
254 |
-
result = dataset.select_rows(
|
255 |
-
searches=[
|
256 |
-
Search(
|
257 |
-
path='text',
|
258 |
-
query=ConceptQuery(
|
259 |
-
type='concept',
|
260 |
-
concept_namespace='test_namespace',
|
261 |
-
concept_name='test_concept',
|
262 |
-
embedding='test_embedding'))
|
263 |
-
],
|
264 |
-
filters=[(UUID_COLUMN, ListOp.IN, ['1', '2'])],
|
265 |
-
combine_columns=True)
|
266 |
-
expected_signal_udf = ConceptScoreSignal(
|
267 |
-
namespace='test_namespace', concept_name='test_concept', embedding='test_embedding')
|
268 |
-
|
269 |
-
assert list(result) == [
|
270 |
-
# Results are sorted by score desc.
|
271 |
-
{
|
272 |
-
UUID_COLUMN: '2',
|
273 |
-
'text': enriched_item(
|
274 |
-
'hello world2.', {
|
275 |
-
test_embedding.key():
|
276 |
-
[enriched_embedding_span(0, 13, {expected_signal_udf.key(): approx(0.75, abs=0.25)})],
|
277 |
-
'test_namespace/test_concept/labels': [lilac_span(0, 13, {'label': True})]
|
278 |
-
})
|
279 |
-
},
|
280 |
-
{
|
281 |
-
UUID_COLUMN: '1',
|
282 |
-
'text': enriched_item(
|
283 |
-
'hello world.', {
|
284 |
-
test_embedding.key():
|
285 |
-
[enriched_embedding_span(0, 12, {expected_signal_udf.key(): approx(0.25, abs=0.25)})],
|
286 |
-
'test_namespace/test_concept/labels': [lilac_span(0, 12, {'label': False})]
|
287 |
-
})
|
288 |
-
},
|
289 |
-
]
|
290 |
-
|
291 |
-
(_, embeddings, labels, _) = concept_model_mock.call_args_list[-1].args
|
292 |
-
assert embeddings.shape == (2, 3)
|
293 |
-
assert labels == [
|
294 |
-
# Explicit labels.
|
295 |
-
False,
|
296 |
-
True
|
297 |
-
]
|
298 |
-
|
299 |
-
|
300 |
-
def test_sort_override_search(make_test_data: TestDataMaker) -> None:
|
301 |
-
dataset = make_test_data([{
|
302 |
-
UUID_COLUMN: '1',
|
303 |
-
'text': 'hello world.',
|
304 |
-
'value': 10
|
305 |
-
}, {
|
306 |
-
UUID_COLUMN: '2',
|
307 |
-
'text': 'hello world2.',
|
308 |
-
'value': 20
|
309 |
-
}])
|
310 |
-
|
311 |
-
test_embedding = TestEmbedding()
|
312 |
-
dataset.compute_signal(test_embedding, ('text'))
|
313 |
-
|
314 |
-
query = 'hello2.'
|
315 |
-
search = Search(
|
316 |
-
path='text', query=SemanticQuery(type='semantic', search=query, embedding='test_embedding'))
|
317 |
-
|
318 |
-
expected_signal_udf = SemanticSimilaritySignal(query=query, embedding='test_embedding')
|
319 |
-
expected_item_1 = {
|
320 |
-
UUID_COLUMN: '1',
|
321 |
-
'text': enriched_item(
|
322 |
-
'hello world.', {
|
323 |
-
test_embedding.key():
|
324 |
-
[enriched_embedding_span(0, 12, {expected_signal_udf.key(): approx(0.885, 1e-3)})]
|
325 |
-
}),
|
326 |
-
'value': 10
|
327 |
-
}
|
328 |
-
expected_item_2 = {
|
329 |
-
UUID_COLUMN: '2',
|
330 |
-
'text': enriched_item(
|
331 |
-
'hello world2.', {
|
332 |
-
test_embedding.key():
|
333 |
-
[enriched_embedding_span(0, 13, {expected_signal_udf.key(): approx(0.916, 1e-3)})]
|
334 |
-
}),
|
335 |
-
'value': 20
|
336 |
-
}
|
337 |
-
|
338 |
-
sort_order = SortOrder.ASC
|
339 |
-
result = dataset.select_rows(
|
340 |
-
searches=[search], sort_by=[('value',)], sort_order=sort_order, combine_columns=True)
|
341 |
-
assert list(result) == [
|
342 |
-
# Results are sorted by score ascending.
|
343 |
-
expected_item_1,
|
344 |
-
expected_item_2
|
345 |
-
]
|
346 |
-
|
347 |
-
sort_order = SortOrder.DESC
|
348 |
-
result = dataset.select_rows(
|
349 |
-
searches=[search], sort_by=[('text',)], sort_order=sort_order, combine_columns=True)
|
350 |
-
assert list(result) == [
|
351 |
-
# Results are sorted by score descending.
|
352 |
-
expected_item_2,
|
353 |
-
expected_item_1
|
354 |
-
]
|
355 |
-
|
356 |
-
|
357 |
-
def test_search_keyword_and_semantic(make_test_data: TestDataMaker) -> None:
|
358 |
-
dataset = make_test_data([{
|
359 |
-
UUID_COLUMN: '1',
|
360 |
-
'text': 'hello world.',
|
361 |
-
}, {
|
362 |
-
UUID_COLUMN: '2',
|
363 |
-
'text': 'hello world2.',
|
364 |
-
}])
|
365 |
-
|
366 |
-
test_embedding = TestEmbedding()
|
367 |
-
dataset.compute_signal(test_embedding, ('text'))
|
368 |
-
|
369 |
-
query = 'hello2.'
|
370 |
-
keyword_query = 'rld2'
|
371 |
-
result = dataset.select_rows(
|
372 |
-
searches=[
|
373 |
-
Search(
|
374 |
-
path='text', query=SemanticQuery(type='semantic', search=query,
|
375 |
-
embedding='test_embedding')),
|
376 |
-
Search(path='text', query=KeywordQuery(type='keyword', search=keyword_query))
|
377 |
-
],
|
378 |
-
combine_columns=True)
|
379 |
-
expected_semantic_signal = SemanticSimilaritySignal(query=query, embedding='test_embedding')
|
380 |
-
expected_keyword_signal = SubstringSignal(query=keyword_query)
|
381 |
-
assert list(result) == [
|
382 |
-
# Results are sorted by score desc.
|
383 |
-
{
|
384 |
-
UUID_COLUMN: '2',
|
385 |
-
'text': enriched_item(
|
386 |
-
'hello world2.', {
|
387 |
-
test_embedding.key():
|
388 |
-
[enriched_embedding_span(0, 13, {expected_semantic_signal.key(): approx(0.916, 1e-3)})],
|
389 |
-
expected_keyword_signal.key(): [lilac_span(8, 12)],
|
390 |
-
})
|
391 |
-
},
|
392 |
-
# UUID '1' is not returned because it does not match the keyword query.
|
393 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/dataset_select_rows_sort_test.py
DELETED
@@ -1,904 +0,0 @@
|
|
1 |
-
"""Tests for dataset.select_rows(sort_by=...)."""
|
2 |
-
|
3 |
-
from typing import Iterable, Optional, Sequence, cast
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import pytest
|
7 |
-
from typing_extensions import override
|
8 |
-
|
9 |
-
from ..embeddings.vector_store import VectorStore
|
10 |
-
from ..schema import UUID_COLUMN, Field, Item, RichData, VectorKey, field
|
11 |
-
from ..signals.signal import (
|
12 |
-
TextEmbeddingModelSignal,
|
13 |
-
TextEmbeddingSignal,
|
14 |
-
TextSignal,
|
15 |
-
clear_signal_registry,
|
16 |
-
register_signal,
|
17 |
-
)
|
18 |
-
from .dataset import BinaryOp, Column, SortOrder
|
19 |
-
from .dataset_test_utils import TestDataMaker, enriched_item
|
20 |
-
from .dataset_utils import lilac_embedding
|
21 |
-
|
22 |
-
|
23 |
-
class TestSignal(TextSignal):
|
24 |
-
name = 'test_signal'
|
25 |
-
|
26 |
-
def fields(self) -> Field:
|
27 |
-
return field(fields={'len': 'int32', 'is_all_cap': 'boolean'})
|
28 |
-
|
29 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
30 |
-
for text_content in data:
|
31 |
-
yield {'len': len(text_content), 'is_all_cap': text_content.isupper()}
|
32 |
-
|
33 |
-
|
34 |
-
class TestPrimitiveSignal(TextSignal):
|
35 |
-
name = 'primitive_signal'
|
36 |
-
|
37 |
-
def fields(self) -> Field:
|
38 |
-
return field('int32')
|
39 |
-
|
40 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
41 |
-
for text_content in data:
|
42 |
-
yield len(text_content) + 1
|
43 |
-
|
44 |
-
|
45 |
-
class NestedArraySignal(TextSignal):
|
46 |
-
name = 'nested_array'
|
47 |
-
|
48 |
-
def fields(self) -> Field:
|
49 |
-
return field(fields=[['int32']])
|
50 |
-
|
51 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
52 |
-
for text_content in data:
|
53 |
-
yield [[len(text_content) + 1], [len(text_content)]]
|
54 |
-
|
55 |
-
|
56 |
-
@pytest.fixture(scope='module', autouse=True)
|
57 |
-
def setup_teardown() -> Iterable[None]:
|
58 |
-
# Setup.
|
59 |
-
register_signal(TestSignal)
|
60 |
-
register_signal(TestPrimitiveSignal)
|
61 |
-
register_signal(NestedArraySignal)
|
62 |
-
register_signal(TopKEmbedding)
|
63 |
-
# Unit test runs.
|
64 |
-
yield
|
65 |
-
# Teardown.
|
66 |
-
clear_signal_registry()
|
67 |
-
|
68 |
-
|
69 |
-
def test_sort_by_source_no_alias_no_repeated(make_test_data: TestDataMaker) -> None:
|
70 |
-
dataset = make_test_data([{
|
71 |
-
UUID_COLUMN: '1',
|
72 |
-
'erased': True,
|
73 |
-
'score': 4.1,
|
74 |
-
'document': {
|
75 |
-
'num_pages': 4,
|
76 |
-
'header': {
|
77 |
-
'title': 'c'
|
78 |
-
}
|
79 |
-
}
|
80 |
-
}, {
|
81 |
-
UUID_COLUMN: '2',
|
82 |
-
'erased': False,
|
83 |
-
'score': 3.5,
|
84 |
-
'document': {
|
85 |
-
'num_pages': 5,
|
86 |
-
'header': {
|
87 |
-
'title': 'b'
|
88 |
-
}
|
89 |
-
},
|
90 |
-
}, {
|
91 |
-
UUID_COLUMN: '3',
|
92 |
-
'erased': True,
|
93 |
-
'score': 3.7,
|
94 |
-
'document': {
|
95 |
-
'num_pages': 3,
|
96 |
-
'header': {
|
97 |
-
'title': 'a'
|
98 |
-
}
|
99 |
-
},
|
100 |
-
}])
|
101 |
-
|
102 |
-
# Sort by bool.
|
103 |
-
result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['erased'], sort_order=SortOrder.ASC)
|
104 |
-
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}]
|
105 |
-
result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['erased'], sort_order=SortOrder.DESC)
|
106 |
-
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}]
|
107 |
-
|
108 |
-
# Sort by float.
|
109 |
-
result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['score'], sort_order=SortOrder.ASC)
|
110 |
-
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '1'}]
|
111 |
-
result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['score'], sort_order=SortOrder.DESC)
|
112 |
-
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}]
|
113 |
-
|
114 |
-
# Sort by nested int.
|
115 |
-
result = dataset.select_rows(
|
116 |
-
columns=[UUID_COLUMN], sort_by=['document.num_pages'], sort_order=SortOrder.ASC)
|
117 |
-
assert list(result) == [{UUID_COLUMN: '3'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}]
|
118 |
-
result = dataset.select_rows(
|
119 |
-
columns=[UUID_COLUMN], sort_by=['document.num_pages'], sort_order=SortOrder.DESC)
|
120 |
-
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}]
|
121 |
-
|
122 |
-
# Sort by double nested string.
|
123 |
-
result = dataset.select_rows(
|
124 |
-
columns=[UUID_COLUMN], sort_by=['document.header.title'], sort_order=SortOrder.ASC)
|
125 |
-
assert list(result) == [{UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}]
|
126 |
-
result = dataset.select_rows(
|
127 |
-
columns=[UUID_COLUMN], sort_by=['document.header.title'], sort_order=SortOrder.DESC)
|
128 |
-
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}, {UUID_COLUMN: '3'}]
|
129 |
-
|
130 |
-
|
131 |
-
def test_sort_by_signal_no_alias_no_repeated(make_test_data: TestDataMaker) -> None:
|
132 |
-
dataset = make_test_data([{
|
133 |
-
UUID_COLUMN: '1',
|
134 |
-
'text': 'HEY'
|
135 |
-
}, {
|
136 |
-
UUID_COLUMN: '2',
|
137 |
-
'text': 'everyone'
|
138 |
-
}, {
|
139 |
-
UUID_COLUMN: '3',
|
140 |
-
'text': 'HI'
|
141 |
-
}])
|
142 |
-
|
143 |
-
dataset.compute_signal(TestSignal(), 'text')
|
144 |
-
|
145 |
-
# Sort by `signal.len`.
|
146 |
-
result = dataset.select_rows(
|
147 |
-
columns=[UUID_COLUMN], sort_by=['text.test_signal.len'], sort_order=SortOrder.ASC)
|
148 |
-
assert list(result) == [{UUID_COLUMN: '3'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}]
|
149 |
-
result = dataset.select_rows(
|
150 |
-
columns=[UUID_COLUMN], sort_by=['text.test_signal.len'], sort_order=SortOrder.DESC)
|
151 |
-
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}]
|
152 |
-
|
153 |
-
# Sort by `signal.is_all_cap`.
|
154 |
-
result = dataset.select_rows(
|
155 |
-
columns=[UUID_COLUMN], sort_by=['text.test_signal.is_all_cap'], sort_order=SortOrder.ASC)
|
156 |
-
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}]
|
157 |
-
result = dataset.select_rows(
|
158 |
-
columns=[UUID_COLUMN], sort_by=['text.test_signal.is_all_cap'], sort_order=SortOrder.DESC)
|
159 |
-
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}]
|
160 |
-
|
161 |
-
|
162 |
-
def test_sort_by_signal_alias_no_repeated(make_test_data: TestDataMaker) -> None:
|
163 |
-
dataset = make_test_data([{
|
164 |
-
UUID_COLUMN: '1',
|
165 |
-
'text': 'HEY'
|
166 |
-
}, {
|
167 |
-
UUID_COLUMN: '2',
|
168 |
-
'text': 'everyone'
|
169 |
-
}, {
|
170 |
-
UUID_COLUMN: '3',
|
171 |
-
'text': 'HI'
|
172 |
-
}])
|
173 |
-
|
174 |
-
dataset.compute_signal(TestSignal(), 'text')
|
175 |
-
|
176 |
-
# Sort by `signal.len`.
|
177 |
-
signal_alias = Column('text.test_signal', alias='signal')
|
178 |
-
result = dataset.select_rows(
|
179 |
-
columns=[signal_alias], sort_by=['signal.len'], sort_order=SortOrder.ASC)
|
180 |
-
assert list(result) == [{
|
181 |
-
UUID_COLUMN: '3',
|
182 |
-
'signal': {
|
183 |
-
'len': 2,
|
184 |
-
'is_all_cap': True
|
185 |
-
}
|
186 |
-
}, {
|
187 |
-
UUID_COLUMN: '1',
|
188 |
-
'signal': {
|
189 |
-
'len': 3,
|
190 |
-
'is_all_cap': True
|
191 |
-
}
|
192 |
-
}, {
|
193 |
-
UUID_COLUMN: '2',
|
194 |
-
'signal': {
|
195 |
-
'len': 8,
|
196 |
-
'is_all_cap': False
|
197 |
-
}
|
198 |
-
}]
|
199 |
-
result = dataset.select_rows(
|
200 |
-
columns=[signal_alias], sort_by=['signal.len'], sort_order=SortOrder.DESC)
|
201 |
-
assert list(result) == [{
|
202 |
-
UUID_COLUMN: '2',
|
203 |
-
'signal': {
|
204 |
-
'len': 8,
|
205 |
-
'is_all_cap': False
|
206 |
-
}
|
207 |
-
}, {
|
208 |
-
UUID_COLUMN: '1',
|
209 |
-
'signal': {
|
210 |
-
'len': 3,
|
211 |
-
'is_all_cap': True
|
212 |
-
}
|
213 |
-
}, {
|
214 |
-
UUID_COLUMN: '3',
|
215 |
-
'signal': {
|
216 |
-
'len': 2,
|
217 |
-
'is_all_cap': True
|
218 |
-
}
|
219 |
-
}]
|
220 |
-
|
221 |
-
|
222 |
-
def test_sort_by_enriched_alias_no_repeated(make_test_data: TestDataMaker) -> None:
|
223 |
-
dataset = make_test_data([{
|
224 |
-
UUID_COLUMN: '1',
|
225 |
-
'text': 'HEY'
|
226 |
-
}, {
|
227 |
-
UUID_COLUMN: '2',
|
228 |
-
'text': 'everyone'
|
229 |
-
}, {
|
230 |
-
UUID_COLUMN: '3',
|
231 |
-
'text': 'HI'
|
232 |
-
}])
|
233 |
-
|
234 |
-
dataset.compute_signal(TestSignal(), 'text')
|
235 |
-
|
236 |
-
# Sort by `document.test_signal.is_all_cap` where 'document' is an alias to 'text'.
|
237 |
-
text_alias = Column('text', alias='document')
|
238 |
-
result = dataset.select_rows(
|
239 |
-
columns=[text_alias], sort_by=['document.test_signal.is_all_cap'], sort_order=SortOrder.ASC)
|
240 |
-
assert list(result) == [{
|
241 |
-
UUID_COLUMN: '2',
|
242 |
-
'document': enriched_item('everyone', {'test_signal': {
|
243 |
-
'len': 8,
|
244 |
-
'is_all_cap': False
|
245 |
-
}})
|
246 |
-
}, {
|
247 |
-
UUID_COLUMN: '1',
|
248 |
-
'document': enriched_item('HEY', {'test_signal': {
|
249 |
-
'len': 3,
|
250 |
-
'is_all_cap': True
|
251 |
-
}})
|
252 |
-
}, {
|
253 |
-
UUID_COLUMN: '3',
|
254 |
-
'document': enriched_item('HI', {'test_signal': {
|
255 |
-
'len': 2,
|
256 |
-
'is_all_cap': True
|
257 |
-
}})
|
258 |
-
}]
|
259 |
-
|
260 |
-
result = dataset.select_rows(
|
261 |
-
columns=[text_alias], sort_by=['document.test_signal.is_all_cap'], sort_order=SortOrder.DESC)
|
262 |
-
assert list(result) == [{
|
263 |
-
UUID_COLUMN: '1',
|
264 |
-
'document': enriched_item('HEY', {'test_signal': {
|
265 |
-
'len': 3,
|
266 |
-
'is_all_cap': True
|
267 |
-
}})
|
268 |
-
}, {
|
269 |
-
UUID_COLUMN: '3',
|
270 |
-
'document': enriched_item('HI', {'test_signal': {
|
271 |
-
'len': 2,
|
272 |
-
'is_all_cap': True
|
273 |
-
}})
|
274 |
-
}, {
|
275 |
-
UUID_COLUMN: '2',
|
276 |
-
'document': enriched_item('everyone', {'test_signal': {
|
277 |
-
'len': 8,
|
278 |
-
'is_all_cap': False
|
279 |
-
}})
|
280 |
-
}]
|
281 |
-
|
282 |
-
|
283 |
-
def test_sort_by_udf_alias_no_repeated(make_test_data: TestDataMaker) -> None:
|
284 |
-
dataset = make_test_data([{
|
285 |
-
UUID_COLUMN: '1',
|
286 |
-
'text': 'HEY'
|
287 |
-
}, {
|
288 |
-
UUID_COLUMN: '2',
|
289 |
-
'text': 'everyone'
|
290 |
-
}, {
|
291 |
-
UUID_COLUMN: '3',
|
292 |
-
'text': 'HI'
|
293 |
-
}])
|
294 |
-
|
295 |
-
# Equivalent to: SELECT `TestSignal(text) AS udf`.
|
296 |
-
text_udf = Column('text', signal_udf=TestSignal(), alias='udf')
|
297 |
-
# Sort by `udf.len`, where `udf` is an alias to `TestSignal(text)`.
|
298 |
-
result = dataset.select_rows(['*', text_udf], sort_by=['udf.len'], sort_order=SortOrder.ASC)
|
299 |
-
assert list(result) == [{
|
300 |
-
UUID_COLUMN: '3',
|
301 |
-
'text': 'HI',
|
302 |
-
'udf': {
|
303 |
-
'len': 2,
|
304 |
-
'is_all_cap': True
|
305 |
-
}
|
306 |
-
}, {
|
307 |
-
UUID_COLUMN: '1',
|
308 |
-
'text': 'HEY',
|
309 |
-
'udf': {
|
310 |
-
'len': 3,
|
311 |
-
'is_all_cap': True
|
312 |
-
}
|
313 |
-
}, {
|
314 |
-
UUID_COLUMN: '2',
|
315 |
-
'text': 'everyone',
|
316 |
-
'udf': {
|
317 |
-
'len': 8,
|
318 |
-
'is_all_cap': False
|
319 |
-
}
|
320 |
-
}]
|
321 |
-
|
322 |
-
|
323 |
-
def test_sort_by_udf_no_alias_no_repeated(make_test_data: TestDataMaker) -> None:
|
324 |
-
dataset = make_test_data([{
|
325 |
-
UUID_COLUMN: '1',
|
326 |
-
'text': 'HEY'
|
327 |
-
}, {
|
328 |
-
UUID_COLUMN: '2',
|
329 |
-
'text': 'everyone'
|
330 |
-
}, {
|
331 |
-
UUID_COLUMN: '3',
|
332 |
-
'text': 'HI'
|
333 |
-
}])
|
334 |
-
|
335 |
-
text_udf = Column('text', signal_udf=TestSignal())
|
336 |
-
# Sort by `text.test_signal.len`, produced by executing the udf `TestSignal(text)`.
|
337 |
-
result = dataset.select_rows(['*', text_udf],
|
338 |
-
sort_by=[('text', 'test_signal', 'len')],
|
339 |
-
sort_order=SortOrder.ASC,
|
340 |
-
combine_columns=True)
|
341 |
-
assert list(result) == [{
|
342 |
-
UUID_COLUMN: '3',
|
343 |
-
'text': enriched_item('HI', {'test_signal': {
|
344 |
-
'len': 2,
|
345 |
-
'is_all_cap': True
|
346 |
-
}}),
|
347 |
-
}, {
|
348 |
-
UUID_COLUMN: '1',
|
349 |
-
'text': enriched_item('HEY', {'test_signal': {
|
350 |
-
'len': 3,
|
351 |
-
'is_all_cap': True
|
352 |
-
}}),
|
353 |
-
}, {
|
354 |
-
UUID_COLUMN: '2',
|
355 |
-
'text': enriched_item('everyone', {'test_signal': {
|
356 |
-
'len': 8,
|
357 |
-
'is_all_cap': False
|
358 |
-
}}),
|
359 |
-
}]
|
360 |
-
|
361 |
-
# Sort descending.
|
362 |
-
result = dataset.select_rows(['*', text_udf],
|
363 |
-
sort_by=[('text', 'test_signal', 'len')],
|
364 |
-
sort_order=SortOrder.DESC,
|
365 |
-
combine_columns=True)
|
366 |
-
assert list(result) == [{
|
367 |
-
UUID_COLUMN: '2',
|
368 |
-
'text': enriched_item('everyone', {'test_signal': {
|
369 |
-
'len': 8,
|
370 |
-
'is_all_cap': False
|
371 |
-
}}),
|
372 |
-
}, {
|
373 |
-
UUID_COLUMN: '1',
|
374 |
-
'text': enriched_item('HEY', {'test_signal': {
|
375 |
-
'len': 3,
|
376 |
-
'is_all_cap': True
|
377 |
-
}}),
|
378 |
-
}, {
|
379 |
-
UUID_COLUMN: '3',
|
380 |
-
'text': enriched_item('HI', {'test_signal': {
|
381 |
-
'len': 2,
|
382 |
-
'is_all_cap': True
|
383 |
-
}}),
|
384 |
-
}]
|
385 |
-
|
386 |
-
|
387 |
-
def test_sort_by_primitive_udf_alias_no_repeated(make_test_data: TestDataMaker) -> None:
|
388 |
-
dataset = make_test_data([{
|
389 |
-
UUID_COLUMN: '1',
|
390 |
-
'text': 'HEY'
|
391 |
-
}, {
|
392 |
-
UUID_COLUMN: '2',
|
393 |
-
'text': 'everyone'
|
394 |
-
}, {
|
395 |
-
UUID_COLUMN: '3',
|
396 |
-
'text': 'HI'
|
397 |
-
}])
|
398 |
-
|
399 |
-
# Equivalent to: SELECT `TestPrimitiveSignal(text) AS udf`.
|
400 |
-
text_udf = Column('text', signal_udf=TestPrimitiveSignal(), alias='udf')
|
401 |
-
# Sort by the primitive value returned by the udf.
|
402 |
-
result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.ASC)
|
403 |
-
assert list(result) == [{
|
404 |
-
UUID_COLUMN: '3',
|
405 |
-
'text': 'HI',
|
406 |
-
'udf': 3
|
407 |
-
}, {
|
408 |
-
UUID_COLUMN: '1',
|
409 |
-
'text': 'HEY',
|
410 |
-
'udf': 4
|
411 |
-
}, {
|
412 |
-
UUID_COLUMN: '2',
|
413 |
-
'text': 'everyone',
|
414 |
-
'udf': 9
|
415 |
-
}]
|
416 |
-
|
417 |
-
|
418 |
-
def test_sort_by_source_non_leaf_errors(make_test_data: TestDataMaker) -> None:
|
419 |
-
dataset = make_test_data([{
|
420 |
-
UUID_COLUMN: '1',
|
421 |
-
'vals': [7, 1]
|
422 |
-
}, {
|
423 |
-
UUID_COLUMN: '2',
|
424 |
-
'vals': [3, 4]
|
425 |
-
}, {
|
426 |
-
UUID_COLUMN: '3',
|
427 |
-
'vals': [9, 0]
|
428 |
-
}])
|
429 |
-
|
430 |
-
# Sort by repeated.
|
431 |
-
with pytest.raises(ValueError, match='Unable to sort by path'):
|
432 |
-
dataset.select_rows(columns=[UUID_COLUMN], sort_by=['vals'], sort_order=SortOrder.ASC)
|
433 |
-
|
434 |
-
|
435 |
-
def test_sort_by_source_no_alias_repeated(make_test_data: TestDataMaker) -> None:
|
436 |
-
dataset = make_test_data([{
|
437 |
-
UUID_COLUMN: '1',
|
438 |
-
'vals': [[{
|
439 |
-
'score': 7
|
440 |
-
}, {
|
441 |
-
'score': 1
|
442 |
-
}], [{
|
443 |
-
'score': 1
|
444 |
-
}, {
|
445 |
-
'score': 7
|
446 |
-
}]]
|
447 |
-
}, {
|
448 |
-
UUID_COLUMN: '2',
|
449 |
-
'vals': [[{
|
450 |
-
'score': 3
|
451 |
-
}, {
|
452 |
-
'score': 4
|
453 |
-
}]]
|
454 |
-
}, {
|
455 |
-
UUID_COLUMN: '3',
|
456 |
-
'vals': [[{
|
457 |
-
'score': 9
|
458 |
-
}, {
|
459 |
-
'score': 0
|
460 |
-
}]]
|
461 |
-
}])
|
462 |
-
|
463 |
-
# Sort by repeated 'vals'.
|
464 |
-
result = dataset.select_rows(
|
465 |
-
columns=[UUID_COLUMN, 'vals'], sort_by=['vals.*.*.score'], sort_order=SortOrder.ASC)
|
466 |
-
assert list(result) == [{
|
467 |
-
UUID_COLUMN: '3',
|
468 |
-
'vals': [[{
|
469 |
-
'score': 9
|
470 |
-
}, {
|
471 |
-
'score': 0
|
472 |
-
}]]
|
473 |
-
}, {
|
474 |
-
UUID_COLUMN: '1',
|
475 |
-
'vals': [[{
|
476 |
-
'score': 7
|
477 |
-
}, {
|
478 |
-
'score': 1
|
479 |
-
}], [{
|
480 |
-
'score': 1
|
481 |
-
}, {
|
482 |
-
'score': 7
|
483 |
-
}]]
|
484 |
-
}, {
|
485 |
-
UUID_COLUMN: '2',
|
486 |
-
'vals': [[{
|
487 |
-
'score': 3
|
488 |
-
}, {
|
489 |
-
'score': 4
|
490 |
-
}]]
|
491 |
-
}]
|
492 |
-
|
493 |
-
result = dataset.select_rows(
|
494 |
-
columns=[UUID_COLUMN, 'vals'], sort_by=['vals.*.*.score'], sort_order=SortOrder.DESC)
|
495 |
-
assert list(result) == [{
|
496 |
-
UUID_COLUMN: '3',
|
497 |
-
'vals': [[{
|
498 |
-
'score': 9
|
499 |
-
}, {
|
500 |
-
'score': 0
|
501 |
-
}]]
|
502 |
-
}, {
|
503 |
-
UUID_COLUMN: '1',
|
504 |
-
'vals': [[{
|
505 |
-
'score': 7
|
506 |
-
}, {
|
507 |
-
'score': 1
|
508 |
-
}], [{
|
509 |
-
'score': 1
|
510 |
-
}, {
|
511 |
-
'score': 7
|
512 |
-
}]]
|
513 |
-
}, {
|
514 |
-
UUID_COLUMN: '2',
|
515 |
-
'vals': [[{
|
516 |
-
'score': 3
|
517 |
-
}, {
|
518 |
-
'score': 4
|
519 |
-
}]]
|
520 |
-
}]
|
521 |
-
|
522 |
-
|
523 |
-
def test_sort_by_source_alias_repeated(make_test_data: TestDataMaker) -> None:
|
524 |
-
dataset = make_test_data([{
|
525 |
-
UUID_COLUMN: '1',
|
526 |
-
'vals': [[7, 1], [1, 7]]
|
527 |
-
}, {
|
528 |
-
UUID_COLUMN: '2',
|
529 |
-
'vals': [[3], [11]]
|
530 |
-
}, {
|
531 |
-
UUID_COLUMN: '3',
|
532 |
-
'vals': [[9, 0]]
|
533 |
-
}])
|
534 |
-
|
535 |
-
# Sort by repeated 'vals'.
|
536 |
-
result = dataset.select_rows(
|
537 |
-
columns=[UUID_COLUMN, Column('vals', alias='scores')],
|
538 |
-
sort_by=['scores.*.*'],
|
539 |
-
sort_order=SortOrder.ASC)
|
540 |
-
assert list(result) == [{
|
541 |
-
UUID_COLUMN: '3',
|
542 |
-
'scores': [[9, 0]]
|
543 |
-
}, {
|
544 |
-
UUID_COLUMN: '1',
|
545 |
-
'scores': [[7, 1], [1, 7]]
|
546 |
-
}, {
|
547 |
-
UUID_COLUMN: '2',
|
548 |
-
'scores': [[3], [11]]
|
549 |
-
}]
|
550 |
-
|
551 |
-
result = dataset.select_rows(
|
552 |
-
columns=[UUID_COLUMN, Column('vals', alias='scores')],
|
553 |
-
sort_by=['scores.*.*'],
|
554 |
-
sort_order=SortOrder.DESC)
|
555 |
-
assert list(result) == [{
|
556 |
-
UUID_COLUMN: '2',
|
557 |
-
'scores': [[3], [11]]
|
558 |
-
}, {
|
559 |
-
UUID_COLUMN: '3',
|
560 |
-
'scores': [[9, 0]]
|
561 |
-
}, {
|
562 |
-
UUID_COLUMN: '1',
|
563 |
-
'scores': [[7, 1], [1, 7]]
|
564 |
-
}]
|
565 |
-
|
566 |
-
|
567 |
-
def test_sort_by_udf_alias_repeated(make_test_data: TestDataMaker) -> None:
|
568 |
-
dataset = make_test_data([{
|
569 |
-
UUID_COLUMN: '1',
|
570 |
-
'text': 'HEY'
|
571 |
-
}, {
|
572 |
-
UUID_COLUMN: '2',
|
573 |
-
'text': 'everyone'
|
574 |
-
}, {
|
575 |
-
UUID_COLUMN: '3',
|
576 |
-
'text': 'HI'
|
577 |
-
}])
|
578 |
-
|
579 |
-
# Equivalent to: SELECT `NestedArraySignal(text) AS udf`.
|
580 |
-
text_udf = Column('text', signal_udf=NestedArraySignal(), alias='udf')
|
581 |
-
# Sort by `udf.*.*`, where `udf` is an alias to `NestedArraySignal(text)`.
|
582 |
-
result = dataset.select_rows(['*', text_udf], sort_by=['udf.*.*'], sort_order=SortOrder.ASC)
|
583 |
-
assert list(result) == [{
|
584 |
-
UUID_COLUMN: '3',
|
585 |
-
'text': 'HI',
|
586 |
-
'udf': [[3], [2]]
|
587 |
-
}, {
|
588 |
-
UUID_COLUMN: '1',
|
589 |
-
'text': 'HEY',
|
590 |
-
'udf': [[4], [3]]
|
591 |
-
}, {
|
592 |
-
UUID_COLUMN: '2',
|
593 |
-
'text': 'everyone',
|
594 |
-
'udf': [[9], [8]]
|
595 |
-
}]
|
596 |
-
result = dataset.select_rows(['*', text_udf], sort_by=['udf.*.*'], sort_order=SortOrder.DESC)
|
597 |
-
assert list(result) == [{
|
598 |
-
UUID_COLUMN: '2',
|
599 |
-
'text': 'everyone',
|
600 |
-
'udf': [[9], [8]]
|
601 |
-
}, {
|
602 |
-
UUID_COLUMN: '1',
|
603 |
-
'text': 'HEY',
|
604 |
-
'udf': [[4], [3]]
|
605 |
-
}, {
|
606 |
-
UUID_COLUMN: '3',
|
607 |
-
'text': 'HI',
|
608 |
-
'udf': [[3], [2]]
|
609 |
-
}]
|
610 |
-
|
611 |
-
|
612 |
-
def test_sort_by_complex_signal_udf_alias_called_on_repeated(make_test_data: TestDataMaker) -> None:
|
613 |
-
dataset = make_test_data([{
|
614 |
-
UUID_COLUMN: '1',
|
615 |
-
'texts': [{
|
616 |
-
'text': 'eardrop'
|
617 |
-
}, {
|
618 |
-
'text': 'I'
|
619 |
-
}]
|
620 |
-
}, {
|
621 |
-
UUID_COLUMN: '2',
|
622 |
-
'texts': [{
|
623 |
-
'text': 'hey'
|
624 |
-
}, {
|
625 |
-
'text': 'CARS'
|
626 |
-
}]
|
627 |
-
}, {
|
628 |
-
UUID_COLUMN: '3',
|
629 |
-
'texts': [{
|
630 |
-
'text': 'everyone'
|
631 |
-
}, {
|
632 |
-
'text': ''
|
633 |
-
}]
|
634 |
-
}])
|
635 |
-
|
636 |
-
# Equivalent to: SELECT `TestSignal(texts.*.text) AS udf`.
|
637 |
-
texts_udf = Column('texts.*.text', signal_udf=TestSignal(), alias='udf')
|
638 |
-
# Sort by `udf.len`, where `udf` is an alias to `TestSignal(texts.*.text)`.
|
639 |
-
result = dataset.select_rows(['*', texts_udf],
|
640 |
-
sort_by=['udf.len'],
|
641 |
-
sort_order=SortOrder.ASC,
|
642 |
-
combine_columns=True)
|
643 |
-
assert list(result) == [{
|
644 |
-
UUID_COLUMN: '3',
|
645 |
-
'texts': [{
|
646 |
-
'text': enriched_item('everyone', {'test_signal': {
|
647 |
-
'len': 8,
|
648 |
-
'is_all_cap': False
|
649 |
-
}})
|
650 |
-
}, {
|
651 |
-
'text': enriched_item('', {'test_signal': {
|
652 |
-
'len': 0,
|
653 |
-
'is_all_cap': False
|
654 |
-
}})
|
655 |
-
}]
|
656 |
-
}, {
|
657 |
-
UUID_COLUMN: '1',
|
658 |
-
'texts': [{
|
659 |
-
'text': enriched_item('eardrop', {'test_signal': {
|
660 |
-
'len': 7,
|
661 |
-
'is_all_cap': False
|
662 |
-
}})
|
663 |
-
}, {
|
664 |
-
'text': enriched_item('I', {'test_signal': {
|
665 |
-
'len': 1,
|
666 |
-
'is_all_cap': True
|
667 |
-
}})
|
668 |
-
}]
|
669 |
-
}, {
|
670 |
-
UUID_COLUMN: '2',
|
671 |
-
'texts': [{
|
672 |
-
'text': enriched_item('hey', {'test_signal': {
|
673 |
-
'len': 3,
|
674 |
-
'is_all_cap': False
|
675 |
-
}})
|
676 |
-
}, {
|
677 |
-
'text': enriched_item('CARS', {'test_signal': {
|
678 |
-
'len': 4,
|
679 |
-
'is_all_cap': True
|
680 |
-
}})
|
681 |
-
}]
|
682 |
-
}]
|
683 |
-
|
684 |
-
|
685 |
-
def test_sort_by_primitive_signal_udf_alias_called_on_repeated(
|
686 |
-
make_test_data: TestDataMaker) -> None:
|
687 |
-
dataset = make_test_data([{
|
688 |
-
UUID_COLUMN: '1',
|
689 |
-
'texts': [{
|
690 |
-
'text': 'eardrop'
|
691 |
-
}, {
|
692 |
-
'text': 'I'
|
693 |
-
}]
|
694 |
-
}, {
|
695 |
-
UUID_COLUMN: '2',
|
696 |
-
'texts': [{
|
697 |
-
'text': 'hey'
|
698 |
-
}, {
|
699 |
-
'text': 'CARS'
|
700 |
-
}]
|
701 |
-
}, {
|
702 |
-
UUID_COLUMN: '3',
|
703 |
-
'texts': [{
|
704 |
-
'text': 'everyone'
|
705 |
-
}, {
|
706 |
-
'text': ''
|
707 |
-
}]
|
708 |
-
}])
|
709 |
-
|
710 |
-
# Equivalent to: SELECT `TestPrimitiveSignal(texts.*.text) AS udf`.
|
711 |
-
texts_udf = Column('texts.*.text', signal_udf=TestPrimitiveSignal(), alias='udf')
|
712 |
-
# Sort by `udf`, where `udf` is an alias to `TestPrimitiveSignal(texts.*.text)`.
|
713 |
-
result = dataset.select_rows(['*', texts_udf],
|
714 |
-
sort_by=['udf'],
|
715 |
-
sort_order=SortOrder.ASC,
|
716 |
-
combine_columns=True)
|
717 |
-
assert list(result) == [{
|
718 |
-
UUID_COLUMN: '3',
|
719 |
-
'texts': [{
|
720 |
-
'text': enriched_item('everyone', {'primitive_signal': 9})
|
721 |
-
}, {
|
722 |
-
'text': enriched_item('', {'primitive_signal': 1})
|
723 |
-
}]
|
724 |
-
}, {
|
725 |
-
UUID_COLUMN: '1',
|
726 |
-
'texts': [{
|
727 |
-
'text': enriched_item('eardrop', {'primitive_signal': 8})
|
728 |
-
}, {
|
729 |
-
'text': enriched_item('I', {'primitive_signal': 2})
|
730 |
-
}]
|
731 |
-
}, {
|
732 |
-
UUID_COLUMN: '2',
|
733 |
-
'texts': [{
|
734 |
-
'text': enriched_item('hey', {'primitive_signal': 4})
|
735 |
-
}, {
|
736 |
-
'text': enriched_item('CARS', {'primitive_signal': 5})
|
737 |
-
}]
|
738 |
-
}]
|
739 |
-
result = dataset.select_rows(['*', texts_udf],
|
740 |
-
sort_by=['udf'],
|
741 |
-
sort_order=SortOrder.DESC,
|
742 |
-
combine_columns=True)
|
743 |
-
assert list(result) == [{
|
744 |
-
UUID_COLUMN: '3',
|
745 |
-
'texts': [{
|
746 |
-
'text': enriched_item('everyone', {'primitive_signal': 9})
|
747 |
-
}, {
|
748 |
-
'text': enriched_item('', {'primitive_signal': 1})
|
749 |
-
}]
|
750 |
-
}, {
|
751 |
-
UUID_COLUMN: '1',
|
752 |
-
'texts': [{
|
753 |
-
'text': enriched_item('eardrop', {'primitive_signal': 8})
|
754 |
-
}, {
|
755 |
-
'text': enriched_item('I', {'primitive_signal': 2})
|
756 |
-
}]
|
757 |
-
}, {
|
758 |
-
UUID_COLUMN: '2',
|
759 |
-
'texts': [{
|
760 |
-
'text': enriched_item('hey', {'primitive_signal': 4})
|
761 |
-
}, {
|
762 |
-
'text': enriched_item('CARS', {'primitive_signal': 5})
|
763 |
-
}]
|
764 |
-
}]
|
765 |
-
|
766 |
-
|
767 |
-
class TopKEmbedding(TextEmbeddingSignal):
|
768 |
-
"""A test embed function."""
|
769 |
-
name = 'topk_embedding'
|
770 |
-
|
771 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
|
772 |
-
"""Call the embedding function."""
|
773 |
-
for example in data:
|
774 |
-
example = cast(str, example)
|
775 |
-
emb_spans: list[Item] = []
|
776 |
-
for i, score in enumerate(example.split('_')):
|
777 |
-
start, end = i * 2, i * 2 + 1
|
778 |
-
vector = np.array([int(score)])
|
779 |
-
emb_spans.append(lilac_embedding(start, end, vector))
|
780 |
-
yield emb_spans
|
781 |
-
|
782 |
-
|
783 |
-
class TopKSignal(TextEmbeddingModelSignal):
|
784 |
-
"""Compute scores along a given concept for documents."""
|
785 |
-
name = 'topk_signal'
|
786 |
-
|
787 |
-
_query = np.array([1])
|
788 |
-
|
789 |
-
def fields(self) -> Field:
|
790 |
-
return field('float32')
|
791 |
-
|
792 |
-
@override
|
793 |
-
def vector_compute(self, keys: Iterable[VectorKey],
|
794 |
-
vector_store: VectorStore) -> Iterable[Optional[Item]]:
|
795 |
-
text_embeddings = vector_store.get(keys)
|
796 |
-
dot_products = text_embeddings.dot(self._query).reshape(-1)
|
797 |
-
return dot_products.tolist()
|
798 |
-
|
799 |
-
@override
|
800 |
-
def vector_compute_topk(
|
801 |
-
self,
|
802 |
-
topk: int,
|
803 |
-
vector_store: VectorStore,
|
804 |
-
keys: Optional[Iterable[VectorKey]] = None) -> Sequence[tuple[VectorKey, Optional[Item]]]:
|
805 |
-
return vector_store.topk(self._query, topk, keys)
|
806 |
-
|
807 |
-
|
808 |
-
def test_sort_by_topk_embedding_udf(make_test_data: TestDataMaker) -> None:
|
809 |
-
dataset = make_test_data([{
|
810 |
-
UUID_COLUMN: '1',
|
811 |
-
'scores': '8_1',
|
812 |
-
}, {
|
813 |
-
UUID_COLUMN: '2',
|
814 |
-
'scores': '3_5'
|
815 |
-
}, {
|
816 |
-
UUID_COLUMN: '3',
|
817 |
-
'scores': '9_7'
|
818 |
-
}])
|
819 |
-
|
820 |
-
dataset.compute_signal(TopKEmbedding(), 'scores')
|
821 |
-
|
822 |
-
# Equivalent to: SELECT `TopKSignal(scores, embedding='...') AS udf`.
|
823 |
-
text_udf = Column('scores', signal_udf=TopKSignal(embedding='topk_embedding'), alias='udf')
|
824 |
-
# Sort by `udf`, where `udf` is an alias to `TopKSignal(scores, embedding='...')`.
|
825 |
-
result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.DESC, limit=3)
|
826 |
-
assert list(result) == [{
|
827 |
-
UUID_COLUMN: '3',
|
828 |
-
'scores': enriched_item(
|
829 |
-
'9_7', {'topk_embedding': [lilac_embedding(0, 1, None),
|
830 |
-
lilac_embedding(2, 3, None)]}),
|
831 |
-
'udf': [9.0, 7.0]
|
832 |
-
}, {
|
833 |
-
UUID_COLUMN: '1',
|
834 |
-
'scores': enriched_item(
|
835 |
-
'8_1', {'topk_embedding': [lilac_embedding(0, 1, None),
|
836 |
-
lilac_embedding(2, 3, None)]}),
|
837 |
-
'udf': [8.0, 1.0]
|
838 |
-
}]
|
839 |
-
|
840 |
-
# Same but set limit to 4.
|
841 |
-
result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.DESC, limit=4)
|
842 |
-
assert list(result) == [{
|
843 |
-
UUID_COLUMN: '3',
|
844 |
-
'scores': enriched_item(
|
845 |
-
'9_7', {'topk_embedding': [lilac_embedding(0, 1, None),
|
846 |
-
lilac_embedding(2, 3, None)]}),
|
847 |
-
'udf': [9.0, 7.0]
|
848 |
-
}, {
|
849 |
-
UUID_COLUMN: '1',
|
850 |
-
'scores': enriched_item(
|
851 |
-
'8_1', {'topk_embedding': [lilac_embedding(0, 1, None),
|
852 |
-
lilac_embedding(2, 3, None)]}),
|
853 |
-
'udf': [8.0, 1.0]
|
854 |
-
}, {
|
855 |
-
UUID_COLUMN: '2',
|
856 |
-
'scores': enriched_item(
|
857 |
-
'3_5', {'topk_embedding': [lilac_embedding(0, 1, None),
|
858 |
-
lilac_embedding(2, 3, None)]}),
|
859 |
-
'udf': [3.0, 5.0]
|
860 |
-
}]
|
861 |
-
|
862 |
-
|
863 |
-
def test_sort_by_topk_udf_with_filter(make_test_data: TestDataMaker) -> None:
|
864 |
-
dataset = make_test_data([{
|
865 |
-
UUID_COLUMN: '1',
|
866 |
-
'scores': '8_1',
|
867 |
-
'active': True
|
868 |
-
}, {
|
869 |
-
UUID_COLUMN: '2',
|
870 |
-
'scores': '3_5',
|
871 |
-
'active': True
|
872 |
-
}, {
|
873 |
-
UUID_COLUMN: '3',
|
874 |
-
'scores': '9_7',
|
875 |
-
'active': False
|
876 |
-
}])
|
877 |
-
|
878 |
-
dataset.compute_signal(TopKEmbedding(), 'scores')
|
879 |
-
|
880 |
-
# Equivalent to: SELECT `TopKSignal(scores, embedding='...') AS udf`.
|
881 |
-
text_udf = Column('scores', signal_udf=TopKSignal(embedding='topk_embedding'), alias='udf')
|
882 |
-
# Sort by `udf`, where `udf` is an alias to `TopKSignal(scores, embedding='...')`.
|
883 |
-
result = dataset.select_rows(['*', text_udf],
|
884 |
-
sort_by=['udf'],
|
885 |
-
filters=[('active', BinaryOp.EQUALS, True)],
|
886 |
-
sort_order=SortOrder.DESC,
|
887 |
-
limit=2)
|
888 |
-
# We make sure that '3' is not in the result, because it is not active, even though it has the
|
889 |
-
# highest topk score.
|
890 |
-
assert list(result) == [{
|
891 |
-
UUID_COLUMN: '1',
|
892 |
-
'active': True,
|
893 |
-
'scores': enriched_item(
|
894 |
-
'8_1', {'topk_embedding': [lilac_embedding(0, 1, None),
|
895 |
-
lilac_embedding(2, 3, None)]}),
|
896 |
-
'udf': [8.0, 1.0]
|
897 |
-
}, {
|
898 |
-
UUID_COLUMN: '2',
|
899 |
-
'active': True,
|
900 |
-
'scores': enriched_item(
|
901 |
-
'3_5', {'topk_embedding': [lilac_embedding(0, 1, None),
|
902 |
-
lilac_embedding(2, 3, None)]}),
|
903 |
-
'udf': [3.0, 5.0]
|
904 |
-
}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/dataset_select_rows_udf_test.py
DELETED
@@ -1,404 +0,0 @@
|
|
1 |
-
"""Tests for dataset.select_rows(udf_col)."""
|
2 |
-
|
3 |
-
from typing import Iterable, Optional, cast
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import pytest
|
7 |
-
from typing_extensions import override
|
8 |
-
|
9 |
-
from ..embeddings.vector_store import VectorStore
|
10 |
-
from ..schema import UUID_COLUMN, VALUE_KEY, Field, Item, RichData, VectorKey, field
|
11 |
-
from ..signals.signal import (
|
12 |
-
TextEmbeddingModelSignal,
|
13 |
-
TextEmbeddingSignal,
|
14 |
-
TextSignal,
|
15 |
-
TextSplitterSignal,
|
16 |
-
clear_signal_registry,
|
17 |
-
register_signal,
|
18 |
-
)
|
19 |
-
from .dataset import BinaryFilterTuple, BinaryOp, Column, val
|
20 |
-
from .dataset_test_utils import TestDataMaker, enriched_item
|
21 |
-
from .dataset_utils import lilac_embedding, lilac_span
|
22 |
-
|
23 |
-
EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]),
|
24 |
-
('hello2.', [1.0, 1.0, 0.0]),
|
25 |
-
('hello world.', [1.0, 1.0, 1.0]),
|
26 |
-
('hello world2.', [2.0, 1.0, 1.0])]
|
27 |
-
|
28 |
-
STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS}
|
29 |
-
|
30 |
-
|
31 |
-
class TestEmbedding(TextEmbeddingSignal):
|
32 |
-
"""A test embed function."""
|
33 |
-
name = 'test_embedding'
|
34 |
-
|
35 |
-
@override
|
36 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
|
37 |
-
"""Call the embedding function."""
|
38 |
-
for example in data:
|
39 |
-
yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))]
|
40 |
-
|
41 |
-
|
42 |
-
class LengthSignal(TextSignal):
|
43 |
-
name = 'length_signal'
|
44 |
-
|
45 |
-
_call_count: int = 0
|
46 |
-
|
47 |
-
def fields(self) -> Field:
|
48 |
-
return field('int32')
|
49 |
-
|
50 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
51 |
-
for text_content in data:
|
52 |
-
self._call_count += 1
|
53 |
-
yield len(text_content)
|
54 |
-
|
55 |
-
|
56 |
-
class TestSignal(TextSignal):
|
57 |
-
name = 'test_signal'
|
58 |
-
|
59 |
-
@override
|
60 |
-
def fields(self) -> Field:
|
61 |
-
return field(fields={'len': 'int32', 'flen': 'float32'})
|
62 |
-
|
63 |
-
@override
|
64 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
65 |
-
return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data]
|
66 |
-
|
67 |
-
|
68 |
-
class TestEmbeddingSumSignal(TextEmbeddingModelSignal):
|
69 |
-
"""Sums the embeddings to return a single floating point value."""
|
70 |
-
name = 'test_embedding_sum'
|
71 |
-
|
72 |
-
@override
|
73 |
-
def fields(self) -> Field:
|
74 |
-
return field('float32')
|
75 |
-
|
76 |
-
@override
|
77 |
-
def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Item]:
|
78 |
-
# The signal just sums the values of the embedding.
|
79 |
-
embedding_sums = vector_store.get(keys).sum(axis=1)
|
80 |
-
for embedding_sum in embedding_sums.tolist():
|
81 |
-
yield embedding_sum
|
82 |
-
|
83 |
-
|
84 |
-
class ComputedKeySignal(TextSignal):
|
85 |
-
name = 'computed_key'
|
86 |
-
|
87 |
-
@override
|
88 |
-
def fields(self) -> Field:
|
89 |
-
return field('int64')
|
90 |
-
|
91 |
-
@override
|
92 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
93 |
-
for text in data:
|
94 |
-
yield 1
|
95 |
-
|
96 |
-
def key(self, is_computed_signal: Optional[bool] = False) -> str:
|
97 |
-
return f'key_{is_computed_signal}'
|
98 |
-
|
99 |
-
|
100 |
-
@pytest.fixture(scope='module', autouse=True)
|
101 |
-
def setup_teardown() -> Iterable[None]:
|
102 |
-
# Setup.
|
103 |
-
register_signal(LengthSignal)
|
104 |
-
register_signal(TestSplitter)
|
105 |
-
register_signal(TestEmbedding)
|
106 |
-
register_signal(TestSignal)
|
107 |
-
register_signal(TestEmbeddingSumSignal)
|
108 |
-
register_signal(ComputedKeySignal)
|
109 |
-
|
110 |
-
# Unit test runs.
|
111 |
-
yield
|
112 |
-
# Teardown.
|
113 |
-
clear_signal_registry()
|
114 |
-
|
115 |
-
|
116 |
-
def test_udf(make_test_data: TestDataMaker) -> None:
|
117 |
-
dataset = make_test_data([{
|
118 |
-
UUID_COLUMN: '1',
|
119 |
-
'text': 'hello'
|
120 |
-
}, {
|
121 |
-
UUID_COLUMN: '2',
|
122 |
-
'text': 'everybody'
|
123 |
-
}])
|
124 |
-
|
125 |
-
signal_col = Column('text', signal_udf=TestSignal())
|
126 |
-
result = dataset.select_rows(['text', signal_col])
|
127 |
-
|
128 |
-
assert list(result) == [{
|
129 |
-
UUID_COLUMN: '1',
|
130 |
-
'text': 'hello',
|
131 |
-
'test_signal(text)': {
|
132 |
-
'len': 5,
|
133 |
-
'flen': 5.0
|
134 |
-
}
|
135 |
-
}, {
|
136 |
-
UUID_COLUMN: '2',
|
137 |
-
'text': 'everybody',
|
138 |
-
'test_signal(text)': {
|
139 |
-
'len': 9,
|
140 |
-
'flen': 9.0
|
141 |
-
}
|
142 |
-
}]
|
143 |
-
|
144 |
-
|
145 |
-
def test_udf_with_filters(make_test_data: TestDataMaker) -> None:
|
146 |
-
dataset = make_test_data([{
|
147 |
-
UUID_COLUMN: '1',
|
148 |
-
'text': 'hello'
|
149 |
-
}, {
|
150 |
-
UUID_COLUMN: '2',
|
151 |
-
'text': 'everybody'
|
152 |
-
}])
|
153 |
-
|
154 |
-
signal_col = Column('text', signal_udf=TestSignal())
|
155 |
-
# Filter by source feature.
|
156 |
-
filters: list[BinaryFilterTuple] = [('text', BinaryOp.EQUALS, 'everybody')]
|
157 |
-
result = dataset.select_rows(['text', signal_col], filters=filters)
|
158 |
-
assert list(result) == [{
|
159 |
-
UUID_COLUMN: '2',
|
160 |
-
'text': 'everybody',
|
161 |
-
'test_signal(text)': {
|
162 |
-
'len': 9,
|
163 |
-
'flen': 9.0
|
164 |
-
}
|
165 |
-
}]
|
166 |
-
|
167 |
-
|
168 |
-
def test_udf_with_uuid_filter(make_test_data: TestDataMaker) -> None:
|
169 |
-
|
170 |
-
dataset = make_test_data([{
|
171 |
-
UUID_COLUMN: '1',
|
172 |
-
'text': 'hello'
|
173 |
-
}, {
|
174 |
-
UUID_COLUMN: '2',
|
175 |
-
'text': 'everybody'
|
176 |
-
}])
|
177 |
-
|
178 |
-
# Filter by a specific UUID.
|
179 |
-
filters: list[BinaryFilterTuple] = [(UUID_COLUMN, BinaryOp.EQUALS, '1')]
|
180 |
-
udf_col = Column('text', signal_udf=LengthSignal())
|
181 |
-
result = dataset.select_rows(['text', udf_col], filters=filters)
|
182 |
-
assert list(result) == [{UUID_COLUMN: '1', 'text': 'hello', 'length_signal(text)': 5}]
|
183 |
-
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 1
|
184 |
-
|
185 |
-
filters = [(UUID_COLUMN, BinaryOp.EQUALS, '2')]
|
186 |
-
result = dataset.select_rows(['text', udf_col], filters=filters)
|
187 |
-
assert list(result) == [{UUID_COLUMN: '2', 'text': 'everybody', 'length_signal(text)': 9}]
|
188 |
-
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 1 + 1
|
189 |
-
|
190 |
-
# No filters.
|
191 |
-
result = dataset.select_rows(['text', udf_col])
|
192 |
-
assert list(result) == [{
|
193 |
-
UUID_COLUMN: '1',
|
194 |
-
'text': 'hello',
|
195 |
-
'length_signal(text)': 5
|
196 |
-
}, {
|
197 |
-
UUID_COLUMN: '2',
|
198 |
-
'text': 'everybody',
|
199 |
-
'length_signal(text)': 9
|
200 |
-
}]
|
201 |
-
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2 + 2
|
202 |
-
|
203 |
-
|
204 |
-
def test_udf_with_uuid_filter_repeated(make_test_data: TestDataMaker) -> None:
|
205 |
-
|
206 |
-
dataset = make_test_data([{
|
207 |
-
UUID_COLUMN: '1',
|
208 |
-
'text': ['hello', 'hi']
|
209 |
-
}, {
|
210 |
-
UUID_COLUMN: '2',
|
211 |
-
'text': ['everybody', 'bye', 'test']
|
212 |
-
}])
|
213 |
-
|
214 |
-
# Filter by a specific UUID.
|
215 |
-
filters: list[BinaryFilterTuple] = [(UUID_COLUMN, BinaryOp.EQUALS, '1')]
|
216 |
-
udf_col = Column(('text', '*'), signal_udf=LengthSignal())
|
217 |
-
result = dataset.select_rows(['text', udf_col], filters=filters)
|
218 |
-
assert list(result) == [{
|
219 |
-
UUID_COLUMN: '1',
|
220 |
-
'text': ['hello', 'hi'],
|
221 |
-
'length_signal(text)': [5, 2]
|
222 |
-
}]
|
223 |
-
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2
|
224 |
-
|
225 |
-
# Filter by a specific UUID.
|
226 |
-
filters = [(UUID_COLUMN, BinaryOp.EQUALS, '2')]
|
227 |
-
result = dataset.select_rows(['text', udf_col], filters=filters)
|
228 |
-
assert list(result) == [{
|
229 |
-
UUID_COLUMN: '2',
|
230 |
-
'text': ['everybody', 'bye', 'test'],
|
231 |
-
'length_signal(text)': [9, 3, 4]
|
232 |
-
}]
|
233 |
-
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2 + 3
|
234 |
-
|
235 |
-
|
236 |
-
def test_udf_deeply_nested(make_test_data: TestDataMaker) -> None:
|
237 |
-
dataset = make_test_data([{
|
238 |
-
UUID_COLUMN: '1',
|
239 |
-
'text': [['hello'], ['hi', 'bye']]
|
240 |
-
}, {
|
241 |
-
UUID_COLUMN: '2',
|
242 |
-
'text': [['everybody', 'bye'], ['test']]
|
243 |
-
}])
|
244 |
-
|
245 |
-
udf_col = Column(('text', '*', '*'), signal_udf=LengthSignal())
|
246 |
-
result = dataset.select_rows([udf_col])
|
247 |
-
assert list(result) == [{
|
248 |
-
UUID_COLUMN: '1',
|
249 |
-
'length_signal(text.*)': [[5], [2, 3]]
|
250 |
-
}, {
|
251 |
-
UUID_COLUMN: '2',
|
252 |
-
'length_signal(text.*)': [[9, 3], [4]]
|
253 |
-
}]
|
254 |
-
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 6
|
255 |
-
|
256 |
-
|
257 |
-
def test_udf_with_embedding(make_test_data: TestDataMaker) -> None:
|
258 |
-
dataset = make_test_data([{
|
259 |
-
UUID_COLUMN: '1',
|
260 |
-
'text': 'hello.',
|
261 |
-
}, {
|
262 |
-
UUID_COLUMN: '2',
|
263 |
-
'text': 'hello2.',
|
264 |
-
}])
|
265 |
-
|
266 |
-
dataset.compute_signal(TestEmbedding(), 'text')
|
267 |
-
|
268 |
-
signal_col = Column('text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'))
|
269 |
-
result = dataset.select_rows([val('text'), signal_col])
|
270 |
-
|
271 |
-
expected_result: list[Item] = [{
|
272 |
-
UUID_COLUMN: '1',
|
273 |
-
f'text.{VALUE_KEY}': 'hello.',
|
274 |
-
'test_embedding_sum(text.test_embedding.*.embedding)': [1.0]
|
275 |
-
}, {
|
276 |
-
UUID_COLUMN: '2',
|
277 |
-
f'text.{VALUE_KEY}': 'hello2.',
|
278 |
-
'test_embedding_sum(text.test_embedding.*.embedding)': [2.0]
|
279 |
-
}]
|
280 |
-
assert list(result) == expected_result
|
281 |
-
|
282 |
-
# Select rows with alias.
|
283 |
-
signal_col = Column(
|
284 |
-
'text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'), alias='emb_sum')
|
285 |
-
result = dataset.select_rows([val('text'), signal_col])
|
286 |
-
expected_result = [{
|
287 |
-
UUID_COLUMN: '1',
|
288 |
-
f'text.{VALUE_KEY}': 'hello.',
|
289 |
-
'emb_sum': [1.0]
|
290 |
-
}, {
|
291 |
-
UUID_COLUMN: '2',
|
292 |
-
f'text.{VALUE_KEY}': 'hello2.',
|
293 |
-
'emb_sum': [2.0]
|
294 |
-
}]
|
295 |
-
assert list(result) == expected_result
|
296 |
-
|
297 |
-
|
298 |
-
def test_udf_with_nested_embedding(make_test_data: TestDataMaker) -> None:
|
299 |
-
dataset = make_test_data([{
|
300 |
-
UUID_COLUMN: '1',
|
301 |
-
'text': ['hello.', 'hello world.'],
|
302 |
-
}, {
|
303 |
-
UUID_COLUMN: '2',
|
304 |
-
'text': ['hello world2.', 'hello2.'],
|
305 |
-
}])
|
306 |
-
|
307 |
-
dataset.compute_signal(TestEmbedding(), ('text', '*'))
|
308 |
-
|
309 |
-
signal_col = Column(('text', '*'), signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'))
|
310 |
-
result = dataset.select_rows([val(('text', '*')), signal_col])
|
311 |
-
expected_result = [{
|
312 |
-
UUID_COLUMN: '1',
|
313 |
-
f'text.*.{VALUE_KEY}': ['hello.', 'hello world.'],
|
314 |
-
'test_embedding_sum(text.*.test_embedding.*.embedding)': [[1.0], [3.0]]
|
315 |
-
}, {
|
316 |
-
UUID_COLUMN: '2',
|
317 |
-
f'text.*.{VALUE_KEY}': ['hello world2.', 'hello2.'],
|
318 |
-
'test_embedding_sum(text.*.test_embedding.*.embedding)': [[4.0], [2.0]]
|
319 |
-
}]
|
320 |
-
assert list(result) == expected_result
|
321 |
-
|
322 |
-
|
323 |
-
def test_udf_throws_without_precomputing(make_test_data: TestDataMaker) -> None:
|
324 |
-
dataset = make_test_data([{
|
325 |
-
UUID_COLUMN: '1',
|
326 |
-
'text': 'hello.',
|
327 |
-
}, {
|
328 |
-
UUID_COLUMN: '2',
|
329 |
-
'text': 'hello2.',
|
330 |
-
}])
|
331 |
-
|
332 |
-
# Embedding is not precomputed, yet we ask for the embedding.
|
333 |
-
|
334 |
-
signal_col = Column('text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'))
|
335 |
-
|
336 |
-
with pytest.raises(ValueError, match='Embedding signal "test_embedding" is not computed'):
|
337 |
-
dataset.select_rows([val('text'), signal_col])
|
338 |
-
|
339 |
-
|
340 |
-
class TestSplitter(TextSplitterSignal):
|
341 |
-
"""Split documents into sentence by splitting on period."""
|
342 |
-
name = 'test_splitter'
|
343 |
-
|
344 |
-
@override
|
345 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
|
346 |
-
for text in data:
|
347 |
-
if not isinstance(text, str):
|
348 |
-
raise ValueError(f'Expected text to be a string, got {type(text)} instead.')
|
349 |
-
result: list[Item] = []
|
350 |
-
for sentence in text.split('.'):
|
351 |
-
start = text.index(sentence)
|
352 |
-
end = start + len(sentence)
|
353 |
-
result.append(lilac_span(start, end))
|
354 |
-
yield result
|
355 |
-
|
356 |
-
|
357 |
-
def test_udf_after_precomputed_split(make_test_data: TestDataMaker) -> None:
|
358 |
-
dataset = make_test_data([{
|
359 |
-
UUID_COLUMN: '1',
|
360 |
-
'text': 'sentence 1. sentence 2 is longer',
|
361 |
-
}, {
|
362 |
-
UUID_COLUMN: '2',
|
363 |
-
'text': 'sentence 1 is longer. sent2 is short',
|
364 |
-
}])
|
365 |
-
dataset.compute_signal(TestSplitter(), 'text')
|
366 |
-
udf = Column('text', signal_udf=LengthSignal())
|
367 |
-
result = dataset.select_rows(['*', udf], combine_columns=True)
|
368 |
-
assert list(result) == [{
|
369 |
-
UUID_COLUMN: '1',
|
370 |
-
'text': enriched_item('sentence 1. sentence 2 is longer', {
|
371 |
-
'length_signal': 32,
|
372 |
-
'test_splitter': [lilac_span(0, 10), lilac_span(11, 32)]
|
373 |
-
})
|
374 |
-
}, {
|
375 |
-
UUID_COLUMN: '2',
|
376 |
-
'text': enriched_item('sentence 1 is longer. sent2 is short', {
|
377 |
-
'length_signal': 36,
|
378 |
-
'test_splitter': [lilac_span(0, 20), lilac_span(21, 36)]
|
379 |
-
})
|
380 |
-
}]
|
381 |
-
|
382 |
-
|
383 |
-
def test_is_computed_signal_key(make_test_data: TestDataMaker) -> None:
|
384 |
-
dataset = make_test_data([{
|
385 |
-
UUID_COLUMN: '1',
|
386 |
-
'text': 'hello.',
|
387 |
-
}, {
|
388 |
-
UUID_COLUMN: '2',
|
389 |
-
'text': 'hello2.',
|
390 |
-
}])
|
391 |
-
|
392 |
-
signal_col = Column('text', signal_udf=ComputedKeySignal())
|
393 |
-
# Filter by source feature.
|
394 |
-
filters: list[BinaryFilterTuple] = [('text', BinaryOp.EQUALS, 'everybody')]
|
395 |
-
result = dataset.select_rows(['text', signal_col])
|
396 |
-
assert list(result) == [{
|
397 |
-
UUID_COLUMN: '1',
|
398 |
-
'text': 'hello.',
|
399 |
-
'key_False(text)': 1
|
400 |
-
}, {
|
401 |
-
UUID_COLUMN: '2',
|
402 |
-
'text': 'hello2.',
|
403 |
-
'key_False(text)': 1
|
404 |
-
}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/dataset_stats_test.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
"""Tests for dataset.stats()."""
|
2 |
-
|
3 |
-
from typing import Any, cast
|
4 |
-
|
5 |
-
import pytest
|
6 |
-
from pytest_mock import MockerFixture
|
7 |
-
|
8 |
-
from ..schema import UUID_COLUMN, Item, schema
|
9 |
-
from . import dataset_duckdb
|
10 |
-
from .dataset import StatsResult
|
11 |
-
from .dataset_test_utils import TestDataMaker
|
12 |
-
|
13 |
-
SIMPLE_ITEMS: list[Item] = [{
|
14 |
-
UUID_COLUMN: '1',
|
15 |
-
'str': 'a',
|
16 |
-
'int': 1,
|
17 |
-
'bool': False,
|
18 |
-
'float': 3.0,
|
19 |
-
}, {
|
20 |
-
UUID_COLUMN: '2',
|
21 |
-
'str': 'b',
|
22 |
-
'int': 2,
|
23 |
-
'bool': True,
|
24 |
-
'float': 2.0
|
25 |
-
}, {
|
26 |
-
UUID_COLUMN: '3',
|
27 |
-
'str': 'b',
|
28 |
-
'int': 2,
|
29 |
-
'bool': True,
|
30 |
-
'float': 1.0
|
31 |
-
}, {
|
32 |
-
UUID_COLUMN: '4',
|
33 |
-
'float': float('nan')
|
34 |
-
}]
|
35 |
-
|
36 |
-
|
37 |
-
def test_simple_stats(make_test_data: TestDataMaker) -> None:
|
38 |
-
dataset = make_test_data(SIMPLE_ITEMS)
|
39 |
-
|
40 |
-
result = dataset.stats(leaf_path='str')
|
41 |
-
assert result == StatsResult(
|
42 |
-
path=('str',), total_count=3, approx_count_distinct=2, avg_text_length=1)
|
43 |
-
|
44 |
-
result = dataset.stats(leaf_path='float')
|
45 |
-
assert result == StatsResult(
|
46 |
-
path=('float',), total_count=4, approx_count_distinct=4, min_val=1.0, max_val=3.0)
|
47 |
-
|
48 |
-
result = dataset.stats(leaf_path='bool')
|
49 |
-
assert result == StatsResult(path=('bool',), total_count=3, approx_count_distinct=2)
|
50 |
-
|
51 |
-
result = dataset.stats(leaf_path='int')
|
52 |
-
assert result == StatsResult(
|
53 |
-
path=('int',), total_count=3, approx_count_distinct=2, min_val=1, max_val=2)
|
54 |
-
|
55 |
-
|
56 |
-
def test_nested_stats(make_test_data: TestDataMaker) -> None:
|
57 |
-
nested_items: list[Item] = [
|
58 |
-
{
|
59 |
-
'name': 'Name1',
|
60 |
-
'addresses': [{
|
61 |
-
'zips': [5, 8]
|
62 |
-
}]
|
63 |
-
},
|
64 |
-
{
|
65 |
-
'name': 'Name2',
|
66 |
-
'addresses': [{
|
67 |
-
'zips': [3]
|
68 |
-
}, {
|
69 |
-
'zips': [11, 8]
|
70 |
-
}]
|
71 |
-
},
|
72 |
-
{
|
73 |
-
'name': 'Name2',
|
74 |
-
'addresses': []
|
75 |
-
}, # No addresses.
|
76 |
-
{
|
77 |
-
'name': 'Name2',
|
78 |
-
'addresses': [{
|
79 |
-
'zips': []
|
80 |
-
}]
|
81 |
-
} # No zips in the first address.
|
82 |
-
]
|
83 |
-
nested_schema = schema({
|
84 |
-
UUID_COLUMN: 'string',
|
85 |
-
'name': 'string',
|
86 |
-
'addresses': [{
|
87 |
-
'zips': ['int32']
|
88 |
-
}]
|
89 |
-
})
|
90 |
-
dataset = make_test_data(nested_items, schema=nested_schema)
|
91 |
-
|
92 |
-
result = dataset.stats(leaf_path='name')
|
93 |
-
assert result == StatsResult(
|
94 |
-
path=('name',), total_count=4, approx_count_distinct=2, avg_text_length=5)
|
95 |
-
|
96 |
-
result = dataset.stats(leaf_path='addresses.*.zips.*')
|
97 |
-
assert result == StatsResult(
|
98 |
-
path=('addresses', '*', 'zips', '*'),
|
99 |
-
total_count=5,
|
100 |
-
approx_count_distinct=4,
|
101 |
-
min_val=3,
|
102 |
-
max_val=11)
|
103 |
-
|
104 |
-
|
105 |
-
def test_stats_approximation(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
|
106 |
-
sample_size = 5
|
107 |
-
mocker.patch(f'{dataset_duckdb.__name__}.SAMPLE_SIZE_DISTINCT_COUNT', sample_size)
|
108 |
-
|
109 |
-
nested_items: list[Item] = [{'feature': str(i)} for i in range(sample_size * 10)]
|
110 |
-
nested_schema = schema({UUID_COLUMN: 'string', 'feature': 'string'})
|
111 |
-
dataset = make_test_data(nested_items, schema=nested_schema)
|
112 |
-
|
113 |
-
result = dataset.stats(leaf_path='feature')
|
114 |
-
assert result == StatsResult(
|
115 |
-
path=('feature',), total_count=50, approx_count_distinct=50, avg_text_length=1)
|
116 |
-
|
117 |
-
|
118 |
-
def test_error_handling(make_test_data: TestDataMaker) -> None:
|
119 |
-
dataset = make_test_data(SIMPLE_ITEMS)
|
120 |
-
|
121 |
-
with pytest.raises(ValueError, match='leaf_path must be provided'):
|
122 |
-
dataset.stats(cast(Any, None))
|
123 |
-
|
124 |
-
with pytest.raises(ValueError, match='Leaf "\\(\'unknown\',\\)" not found in dataset'):
|
125 |
-
dataset.stats(leaf_path='unknown')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/dataset_test.py
DELETED
@@ -1,860 +0,0 @@
|
|
1 |
-
"""Implementation-agnostic tests of the Dataset DB API."""
|
2 |
-
|
3 |
-
from typing import Iterable, Optional, cast
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import pytest
|
7 |
-
from typing_extensions import override
|
8 |
-
|
9 |
-
from ..schema import UUID_COLUMN, VALUE_KEY, Field, Item, RichData, field, schema
|
10 |
-
from ..signals.signal import TextEmbeddingSignal, TextSignal, clear_signal_registry, register_signal
|
11 |
-
from .dataset import Column, DatasetManifest, val
|
12 |
-
from .dataset_test_utils import TEST_DATASET_NAME, TEST_NAMESPACE, TestDataMaker, enriched_item
|
13 |
-
from .dataset_utils import lilac_embedding
|
14 |
-
|
15 |
-
SIMPLE_ITEMS: list[Item] = [{
|
16 |
-
UUID_COLUMN: '1',
|
17 |
-
'str': 'a',
|
18 |
-
'int': 1,
|
19 |
-
'bool': False,
|
20 |
-
'float': 3.0
|
21 |
-
}, {
|
22 |
-
UUID_COLUMN: '2',
|
23 |
-
'str': 'b',
|
24 |
-
'int': 2,
|
25 |
-
'bool': True,
|
26 |
-
'float': 2.0
|
27 |
-
}, {
|
28 |
-
UUID_COLUMN: '3',
|
29 |
-
'str': 'b',
|
30 |
-
'int': 2,
|
31 |
-
'bool': True,
|
32 |
-
'float': 1.0
|
33 |
-
}]
|
34 |
-
|
35 |
-
EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]),
|
36 |
-
('hello2.', [1.0, 1.0, 0.0]),
|
37 |
-
('hello world.', [1.0, 1.0, 1.0]),
|
38 |
-
('hello world2.', [2.0, 1.0, 1.0])]
|
39 |
-
|
40 |
-
STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS}
|
41 |
-
|
42 |
-
|
43 |
-
class TestEmbedding(TextEmbeddingSignal):
|
44 |
-
"""A test embed function."""
|
45 |
-
name = 'test_embedding'
|
46 |
-
|
47 |
-
@override
|
48 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
|
49 |
-
"""Call the embedding function."""
|
50 |
-
for example in data:
|
51 |
-
yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))]
|
52 |
-
|
53 |
-
|
54 |
-
class LengthSignal(TextSignal):
|
55 |
-
name = 'length_signal'
|
56 |
-
|
57 |
-
_call_count: int = 0
|
58 |
-
|
59 |
-
def fields(self) -> Field:
|
60 |
-
return field('int32')
|
61 |
-
|
62 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
63 |
-
for text_content in data:
|
64 |
-
self._call_count += 1
|
65 |
-
yield len(text_content)
|
66 |
-
|
67 |
-
|
68 |
-
class TestSignal(TextSignal):
|
69 |
-
name = 'test_signal'
|
70 |
-
|
71 |
-
@override
|
72 |
-
def fields(self) -> Field:
|
73 |
-
return field(fields={'len': 'int32', 'flen': 'float32'})
|
74 |
-
|
75 |
-
@override
|
76 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
77 |
-
return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data]
|
78 |
-
|
79 |
-
|
80 |
-
@pytest.fixture(scope='module', autouse=True)
|
81 |
-
def setup_teardown() -> Iterable[None]:
|
82 |
-
# Setup.
|
83 |
-
register_signal(TestSignal)
|
84 |
-
register_signal(LengthSignal)
|
85 |
-
register_signal(SignalWithQuoteInIt)
|
86 |
-
register_signal(SignalWithDoubleQuoteInIt)
|
87 |
-
|
88 |
-
# Unit test runs.
|
89 |
-
yield
|
90 |
-
|
91 |
-
# Teardown.
|
92 |
-
clear_signal_registry()
|
93 |
-
|
94 |
-
|
95 |
-
def test_select_all_columns(make_test_data: TestDataMaker) -> None:
|
96 |
-
dataset = make_test_data(SIMPLE_ITEMS)
|
97 |
-
|
98 |
-
result = dataset.select_rows()
|
99 |
-
assert list(result) == SIMPLE_ITEMS
|
100 |
-
|
101 |
-
|
102 |
-
def test_select_subcols_with_dot_seperator(make_test_data: TestDataMaker) -> None:
|
103 |
-
items: list[Item] = [{
|
104 |
-
UUID_COLUMN: '1',
|
105 |
-
'people': [{
|
106 |
-
'name': 'A',
|
107 |
-
'address': {
|
108 |
-
'zip': 1
|
109 |
-
}
|
110 |
-
}, {
|
111 |
-
'name': 'B',
|
112 |
-
'address': {
|
113 |
-
'zip': 2
|
114 |
-
}
|
115 |
-
}]
|
116 |
-
}, {
|
117 |
-
UUID_COLUMN: '2',
|
118 |
-
'people': [{
|
119 |
-
'name': 'C',
|
120 |
-
'address': {
|
121 |
-
'zip': 3
|
122 |
-
}
|
123 |
-
}]
|
124 |
-
}]
|
125 |
-
dataset = make_test_data(items)
|
126 |
-
|
127 |
-
result = dataset.select_rows(['people.*.name', 'people.*.address.zip'])
|
128 |
-
assert list(result) == [{
|
129 |
-
UUID_COLUMN: '1',
|
130 |
-
'people.*.name': ['A', 'B'],
|
131 |
-
'people.*.address.zip': [1, 2]
|
132 |
-
}, {
|
133 |
-
UUID_COLUMN: '2',
|
134 |
-
'people.*.name': ['C'],
|
135 |
-
'people.*.address.zip': [3]
|
136 |
-
}]
|
137 |
-
|
138 |
-
result = dataset.select_rows(['people.*.address.zip'], combine_columns=True)
|
139 |
-
assert list(result) == [{
|
140 |
-
UUID_COLUMN: '1',
|
141 |
-
'people': [{
|
142 |
-
'address': {
|
143 |
-
'zip': 1
|
144 |
-
}
|
145 |
-
}, {
|
146 |
-
'address': {
|
147 |
-
'zip': 2
|
148 |
-
}
|
149 |
-
}]
|
150 |
-
}, {
|
151 |
-
UUID_COLUMN: '2',
|
152 |
-
'people': [{
|
153 |
-
'address': {
|
154 |
-
'zip': 3
|
155 |
-
}
|
156 |
-
}]
|
157 |
-
}]
|
158 |
-
|
159 |
-
result = dataset.select_rows(['people'])
|
160 |
-
assert list(result) == items
|
161 |
-
|
162 |
-
|
163 |
-
def test_select_subcols_with_escaped_dot(make_test_data: TestDataMaker) -> None:
|
164 |
-
items: list[Item] = [{
|
165 |
-
UUID_COLUMN: '1',
|
166 |
-
'people.new': [{
|
167 |
-
'name': 'A'
|
168 |
-
}, {
|
169 |
-
'name': 'B'
|
170 |
-
}]
|
171 |
-
}, {
|
172 |
-
UUID_COLUMN: '2',
|
173 |
-
'people.new': [{
|
174 |
-
'name': 'C'
|
175 |
-
}]
|
176 |
-
}]
|
177 |
-
dataset = make_test_data(items)
|
178 |
-
|
179 |
-
result = dataset.select_rows(['"people.new".*.name'])
|
180 |
-
assert list(result) == [{
|
181 |
-
UUID_COLUMN: '1',
|
182 |
-
'people.new.*.name': ['A', 'B'],
|
183 |
-
}, {
|
184 |
-
UUID_COLUMN: '2',
|
185 |
-
'people.new.*.name': ['C'],
|
186 |
-
}]
|
187 |
-
|
188 |
-
# Escape name even though it does not need to be.
|
189 |
-
result = dataset.select_rows(['"people.new".*."name"'])
|
190 |
-
assert list(result) == [{
|
191 |
-
UUID_COLUMN: '1',
|
192 |
-
'people.new.*.name': ['A', 'B'],
|
193 |
-
}, {
|
194 |
-
UUID_COLUMN: '2',
|
195 |
-
'people.new.*.name': ['C'],
|
196 |
-
}]
|
197 |
-
|
198 |
-
|
199 |
-
def test_select_star(make_test_data: TestDataMaker) -> None:
|
200 |
-
items: list[Item] = [{
|
201 |
-
UUID_COLUMN: '1',
|
202 |
-
'name': 'A',
|
203 |
-
'info': {
|
204 |
-
'age': 40
|
205 |
-
}
|
206 |
-
}, {
|
207 |
-
UUID_COLUMN: '2',
|
208 |
-
'name': 'B',
|
209 |
-
'info': {
|
210 |
-
'age': 42
|
211 |
-
}
|
212 |
-
}]
|
213 |
-
dataset = make_test_data(items)
|
214 |
-
|
215 |
-
# Select *.
|
216 |
-
result = dataset.select_rows(['*'])
|
217 |
-
assert list(result) == items
|
218 |
-
|
219 |
-
# Select (*,).
|
220 |
-
result = dataset.select_rows([('*',)])
|
221 |
-
assert list(result) == items
|
222 |
-
|
223 |
-
# Select *, plus a redundant `info` column.
|
224 |
-
result = dataset.select_rows(['*', 'info'])
|
225 |
-
assert list(result) == [{
|
226 |
-
UUID_COLUMN: '1',
|
227 |
-
'name': 'A',
|
228 |
-
'info': {
|
229 |
-
'age': 40
|
230 |
-
},
|
231 |
-
'info_2': {
|
232 |
-
'age': 40
|
233 |
-
},
|
234 |
-
}, {
|
235 |
-
UUID_COLUMN: '2',
|
236 |
-
'name': 'B',
|
237 |
-
'info': {
|
238 |
-
'age': 42
|
239 |
-
},
|
240 |
-
'info_2': {
|
241 |
-
'age': 42
|
242 |
-
},
|
243 |
-
}]
|
244 |
-
|
245 |
-
# Select * plus an inner `info.age` column.
|
246 |
-
result = dataset.select_rows(['*', ('info', 'age')])
|
247 |
-
assert list(result) == [{
|
248 |
-
UUID_COLUMN: '1',
|
249 |
-
'name': 'A',
|
250 |
-
'info': {
|
251 |
-
'age': 40
|
252 |
-
},
|
253 |
-
'info.age': 40
|
254 |
-
}, {
|
255 |
-
UUID_COLUMN: '2',
|
256 |
-
'name': 'B',
|
257 |
-
'info': {
|
258 |
-
'age': 42
|
259 |
-
},
|
260 |
-
'info.age': 42
|
261 |
-
}]
|
262 |
-
|
263 |
-
|
264 |
-
def test_select_star_with_combine_cols(make_test_data: TestDataMaker) -> None:
|
265 |
-
items: list[Item] = [{
|
266 |
-
UUID_COLUMN: '1',
|
267 |
-
'name': 'A',
|
268 |
-
'info': {
|
269 |
-
'age': 40
|
270 |
-
}
|
271 |
-
}, {
|
272 |
-
UUID_COLUMN: '2',
|
273 |
-
'name': 'B',
|
274 |
-
'info': {
|
275 |
-
'age': 42
|
276 |
-
}
|
277 |
-
}]
|
278 |
-
dataset = make_test_data(items)
|
279 |
-
|
280 |
-
# Select *.
|
281 |
-
result = dataset.select_rows(['*'], combine_columns=True)
|
282 |
-
assert list(result) == items
|
283 |
-
|
284 |
-
# Select *, plus a redundant `info` column.
|
285 |
-
result = dataset.select_rows(['*', 'info'], combine_columns=True)
|
286 |
-
assert list(result) == items
|
287 |
-
|
288 |
-
# Select * plus an inner `info.age` column.
|
289 |
-
result = dataset.select_rows(['*', ('info', 'age')], combine_columns=True)
|
290 |
-
assert list(result) == items
|
291 |
-
|
292 |
-
# Select *, plus redundant `name`, plus a udf.
|
293 |
-
udf = Column('name', signal_udf=TestSignal())
|
294 |
-
result = dataset.select_rows(['*', 'name', udf], combine_columns=True)
|
295 |
-
|
296 |
-
assert list(result) == [{
|
297 |
-
UUID_COLUMN: '1',
|
298 |
-
'name': enriched_item('A', {'test_signal': {
|
299 |
-
'len': 1,
|
300 |
-
'flen': 1.0
|
301 |
-
}}),
|
302 |
-
'info': {
|
303 |
-
'age': 40
|
304 |
-
}
|
305 |
-
}, {
|
306 |
-
UUID_COLUMN: '2',
|
307 |
-
'name': enriched_item('B', {'test_signal': {
|
308 |
-
'len': 1,
|
309 |
-
'flen': 1.0
|
310 |
-
}}),
|
311 |
-
'info': {
|
312 |
-
'age': 42
|
313 |
-
}
|
314 |
-
}]
|
315 |
-
|
316 |
-
|
317 |
-
def test_select_ids(make_test_data: TestDataMaker) -> None:
|
318 |
-
dataset = make_test_data(SIMPLE_ITEMS)
|
319 |
-
|
320 |
-
result = dataset.select_rows([UUID_COLUMN])
|
321 |
-
|
322 |
-
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}, {UUID_COLUMN: '3'}]
|
323 |
-
|
324 |
-
|
325 |
-
def test_select_ids_with_limit_and_offset(make_test_data: TestDataMaker) -> None:
|
326 |
-
items: list[Item] = [{UUID_COLUMN: str(i)} for i in range(10, 20)]
|
327 |
-
dataset = make_test_data(items)
|
328 |
-
|
329 |
-
result = dataset.select_rows([UUID_COLUMN], offset=1, limit=3)
|
330 |
-
assert list(result) == [{UUID_COLUMN: '11'}, {UUID_COLUMN: '12'}, {UUID_COLUMN: '13'}]
|
331 |
-
|
332 |
-
result = dataset.select_rows([UUID_COLUMN], offset=7, limit=2)
|
333 |
-
assert list(result) == [{UUID_COLUMN: '17'}, {UUID_COLUMN: '18'}]
|
334 |
-
|
335 |
-
result = dataset.select_rows([UUID_COLUMN], offset=9, limit=200)
|
336 |
-
assert list(result) == [{UUID_COLUMN: '19'}]
|
337 |
-
|
338 |
-
result = dataset.select_rows([UUID_COLUMN], offset=10, limit=200)
|
339 |
-
assert list(result) == []
|
340 |
-
|
341 |
-
|
342 |
-
def test_columns(make_test_data: TestDataMaker) -> None:
|
343 |
-
dataset = make_test_data(SIMPLE_ITEMS)
|
344 |
-
|
345 |
-
result = dataset.select_rows(['str', 'float'])
|
346 |
-
|
347 |
-
assert list(result) == [{
|
348 |
-
UUID_COLUMN: '1',
|
349 |
-
'str': 'a',
|
350 |
-
'float': 3.0
|
351 |
-
}, {
|
352 |
-
UUID_COLUMN: '2',
|
353 |
-
'str': 'b',
|
354 |
-
'float': 2.0
|
355 |
-
}, {
|
356 |
-
UUID_COLUMN: '3',
|
357 |
-
'str': 'b',
|
358 |
-
'float': 1.0
|
359 |
-
}]
|
360 |
-
|
361 |
-
|
362 |
-
def test_merge_values(make_test_data: TestDataMaker) -> None:
|
363 |
-
dataset = make_test_data([{
|
364 |
-
UUID_COLUMN: '1',
|
365 |
-
'text': 'hello'
|
366 |
-
}, {
|
367 |
-
UUID_COLUMN: '2',
|
368 |
-
'text': 'everybody'
|
369 |
-
}])
|
370 |
-
test_signal = TestSignal()
|
371 |
-
dataset.compute_signal(test_signal, 'text')
|
372 |
-
length_signal = LengthSignal()
|
373 |
-
dataset.compute_signal(length_signal, 'text')
|
374 |
-
|
375 |
-
result = dataset.select_rows(['text'])
|
376 |
-
assert list(result) == [{
|
377 |
-
UUID_COLUMN: '1',
|
378 |
-
'text': enriched_item('hello', {
|
379 |
-
'length_signal': 5,
|
380 |
-
'test_signal': {
|
381 |
-
'len': 5,
|
382 |
-
'flen': 5.0
|
383 |
-
}
|
384 |
-
})
|
385 |
-
}, {
|
386 |
-
UUID_COLUMN: '2',
|
387 |
-
'text': enriched_item('everybody', {
|
388 |
-
'length_signal': 9,
|
389 |
-
'test_signal': {
|
390 |
-
'len': 9,
|
391 |
-
'flen': 9.0
|
392 |
-
}
|
393 |
-
}),
|
394 |
-
}]
|
395 |
-
|
396 |
-
# Test subselection.
|
397 |
-
result = dataset.select_rows(
|
398 |
-
[val('text'), ('text', 'test_signal', 'flen'), ('text', 'test_signal', 'len')])
|
399 |
-
assert list(result) == [{
|
400 |
-
UUID_COLUMN: '1',
|
401 |
-
f'text.{VALUE_KEY}': 'hello',
|
402 |
-
'text.test_signal.flen': 5.0,
|
403 |
-
'text.test_signal.len': 5
|
404 |
-
}, {
|
405 |
-
UUID_COLUMN: '2',
|
406 |
-
f'text.{VALUE_KEY}': 'everybody',
|
407 |
-
'text.test_signal.flen': 9.0,
|
408 |
-
'text.test_signal.len': 9
|
409 |
-
}]
|
410 |
-
|
411 |
-
# Test subselection with combine_columns=True.
|
412 |
-
result = dataset.select_rows(
|
413 |
-
['text', ('text', 'test_signal', 'flen'), ('text', 'test_signal', 'len')], combine_columns=True)
|
414 |
-
assert list(result) == [{
|
415 |
-
UUID_COLUMN: '1',
|
416 |
-
'text': enriched_item('hello', {
|
417 |
-
'length_signal': 5,
|
418 |
-
'test_signal': {
|
419 |
-
'len': 5,
|
420 |
-
'flen': 5.0
|
421 |
-
}
|
422 |
-
})
|
423 |
-
}, {
|
424 |
-
UUID_COLUMN: '2',
|
425 |
-
'text': enriched_item('everybody', {
|
426 |
-
'length_signal': 9,
|
427 |
-
'test_signal': {
|
428 |
-
'len': 9,
|
429 |
-
'flen': 9.0
|
430 |
-
}
|
431 |
-
}),
|
432 |
-
}]
|
433 |
-
|
434 |
-
# Test subselection with aliasing.
|
435 |
-
result = dataset.select_rows(
|
436 |
-
columns=[val('text'), Column(('text', 'test_signal', 'len'), alias='metadata')])
|
437 |
-
assert list(result) == [{
|
438 |
-
UUID_COLUMN: '1',
|
439 |
-
f'text.{VALUE_KEY}': 'hello',
|
440 |
-
'metadata': 5
|
441 |
-
}, {
|
442 |
-
UUID_COLUMN: '2',
|
443 |
-
f'text.{VALUE_KEY}': 'everybody',
|
444 |
-
'metadata': 9
|
445 |
-
}]
|
446 |
-
|
447 |
-
result = dataset.select_rows(columns=[Column(('text'), alias='text_enrichment')])
|
448 |
-
assert list(result) == [{
|
449 |
-
UUID_COLUMN: '1',
|
450 |
-
'text_enrichment': enriched_item('hello', {
|
451 |
-
'length_signal': 5,
|
452 |
-
'test_signal': {
|
453 |
-
'len': 5,
|
454 |
-
'flen': 5.0
|
455 |
-
}
|
456 |
-
})
|
457 |
-
}, {
|
458 |
-
UUID_COLUMN: '2',
|
459 |
-
'text_enrichment': enriched_item('everybody', {
|
460 |
-
'length_signal': 9,
|
461 |
-
'test_signal': {
|
462 |
-
'len': 9,
|
463 |
-
'flen': 9.0
|
464 |
-
}
|
465 |
-
})
|
466 |
-
}]
|
467 |
-
|
468 |
-
|
469 |
-
def test_merge_array_values(make_test_data: TestDataMaker) -> None:
|
470 |
-
dataset = make_test_data([{
|
471 |
-
UUID_COLUMN: '1',
|
472 |
-
'texts': ['hello', 'everybody']
|
473 |
-
}, {
|
474 |
-
UUID_COLUMN: '2',
|
475 |
-
'texts': ['a', 'bc', 'def']
|
476 |
-
}])
|
477 |
-
|
478 |
-
test_signal = TestSignal()
|
479 |
-
dataset.compute_signal(test_signal, ('texts', '*'))
|
480 |
-
length_signal = LengthSignal()
|
481 |
-
dataset.compute_signal(length_signal, ('texts', '*'))
|
482 |
-
|
483 |
-
assert dataset.manifest() == DatasetManifest(
|
484 |
-
namespace=TEST_NAMESPACE,
|
485 |
-
dataset_name=TEST_DATASET_NAME,
|
486 |
-
data_schema=schema({
|
487 |
-
UUID_COLUMN: 'string',
|
488 |
-
'texts': [
|
489 |
-
field(
|
490 |
-
'string',
|
491 |
-
fields={
|
492 |
-
'length_signal': field('int32', length_signal.dict()),
|
493 |
-
'test_signal': field(
|
494 |
-
signal=test_signal.dict(), fields={
|
495 |
-
'len': 'int32',
|
496 |
-
'flen': 'float32'
|
497 |
-
})
|
498 |
-
})
|
499 |
-
],
|
500 |
-
}),
|
501 |
-
num_items=2)
|
502 |
-
|
503 |
-
result = dataset.select_rows(['texts'])
|
504 |
-
assert list(result) == [{
|
505 |
-
UUID_COLUMN: '1',
|
506 |
-
'texts': [
|
507 |
-
enriched_item('hello', {
|
508 |
-
'length_signal': 5,
|
509 |
-
'test_signal': {
|
510 |
-
'len': 5,
|
511 |
-
'flen': 5.0
|
512 |
-
}
|
513 |
-
}),
|
514 |
-
enriched_item('everybody', {
|
515 |
-
'length_signal': 9,
|
516 |
-
'test_signal': {
|
517 |
-
'len': 9,
|
518 |
-
'flen': 9.0
|
519 |
-
}
|
520 |
-
})
|
521 |
-
],
|
522 |
-
}, {
|
523 |
-
UUID_COLUMN: '2',
|
524 |
-
'texts': [
|
525 |
-
enriched_item('a', {
|
526 |
-
'length_signal': 1,
|
527 |
-
'test_signal': {
|
528 |
-
'len': 1,
|
529 |
-
'flen': 1.0
|
530 |
-
}
|
531 |
-
}),
|
532 |
-
enriched_item('bc', {
|
533 |
-
'length_signal': 2,
|
534 |
-
'test_signal': {
|
535 |
-
'len': 2,
|
536 |
-
'flen': 2.0
|
537 |
-
}
|
538 |
-
}),
|
539 |
-
enriched_item('def', {
|
540 |
-
'length_signal': 3,
|
541 |
-
'test_signal': {
|
542 |
-
'len': 3,
|
543 |
-
'flen': 3.0
|
544 |
-
}
|
545 |
-
})
|
546 |
-
],
|
547 |
-
}]
|
548 |
-
|
549 |
-
# Test subselection.
|
550 |
-
result = dataset.select_rows(
|
551 |
-
[val(('texts', '*')), ('texts', '*', 'length_signal'), ('texts', '*', 'test_signal', 'flen')])
|
552 |
-
assert list(result) == [{
|
553 |
-
UUID_COLUMN: '1',
|
554 |
-
f'texts.*.{VALUE_KEY}': ['hello', 'everybody'],
|
555 |
-
'texts.*.test_signal.flen': [5.0, 9.0],
|
556 |
-
'texts.*.length_signal': [5, 9]
|
557 |
-
}, {
|
558 |
-
UUID_COLUMN: '2',
|
559 |
-
f'texts.*.{VALUE_KEY}': ['a', 'bc', 'def'],
|
560 |
-
'texts.*.test_signal.flen': [1.0, 2.0, 3.0],
|
561 |
-
'texts.*.length_signal': [1, 2, 3]
|
562 |
-
}]
|
563 |
-
|
564 |
-
|
565 |
-
def test_combining_columns(make_test_data: TestDataMaker) -> None:
|
566 |
-
dataset = make_test_data([{
|
567 |
-
UUID_COLUMN: '1',
|
568 |
-
'text': 'hello',
|
569 |
-
'extra': {
|
570 |
-
'text': {
|
571 |
-
'length_signal': 5,
|
572 |
-
'test_signal': {
|
573 |
-
'len': 5,
|
574 |
-
'flen': 5.0
|
575 |
-
}
|
576 |
-
}
|
577 |
-
}
|
578 |
-
}, {
|
579 |
-
UUID_COLUMN: '2',
|
580 |
-
'text': 'everybody',
|
581 |
-
'extra': {
|
582 |
-
'text': {
|
583 |
-
'length_signal': 9,
|
584 |
-
'test_signal': {
|
585 |
-
'len': 9,
|
586 |
-
'flen': 9.0
|
587 |
-
}
|
588 |
-
}
|
589 |
-
}
|
590 |
-
}])
|
591 |
-
|
592 |
-
# Sub-select text and test_signal.
|
593 |
-
result = dataset.select_rows(['text', ('extra', 'text', 'test_signal')], combine_columns=True)
|
594 |
-
assert list(result) == [{
|
595 |
-
UUID_COLUMN: '1',
|
596 |
-
'text': 'hello',
|
597 |
-
'extra': {
|
598 |
-
'text': {
|
599 |
-
'test_signal': {
|
600 |
-
'len': 5,
|
601 |
-
'flen': 5.0
|
602 |
-
}
|
603 |
-
}
|
604 |
-
}
|
605 |
-
}, {
|
606 |
-
UUID_COLUMN: '2',
|
607 |
-
'text': 'everybody',
|
608 |
-
'extra': {
|
609 |
-
'text': {
|
610 |
-
'test_signal': {
|
611 |
-
'len': 9,
|
612 |
-
'flen': 9.0
|
613 |
-
}
|
614 |
-
}
|
615 |
-
}
|
616 |
-
}]
|
617 |
-
|
618 |
-
# Sub-select text and length_signal.
|
619 |
-
result = dataset.select_rows(['text', ('extra', 'text', 'length_signal')], combine_columns=True)
|
620 |
-
assert list(result) == [{
|
621 |
-
UUID_COLUMN: '1',
|
622 |
-
'text': 'hello',
|
623 |
-
'extra': {
|
624 |
-
'text': {
|
625 |
-
'length_signal': 5
|
626 |
-
}
|
627 |
-
}
|
628 |
-
}, {
|
629 |
-
UUID_COLUMN: '2',
|
630 |
-
'text': 'everybody',
|
631 |
-
'extra': {
|
632 |
-
'text': {
|
633 |
-
'length_signal': 9
|
634 |
-
}
|
635 |
-
}
|
636 |
-
}]
|
637 |
-
|
638 |
-
# Sub-select length_signal only.
|
639 |
-
result = dataset.select_rows([('extra', 'text', 'length_signal')], combine_columns=True)
|
640 |
-
assert list(result) == [{
|
641 |
-
UUID_COLUMN: '1',
|
642 |
-
'extra': {
|
643 |
-
'text': {
|
644 |
-
'length_signal': 5
|
645 |
-
}
|
646 |
-
}
|
647 |
-
}, {
|
648 |
-
UUID_COLUMN: '2',
|
649 |
-
'extra': {
|
650 |
-
'text': {
|
651 |
-
'length_signal': 9
|
652 |
-
}
|
653 |
-
}
|
654 |
-
}]
|
655 |
-
|
656 |
-
# Aliases are ignored when combing columns.
|
657 |
-
len_col = Column(('extra', 'text', 'length_signal'), alias='hello')
|
658 |
-
result = dataset.select_rows([len_col], combine_columns=True)
|
659 |
-
assert list(result) == [{
|
660 |
-
UUID_COLUMN: '1',
|
661 |
-
'extra': {
|
662 |
-
'text': {
|
663 |
-
'length_signal': 5
|
664 |
-
}
|
665 |
-
}
|
666 |
-
}, {
|
667 |
-
UUID_COLUMN: '2',
|
668 |
-
'extra': {
|
669 |
-
'text': {
|
670 |
-
'length_signal': 9
|
671 |
-
}
|
672 |
-
}
|
673 |
-
}]
|
674 |
-
|
675 |
-
# Works with UDFs and aliases are ignored.
|
676 |
-
udf_col = Column('text', alias='ignored', signal_udf=LengthSignal())
|
677 |
-
result = dataset.select_rows(['text', udf_col], combine_columns=True)
|
678 |
-
assert list(result) == [{
|
679 |
-
UUID_COLUMN: '1',
|
680 |
-
'text': enriched_item('hello', {'length_signal': 5})
|
681 |
-
}, {
|
682 |
-
UUID_COLUMN: '2',
|
683 |
-
'text': enriched_item('everybody', {'length_signal': 9})
|
684 |
-
}]
|
685 |
-
|
686 |
-
|
687 |
-
def test_source_joined_with_named_signal(make_test_data: TestDataMaker) -> None:
|
688 |
-
dataset = make_test_data(SIMPLE_ITEMS)
|
689 |
-
assert dataset.manifest() == DatasetManifest(
|
690 |
-
namespace=TEST_NAMESPACE,
|
691 |
-
dataset_name=TEST_DATASET_NAME,
|
692 |
-
data_schema=schema({
|
693 |
-
UUID_COLUMN: 'string',
|
694 |
-
'str': 'string',
|
695 |
-
'int': 'int32',
|
696 |
-
'bool': 'boolean',
|
697 |
-
'float': 'float32',
|
698 |
-
}),
|
699 |
-
num_items=3)
|
700 |
-
|
701 |
-
test_signal = TestSignal()
|
702 |
-
dataset.compute_signal(test_signal, 'str')
|
703 |
-
|
704 |
-
# Check the enriched dataset manifest has 'text' enriched.
|
705 |
-
assert dataset.manifest() == DatasetManifest(
|
706 |
-
namespace=TEST_NAMESPACE,
|
707 |
-
dataset_name=TEST_DATASET_NAME,
|
708 |
-
data_schema=schema({
|
709 |
-
UUID_COLUMN: 'string',
|
710 |
-
'str': field(
|
711 |
-
'string',
|
712 |
-
fields={
|
713 |
-
'test_signal': field(
|
714 |
-
signal=test_signal.dict(), fields={
|
715 |
-
'len': 'int32',
|
716 |
-
'flen': 'float32'
|
717 |
-
})
|
718 |
-
}),
|
719 |
-
'int': 'int32',
|
720 |
-
'bool': 'boolean',
|
721 |
-
'float': 'float32',
|
722 |
-
}),
|
723 |
-
num_items=3)
|
724 |
-
|
725 |
-
# Select both columns, without val() on str.
|
726 |
-
result = dataset.select_rows(['str', Column(('str', 'test_signal'), alias='test_signal_on_str')])
|
727 |
-
|
728 |
-
assert list(result) == [{
|
729 |
-
UUID_COLUMN: '1',
|
730 |
-
'str': enriched_item('a', {'test_signal': {
|
731 |
-
'len': 1,
|
732 |
-
'flen': 1.0
|
733 |
-
}}),
|
734 |
-
'test_signal_on_str': {
|
735 |
-
'len': 1,
|
736 |
-
'flen': 1.0
|
737 |
-
}
|
738 |
-
}, {
|
739 |
-
UUID_COLUMN: '2',
|
740 |
-
'str': enriched_item('b', {'test_signal': {
|
741 |
-
'len': 1,
|
742 |
-
'flen': 1.0
|
743 |
-
}}),
|
744 |
-
'test_signal_on_str': {
|
745 |
-
'len': 1,
|
746 |
-
'flen': 1.0
|
747 |
-
}
|
748 |
-
}, {
|
749 |
-
UUID_COLUMN: '3',
|
750 |
-
'str': enriched_item('b', {'test_signal': {
|
751 |
-
'len': 1,
|
752 |
-
'flen': 1.0
|
753 |
-
}}),
|
754 |
-
'test_signal_on_str': {
|
755 |
-
'len': 1,
|
756 |
-
'flen': 1.0
|
757 |
-
}
|
758 |
-
}]
|
759 |
-
|
760 |
-
# Select both columns, with val() on str.
|
761 |
-
result = dataset.select_rows(
|
762 |
-
[val('str'), Column(('str', 'test_signal'), alias='test_signal_on_str')])
|
763 |
-
|
764 |
-
assert list(result) == [{
|
765 |
-
UUID_COLUMN: '1',
|
766 |
-
f'str.{VALUE_KEY}': 'a',
|
767 |
-
'test_signal_on_str': {
|
768 |
-
'len': 1,
|
769 |
-
'flen': 1.0
|
770 |
-
}
|
771 |
-
}, {
|
772 |
-
UUID_COLUMN: '2',
|
773 |
-
f'str.{VALUE_KEY}': 'b',
|
774 |
-
'test_signal_on_str': {
|
775 |
-
'len': 1,
|
776 |
-
'flen': 1.0
|
777 |
-
}
|
778 |
-
}, {
|
779 |
-
UUID_COLUMN: '3',
|
780 |
-
f'str.{VALUE_KEY}': 'b',
|
781 |
-
'test_signal_on_str': {
|
782 |
-
'len': 1,
|
783 |
-
'flen': 1.0
|
784 |
-
}
|
785 |
-
}]
|
786 |
-
|
787 |
-
|
788 |
-
def test_invalid_column_paths(make_test_data: TestDataMaker) -> None:
|
789 |
-
dataset = make_test_data([{
|
790 |
-
UUID_COLUMN: '1',
|
791 |
-
'text': enriched_item('hello', {'test_signal': {
|
792 |
-
'len': 5
|
793 |
-
}}),
|
794 |
-
'text2': [
|
795 |
-
enriched_item('hello', {'test_signal': {
|
796 |
-
'len': 5
|
797 |
-
}}),
|
798 |
-
enriched_item('hi', {'test_signal': {
|
799 |
-
'len': 2
|
800 |
-
}})
|
801 |
-
],
|
802 |
-
}])
|
803 |
-
|
804 |
-
with pytest.raises(ValueError, match='Path part "invalid" not found in the dataset'):
|
805 |
-
dataset.select_rows([('text', 'test_signal', 'invalid')])
|
806 |
-
|
807 |
-
with pytest.raises(ValueError, match='Selecting a specific index of a repeated field'):
|
808 |
-
dataset.select_rows([('text2', '4', 'test_signal')])
|
809 |
-
|
810 |
-
|
811 |
-
def test_signal_with_quote(make_test_data: TestDataMaker) -> None:
|
812 |
-
dataset = make_test_data([{
|
813 |
-
UUID_COLUMN: '1',
|
814 |
-
'text': 'hello',
|
815 |
-
}, {
|
816 |
-
UUID_COLUMN: '2',
|
817 |
-
'text': 'world',
|
818 |
-
}])
|
819 |
-
dataset.compute_signal(SignalWithQuoteInIt(), 'text')
|
820 |
-
dataset.compute_signal(SignalWithDoubleQuoteInIt(), 'text')
|
821 |
-
result = dataset.select_rows(['text'])
|
822 |
-
assert list(result) == [{
|
823 |
-
UUID_COLUMN: '1',
|
824 |
-
'text': enriched_item('hello', {
|
825 |
-
"test'signal": True,
|
826 |
-
'test"signal': True
|
827 |
-
})
|
828 |
-
}, {
|
829 |
-
UUID_COLUMN: '2',
|
830 |
-
'text': enriched_item('world', {
|
831 |
-
"test'signal": True,
|
832 |
-
'test"signal': True
|
833 |
-
}),
|
834 |
-
}]
|
835 |
-
|
836 |
-
|
837 |
-
class SignalWithQuoteInIt(TextSignal):
|
838 |
-
name = "test'signal"
|
839 |
-
|
840 |
-
@override
|
841 |
-
def fields(self) -> Field:
|
842 |
-
return field('boolean')
|
843 |
-
|
844 |
-
@override
|
845 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
846 |
-
for d in data:
|
847 |
-
yield True
|
848 |
-
|
849 |
-
|
850 |
-
class SignalWithDoubleQuoteInIt(TextSignal):
|
851 |
-
name = 'test"signal'
|
852 |
-
|
853 |
-
@override
|
854 |
-
def fields(self) -> Field:
|
855 |
-
return field('boolean')
|
856 |
-
|
857 |
-
@override
|
858 |
-
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
|
859 |
-
for d in data:
|
860 |
-
yield True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/dataset_utils.py
CHANGED
@@ -73,7 +73,7 @@ def lilac_embedding(start: int, end: int, embedding: Optional[np.ndarray]) -> It
|
|
73 |
Tflatten = TypeVar('Tflatten', object, np.ndarray)
|
74 |
|
75 |
|
76 |
-
def _flatten(input: Union[
|
77 |
bool]) -> Generator:
|
78 |
"""Flattens a nested iterable."""
|
79 |
if is_primitive_predicate(input):
|
@@ -83,13 +83,13 @@ def _flatten(input: Union[Iterable, object], is_primitive_predicate: Callable[[o
|
|
83 |
elif is_primitive(input):
|
84 |
yield input
|
85 |
else:
|
86 |
-
for elem in cast(
|
87 |
yield from _flatten(elem, is_primitive_predicate)
|
88 |
|
89 |
|
90 |
-
def flatten(input: Union[Iterable, Tflatten],
|
91 |
-
is_primitive_predicate: Callable[[object], bool] = is_primitive) ->
|
92 |
-
"""Flattens a nested
|
93 |
|
94 |
Primitives and dictionaries are not flattened. The user can also provide a predicate to determine
|
95 |
what is a primitive.
|
@@ -97,7 +97,7 @@ def flatten(input: Union[Iterable, Tflatten],
|
|
97 |
return _flatten(input, is_primitive_predicate)
|
98 |
|
99 |
|
100 |
-
def count_primitives(input: Iterable) -> int:
|
101 |
"""Iterate through each element of the input, flattening each one, computing a count.
|
102 |
|
103 |
Sum the final set of counts. This is the important iterable not to exhaust.
|
@@ -128,7 +128,8 @@ def _unflatten(flat_input: Iterator[list[object]],
|
|
128 |
return [_unflatten(flat_input, orig_elem) for orig_elem in values]
|
129 |
|
130 |
|
131 |
-
def unflatten(flat_input: Iterable, original_input: Union[Iterable,
|
|
|
132 |
"""Unflattens a flattened iterable according to the original iterable's structure."""
|
133 |
return cast(list, _unflatten(iter(flat_input), original_input))
|
134 |
|
@@ -234,23 +235,27 @@ def write_item_embeddings_to_disk(keys: Iterable[str], embeddings: Iterable[obje
|
|
234 |
return isinstance(input, np.ndarray)
|
235 |
|
236 |
flat_keys = flatten_keys(keys, embeddings, is_primitive_predicate=embedding_predicate)
|
|
|
|
|
237 |
embedding_vectors: list[np.ndarray] = []
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
239 |
# We use squeeze here because embedding functions can return outer dimensions of 1.
|
240 |
-
|
241 |
-
|
242 |
-
raise ValueError(f'Expected embeddings to be 1-dimensional, got {embedding_vector.ndim} '
|
243 |
-
f'with shape {embedding_vector.shape}.')
|
244 |
-
embedding_vectors.append(embedding_vector)
|
245 |
|
246 |
-
|
247 |
|
248 |
# Write the embedding index and the ordered UUID column to disk so they can be joined later.
|
249 |
|
250 |
with open_file(output_path_prefix + _EMBEDDINGS_SUFFIX, 'wb') as f:
|
251 |
-
np.save(f,
|
252 |
with open_file(output_path_prefix + _KEYS_SUFFIX, 'wb') as f:
|
253 |
-
pickle.dump(
|
254 |
|
255 |
return output_path_prefix
|
256 |
|
@@ -314,34 +319,63 @@ def parquet_filename(prefix: str, shard_index: int, num_shards: int) -> str:
|
|
314 |
|
315 |
|
316 |
def _flatten_keys(uuid: str, nested_input: Iterable, location: list[int],
|
317 |
-
is_primitive_predicate: Callable[[object], bool]) ->
|
318 |
-
if is_primitive_predicate(nested_input)
|
319 |
-
|
320 |
-
|
321 |
-
return
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
for value in nested_input.values():
|
326 |
-
result.extend(_flatten_keys(uuid, value, location, is_primitive_predicate))
|
327 |
-
else:
|
328 |
-
for i, input in enumerate(nested_input):
|
329 |
-
result.extend(_flatten_keys(uuid, input, [*location, i], is_primitive_predicate))
|
330 |
-
return result
|
331 |
|
332 |
|
333 |
def flatten_keys(
|
334 |
uuids: Iterable[str],
|
335 |
nested_input: Iterable,
|
336 |
-
is_primitive_predicate: Callable[[object],
|
|
|
337 |
"""Flatten the uuid keys of a nested input."""
|
338 |
-
result: list[VectorKey] = []
|
339 |
for uuid, input in zip(uuids, nested_input):
|
340 |
-
|
341 |
-
|
|
|
|
|
342 |
|
343 |
|
344 |
def embedding_index_filename_prefix(output_dir: str, shard_index: int, num_shards: int) -> str:
|
345 |
"""Return the filename prefix for the embedding index."""
|
346 |
npy_filename = f'embeddings-{shard_index:05d}-of-{num_shards:05d}'
|
347 |
return os.path.join(output_dir, npy_filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
Tflatten = TypeVar('Tflatten', object, np.ndarray)
|
74 |
|
75 |
|
76 |
+
def _flatten(input: Union[Iterator, object], is_primitive_predicate: Callable[[object],
|
77 |
bool]) -> Generator:
|
78 |
"""Flattens a nested iterable."""
|
79 |
if is_primitive_predicate(input):
|
|
|
83 |
elif is_primitive(input):
|
84 |
yield input
|
85 |
else:
|
86 |
+
for elem in cast(Iterator, input):
|
87 |
yield from _flatten(elem, is_primitive_predicate)
|
88 |
|
89 |
|
90 |
+
def flatten(input: Union[Iterator, Iterable, Tflatten],
|
91 |
+
is_primitive_predicate: Callable[[object], bool] = is_primitive) -> Iterator[Tflatten]:
|
92 |
+
"""Flattens a nested iterator.
|
93 |
|
94 |
Primitives and dictionaries are not flattened. The user can also provide a predicate to determine
|
95 |
what is a primitive.
|
|
|
97 |
return _flatten(input, is_primitive_predicate)
|
98 |
|
99 |
|
100 |
+
def count_primitives(input: Union[Iterable, Iterator]) -> int:
|
101 |
"""Iterate through each element of the input, flattening each one, computing a count.
|
102 |
|
103 |
Sum the final set of counts. This is the important iterable not to exhaust.
|
|
|
128 |
return [_unflatten(flat_input, orig_elem) for orig_elem in values]
|
129 |
|
130 |
|
131 |
+
def unflatten(flat_input: Union[Iterable, Iterator], original_input: Union[Iterable,
|
132 |
+
object]) -> list:
|
133 |
"""Unflattens a flattened iterable according to the original iterable's structure."""
|
134 |
return cast(list, _unflatten(iter(flat_input), original_input))
|
135 |
|
|
|
235 |
return isinstance(input, np.ndarray)
|
236 |
|
237 |
flat_keys = flatten_keys(keys, embeddings, is_primitive_predicate=embedding_predicate)
|
238 |
+
flat_embeddings = flatten(embeddings, is_primitive_predicate=embedding_predicate)
|
239 |
+
|
240 |
embedding_vectors: list[np.ndarray] = []
|
241 |
+
embedding_keys: list[VectorKey] = []
|
242 |
+
for key, lilac_embedding in zip(flat_keys, flat_embeddings):
|
243 |
+
if not key or not lilac_embedding or EMBEDDING_KEY not in lilac_embedding:
|
244 |
+
# Sparse embeddings may not have an embedding for every key.
|
245 |
+
continue
|
246 |
+
|
247 |
# We use squeeze here because embedding functions can return outer dimensions of 1.
|
248 |
+
embedding_vectors.append(lilac_embedding[EMBEDDING_KEY].reshape(-1))
|
249 |
+
embedding_keys.append(key)
|
|
|
|
|
|
|
250 |
|
251 |
+
embedding_vectors = np.array(embedding_vectors)
|
252 |
|
253 |
# Write the embedding index and the ordered UUID column to disk so they can be joined later.
|
254 |
|
255 |
with open_file(output_path_prefix + _EMBEDDINGS_SUFFIX, 'wb') as f:
|
256 |
+
np.save(f, embedding_vectors, allow_pickle=False)
|
257 |
with open_file(output_path_prefix + _KEYS_SUFFIX, 'wb') as f:
|
258 |
+
pickle.dump(embedding_keys, f)
|
259 |
|
260 |
return output_path_prefix
|
261 |
|
|
|
319 |
|
320 |
|
321 |
def _flatten_keys(uuid: str, nested_input: Iterable, location: list[int],
|
322 |
+
is_primitive_predicate: Callable[[object], bool]) -> Iterator[VectorKey]:
|
323 |
+
if is_primitive_predicate(nested_input) or is_primitive(nested_input) or isinstance(
|
324 |
+
nested_input, dict):
|
325 |
+
yield (uuid, *location)
|
326 |
+
return
|
327 |
+
|
328 |
+
for i, input in enumerate(nested_input):
|
329 |
+
yield from _flatten_keys(uuid, input, [*location, i], is_primitive_predicate)
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
|
332 |
def flatten_keys(
|
333 |
uuids: Iterable[str],
|
334 |
nested_input: Iterable,
|
335 |
+
is_primitive_predicate: Callable[[object],
|
336 |
+
bool] = is_primitive) -> Iterator[Optional[VectorKey]]:
|
337 |
"""Flatten the uuid keys of a nested input."""
|
|
|
338 |
for uuid, input in zip(uuids, nested_input):
|
339 |
+
if input is None:
|
340 |
+
yield None
|
341 |
+
continue
|
342 |
+
yield from _flatten_keys(uuid, input, [], is_primitive_predicate)
|
343 |
|
344 |
|
345 |
def embedding_index_filename_prefix(output_dir: str, shard_index: int, num_shards: int) -> str:
|
346 |
"""Return the filename prefix for the embedding index."""
|
347 |
npy_filename = f'embeddings-{shard_index:05d}-of-{num_shards:05d}'
|
348 |
return os.path.join(output_dir, npy_filename)
|
349 |
+
|
350 |
+
|
351 |
+
Tin = TypeVar('Tin')
|
352 |
+
Tout = TypeVar('Tout')
|
353 |
+
|
354 |
+
|
355 |
+
def sparse_to_dense_compute(
|
356 |
+
sparse_input: Iterator[Optional[Tin]],
|
357 |
+
func: Callable[[Iterable[Tin]], Iterable[Tout]]) -> Iterator[Optional[Tout]]:
|
358 |
+
"""Densifies the input before calling the provided `func` and sparsifies the output."""
|
359 |
+
empty_mask: list[bool] = []
|
360 |
+
|
361 |
+
def densify(x: Iterator[Optional[Tin]]) -> Iterator[Tin]:
|
362 |
+
nonlocal empty_mask
|
363 |
+
for i, value in enumerate(x):
|
364 |
+
empty_mask.append(value is None)
|
365 |
+
if value is not None:
|
366 |
+
yield value
|
367 |
+
|
368 |
+
dense_input = densify(sparse_input)
|
369 |
+
dense_output = iter(func(dense_input))
|
370 |
+
index = 0
|
371 |
+
|
372 |
+
while True:
|
373 |
+
try:
|
374 |
+
out = next(dense_output)
|
375 |
+
yield (None if empty_mask[index] else out)
|
376 |
+
index += 1
|
377 |
+
except StopIteration:
|
378 |
+
while index < len(empty_mask):
|
379 |
+
yield None
|
380 |
+
index += 1
|
381 |
+
return
|
src/data/dataset_utils_test.py
DELETED
@@ -1,114 +0,0 @@
|
|
1 |
-
"""Tests for dataset utils."""
|
2 |
-
from ..schema import PathTuple
|
3 |
-
from .dataset_utils import count_primitives, flatten, unflatten, wrap_in_dicts
|
4 |
-
|
5 |
-
|
6 |
-
def test_flatten() -> None:
|
7 |
-
a = [[1, 2], [[3]], [4, 5, 5]]
|
8 |
-
result = list(flatten(a))
|
9 |
-
assert result == [1, 2, 3, 4, 5, 5]
|
10 |
-
|
11 |
-
|
12 |
-
def test_flatten_primitive() -> None:
|
13 |
-
result = list(flatten('hello'))
|
14 |
-
assert result == ['hello']
|
15 |
-
|
16 |
-
|
17 |
-
def test_unflatten() -> None:
|
18 |
-
a = [[1, 2], [[3]], [4, 5, 5]]
|
19 |
-
flat_a = list(flatten(a))
|
20 |
-
result = unflatten(flat_a, a)
|
21 |
-
assert result == [[1, 2], [[3]], [4, 5, 5]]
|
22 |
-
|
23 |
-
|
24 |
-
def test_count_nested() -> None:
|
25 |
-
a = [[1, 2], [[3]], [4, 5, 6]]
|
26 |
-
assert 6 == count_primitives(a)
|
27 |
-
|
28 |
-
|
29 |
-
def test_wrap_in_dicts_with_spec_of_one_repeated() -> None:
|
30 |
-
a = [[1, 2], [3], [4, 5, 5]]
|
31 |
-
spec: list[PathTuple] = [('a', 'b', 'c'), ('d',)] # Corresponds to a.b.c.*.d.
|
32 |
-
result = wrap_in_dicts(a, spec)
|
33 |
-
assert result == [{
|
34 |
-
'a': {
|
35 |
-
'b': {
|
36 |
-
'c': [{
|
37 |
-
'd': 1
|
38 |
-
}, {
|
39 |
-
'd': 2
|
40 |
-
}]
|
41 |
-
}
|
42 |
-
}
|
43 |
-
}, {
|
44 |
-
'a': {
|
45 |
-
'b': {
|
46 |
-
'c': [{
|
47 |
-
'd': 3
|
48 |
-
}]
|
49 |
-
}
|
50 |
-
}
|
51 |
-
}, {
|
52 |
-
'a': {
|
53 |
-
'b': {
|
54 |
-
'c': [{
|
55 |
-
'd': 4
|
56 |
-
}, {
|
57 |
-
'd': 5
|
58 |
-
}, {
|
59 |
-
'd': 5
|
60 |
-
}]
|
61 |
-
}
|
62 |
-
}
|
63 |
-
}]
|
64 |
-
|
65 |
-
|
66 |
-
def test_wrap_in_dicts_with_spec_of_double_repeated() -> None:
|
67 |
-
a = [[[1, 2], [3, 4, 5]], [[6]], [[7], [8], [9, 10]]]
|
68 |
-
spec: list[PathTuple] = [('a', 'b'), tuple(), ('c',)] # Corresponds to a.b.*.*.c.
|
69 |
-
result = wrap_in_dicts(a, spec)
|
70 |
-
assert result == [{
|
71 |
-
'a': {
|
72 |
-
'b': [[{
|
73 |
-
'c': 1
|
74 |
-
}, {
|
75 |
-
'c': 2
|
76 |
-
}], [{
|
77 |
-
'c': 3
|
78 |
-
}, {
|
79 |
-
'c': 4
|
80 |
-
}, {
|
81 |
-
'c': 5
|
82 |
-
}]]
|
83 |
-
}
|
84 |
-
}, {
|
85 |
-
'a': {
|
86 |
-
'b': [[{
|
87 |
-
'c': 6
|
88 |
-
}]]
|
89 |
-
}
|
90 |
-
}, {
|
91 |
-
'a': {
|
92 |
-
'b': [[{
|
93 |
-
'c': 7
|
94 |
-
}], [{
|
95 |
-
'c': 8
|
96 |
-
}], [{
|
97 |
-
'c': 9
|
98 |
-
}, {
|
99 |
-
'c': 10
|
100 |
-
}]]
|
101 |
-
}
|
102 |
-
}]
|
103 |
-
|
104 |
-
|
105 |
-
def test_unflatten_primitive() -> None:
|
106 |
-
original = 'hello'
|
107 |
-
result = unflatten(['hello'], original)
|
108 |
-
assert result == 'hello'
|
109 |
-
|
110 |
-
|
111 |
-
def test_unflatten_primitive_list() -> None:
|
112 |
-
original = ['hello', 'world']
|
113 |
-
result = unflatten(['hello', 'world'], original)
|
114 |
-
assert result == ['hello', 'world']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/sources/csv_source_test.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
"""Tests for the CSV source."""
|
2 |
-
import csv
|
3 |
-
import os
|
4 |
-
import pathlib
|
5 |
-
|
6 |
-
from ...schema import schema
|
7 |
-
from .csv_source import LINE_NUMBER_COLUMN, CSVDataset
|
8 |
-
from .source import SourceSchema
|
9 |
-
|
10 |
-
|
11 |
-
def test_csv(tmp_path: pathlib.Path) -> None:
|
12 |
-
csv_rows = [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}]
|
13 |
-
|
14 |
-
filename = 'test-dataset.csv'
|
15 |
-
filepath = os.path.join(tmp_path, filename)
|
16 |
-
with open(filepath, 'w') as f:
|
17 |
-
writer = csv.DictWriter(f, fieldnames=list(csv_rows[0].keys()))
|
18 |
-
writer.writeheader()
|
19 |
-
writer.writerows(csv_rows)
|
20 |
-
|
21 |
-
source = CSVDataset(filepaths=[filepath])
|
22 |
-
source.setup()
|
23 |
-
|
24 |
-
source_schema = source.source_schema()
|
25 |
-
assert source_schema == SourceSchema(
|
26 |
-
fields=schema({
|
27 |
-
LINE_NUMBER_COLUMN: 'int64',
|
28 |
-
'x': 'int64',
|
29 |
-
'y': 'string'
|
30 |
-
}).fields, num_items=2)
|
31 |
-
|
32 |
-
items = list(source.process())
|
33 |
-
|
34 |
-
assert items == [{
|
35 |
-
LINE_NUMBER_COLUMN: 0,
|
36 |
-
'x': 1,
|
37 |
-
'y': 'ten'
|
38 |
-
}, {
|
39 |
-
LINE_NUMBER_COLUMN: 1,
|
40 |
-
'x': 2,
|
41 |
-
'y': 'twenty'
|
42 |
-
}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/sources/huggingface_source_test.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
"""Tests for the pandas source."""
|
2 |
-
import os
|
3 |
-
import pathlib
|
4 |
-
|
5 |
-
# mypy: disable-error-code="attr-defined"
|
6 |
-
from datasets import Dataset, Features, Sequence, Value
|
7 |
-
|
8 |
-
from ...schema import schema
|
9 |
-
from .huggingface_source import HF_SPLIT_COLUMN, HuggingFaceDataset
|
10 |
-
from .source import SourceSchema
|
11 |
-
|
12 |
-
|
13 |
-
def test_hf(tmp_path: pathlib.Path) -> None:
|
14 |
-
dataset = Dataset.from_list([{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}])
|
15 |
-
|
16 |
-
dataset_name = os.path.join(tmp_path, 'hf-test-dataset')
|
17 |
-
dataset.save_to_disk(dataset_name)
|
18 |
-
|
19 |
-
source = HuggingFaceDataset(dataset_name=dataset_name, load_from_disk=True)
|
20 |
-
|
21 |
-
items = source.process()
|
22 |
-
source.setup()
|
23 |
-
|
24 |
-
source_schema = source.source_schema()
|
25 |
-
assert source_schema == SourceSchema(
|
26 |
-
fields=schema({
|
27 |
-
HF_SPLIT_COLUMN: 'string',
|
28 |
-
'x': 'int64',
|
29 |
-
'y': 'string'
|
30 |
-
}).fields, num_items=2)
|
31 |
-
|
32 |
-
items = list(source.process())
|
33 |
-
|
34 |
-
assert items == [{
|
35 |
-
HF_SPLIT_COLUMN: 'default',
|
36 |
-
'x': 1,
|
37 |
-
'y': 'ten'
|
38 |
-
}, {
|
39 |
-
HF_SPLIT_COLUMN: 'default',
|
40 |
-
'x': 2,
|
41 |
-
'y': 'twenty'
|
42 |
-
}]
|
43 |
-
|
44 |
-
|
45 |
-
def test_hf_sequence(tmp_path: pathlib.Path) -> None:
|
46 |
-
dataset = Dataset.from_list([{
|
47 |
-
'scalar': 1,
|
48 |
-
'seq': [1, 0],
|
49 |
-
'seq_dict': {
|
50 |
-
'x': [1, 2, 3],
|
51 |
-
'y': ['four', 'five', 'six']
|
52 |
-
}
|
53 |
-
}, {
|
54 |
-
'scalar': 2,
|
55 |
-
'seq': [2, 0],
|
56 |
-
'seq_dict': {
|
57 |
-
'x': [10, 20, 30],
|
58 |
-
'y': ['forty', 'fifty', 'sixty']
|
59 |
-
}
|
60 |
-
}],
|
61 |
-
features=Features({
|
62 |
-
'scalar': Value(dtype='int64'),
|
63 |
-
'seq': Sequence(feature=Value(dtype='int64')),
|
64 |
-
'seq_dict': Sequence(feature={
|
65 |
-
'x': Value(dtype='int64'),
|
66 |
-
'y': Value(dtype='string')
|
67 |
-
})
|
68 |
-
}))
|
69 |
-
|
70 |
-
dataset_name = os.path.join(tmp_path, 'hf-test-dataset')
|
71 |
-
dataset.save_to_disk(dataset_name)
|
72 |
-
|
73 |
-
source = HuggingFaceDataset(dataset_name=dataset_name, load_from_disk=True)
|
74 |
-
|
75 |
-
items = source.process()
|
76 |
-
source.setup()
|
77 |
-
|
78 |
-
source_schema = source.source_schema()
|
79 |
-
assert source_schema == SourceSchema(
|
80 |
-
fields=schema({
|
81 |
-
HF_SPLIT_COLUMN: 'string',
|
82 |
-
'scalar': 'int64',
|
83 |
-
'seq': ['int64'],
|
84 |
-
'seq_dict': {
|
85 |
-
'x': ['int64'],
|
86 |
-
'y': ['string'],
|
87 |
-
},
|
88 |
-
}).fields,
|
89 |
-
num_items=2)
|
90 |
-
|
91 |
-
items = list(source.process())
|
92 |
-
|
93 |
-
assert items == [{
|
94 |
-
HF_SPLIT_COLUMN: 'default',
|
95 |
-
'scalar': 1,
|
96 |
-
'seq': [1, 0],
|
97 |
-
'seq_dict': {
|
98 |
-
'x': [1, 2, 3],
|
99 |
-
'y': ['four', 'five', 'six']
|
100 |
-
}
|
101 |
-
}, {
|
102 |
-
HF_SPLIT_COLUMN: 'default',
|
103 |
-
'scalar': 2,
|
104 |
-
'seq': [2, 0],
|
105 |
-
'seq_dict': {
|
106 |
-
'x': [10, 20, 30],
|
107 |
-
'y': ['forty', 'fifty', 'sixty']
|
108 |
-
}
|
109 |
-
}]
|
110 |
-
|
111 |
-
|
112 |
-
def test_hf_list(tmp_path: pathlib.Path) -> None:
|
113 |
-
dataset = Dataset.from_list([{
|
114 |
-
'scalar': 1,
|
115 |
-
'list': [{
|
116 |
-
'x': 1,
|
117 |
-
'y': 'two'
|
118 |
-
}]
|
119 |
-
}, {
|
120 |
-
'scalar': 2,
|
121 |
-
'list': [{
|
122 |
-
'x': 3,
|
123 |
-
'y': 'four'
|
124 |
-
}]
|
125 |
-
}],
|
126 |
-
features=Features({
|
127 |
-
'scalar': Value(dtype='int64'),
|
128 |
-
'list': [{
|
129 |
-
'x': Value(dtype='int64'),
|
130 |
-
'y': Value(dtype='string')
|
131 |
-
}]
|
132 |
-
}))
|
133 |
-
|
134 |
-
dataset_name = os.path.join(tmp_path, 'hf-test-dataset')
|
135 |
-
dataset.save_to_disk(dataset_name)
|
136 |
-
|
137 |
-
source = HuggingFaceDataset(dataset_name=dataset_name, load_from_disk=True)
|
138 |
-
|
139 |
-
items = source.process()
|
140 |
-
source.setup()
|
141 |
-
|
142 |
-
source_schema = source.source_schema()
|
143 |
-
assert source_schema == SourceSchema(
|
144 |
-
fields=schema({
|
145 |
-
HF_SPLIT_COLUMN: 'string',
|
146 |
-
'scalar': 'int64',
|
147 |
-
'list': [{
|
148 |
-
'x': 'int64',
|
149 |
-
'y': 'string',
|
150 |
-
}],
|
151 |
-
}).fields,
|
152 |
-
num_items=2)
|
153 |
-
|
154 |
-
items = list(source.process())
|
155 |
-
|
156 |
-
assert items == [{
|
157 |
-
HF_SPLIT_COLUMN: 'default',
|
158 |
-
'scalar': 1,
|
159 |
-
'list': [{
|
160 |
-
'x': 1,
|
161 |
-
'y': 'two'
|
162 |
-
}]
|
163 |
-
}, {
|
164 |
-
HF_SPLIT_COLUMN: 'default',
|
165 |
-
'scalar': 2,
|
166 |
-
'list': [{
|
167 |
-
'x': 3,
|
168 |
-
'y': 'four'
|
169 |
-
}]
|
170 |
-
}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/sources/json_source_test.py
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
"""Tests for the JSON source."""
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
import pathlib
|
5 |
-
|
6 |
-
from ...schema import schema
|
7 |
-
from .json_source import ROW_ID_COLUMN, JSONDataset
|
8 |
-
from .source import SourceSchema
|
9 |
-
|
10 |
-
|
11 |
-
def test_simple_json(tmp_path: pathlib.Path) -> None:
|
12 |
-
json_records = [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}]
|
13 |
-
|
14 |
-
filename = 'test-dataset.jsonl'
|
15 |
-
filepath = os.path.join(tmp_path, filename)
|
16 |
-
with open(filepath, 'w') as f:
|
17 |
-
f.write(json.dumps(json_records))
|
18 |
-
|
19 |
-
source = JSONDataset(filepaths=[filepath])
|
20 |
-
source.setup()
|
21 |
-
|
22 |
-
source_schema = source.source_schema()
|
23 |
-
assert source_schema == SourceSchema(
|
24 |
-
fields=schema({
|
25 |
-
ROW_ID_COLUMN: 'int64',
|
26 |
-
'x': 'int64',
|
27 |
-
'y': 'string'
|
28 |
-
}).fields, num_items=2)
|
29 |
-
|
30 |
-
items = list(source.process())
|
31 |
-
|
32 |
-
assert items == [{
|
33 |
-
ROW_ID_COLUMN: 0,
|
34 |
-
'x': 1,
|
35 |
-
'y': 'ten'
|
36 |
-
}, {
|
37 |
-
ROW_ID_COLUMN: 1,
|
38 |
-
'x': 2,
|
39 |
-
'y': 'twenty'
|
40 |
-
}]
|
41 |
-
|
42 |
-
|
43 |
-
def test_simple_jsonl(tmp_path: pathlib.Path) -> None:
|
44 |
-
json_records = [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}]
|
45 |
-
json_lines = [json.dumps(record) + '\n' for record in json_records]
|
46 |
-
|
47 |
-
filename = 'test-dataset.jsonl'
|
48 |
-
filepath = os.path.join(tmp_path, filename)
|
49 |
-
with open(filepath, 'w') as f:
|
50 |
-
f.writelines(json_lines)
|
51 |
-
|
52 |
-
source = JSONDataset(dataset_name='test_dataset', filepaths=[filepath])
|
53 |
-
source.setup()
|
54 |
-
|
55 |
-
source_schema = source.source_schema()
|
56 |
-
|
57 |
-
assert source_schema == SourceSchema(
|
58 |
-
fields=schema({
|
59 |
-
ROW_ID_COLUMN: 'int64',
|
60 |
-
'x': 'int64',
|
61 |
-
'y': 'string'
|
62 |
-
}).fields, num_items=2)
|
63 |
-
|
64 |
-
items = list(source.process())
|
65 |
-
|
66 |
-
assert items == [{
|
67 |
-
ROW_ID_COLUMN: 0,
|
68 |
-
'x': 1,
|
69 |
-
'y': 'ten'
|
70 |
-
}, {
|
71 |
-
ROW_ID_COLUMN: 1,
|
72 |
-
'x': 2,
|
73 |
-
'y': 'twenty'
|
74 |
-
}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/sources/pandas_source_test.py
DELETED
@@ -1,91 +0,0 @@
|
|
1 |
-
"""Tests for the pandas source."""
|
2 |
-
|
3 |
-
import pandas as pd
|
4 |
-
|
5 |
-
from ...schema import schema
|
6 |
-
from .pandas_source import PANDAS_INDEX_COLUMN, PandasDataset
|
7 |
-
from .source import SourceSchema
|
8 |
-
|
9 |
-
|
10 |
-
def test_simple_dataframe() -> None:
|
11 |
-
df = pd.DataFrame.from_records([{
|
12 |
-
'name': 'a',
|
13 |
-
'age': 1
|
14 |
-
}, {
|
15 |
-
'name': 'b',
|
16 |
-
'age': 2
|
17 |
-
}, {
|
18 |
-
'name': 'c',
|
19 |
-
'age': 3
|
20 |
-
}])
|
21 |
-
|
22 |
-
source = PandasDataset(df)
|
23 |
-
source.setup()
|
24 |
-
|
25 |
-
source_schema = source.source_schema()
|
26 |
-
assert source_schema == SourceSchema(
|
27 |
-
fields=schema({
|
28 |
-
PANDAS_INDEX_COLUMN: 'int64',
|
29 |
-
'name': 'string',
|
30 |
-
'age': 'int64'
|
31 |
-
}).fields,
|
32 |
-
num_items=3)
|
33 |
-
|
34 |
-
items = list(source.process())
|
35 |
-
|
36 |
-
assert items == [{
|
37 |
-
PANDAS_INDEX_COLUMN: 0,
|
38 |
-
'name': 'a',
|
39 |
-
'age': 1
|
40 |
-
}, {
|
41 |
-
PANDAS_INDEX_COLUMN: 1,
|
42 |
-
'name': 'b',
|
43 |
-
'age': 2
|
44 |
-
}, {
|
45 |
-
PANDAS_INDEX_COLUMN: 2,
|
46 |
-
'name': 'c',
|
47 |
-
'age': 3
|
48 |
-
}]
|
49 |
-
|
50 |
-
|
51 |
-
def test_simple_dataframe_with_index() -> None:
|
52 |
-
df = pd.DataFrame.from_records([{
|
53 |
-
'name': 'a',
|
54 |
-
'age': 1
|
55 |
-
}, {
|
56 |
-
'name': 'b',
|
57 |
-
'age': 2
|
58 |
-
}, {
|
59 |
-
'name': 'c',
|
60 |
-
'age': 3
|
61 |
-
}],
|
62 |
-
index=['id1', 'id2', 'id3'])
|
63 |
-
|
64 |
-
source = PandasDataset(df)
|
65 |
-
source.setup()
|
66 |
-
|
67 |
-
source_schema = source.source_schema()
|
68 |
-
assert source_schema == SourceSchema(
|
69 |
-
fields=schema({
|
70 |
-
PANDAS_INDEX_COLUMN: 'string',
|
71 |
-
'name': 'string',
|
72 |
-
'age': 'int64'
|
73 |
-
}).fields,
|
74 |
-
num_items=3)
|
75 |
-
|
76 |
-
items = list(source.process())
|
77 |
-
|
78 |
-
# The PANDAS_INDEX_COLUMN aligns with the pandas index.
|
79 |
-
assert items == [{
|
80 |
-
PANDAS_INDEX_COLUMN: 'id1',
|
81 |
-
'name': 'a',
|
82 |
-
'age': 1
|
83 |
-
}, {
|
84 |
-
PANDAS_INDEX_COLUMN: 'id2',
|
85 |
-
'name': 'b',
|
86 |
-
'age': 2
|
87 |
-
}, {
|
88 |
-
PANDAS_INDEX_COLUMN: 'id3',
|
89 |
-
'name': 'c',
|
90 |
-
'age': 3
|
91 |
-
}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/sources/source_registry_test.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
"""A source to compute semantic search for a document."""
|
2 |
-
from typing import Iterable, cast
|
3 |
-
|
4 |
-
import pytest
|
5 |
-
from typing_extensions import override
|
6 |
-
|
7 |
-
from ...schema import Item
|
8 |
-
from .source import Source, SourceSchema
|
9 |
-
from .source_registry import clear_source_registry, get_source_cls, register_source, resolve_source
|
10 |
-
|
11 |
-
|
12 |
-
class TestSource(Source):
|
13 |
-
"""A test source."""
|
14 |
-
name = 'test_source'
|
15 |
-
|
16 |
-
@override
|
17 |
-
def setup(self) -> None:
|
18 |
-
pass
|
19 |
-
|
20 |
-
@override
|
21 |
-
def source_schema(self) -> SourceSchema:
|
22 |
-
"""Return the source schema."""
|
23 |
-
return cast(SourceSchema, None)
|
24 |
-
|
25 |
-
@override
|
26 |
-
def process(self) -> Iterable[Item]:
|
27 |
-
yield None
|
28 |
-
|
29 |
-
|
30 |
-
@pytest.fixture(scope='module', autouse=True)
|
31 |
-
def setup_teardown() -> Iterable[None]:
|
32 |
-
# Setup.
|
33 |
-
register_source(TestSource)
|
34 |
-
|
35 |
-
# Unit test runs.
|
36 |
-
yield
|
37 |
-
|
38 |
-
# Teardown.
|
39 |
-
clear_source_registry()
|
40 |
-
|
41 |
-
|
42 |
-
def test_get_source_cls() -> None:
|
43 |
-
"""Test getting a source."""
|
44 |
-
assert TestSource == get_source_cls('test_source')
|
45 |
-
|
46 |
-
|
47 |
-
def test_resolve_source() -> None:
|
48 |
-
"""Test resolving a source."""
|
49 |
-
test_source = TestSource()
|
50 |
-
|
51 |
-
# sources pass through.
|
52 |
-
assert resolve_source(test_source) == test_source
|
53 |
-
|
54 |
-
# Dicts resolve to the base class.
|
55 |
-
assert resolve_source(test_source.dict()) == test_source
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data_loader_test.py
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
"""Tests for data_loader.py."""
|
2 |
-
|
3 |
-
import os
|
4 |
-
import pathlib
|
5 |
-
import uuid
|
6 |
-
from typing import Iterable
|
7 |
-
|
8 |
-
from pytest_mock import MockerFixture
|
9 |
-
from typing_extensions import override
|
10 |
-
|
11 |
-
from .data.dataset_duckdb import read_source_manifest
|
12 |
-
from .data.dataset_utils import parquet_filename
|
13 |
-
from .data.sources.source import Source, SourceSchema
|
14 |
-
from .data_loader import process_source
|
15 |
-
from .schema import PARQUET_FILENAME_PREFIX, UUID_COLUMN, Item, SourceManifest, schema
|
16 |
-
from .test_utils import fake_uuid, read_items
|
17 |
-
from .utils import DATASETS_DIR_NAME
|
18 |
-
|
19 |
-
|
20 |
-
class TestSource(Source):
|
21 |
-
"""A test source."""
|
22 |
-
name = 'test_source'
|
23 |
-
|
24 |
-
@override
|
25 |
-
def setup(self) -> None:
|
26 |
-
pass
|
27 |
-
|
28 |
-
@override
|
29 |
-
def source_schema(self) -> SourceSchema:
|
30 |
-
"""Return the source schema."""
|
31 |
-
return SourceSchema(fields=schema({'x': 'int64', 'y': 'string'}).fields, num_items=2)
|
32 |
-
|
33 |
-
@override
|
34 |
-
def process(self) -> Iterable[Item]:
|
35 |
-
return [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}]
|
36 |
-
|
37 |
-
|
38 |
-
def test_data_loader(tmp_path: pathlib.Path, mocker: MockerFixture) -> None:
|
39 |
-
mock_uuid = mocker.patch.object(uuid, 'uuid4', autospec=True)
|
40 |
-
mock_uuid.side_effect = [fake_uuid(b'1'), fake_uuid(b'2')]
|
41 |
-
|
42 |
-
source = TestSource()
|
43 |
-
setup_mock = mocker.spy(TestSource, 'setup')
|
44 |
-
|
45 |
-
output_dir, num_items = process_source(tmp_path, 'test_namespace', 'test_dataset', source)
|
46 |
-
|
47 |
-
assert setup_mock.call_count == 1
|
48 |
-
|
49 |
-
assert output_dir == os.path.join(tmp_path, DATASETS_DIR_NAME, 'test_namespace', 'test_dataset')
|
50 |
-
assert num_items == 2
|
51 |
-
|
52 |
-
source_manifest = read_source_manifest(output_dir)
|
53 |
-
|
54 |
-
assert source_manifest == SourceManifest(
|
55 |
-
files=[parquet_filename(PARQUET_FILENAME_PREFIX, 0, 1)],
|
56 |
-
data_schema=schema({
|
57 |
-
# UUID_COLUMN is generated by the data loader.
|
58 |
-
UUID_COLUMN: 'string',
|
59 |
-
'x': 'int64',
|
60 |
-
'y': 'string'
|
61 |
-
}),
|
62 |
-
)
|
63 |
-
|
64 |
-
items = read_items(output_dir, source_manifest.files, source_manifest.data_schema)
|
65 |
-
|
66 |
-
assert items == [{
|
67 |
-
UUID_COLUMN: fake_uuid(b'1').hex,
|
68 |
-
'x': 1,
|
69 |
-
'y': 'ten'
|
70 |
-
}, {
|
71 |
-
UUID_COLUMN: fake_uuid(b'2').hex,
|
72 |
-
'x': 2,
|
73 |
-
'y': 'twenty'
|
74 |
-
}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/embeddings/embedding.py
CHANGED
@@ -57,7 +57,7 @@ def compute_split_embeddings(docs: Iterable[str],
|
|
57 |
pool = ThreadPoolExecutor()
|
58 |
|
59 |
def _splitter(doc: str) -> list[TextChunk]:
|
60 |
-
if doc
|
61 |
return []
|
62 |
if split_fn:
|
63 |
return split_fn(doc)
|
@@ -65,15 +65,19 @@ def compute_split_embeddings(docs: Iterable[str],
|
|
65 |
# Return a single chunk that spans the entire document.
|
66 |
return [(doc, (0, len(doc)))]
|
67 |
|
|
|
|
|
68 |
def _flat_split_batch_docs(docs: Iterable[str]) -> Generator[tuple[int, TextChunk], None, None]:
|
69 |
"""Split a batch of documents into chunks and yield them."""
|
|
|
70 |
for i, doc in enumerate(docs):
|
71 |
-
|
|
|
72 |
for chunk in chunks:
|
73 |
yield (i, chunk)
|
74 |
|
75 |
doc_chunks = _flat_split_batch_docs(docs)
|
76 |
-
items_to_yield: list[Item] =
|
77 |
current_index = 0
|
78 |
|
79 |
mega_batch_size = batch_size * num_parallel_requests
|
@@ -81,19 +85,27 @@ def compute_split_embeddings(docs: Iterable[str],
|
|
81 |
for batch in chunks(doc_chunks, mega_batch_size):
|
82 |
texts = [text for _, (text, _) in batch]
|
83 |
embeddings: list[np.ndarray] = []
|
|
|
84 |
for x in list(pool.map(lambda x: embed_fn(x), chunks(texts, batch_size))):
|
85 |
embeddings.extend(x)
|
86 |
matrix = normalize(np.array(embeddings)).astype(np.float16)
|
87 |
# np.split returns a shallow copy of each embedding so we don't increase the mem footprint.
|
88 |
embeddings_batch = cast(list[np.ndarray], np.split(matrix, matrix.shape[0]))
|
89 |
for (index, (_, (start, end))), embedding in zip(batch, embeddings_batch):
|
|
|
90 |
if index == current_index:
|
|
|
|
|
91 |
items_to_yield.append(lilac_embedding(start, end, embedding))
|
92 |
else:
|
93 |
yield items_to_yield
|
|
|
|
|
|
|
|
|
94 |
items_to_yield = [lilac_embedding(start, end, embedding)]
|
95 |
-
current_index = index
|
96 |
|
97 |
-
|
98 |
-
if items_to_yield:
|
99 |
yield items_to_yield
|
|
|
|
|
|
57 |
pool = ThreadPoolExecutor()
|
58 |
|
59 |
def _splitter(doc: str) -> list[TextChunk]:
|
60 |
+
if not doc:
|
61 |
return []
|
62 |
if split_fn:
|
63 |
return split_fn(doc)
|
|
|
65 |
# Return a single chunk that spans the entire document.
|
66 |
return [(doc, (0, len(doc)))]
|
67 |
|
68 |
+
num_docs = 0
|
69 |
+
|
70 |
def _flat_split_batch_docs(docs: Iterable[str]) -> Generator[tuple[int, TextChunk], None, None]:
|
71 |
"""Split a batch of documents into chunks and yield them."""
|
72 |
+
nonlocal num_docs
|
73 |
for i, doc in enumerate(docs):
|
74 |
+
num_docs += 1
|
75 |
+
chunks = _splitter(doc)
|
76 |
for chunk in chunks:
|
77 |
yield (i, chunk)
|
78 |
|
79 |
doc_chunks = _flat_split_batch_docs(docs)
|
80 |
+
items_to_yield: Optional[list[Item]] = None
|
81 |
current_index = 0
|
82 |
|
83 |
mega_batch_size = batch_size * num_parallel_requests
|
|
|
85 |
for batch in chunks(doc_chunks, mega_batch_size):
|
86 |
texts = [text for _, (text, _) in batch]
|
87 |
embeddings: list[np.ndarray] = []
|
88 |
+
|
89 |
for x in list(pool.map(lambda x: embed_fn(x), chunks(texts, batch_size))):
|
90 |
embeddings.extend(x)
|
91 |
matrix = normalize(np.array(embeddings)).astype(np.float16)
|
92 |
# np.split returns a shallow copy of each embedding so we don't increase the mem footprint.
|
93 |
embeddings_batch = cast(list[np.ndarray], np.split(matrix, matrix.shape[0]))
|
94 |
for (index, (_, (start, end))), embedding in zip(batch, embeddings_batch):
|
95 |
+
embedding = embedding.reshape(-1)
|
96 |
if index == current_index:
|
97 |
+
if items_to_yield is None:
|
98 |
+
items_to_yield = []
|
99 |
items_to_yield.append(lilac_embedding(start, end, embedding))
|
100 |
else:
|
101 |
yield items_to_yield
|
102 |
+
current_index += 1
|
103 |
+
while current_index < index:
|
104 |
+
yield None
|
105 |
+
current_index += 1
|
106 |
items_to_yield = [lilac_embedding(start, end, embedding)]
|
|
|
107 |
|
108 |
+
while current_index < num_docs:
|
|
|
109 |
yield items_to_yield
|
110 |
+
items_to_yield = None
|
111 |
+
current_index += 1
|