RVC_V2 / docs /kr /faiss_tips_ko.md
Rejekts's picture
Add folders
528df8b
|
raw
history blame
9.18 kB

Facebook AI Similarity Search (Faiss) ํŒ

Faiss์— ๋Œ€ํ•˜์—ฌ

Faiss ๋Š” Facebook Research๊ฐ€ ๊ฐœ๋ฐœํ•˜๋Š”, ๊ณ ๋ฐ€๋„ ๋ฒกํ„ฐ ์ด์›ƒ ๊ฒ€์ƒ‰ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ž…๋‹ˆ๋‹ค. ๊ทผ์‚ฌ ๊ทผ์ ‘ ํƒ์ƒ‰๋ฒ• (Approximate Neigbor Search)์€ ์•ฝ๊ฐ„์˜ ์ •ํ™•์„ฑ์„ ํฌ์ƒํ•˜์—ฌ ์œ ์‚ฌ ๋ฒกํ„ฐ๋ฅผ ๊ณ ์†์œผ๋กœ ์ฐพ์Šต๋‹ˆ๋‹ค.

RVC์— ์žˆ์–ด์„œ Faiss

RVC์—์„œ๋Š” HuBERT๋กœ ๋ณ€ํ™˜ํ•œ feature์˜ embedding์„ ์œ„ํ•ด ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์—์„œ ์ƒ์„ฑ๋œ embedding๊ณผ ์œ ์‚ฌํ•œ embadding์„ ๊ฒ€์ƒ‰ํ•˜๊ณ  ํ˜ผํ•ฉํ•˜์—ฌ ์›๋ž˜์˜ ์Œ์„ฑ์— ๋”์šฑ ๊ฐ€๊นŒ์šด ๋ณ€ํ™˜์„ ๋‹ฌ์„ฑํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜, ์ด ํƒ์ƒ‰๋ฒ•์€ ๋‹จ์ˆœํžˆ ์ˆ˜ํ–‰ํ•˜๋ฉด ์‹œ๊ฐ„์ด ๋‹ค์†Œ ์†Œ๋ชจ๋˜๋ฏ€๋กœ, ๊ทผ์‚ฌ ๊ทผ์ ‘ ํƒ์ƒ‰๋ฒ•์„ ํ†ตํ•ด ๊ณ ์† ๋ณ€ํ™˜์„ ๊ฐ€๋Šฅ์ผ€ ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

๊ตฌํ˜„ ๊ฐœ์š”

๋ชจ๋ธ์ด ์œ„์น˜ํ•œ /logs/your-experiment/3_feature256์—๋Š” ๊ฐ ์Œ์„ฑ ๋ฐ์ดํ„ฐ์—์„œ HuBERT๊ฐ€ ์ถ”์ถœํ•œ feature๋“ค์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์—์„œ ํŒŒ์ผ ์ด๋ฆ„๋ณ„๋กœ ์ •๋ ฌ๋œ npy ํŒŒ์ผ์„ ์ฝ๊ณ , ๋ฒกํ„ฐ๋ฅผ ์—ฐ๊ฒฐํ•˜์—ฌ big_npy ([N, 256] ๋ชจ์–‘์˜ ๋ฒกํ„ฐ) ๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค. big_npy๋ฅผ /logs/your-experiment/total_fea.npy๋กœ ์ €์žฅํ•œ ํ›„, Faiss๋กœ ํ•™์Šต์‹œํ‚ต๋‹ˆ๋‹ค.

2023/04/18 ๊ธฐ์ค€์œผ๋กœ, Faiss์˜ Index Factory ๊ธฐ๋Šฅ์„ ์ด์šฉํ•ด, L2 ๊ฑฐ๋ฆฌ์— ๊ทผ๊ฑฐํ•˜๋Š” IVF๋ฅผ ์ด์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. IVF์˜ ๋ถ„ํ• ์ˆ˜(n_ivf)๋Š” N//39๋กœ, n_probe๋Š” int(np.power(n_ivf, 0.3))๊ฐ€ ์‚ฌ์šฉ๋˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. (infer-web.py์˜ train_index ์ฃผ์œ„๋ฅผ ์ฐพ์œผ์‹ญ์‹œ์˜ค.)

