hanshu.yan commited on
Commit
2ec72fb
β€’
1 Parent(s): b83e3cf

add app.py

Browse files
Files changed (49) hide show
  1. LICENSE +21 -0
  2. README.md +80 -12
  3. app.py +182 -0
  4. gradio_app.py +188 -0
  5. output/.DS_Store +0 -0
  6. output/0/input.png +0 -0
  7. output/0/mesh.obj +0 -0
  8. requirements.txt +17 -0
  9. requirements2.txt +9 -0
  10. run.py +162 -0
  11. src/__pycache__/__init__.cpython-38.pyc +0 -0
  12. src/__pycache__/scheduler_perflow.cpython-310.pyc +0 -0
  13. src/__pycache__/scheduler_perflow.cpython-38.pyc +0 -0
  14. src/__pycache__/utils_perflow.cpython-38.pyc +0 -0
  15. src/laion_bytenas.py +257 -0
  16. src/pfode_solver.py +120 -0
  17. src/scheduler_perflow.py +343 -0
  18. src/utils_perflow.py +77 -0
  19. test.yaml +10 -0
  20. tsr/__pycache__/system.cpython-310.pyc +0 -0
  21. tsr/__pycache__/system.cpython-38.pyc +0 -0
  22. tsr/__pycache__/utils.cpython-310.pyc +0 -0
  23. tsr/__pycache__/utils.cpython-38.pyc +0 -0
  24. tsr/models/__pycache__/isosurface.cpython-310.pyc +0 -0
  25. tsr/models/__pycache__/isosurface.cpython-38.pyc +0 -0
  26. tsr/models/__pycache__/nerf_renderer.cpython-310.pyc +0 -0
  27. tsr/models/__pycache__/nerf_renderer.cpython-38.pyc +0 -0
  28. tsr/models/__pycache__/network_utils.cpython-310.pyc +0 -0
  29. tsr/models/__pycache__/network_utils.cpython-38.pyc +0 -0
  30. tsr/models/isosurface.py +52 -0
  31. tsr/models/nerf_renderer.py +180 -0
  32. tsr/models/network_utils.py +124 -0
  33. tsr/models/tokenizers/__pycache__/image.cpython-310.pyc +0 -0
  34. tsr/models/tokenizers/__pycache__/image.cpython-38.pyc +0 -0
  35. tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc +0 -0
  36. tsr/models/tokenizers/__pycache__/triplane.cpython-38.pyc +0 -0
  37. tsr/models/tokenizers/image.py +66 -0
  38. tsr/models/tokenizers/triplane.py +45 -0
  39. tsr/models/transformer/__pycache__/attention.cpython-310.pyc +0 -0
  40. tsr/models/transformer/__pycache__/attention.cpython-38.pyc +0 -0
  41. tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc +0 -0
  42. tsr/models/transformer/__pycache__/basic_transformer_block.cpython-38.pyc +0 -0
  43. tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc +0 -0
  44. tsr/models/transformer/__pycache__/transformer_1d.cpython-38.pyc +0 -0
  45. tsr/models/transformer/attention.py +653 -0
  46. tsr/models/transformer/basic_transformer_block.py +334 -0
  47. tsr/models/transformer/transformer_1d.py +219 -0
  48. tsr/system.py +203 -0
  49. tsr/utils.py +474 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Tripo AI & Stability AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,80 @@
1
- ---
2
- title: Perflow Triposr
3
- emoji: πŸ“Š
4
- colorFrom: blue
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 4.20.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TripoSR <a href="https://huggingface.co/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a> <a href="https://huggingface.co/spaces/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a> <a href="https://arxiv.org/abs/2403.02151"><img src="https://img.shields.io/badge/Arxiv-2403.02151-B31B1B.svg"></a>
2
+
3
+ <div align="center">
4
+ <img src="figures/teaser800.gif" alt="Teaser Video">
5
+ </div>
6
+
7
+ This is the official codebase for **TripoSR**, a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
8
+ <br><br>
9
+ Leveraging the principles of the [Large Reconstruction Model (LRM)](https://yiconghong.me/LRM/), TripoSR brings to the table key advancements that significantly boost both the speed and quality of 3D reconstruction. Our model is distinguished by its ability to rapidly process inputs, generating high-quality 3D models in less than 0.5 seconds on an NVIDIA A100 GPU. TripoSR has exhibited superior performance in both qualitative and quantitative evaluations, outperforming other open-source alternatives across multiple public datasets. The figures below illustrate visual comparisons and metrics showcasing TripoSR's performance relative to other leading models. Details about the model architecture, training process, and comparisons can be found in this [technical report](https://arxiv.org/abs/2403.02151).
10
+
11
+ <!--
12
+ <div align="center">
13
+ <img src="figures/comparison800.gif" alt="Teaser Video">
14
+ </div>
15
+ -->
16
+ <p align="center">
17
+ <img width="800" src="figures/visual_comparisons.jpg"/>
18
+ </p>
19
+
20
+ <p align="center">
21
+ <img width="450" src="figures/scatter-comparison.png"/>
22
+ </p>
23
+
24
+
25
+ The model is released under the MIT license, which includes the source code, pretrained models, and an interactive online demo. Our goal is to empower researchers, developers, and creatives to push the boundaries of what's possible in 3D generative AI and 3D content creation.
26
+
27
+ ## Getting Started
28
+ ### Installation
29
+ - Python >= 3.8
30
+ - Install CUDA if available
31
+ - Install PyTorch according to your platform: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) **[Please make sure that the locally-installed CUDA major version matches the PyTorch-shipped CUDA major version. For example if you have CUDA 11.x installed, make sure to install PyTorch compiled with CUDA 11.x.]**
32
+ - Update setuptools by `pip install --upgrade setuptools`
33
+ - Install other dependencies by `pip install -r requirements.txt`
34
+
35
+ ### Manual Inference
36
+ ```sh
37
+ python run.py examples/chair.png --output-dir output/
38
+ ```
39
+ This will save the reconstructed 3D model to `output/`. You can also specify more than one image path separated by spaces. The default options takes about **6GB VRAM** for a single image input.
40
+
41
+ For detailed usage of this script, use `python run.py --help`.
42
+
43
+ ### Local Gradio App
44
+ Install Gradio:
45
+ ```sh
46
+ pip install gradio
47
+ ```
48
+ Start the Gradio App:
49
+ ```sh
50
+ python gradio_app.py
51
+ ```
52
+
53
+ ## Troubleshooting
54
+ > AttributeError: module 'torchmcubes_module' has no attribute 'mcubes_cuda'
55
+
56
+ or
57
+
58
+ > torchmcubes was not compiled with CUDA support, use CPU version instead.
59
+
60
+ This is because `torchmcubes` is compiled without CUDA support. Please make sure that
61
+
62
+ - The locally-installed CUDA major version matches the PyTorch-shipped CUDA major version. For example if you have CUDA 11.x installed, make sure to install PyTorch compiled with CUDA 11.x.
63
+ - `setuptools>=49.6.0`. If not, upgrade by `pip install --upgrade setuptools`.
64
+
65
+ Then re-install `torchmcubes` by:
66
+
67
+ ```sh
68
+ pip uninstall torchmcubes
69
+ pip install git+https://github.com/tatsy/torchmcubes.git
70
+ ```
71
+
72
+ ## Citation
73
+ ```BibTeX
74
+ @article{TripoSR2024,
75
+ title={TripoSR: Fast 3D Object Reconstruction from a Single Image},
76
+ author={Tochilkin, Dmitry and Pankratz, David and Liu, Zexiang and Huang, Zixuan and and Letts, Adam and Li, Yangguang and Liang, Ding and Laforte, Christian and Jampani, Varun and Cao, Yan-Pei},
77
+ journal={arXiv preprint arXiv:2403.02151},
78
+ year={2024}
79
+ }
80
+ ```
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, logging, time, argparse, random, tempfile, rembg
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from functools import partial
7
+ from tsr.system import TSR
8
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
9
+
10
+ from src.scheduler_perflow import PeRFlowScheduler
11
+ from diffusers import StableDiffusionPipeline, UNet2DConditionModel
12
+
13
+ def merge_delta_weights_into_unet(pipe, delta_weights, org_alpha = 1.0):
14
+ unet_weights = pipe.unet.state_dict()
15
+ for key in delta_weights.keys():
16
+ dtype = unet_weights[key].dtype
17
+ try:
18
+ unet_weights[key] = org_alpha * unet_weights[key].to(dtype=delta_weights[key].dtype) + delta_weights[key].to(device=unet_weights[key].device)
19
+ except:
20
+ unet_weights[key] = unet_weights[key].to(dtype=delta_weights[key].dtype)
21
+ unet_weights[key] = unet_weights[key].to(dtype)
22
+ pipe.unet.load_state_dict(unet_weights, strict=True)
23
+ return pipe
24
+
25
+ def setup_seed(seed):
26
+ random.seed(seed)
27
+ np.random.seed(seed)
28
+ torch.manual_seed(seed)
29
+ torch.cuda.manual_seed_all(seed)
30
+ torch.backends.cudnn.deterministic = True
31
+
32
+ if torch.cuda.is_available():
33
+ device = "cuda:0"
34
+ else:
35
+ device = "cpu"
36
+
37
+ ### TripoSR
38
+ model = TSR.from_pretrained(
39
+ "stabilityai/TripoSR",
40
+ config_name="config.yaml",
41
+ weight_name="model.ckpt",
42
+ )
43
+ # adjust the chunk size to balance between speed and memory usage
44
+ model.renderer.set_chunk_size(8192)
45
+ model.to(device)
46
+
47
+
48
+ ### PeRFlow-T2I
49
+ # pipe_t2i = StableDiffusionPipeline.from_pretrained("Lykon/dreamshaper-8", torch_dtype=torch.float16, safety_checker=None)
50
+ pipe_t2i = StableDiffusionPipeline.from_pretrained("stablediffusionapi/disney-pixar-cartoon", torch_dtype=torch.float16, safety_checker=None)
51
+ delta_weights = UNet2DConditionModel.from_pretrained("hansyan/piecewise-rectified-flow-delta-weights", torch_dtype=torch.float16, variant="v0-1",).state_dict()
52
+ pipe_t2i = merge_delta_weights_into_unet(pipe_t2i, delta_weights)
53
+ pipe_t2i.scheduler = PeRFlowScheduler.from_config(pipe_t2i.scheduler.config, prediction_type="epsilon", num_time_windows=4)
54
+ pipe_t2i.to('cuda:0', torch.float16)
55
+
56
+
57
+ ### gradio
58
+ rembg_session = rembg.new_session()
59
+
60
+ def generate(text, seed):
61
+ def fill_background(image):
62
+ image = np.array(image).astype(np.float32) / 255.0
63
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
64
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
65
+ return image
66
+
67
+ setup_seed(int(seed))
68
+ # text = text
69
+ samples = pipe_t2i(
70
+ prompt = [text],
71
+ negative_prompt = ["distorted, blur, low-quality, haze, out of focus"],
72
+ height = 512,
73
+ width = 512,
74
+ # num_inference_steps = 4,
75
+ # guidance_scale = 4.5,
76
+ num_inference_steps = 6,
77
+ guidance_scale = 7,
78
+ output_type = 'pt',
79
+ ).images
80
+ samples = torch.nn.functional.interpolate(samples, size=768, mode='bilinear')
81
+ samples = samples.squeeze(0).permute(1, 2, 0).cpu().numpy()*255.
82
+ samples = samples.astype(np.uint8)
83
+ samples = Image.fromarray(samples[:, :, :3])
84
+
85
+ image = remove_background(samples, rembg_session)
86
+ image = resize_foreground(image, 0.85)
87
+ image = fill_background(image)
88
+ return image
89
+
90
+ def render(image, mc_resolution=256, formats=["obj"]):
91
+ scene_codes = model(image, device=device)
92
+ mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
93
+ mesh = to_gradio_3d_orientation(mesh)
94
+ rv = []
95
+ for format in formats:
96
+ mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
97
+ mesh.export(mesh_path.name)
98
+ rv.append(mesh_path.name)
99
+ return rv[0]
100
+
101
+ # warm up
102
+ _ = generate("a bird", 42)
103
+
104
+ # layout
105
+ css = """
106
+ h1 {
107
+ text-align: center;
108
+ display:block;
109
+ }
110
+ h2 {
111
+ text-align: center;
112
+ display:block;
113
+ }
114
+ h3 {
115
+ text-align: center;
116
+ display:block;
117
+ }
118
+ """
119
+ with gr.Blocks(title="TripoSR", css=css) as interface:
120
+ gr.Markdown(
121
+ """
122
+ # Instant Text-to-3D Mesh Demo
123
+
124
+ ### [PeRFlow](https://github.com/magic-research/piecewise-rectified-flow)-T2I + [TripoSR](https://github.com/VAST-AI-Research/TripoSR)
125
+
126
+ Two-stage synthesis: 1) generating images by PeRFlow-T2I with 6-step inference; 2) rendering 3D assests.
127
+ """
128
+ )
129
+
130
+ with gr.Column():
131
+ with gr.Row():
132
+ output_image = gr.Image(label='Generated Image', height=384, width=384)
133
+
134
+ output_model_obj = gr.Model3D(
135
+ label="Output 3D Model (OBJ Format)",
136
+ interactive=False,
137
+ height=384, width=384,
138
+ )
139
+
140
+ with gr.Row():
141
+ textbox = gr.Textbox(label="Input Prompt", value="a colorful bird")
142
+ seed = gr.Textbox(label="Random Seed", value=42)
143
+
144
+ # activate
145
+ textbox.submit(
146
+ fn=generate,
147
+ inputs=[textbox, seed],
148
+ outputs=[output_image],
149
+ ).success(
150
+ fn=render,
151
+ inputs=[output_image],
152
+ outputs=[output_model_obj],
153
+ )
154
+
155
+ seed.submit(
156
+ fn=generate,
157
+ inputs=[textbox, seed],
158
+ outputs=[output_image],
159
+ ).success(
160
+ fn=render,
161
+ inputs=[output_image],
162
+ outputs=[output_model_obj],
163
+ )
164
+
165
+
166
+
167
+ if __name__ == '__main__':
168
+ parser = argparse.ArgumentParser()
169
+ parser.add_argument('--username', type=str, default=None, help='Username for authentication')
170
+ parser.add_argument('--password', type=str, default=None, help='Password for authentication')
171
+ parser.add_argument('--port', type=int, default=7860, help='Port to run the server listener on')
172
+ parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
173
+ parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
174
+ parser.add_argument("--queuesize", type=int, default=1, help="launch gradio queue max_size")
175
+ args = parser.parse_args()
176
+ interface.queue(max_size=args.queuesize)
177
+ interface.launch(
178
+ auth=(args.username, args.password) if (args.username and args.password) else None,
179
+ share=args.share,
180
+ server_name="0.0.0.0" if args.listen else None,
181
+ server_port=args.port
182
+ )
gradio_app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ import time
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import rembg
9
+ import torch
10
+ from PIL import Image
11
+ from functools import partial
12
+
13
+ from tsr.system import TSR
14
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
+
16
+ import argparse
17
+
18
+
19
+ if torch.cuda.is_available():
20
+ device = "cuda:0"
21
+ else:
22
+ device = "cpu"
23
+
24
+ model = TSR.from_pretrained(
25
+ "stabilityai/TripoSR",
26
+ config_name="config.yaml",
27
+ weight_name="model.ckpt",
28
+ )
29
+
30
+ # adjust the chunk size to balance between speed and memory usage
31
+ model.renderer.set_chunk_size(8192)
32
+ model.to(device)
33
+
34
+ rembg_session = rembg.new_session()
35
+
36
+
37
+ def check_input_image(input_image):
38
+ if input_image is None:
39
+ raise gr.Error("No image uploaded!")
40
+
41
+
42
+ def preprocess(input_image, do_remove_background, foreground_ratio):
43
+ def fill_background(image):
44
+ image = np.array(image).astype(np.float32) / 255.0
45
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
46
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
47
+ return image
48
+
49
+ if do_remove_background:
50
+ image = input_image.convert("RGB")
51
+ image = remove_background(image, rembg_session)
52
+ image = resize_foreground(image, foreground_ratio)
53
+ image = fill_background(image)
54
+ else:
55
+ image = input_image
56
+ if image.mode == "RGBA":
57
+ image = fill_background(image)
58
+ return image
59
+
60
+
61
+ def generate(image, mc_resolution, formats=["obj", "glb"]):
62
+ print(image.shape, image.min(), image.max())
63
+ scene_codes = model(image, device=device)
64
+ mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
65
+ mesh = to_gradio_3d_orientation(mesh)
66
+ rv = []
67
+ for format in formats:
68
+ mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
69
+ mesh.export(mesh_path.name)
70
+ rv.append(mesh_path.name)
71
+ return rv
72
+
73
+
74
+ def run_example(image_pil):
75
+ preprocessed = preprocess(image_pil, False, 0.9)
76
+ mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
77
+ return preprocessed, mesh_name_obj, mesh_name_glb
78
+
79
+
80
+ with gr.Blocks(title="TripoSR") as interface:
81
+ gr.Markdown(
82
+ """
83
+ # TripoSR Demo
84
+ [TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
85
+
86
+ **Tips:**
87
+ 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
88
+ 2. You can disable "Remove Background" for the provided examples since they have been already preprocessed.
89
+ 3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
90
+ """
91
+ )
92
+ with gr.Row(variant="panel"):
93
+ with gr.Column():
94
+ with gr.Row():
95
+ input_image = gr.Image(
96
+ label="Input Image",
97
+ image_mode="RGBA",
98
+ sources="upload",
99
+ type="pil",
100
+ elem_id="content_image",
101
+ )
102
+ processed_image = gr.Image(label="Processed Image", interactive=False)
103
+ with gr.Row():
104
+ with gr.Group():
105
+ do_remove_background = gr.Checkbox(
106
+ label="Remove Background", value=True
107
+ )
108
+ foreground_ratio = gr.Slider(
109
+ label="Foreground Ratio",
110
+ minimum=0.5,
111
+ maximum=1.0,
112
+ value=0.85,
113
+ step=0.05,
114
+ )
115
+ mc_resolution = gr.Slider(
116
+ label="Marching Cubes Resolution",
117
+ minimum=32,
118
+ maximum=320,
119
+ value=256,
120
+ step=32
121
+ )
122
+ with gr.Row():
123
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
124
+ with gr.Column():
125
+ with gr.Tab("OBJ"):
126
+ output_model_obj = gr.Model3D(
127
+ label="Output Model (OBJ Format)",
128
+ interactive=False,
129
+ )
130
+ gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
131
+ with gr.Tab("GLB"):
132
+ output_model_glb = gr.Model3D(
133
+ label="Output Model (GLB Format)",
134
+ interactive=False,
135
+ )
136
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
137
+ with gr.Row(variant="panel"):
138
+ gr.Examples(
139
+ examples=[
140
+ "examples/hamburger.png",
141
+ "examples/poly_fox.png",
142
+ "examples/robot.png",
143
+ "examples/teapot.png",
144
+ "examples/tiger_girl.png",
145
+ "examples/horse.png",
146
+ "examples/flamingo.png",
147
+ "examples/unicorn.png",
148
+ "examples/chair.png",
149
+ "examples/iso_house.png",
150
+ "examples/marble.png",
151
+ "examples/police_woman.png",
152
+ "examples/captured_p.png",
153
+ ],
154
+ inputs=[input_image],
155
+ outputs=[processed_image, output_model_obj, output_model_glb],
156
+ cache_examples=False,
157
+ fn=partial(run_example),
158
+ label="Examples",
159
+ examples_per_page=20,
160
+ )
161
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
162
+ fn=preprocess,
163
+ inputs=[input_image, do_remove_background, foreground_ratio],
164
+ outputs=[processed_image],
165
+ ).success(
166
+ fn=generate,
167
+ inputs=[processed_image, mc_resolution],
168
+ outputs=[output_model_obj, output_model_glb],
169
+ )
170
+
171
+
172
+
173
+ if __name__ == '__main__':
174
+ parser = argparse.ArgumentParser()
175
+ parser.add_argument('--username', type=str, default=None, help='Username for authentication')
176
+ parser.add_argument('--password', type=str, default=None, help='Password for authentication')
177
+ parser.add_argument('--port', type=int, default=7860, help='Port to run the server listener on')
178
+ parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
179
+ parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
180
+ parser.add_argument("--queuesize", type=int, default=1, help="launch gradio queue max_size")
181
+ args = parser.parse_args()
182
+ interface.queue(max_size=args.queuesize)
183
+ interface.launch(
184
+ auth=(args.username, args.password) if (args.username and args.password) else None,
185
+ share=args.share,
186
+ server_name="0.0.0.0" if args.listen else None,
187
+ server_port=args.port
188
+ )
output/.DS_Store ADDED
Binary file (6.15 kB). View file
 
