kfahn commited on
Commit
e941ff9
1 Parent(s): 81541d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -25,12 +25,12 @@ def infer(prompt):
25
  num_samples = 1 #jax.device_count()
26
  rng = create_key(0)
27
  rng = jax.random.split(rng, jax.device_count())
28
- im = image
29
- image = Image.fromarray(im)
30
 
31
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
32
  #negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
33
- processed_image = pipe.prepare_image_inputs([image] * num_samples)
34
 
35
  p_params = replicate(params)
36
  prompt_ids = shard(prompt_ids)
 
25
  num_samples = 1 #jax.device_count()
26
  rng = create_key(0)
27
  rng = jax.random.split(rng, jax.device_count())
28
+ #im = image
29
+ #image = Image.fromarray(im)
30
 
31
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
32
  #negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
33
+ #processed_image = pipe.prepare_image_inputs([image] * num_samples)
34
 
35
  p_params = replicate(params)
36
  prompt_ids = shard(prompt_ids)