์ด ํŒ์—์„œ๋Š” ๋จผ์ € ์ด๋Ÿฌํ•œ ๋งค๊ฐœ ๋ณ€์ˆ˜์˜ ์˜๋ฏธ๋ฅผ ์„ค๋ช…ํ•˜๊ณ , ๊ฐœ๋ฐœ์ž๊ฐ€ ์ถ”ํ›„ ๋” ๋‚˜์€ index๋ฅผ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” ์กฐ์–ธ์„ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค.

๋ฐฉ๋ฒ•์˜ ์„ค๋ช…

Index factory

index factory๋Š” ์—ฌ๋Ÿฌ ๊ทผ์‚ฌ ๊ทผ์ ‘ ํƒ์ƒ‰๋ฒ•์„ ๋ฌธ์ž์—ด๋กœ ์—ฐ๊ฒฐํ•˜๋Š” pipeline์„ ๋ฌธ์ž์—ด๋กœ ํ‘œ๊ธฐํ•˜๋Š” Faiss๋งŒ์˜ ๋…์ž์ ์ธ ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด index factory์˜ ๋ฌธ์ž์—ด์„ ๋ณ€๊ฒฝํ•˜๋Š” ๊ฒƒ๋งŒ์œผ๋กœ ๋‹ค์–‘ํ•œ ๊ทผ์‚ฌ ๊ทผ์ ‘ ํƒ์ƒ‰์„ ์‹œ๋„ํ•ด ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. RVC์—์„œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค:

index = Faiss.index_factory(256, "IVF%s,Flat" % n_ivf)

index_factory์˜ ์ธ์ˆ˜๋“ค ์ค‘ ์ฒซ ๋ฒˆ์งธ๋Š” ๋ฒกํ„ฐ์˜ ์ฐจ์› ์ˆ˜์ด๊ณ , ๋‘๋ฒˆ์งธ๋Š” index factory ๋ฌธ์ž์—ด์ด๋ฉฐ, ์„ธ๋ฒˆ์งธ์—๋Š” ์‚ฌ์šฉํ•  ๊ฑฐ๋ฆฌ๋ฅผ ์ง€์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ธฐ๋ฒ•์˜ ๋ณด๋‹ค ์ž์„ธํ•œ ์„ค๋ช…์€ https://github.com/facebookresearch/Faiss/wiki/The-index-factory ๋ฅผ ํ™•์ธํ•ด ์ฃผ์‹ญ์‹œ์˜ค.

๊ฑฐ๋ฆฌ์— ๋Œ€ํ•œ index

embedding์˜ ์œ ์‚ฌ๋„๋กœ์„œ ์‚ฌ์šฉ๋˜๋Š” ๋Œ€ํ‘œ์ ์ธ ์ง€ํ‘œ๋กœ์„œ ์ดํ•˜์˜ 2๊ฐœ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

  • ์œ ํด๋ฆฌ๋“œ ๊ฑฐ๋ฆฌ (METRIC_L2)
  • ๋‚ด์ (ๅ†…็ฉ) (METRIC_INNER_PRODUCT)

์œ ํด๋ฆฌ๋“œ ๊ฑฐ๋ฆฌ์—์„œ๋Š” ๊ฐ ์ฐจ์›์—์„œ ์ œ๊ณฑ์˜ ์ฐจ๋ฅผ ๊ตฌํ•˜๊ณ , ๊ฐ ์ฐจ์›์—์„œ ๊ตฌํ•œ ์ฐจ๋ฅผ ๋ชจ๋‘ ๋”ํ•œ ํ›„ ์ œ๊ณฑ๊ทผ์„ ์ทจํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ์ผ์ƒ์ ์œผ๋กœ ์‚ฌ์šฉ๋˜๋Š” 2์ฐจ์›, 3์ฐจ์›์—์„œ์˜ ๊ฑฐ๋ฆฌ์˜ ์—ฐ์‚ฐ๋ฒ•๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค. ๋‚ด์ ์€ ๊ทธ ๊ฐ’์„ ๊ทธ๋Œ€๋กœ ์œ ์‚ฌ๋„ ์ง€ํ‘œ๋กœ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ , L2 ์ •๊ทœํ™”๋ฅผ ํ•œ ์ดํ›„ ๋‚ด์ ์„ ์ทจํ•˜๋Š” ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