output/0/input.png ADDED
output/0/mesh.obj ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.24.0
2
+ einops==0.7.0
3
+ gradio==4.20.1
4
+ huggingface_hub==0.21.4
5
+ imageio==2.27.0
6
+ numpy==1.24.3
7
+ omegaconf==2.3.0
8
+ packaging==23.2
9
+ Pillow==10.1.0
10
+ rembg==2.0.55
11
+ safetensors==0.3.2
12
+ torch==2.0.0
13
+ torchvision==0.15.1
14
+ tqdm==4.64.1
15
+ transformers==4.27.0
16
+ trimesh==4.0.5
17
+ git+https://github.com/tatsy/torchmcubes.git
requirements2.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf==2.3.0
2
+ Pillow==10.1.0
3
+ einops==0.7.0
4
+ git+https://github.com/tatsy/torchmcubes.git
5
+ transformers==4.35.0
6
+ trimesh==4.0.5
7
+ rembg
8
+ huggingface-hub
9
+ imageio[ffmpeg]
run.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import time
5
+
6
+ import numpy as np
7
+ import rembg
8
+ import torch
9
+ from PIL import Image
10
+
11
+ from tsr.system import TSR
12
+ from tsr.utils import remove_background, resize_foreground, save_video
13
+
14
+
15
+ class Timer:
16
+ def __init__(self):
17
+ self.items = {}
18
+ self.time_scale = 1000.0 # ms
19
+ self.time_unit = "ms"
20
+
21
+ def start(self, name: str) -> None:
22
+ if torch.cuda.is_available():
23
+ torch.cuda.synchronize()
24
+ self.items[name] = time.time()
25
+ logging.info(f"{name} ...")
26
+
27
+ def end(self, name: str) -> float:
28
+ if name not in self.items:
29
+ return
30
+ if torch.cuda.is_available():
31
+ torch.cuda.synchronize()
32
+ start_time = self.items.pop(name)
33
+ delta = time.time() - start_time
34
+ t = delta * self.time_scale
35
+ logging.info(f"{name} finished in {t:.2f}{self.time_unit}.")
36
+
37
+
38
+ timer = Timer()
39
+
40
+
41
+ logging.basicConfig(
42
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
43
+ )
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument("image", type=str, nargs="+", help="Path to input image(s).")
46
+ parser.add_argument(
47
+ "--device",
48
+ default="cuda:0",
49
+ type=str,
50
+ help="Device to use. If no CUDA-compatible device is found, will fallback to 'cpu'. Default: 'cuda:0'",
51
+ )
52
+ parser.add_argument(
53
+ "--pretrained-model-name-or-path",
54
+ default="stabilityai/TripoSR",
55
+ type=str,
56
+ help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/TripoSR'",
57
+ )
58
+ parser.add_argument(
59
+ "--chunk-size",
60
+ default=8192,
61
+ type=int,
62
+ help="Evaluation chunk size for surface extraction and rendering. Smaller chunk size reduces VRAM usage but increases computation time. 0 for no chunking. Default: 8192",
63
+ )
64
+ parser.add_argument(
65
+ "--mc-resolution",
66
+ default=256,
67
+ type=int,
68
+ help="Marching cubes grid resolution. Default: 256"
69
+ )
70
+ parser.add_argument(
71
+ "--no-remove-bg",
72
+ action="store_true",
73
+ help="If specified, the background will NOT be automatically removed from the input image, and the input image should be an RGB image with gray background and properly-sized foreground. Default: false",
74
+ )
75
+ parser.add_argument(
76
+ "--foreground-ratio",
77
+ default=0.85,
78
+ type=float,
79
+ help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
80
+ )
81
+ parser.add_argument(
82
+ "--output-dir",
83
+ default="output/",
84
+ type=str,
85
+ help="Output directory to save the results. Default: 'output/'",
86
+ )
87
+ parser.add_argument(
88
+ "--model-save-format",
89
+ default="obj",
90
+ type=str,
91
+ choices=["obj", "glb"],
92
+ help="Format to save the extracted mesh. Default: 'obj'",
93
+ )
94
+ parser.add_argument(
95
+ "--render",
96
+ action="store_true",
97
+ help="If specified, save a NeRF-rendered video. Default: false",
98
+ )
99
+ args = parser.parse_args()
100
+
101
+ output_dir = args.output_dir
102
+ os.makedirs(output_dir, exist_ok=True)
103
+
104
+ device = args.device
105
+ if not torch.cuda.is_available():
106
+ device = "cpu"
107
+
108
+ timer.start("Initializing model")
109
+ model = TSR.from_pretrained(
110
+ args.pretrained_model_name_or_path,
111
+ config_name="config.yaml",
112
+ weight_name="model.ckpt",
113
+ )
114
+ model.renderer.set_chunk_size(args.chunk_size)
115
+ model.to(device)
116
+ timer.end("Initializing model")
117
+
118
+ timer.start("Processing images")
119
+ images = []
120
+
121
+ if args.no_remove_bg:
122
+ rembg_session = None
123
+ else:
124
+ rembg_session = rembg.new_session()
125
+
126
+ for i, image_path in enumerate(args.image):
127
+ if args.no_remove_bg:
128
+ image = np.array(Image.open(image_path).convert("RGB"))
129
+ else:
130
+ image = remove_background(Image.open(image_path), rembg_session)
131
+ image = resize_foreground(image, args.foreground_ratio)
132
+ image = np.array(image).astype(np.float32) / 255.0
133
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
134
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
135
+ if not os.path.exists(os.path.join(output_dir, str(i))):
136
+ os.makedirs(os.path.join(output_dir, str(i)))
137
+ image.save(os.path.join(output_dir, str(i), f"input.png"))
138
+ images.append(image)
139
+ timer.end("Processing images")
140
+
141
+ for i, image in enumerate(images):
142
+ logging.info(f"Running image {i + 1}/{len(images)} ...")
143
+
144
+ timer.start("Running model")
145
+ with torch.no_grad():
146
+ scene_codes = model([image], device=device)
147
+ timer.end("Running model")
148
+
149
+ if args.render:
150
+ timer.start("Rendering")
151
+ render_images = model.render(scene_codes, n_views=30, return_type="pil")
152
+ for ri, render_image in enumerate(render_images[0]):
153
+ render_image.save(os.path.join(output_dir, str(i), f"render_{ri:03d}.png"))
154
+ save_video(
155
+ render_images[0], os.path.join(output_dir, str(i), f"render.mp4"), fps=30
156
+ )
157
+ timer.end("Rendering")
158
+
159
+ timer.start("Exporting mesh")
160
+ meshes = model.extract_mesh(scene_codes, resolution=args.mc_resolution)
161
+ meshes[0].export(os.path.join(output_dir, str(i), f"mesh.{args.model_save_format}"))
162
+ timer.end("Exporting mesh")
src/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (147 Bytes). View file
 
src/__pycache__/scheduler_perflow.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
src/__pycache__/scheduler_perflow.cpython-38.pyc ADDED
Binary file (12.1 kB). View file
 
src/__pycache__/utils_perflow.cpython-38.pyc ADDED
Binary file (2.64 kB). View file
 
