Update app.py
Browse files
app.py
CHANGED
@@ -25,12 +25,12 @@ def infer(prompt):
|
|
25 |
num_samples = 1 #jax.device_count()
|
26 |
rng = create_key(0)
|
27 |
rng = jax.random.split(rng, jax.device_count())
|
28 |
-
im = image
|
29 |
-
image = Image.fromarray(im)
|
30 |
|
31 |
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
|
32 |
#negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
|
33 |
-
processed_image = pipe.prepare_image_inputs([image] * num_samples)
|
34 |
|
35 |
p_params = replicate(params)
|
36 |
prompt_ids = shard(prompt_ids)
|
|
|
25 |
num_samples = 1 #jax.device_count()
|
26 |
rng = create_key(0)
|
27 |
rng = jax.random.split(rng, jax.device_count())
|
28 |
+
#im = image
|
29 |
+
#image = Image.fromarray(im)
|
30 |
|
31 |
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
|
32 |
#negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
|
33 |
+
#processed_image = pipe.prepare_image_inputs([image] * num_samples)
|
34 |
|
35 |
p_params = replicate(params)
|
36 |
prompt_ids = shard(prompt_ids)
|