kxhit commited on
Commit
5833474
1 Parent(s): ed19cf4
Files changed (2) hide show
  1. 6DoF/dataset.py +1 -29
  2. app.py +6 -14
6DoF/dataset.py CHANGED
@@ -2,40 +2,12 @@ import os
2
  import math
3
  from pathlib import Path
4
  import torch
5
- import torchvision
6
- from torch.utils.data import Dataset, DataLoader
7
- from torchvision import transforms
8
  from PIL import Image
9
  import numpy as np
10
- import webdataset as wds
11
- from torch.utils.data.distributed import DistributedSampler
12
  import matplotlib.pyplot as plt
13
  import sys
14
 
15
- class ObjaverseDataLoader():
16
- def __init__(self, root_dir, batch_size, total_view=12, num_workers=4):
17
- self.root_dir = root_dir
18
- self.batch_size = batch_size
19
- self.num_workers = num_workers
20
- self.total_view = total_view
21
-
22
- image_transforms = [torchvision.transforms.Resize((256, 256)),
23
- transforms.ToTensor(),
24
- transforms.Normalize([0.5], [0.5])]
25
- self.image_transforms = torchvision.transforms.Compose(image_transforms)
26
-
27
- def train_dataloader(self):
28
- dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False,
29
- image_transforms=self.image_transforms)
30
- # sampler = DistributedSampler(dataset)
31
- return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
32
- # sampler=sampler)
33
-
34
- def val_dataloader(self):
35
- dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True,
36
- image_transforms=self.image_transforms)
37
- sampler = DistributedSampler(dataset)
38
- return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
39
 
40
  def get_pose(transformation):
41
  # transformation: 4x4
 
2
  import math
3
  from pathlib import Path
4
  import torch
5
+ from torch.utils.data import Dataset
 
 
6
  from PIL import Image
7
  import numpy as np
 
 
8
  import matplotlib.pyplot as plt
9
  import sys
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def get_pose(transformation):
13
  # transformation: 4x4
app.py CHANGED
@@ -183,19 +183,11 @@ def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_
183
  # run inference
184
  # pipeline.to(device)
185
  pipeline.enable_xformers_memory_efficient_attention()
186
- if CaPE_TYPE == "6DoF":
187
- with torch.autocast("cuda"):
188
- image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
189
- poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
190
- height=h, width=w, T_in=T_in, T_out=T_out,
191
- guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
192
- output_type="numpy").images
193
- elif CaPE_TYPE == "4DoF":
194
- with torch.autocast("cuda"):
195
- image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[pose_out, pose_in],
196
- height=h, width=w, T_in=T_in, T_out=T_out,
197
- guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
198
- output_type="numpy").images
199
 
200
  # save output image
201
  output_dir = os.path.join(tmpdirname, "eschernet")
@@ -748,7 +740,7 @@ with gr.Blocks() as demo:
748
 
749
  # demo.queue(max_size=10)
750
  # demo.launch(share=True, server_name="0.0.0.0", server_port=None)
751
- demo.queue(max_size=1).launch()
752
 
753
  # if __name__ == '__main__':
754
  # main()
 
183
  # run inference
184
  # pipeline.to(device)
185
  pipeline.enable_xformers_memory_efficient_attention()
186
+ image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
187
+ poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
188
+ height=h, width=w, T_in=T_in, T_out=T_out,
189
+ guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
190
+ output_type="numpy").images
 
 
 
 
 
 
 
 
191
 
192
  # save output image
193
  output_dir = os.path.join(tmpdirname, "eschernet")
 
740
 
741
  # demo.queue(max_size=10)
742
  # demo.launch(share=True, server_name="0.0.0.0", server_port=None)
743
+ demo.queue(max_size=10).launch()
744
 
745
  # if __name__ == '__main__':
746
  # main()