src/laion_bytenas.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ from PIL import Image, ImageStat
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader, IterableDataset, get_worker_info
9
+ from torchvision import transforms as T
10
+
11
+
12
+ ### >>>>>>>> >>>>>>>> text related >>>>>>>> >>>>>>>> ###
13
+
14
+ class TokenizerWrapper():
15
+ def __init__(self, tokenizer, is_train, proportion_empty_prompts, use_generic_prompts=False):
16
+ self.tokenizer = tokenizer
17
+ self.is_train = is_train
18
+ self.proportion_empty_prompts = proportion_empty_prompts
19
+ self.use_generic_prompts = use_generic_prompts
20
+
21
+ def __call__(self, prompts):
22
+ if isinstance(prompts, str):
23
+ prompts = [prompts]
24
+ captions = []
25
+ for caption in prompts:
26
+ if random.random() < self.proportion_empty_prompts:
27
+ captions.append("")
28
+ else:
29
+ if self.use_generic_prompts:
30
+ captions.append("best quality, high quality")
31
+ elif isinstance(caption, str):
32
+ captions.append(caption)
33
+ elif isinstance(caption, (list, np.ndarray)):
34
+ # take a random caption if there are multiple
35
+ captions.append(random.choice(caption) if self.is_train else caption[0])
36
+ else:
37
+ raise ValueError(
38
+ f"Caption column should contain either strings or lists of strings."
39
+ )
40
+ inputs = self.tokenizer(
41
+ captions, max_length=self.tokenizer.model_max_length, padding="max_length",
42
+ truncation=True, return_tensors="pt"
43
+ )
44
+ return inputs.input_ids
45
+
46
+
47
+
48
+ ### >>>>>>>> >>>>>>>> image related >>>>>>>> >>>>>>>> ###
49
+
50
+ MONOCHROMATIC_MAX_VARIANCE = 0.3
51
+
52
+ def is_monochromatic_image(pil_img):
53
+ v = ImageStat.Stat(pil_img.convert('RGB')).var
54
+ return sum(v)<MONOCHROMATIC_MAX_VARIANCE
55
+
56
+ def isnumeric(text):
57
+ return (''.join(filter(str.isalnum, text))).isnumeric()
58
+
59
+
60
+
61
+ class TextPromptDataset(IterableDataset):
62
+ '''
63
+ The dataset for (text embedding, noise, generated latent) triplets.
64
+ '''
65
+ def __init__(self,
66
+ data_root,
67
+ tokenizer = None,
68
+ transform = None,
69
+ rank = 0,
70
+ world_size = 1,
71
+ shuffle = True,
72
+ ):
73
+ self.tokenizer = tokenizer
74
+ self.transform = transform
75
+
76
+ self.img_root = os.path.join(data_root, 'JPEGImages')
77
+ self.data_list = []
78
+
79
+ print("#### Loading filename list...")
80
+ json_root = os.path.join(data_root, 'list')
81
+ json_list = [p for p in os.listdir(json_root) if p.startswith("shard") and p.endswith('.json')]
82
+
83
+ # duplicate several shards to make sure each process has the same number of shards
84
+ assert len(json_list) > world_size
85
+ duplicate = world_size - len(json_list)%world_size if len(json_list)%world_size>0 else 0
86
+ json_list = json_list + json_list[:duplicate]
87
+ json_list = json_list[rank::world_size]
88
+
89
+ for json_file in tqdm(json_list):
90
+ shard_name = os.path.basename(json_file).split('.')[0]
91
+ with open(os.path.join(json_root, json_file)) as f:
92
+ key_text_pairs = json.load(f)
93
+
94
+ for pair in key_text_pairs:
95
+ self.data_list.append( [shard_name] + pair )
96
+
97
+ print("#### All filename loaded...")
98
+
99
+ self.shuffle = shuffle
100
+
101
+ def __len__(self):
102
+ return len(self.data_list)
103
+
104
+
105
+ def __iter__(self):
106
+ worker_info = get_worker_info()
107
+
108
+ if worker_info is None: # single-process data loading, return the full iterator
109
+ data_list = self.data_list
110
+ else:
111
+ len_data = len(self.data_list) - len(self.data_list) % worker_info.num_workers
112
+ data_list = self.data_list[:len_data][worker_info.id :: worker_info.num_workers]
113
+ # print(worker_info.num_workers, worker_info.id, len(data_list)/len(self.data_list))
114
+
115
+ if self.shuffle:
116
+ random.shuffle(data_list)
117
+
118
+ while True:
119
+ for idx in range(len(data_list)):
120
+ # try:
121
+ shard_name = data_list[idx][0]
122
+ data = {}
123
+
124
+ img_file = data_list[idx][1]
125
+ img = Image.open(os.path.join(self.img_root, shard_name, img_file+'.jpg')).convert("RGB")
126
+
127
+ if is_monochromatic_image(img):
128
+ continue
129
+
130
+ if self.transform is not None:
131
+ img = self.transform(img)
132
+
133
+ data['pixel_values'] = img
134
+
135
+ text = data_list[idx][2]
136
+ if self.tokenizer is not None:
137
+ if isinstance(self.tokenizer, list):
138
+ assert len(self.tokenizer)==2
139
+ data['input_ids'] = self.tokenizer[0](text)[0]
140
+ data['input_ids_2'] = self.tokenizer[1](text)[0]
141
+ else:
142
+ data['input_ids'] = self.tokenizer(text)[0]
143
+ else:
144
+ data['input_ids'] = text
145
+
146
+ yield data
147
+
148
+ # except Exception as e:
149
+ # raise(e)
150
+
151
+ def collate_fn(self, examples):
152
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
153
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
154
+
155
+ if self.tokenizer is not None:
156
+ if isinstance(self.tokenizer, list):
157
+ assert len(self.tokenizer)==2
158
+ input_ids = torch.stack([example["input_ids"] for example in examples])
159
+ input_ids_2 = torch.stack([example["input_ids_2"] for example in examples])
160
+ return {"pixel_values": pixel_values, "input_ids": input_ids, "input_ids_2": input_ids_2,}
161
+ else:
162
+ input_ids = torch.stack([example["input_ids"] for example in examples])
163
+ return {"pixel_values": pixel_values, "input_ids": input_ids,}
164
+ else:
165
+ input_ids = [example["input_ids"] for example in examples]
166
+ return {"pixel_values": pixel_values, "input_ids": input_ids,}
167
+
168
+
169
+ def make_train_dataset(
170
+ train_data_path,
171
+ size = 512,
172
+ tokenizer=None,
173
+ cfg_drop_ratio=0,
174
+ rank=0,
175
+ world_size=1,
176
+ shuffle=True,
177
+ ):
178
+
179
+ _image_transform = T.Compose([
180
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
181
+ T.Resize(size),
182
+ T.CenterCrop((size,size)),
183
+ T.ToTensor(),
184
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
185
+ ])
186
+
187
+ if tokenizer is not None:
188
+ if isinstance(tokenizer, list):
189
+ assert len(tokenizer)==2
190
+ tokenizer_1 = TokenizerWrapper(
191
+ tokenizer[0],
192
+ is_train=True,
193
+ proportion_empty_prompts=cfg_drop_ratio,
194
+ use_generic_prompts=False,
195
+ )
196
+ tokenizer_2 = TokenizerWrapper(
197
+ tokenizer[1],
198
+ is_train=True,
199
+ proportion_empty_prompts=cfg_drop_ratio,
200
+ use_generic_prompts=False,
201
+ )
202
+ tokenizer = [tokenizer_1, tokenizer_2]
203
+
204
+ else:
205
+ tokenizer = TokenizerWrapper(
206
+ tokenizer,
207
+ is_train=True,
208
+ proportion_empty_prompts=cfg_drop_ratio,
209
+ use_generic_prompts=False,
210
+ )
211
+
212
+
213
+ train_dataset = TextPromptDataset(
214
+ data_root=train_data_path,
215
+ transform=_image_transform,
216
+ rank=rank,
217
+ world_size=world_size,
218
+ tokenizer=tokenizer,
219
+ shuffle=shuffle,
220
+ )
221
+ return train_dataset
222
+
223
+
224
+
225
+
226
+
227
+
228
+
229
+
230
+
231
+
232
+ ### >>>>>>>> >>>>>>>> Test >>>>>>>> >>>>>>>> ###
233
+ if __name__ == "__main__":
234
+ from transformers import CLIPTextModel, CLIPTokenizer
235
+ tokenizer = CLIPTokenizer.from_pretrained(
236
+ "/mnt/bn/ic-research-aigc-editing/fast-diffusion-models/assets/public_models/StableDiffusion/stable-diffusion-v1-5",
237
+ subfolder="tokenizer"
238
+ )
239
+ train_dataset = make_train_dataset(tokenizer=tokenizer, rank=0, world_size=10)
240
+
241
+ loader = torch.utils.data.DataLoader(
242
+ train_dataset, batch_size=64, num_workers=0,
243
+ collate_fn=train_dataset.collect_fn if hasattr(train_dataset, 'collect_fn') else None,
244
+ )
245
+ for batch in loader:
246
+ pixel_values = batch["pixel_values"]
247
+ prompt_ids = batch['input_ids']
248
+ from einops import rearrange
249
+ pixel_values = rearrange(pixel_values, 'b c h w -> b h w c')
250
+
251
+ for i in range(pixel_values.shape[0]):
252
+ import pdb; pdb.set_trace()
253
+ Image.fromarray(((pixel_values[i] + 1 )/2 * 255 ).numpy().astype(np.uint8)).save('tmp.png')
254
+ input_id = prompt_ids[i]
255
+ text = tokenizer.decode(input_id).split('<|startoftext|>')[-1].split('<|endoftext|>')[0]
256
+ print(text)
257
+ pass
src/pfode_solver.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, math, random, argparse, logging
2
+ from pathlib import Path
3
+ from typing import Optional, Union, List, Callable
4
+ from collections import OrderedDict
5
+ from packaging import version
6
+ from tqdm.auto import tqdm
7
+ from omegaconf import OmegaConf
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ import torchvision
14
+
15
+
16
+ class PFODESolver():
17
+ def __init__(self, scheduler, t_initial=1, t_terminal=0,) -> None:
18
+ self.t_initial = t_initial
19
+ self.t_terminal = t_terminal
20
+ self.scheduler = scheduler
21
+
22
+ train_step_terminal = 0
23
+ train_step_initial = train_step_terminal + self.scheduler.config.num_train_timesteps # 0+1000
24
+ self.stepsize = (t_terminal-t_initial) / (train_step_terminal - train_step_initial) #1/1000
25
+
26
+ def get_timesteps(self, t_start, t_end, num_steps):
27
+ # (b,) -> (b,1)
28
+ t_start = t_start[:, None]
29
+ t_end = t_end[:, None]
30
+ assert t_start.dim() == 2
31
+
32
+ timepoints = torch.arange(0, num_steps, 1).expand(t_start.shape[0], num_steps).to(device=t_start.device)
33
+ interval = (t_end - t_start) / (torch.ones([1], device=t_start.device) * num_steps)
34
+ timepoints = t_start + interval * timepoints
35
+
36
+ timesteps = (self.scheduler.num_train_timesteps - 1) + (timepoints - self.t_initial) / self.stepsize # correspondint to StableDiffusion indexing system, from 999 (t_init) -> 0 (dt)
37
+ return timesteps.round().long()
38
+
39
+ def solve(self,
40
+ latents,
41
+ unet,
42
+ t_start,
43
+ t_end,
44
+ prompt_embeds,
45
+ negative_prompt_embeds,
46
+ guidance_scale=1.0,
47
+ num_steps = 2,
48
+ num_windows = 1,
49
+ ):
50
+ assert t_start.dim() == 1
51
+ assert guidance_scale >= 1 and torch.all(torch.gt(t_start, t_end))
52
+
53
+ do_classifier_free_guidance = True if guidance_scale > 1 else False
54
+ bsz = latents.shape[0]
55
+
56
+ if do_classifier_free_guidance:
57
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
58
+
59
+ timestep_cond = None
60
+ if unet.config.time_cond_proj_dim is not None:
61
+ guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(bsz)
62
+ timestep_cond = self.get_guidance_scale_embedding(
63
+ guidance_scale_tensor, embedding_dim=unet.config.time_cond_proj_dim
64
+ ).to(device=latents.device, dtype=latents.dtype)
65
+
66
+
67
+ timesteps = self.get_timesteps(t_start, t_end, num_steps).to(device=latents.device)
68
+ timestep_interval = self.scheduler.config.num_train_timesteps // (num_windows * num_steps)
69
+
70
+ # Denoising loop
71
+ with torch.no_grad():
72
+ for i in range(num_steps):
73
+ t = torch.cat([timesteps[:, i]]*2) if do_classifier_free_guidance else timesteps[:, i]
74
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
75
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
76
+
77
+ noise_pred = unet(
78
+ latent_model_input,
79
+ t,
80
+ encoder_hidden_states=prompt_embeds,
81
+ timestep_cond=timestep_cond,
82
+ return_dict=False,
83
+ )[0]
84
+
85
+ if do_classifier_free_guidance:
86
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
87
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
88
+
89
+ ##### STEP: compute the previous noisy sample x_t -> x_t-1
90
+ batch_timesteps = timesteps[:, i].cpu()
91
+ prev_timestep = batch_timesteps - timestep_interval
92
+
93
+ alpha_prod_t = self.scheduler.alphas_cumprod[batch_timesteps]
94
+ alpha_prod_t_prev = torch.zeros_like(alpha_prod_t)
95
+ for ib in range(prev_timestep.shape[0]):
96
+ alpha_prod_t_prev[ib] = self.scheduler.alphas_cumprod[prev_timestep[ib]] if prev_timestep[ib] >= 0 else self.scheduler.final_alpha_cumprod
97
+ beta_prod_t = 1 - alpha_prod_t
98
+
99
+ alpha_prod_t = alpha_prod_t.to(device=latents.device, dtype=latents.dtype)
100
+ alpha_prod_t_prev = alpha_prod_t_prev.to(device=latents.device, dtype=latents.dtype)
101
+ beta_prod_t = beta_prod_t.to(device=latents.device, dtype=latents.dtype)
102
+
103
+ if self.scheduler.config.prediction_type == "epsilon":
104
+ pred_original_sample = (latents - beta_prod_t[:,None,None,None] ** (0.5) * noise_pred) / alpha_prod_t[:, None,None,None] ** (0.5)
105
+ pred_epsilon = noise_pred
106
+ elif self.scheduler.config.prediction_type == "v_prediction":
107
+ pred_original_sample = (alpha_prod_t[:,None,None,None]**0.5) * latents - (beta_prod_t[:,None,None,None]**0.5) * noise_pred
108
+ pred_epsilon = (alpha_prod_t[:,None,None,None]**0.5) * noise_pred + (beta_prod_t[:,None,None,None]**0.5) * latents
109
+ else:
110
+ raise NotImplementedError
111
+
112
+ pred_sample_direction = (1 - alpha_prod_t_prev[:,None,None,None]) ** (0.5) * pred_epsilon
113
+ latents = alpha_prod_t_prev[:,None,None,None] ** (0.5) * pred_original_sample + pred_sample_direction
114
+
115
+
116
+ return latents
117
+
118
+
119
+
120
+
src/scheduler_perflow.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+ import numpy as np
22
+ import torch
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.utils import BaseOutput
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
27
+
28
+
29
+ class Time_Windows():
30
+ def __init__(self, t_initial=1, t_terminal=0, num_windows=4, precision=1./1000) -> None:
31
+ assert t_terminal < t_initial
32
+ time_windows = [ 1.*i/num_windows for i in range(1, num_windows+1)][::-1]
33
+
34
+ self.window_starts = time_windows # [1.0, 0.75, 0.5, 0.25]
35
+ self.window_ends = time_windows[1:] + [t_terminal] # [0.75, 0.5, 0.25, 0]
36
+ self.precision = precision
37
+
38
+ def get_window(self, tp):
39
+ idx = 0
40
+ # robust to numerical error; e.g, (0.6+1/10000) belongs to [0.6, 0.3)
41
+ while (tp-0.1*self.precision) <= self.window_ends[idx]:
42
+ idx += 1
43
+ return self.window_starts[idx], self.window_ends[idx]
44
+
45
+ def lookup_window(self, timepoint):
46
+ if timepoint.dim() == 0:
47
+ t_start, t_end = self.get_window(timepoint)
48
+ t_start = torch.ones_like(timepoint) * t_start
49
+ t_end = torch.ones_like(timepoint) * t_end
50
+ else:
51
+ t_start = torch.zeros_like(timepoint)
52
+ t_end = torch.zeros_like(timepoint)
53
+ bsz = timepoint.shape[0]
54
+ for i in range(bsz):
55
+ tp = timepoint[i]
56
+ ts, te = self.get_window(tp)
57
+ t_start[i] = ts
58
+ t_end[i] = te
59
+ return t_start, t_end
60
+
61
+
62
+
63
+ @dataclass
64
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
65
+ class PeRFlowSchedulerOutput(BaseOutput):
66
+ """
67
+ Output class for the scheduler's `step` function output.
68
+
69
+ Args:
70
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
71
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
72
+ denoising loop.
73
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
74
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
75
+ `pred_original_sample` can be used to preview progress or for guidance.
76
+ """
77
+
78
+ prev_sample: torch.FloatTensor
79
+ pred_original_sample: Optional[torch.FloatTensor] = None
80
+
81
+
82
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
83
+ def betas_for_alpha_bar(
84
+ num_diffusion_timesteps,
85
+ max_beta=0.999,
86
+ alpha_transform_type="cosine",
87
+ ):
88
+ """
89
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
90
+ (1-beta) over time from t = [0,1].
91
+
92
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
93
+ to that part of the diffusion process.
94
+
95
+
96
+ Args:
97
+ num_diffusion_timesteps (`int`): the number of betas to produce.
98
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
99
+ prevent singularities.
100
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
101
+ Choose from `cosine` or `exp`
102
+
103
+ Returns:
104
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
105
+ """
106
+ if alpha_transform_type == "cosine":
107
+
108
+ def alpha_bar_fn(t):
109
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
110
+
111
+ elif alpha_transform_type == "exp":
112
+
113
+ def alpha_bar_fn(t):
114
+ return math.exp(t * -12.0)
115
+
116
+ else:
117
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
118
+
119
+ betas = []
120
+ for i in range(num_diffusion_timesteps):
121
+ t1 = i / num_diffusion_timesteps
122
+ t2 = (i + 1) / num_diffusion_timesteps
123
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
124
+ return torch.tensor(betas, dtype=torch.float32)
125
+
126
+
127
+
128
+ class PeRFlowScheduler(SchedulerMixin, ConfigMixin):
129
+ """
130
+ `ReFlowScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
131
+ non-Markovian guidance.
132
+
133
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
134
+ methods the library implements for all schedulers such as loading and saving.
135
+
136
+ Args:
137
+ num_train_timesteps (`int`, defaults to 1000):
138
+ The number of diffusion steps to train the model.
139
+ beta_start (`float`, defaults to 0.0001):
140
+ The starting `beta` value of inference.
141
+ beta_end (`float`, defaults to 0.02):
142
+ The final `beta` value.
143
+ beta_schedule (`str`, defaults to `"linear"`):
144
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
145
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
146
+ trained_betas (`np.ndarray`, *optional*):
147
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
148
+ set_alpha_to_one (`bool`, defaults to `True`):
149
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
150
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
151
+ otherwise it uses the alpha value at step 0.
152
+ prediction_type (`str`, defaults to `epsilon`, *optional*)
153
+ """
154
+
155
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
156
+ order = 1
157
+
158
+ @register_to_config
159
+ def __init__(
160
+ self,
161
+ num_train_timesteps: int = 1000,
162
+ beta_start: float = 0.00085,
163
+ beta_end: float = 0.012,
164
+ beta_schedule: str = "scaled_linear",
165
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
166
+ set_alpha_to_one: bool = False,
167
+ prediction_type: str = "epsilon",
168
+ t_noise: float = 1,
169
+ t_clean: float = 0,
170
+ num_time_windows = 4,
171
+ ):
172
+ if trained_betas is not None:
173
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
174
+ elif beta_schedule == "linear":
175
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
176
+ elif beta_schedule == "scaled_linear":
177
+ # this schedule is very specific to the latent diffusion model.
178
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
179
+ elif beta_schedule == "squaredcos_cap_v2":
180
+ # Glide cosine schedule
181
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
182
+ else:
183
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
184
+
185
+ self.alphas = 1.0 - self.betas
186
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
187
+
188
+ # At every step in ddim, we are looking into the previous alphas_cumprod
189
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
190
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
191
+ # whether we use the final alpha of the "non-previous" one.
192
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
193
+
194
+ # # standard deviation of the initial noise distribution
195
+ self.init_noise_sigma = 1.0
196
+
197
+ self.time_windows = Time_Windows(t_initial=t_noise, t_terminal=t_clean,
198
+ num_windows=num_time_windows,
199
+ precision=1./num_train_timesteps)
200
+
201
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
202
+ """
203
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
204
+ current timestep.
205
+
206
+ Args:
207
+ sample (`torch.FloatTensor`):
208
+ The input sample.
209
+ timestep (`int`, *optional*):
210
+ The current timestep in the diffusion chain.
211
+
212
+ Returns:
213
+ `torch.FloatTensor`:
214
+ A scaled input sample.
215
+ """
216
+ return sample
217
+
218
+
219
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
220
+ """
221
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
222
+
223
+ Args:
224
+ num_inference_steps (`int`):
225
+ The number of diffusion steps used when generating samples with a pre-trained model.
226
+ """
227
+ if num_inference_steps < self.config.num_time_windows:
228
+ num_inference_steps = self.config.num_time_windows
229
+ print(f"### We recommend a num_inference_steps not less than num_time_windows. It's set as {self.config.num_time_windows}.")
230
+
231
+ timesteps = []
232
+ for i in range(self.config.num_time_windows):
233
+ if i < num_inference_steps%self.config.num_time_windows:
234
+ num_steps_cur_win = num_inference_steps//self.config.num_time_windows+1
235
+ else:
236
+ num_steps_cur_win = num_inference_steps//self.config.num_time_windows
237
+
238
+ t_s = self.time_windows.window_starts[i]
239
+ t_e = self.time_windows.window_ends[i]
240
+ timesteps_cur_win = np.linspace(t_s, t_e, num=num_steps_cur_win, endpoint=False)
241
+ timesteps.append(timesteps_cur_win)
242
+
243
+ timesteps = np.concatenate(timesteps)
244
+
245
+ self.timesteps = torch.from_numpy(
246
+ (timesteps*self.config.num_train_timesteps).astype(np.int64)
247
+ ).to(device)
248
+
249
+ def get_window_alpha(self, timestep):
250
+ time_windows = self.time_windows
251
+ num_train_timesteps = self.config.num_train_timesteps
252
+
253
+ t_win_start, t_win_end = time_windows.lookup_window(timestep / num_train_timesteps)
254
+ t_win_len = t_win_end - t_win_start
255
+ t_interval = timestep / num_train_timesteps - t_win_start # NOTE: negative value
256
+
257
+ idx_start = (t_win_start*num_train_timesteps - 1 ).long()
258
+ idx_end = torch.clamp( (t_win_end*num_train_timesteps - 1 ).long(), min=0)
259
+ alpha_cumprod_s_e = self.alphas_cumprod[idx_start] / self.alphas_cumprod[idx_end]
260
+ gamma_s_e = alpha_cumprod_s_e ** 0.5
261
+
262
+ return t_win_start, t_win_end, t_win_len, t_interval, gamma_s_e
263
+
264
+ def step(
265
+ self,
266
+ model_output: torch.FloatTensor,
267
+ timestep: int,
268
+ sample: torch.FloatTensor,
269
+ return_dict: bool = True,
270
+ ) -> Union[PeRFlowSchedulerOutput, Tuple]:
271
+ """
272
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
273
+ process from the learned model outputs (most often the predicted noise).
274
+
275
+ Args:
276
+ model_output (`torch.FloatTensor`):
277
+ The direct output from learned diffusion model.
278
+ timestep (`float`):
279
+ The current discrete timestep in the diffusion chain.
280
+ sample (`torch.FloatTensor`):
281
+ A current instance of a sample created by the diffusion process.
282
+ return_dict (`bool`, *optional*, defaults to `True`):
283
+ Whether or not to return a [`~schedulers.scheduling_ddim.PeRFlowSchedulerOutput`] or `tuple`.
284
+
285
+ Returns:
286
+ [`~schedulers.scheduling_utils.PeRFlowSchedulerOutput`] or `tuple`:
287
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.PeRFlowSchedulerOutput`] is returned, otherwise a
288
+ tuple is returned where the first element is the sample tensor.
289
+ """
290
+
291
+ if self.config.prediction_type == "epsilon":
292
+ pred_epsilon = model_output
293
+ t_win_start, t_win_end, t_win_len, t_interval, gamma_s_e = self.get_window_alpha(timestep)
294
+ pred_sample_end = ( sample - (1-t_interval/t_win_len) * ((1-gamma_s_e**2)**0.5) * pred_epsilon ) \
295
+ / ( gamma_s_e + t_interval / t_win_len * (1-gamma_s_e) )
296
+ pred_velocity = (pred_sample_end - sample) / (t_win_end - (t_win_start + t_interval))
297
+
298
+ elif self.config.prediction_type == "velocity":
299
+ pred_velocity = model_output
300
+ else:
301
+ raise ValueError(
302
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `velocity`."
303
+ )
304
+
305
+ # get dt
306
+ idx = torch.argwhere(torch.where(self.timesteps==timestep, 1,0))
307
+ prev_step = self.timesteps[idx+1] if (idx+1)<len(self.timesteps) else 0
308
+ dt = (prev_step - timestep) / self.config.num_train_timesteps
309
+ dt = dt.to(sample.device, sample.dtype)
310
+
311
+ prev_sample = sample + dt * pred_velocity
312
+
313
+ if not return_dict:
314
+ return (prev_sample,)
315
+ return PeRFlowSchedulerOutput(prev_sample=prev_sample, pred_original_sample=None)
316
+
317
+
318
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
319
+ def add_noise(
320
+ self,
321
+ original_samples: torch.FloatTensor,
322
+ noise: torch.FloatTensor,
323
+ timesteps: torch.IntTensor,
324
+ ) -> torch.FloatTensor:
325
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
326
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
327
+ timesteps = timesteps.to(original_samples.device) - 1 # indexing from 0
328
+
329
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
330
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
331
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
332
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
333
+
334
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
335
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
336
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
337
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
338
+
339
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
340
+ return noisy_samples
341
+
342
+ def __len__(self):
343
+ return self.config.num_train_timesteps
src/utils_perflow.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+ import torch
4
+ from safetensors import safe_open
5
+ from safetensors.torch import save_file
6
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
7
+ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, convert_ldm_clip_checkpoint
8
+
9
+
10
+ def merge_delta_weights_into_unet(pipe, delta_weights):
11
+ unet_weights = pipe.unet.state_dict()
12
+ assert unet_weights.keys() == delta_weights.keys()
13
+ for key in delta_weights.keys():
14
+ dtype = unet_weights[key].dtype
15
+ unet_weights[key] = unet_weights[key].to(dtype=delta_weights[key].dtype) + delta_weights[key].to(device=unet_weights[key].device)
16
+ unet_weights[key] = unet_weights[key].to(dtype)
17
+ pipe.unet.load_state_dict(unet_weights, strict=True)
18
+ return pipe
19
+
20
+
21
+ def load_delta_weights_into_unet(
22
+ pipe,
23
+ model_path = "hsyan/piecewise-rectified-flow-v0-1",
24
+ base_path = "runwayml/stable-diffusion-v1-5",
25
+ ):
26
+ ## load delta_weights
27
+ if os.path.exists(os.path.join(model_path, "delta_weights.safetensors")):
28
+ print("### delta_weights exists, loading...")
29
+ delta_weights = OrderedDict()
30
+ with safe_open(os.path.join(model_path, "delta_weights.safetensors"), framework="pt", device="cpu") as f:
31
+ for key in f.keys():
32
+ delta_weights[key] = f.get_tensor(key)
33
+
34
+ elif os.path.exists(os.path.join(model_path, "diffusion_pytorch_model.safetensors")):
35
+ print("### merged_weights exists, loading...")
36
+ merged_weights = OrderedDict()
37
+ with safe_open(os.path.join(model_path, "diffusion_pytorch_model.safetensors"), framework="pt", device="cpu") as f:
38
+ for key in f.keys():
39
+ merged_weights[key] = f.get_tensor(key)
40
+
41
+ base_weights = StableDiffusionPipeline.from_pretrained(
42
+ base_path, torch_dtype=torch.float16, safety_checker=None).unet.state_dict()
43
+ assert base_weights.keys() == merged_weights.keys()
44
+
45
+ delta_weights = OrderedDict()
46
+ for key in merged_weights.keys():
47
+ delta_weights[key] = merged_weights[key] - base_weights[key].to(device=merged_weights[key].device, dtype=merged_weights[key].dtype)
48
+
49
+ print("### saving delta_weights...")
50
+ save_file(delta_weights, os.path.join(model_path, "delta_weights.safetensors"))
51
+
52
+ else:
53
+ raise ValueError(f"{model_path} does not contain delta weights or merged weights")
54
+
55
+ ## merge delta_weights to the target pipeline
56
+ pipe = merge_delta_weights_into_unet(pipe, delta_weights)
57
+ return pipe
58
+
59
+
60
+
61
+
62
+ def load_dreambooth_into_pipeline(pipe, sd_dreambooth):
63
+ assert sd_dreambooth.endswith(".safetensors")
64
+ state_dict = {}
65
+ with safe_open(sd_dreambooth, framework="pt", device="cpu") as f:
66
+ for key in f.keys():
67
+ state_dict[key] = f.get_tensor(key)
68
+
69
+ unet_config = {} # unet, line 449 in convert_ldm_unet_checkpoint
70
+ for key in pipe.unet.config.keys():
71
+ if key != 'num_class_embeds':
72
+ unet_config[key] = pipe.unet.config[key]
73
+
74
+ pipe.unet.load_state_dict(convert_ldm_unet_checkpoint(state_dict, unet_config), strict=False)
75
+ pipe.vae.load_state_dict(convert_ldm_vae_checkpoint(state_dict, pipe.vae.config))
76
+ pipe.text_encoder = convert_ldm_clip_checkpoint(state_dict, text_encoder=pipe.text_encoder)
77
+ return pipe
test.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: test
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ - conda-forge
7
+ dependencies:
8
+ - python=3.10.12
9
+ - pip=23.2.1
10
+ - cudatoolkit=11.7
tsr/__pycache__/system.cpython-310.pyc ADDED
Binary file (5.19 kB). View file
 