์–ด๋Š ์ชฝ์ด ๋” ์ข‹์€์ง€๋Š” ๊ฒฝ์šฐ์— ๋”ฐ๋ผ ๋‹ค๋ฅด์ง€๋งŒ, word2vec์—์„œ ์–ป์€ embedding ๋ฐ ArcFace๋ฅผ ํ™œ์šฉํ•œ ์ด๋ฏธ์ง€ ๊ฒ€์ƒ‰ ๋ชจ๋ธ์€ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ์„ฑ์ด ์ด์šฉ๋˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค. numpy๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฒกํ„ฐ X์— ๋Œ€ํ•ด L2 ์ •๊ทœํ™”๋ฅผ ํ•˜๊ณ ์ž ํ•˜๋Š” ๊ฒฝ์šฐ, 0 division์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด ์ถฉ๋ถ„ํžˆ ์ž‘์€ ๊ฐ’์„ eps๋กœ ํ•œ ๋’ค ์ดํ•˜์— ์ฝ”๋“œ๋ฅผ ํ™œ์šฉํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

X_normed = X / np.maximum(eps, np.linalg.norm(X, ord=2, axis=-1, keepdims=True))

๋˜ํ•œ, index factory์˜ 3๋ฒˆ์งธ ์ธ์ˆ˜์— ๊ฑด๋„ค์ฃผ๋Š” ๊ฐ’์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ์„ ํ†ตํ•ด ๊ณ„์‚ฐ์— ์‚ฌ์šฉํ•˜๋Š” ๊ฑฐ๋ฆฌ index๋ฅผ ๋ณ€๊ฒฝํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

index = Faiss.index_factory(dimention, text, Faiss.METRIC_INNER_PRODUCT)

IVF

IVF (Inverted file indexes)๋Š” ์—ญ์ƒ‰์ธ ํƒ์ƒ‰๋ฒ•๊ณผ ์œ ์‚ฌํ•œ ์•Œ๊ณ ๋ฆฌ์ฆ˜์ž…๋‹ˆ๋‹ค. ํ•™์Šต์‹œ์—๋Š” ๊ฒ€์ƒ‰ ๋Œ€์ƒ์— ๋Œ€ํ•ด k-ํ‰๊ท  ๊ตฐ์ง‘๋ฒ•์„ ์‹ค์‹œํ•˜๊ณ  ํด๋Ÿฌ์Šคํ„ฐ ์ค‘์‹ฌ์„ ์ด์šฉํ•ด ๋ณด๋กœ๋…ธ์ด ๋ถ„ํ• ์„ ์‹ค์‹œํ•ฉ๋‹ˆ๋‹ค. ๊ฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์—๋Š” ํด๋Ÿฌ์Šคํ„ฐ๊ฐ€ ํ• ๋‹น๋˜๋ฏ€๋กœ, ํด๋Ÿฌ์Šคํ„ฐ์—์„œ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ๋ฅผ ์กฐํšŒํ•˜๋Š” dictionary๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, ํด๋Ÿฌ์Šคํ„ฐ๊ฐ€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ• ๋‹น๋œ ๊ฒฝ์šฐ

index Cluster
1 A
2 B
3 A
4 C
5 B

IVF ์ดํ›„์˜ ๊ฒฐ๊ณผ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

cluster index
A 1, 3
B 2, 5
C 4

ํƒ์ƒ‰ ์‹œ, ์šฐ์„  ํด๋Ÿฌ์Šคํ„ฐ์—์„œ n_probe๊ฐœ์˜ ํด๋Ÿฌ์Šคํ„ฐ๋ฅผ ํƒ์ƒ‰ํ•œ ๋‹ค์Œ, ๊ฐ ํด๋Ÿฌ์Šคํ„ฐ์— ์†ํ•œ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์˜ ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.

