File size: 3,761 Bytes
bfc0ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8176be9
bfc0ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
2ad28d6
bfc0ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Startup work before running the web server."""

import os
import shutil
from typing import TypedDict

import yaml
from huggingface_hub import scan_cache_dir, snapshot_download

from lilac.concepts.db_concept import CONCEPTS_DIR, DiskConceptDB, get_concept_output_dir
from lilac.env import data_path, env
from lilac.utils import get_datasets_dir, get_lilac_cache_dir, log


def delete_old_files() -> None:
  """Delete old files from the cache."""
  # Scan cache
  try:
    scan = scan_cache_dir()
  except BaseException:
    # Cache was not found.
    return

  # Select revisions to delete
  to_delete = []
  for repo in scan.repos:
    latest_revision = max(repo.revisions, key=lambda x: x.last_modified)
    to_delete.extend(
      [revision.commit_hash for revision in repo.revisions if revision != latest_revision])
  strategy = scan.delete_revisions(*to_delete)

  # Delete them
  log(f'Will delete {len(to_delete)} old revisions and save {strategy.expected_freed_size_str}')
  strategy.execute()


class HfSpaceConfig(TypedDict):
  """The huggingface space config, defined in README.md.

  See:
  https://huggingface.co/docs/hub/spaces-config-reference
  """
  title: str
  datasets: list[str]


def main() -> None:
  """Download dataset files from the HF space that was uploaded before building the image."""
  # SPACE_ID is the HuggingFace Space ID environment variable that is automatically set by HF.
  repo_id = env('SPACE_ID', None)
  if not repo_id:
    return

  delete_old_files()

  with open(os.path.abspath('README.md')) as f:
    # Strip the '---' for the huggingface readme config.
    readme = f.read().strip().strip('---')
    hf_config: HfSpaceConfig = yaml.safe_load(readme)

  # Download the huggingface space data. This includes code and datasets, so we move the datasets
  # alone to the data directory.
  for lilac_hf_dataset in hf_config['datasets']:
    print('Downloading dataset from HuggingFace: ', lilac_hf_dataset)
    snapshot_download(
      repo_id=lilac_hf_dataset,
      repo_type='dataset',
      token=env('HF_ACCESS_TOKEN'),
      local_dir=get_datasets_dir(data_path()),
      ignore_patterns=['.gitattributes', 'README.md'])

  snapshot_dir = snapshot_download(repo_id=repo_id, repo_type='space', token=env('HF_ACCESS_TOKEN'))
  # Copy datasets.
  spaces_data_dir = os.path.join(snapshot_dir, 'data')

  # Delete cache files from persistent storage.
  cache_dir = get_lilac_cache_dir(data_path())
  if os.path.exists(cache_dir):
    shutil.rmtree(cache_dir)

  # NOTE: This is temporary during the move of concepts into the pip package. Once all the demos
  # have been updated, this block can be deleted.
  old_lilac_concepts_data_dir = os.path.join(data_path(), CONCEPTS_DIR, 'lilac')
  if os.path.exists(old_lilac_concepts_data_dir):
    shutil.rmtree(old_lilac_concepts_data_dir)

  # Copy cache files from the space if they exist.
  spaces_cache_dir = get_lilac_cache_dir(spaces_data_dir)
  if os.path.exists(spaces_cache_dir):
    shutil.copytree(spaces_cache_dir, cache_dir)

  # Copy concepts.
  concepts = DiskConceptDB(spaces_data_dir).list()
  for concept in concepts:
    # Ignore lilac concepts, they're already part of the source code.
    if concept.namespace == 'lilac':
      continue
    spaces_concept_output_dir = get_concept_output_dir(spaces_data_dir, concept.namespace,
                                                       concept.name)
    persistent_output_dir = get_concept_output_dir(data_path(), concept.namespace, concept.name)
    shutil.rmtree(persistent_output_dir, ignore_errors=True)
    shutil.copytree(spaces_concept_output_dir, persistent_output_dir, dirs_exist_ok=True)
    shutil.rmtree(spaces_concept_output_dir, ignore_errors=True)


if __name__ == '__main__':
  main()