tsr/__pycache__/system.cpython-38.pyc ADDED
Binary file (5.07 kB). View file
 
tsr/__pycache__/utils.cpython-310.pyc ADDED
Binary file (13.6 kB). View file
 
tsr/__pycache__/utils.cpython-38.pyc ADDED
Binary file (13.5 kB). View file
 
tsr/models/__pycache__/isosurface.cpython-310.pyc ADDED
Binary file (2.27 kB). View file
 
tsr/models/__pycache__/isosurface.cpython-38.pyc ADDED
Binary file (2.23 kB). View file
 
tsr/models/__pycache__/nerf_renderer.cpython-310.pyc ADDED
Binary file (5.32 kB). View file
 
tsr/models/__pycache__/nerf_renderer.cpython-38.pyc ADDED
Binary file (5.31 kB). View file
 
tsr/models/__pycache__/network_utils.cpython-310.pyc ADDED
Binary file (3.44 kB). View file
 
tsr/models/__pycache__/network_utils.cpython-38.pyc ADDED
Binary file (3.39 kB). View file
 
tsr/models/isosurface.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchmcubes import marching_cubes
7
+
8
+
9
+ class IsosurfaceHelper(nn.Module):
10
+ points_range: Tuple[float, float] = (0, 1)
11
+
12
+ @property
13
+ def grid_vertices(self) -> torch.FloatTensor:
14
+ raise NotImplementedError
15
+
16
+
17
+ class MarchingCubeHelper(IsosurfaceHelper):
18
+ def __init__(self, resolution: int) -> None:
19
+ super().__init__()
20
+ self.resolution = resolution
21
+ self.mc_func: Callable = marching_cubes
22
+ self._grid_vertices: Optional[torch.FloatTensor] = None
23
+
24
+ @property
25
+ def grid_vertices(self) -> torch.FloatTensor:
26
+ if self._grid_vertices is None:
27
+ # keep the vertices on CPU so that we can support very large resolution
28
+ x, y, z = (
29
+ torch.linspace(*self.points_range, self.resolution),
30
+ torch.linspace(*self.points_range, self.resolution),
31
+ torch.linspace(*self.points_range, self.resolution),
32
+ )
33
+ x, y, z = torch.meshgrid(x, y, z, indexing="ij")
34
+ verts = torch.cat(
35
+ [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
36
+ ).reshape(-1, 3)
37
+ self._grid_vertices = verts
38
+ return self._grid_vertices
39
+
40
+ def forward(
41
+ self,
42
+ level: torch.FloatTensor,
43
+ ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
44
+ level = -level.view(self.resolution, self.resolution, self.resolution)
45
+ try:
46
+ v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
47
+ except AttributeError:
48
+ print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
49
+ v_pos, t_pos_idx = self.mc_func(level.detach().cpu(), 0.0)
50
+ v_pos = v_pos[..., [2, 1, 0]]
51
+ v_pos = v_pos / (self.resolution - 1.0)
52
+ return v_pos.to(level.device), t_pos_idx.to(level.device)
tsr/models/nerf_renderer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, reduce
7
+
8
+ from ..utils import (
9
+ BaseModule,
10
+ chunk_batch,
11
+ get_activation,
12
+ rays_intersect_bbox,
13
+ scale_tensor,
14
+ )
15
+
16
+
17
+ class TriplaneNeRFRenderer(BaseModule):
18
+ @dataclass
19
+ class Config(BaseModule.Config):
20
+ radius: float
21
+
22
+ feature_reduction: str = "concat"
23
+ density_activation: str = "trunc_exp"
24
+ density_bias: float = -1.0
25
+ color_activation: str = "sigmoid"
26
+ num_samples_per_ray: int = 128
27
+ randomized: bool = False
28
+
29
+ cfg: Config
30
+
31
+ def configure(self) -> None:
32
+ assert self.cfg.feature_reduction in ["concat", "mean"]
33
+ self.chunk_size = 0
34
+
35
+ def set_chunk_size(self, chunk_size: int):
36
+ assert (
37
+ chunk_size >= 0
38
+ ), "chunk_size must be a non-negative integer (0 for no chunking)."
39
+ self.chunk_size = chunk_size
40
+
41
+ def query_triplane(
42
+ self,
43
+ decoder: torch.nn.Module,
44
+ positions: torch.Tensor,
45
+ triplane: torch.Tensor,
46
+ ) -> Dict[str, torch.Tensor]:
47
+ input_shape = positions.shape[:-1]
48
+ positions = positions.view(-1, 3)
49
+
50
+ # positions in (-radius, radius)
51
+ # normalized to (-1, 1) for grid sample
52
+ positions = scale_tensor(
53
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
54
+ )
55
+
56
+ def _query_chunk(x):
57
+ indices2D: torch.Tensor = torch.stack(
58
+ (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
59
+ dim=-3,
60
+ )
61
+ out: torch.Tensor = F.grid_sample(
62
+ rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
63
+ rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
64
+ align_corners=False,
65
+ mode="bilinear",
66
+ )
67
+ if self.cfg.feature_reduction == "concat":
68
+ out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
69
+ elif self.cfg.feature_reduction == "mean":
70
+ out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
71
+ else:
72
+ raise NotImplementedError
73
+
74
+ net_out: Dict[str, torch.Tensor] = decoder(out)
75
+ return net_out
76
+
77
+ if self.chunk_size > 0:
78
+ net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
79
+ else:
80
+ net_out = _query_chunk(positions)
81
+
82
+ net_out["density_act"] = get_activation(self.cfg.density_activation)(
83
+ net_out["density"] + self.cfg.density_bias
84
+ )
85
+ net_out["color"] = get_activation(self.cfg.color_activation)(
86
+ net_out["features"]
87
+ )
88
+
89
+ net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
90
+
91
+ return net_out
92
+
93
+ def _forward(
94
+ self,
95
+ decoder: torch.nn.Module,
96
+ triplane: torch.Tensor,
97
+ rays_o: torch.Tensor,
98
+ rays_d: torch.Tensor,
99
+ **kwargs,
100
+ ):
101
+ rays_shape = rays_o.shape[:-1]
102
+ rays_o = rays_o.view(-1, 3)
103
+ rays_d = rays_d.view(-1, 3)
104
+ n_rays = rays_o.shape[0]
105
+
106
+ t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
107
+ t_near, t_far = t_near[rays_valid], t_far[rays_valid]
108
+
109
+ t_vals = torch.linspace(
110
+ 0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
111
+ )
112
+ t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
113
+ z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples)
114
+
115
+ xyz = (
116
+ rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
117
+ ) # (N_rays, N_sample, 3)
118
+
119
+ mlp_out = self.query_triplane(
120
+ decoder=decoder,
121
+ positions=xyz,
122
+ triplane=triplane,
123
+ )
124
+
125
+ eps = 1e-10
126
+ # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
127
+ deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples)
128
+ alpha = 1 - torch.exp(
129
+ -deltas * mlp_out["density_act"][..., 0]
130
+ ) # (N_rays, N_samples)
131
+ accum_prod = torch.cat(
132
+ [
133
+ torch.ones_like(alpha[:, :1]),
134
+ torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
135
+ ],
136
+ dim=-1,
137
+ )
138
+ weights = alpha * accum_prod # (N_rays, N_samples)
139
+ comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3)
140
+ opacity_ = weights.sum(dim=-1) # (N_rays)
141
+
142
+ comp_rgb = torch.zeros(
143
+ n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
144
+ )
145
+ opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
146
+ comp_rgb[rays_valid] = comp_rgb_
147
+ opacity[rays_valid] = opacity_
148
+
149
+ comp_rgb += 1 - opacity[..., None]
150
+ comp_rgb = comp_rgb.view(*rays_shape, 3)
151
+
152
+ return comp_rgb
153
+
154
+ def forward(
155
+ self,
156
+ decoder: torch.nn.Module,
157
+ triplane: torch.Tensor,
158
+ rays_o: torch.Tensor,
159
+ rays_d: torch.Tensor,
160
+ ) -> Dict[str, torch.Tensor]:
161
+ if triplane.ndim == 4:
162
+ comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
163
+ else:
164
+ comp_rgb = torch.stack(
165
+ [
166
+ self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
167
+ for i in range(triplane.shape[0])
168
+ ],
169
+ dim=0,
170
+ )
171
+
172
+ return comp_rgb
173
+
174
+ def train(self, mode=True):
175
+ self.randomized = mode and self.cfg.randomized
176
+ return super().train(mode=mode)
177
+
178
+ def eval(self):
179
+ self.randomized = False
180
+ return super().eval()
tsr/models/network_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from ..utils import BaseModule
9
+
10
+
11
+ class TriplaneUpsampleNetwork(BaseModule):
12
+ @dataclass
13
+ class Config(BaseModule.Config):
14
+ in_channels: int
15
+ out_channels: int
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.upsample = nn.ConvTranspose2d(
21
+ self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
22
+ )
23
+
24
+ def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
25
+ triplanes_up = rearrange(
26
+ self.upsample(
27
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
28
+ ),
29
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
30
+ Np=3,
31
+ )
32
+ return triplanes_up
33
+
34
+
35
+ class NeRFMLP(BaseModule):
36
+ @dataclass
37
+ class Config(BaseModule.Config):
38
+ in_channels: int
39
+ n_neurons: int
40
+ n_hidden_layers: int
41
+ activation: str = "relu"
42
+ bias: bool = True
43
+ weight_init: Optional[str] = "kaiming_uniform"
44
+ bias_init: Optional[str] = None
45
+
46
+ cfg: Config
47
+
48
+ def configure(self) -> None:
49
+ layers = [
50
+ self.make_linear(
51
+ self.cfg.in_channels,
52
+ self.cfg.n_neurons,
53
+ bias=self.cfg.bias,
54
+ weight_init=self.cfg.weight_init,
55
+ bias_init=self.cfg.bias_init,
56
+ ),
57
+ self.make_activation(self.cfg.activation),
58
+ ]
59
+ for i in range(self.cfg.n_hidden_layers - 1):
60
+ layers += [
61
+ self.make_linear(
62
+ self.cfg.n_neurons,
63
+ self.cfg.n_neurons,
64
+ bias=self.cfg.bias,
65
+ weight_init=self.cfg.weight_init,
66
+ bias_init=self.cfg.bias_init,
67
+ ),
68
+ self.make_activation(self.cfg.activation),
69
+ ]
70
+ layers += [
71
+ self.make_linear(
72
+ self.cfg.n_neurons,
73
+ 4, # density 1 + features 3
74
+ bias=self.cfg.bias,
75
+ weight_init=self.cfg.weight_init,
76
+ bias_init=self.cfg.bias_init,
77
+ )
78
+ ]
79
+ self.layers = nn.Sequential(*layers)
80
+
81
+ def make_linear(
82
+ self,
83
+ dim_in,
84
+ dim_out,
85
+ bias=True,
86
+ weight_init=None,
87
+ bias_init=None,
88
+ ):
89
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
90
+
91
+ if weight_init is None:
92
+ pass
93
+ elif weight_init == "kaiming_uniform":
94
+ torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ if bias:
99
+ if bias_init is None:
100
+ pass
101
+ elif bias_init == "zero":
102
+ torch.nn.init.zeros_(layer.bias)
103
+ else:
104
+ raise NotImplementedError
105
+
106
+ return layer
107
+
108
+ def make_activation(self, activation):
109
+ if activation == "relu":
110
+ return nn.ReLU(inplace=True)
111
+ elif activation == "silu":
112
+ return nn.SiLU(inplace=True)
113
+ else:
114
+ raise NotImplementedError
115
+
116
+ def forward(self, x):
117
+ inp_shape = x.shape[:-1]
118
+ x = x.reshape(-1, x.shape[-1])
119
+
120
+ features = self.layers(x)
121
+ features = features.reshape(*inp_shape, -1)
122
+ out = {"density": features[..., 0:1], "features": features[..., 1:4]}
123
+
124
+ return out
tsr/models/tokenizers/__pycache__/image.cpython-310.pyc ADDED
Binary file (2.42 kB). View file
 