๊ถŒ์žฅ ๋งค๊ฐœ๋ณ€์ˆ˜

index์˜ ์„ ํƒ ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด์„œ๋Š” ๊ณต์‹์ ์œผ๋กœ ๊ฐ€์ด๋“œ ๋ผ์ธ์ด ์žˆ์œผ๋ฏ€๋กœ, ๊ฑฐ๊ธฐ์— ์ค€ํ•ด ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค. https://github.com/facebookresearch/Faiss/wiki/Guidelines-to-choose-an-index

1M ์ดํ•˜์˜ ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ์žˆ์–ด์„œ๋Š” 4bit-PQ๊ฐ€ 2023๋…„ 4์›” ์‹œ์ ์—์„œ๋Š” Faiss๋กœ ์ด์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ฐ€์žฅ ํšจ์œจ์ ์ธ ์ˆ˜๋ฒ•์ž…๋‹ˆ๋‹ค. ์ด๊ฒƒ์„ IVF์™€ ์กฐํ•ฉํ•ด, 4bit-PQ๋กœ ํ›„๋ณด๋ฅผ ์ถ”๋ ค๋‚ด๊ณ , ๋งˆ์ง€๋ง‰์œผ๋กœ ์ดํ•˜์˜ index factory๋ฅผ ์ด์šฉํ•˜์—ฌ ์ •ํ™•ํ•œ ์ง€ํ‘œ๋กœ ๊ฑฐ๋ฆฌ๋ฅผ ์žฌ๊ณ„์‚ฐํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

index = Faiss.index_factory(256, "IVF1024,PQ128x4fs,RFlat")

IVF ๊ถŒ์žฅ ๋งค๊ฐœ๋ณ€์ˆ˜

IVF์˜ ์ˆ˜๊ฐ€ ๋„ˆ๋ฌด ๋งŽ์œผ๋ฉด, ๊ฐ€๋ น ๋ฐ์ดํ„ฐ ์ˆ˜์˜ ์ˆ˜๋งŒํผ IVF๋กœ ์–‘์žํ™”(Quantization)๋ฅผ ์ˆ˜ํ–‰ํ•˜๋ฉด, ์ด๊ฒƒ์€ ์™„์ „ํƒ์ƒ‰๊ณผ ๊ฐ™์•„์ ธ ํšจ์œจ์ด ๋‚˜๋น ์ง€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. 1M ์ดํ•˜์˜ ๊ฒฝ์šฐ IVF ๊ฐ’์€ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ ์ˆ˜ N์— ๋Œ€ํ•ด 4sqrt(N) ~ 16sqrt(N)๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค.

n_probe๋Š” n_probe์˜ ์ˆ˜์— ๋น„๋ก€ํ•˜์—ฌ ๊ณ„์‚ฐ ์‹œ๊ฐ„์ด ๋Š˜์–ด๋‚˜๋ฏ€๋กœ ์ •ํ™•๋„์™€ ์‹œ๊ฐ„์„ ์ ์ ˆํžˆ ๊ท ํ˜•์„ ๋งž์ถ”์–ด ์ฃผ์‹ญ์‹œ์˜ค. ๊ฐœ์ธ์ ์œผ๋กœ RVC์— ์žˆ์–ด์„œ ๊ทธ๋ ‡๊ฒŒ๊นŒ์ง€ ์ •ํ™•๋„๋Š” ํ•„์š” ์—†๋‹ค๊ณ  ์ƒ๊ฐํ•˜๊ธฐ ๋•Œ๋ฌธ์— n_probe = 1์ด๋ฉด ๋œ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค.

FastScan

