Spaces:
Runtime error
Runtime error
Push
Browse files- src/data/duckdb_utils.py +5 -1
- src/data/sources/json_source.py +0 -2
- src/embeddings/sbert.py +3 -1
src/data/duckdb_utils.py
CHANGED
@@ -1,11 +1,15 @@
|
|
1 |
"""Utils for duckdb."""
|
2 |
import duckdb
|
3 |
|
4 |
-
from ..config import CONFIG
|
5 |
|
6 |
|
7 |
def duckdb_gcs_setup(con: duckdb.DuckDBPyConnection) -> str:
|
8 |
"""Setup DuckDB for GCS."""
|
|
|
|
|
|
|
|
|
9 |
con.install_extension('httpfs')
|
10 |
con.load_extension('httpfs')
|
11 |
|
|
|
1 |
"""Utils for duckdb."""
|
2 |
import duckdb
|
3 |
|
4 |
+
from ..config import CONFIG, data_path
|
5 |
|
6 |
|
7 |
def duckdb_gcs_setup(con: duckdb.DuckDBPyConnection) -> str:
|
8 |
"""Setup DuckDB for GCS."""
|
9 |
+
con.execute(f"""
|
10 |
+
SET extension_directory='{data_path()}';
|
11 |
+
""")
|
12 |
+
|
13 |
con.install_extension('httpfs')
|
14 |
con.load_extension('httpfs')
|
15 |
|
src/data/sources/json_source.py
CHANGED
@@ -6,7 +6,6 @@ import pandas as pd
|
|
6 |
from pydantic import Field as PydanticField
|
7 |
from typing_extensions import override
|
8 |
|
9 |
-
from ...config import data_path
|
10 |
from ...schema import Item
|
11 |
from ...utils import download_http_files
|
12 |
from ..duckdb_utils import duckdb_gcs_setup
|
@@ -39,7 +38,6 @@ class JSONDataset(Source):
|
|
39 |
# DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
|
40 |
s3_filepaths = [path.replace('gs://', 's3://') for path in filepaths]
|
41 |
|
42 |
-
con.execute(f"""SET extension_directory='{data_path()}';""")
|
43 |
# NOTE: We use duckdb here to increase parallelism for multiple files.
|
44 |
self._df = con.execute(f"""
|
45 |
{duckdb_gcs_setup(con)}
|
|
|
6 |
from pydantic import Field as PydanticField
|
7 |
from typing_extensions import override
|
8 |
|
|
|
9 |
from ...schema import Item
|
10 |
from ...utils import download_http_files
|
11 |
from ..duckdb_utils import duckdb_gcs_setup
|
|
|
38 |
# DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
|
39 |
s3_filepaths = [path.replace('gs://', 's3://') for path in filepaths]
|
40 |
|
|
|
41 |
# NOTE: We use duckdb here to increase parallelism for multiple files.
|
42 |
self._df = con.execute(f"""
|
43 |
{duckdb_gcs_setup(con)}
|
src/embeddings/sbert.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
from typing_extensions import override
|
8 |
|
|
|
9 |
from ..schema import Item, RichData
|
10 |
from ..signals.signal import TextEmbeddingSignal
|
11 |
from ..signals.splitters.chunk_splitter import split_text
|
@@ -32,7 +33,8 @@ def _sbert() -> tuple[Optional[str], SentenceTransformer]:
|
|
32 |
preferred_device = 'mps'
|
33 |
elif not torch.backends.mps.is_built():
|
34 |
log('MPS not available because the current PyTorch install was not built with MPS enabled.')
|
35 |
-
return preferred_device, SentenceTransformer(
|
|
|
36 |
|
37 |
|
38 |
def _optimal_batch_size(preferred_device: Optional[str]) -> int:
|
|
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
from typing_extensions import override
|
8 |
|
9 |
+
from ..config import data_path
|
10 |
from ..schema import Item, RichData
|
11 |
from ..signals.signal import TextEmbeddingSignal
|
12 |
from ..signals.splitters.chunk_splitter import split_text
|
|
|
33 |
preferred_device = 'mps'
|
34 |
elif not torch.backends.mps.is_built():
|
35 |
log('MPS not available because the current PyTorch install was not built with MPS enabled.')
|
36 |
+
return preferred_device, SentenceTransformer(
|
37 |
+
MODEL_NAME, device=preferred_device, cache_folder=data_path())
|
38 |
|
39 |
|
40 |
def _optimal_batch_size(preferred_device: Optional[str]) -> int:
|