tsr/models/tokenizers/__pycache__/image.cpython-38.pyc ADDED
Binary file (2.39 kB). View file
 
tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc ADDED
Binary file (1.79 kB). View file
 
tsr/models/tokenizers/__pycache__/triplane.cpython-38.pyc ADDED
Binary file (1.77 kB). View file
 
tsr/models/tokenizers/image.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers.models.vit.modeling_vit import ViTModel
8
+
9
+ from ...utils import BaseModule
10
+
11
+
12
+ class DINOSingleImageTokenizer(BaseModule):
13
+ @dataclass
14
+ class Config(BaseModule.Config):
15
+ pretrained_model_name_or_path: str = "facebook/dino-vitb16"
16
+ enable_gradient_checkpointing: bool = False
17
+
18
+ cfg: Config
19
+
20
+ def configure(self) -> None:
21
+ self.model: ViTModel = ViTModel(
22
+ ViTModel.config_class.from_pretrained(
23
+ hf_hub_download(
24
+ repo_id=self.cfg.pretrained_model_name_or_path,
25
+ filename="config.json",
26
+ )
27
+ )
28
+ )
29
+
30
+ if self.cfg.enable_gradient_checkpointing:
31
+ self.model.encoder.gradient_checkpointing = True
32
+
33
+ self.register_buffer(
34
+ "image_mean",
35
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
36
+ persistent=False,
37
+ )
38
+ self.register_buffer(
39
+ "image_std",
40
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
41
+ persistent=False,
42
+ )
43
+
44
+ def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
45
+ packed = False
46
+ if images.ndim == 4:
47
+ packed = True
48
+ images = images.unsqueeze(1)
49
+
50
+ batch_size, n_input_views = images.shape[:2]
51
+ images = (images - self.image_mean) / self.image_std
52
+ out = self.model(
53
+ rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
54
+ )
55
+ local_features, global_features = out.last_hidden_state, out.pooler_output
56
+ local_features = local_features.permute(0, 2, 1)
57
+ local_features = rearrange(
58
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
59
+ )
60
+ if packed:
61
+ local_features = local_features.squeeze(1)
62
+
63
+ return local_features
64
+
65
+ def detokenize(self, *args, **kwargs):
66
+ raise NotImplementedError
tsr/models/tokenizers/triplane.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+
8
+ from ...utils import BaseModule
9
+
10
+
11
+ class Triplane1DTokenizer(BaseModule):
12
+ @dataclass
13
+ class Config(BaseModule.Config):
14
+ plane_size: int
15
+ num_channels: int
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.embeddings = nn.Parameter(
21
+ torch.randn(
22
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
23
+ dtype=torch.float32,
24
+ )
25
+ * 1
26
+ / math.sqrt(self.cfg.num_channels)
27
+ )
28
+
29
+ def forward(self, batch_size: int) -> torch.Tensor:
30
+ return rearrange(
31
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
32
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
33
+ )
34
+
35
+ def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ batch_size, Ct, Nt = tokens.shape
37
+ assert Nt == self.cfg.plane_size**2 * 3
38
+ assert Ct == self.cfg.num_channels
39
+ return rearrange(
40
+ tokens,
41
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
42
+ Np=3,
43
+ Hp=self.cfg.plane_size,
44
+ Wp=self.cfg.plane_size,
45
+ )
tsr/models/transformer/__pycache__/attention.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
tsr/models/transformer/__pycache__/attention.cpython-38.pyc ADDED
Binary file (15.2 kB). View file
 
tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc ADDED
Binary file (9.65 kB). View file
 
tsr/models/transformer/__pycache__/basic_transformer_block.cpython-38.pyc ADDED
Binary file (9.49 kB). View file
 
tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc ADDED
Binary file (4.91 kB). View file
 
tsr/models/transformer/__pycache__/transformer_1d.cpython-38.pyc ADDED
Binary file (4.85 kB). View file
 
