[Feat] Add user history

#4
by Wauplin HF staff - opened
Files changed (4) hide show
  1. README.md +2 -1
  2. app.py +15 -3
  3. gallery_history.py +129 -0
  4. requirements.txt +2 -1
README.md CHANGED
@@ -4,10 +4,11 @@ emoji: 🌍
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.39.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.44.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ hf_oauth: true
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -9,6 +9,9 @@ from diffusers.utils import numpy_to_pil
9
  from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
10
  from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
  from previewer.modules import Previewer
 
 
 
12
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
13
 
14
  DESCRIPTION = "# Würstchen"
@@ -38,7 +41,7 @@ if torch.cuda.is_available():
38
  if USE_TORCH_COMPILE:
39
  prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
40
  decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True)
41
-
42
  if PREVIEW_IMAGES:
43
  previewer = Previewer()
44
  previewer.load_state_dict(torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"])
@@ -48,6 +51,7 @@ if torch.cuda.is_available():
48
  output = previewer(latents)
49
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
50
  return output
 
51
  else:
52
  previewer = None
53
  callback_prior = None
@@ -96,7 +100,7 @@ def generate(
96
  if isinstance(r, list):
97
  yield r
98
  prior_output = r
99
-
100
  decoder_output = decoder_pipeline(
101
  image_embeddings=prior_output.image_embeddings,
102
  prompt=prompt,
@@ -209,6 +213,8 @@ with gr.Blocks(css="style.css") as demo:
209
  cache_examples=CACHE_EXAMPLES,
210
  )
211
 
 
 
212
  inputs = [
213
  prompt,
214
  negative_prompt,
@@ -234,6 +240,8 @@ with gr.Blocks(css="style.css") as demo:
234
  inputs=inputs,
235
  outputs=result,
236
  api_name="run",
 
 
237
  )
238
  negative_prompt.submit(
239
  fn=randomize_seed_fn,
@@ -246,6 +254,8 @@ with gr.Blocks(css="style.css") as demo:
246
  inputs=inputs,
247
  outputs=result,
248
  api_name=False,
 
 
249
  )
250
  run_button.click(
251
  fn=randomize_seed_fn,
@@ -258,7 +268,9 @@ with gr.Blocks(css="style.css") as demo:
258
  inputs=inputs,
259
  outputs=result,
260
  api_name=False,
 
 
261
  )
262
 
263
  if __name__ == "__main__":
264
- demo.queue(max_size=20).launch()
 
9
  from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
10
  from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
  from previewer.modules import Previewer
12
+
13
+ from gallery_history import fetch_gallery_history, show_gallery_history
14
+
15
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
16
 
17
  DESCRIPTION = "# Würstchen"
 
41
  if USE_TORCH_COMPILE:
42
  prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
43
  decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True)
44
+
45
  if PREVIEW_IMAGES:
46
  previewer = Previewer()
47
  previewer.load_state_dict(torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"])
 
51
  output = previewer(latents)
52
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
53
  return output
54
+
55
  else:
56
  previewer = None
57
  callback_prior = None
 
100
  if isinstance(r, list):
101
  yield r
102
  prior_output = r
103
+
104
  decoder_output = decoder_pipeline(
105
  image_embeddings=prior_output.image_embeddings,
106
  prompt=prompt,
 
213
  cache_examples=CACHE_EXAMPLES,
214
  )
215
 
216
+ history = show_gallery_history()
217
+
218
  inputs = [
219
  prompt,
220
  negative_prompt,
 
240
  inputs=inputs,
241
  outputs=result,
242
  api_name="run",
243
+ ).then(
244
+ fn=fetch_gallery_history, inputs=[prompt, result], outputs=history
245
  )
246
  negative_prompt.submit(
247
  fn=randomize_seed_fn,
 
254
  inputs=inputs,
255
  outputs=result,
256
  api_name=False,
257
+ ).then(
258
+ fn=fetch_gallery_history, inputs=[prompt, result], outputs=history
259
  )
260
  run_button.click(
261
  fn=randomize_seed_fn,
 
268
  inputs=inputs,
269
  outputs=result,
270
  api_name=False,
271
+ ).then(
272
+ fn=fetch_gallery_history, inputs=[prompt, result], outputs=history
273
  )
274
 
275
  if __name__ == "__main__":
276
+ demo.queue(max_size=20).launch()
gallery_history.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ How to use:
3
+ 1. Create a Space with a Persistent Storage attached. Filesystem will be available under `/data`.
4
+ 2. Add `hf_oauth: true` to the Space metadata (README.md). Make sure to have Gradio>=3.41.0 configured.
5
+ 3. Add `HISTORY_FOLDER` as a Space variable (example. `"/data/history"`).
6
+ 4. Add `filelock` as dependency in `requirements.txt`.
7
+ 5. Add history gallery to your Gradio app:
8
+ a. Add imports: `from gallery_history import fetch_gallery_history, show_gallery_history`
9
+ a. Add `history = show_gallery_history()` within `gr.Blocks` context.
10
+ b. Add `.then(fn=fetch_gallery_history, inputs=[prompt, result], outputs=history)` on the generate event.
11
+ """
12
+ import json
13
+ import os
14
+ import shutil
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional, Tuple
17
+ from uuid import uuid4
18
+
19
+ import gradio as gr
20
+ from filelock import FileLock
21
+
22
+ _folder = os.environ.get("HISTORY_FOLDER")
23
+ if _folder is None:
24
+ print(
25
+ "'HISTORY_FOLDER' environment variable not set. User history will be saved "
26
+ "locally and will be lost when the Space instance is restarted."
27
+ )
28
+ _folder = Path(__file__).parent / "history"
29
+ HISTORY_FOLDER_PATH = Path(_folder)
30
+
31
+ IMAGES_FOLDER_PATH = HISTORY_FOLDER_PATH / "images"
32
+ IMAGES_FOLDER_PATH.mkdir(parents=True, exist_ok=True)
33
+
34
+
35
+ def show_gallery_history():
36
+ gr.Markdown(
37
+ "## Your past generations\n\n(Log in to keep a gallery of your previous generations."
38
+ " Your history will be saved and available on your next visit.)"
39
+ )
40
+ with gr.Column():
41
+ with gr.Row():
42
+ gr.LoginButton(min_width=250)
43
+ gr.LogoutButton(min_width=250)
44
+ gallery = gr.Gallery(
45
+ label="Past images",
46
+ show_label=True,
47
+ elem_id="gallery",
48
+ object_fit="contain",
49
+ columns=3,
50
+ height=300,
51
+ preview=False,
52
+ show_share_button=False,
53
+ show_download_button=False,
54
+ )
55
+ gr.Markdown(
56
+ "Make sure to save your images from time to time, this gallery may be deleted in the future."
57
+ )
58
+ gallery.attach_load_event(fetch_gallery_history, every=None)
59
+ return gallery
60
+
61
+
62
+ def fetch_gallery_history(
63
+ prompt: Optional[str] = None,
64
+ result: Optional[Dict] = None,
65
+ user: Optional[gr.OAuthProfile] = None,
66
+ ):
67
+ if user is None:
68
+ return []
69
+ try:
70
+ if prompt is not None and result is not None: # None values means no new images
71
+ return _update_user_history(
72
+ user["preferred_username"], [(item["name"], prompt) for item in result]
73
+ )
74
+ else:
75
+ return _read_user_history(user["preferred_username"])
76
+ except Exception as e:
77
+ raise gr.Error(f"Error while fetching history: {e}") from e
78
+
79
+
80
+ ####################
81
+ # Internal helpers #
82
+ ####################
83
+
84
+
85
+ def _read_user_history(username: str) -> List[Tuple[str, str]]:
86
+ """Return saved history for that user."""
87
+ with _user_lock(username):
88
+ path = _user_history_path(username)
89
+ if path.exists():
90
+ return json.loads(path.read_text())
91
+ return [] # No history yet
92
+
93
+
94
+ def _update_user_history(
95
+ username: str, new_images: List[Tuple[str, str]]
96
+ ) -> List[Tuple[str, str]]:
97
+ """Update history for that user and return it."""
98
+ with _user_lock(username):
99
+ # Read existing
100
+ path = _user_history_path(username)
101
+ if path.exists():
102
+ images = json.loads(path.read_text())
103
+ else:
104
+ images = [] # No history yet
105
+
106
+ # Copy images to persistent folder
107
+ images = [
108
+ (_copy_image(src_path), prompt) for src_path, prompt in new_images
109
+ ] + images
110
+
111
+ # Save and return
112
+ path.write_text(json.dumps(images))
113
+ return images
114
+
115
+
116
+ def _user_history_path(username: str) -> Path:
117
+ return HISTORY_FOLDER_PATH / f"{username}.json"
118
+
119
+
120
+ def _user_lock(username: str) -> FileLock:
121
+ """Ensure history is not corrupted if concurrent calls."""
122
+ return FileLock(f"{_user_history_path(username)}.lock")
123
+
124
+
125
+ def _copy_image(src: str) -> str:
126
+ """Copy image to the persistent storage."""
127
+ dst = IMAGES_FOLDER_PATH / f"{uuid4().hex}_{Path(src).name}" # keep file ext
128
+ shutil.copyfile(src, dst)
129
+ return str(dst)
requirements.txt CHANGED
@@ -5,4 +5,5 @@ invisible-watermark==0.2.0
5
  Pillow==10.0.0
6
  torch==2.0.1
7
  transformers==4.32.1
8
- compel
 
 
5
  Pillow==10.0.0
6
  torch==2.0.1
7
  transformers==4.32.1
8
+ compel
9
+ filelock