FastScan์€ ์ง์  ์–‘์žํ™”๋ฅผ ๋ ˆ์ง€์Šคํ„ฐ์—์„œ ์ˆ˜ํ–‰ํ•จ์œผ๋กœ์จ ๊ฑฐ๋ฆฌ์˜ ๊ณ ์† ๊ทผ์‚ฌ๋ฅผ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.์ง์  ์–‘์žํ™”๋Š” ํ•™์Šต์‹œ์— d์ฐจ์›๋งˆ๋‹ค(๋ณดํ†ต d=2)์— ๋…๋ฆฝ์ ์œผ๋กœ ํด๋Ÿฌ์Šคํ„ฐ๋ง์„ ์‹ค์‹œํ•ด, ํด๋Ÿฌ์Šคํ„ฐ๋ผ๋ฆฌ์˜ ๊ฑฐ๋ฆฌ๋ฅผ ์‚ฌ์ „ ๊ณ„์‚ฐํ•ด lookup table๋ฅผ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ์ธก์‹œ๋Š” lookup table์„ ๋ณด๋ฉด ๊ฐ ์ฐจ์›์˜ ๊ฑฐ๋ฆฌ๋ฅผ O(1)๋กœ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ PQ ๋‹ค์Œ์— ์ง€์ •ํ•˜๋Š” ์ˆซ์ž๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ๋ฒกํ„ฐ์˜ ์ ˆ๋ฐ˜ ์ฐจ์›์„ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.

FastScan์— ๋Œ€ํ•œ ์ž์„ธํ•œ ์„ค๋ช…์€ ๊ณต์‹ ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์‹ญ์‹œ์˜ค. https://github.com/facebookresearch/Faiss/wiki/Fast-accumulation-of-PQ-and-AQ-codes-(FastScan)

RFlat

RFlat์€ FastScan์ด ๊ณ„์‚ฐํ•œ ๋Œ€๋žต์ ์ธ ๊ฑฐ๋ฆฌ๋ฅผ index factory์˜ 3๋ฒˆ์งธ ์ธ์ˆ˜๋กœ ์ง€์ •ํ•œ ์ •ํ™•ํ•œ ๊ฑฐ๋ฆฌ๋กœ ๋‹ค์‹œ ๊ณ„์‚ฐํ•˜๋ผ๋Š” ์ธ์ŠคํŠธ๋Ÿญ์…˜์ž…๋‹ˆ๋‹ค. k๊ฐœ์˜ ๊ทผ์ ‘ ๋ณ€์ˆ˜๋ฅผ ๊ฐ€์ ธ์˜ฌ ๋•Œ k*k_factor๊ฐœ์˜ ์ ์— ๋Œ€ํ•ด ์žฌ๊ณ„์‚ฐ์ด ์ด๋ฃจ์–ด์ง‘๋‹ˆ๋‹ค.

Embedding ํ…Œํฌ๋‹‰

Alpha ์ฟผ๋ฆฌ ํ™•์žฅ

ํ€ด๋ฆฌ ํ™•์žฅ์ด๋ž€ ํƒ์ƒ‰์—์„œ ์‚ฌ์šฉ๋˜๋Š” ๊ธฐ์ˆ ๋กœ, ์˜ˆ๋ฅผ ๋“ค์–ด ์ „๋ฌธ ํƒ์ƒ‰ ์‹œ, ์ž…๋ ฅ๋œ ๊ฒ€์ƒ‰๋ฌธ์— ๋‹จ์–ด๋ฅผ ๋ช‡ ๊ฐœ๋ฅผ ์ถ”๊ฐ€ํ•จ์œผ๋กœ์จ ๊ฒ€์ƒ‰ ์ •ํ™•๋„๋ฅผ ์˜ฌ๋ฆฌ๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ๋ฐฑํ„ฐ ํƒ์ƒ‰์„ ์œ„ํ•ด์„œ๋„ ๋ช‡๊ฐ€์ง€ ๋ฐฉ๋ฒ•์ด ์ œ์•ˆ๋˜์—ˆ๋Š”๋ฐ, ๊ทธ ์ค‘ ฮฑ-์ฟผ๋ฆฌ ํ™•์žฅ์€ ์ถ”๊ฐ€ ํ•™์Šต์ด ํ•„์š” ์—†๋Š” ๋งค์šฐ ํšจ๊ณผ์ ์ธ ๋ฐฉ๋ฒ•์œผ๋กœ ์•Œ๋ ค์ ธ ์žˆ์Šต๋‹ˆ๋‹ค. Attention-Based Query Expansion Learning์™€ 2nd place solution of kaggle shopee competition ๋…ผ๋ฌธ์—์„œ ์†Œ๊ฐœ๋œ ๋ฐ” ์žˆ์Šต๋‹ˆ๋‹ค..

