Spaces:
Running
on
Zero
Running
on
Zero
kxhit
commited on
Commit
•
5833474
1
Parent(s):
ed19cf4
queue 1
Browse files- 6DoF/dataset.py +1 -29
- app.py +6 -14
6DoF/dataset.py
CHANGED
@@ -2,40 +2,12 @@ import os
|
|
2 |
import math
|
3 |
from pathlib import Path
|
4 |
import torch
|
5 |
-
import
|
6 |
-
from torch.utils.data import Dataset, DataLoader
|
7 |
-
from torchvision import transforms
|
8 |
from PIL import Image
|
9 |
import numpy as np
|
10 |
-
import webdataset as wds
|
11 |
-
from torch.utils.data.distributed import DistributedSampler
|
12 |
import matplotlib.pyplot as plt
|
13 |
import sys
|
14 |
|
15 |
-
class ObjaverseDataLoader():
|
16 |
-
def __init__(self, root_dir, batch_size, total_view=12, num_workers=4):
|
17 |
-
self.root_dir = root_dir
|
18 |
-
self.batch_size = batch_size
|
19 |
-
self.num_workers = num_workers
|
20 |
-
self.total_view = total_view
|
21 |
-
|
22 |
-
image_transforms = [torchvision.transforms.Resize((256, 256)),
|
23 |
-
transforms.ToTensor(),
|
24 |
-
transforms.Normalize([0.5], [0.5])]
|
25 |
-
self.image_transforms = torchvision.transforms.Compose(image_transforms)
|
26 |
-
|
27 |
-
def train_dataloader(self):
|
28 |
-
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False,
|
29 |
-
image_transforms=self.image_transforms)
|
30 |
-
# sampler = DistributedSampler(dataset)
|
31 |
-
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
32 |
-
# sampler=sampler)
|
33 |
-
|
34 |
-
def val_dataloader(self):
|
35 |
-
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True,
|
36 |
-
image_transforms=self.image_transforms)
|
37 |
-
sampler = DistributedSampler(dataset)
|
38 |
-
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
39 |
|
40 |
def get_pose(transformation):
|
41 |
# transformation: 4x4
|
|
|
2 |
import math
|
3 |
from pathlib import Path
|
4 |
import torch
|
5 |
+
from torch.utils.data import Dataset
|
|
|
|
|
6 |
from PIL import Image
|
7 |
import numpy as np
|
|
|
|
|
8 |
import matplotlib.pyplot as plt
|
9 |
import sys
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def get_pose(transformation):
|
13 |
# transformation: 4x4
|
app.py
CHANGED
@@ -183,19 +183,11 @@ def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_
|
|
183 |
# run inference
|
184 |
# pipeline.to(device)
|
185 |
pipeline.enable_xformers_memory_efficient_attention()
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
|
192 |
-
output_type="numpy").images
|
193 |
-
elif CaPE_TYPE == "4DoF":
|
194 |
-
with torch.autocast("cuda"):
|
195 |
-
image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[pose_out, pose_in],
|
196 |
-
height=h, width=w, T_in=T_in, T_out=T_out,
|
197 |
-
guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
|
198 |
-
output_type="numpy").images
|
199 |
|
200 |
# save output image
|
201 |
output_dir = os.path.join(tmpdirname, "eschernet")
|
@@ -748,7 +740,7 @@ with gr.Blocks() as demo:
|
|
748 |
|
749 |
# demo.queue(max_size=10)
|
750 |
# demo.launch(share=True, server_name="0.0.0.0", server_port=None)
|
751 |
-
demo.queue(max_size=
|
752 |
|
753 |
# if __name__ == '__main__':
|
754 |
# main()
|
|
|
183 |
# run inference
|
184 |
# pipeline.to(device)
|
185 |
pipeline.enable_xformers_memory_efficient_attention()
|
186 |
+
image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
|
187 |
+
poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
|
188 |
+
height=h, width=w, T_in=T_in, T_out=T_out,
|
189 |
+
guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
|
190 |
+
output_type="numpy").images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
# save output image
|
193 |
output_dir = os.path.join(tmpdirname, "eschernet")
|
|
|
740 |
|
741 |
# demo.queue(max_size=10)
|
742 |
# demo.launch(share=True, server_name="0.0.0.0", server_port=None)
|
743 |
+
demo.queue(max_size=10).launch()
|
744 |
|
745 |
# if __name__ == '__main__':
|
746 |
# main()
|