jbilcke-hf HF staff commited on
Commit
22b2a6e
β€’
1 Parent(s): d96ce03

improve concurrency

Browse files
app.py CHANGED
@@ -79,7 +79,7 @@ async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
79
 
80
  elif msg.type == WSMsgType.TEXT:
81
  data = json.loads(msg.data)
82
- webp_bytes = engine.transform_image(data.get('hash'), data.get('params'))
83
  await ws.send_bytes(webp_bytes)
84
 
85
  except Exception as e:
 
79
 
80
  elif msg.type == WSMsgType.TEXT:
81
  data = json.loads(msg.data)
82
+ webp_bytes = await engine.transform_image(data.get('hash'), data.get('params'))
83
  await ws.send_bytes(webp_bytes)
84
 
85
  except Exception as e:
client/src/hooks/useFaceLandmarkDetection.tsx CHANGED
@@ -18,7 +18,7 @@ export function useFaceLandmarkDetection() {
18
  // if we only send the face/square then we can use 138ms
19
  // unfortunately it doesn't work well yet
20
  // const throttleInMs = 138ms
21
- const throttleInMs = 200
22
  ////////////////////////////////////////////////////////////////////////
23
 
24
  // State for face detection
 
18
  // if we only send the face/square then we can use 138ms
19
  // unfortunately it doesn't work well yet
20
  // const throttleInMs = 138ms
21
+ const throttleInMs = 220
22
  ////////////////////////////////////////////////////////////////////////
23
 
24
  // State for face detection
engine.py CHANGED
@@ -129,7 +129,7 @@ class Engine:
129
  # 'bbox_rot': bbox_info['bbox_rot'].toList(), # 4x2
130
  }
131
 
132
- def transform_image(self, image_hash: str, params: Dict[str, float]) -> bytes:
133
  # If we don't have the image in cache yet, add it
134
  if image_hash not in self.processed_cache:
135
  raise ValueError("cache miss")
@@ -197,11 +197,11 @@ class Engine:
197
  x_d_new = processed_data['x_s_info']['scale'] * (x_d_new @ R_new) + processed_data['x_s_info']['t']
198
 
199
  # Apply stitching
200
- x_d_new = self.live_portrait.live_portrait_wrapper.stitching(processed_data['x_s'], x_d_new)
201
 
202
  # Generate the output
203
- out = self.live_portrait.live_portrait_wrapper.warp_decode(processed_data['f_s'], processed_data['x_s'], x_d_new)
204
- I_p = self.live_portrait.live_portrait_wrapper.parse_output(out['out'])
205
 
206
  buffered = io.BytesIO()
207
 
@@ -214,11 +214,11 @@ class Engine:
214
  # I'm currently running some experiments to do it in the frontend
215
  #
216
  # --- old way: we do it in the server-side: ---
217
- mask_ori = prepare_paste_back(
218
  processed_data['inference_cfg'].mask_crop, processed_data['crop_info']['M_c2o'],
219
  dsize=(processed_data['img_rgb'].shape[1], processed_data['img_rgb'].shape[0])
220
  )
221
- I_p_to_ori_blend = paste_back(
222
  I_p[0], processed_data['crop_info']['M_c2o'], processed_data['img_rgb'], mask_ori
223
  )
224
  result_image = Image.fromarray(I_p_to_ori_blend)
 
129
  # 'bbox_rot': bbox_info['bbox_rot'].toList(), # 4x2
130
  }
131
 
132
+ async def transform_image(self, image_hash: str, params: Dict[str, float]) -> bytes:
133
  # If we don't have the image in cache yet, add it
134
  if image_hash not in self.processed_cache:
135
  raise ValueError("cache miss")
 
197
  x_d_new = processed_data['x_s_info']['scale'] * (x_d_new @ R_new) + processed_data['x_s_info']['t']
198
 
199
  # Apply stitching
200
+ x_d_new = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.stitching, processed_data['x_s'], x_d_new)
201
 
202
  # Generate the output
203
+ out = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.warp_decode, processed_data['f_s'], processed_data['x_s'], x_d_new)
204
+ I_p = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.parse_output, out['out'])
205
 
206
  buffered = io.BytesIO()
207
 
 
214
  # I'm currently running some experiments to do it in the frontend
215
  #
216
  # --- old way: we do it in the server-side: ---
217
+ mask_ori = await asyncio.to_thread(prepare_paste_back,
218
  processed_data['inference_cfg'].mask_crop, processed_data['crop_info']['M_c2o'],
219
  dsize=(processed_data['img_rgb'].shape[1], processed_data['img_rgb'].shape[0])
220
  )
221
+ I_p_to_ori_blend = await asyncio.to_thread(paste_back,
222
  I_p[0], processed_data['crop_info']['M_c2o'], processed_data['img_rgb'], mask_ori
223
  )
224
  result_image = Image.fromarray(I_p_to_ori_blend)