ฮฑ-์ฟผ๋ฆฌ ํ™•์žฅ์€ ํ•œ ๋ฒกํ„ฐ์— ์ธ์ ‘ํ•œ ๋ฒกํ„ฐ๋ฅผ ์œ ์‚ฌ๋„์˜ ฮฑ๊ณฑํ•œ ๊ฐ€์ค‘์น˜๋กœ ๋”ํ•ด์ฃผ๋ฉด ๋ฉ๋‹ˆ๋‹ค. ์ฝ”๋“œ๋กœ ์˜ˆ์‹œ๋ฅผ ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. big_npy๋ฅผ ฮฑ query expansion๋กœ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค.

alpha = 3.
index = Faiss.index_factory(256, "IVF512,PQ128x4fs,RFlat")
original_norm = np.maximum(np.linalg.norm(big_npy, ord=2, axis=1, keepdims=True), 1e-9)
big_npy /= original_norm
index.train(big_npy)
index.add(big_npy)
dist, neighbor = index.search(big_npy, num_expand)

expand_arrays = []
ixs = np.arange(big_npy.shape[0])
for i in range(-(-big_npy.shape[0]//batch_size)):
    ix = ixs[i*batch_size:(i+1)*batch_size]
    weight = np.power(np.einsum("nd,nmd->nm", big_npy[ix], big_npy[neighbor[ix]]), alpha)
    expand_arrays.append(np.sum(big_npy[neighbor[ix]] * np.expand_dims(weight, axis=2),axis=1))
big_npy = np.concatenate(expand_arrays, axis=0)

# index version ์ •๊ทœํ™”
big_npy = big_npy / np.maximum(np.linalg.norm(big_npy, ord=2, axis=1, keepdims=True), 1e-9)

์œ„ ํ…Œํฌ๋‹‰์€ ํƒ์ƒ‰์„ ์ˆ˜ํ–‰ํ•˜๋Š” ์ฟผ๋ฆฌ์—๋„, ํƒ์ƒ‰ ๋Œ€์ƒ DB์—๋„ ์ ์‘ ๊ฐ€๋Šฅํ•œ ํ…Œํฌ๋‹‰์ž…๋‹ˆ๋‹ค.

MiniBatch KMeans์— ์˜ํ•œ embedding ์••์ถ•

total_fea.npy๊ฐ€ ๋„ˆ๋ฌด ํด ๊ฒฝ์šฐ K-means๋ฅผ ์ด์šฉํ•˜์—ฌ ๋ฒกํ„ฐ๋ฅผ ์ž‘๊ฒŒ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. ์ดํ•˜ ์ฝ”๋“œ๋กœ embedding์˜ ์••์ถ•์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. n_clusters์— ์••์ถ•ํ•˜๊ณ ์ž ํ•˜๋Š” ํฌ๊ธฐ๋ฅผ ์ง€์ •ํ•˜๊ณ  batch_size์— 256 * CPU์˜ ์ฝ”์–ด ์ˆ˜๋ฅผ ์ง€์ •ํ•จ์œผ๋กœ์จ CPU ๋ณ‘๋ ฌํ™”์˜ ํ˜œํƒ์„ ์ถฉ๋ถ„ํžˆ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

import multiprocessing
from sklearn.cluster import MiniBatchKMeans
kmeans = MiniBatchKMeans(n_clusters=10000, batch_size=256 * multiprocessing.cpu_count(), init="random")
kmeans.fit(big_npy)
sample_npy = kmeans.cluster_centers_