boris commited on
Commit
7f9514d
2 Parent(s): b49f529 6e79248

Merge pull request #41 from borisdayma/predictions

Browse files

Get predictions from backend

Former-commit-id: 54f7e9533e053884a9d98e69ef9a6ca7090d02ab

dalle_mini/backend.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from io import BytesIO
3
+ import base64
4
+ from PIL import Image
5
+
6
+ class ServiceError(Exception):
7
+ def __init__(self, status_code):
8
+ self.status_code = status_code
9
+
10
+ def get_images_from_backend(prompt, backend_url):
11
+ r = requests.post(
12
+ backend_url,
13
+ json={"prompt": prompt}
14
+ )
15
+ if r.status_code == 200:
16
+ images = r.json()["images"]
17
+ images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
18
+ return images
19
+ else:
20
+ raise ServiceError(r.status_code)
dev/predictions/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Scripts to generate predictions for assessment and reporting.
dev/predictions/dalle_mini ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../dalle_mini
dev/predictions/wandb-examples-from-backend.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import wandb
6
+ import os
7
+
8
+ from dalle_mini.backend import ServiceError, get_images_from_backend
9
+
10
+ os.environ["WANDB_SILENT"] = "true"
11
+ os.environ["WANDB_CONSOLE"] = "off"
12
+
13
+ # set id to None so our latest images don't get overwritten
14
+ id = None
15
+ run = wandb.init(id=id,
16
+ entity='wandb',
17
+ project="hf-flax-dalle-mini",
18
+ job_type="predictions",
19
+ resume="allow"
20
+ )
21
+
22
+ def captioned_strip(images, caption):
23
+ w, h = images[0].size[0], images[0].size[1]
24
+ img = Image.new("RGB", (len(images)*w, h + 48))
25
+ for i, img_ in enumerate(images):
26
+ img.paste(img_, (i*w, 48))
27
+ draw = ImageDraw.Draw(img)
28
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
29
+ draw.text((20, 3), caption, (255,255,255), font=font)
30
+ return img
31
+
32
+ def log_to_wandb(prompts):
33
+ try:
34
+ backend_url = os.environ["BACKEND_SERVER"]
35
+
36
+ strips = []
37
+ for prompt in prompts:
38
+ print(f"Getting selections for: {prompt}")
39
+ selected = get_images_from_backend(prompt, backend_url)
40
+ strip = captioned_strip(selected, prompt)
41
+ strips.append(wandb.Image(strip))
42
+ wandb.log({"images": strips})
43
+ except ServiceError as error:
44
+ print(f"Service unavailable, status: {error.status_code}")
45
+ except KeyError:
46
+ print("Error: BACKEND_SERVER unset")
47
+
48
+ prompts = [
49
+ "white snow covered mountain under blue sky during daytime",
50
+ "aerial view of beach during daytime",
51
+ "aerial view of beach at night",
52
+ "an armchair in the shape of an avocado",
53
+ "a logo of an avocado armchair playing music",
54
+ "young woman riding her bike trough a forest",
55
+ "rice fields by the mediterranean coast",
56
+ "white houses on the hill of a greek coastline",
57
+ "illustration of a shark with a baby shark",
58
+ "painting of an oniric forest glade surrounded by tall trees",
59
+ ]
60
+
61
+ log_to_wandb(prompts)
dev/{wandb-examples.py → predictions/wandb-examples.py} RENAMED
File without changes