tsr/models/transformer/attention.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from typing import Optional
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+ from torch import nn
44
+
45
+
46
+ class Attention(nn.Module):
47
+ r"""
48
+ A cross attention layer.
49
+
50
+ Parameters:
51
+ query_dim (`int`):
52
+ The number of channels in the query.
53
+ cross_attention_dim (`int`, *optional*):
54
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
55
+ heads (`int`, *optional*, defaults to 8):
56
+ The number of heads to use for multi-head attention.
57
+ dim_head (`int`, *optional*, defaults to 64):
58
+ The number of channels in each head.
59
+ dropout (`float`, *optional*, defaults to 0.0):
60
+ The dropout probability to use.
61
+ bias (`bool`, *optional*, defaults to False):
62
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
63
+ upcast_attention (`bool`, *optional*, defaults to False):
64
+ Set to `True` to upcast the attention computation to `float32`.
65
+ upcast_softmax (`bool`, *optional*, defaults to False):
66
+ Set to `True` to upcast the softmax computation to `float32`.
67
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
68
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
69
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
70
+ The number of groups to use for the group norm in the cross attention.
71
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
72
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
73
+ norm_num_groups (`int`, *optional*, defaults to `None`):
74
+ The number of groups to use for the group norm in the attention.
75
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
76
+ The number of channels to use for the spatial normalization.
77
+ out_bias (`bool`, *optional*, defaults to `True`):
78
+ Set to `True` to use a bias in the output linear layer.
79
+ scale_qk (`bool`, *optional*, defaults to `True`):
80
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
81
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
82
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
83
+ `added_kv_proj_dim` is not `None`.
84
+ eps (`float`, *optional*, defaults to 1e-5):
85
+ An additional value added to the denominator in group normalization that is used for numerical stability.
86
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
87
+ A factor to rescale the output by dividing it with this value.
88
+ residual_connection (`bool`, *optional*, defaults to `False`):
89
+ Set to `True` to add the residual connection to the output.
90
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
91
+ Set to `True` if the attention block is loaded from a deprecated state dict.
92
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
93
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
94
+ `AttnProcessor` otherwise.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ query_dim: int,
100
+ cross_attention_dim: Optional[int] = None,
101
+ heads: int = 8,
102
+ dim_head: int = 64,
103
+ dropout: float = 0.0,
104
+ bias: bool = False,
105
+ upcast_attention: bool = False,
106
+ upcast_softmax: bool = False,
107
+ cross_attention_norm: Optional[str] = None,
108
+ cross_attention_norm_num_groups: int = 32,
109
+ added_kv_proj_dim: Optional[int] = None,
110
+ norm_num_groups: Optional[int] = None,
111
+ out_bias: bool = True,
112
+ scale_qk: bool = True,
113
+ only_cross_attention: bool = False,
114
+ eps: float = 1e-5,
115
+ rescale_output_factor: float = 1.0,
116
+ residual_connection: bool = False,
117
+ _from_deprecated_attn_block: bool = False,
118
+ processor: Optional["AttnProcessor"] = None,
119
+ out_dim: int = None,
120
+ ):
121
+ super().__init__()
122
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
123
+ self.query_dim = query_dim
124
+ self.cross_attention_dim = (
125
+ cross_attention_dim if cross_attention_dim is not None else query_dim
126
+ )
127
+ self.upcast_attention = upcast_attention
128
+ self.upcast_softmax = upcast_softmax
129
+ self.rescale_output_factor = rescale_output_factor
130
+ self.residual_connection = residual_connection
131
+ self.dropout = dropout
132
+ self.fused_projections = False
133
+ self.out_dim = out_dim if out_dim is not None else query_dim
134
+
135
+ # we make use of this private variable to know whether this class is loaded
136
+ # with an deprecated state dict so that we can convert it on the fly
137
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
138
+
139
+ self.scale_qk = scale_qk
140
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
141
+
142
+ self.heads = out_dim // dim_head if out_dim is not None else heads
143
+ # for slice_size > 0 the attention score computation
144
+ # is split across the batch axis to save memory
145
+ # You can set slice_size with `set_attention_slice`
146
+ self.sliceable_head_dim = heads
147
+
148
+ self.added_kv_proj_dim = added_kv_proj_dim
149
+ self.only_cross_attention = only_cross_attention
150
+
151
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
152
+ raise ValueError(
153
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
154
+ )
155
+
156
+ if norm_num_groups is not None:
157
+ self.group_norm = nn.GroupNorm(
158
+ num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
159
+ )
160
+ else:
161
+ self.group_norm = None
162
+
163
+ self.spatial_norm = None
164
+
165
+ if cross_attention_norm is None:
166
+ self.norm_cross = None
167
+ elif cross_attention_norm == "layer_norm":
168
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
169
+ elif cross_attention_norm == "group_norm":
170
+ if self.added_kv_proj_dim is not None:
171
+ # The given `encoder_hidden_states` are initially of shape
172
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
173
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
174
+ # before the projection, so we need to use `added_kv_proj_dim` as
175
+ # the number of channels for the group norm.
176
+ norm_cross_num_channels = added_kv_proj_dim
177
+ else:
178
+ norm_cross_num_channels = self.cross_attention_dim
179
+
180
+ self.norm_cross = nn.GroupNorm(
181
+ num_channels=norm_cross_num_channels,
182
+ num_groups=cross_attention_norm_num_groups,
183
+ eps=1e-5,
184
+ affine=True,
185
+ )
186
+ else:
187
+ raise ValueError(
188
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
189
+ )
190
+
191
+ linear_cls = nn.Linear
192
+
193
+ self.linear_cls = linear_cls
194
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
195
+
196
+ if not self.only_cross_attention:
197
+ # only relevant for the `AddedKVProcessor` classes
198
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
199
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
200
+ else:
201
+ self.to_k = None
202
+ self.to_v = None
203
+
204
+ if self.added_kv_proj_dim is not None:
205
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
206
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
207
+
208
+ self.to_out = nn.ModuleList([])
209
+ self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
210
+ self.to_out.append(nn.Dropout(dropout))
211
+
212
+ # set attention processor
213
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
214
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
215
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
216
+ if processor is None:
217
+ processor = (
218
+ AttnProcessor2_0()
219
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
220
+ else AttnProcessor()
221
+ )
222
+ self.set_processor(processor)
223
+
224
+ def set_processor(self, processor: "AttnProcessor") -> None:
225
+ self.processor = processor
226
+
227
+ def forward(
228
+ self,
229
+ hidden_states: torch.FloatTensor,
230
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
231
+ attention_mask: Optional[torch.FloatTensor] = None,
232
+ **cross_attention_kwargs,
233
+ ) -> torch.Tensor:
234
+ r"""
235
+ The forward method of the `Attention` class.
236
+
237
+ Args:
238
+ hidden_states (`torch.Tensor`):
239
+ The hidden states of the query.
240
+ encoder_hidden_states (`torch.Tensor`, *optional*):
241
+ The hidden states of the encoder.
242
+ attention_mask (`torch.Tensor`, *optional*):
243
+ The attention mask to use. If `None`, no mask is applied.
244
+ **cross_attention_kwargs:
245
+ Additional keyword arguments to pass along to the cross attention.
246
+
247
+ Returns:
248
+ `torch.Tensor`: The output of the attention layer.
249
+ """
250
+ # The `Attention` class can call different attention processors / attention functions
251
+ # here we simply pass along all tensors to the selected processor class
252
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
253
+ return self.processor(
254
+ self,
255
+ hidden_states,
256
+ encoder_hidden_states=encoder_hidden_states,
257
+ attention_mask=attention_mask,
258
+ **cross_attention_kwargs,
259
+ )
260
+
261
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
262
+ r"""
263
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
264
+ is the number of heads initialized while constructing the `Attention` class.
265
+
266
+ Args:
267
+ tensor (`torch.Tensor`): The tensor to reshape.
268
+
269
+ Returns:
270
+ `torch.Tensor`: The reshaped tensor.
271
+ """
272
+ head_size = self.heads
273
+ batch_size, seq_len, dim = tensor.shape
274
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
275
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
276
+ batch_size // head_size, seq_len, dim * head_size
277
+ )
278
+ return tensor
279
+
280
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
281
+ r"""
282
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
283
+ the number of heads initialized while constructing the `Attention` class.
284
+
285
+ Args:
286
+ tensor (`torch.Tensor`): The tensor to reshape.
287
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
288
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
289
+
290
+ Returns:
291
+ `torch.Tensor`: The reshaped tensor.
292
+ """
293
+ head_size = self.heads
294
+ batch_size, seq_len, dim = tensor.shape
295
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
296
+ tensor = tensor.permute(0, 2, 1, 3)
297
+
298
+ if out_dim == 3:
299
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
300
+
301
+ return tensor
302
+
303
+ def get_attention_scores(
304
+ self,
305
+ query: torch.Tensor,
306
+ key: torch.Tensor,
307
+ attention_mask: torch.Tensor = None,
308
+ ) -> torch.Tensor:
309
+ r"""
310
+ Compute the attention scores.
311
+
312
+ Args:
313
+ query (`torch.Tensor`): The query tensor.
314
+ key (`torch.Tensor`): The key tensor.
315
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
316
+
317
+ Returns:
318
+ `torch.Tensor`: The attention probabilities/scores.
319
+ """
320
+ dtype = query.dtype
321
+ if self.upcast_attention:
322
+ query = query.float()
323
+ key = key.float()
324
+
325
+ if attention_mask is None:
326
+ baddbmm_input = torch.empty(
327
+ query.shape[0],
328
+ query.shape[1],
329
+ key.shape[1],
330
+ dtype=query.dtype,
331
+ device=query.device,
332
+ )
333
+ beta = 0
334
+ else:
335
+ baddbmm_input = attention_mask
336
+ beta = 1
337
+
338
+ attention_scores = torch.baddbmm(
339
+ baddbmm_input,
340
+ query,
341
+ key.transpose(-1, -2),
342
+ beta=beta,
343
+ alpha=self.scale,
344
+ )
345
+ del baddbmm_input
346
+
347
+ if self.upcast_softmax:
348
+ attention_scores = attention_scores.float()
349
+
350
+ attention_probs = attention_scores.softmax(dim=-1)
351
+ del attention_scores
352
+
353
+ attention_probs = attention_probs.to(dtype)
354
+
355
+ return attention_probs
356
+
357
+ def prepare_attention_mask(
358
+ self,
359
+ attention_mask: torch.Tensor,
360
+ target_length: int,
361
+ batch_size: int,
362
+ out_dim: int = 3,
363
+ ) -> torch.Tensor:
364
+ r"""
365
+ Prepare the attention mask for the attention computation.
366
+
367
+ Args:
368
+ attention_mask (`torch.Tensor`):
369
+ The attention mask to prepare.
370
+ target_length (`int`):
371
+ The target length of the attention mask. This is the length of the attention mask after padding.
372
+ batch_size (`int`):
373
+ The batch size, which is used to repeat the attention mask.
374
+ out_dim (`int`, *optional*, defaults to `3`):
375
+ The output dimension of the attention mask. Can be either `3` or `4`.
376
+
377
+ Returns:
378
+ `torch.Tensor`: The prepared attention mask.
379
+ """
380
+ head_size = self.heads
381
+ if attention_mask is None:
382
+ return attention_mask
383
+
384
+ current_length: int = attention_mask.shape[-1]
385
+ if current_length != target_length:
386
+ if attention_mask.device.type == "mps":
387
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
388
+ # Instead, we can manually construct the padding tensor.
389
+ padding_shape = (
390
+ attention_mask.shape[0],
391
+ attention_mask.shape[1],
392
+ target_length,
393
+ )
394
+ padding = torch.zeros(
395
+ padding_shape,
396
+ dtype=attention_mask.dtype,
397
+ device=attention_mask.device,
398
+ )
399
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
400
+ else:
401
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
402
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
403
+ # remaining_length: int = target_length - current_length
404
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
405
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
406
+
407
+ if out_dim == 3:
408
+ if attention_mask.shape[0] < batch_size * head_size:
409
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
410
+ elif out_dim == 4:
411
+ attention_mask = attention_mask.unsqueeze(1)
412
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
413
+
414
+ return attention_mask
415
+
416
+ def norm_encoder_hidden_states(
417
+ self, encoder_hidden_states: torch.Tensor
418
+ ) -> torch.Tensor:
419
+ r"""
420
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
421
+ `Attention` class.
422
+
423
+ Args:
424
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
425
+
426
+ Returns:
427
+ `torch.Tensor`: The normalized encoder hidden states.
428
+ """
429
+ assert (
430
+ self.norm_cross is not None
431
+ ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
432
+
433
+ if isinstance(self.norm_cross, nn.LayerNorm):
434
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
435
+ elif isinstance(self.norm_cross, nn.GroupNorm):
436
+ # Group norm norms along the channels dimension and expects
437
+ # input to be in the shape of (N, C, *). In this case, we want
438
+ # to norm along the hidden dimension, so we need to move
439
+ # (batch_size, sequence_length, hidden_size) ->
440
+ # (batch_size, hidden_size, sequence_length)
441
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
442
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
443
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
444
+ else:
445
+ assert False
446
+
447
+ return encoder_hidden_states
448
+
449
+ @torch.no_grad()
450
+ def fuse_projections(self, fuse=True):
451
+ is_cross_attention = self.cross_attention_dim != self.query_dim
452
+ device = self.to_q.weight.data.device
453
+ dtype = self.to_q.weight.data.dtype
454
+
455
+ if not is_cross_attention:
456
+ # fetch weight matrices.
457
+ concatenated_weights = torch.cat(
458
+ [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
459
+ )
460
+ in_features = concatenated_weights.shape[1]
461
+ out_features = concatenated_weights.shape[0]
462
+
463
+ # create a new single projection layer and copy over the weights.
464
+ self.to_qkv = self.linear_cls(
465
+ in_features, out_features, bias=False, device=device, dtype=dtype
466
+ )
467
+ self.to_qkv.weight.copy_(concatenated_weights)
468
+
469
+ else:
470
+ concatenated_weights = torch.cat(
471
+ [self.to_k.weight.data, self.to_v.weight.data]
472
+ )
473
+ in_features = concatenated_weights.shape[1]
474
+ out_features = concatenated_weights.shape[0]
475
+
476
+ self.to_kv = self.linear_cls(
477
+ in_features, out_features, bias=False, device=device, dtype=dtype
478
+ )
479
+ self.to_kv.weight.copy_(concatenated_weights)
480
+
481
+ self.fused_projections = fuse
482
+
483
+
484
+ class AttnProcessor:
485
+ r"""
486
+ Default processor for performing attention-related computations.
487
+ """
488
+
489
+ def __call__(
490
+ self,
491
+ attn: Attention,
492
+ hidden_states: torch.FloatTensor,
493
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
494
+ attention_mask: Optional[torch.FloatTensor] = None,
495
+ ) -> torch.Tensor:
496
+ residual = hidden_states
497
+
498
+ input_ndim = hidden_states.ndim
499
+
500
+ if input_ndim == 4:
501
+ batch_size, channel, height, width = hidden_states.shape
502
+ hidden_states = hidden_states.view(
503
+ batch_size, channel, height * width
504
+ ).transpose(1, 2)
505
+
506
+ batch_size, sequence_length, _ = (
507
+ hidden_states.shape
508
+ if encoder_hidden_states is None
509
+ else encoder_hidden_states.shape
510
+ )
511
+ attention_mask = attn.prepare_attention_mask(
512
+ attention_mask, sequence_length, batch_size
513
+ )
514
+
515
+ if attn.group_norm is not None:
516
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
517
+ 1, 2
518
+ )
519
+
520
+ query = attn.to_q(hidden_states)
521
+
522
+ if encoder_hidden_states is None:
523
+ encoder_hidden_states = hidden_states
524
+ elif attn.norm_cross:
525
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
526
+ encoder_hidden_states
527
+ )
528
+
529
+ key = attn.to_k(encoder_hidden_states)
530
+ value = attn.to_v(encoder_hidden_states)
531
+
532
+ query = attn.head_to_batch_dim(query)
533
+ key = attn.head_to_batch_dim(key)
534
+ value = attn.head_to_batch_dim(value)
535
+
536
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
537
+ hidden_states = torch.bmm(attention_probs, value)
538
+ hidden_states = attn.batch_to_head_dim(hidden_states)
539
+
540
+ # linear proj
541
+ hidden_states = attn.to_out[0](hidden_states)
542
+ # dropout
543
+ hidden_states = attn.to_out[1](hidden_states)
544
+
545
+ if input_ndim == 4:
546
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
547
+ batch_size, channel, height, width
548
+ )
549
+
550
+ if attn.residual_connection:
551
+ hidden_states = hidden_states + residual
552
+
553
+ hidden_states = hidden_states / attn.rescale_output_factor
554
+
555
+ return hidden_states
556
+
557
+
558
+ class AttnProcessor2_0:
559
+ r"""
560
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
561
+ """
562
+
563
+ def __init__(self):
564
+ if not hasattr(F, "scaled_dot_product_attention"):
565
+ raise ImportError(
566
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
567
+ )
568
+
569
+ def __call__(
570
+ self,
571
+ attn: Attention,
572
+ hidden_states: torch.FloatTensor,
573
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
574
+ attention_mask: Optional[torch.FloatTensor] = None,
575
+ ) -> torch.FloatTensor:
576
+ residual = hidden_states
577
+
578
+ input_ndim = hidden_states.ndim
579
+
580
+ if input_ndim == 4:
581
+ batch_size, channel, height, width = hidden_states.shape
582
+ hidden_states = hidden_states.view(
583
+ batch_size, channel, height * width
584
+ ).transpose(1, 2)
585
+
586
+ batch_size, sequence_length, _ = (
587
+ hidden_states.shape
588
+ if encoder_hidden_states is None
589
+ else encoder_hidden_states.shape
590
+ )
591
+
592
+ if attention_mask is not None:
593
+ attention_mask = attn.prepare_attention_mask(
594
+ attention_mask, sequence_length, batch_size
595
+ )
596
+ # scaled_dot_product_attention expects attention_mask shape to be
597
+ # (batch, heads, source_length, target_length)
598
+ attention_mask = attention_mask.view(
599
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
600
+ )
601
+
602
+ if attn.group_norm is not None:
603
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
604
+ 1, 2
605
+ )
606
+
607
+ query = attn.to_q(hidden_states)
608
+
609
+ if encoder_hidden_states is None:
610
+ encoder_hidden_states = hidden_states
611
+ elif attn.norm_cross:
612
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
613
+ encoder_hidden_states
614
+ )
615
+
616
+ key = attn.to_k(encoder_hidden_states)
617
+ value = attn.to_v(encoder_hidden_states)
618
+
619
+ inner_dim = key.shape[-1]
620
+ head_dim = inner_dim // attn.heads
621
+
622
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
623
+
624
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
625
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
626
+
627
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
628
+ # TODO: add support for attn.scale when we move to Torch 2.1
629
+ hidden_states = F.scaled_dot_product_attention(
630
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
631
+ )
632
+
633
+ hidden_states = hidden_states.transpose(1, 2).reshape(
634
+ batch_size, -1, attn.heads * head_dim
635
+ )
636
+ hidden_states = hidden_states.to(query.dtype)
637
+
638
+ # linear proj
639
+ hidden_states = attn.to_out[0](hidden_states)
640
+ # dropout
641
+ hidden_states = attn.to_out[1](hidden_states)
642
+
643
+ if input_ndim == 4:
644
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
645
+ batch_size, channel, height, width
646
+ )
647
+
648
+ if attn.residual_connection:
649
+ hidden_states = hidden_states + residual
650
+
651
+ hidden_states = hidden_states / attn.rescale_output_factor
652
+
653
+ return hidden_states
tsr/models/transformer/basic_transformer_block.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from typing import Optional
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+ from torch import nn
44
+
45
+ from .attention import Attention
46
+
47
+
48
+ class BasicTransformerBlock(nn.Module):
49
+ r"""
50
+ A basic Transformer block.
51
+
52
+ Parameters:
53
+ dim (`int`): The number of channels in the input and output.
54
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
55
+ attention_head_dim (`int`): The number of channels in each head.
56
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
57
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
58
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
59
+ attention_bias (:
60
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
61
+ only_cross_attention (`bool`, *optional*):
62
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
63
+ double_self_attention (`bool`, *optional*):
64
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
65
+ upcast_attention (`bool`, *optional*):
66
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
67
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
68
+ Whether to use learnable elementwise affine parameters for normalization.
69
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
70
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
71
+ final_dropout (`bool` *optional*, defaults to False):
72
+ Whether to apply a final dropout after the last feed-forward layer.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ dim: int,
78
+ num_attention_heads: int,
79
+ attention_head_dim: int,
80
+ dropout=0.0,
81
+ cross_attention_dim: Optional[int] = None,
82
+ activation_fn: str = "geglu",
83
+ attention_bias: bool = False,
84
+ only_cross_attention: bool = False,
85
+ double_self_attention: bool = False,
86
+ upcast_attention: bool = False,
87
+ norm_elementwise_affine: bool = True,
88
+ norm_type: str = "layer_norm",
89
+ final_dropout: bool = False,
90
+ ):
91
+ super().__init__()
92
+ self.only_cross_attention = only_cross_attention
93
+
94
+ assert norm_type == "layer_norm"
95
+
96
+ # Define 3 blocks. Each block has its own normalization layer.
97
+ # 1. Self-Attn
98
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
99
+ self.attn1 = Attention(
100
+ query_dim=dim,
101
+ heads=num_attention_heads,
102
+ dim_head=attention_head_dim,
103
+ dropout=dropout,
104
+ bias=attention_bias,
105
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
106
+ upcast_attention=upcast_attention,
107
+ )
108
+
109
+ # 2. Cross-Attn
110
+ if cross_attention_dim is not None or double_self_attention:
111
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
112
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
113
+ # the second cross attention block.
114
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
115
+
116
+ self.attn2 = Attention(
117
+ query_dim=dim,
118
+ cross_attention_dim=(
119
+ cross_attention_dim if not double_self_attention else None
120
+ ),
121
+ heads=num_attention_heads,
122
+ dim_head=attention_head_dim,
123
+ dropout=dropout,
124
+ bias=attention_bias,
125
+ upcast_attention=upcast_attention,
126
+ ) # is self-attn if encoder_hidden_states is none
127
+ else:
128
+ self.norm2 = None
129
+ self.attn2 = None
130
+
131
+ # 3. Feed-forward
132
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
133
+ self.ff = FeedForward(
134
+ dim,
135
+ dropout=dropout,
136
+ activation_fn=activation_fn,
137
+ final_dropout=final_dropout,
138
+ )
139
+
140
+ # let chunk size default to None
141
+ self._chunk_size = None
142
+ self._chunk_dim = 0
143
+
144
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
145
+ # Sets chunk feed-forward
146
+ self._chunk_size = chunk_size
147
+ self._chunk_dim = dim
148
+
149
+ def forward(
150
+ self,
151
+ hidden_states: torch.FloatTensor,
152
+ attention_mask: Optional[torch.FloatTensor] = None,
153
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
154
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
155
+ ) -> torch.FloatTensor:
156
+ # Notice that normalization is always applied before the real computation in the following blocks.
157
+ # 0. Self-Attention
158
+ norm_hidden_states = self.norm1(hidden_states)
159
+
160
+ attn_output = self.attn1(
161
+ norm_hidden_states,
162
+ encoder_hidden_states=(
163
+ encoder_hidden_states if self.only_cross_attention else None
164
+ ),
165
+ attention_mask=attention_mask,
166
+ )
167
+
168
+ hidden_states = attn_output + hidden_states
169
+
170
+ # 3. Cross-Attention
171
+ if self.attn2 is not None:
172
+ norm_hidden_states = self.norm2(hidden_states)
173
+
174
+ attn_output = self.attn2(
175
+ norm_hidden_states,
176
+ encoder_hidden_states=encoder_hidden_states,
177
+ attention_mask=encoder_attention_mask,
178
+ )
179
+ hidden_states = attn_output + hidden_states
180
+
181
+ # 4. Feed-forward
182
+ norm_hidden_states = self.norm3(hidden_states)
183
+
184
+ if self._chunk_size is not None:
185
+ # "feed_forward_chunk_size" can be used to save memory
186
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
187
+ raise ValueError(
188
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
189
+ )
190
+
191
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
192
+ ff_output = torch.cat(
193
+ [
194
+ self.ff(hid_slice)
195
+ for hid_slice in norm_hidden_states.chunk(
196
+ num_chunks, dim=self._chunk_dim
197
+ )
198
+ ],
199
+ dim=self._chunk_dim,
200
+ )
201
+ else:
202
+ ff_output = self.ff(norm_hidden_states)
203
+
204
+ hidden_states = ff_output + hidden_states
205
+
206
+ return hidden_states
207
+
208
+
209
+ class FeedForward(nn.Module):
210
+ r"""
211
+ A feed-forward layer.
212
+
213
+ Parameters:
214
+ dim (`int`): The number of channels in the input.
215
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
216
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
217
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
218
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
219
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ dim: int,
225
+ dim_out: Optional[int] = None,
226
+ mult: int = 4,
227
+ dropout: float = 0.0,
228
+ activation_fn: str = "geglu",
229
+ final_dropout: bool = False,
230
+ ):
231
+ super().__init__()
232
+ inner_dim = int(dim * mult)
233
+ dim_out = dim_out if dim_out is not None else dim
234
+ linear_cls = nn.Linear
235
+
236
+ if activation_fn == "gelu":
237
+ act_fn = GELU(dim, inner_dim)
238
+ if activation_fn == "gelu-approximate":
239
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
240
+ elif activation_fn == "geglu":
241
+ act_fn = GEGLU(dim, inner_dim)
242
+ elif activation_fn == "geglu-approximate":
243
+ act_fn = ApproximateGELU(dim, inner_dim)
244
+
245
+ self.net = nn.ModuleList([])
246
+ # project in
247
+ self.net.append(act_fn)
248
+ # project dropout
249
+ self.net.append(nn.Dropout(dropout))
250
+ # project out
251
+ self.net.append(linear_cls(inner_dim, dim_out))
252
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
253
+ if final_dropout:
254
+ self.net.append(nn.Dropout(dropout))
255
+
256
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
257
+ for module in self.net:
258
+ hidden_states = module(hidden_states)
259
+ return hidden_states
260
+
261
+
262
+ class GELU(nn.Module):
263
+ r"""
264
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
265
+
266
+ Parameters:
267
+ dim_in (`int`): The number of channels in the input.
268
+ dim_out (`int`): The number of channels in the output.
269
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
270
+ """
271
+
272
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
273
+ super().__init__()
274
+ self.proj = nn.Linear(dim_in, dim_out)
275
+ self.approximate = approximate
276
+
277
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
278
+ if gate.device.type != "mps":
279
+ return F.gelu(gate, approximate=self.approximate)
280
+ # mps: gelu is not implemented for float16
281
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
282
+ dtype=gate.dtype
283
+ )
284
+
285
+ def forward(self, hidden_states):
286
+ hidden_states = self.proj(hidden_states)
287
+ hidden_states = self.gelu(hidden_states)
288
+ return hidden_states
289
+
290
+
291
+ class GEGLU(nn.Module):
292
+ r"""
293
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
294
+
295
+ Parameters:
296
+ dim_in (`int`): The number of channels in the input.
297
+ dim_out (`int`): The number of channels in the output.
298
+ """
299
+
300
+ def __init__(self, dim_in: int, dim_out: int):
301
+ super().__init__()
302
+ linear_cls = nn.Linear
303
+
304
+ self.proj = linear_cls(dim_in, dim_out * 2)
305
+
306
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
307
+ if gate.device.type != "mps":
308
+ return F.gelu(gate)
309
+ # mps: gelu is not implemented for float16
310
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
311
+
312
+ def forward(self, hidden_states, scale: float = 1.0):
313
+ args = ()
314
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
315
+ return hidden_states * self.gelu(gate)
316
+
317
+
318
+ class ApproximateGELU(nn.Module):
319
+ r"""
320
+ The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
321
+ https://arxiv.org/abs/1606.08415.
322
+
323
+ Parameters:
324
+ dim_in (`int`): The number of channels in the input.
325
+ dim_out (`int`): The number of channels in the output.
326
+ """
327
+
328
+ def __init__(self, dim_in: int, dim_out: int):
329
+ super().__init__()
330
+ self.proj = nn.Linear(dim_in, dim_out)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ x = self.proj(x)
334
+ return x * torch.sigmoid(1.702 * x)
tsr/models/transformer/transformer_1d.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from dataclasses import dataclass
40
+ from typing import Optional
41
+
42
+ import torch
43
+ import torch.nn.functional as F
44
+ from torch import nn
45
+
46
+ from ...utils import BaseModule
47
+ from .basic_transformer_block import BasicTransformerBlock
48
+
49
+
50
+ class Transformer1D(BaseModule):
51
+ @dataclass
52
+ class Config(BaseModule.Config):
53
+ num_attention_heads: int = 16
54
+ attention_head_dim: int = 88
55
+ in_channels: Optional[int] = None
56
+ out_channels: Optional[int] = None
57
+ num_layers: int = 1
58
+ dropout: float = 0.0
59
+ norm_num_groups: int = 32
60
+ cross_attention_dim: Optional[int] = None
61
+ attention_bias: bool = False
62
+ activation_fn: str = "geglu"
63
+ only_cross_attention: bool = False
64
+ double_self_attention: bool = False
65
+ upcast_attention: bool = False
66
+ norm_type: str = "layer_norm"
67
+ norm_elementwise_affine: bool = True
68
+ gradient_checkpointing: bool = False
69
+
70
+ cfg: Config
71
+
72
+ def configure(self) -> None:
73
+ self.num_attention_heads = self.cfg.num_attention_heads
74
+ self.attention_head_dim = self.cfg.attention_head_dim
75
+ inner_dim = self.num_attention_heads * self.attention_head_dim
76
+
77
+ linear_cls = nn.Linear
78
+
79
+ # 2. Define input layers
80
+ self.in_channels = self.cfg.in_channels
81
+
82
+ self.norm = torch.nn.GroupNorm(
83
+ num_groups=self.cfg.norm_num_groups,
84
+ num_channels=self.cfg.in_channels,
85
+ eps=1e-6,
86
+ affine=True,
87
+ )
88
+ self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
89
+
90
+ # 3. Define transformers blocks
91
+ self.transformer_blocks = nn.ModuleList(
92
+ [
93
+ BasicTransformerBlock(
94
+ inner_dim,
95
+ self.num_attention_heads,
96
+ self.attention_head_dim,
97
+ dropout=self.cfg.dropout,
98
+ cross_attention_dim=self.cfg.cross_attention_dim,
99
+ activation_fn=self.cfg.activation_fn,
100
+ attention_bias=self.cfg.attention_bias,
101
+ only_cross_attention=self.cfg.only_cross_attention,
102
+ double_self_attention=self.cfg.double_self_attention,
103
+ upcast_attention=self.cfg.upcast_attention,
104
+ norm_type=self.cfg.norm_type,
105
+ norm_elementwise_affine=self.cfg.norm_elementwise_affine,
106
+ )
107
+ for d in range(self.cfg.num_layers)
108
+ ]
109
+ )
110
+
111
+ # 4. Define output layers
112
+ self.out_channels = (
113
+ self.cfg.in_channels
114
+ if self.cfg.out_channels is None
115
+ else self.cfg.out_channels
116
+ )
117
+
118
+ self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
119
+
120
+ self.gradient_checkpointing = self.cfg.gradient_checkpointing
121
+
122
+ def forward(
123
+ self,
124
+ hidden_states: torch.Tensor,
125
+ encoder_hidden_states: Optional[torch.Tensor] = None,
126
+ attention_mask: Optional[torch.Tensor] = None,
127
+ encoder_attention_mask: Optional[torch.Tensor] = None,
128
+ ):
129
+ """
130
+ The [`Transformer1DModel`] forward method.
131
+
132
+ Args:
133
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
134
+ Input `hidden_states`.
135
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
136
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
137
+ self-attention.
138
+ attention_mask ( `torch.Tensor`, *optional*):
139
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
140
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
141
+ negative values to the attention scores corresponding to "discard" tokens.
142
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
143
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
144
+
145
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
146
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
147
+
148
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
149
+ above. This bias will be added to the cross-attention scores.
150
+
151
+ Returns:
152
+ torch.FloatTensor
153
+ """
154
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
155
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
156
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
157
+ # expects mask of shape:
158
+ # [batch, key_tokens]
159
+ # adds singleton query_tokens dimension:
160
+ # [batch, 1, key_tokens]
161
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
162
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
163
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
164
+ if attention_mask is not None and attention_mask.ndim == 2:
165
+ # assume that mask is expressed as:
166
+ # (1 = keep, 0 = discard)
167
+ # convert mask into a bias that can be added to attention scores:
168
+ # (keep = +0, discard = -10000.0)
169
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
170
+ attention_mask = attention_mask.unsqueeze(1)
171
+
172
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
173
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
174
+ encoder_attention_mask = (
175
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
176
+ ) * -10000.0
177
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
178
+
179
+ # 1. Input
180
+ batch, _, seq_len = hidden_states.shape
181
+ residual = hidden_states
182
+
183
+ hidden_states = self.norm(hidden_states)
184
+ inner_dim = hidden_states.shape[1]
185
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(
186
+ batch, seq_len, inner_dim
187
+ )
188
+ hidden_states = self.proj_in(hidden_states)
189
+
190
+ # 2. Blocks
191
+ for block in self.transformer_blocks:
192
+ if self.training and self.gradient_checkpointing:
193
+ hidden_states = torch.utils.checkpoint.checkpoint(
194
+ block,
195
+ hidden_states,
196
+ attention_mask,
197
+ encoder_hidden_states,
198
+ encoder_attention_mask,
199
+ use_reentrant=False,
200
+ )
201
+ else:
202
+ hidden_states = block(
203
+ hidden_states,
204
+ attention_mask=attention_mask,
205
+ encoder_hidden_states=encoder_hidden_states,
206
+ encoder_attention_mask=encoder_attention_mask,
207
+ )
208
+
209
+ # 3. Output
210
+ hidden_states = self.proj_out(hidden_states)
211
+ hidden_states = (
212
+ hidden_states.reshape(batch, seq_len, inner_dim)
213
+ .permute(0, 2, 1)
214
+ .contiguous()
215
+ )
216
+
217
+ output = hidden_states + residual
218
+
219
+ return output
tsr/system.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import List, Union
5
+
6
+ import numpy as np
7
+ import PIL.Image
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import trimesh
11
+ from einops import rearrange
12
+ from huggingface_hub import hf_hub_download
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+
16
+ from .models.isosurface import MarchingCubeHelper
17
+ from .utils import (
18
+ BaseModule,
19
+ ImagePreprocessor,
20
+ find_class,
21
+ get_spherical_cameras,
22
+ scale_tensor,
23
+ )
24
+
25
+
26
+ class TSR(BaseModule):
27
+ @dataclass
28
+ class Config(BaseModule.Config):
29
+ cond_image_size: int
30
+
31
+ image_tokenizer_cls: str
32
+ image_tokenizer: dict
33
+
34
+ tokenizer_cls: str
35
+ tokenizer: dict
36
+
37
+ backbone_cls: str
38
+ backbone: dict
39
+
40
+ post_processor_cls: str
41
+ post_processor: dict
42
+
43
+ decoder_cls: str
44
+ decoder: dict
45
+
46
+ renderer_cls: str
47
+ renderer: dict
48
+
49
+ cfg: Config
50
+
51
+ @classmethod
52
+ def from_pretrained(
53
+ cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
54
+ ):
55
+ if os.path.isdir(pretrained_model_name_or_path):
56
+ config_path = os.path.join(pretrained_model_name_or_path, config_name)
57
+ weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
58
+ else:
59
+ config_path = hf_hub_download(
60
+ repo_id=pretrained_model_name_or_path, filename=config_name
61
+ )
62
+ weight_path = hf_hub_download(
63
+ repo_id=pretrained_model_name_or_path, filename=weight_name
64
+ )
65
+
66
+ cfg = OmegaConf.load(config_path)
67
+ OmegaConf.resolve(cfg)
68
+ model = cls(cfg)
69
+ ckpt = torch.load(weight_path, map_location="cpu")
70
+ model.load_state_dict(ckpt)
71
+ return model
72
+
73
+ def configure(self):
74
+ self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
75
+ self.cfg.image_tokenizer
76
+ )
77
+ self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
78
+ self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
79
+ self.post_processor = find_class(self.cfg.post_processor_cls)(
80
+ self.cfg.post_processor
81
+ )
82
+ self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
83
+ self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
84
+ self.image_processor = ImagePreprocessor()
85
+ self.isosurface_helper = None
86
+
87
+ def forward(
88
+ self,
89
+ image: Union[
90
+ PIL.Image.Image,
91
+ np.ndarray,
92
+ torch.FloatTensor,
93
+ List[PIL.Image.Image],
94
+ List[np.ndarray],
95
+ List[torch.FloatTensor],
96
+ ],
97
+ device: str,
98
+ ) -> torch.FloatTensor:
99
+ rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
100
+ device
101
+ )
102
+ batch_size = rgb_cond.shape[0]
103
+
104
+ input_image_tokens: torch.Tensor = self.image_tokenizer(
105
+ rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
106
+ )
107
+
108
+ input_image_tokens = rearrange(
109
+ input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
110
+ )
111
+
112
+ tokens: torch.Tensor = self.tokenizer(batch_size)
113
+
114
+ tokens = self.backbone(
115
+ tokens,
116
+ encoder_hidden_states=input_image_tokens,
117
+ )
118
+
119
+ scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
120
+ return scene_codes
121
+
122
+ def render(
123
+ self,
124
+ scene_codes,
125
+ n_views: int,
126
+ elevation_deg: float = 0.0,
127
+ camera_distance: float = 1.9,
128
+ fovy_deg: float = 40.0,
129
+ height: int = 256,
130
+ width: int = 256,
131
+ return_type: str = "pil",
132
+ ):
133
+ rays_o, rays_d = get_spherical_cameras(
134
+ n_views, elevation_deg, camera_distance, fovy_deg, height, width
135
+ )
136
+ rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
137
+
138
+ def process_output(image: torch.FloatTensor):
139
+ if return_type == "pt":
140
+ return image
141
+ elif return_type == "np":
142
+ return image.detach().cpu().numpy()
143
+ elif return_type == "pil":
144
+ return Image.fromarray(
145
+ (image.detach().cpu().numpy() * 255.0).astype(np.uint8)
146
+ )
147
+ else:
148
+ raise NotImplementedError
149
+
150
+ images = []
151
+ for scene_code in scene_codes:
152
+ images_ = []
153
+ for i in range(n_views):
154
+ with torch.no_grad():
155
+ image = self.renderer(
156
+ self.decoder, scene_code, rays_o[i], rays_d[i]
157
+ )
158
+ images_.append(process_output(image))
159
+ images.append(images_)
160
+
161
+ return images
162
+
163
+ def set_marching_cubes_resolution(self, resolution: int):
164
+ if (
165
+ self.isosurface_helper is not None
166
+ and self.isosurface_helper.resolution == resolution
167
+ ):
168
+ return
169
+ self.isosurface_helper = MarchingCubeHelper(resolution)
170
+
171
+ def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0):
172
+ self.set_marching_cubes_resolution(resolution)
173
+ meshes = []
174
+ for scene_code in scene_codes:
175
+ with torch.no_grad():
176
+ density = self.renderer.query_triplane(
177
+ self.decoder,
178
+ scale_tensor(
179
+ self.isosurface_helper.grid_vertices.to(scene_codes.device),
180
+ self.isosurface_helper.points_range,
181
+ (-self.renderer.cfg.radius, self.renderer.cfg.radius),
182
+ ),
183
+ scene_code,
184
+ )["density_act"]
185
+ v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
186
+ v_pos = scale_tensor(
187
+ v_pos,
188
+ self.isosurface_helper.points_range,
189
+ (-self.renderer.cfg.radius, self.renderer.cfg.radius),
190
+ )
191
+ with torch.no_grad():
192
+ color = self.renderer.query_triplane(
193
+ self.decoder,
194
+ v_pos,
195
+ scene_code,
196
+ )["color"]
197
+ mesh = trimesh.Trimesh(
198
+ vertices=v_pos.cpu().numpy(),
199
+ faces=t_pos_idx.cpu().numpy(),
200
+ vertex_colors=color.cpu().numpy(),
201
+ )
202
+ meshes.append(mesh)
203
+ return meshes
tsr/utils.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import math
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import imageio
8
+ import numpy as np
9
+ import PIL.Image
10
+ import rembg
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import trimesh
15
+ from omegaconf import DictConfig, OmegaConf
16
+ from PIL import Image
17
+
18
+
19
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
20
+ scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
21
+ return scfg
22
+
23
+
24
+ def find_class(cls_string):
25
+ module_string = ".".join(cls_string.split(".")[:-1])
26
+ cls_name = cls_string.split(".")[-1]
27
+ module = importlib.import_module(module_string, package=None)
28
+ cls = getattr(module, cls_name)
29
+ return cls
30
+
31
+
32
+ def get_intrinsic_from_fov(fov, H, W, bs=-1):
33
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
34
+ intrinsic = np.identity(3, dtype=np.float32)
35
+ intrinsic[0, 0] = focal_length
36
+ intrinsic[1, 1] = focal_length
37
+ intrinsic[0, 2] = W / 2.0
38
+ intrinsic[1, 2] = H / 2.0
39
+
40
+ if bs > 0:
41
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
42
+
43
+ return torch.from_numpy(intrinsic)
44
+
45
+
46
+ class BaseModule(nn.Module):
47
+ @dataclass
48
+ class Config:
49
+ pass
50
+
51
+ cfg: Config # add this to every subclass of BaseModule to enable static type checking
52
+
53
+ def __init__(
54
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
55
+ ) -> None:
56
+ super().__init__()
57
+ self.cfg = parse_structured(self.Config, cfg)
58
+ self.configure(*args, **kwargs)
59
+
60
+ def configure(self, *args, **kwargs) -> None:
61
+ raise NotImplementedError
62
+
63
+
64
+ class ImagePreprocessor:
65
+ def convert_and_resize(
66
+ self,
67
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
68
+ size: int,
69
+ ):
70
+ if isinstance(image, PIL.Image.Image):
71
+ image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
72
+ elif isinstance(image, np.ndarray):
73
+ if image.dtype == np.uint8:
74
+ image = torch.from_numpy(image.astype(np.float32) / 255.0)
75
+ else:
76
+ image = torch.from_numpy(image)
77
+ elif isinstance(image, torch.Tensor):
78
+ pass
79
+
80
+ batched = image.ndim == 4
81
+
82
+ if not batched:
83
+ image = image[None, ...]
84
+ image = F.interpolate(
85
+ image.permute(0, 3, 1, 2),
86
+ (size, size),
87
+ mode="bilinear",
88
+ align_corners=False,
89
+ antialias=True,
90
+ ).permute(0, 2, 3, 1)
91
+ if not batched:
92
+ image = image[0]
93
+ return image
94
+
95
+ def __call__(
96
+ self,
97
+ image: Union[
98
+ PIL.Image.Image,
99
+ np.ndarray,
100
+ torch.FloatTensor,
101
+ List[PIL.Image.Image],
102
+ List[np.ndarray],
103
+ List[torch.FloatTensor],
104
+ ],
105
+ size: int,
106
+ ) -> Any:
107
+ if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
108
+ image = self.convert_and_resize(image, size)
109
+ else:
110
+ if not isinstance(image, list):
111
+ image = [image]
112
+ image = [self.convert_and_resize(im, size) for im in image]
113
+ image = torch.stack(image, dim=0)
114
+ return image
115
+
116
+
117
+ def rays_intersect_bbox(
118
+ rays_o: torch.Tensor,
119
+ rays_d: torch.Tensor,
120
+ radius: float,
121
+ near: float = 0.0,
122
+ valid_thresh: float = 0.01,
123
+ ):
124
+ input_shape = rays_o.shape[:-1]
125
+ rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
126
+ rays_d_valid = torch.where(
127
+ rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
128
+ )
129
+ if type(radius) in [int, float]:
130
+ radius = torch.FloatTensor(
131
+ [[-radius, radius], [-radius, radius], [-radius, radius]]
132
+ ).to(rays_o.device)
133
+ radius = (
134
+ 1.0 - 1.0e-3
135
+ ) * radius # tighten the radius to make sure the intersection point lies in the bounding box
136
+ interx0 = (radius[..., 1] - rays_o) / rays_d_valid
137
+ interx1 = (radius[..., 0] - rays_o) / rays_d_valid
138
+ t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
139
+ t_far = torch.maximum(interx0, interx1).amin(dim=-1)
140
+
141
+ # check wheter a ray intersects the bbox or not
142
+ rays_valid = t_far - t_near > valid_thresh
143
+
144
+ t_near[torch.where(~rays_valid)] = 0.0
145
+ t_far[torch.where(~rays_valid)] = 0.0
146
+
147
+ t_near = t_near.view(*input_shape, 1)
148
+ t_far = t_far.view(*input_shape, 1)
149
+ rays_valid = rays_valid.view(*input_shape)
150
+
151
+ return t_near, t_far, rays_valid
152
+
153
+
154
+ def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
155
+ if chunk_size <= 0:
156
+ return func(*args, **kwargs)
157
+ B = None
158
+ for arg in list(args) + list(kwargs.values()):
159
+ if isinstance(arg, torch.Tensor):
160
+ B = arg.shape[0]
161
+ break
162
+ assert (
163
+ B is not None
164
+ ), "No tensor found in args or kwargs, cannot determine batch size."
165
+ out = defaultdict(list)
166
+ out_type = None
167
+ # max(1, B) to support B == 0
168
+ for i in range(0, max(1, B), chunk_size):
169
+ out_chunk = func(
170
+ *[
171
+ arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
172
+ for arg in args
173
+ ],
174
+ **{
175
+ k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
176
+ for k, arg in kwargs.items()
177
+ },
178
+ )
179
+ if out_chunk is None:
180
+ continue
181
+ out_type = type(out_chunk)
182
+ if isinstance(out_chunk, torch.Tensor):
183
+ out_chunk = {0: out_chunk}
184
+ elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
185
+ chunk_length = len(out_chunk)
186
+ out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
187
+ elif isinstance(out_chunk, dict):
188
+ pass
189
+ else:
190
+ print(
191
+ f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
192
+ )
193
+ exit(1)
194
+ for k, v in out_chunk.items():
195
+ v = v if torch.is_grad_enabled() else v.detach()
196
+ out[k].append(v)
197
+
198
+ if out_type is None:
199
+ return None
200
+
201
+ out_merged: Dict[Any, Optional[torch.Tensor]] = {}
202
+ for k, v in out.items():
203
+ if all([vv is None for vv in v]):
204
+ # allow None in return value
205
+ out_merged[k] = None
206
+ elif all([isinstance(vv, torch.Tensor) for vv in v]):
207
+ out_merged[k] = torch.cat(v, dim=0)
208
+ else:
209
+ raise TypeError(
210
+ f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
211
+ )
212
+
213
+ if out_type is torch.Tensor:
214
+ return out_merged[0]
215
+ elif out_type in [tuple, list]:
216
+ return out_type([out_merged[i] for i in range(chunk_length)])
217
+ elif out_type is dict:
218
+ return out_merged
219
+
220
+
221
+ ValidScale = Union[Tuple[float, float], torch.FloatTensor]
222
+
223
+
224
+ def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale):
225
+ if inp_scale is None:
226
+ inp_scale = (0, 1)
227
+ if tgt_scale is None:
228
+ tgt_scale = (0, 1)
229
+ if isinstance(tgt_scale, torch.FloatTensor):
230
+ assert dat.shape[-1] == tgt_scale.shape[-1]
231
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
232
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
233
+ return dat
234
+
235
+
236
+ def get_activation(name) -> Callable:
237
+ if name is None:
238
+ return lambda x: x
239
+ name = name.lower()
240
+ if name == "none":
241
+ return lambda x: x
242
+ elif name == "exp":
243
+ return lambda x: torch.exp(x)
244
+ elif name == "sigmoid":
245
+ return lambda x: torch.sigmoid(x)
246
+ elif name == "tanh":
247
+ return lambda x: torch.tanh(x)
248
+ elif name == "softplus":
249
+ return lambda x: F.softplus(x)
250
+ else:
251
+ try:
252
+ return getattr(F, name)
253
+ except AttributeError:
254
+ raise ValueError(f"Unknown activation function: {name}")
255
+
256
+
257
+ def get_ray_directions(
258
+ H: int,
259
+ W: int,
260
+ focal: Union[float, Tuple[float, float]],
261
+ principal: Optional[Tuple[float, float]] = None,
262
+ use_pixel_centers: bool = True,
263
+ normalize: bool = True,
264
+ ) -> torch.FloatTensor:
265
+ """
266
+ Get ray directions for all pixels in camera coordinate.
267
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
268
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
269
+
270
+ Inputs:
271
+ H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
272
+ Outputs:
273
+ directions: (H, W, 3), the direction of the rays in camera coordinate
274
+ """
275
+ pixel_center = 0.5 if use_pixel_centers else 0
276
+
277
+ if isinstance(focal, float):
278
+ fx, fy = focal, focal
279
+ cx, cy = W / 2, H / 2
280
+ else:
281
+ fx, fy = focal
282
+ assert principal is not None
283
+ cx, cy = principal
284
+
285
+ i, j = torch.meshgrid(
286
+ torch.arange(W, dtype=torch.float32) + pixel_center,
287
+ torch.arange(H, dtype=torch.float32) + pixel_center,
288
+ indexing="xy",
289
+ )
290
+
291
+ directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1)
292
+
293
+ if normalize:
294
+ directions = F.normalize(directions, dim=-1)
295
+
296
+ return directions
297
+
298
+
299
+ def get_rays(
300
+ directions,
301
+ c2w,
302
+ keepdim=False,
303
+ normalize=False,
304
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
305
+ # Rotate ray directions from camera coordinate to the world coordinate
306
+ assert directions.shape[-1] == 3
307
+
308
+ if directions.ndim == 2: # (N_rays, 3)
309
+ if c2w.ndim == 2: # (4, 4)
310
+ c2w = c2w[None, :, :]
311
+ assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
312
+ rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
313
+ rays_o = c2w[:, :3, 3].expand(rays_d.shape)
314
+ elif directions.ndim == 3: # (H, W, 3)
315
+ assert c2w.ndim in [2, 3]
316
+ if c2w.ndim == 2: # (4, 4)
317
+ rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
318
+ -1
319
+ ) # (H, W, 3)
320
+ rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
321
+ elif c2w.ndim == 3: # (B, 4, 4)
322
+ rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
323
+ -1
324
+ ) # (B, H, W, 3)
325
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
326
+ elif directions.ndim == 4: # (B, H, W, 3)
327
+ assert c2w.ndim == 3 # (B, 4, 4)
328
+ rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
329
+ -1
330
+ ) # (B, H, W, 3)
331
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
332
+
333
+ if normalize:
334
+ rays_d = F.normalize(rays_d, dim=-1)
335
+ if not keepdim:
336
+ rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
337
+
338
+ return rays_o, rays_d
339
+
340
+
341
+ def get_spherical_cameras(
342
+ n_views: int,
343
+ elevation_deg: float,
344
+ camera_distance: float,
345
+ fovy_deg: float,
346
+ height: int,
347
+ width: int,
348
+ ):
349
+ azimuth_deg = torch.linspace(0, 360.0, n_views + 1)[:n_views]
350
+ elevation_deg = torch.full_like(azimuth_deg, elevation_deg)
351
+ camera_distances = torch.full_like(elevation_deg, camera_distance)
352
+
353
+ elevation = elevation_deg * math.pi / 180
354
+ azimuth = azimuth_deg * math.pi / 180
355
+
356
+ # convert spherical coordinates to cartesian coordinates
357
+ # right hand coordinate system, x back, y right, z up
358
+ # elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
359
+ camera_positions = torch.stack(
360
+ [
361
+ camera_distances * torch.cos(elevation) * torch.cos(azimuth),
362
+ camera_distances * torch.cos(elevation) * torch.sin(azimuth),
363
+ camera_distances * torch.sin(elevation),
364
+ ],
365
+ dim=-1,
366
+ )
367
+
368
+ # default scene center at origin
369
+ center = torch.zeros_like(camera_positions)
370
+ # default camera up direction as +z
371
+ up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
372
+
373
+ fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180
374
+
375
+ lookat = F.normalize(center - camera_positions, dim=-1)
376
+ right = F.normalize(torch.cross(lookat, up), dim=-1)
377
+ up = F.normalize(torch.cross(right, lookat), dim=-1)
378
+ c2w3x4 = torch.cat(
379
+ [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
380
+ dim=-1,
381
+ )
382
+ c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
383
+ c2w[:, 3, 3] = 1.0
384
+
385
+ # get directions by dividing directions_unit_focal by focal length
386
+ focal_length = 0.5 * height / torch.tan(0.5 * fovy)
387
+ directions_unit_focal = get_ray_directions(
388
+ H=height,
389
+ W=width,
390
+ focal=1.0,
391
+ )
392
+ directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1)
393
+ directions[:, :, :, :2] = (
394
+ directions[:, :, :, :2] / focal_length[:, None, None, None]
395
+ )
396
+ # must use normalize=True to normalize directions here
397
+ rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True)
398
+
399
+ return rays_o, rays_d
400
+
401
+
402
+ def remove_background(
403
+ image: PIL.Image.Image,
404
+ rembg_session: Any = None,
405
+ force: bool = False,
406
+ **rembg_kwargs,
407
+ ) -> PIL.Image.Image:
408
+ do_remove = True
409
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
410
+ do_remove = False
411
+ do_remove = do_remove or force
412
+ if do_remove:
413
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
414
+ return image
415
+
416
+
417
+ def resize_foreground(
418
+ image: PIL.Image.Image,
419
+ ratio: float,
420
+ ) -> PIL.Image.Image:
421
+ image = np.array(image)
422
+ assert image.shape[-1] == 4
423
+ alpha = np.where(image[..., 3] > 0)
424
+ y1, y2, x1, x2 = (
425
+ alpha[0].min(),
426
+ alpha[0].max(),
427
+ alpha[1].min(),
428
+ alpha[1].max(),
429
+ )
430
+ # crop the foreground
431
+ fg = image[y1:y2, x1:x2]
432
+ # pad to square
433
+ size = max(fg.shape[0], fg.shape[1])
434
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
435
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
436
+ new_image = np.pad(
437
+ fg,
438
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
439
+ mode="constant",
440
+ constant_values=((0, 0), (0, 0), (0, 0)),
441
+ )
442
+
443
+ # compute padding according to the ratio
444
+ new_size = int(new_image.shape[0] / ratio)
445
+ # pad to size, double side
446
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
447
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
448
+ new_image = np.pad(
449
+ new_image,
450
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
451
+ mode="constant",
452
+ constant_values=((0, 0), (0, 0), (0, 0)),
453
+ )
454
+ new_image = PIL.Image.fromarray(new_image)
455
+ return new_image
456
+
457
+
458
+ def save_video(
459
+ frames: List[PIL.Image.Image],
460
+ output_path: str,
461
+ fps: int = 30,
462
+ ):
463
+ # use imageio to save video
464
+ frames = [np.array(frame) for frame in frames]
465
+ writer = imageio.get_writer(output_path, fps=fps)
466
+ for frame in frames:
467
+ writer.append_data(frame)
468
+ writer.close()
469
+
470
+
471
+ def to_gradio_3d_orientation(mesh):
472
+ mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
473
+ mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
474
+ return mesh