Spaces:
Runtime error
Runtime error
ShoufaChen
commited on
Commit
•
4d20c2f
0
Parent(s):
init
Browse files- .gitattributes +35 -0
- .gitignore +34 -0
- README.md +13 -0
- app.py +160 -0
- imagenet_en_cn.py +1002 -0
- models/generate.py +176 -0
- models/gpt.py +465 -0
- requirements.txt +1 -0
- tokenizer_image/discriminator.py +255 -0
- tokenizer_image/discriminator_patchgan.py +152 -0
- tokenizer_image/discriminator_stylegan.py +101 -0
- tokenizer_image/lpips.py +164 -0
- tokenizer_image/reconstruction_vq_ddp.py +197 -0
- tokenizer_image/vq_demo.py +84 -0
- tokenizer_image/vq_loss.py +168 -0
- tokenizer_image/vq_model.py +424 -0
- tokenizer_image/vq_train.py +316 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__
|
3 |
+
*.pyc
|
4 |
+
*.egg-info
|
5 |
+
dist
|
6 |
+
.ipynb_checkpoints
|
7 |
+
*.ipynb
|
8 |
+
|
9 |
+
# Log
|
10 |
+
*.log
|
11 |
+
*.log.*
|
12 |
+
*.json
|
13 |
+
*.jsonl
|
14 |
+
|
15 |
+
# Data
|
16 |
+
datasets
|
17 |
+
*.zip
|
18 |
+
*.png
|
19 |
+
*.jpg
|
20 |
+
*.jpeg
|
21 |
+
|
22 |
+
# Model
|
23 |
+
checkpoints
|
24 |
+
ckpts*
|
25 |
+
*.ckpt
|
26 |
+
*.pth
|
27 |
+
*.pt
|
28 |
+
pretrained_models
|
29 |
+
|
30 |
+
# Other
|
31 |
+
.DS_Store
|
32 |
+
wandb
|
33 |
+
output
|
34 |
+
results
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: LlamaGen
|
3 |
+
emoji: 🏆
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.36.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import gradio as gr
|
3 |
+
from imagenet_en_cn import IMAGENET_1K_CLASSES
|
4 |
+
from huggingface_hub import hf_hub_download
|
5 |
+
import torch
|
6 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
7 |
+
torch.backends.cudnn.allow_tf32 = True
|
8 |
+
torch.set_float32_matmul_precision('high')
|
9 |
+
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
|
10 |
+
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
|
11 |
+
|
12 |
+
import time
|
13 |
+
import argparse
|
14 |
+
from tokenizer_image.vq_model import VQ_models
|
15 |
+
from models.gpt import GPT_models
|
16 |
+
from models.generate import generate
|
17 |
+
|
18 |
+
device = "cuda"
|
19 |
+
|
20 |
+
model2ckpt = {
|
21 |
+
"GPT-XL": ("vq_ds16_c2i.pt", "c2i_XL_384.pt", 384),
|
22 |
+
"GPT-B": ("vq_ds16_c2i.pt", "c2i_B_256.pt", 256),
|
23 |
+
}
|
24 |
+
|
25 |
+
def load_model(args):
|
26 |
+
ckpt_folder = "./"
|
27 |
+
vq_ckpt, gpt_ckpt, image_size = model2ckpt[args.gpt_model]
|
28 |
+
hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=vq_ckpt, local_dir=ckpt_folder)
|
29 |
+
hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=gpt_ckpt, local_dir=ckpt_folder)
|
30 |
+
# create and load model
|
31 |
+
vq_model = VQ_models[args.vq_model](
|
32 |
+
codebook_size=args.codebook_size,
|
33 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
34 |
+
vq_model.to(device)
|
35 |
+
vq_model.eval()
|
36 |
+
checkpoint = torch.load(f"{ckpt_folder}{vq_ckpt}", map_location="cpu")
|
37 |
+
vq_model.load_state_dict(checkpoint["model"])
|
38 |
+
del checkpoint
|
39 |
+
print(f"image tokenizer is loaded")
|
40 |
+
|
41 |
+
# create and load gpt model
|
42 |
+
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
|
43 |
+
latent_size = image_size // args.downsample_size
|
44 |
+
gpt_model = GPT_models[args.gpt_model](
|
45 |
+
vocab_size=args.codebook_size,
|
46 |
+
block_size=latent_size ** 2,
|
47 |
+
num_classes=args.num_classes,
|
48 |
+
cls_token_num=args.cls_token_num,
|
49 |
+
model_type=args.gpt_type,
|
50 |
+
).to(device=device, dtype=precision)
|
51 |
+
|
52 |
+
checkpoint = torch.load(f"{ckpt_folder}{gpt_ckpt}", map_location="cpu")
|
53 |
+
if args.from_fsdp: # fspd
|
54 |
+
model_weight = checkpoint
|
55 |
+
elif "model" in checkpoint: # ddp
|
56 |
+
model_weight = checkpoint["model"]
|
57 |
+
elif "module" in checkpoint: # deepspeed
|
58 |
+
model_weight = checkpoint["module"]
|
59 |
+
elif "state_dict" in checkpoint:
|
60 |
+
model_weight = checkpoint["state_dict"]
|
61 |
+
else:
|
62 |
+
raise Exception("please check model weight")
|
63 |
+
# if 'freqs_cis' in model_weight:
|
64 |
+
# model_weight.pop('freqs_cis')
|
65 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
66 |
+
gpt_model.eval()
|
67 |
+
del checkpoint
|
68 |
+
print(f"gpt model is loaded")
|
69 |
+
|
70 |
+
if args.compile:
|
71 |
+
print(f"compiling the model...")
|
72 |
+
gpt_model = torch.compile(
|
73 |
+
gpt_model,
|
74 |
+
mode="reduce-overhead",
|
75 |
+
fullgraph=True
|
76 |
+
) # requires PyTorch 2.0 (optional)
|
77 |
+
else:
|
78 |
+
print(f"no need to compile model in demo")
|
79 |
+
|
80 |
+
return vq_model, gpt_model, image_size
|
81 |
+
|
82 |
+
|
83 |
+
def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
84 |
+
n = 4
|
85 |
+
latent_size = image_size // args.downsample_size
|
86 |
+
# Labels to condition the model with (feel free to change):
|
87 |
+
class_labels = [class_label for _ in range(n)]
|
88 |
+
c_indices = torch.tensor(class_labels, device=device)
|
89 |
+
qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
|
90 |
+
|
91 |
+
t1 = time.time()
|
92 |
+
torch.manual_seed(seed)
|
93 |
+
index_sample = generate(
|
94 |
+
gpt_model, c_indices, latent_size ** 2,
|
95 |
+
cfg_scale=cfg_scale, cfg_interval=args.cfg_interval,
|
96 |
+
temperature=temperature, top_k=top_k,
|
97 |
+
top_p=top_p, sample_logits=True,
|
98 |
+
)
|
99 |
+
sampling_time = time.time() - t1
|
100 |
+
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
101 |
+
|
102 |
+
t2 = time.time()
|
103 |
+
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
104 |
+
decoder_time = time.time() - t2
|
105 |
+
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
106 |
+
# Convert to PIL.Image format:
|
107 |
+
samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy()
|
108 |
+
samples = [Image.fromarray(sample) for sample in samples]
|
109 |
+
return samples
|
110 |
+
|
111 |
+
|
112 |
+
parser = argparse.ArgumentParser()
|
113 |
+
parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
|
114 |
+
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
|
115 |
+
parser.add_argument("--from-fsdp", action='store_true')
|
116 |
+
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
|
117 |
+
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
118 |
+
parser.add_argument("--compile", action='store_true', default=False)
|
119 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
120 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
121 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
122 |
+
parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
|
123 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
124 |
+
parser.add_argument("--cfg-scale", type=float, default=4.0)
|
125 |
+
parser.add_argument("--cfg-interval", type=float, default=-1)
|
126 |
+
parser.add_argument("--seed", type=int, default=0)
|
127 |
+
parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
|
128 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
|
129 |
+
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
130 |
+
args = parser.parse_args()
|
131 |
+
|
132 |
+
vq_model, gpt_model, image_size = load_model(args)
|
133 |
+
|
134 |
+
with gr.Blocks() as demo:
|
135 |
+
gr.Markdown("<h1 style='text-align: center'>Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation</h1>")
|
136 |
+
|
137 |
+
with gr.Tabs():
|
138 |
+
with gr.TabItem('Generate'):
|
139 |
+
with gr.Row():
|
140 |
+
with gr.Column():
|
141 |
+
# with gr.Row():
|
142 |
+
# image_size = gr.Radio(choices=[384], value=384, label='Peize Model Resolution')
|
143 |
+
with gr.Row():
|
144 |
+
i1k_class = gr.Dropdown(
|
145 |
+
list(IMAGENET_1K_CLASSES.values()),
|
146 |
+
value='Eskimo dog, husky [爱斯基摩犬,哈士奇]',
|
147 |
+
type="index", label='ImageNet-1K Class'
|
148 |
+
)
|
149 |
+
cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
|
150 |
+
top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=4000, label='Top-K')
|
151 |
+
top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P")
|
152 |
+
temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature')
|
153 |
+
seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
|
154 |
+
# seed = gr.Number(value=0, label='Seed')
|
155 |
+
button = gr.Button("Generate", variant="primary")
|
156 |
+
with gr.Column():
|
157 |
+
output = gr.Gallery(label='Generated Images', height=700)
|
158 |
+
button.click(infer, inputs=[cfg_scale, top_k, top_p, temperature, i1k_class, seed], outputs=[output])
|
159 |
+
demo.queue()
|
160 |
+
demo.launch(debug=True)
|
imagenet_en_cn.py
ADDED
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
IMAGENET_1K_CLASSES = {
|
2 |
+
0: 'tench, Tinca tinca [丁鲷]',
|
3 |
+
1: 'goldfish, Carassius auratus [金鱼]',
|
4 |
+
2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias [大白鲨]',
|
5 |
+
3: 'tiger shark, Galeocerdo cuvieri [虎鲨]',
|
6 |
+
4: 'hammerhead, hammerhead shark [锤头鲨]',
|
7 |
+
5: 'electric ray, crampfish, numbfish, torpedo [电鳐]',
|
8 |
+
6: 'stingray [黄貂鱼]',
|
9 |
+
7: 'cock [公鸡]',
|
10 |
+
8: 'hen [母鸡]',
|
11 |
+
9: 'ostrich, Struthio camelus [鸵鸟]',
|
12 |
+
10: 'brambling, Fringilla montifringilla [燕雀]',
|
13 |
+
11: 'goldfinch, Carduelis carduelis [金翅雀]',
|
14 |
+
12: 'house finch, linnet, Carpodacus mexicanus [家朱雀]',
|
15 |
+
13: 'junco, snowbird [灯芯草雀]',
|
16 |
+
14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea [靛蓝雀,靛蓝鸟]',
|
17 |
+
15: 'robin, American robin, Turdus migratorius [蓝鹀]',
|
18 |
+
16: 'bulbul [夜莺]',
|
19 |
+
17: 'jay [松鸦]',
|
20 |
+
18: 'magpie [喜鹊]',
|
21 |
+
19: 'chickadee [山雀]',
|
22 |
+
20: 'water ouzel, dipper [河鸟]',
|
23 |
+
21: 'kite [鸢(猛禽)]',
|
24 |
+
22: 'bald eagle, American eagle, Haliaeetus leucocephalus [秃头鹰]',
|
25 |
+
23: 'vulture [秃鹫]',
|
26 |
+
24: 'great grey owl, great gray owl, Strix nebulosa [大灰猫头鹰]',
|
27 |
+
25: 'European fire salamander, Salamandra salamandra [欧洲火蝾螈]',
|
28 |
+
26: 'common newt, Triturus vulgaris [普通蝾螈]',
|
29 |
+
27: 'eft [水蜥]',
|
30 |
+
28: 'spotted salamander, Ambystoma maculatum [斑点蝾螈]',
|
31 |
+
29: 'axolotl, mud puppy, Ambystoma mexicanum [蝾螈,泥狗]',
|
32 |
+
30: 'bullfrog, Rana catesbeiana [牛蛙]',
|
33 |
+
31: 'tree frog, tree-frog [树蛙]',
|
34 |
+
32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui [尾蛙,铃蟾蜍,肋蟾蜍,尾蟾蜍]',
|
35 |
+
33: 'loggerhead, loggerhead turtle, Caretta caretta [红海龟]',
|
36 |
+
34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea [皮革龟]',
|
37 |
+
35: 'mud turtle [泥龟]',
|
38 |
+
36: 'terrapin [淡水龟]',
|
39 |
+
37: 'box turtle, box tortoise [箱龟]',
|
40 |
+
38: 'banded gecko [带状壁虎]',
|
41 |
+
39: 'common iguana, iguana, Iguana iguana [普通鬣蜥]',
|
42 |
+
40: 'American chameleon, anole, Anolis carolinensis [美国变色龙]',
|
43 |
+
41: 'whiptail, whiptail lizard [鞭尾蜥蜴]',
|
44 |
+
42: 'agama [飞龙科蜥蜴]',
|
45 |
+
43: 'frilled lizard, Chlamydosaurus kingi [褶边蜥蜴]',
|
46 |
+
44: 'alligator lizard [鳄鱼蜥蜴]',
|
47 |
+
45: 'Gila monster, Heloderma suspectum [毒蜥]',
|
48 |
+
46: 'green lizard, Lacerta viridis [绿蜥蜴]',
|
49 |
+
47: 'African chameleon, Chamaeleo chamaeleon [非洲变色龙]',
|
50 |
+
48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis [科莫多蜥蜴]',
|
51 |
+
49: 'African crocodile, Nile crocodile, Crocodylus niloticus [非洲鳄,尼罗河鳄鱼]',
|
52 |
+
50: 'American alligator, Alligator mississipiensis [美国鳄鱼,鳄鱼]',
|
53 |
+
51: 'triceratops [三角龙]',
|
54 |
+
52: 'thunder snake, worm snake, Carphophis amoenus [雷蛇,蠕虫蛇]',
|
55 |
+
53: 'ringneck snake, ring-necked snake, ring snake [环蛇,环颈蛇]',
|
56 |
+
54: 'hognose snake, puff adder, sand viper [希腊蛇]',
|
57 |
+
55: 'green snake, grass snake [绿蛇,草蛇]',
|
58 |
+
56: 'king snake, kingsnake [国王蛇]',
|
59 |
+
57: 'garter snake, grass snake [袜带蛇,草蛇]',
|
60 |
+
58: 'water snake [水蛇]',
|
61 |
+
59: 'vine snake [藤蛇]',
|
62 |
+
60: 'night snake, Hypsiglena torquata [夜蛇]',
|
63 |
+
61: 'boa constrictor, Constrictor constrictor [大蟒蛇]',
|
64 |
+
62: 'rock python, rock snake, Python sebae [岩石蟒蛇,岩蛇,蟒蛇]',
|
65 |
+
63: 'Indian cobra, Naja naja [印度眼镜蛇]',
|
66 |
+
64: 'green mamba [绿曼巴]',
|
67 |
+
65: 'sea snake [海蛇]',
|
68 |
+
66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus [角腹蛇]',
|
69 |
+
67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus [菱纹响尾蛇]',
|
70 |
+
68: 'sidewinder, horned rattlesnake, Crotalus cerastes [角响尾蛇]',
|
71 |
+
69: 'trilobite [三叶虫]',
|
72 |
+
70: 'harvestman, daddy longlegs, Phalangium opilio [盲蜘蛛]',
|
73 |
+
71: 'scorpion [蝎子]',
|
74 |
+
72: 'black and gold garden spider, Argiope aurantia [黑金花园蜘蛛]',
|
75 |
+
73: 'barn spider, Araneus cavaticus [谷仓蜘蛛]',
|
76 |
+
74: 'garden spider, Aranea diademata [花园蜘蛛]',
|
77 |
+
75: 'black widow, Latrodectus mactans [黑寡妇蜘蛛]',
|
78 |
+
76: 'tarantula [狼蛛]',
|
79 |
+
77: 'wolf spider, hunting spider [狼蜘蛛,狩猎蜘蛛]',
|
80 |
+
78: 'tick [壁虱]',
|
81 |
+
79: 'centipede [蜈蚣]',
|
82 |
+
80: 'black grouse [黑松鸡]',
|
83 |
+
81: 'ptarmigan [松鸡,雷鸟]',
|
84 |
+
82: 'ruffed grouse, partridge, Bonasa umbellus [披肩鸡,披肩榛鸡]',
|
85 |
+
83: 'prairie chicken, prairie grouse, prairie fowl [草原鸡,草原松鸡]',
|
86 |
+
84: 'peacock [孔雀]',
|
87 |
+
85: 'quail [鹌鹑]',
|
88 |
+
86: 'partridge [鹧鸪]',
|
89 |
+
87: 'African grey, African gray, Psittacus erithacus [非洲灰鹦鹉]',
|
90 |
+
88: 'macaw [金刚鹦鹉]',
|
91 |
+
89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita [硫冠鹦鹉]',
|
92 |
+
90: 'lorikeet [短尾鹦鹉]',
|
93 |
+
91: 'coucal [褐翅鸦鹃]',
|
94 |
+
92: 'bee eater [蜜蜂]',
|
95 |
+
93: 'hornbill [犀鸟]',
|
96 |
+
94: 'hummingbird [蜂鸟]',
|
97 |
+
95: 'jacamar [鹟䴕]',
|
98 |
+
96: 'toucan [犀鸟]',
|
99 |
+
97: 'drake [野鸭]',
|
100 |
+
98: 'red-breasted merganser, Mergus serrator [���胸秋沙鸭]',
|
101 |
+
99: 'goose [鹅]',
|
102 |
+
100: 'black swan, Cygnus atratus [黑天鹅]',
|
103 |
+
101: 'tusker [大象]',
|
104 |
+
102: 'echidna, spiny anteater, anteater [针鼹鼠]',
|
105 |
+
103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus [鸭嘴兽]',
|
106 |
+
104: 'wallaby, brush kangaroo [沙袋鼠]',
|
107 |
+
105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus [考拉,考拉熊]',
|
108 |
+
106: 'wombat [袋熊]',
|
109 |
+
107: 'jellyfish [水母]',
|
110 |
+
108: 'sea anemone, anemone [海葵]',
|
111 |
+
109: 'brain coral [脑珊瑚]',
|
112 |
+
110: 'flatworm, platyhelminth [扁形虫扁虫]',
|
113 |
+
111: 'nematode, nematode worm, roundworm [线虫,蛔虫]',
|
114 |
+
112: 'conch [海螺]',
|
115 |
+
113: 'snail [蜗牛]',
|
116 |
+
114: 'slug [鼻涕虫]',
|
117 |
+
115: 'sea slug, nudibranch [海参]',
|
118 |
+
116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore [石鳖]',
|
119 |
+
117: 'chambered nautilus, pearly nautilus, nautilus [鹦鹉螺]',
|
120 |
+
118: 'Dungeness crab, Cancer magister [珍宝蟹]',
|
121 |
+
119: 'rock crab, Cancer irroratus [石蟹]',
|
122 |
+
120: 'fiddler crab [招潮蟹]',
|
123 |
+
121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica [帝王蟹,阿拉斯加蟹,阿拉斯加帝王蟹]',
|
124 |
+
122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus [美国龙虾,缅因州龙虾]',
|
125 |
+
123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish [大螯虾]',
|
126 |
+
124: 'crayfish, crawfish, crawdad, crawdaddy [小龙虾]',
|
127 |
+
125: 'hermit crab [寄居蟹]',
|
128 |
+
126: 'isopod [等足目动物(明虾和螃蟹近亲)]',
|
129 |
+
127: 'white stork, Ciconia ciconia [白鹳]',
|
130 |
+
128: 'black stork, Ciconia nigra [黑鹳]',
|
131 |
+
129: 'spoonbill [鹭]',
|
132 |
+
130: 'flamingo [火烈鸟]',
|
133 |
+
131: 'little blue heron, Egretta caerulea [小蓝鹭]',
|
134 |
+
132: 'American egret, great white heron, Egretta albus [美国鹭,大白鹭]',
|
135 |
+
133: 'bittern [麻鸦]',
|
136 |
+
134: 'crane [鹤]',
|
137 |
+
135: 'limpkin, Aramus pictus [秧鹤]',
|
138 |
+
136: 'European gallinule, Porphyrio porphyrio [欧洲水鸡,紫水鸡]',
|
139 |
+
137: 'American coot, marsh hen, mud hen, water hen, Fulica americana [沼泽泥母鸡,水母鸡]',
|
140 |
+
138: 'bustard [鸨]',
|
141 |
+
139: 'ruddy turnstone, Arenaria interpres [红翻石鹬]',
|
142 |
+
140: 'red-backed sandpiper, dunlin, Erolia alpina [红背鹬,黑腹滨鹬]',
|
143 |
+
141: 'redshank, Tringa totanus [红脚鹬]',
|
144 |
+
142: 'dowitcher [半蹼鹬]',
|
145 |
+
143: 'oystercatcher, oyster catcher [蛎鹬]',
|
146 |
+
144: 'pelican [鹈鹕]',
|
147 |
+
145: 'king penguin, Aptenodytes patagonica [国王企鹅]',
|
148 |
+
146: 'albatross, mollymawk [信天翁,大海鸟]',
|
149 |
+
147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus [灰鲸]',
|
150 |
+
148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca [杀人鲸,逆戟鲸,虎鲸]',
|
151 |
+
149: 'dugong, Dugong dugon [海牛]',
|
152 |
+
150: 'sea lion [海狮]',
|
153 |
+
151: 'Chihuahua [奇瓦瓦]',
|
154 |
+
152: 'Japanese spaniel [日本猎犬]',
|
155 |
+
153: 'Maltese dog, Maltese terrier, Maltese [马尔济斯犬]',
|
156 |
+
154: 'Pekinese, Pekingese, Peke [狮子狗]',
|
157 |
+
155: 'Shih-Tzu [西施犬]',
|
158 |
+
156: 'Blenheim spaniel [布莱尼姆猎犬]',
|
159 |
+
157: 'papillon [巴比狗]',
|
160 |
+
158: 'toy terrier [玩具犬]',
|
161 |
+
159: 'Rhodesian ridgeback [罗得西亚长背猎狗]',
|
162 |
+
160: 'Afghan hound, Afghan [阿富汗猎犬]',
|
163 |
+
161: 'basset, basset hound [猎犬]',
|
164 |
+
162: 'beagle [比格犬,猎兔犬]',
|
165 |
+
163: 'bloodhound, sleuthhound [侦探犬]',
|
166 |
+
164: 'bluetick [蓝色快狗]',
|
167 |
+
165: 'black-and-tan coonhound [黑褐猎浣熊犬]',
|
168 |
+
166: 'Walker hound, Walker foxhound [沃克猎犬]',
|
169 |
+
167: 'English foxhound [英国猎狐犬]',
|
170 |
+
168: 'redbone [美洲赤狗]',
|
171 |
+
169: 'borzoi, Russian wolfhound [俄罗斯猎狼犬]',
|
172 |
+
170: 'Irish wolfhound [爱尔兰猎狼犬]',
|
173 |
+
171: 'Italian greyhound [意大利灰狗]',
|
174 |
+
172: 'whippet [惠比特犬]',
|
175 |
+
173: 'Ibizan hound, Ibizan Podenco [依比沙猎犬]',
|
176 |
+
174: 'Norwegian elkhound, elkhound [挪威猎犬]',
|
177 |
+
175: 'otterhound, otter hound [奥达猎犬,水獭猎犬]',
|
178 |
+
176: 'Saluki, gazelle hound [沙克犬,瞪羚猎犬]',
|
179 |
+
177: 'Scottish deerhound, deerhound [苏格兰猎鹿犬,猎鹿犬]',
|
180 |
+
178: 'Weimaraner [威玛猎犬]',
|
181 |
+
179: 'Staffordshire bullterrier, Staffordshire bull terrier [斯塔福德郡牛头梗,斯塔福德郡斗牛梗]',
|
182 |
+
180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier [美国斯塔福德郡梗,美国比特斗牛梗,斗牛梗]',
|
183 |
+
181: 'Bedlington terrier [贝德灵顿梗]',
|
184 |
+
182: 'Border terrier [边境梗]',
|
185 |
+
183: 'Kerry blue terrier [凯丽蓝梗]',
|
186 |
+
184: 'Irish terrier [爱尔兰梗]',
|
187 |
+
185: 'Norfolk terrier [诺福克梗]',
|
188 |
+
186: 'Norwich terrier [诺维奇梗]',
|
189 |
+
187: 'Yorkshire terrier [约克郡梗]',
|
190 |
+
188: 'wire-haired fox terrier [刚毛猎狐梗]',
|
191 |
+
189: 'Lakeland terrier [莱克兰梗]',
|
192 |
+
190: 'Sealyham terrier, Sealyham [锡利哈姆梗]',
|
193 |
+
191: 'Airedale, Airedale terrier [艾尔谷犬]',
|
194 |
+
192: 'cairn, cairn terrier [凯恩梗]',
|
195 |
+
193: 'Australian terrier [澳大利亚梗]',
|
196 |
+
194: 'Dandie Dinmont, Dandie Dinmont terrier [丹迪丁蒙梗]',
|
197 |
+
195: 'Boston bull, Boston terrier [波士顿梗]',
|
198 |
+
196: 'miniature schnauzer [迷你雪纳瑞犬]',
|
199 |
+
197: 'giant schnauzer [巨型雪纳瑞犬]',
|
200 |
+
198: 'standard schnauzer [标准雪纳瑞犬]',
|
201 |
+
199: 'Scotch terrier, Scottish terrier, Scottie [苏格兰梗]',
|
202 |
+
200: 'Tibetan terrier, chrysanthemum dog [西藏梗,菊花狗]',
|
203 |
+
201: 'silky terrier, Sydney silky [丝毛梗]',
|
204 |
+
202: 'soft-coated wheaten terrier [软毛麦色梗]',
|
205 |
+
203: 'West Highland white terrier [西高地白梗]',
|
206 |
+
204: 'Lhasa, Lhasa apso [拉萨阿普索犬]',
|
207 |
+
205: 'flat-coated retriever [平毛寻回犬]',
|
208 |
+
206: 'curly-coated retriever [卷毛寻回犬]',
|
209 |
+
207: 'golden retriever [金毛猎犬]',
|
210 |
+
208: 'Labrador retriever [拉布拉多猎犬]',
|
211 |
+
209: 'Chesapeake Bay retriever [乞沙比克猎犬]',
|
212 |
+
210: 'German short-haired pointer [德国短毛猎犬]',
|
213 |
+
211: 'vizsla, Hungarian pointer [维兹拉犬]',
|
214 |
+
212: 'English setter [英国谍犬]',
|
215 |
+
213: 'Irish setter, red setter [爱尔兰雪达犬,红色猎犬]',
|
216 |
+
214: 'Gordon setter [戈登雪达犬]',
|
217 |
+
215: 'Brittany spaniel [布列塔尼犬猎犬]',
|
218 |
+
216: 'clumber, clumber spaniel [黄毛,黄毛猎犬]',
|
219 |
+
217: 'English springer, English springer spaniel [英国史宾格犬]',
|
220 |
+
218: 'Welsh springer spaniel [威尔士史宾格犬]',
|
221 |
+
219: 'cocker spaniel, English cocker spaniel, cocker [可卡犬,英国可卡犬]',
|
222 |
+
220: 'Sussex spaniel [萨塞克斯猎犬]',
|
223 |
+
221: 'Irish water spaniel [爱尔兰水猎犬]',
|
224 |
+
222: 'kuvasz [哥威斯犬]',
|
225 |
+
223: 'schipperke [舒柏奇犬]',
|
226 |
+
224: 'groenendael [比利时牧羊犬]',
|
227 |
+
225: 'malinois [马里努阿犬]',
|
228 |
+
226: 'briard [伯瑞犬]',
|
229 |
+
227: 'kelpie [凯尔皮犬]',
|
230 |
+
228: 'komondor [匈牙利牧羊犬]',
|
231 |
+
229: 'Old English sheepdog, bobtail [老英国牧羊犬]',
|
232 |
+
230: 'Shetland sheepdog, Shetland sheep dog, Shetland [喜乐蒂牧羊犬]',
|
233 |
+
231: 'collie [牧羊犬]',
|
234 |
+
232: 'Border collie [边境牧羊犬]',
|
235 |
+
233: 'Bouvier des Flandres, Bouviers des Flandres [法兰德斯牧牛狗]',
|
236 |
+
234: 'Rottweiler [罗特韦尔犬]',
|
237 |
+
235: 'German shepherd, German shepherd dog, German police dog, alsatian [德国牧羊犬,德国警犬,阿尔萨斯]',
|
238 |
+
236: 'Doberman, Doberman pinscher [多伯曼犬,杜宾犬]',
|
239 |
+
237: 'miniature pinscher [迷你杜宾犬]',
|
240 |
+
238: 'Greater Swiss Mountain dog [大瑞士山地犬]',
|
241 |
+
239: 'Bernese mountain dog [伯恩山犬]',
|
242 |
+
240: 'Appenzeller [Appenzeller狗]',
|
243 |
+
241: 'EntleBucher [EntleBucher狗]',
|
244 |
+
242: 'boxer [拳师狗]',
|
245 |
+
243: 'bull mastiff [斗牛獒]',
|
246 |
+
244: 'Tibetan mastiff [藏獒]',
|
247 |
+
245: 'French bulldog [法国斗牛犬]',
|
248 |
+
246: 'Great Dane [大丹犬]',
|
249 |
+
247: 'Saint Bernard, St Bernard [圣伯纳德狗]',
|
250 |
+
248: 'Eskimo dog, husky [爱斯基摩犬,哈士奇]',
|
251 |
+
249: 'malamute, malemute, Alaskan malamute [雪橇犬,阿拉斯加爱斯基摩狗]',
|
252 |
+
250: 'Siberian husky [哈士奇]',
|
253 |
+
251: 'dalmatian, coach dog, carriage dog [达尔马提亚,教练车狗]',
|
254 |
+
252: 'affenpinscher, monkey pinscher, monkey dog [狮毛狗]',
|
255 |
+
253: 'basenji [巴辛吉狗]',
|
256 |
+
254: 'pug, pug-dog [哈巴狗,狮子狗]',
|
257 |
+
255: 'Leonberg [莱昂贝格狗]',
|
258 |
+
256: 'Newfoundland, Newfoundland dog [纽芬兰岛狗]',
|
259 |
+
257: 'Great Pyrenees [大白熊犬]',
|
260 |
+
258: 'Samoyed, Samoyede [萨摩耶犬]',
|
261 |
+
259: 'Pomeranian [博美犬]',
|
262 |
+
260: 'chow, chow chow [松狮,松狮]',
|
263 |
+
261: 'keeshond [荷兰卷尾狮毛狗]',
|
264 |
+
262: 'Brabancon griffon [布鲁塞尔格林芬犬]',
|
265 |
+
263: 'Pembroke, Pembroke Welsh corgi [彭布洛克威尔士科基犬]',
|
266 |
+
264: 'Cardigan, Cardigan Welsh corgi [威尔士柯基犬]',
|
267 |
+
265: 'toy poodle [玩具贵宾犬]',
|
268 |
+
266: 'miniature poodle [迷你贵宾犬]',
|
269 |
+
267: 'standard poodle [标准贵宾犬]',
|
270 |
+
268: 'Mexican hairless [墨西哥无毛犬]',
|
271 |
+
269: 'timber wolf, grey wolf, gray wolf, Canis lupus [灰狼]',
|
272 |
+
270: 'white wolf, Arctic wolf, Canis lupus tundrarum [白狼,北极狼]',
|
273 |
+
271: 'red wolf, maned wolf, Canis rufus, Canis niger [红太狼,鬃狼,犬犬鲁弗斯]',
|
274 |
+
272: 'coyote, prairie wolf, brush wolf, Canis latrans [狼,草原狼,刷狼,郊狼]',
|
275 |
+
273: 'dingo, warrigal, warragal, Canis dingo [澳洲野狗,澳大利亚野犬]',
|
276 |
+
274: 'dhole, Cuon alpinus [豺]',
|
277 |
+
275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus [非洲猎犬,土狼犬]',
|
278 |
+
276: 'hyena, hyaena [鬣狗]',
|
279 |
+
277: 'red fox, Vulpes vulpes [红狐狸]',
|
280 |
+
278: 'kit fox, Vulpes macrotis [沙狐]',
|
281 |
+
279: 'Arctic fox, white fox, Alopex lagopus [北极狐狸,白狐狸]',
|
282 |
+
280: 'grey fox, gray fox, Urocyon cinereoargenteus [灰狐狸]',
|
283 |
+
281: 'tabby, tabby cat [虎斑猫]',
|
284 |
+
282: 'tiger cat [山猫,虎猫]',
|
285 |
+
283: 'Persian cat [波斯猫]',
|
286 |
+
284: 'Siamese cat, Siamese [暹罗暹罗猫,]',
|
287 |
+
285: 'Egyptian cat [埃及猫]',
|
288 |
+
286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor [美洲狮,美洲豹]',
|
289 |
+
287: 'lynx, catamount [猞猁,山猫]',
|
290 |
+
288: 'leopard, Panthera pardus [豹子]',
|
291 |
+
289: 'snow leopard, ounce, Panthera uncia [雪豹]',
|
292 |
+
290: 'jaguar, panther, Panthera onca, Felis onca [美洲虎]',
|
293 |
+
291: 'lion, king of beasts, Panthera leo [狮子]',
|
294 |
+
292: 'tiger, Panthera tigris [老虎]',
|
295 |
+
293: 'cheetah, chetah, Acinonyx jubatus [猎豹]',
|
296 |
+
294: 'brown bear, bruin, Ursus arctos [棕熊]',
|
297 |
+
295: 'American black bear, black bear, Ursus americanus, Euarctos americanus [美洲黑熊]',
|
298 |
+
296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus [冰熊,北极熊]',
|
299 |
+
297: 'sloth bear, Melursus ursinus, Ursus ursinus [懒熊]',
|
300 |
+
298: 'mongoose [猫鼬]',
|
301 |
+
299: 'meerkat, mierkat [猫鼬,海猫]',
|
302 |
+
300: 'tiger beetle [虎甲虫]',
|
303 |
+
301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle [瓢虫]',
|
304 |
+
302: 'ground beetle, carabid beetle [土鳖虫]',
|
305 |
+
303: 'long-horned beetle, longicorn, longicorn beetle [天牛]',
|
306 |
+
304: 'leaf beetle, chrysomelid [龟甲虫]',
|
307 |
+
305: 'dung beetle [粪甲虫]',
|
308 |
+
306: 'rhinoceros beetle [犀牛甲虫]',
|
309 |
+
307: 'weevil [象甲]',
|
310 |
+
308: 'fly [苍蝇]',
|
311 |
+
309: 'bee [蜜蜂]',
|
312 |
+
310: 'ant, emmet, pismire [蚂蚁]',
|
313 |
+
311: 'grasshopper, hopper [蚱蜢]',
|
314 |
+
312: 'cricket [蟋蟀]',
|
315 |
+
313: 'walking stick, walkingstick, stick insect [竹节虫]',
|
316 |
+
314: 'cockroach, roach [蟑螂]',
|
317 |
+
315: 'mantis, mantid [螳螂]',
|
318 |
+
316: 'cicada, cicala [蝉]',
|
319 |
+
317: 'leafhopper [叶蝉]',
|
320 |
+
318: 'lacewing, lacewing fly [草蜻蛉]',
|
321 |
+
319: 'dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk [蜻蜓]',
|
322 |
+
320: 'damselfly [豆娘,蜻蛉]',
|
323 |
+
321: 'admiral [优红蛱蝶]',
|
324 |
+
322: 'ringlet, ringlet butterfly [小环蝴蝶]',
|
325 |
+
323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus [君主蝴蝶,大斑蝶]',
|
326 |
+
324: 'cabbage butterfly [菜粉蝶]',
|
327 |
+
325: 'sulphur butterfly, sulfur butterfly [白蝴蝶]',
|
328 |
+
326: 'lycaenid, lycaenid butterfly [灰蝶]',
|
329 |
+
327: 'starfish, sea star [海星]',
|
330 |
+
328: 'sea urchin [海胆]',
|
331 |
+
329: 'sea cucumber, holothurian [海参,海黄瓜]',
|
332 |
+
330: 'wood rabbit, cottontail, cottontail rabbit [野兔]',
|
333 |
+
331: 'hare [兔]',
|
334 |
+
332: 'Angora, Angora rabbit [安哥拉兔]',
|
335 |
+
333: 'hamster [仓鼠]',
|
336 |
+
334: 'porcupine, hedgehog [刺猬,豪猪,]',
|
337 |
+
335: 'fox squirrel, eastern fox squirrel, Sciurus niger [黑松鼠]',
|
338 |
+
336: 'marmot [土拨鼠]',
|
339 |
+
337: 'beaver [海狸]',
|
340 |
+
338: 'guinea pig, Cavia cobaya [豚鼠,豚鼠]',
|
341 |
+
339: 'sorrel [栗色马]',
|
342 |
+
340: 'zebra [斑马]',
|
343 |
+
341: 'hog, pig, grunter, squealer, Sus scrofa [猪]',
|
344 |
+
342: 'wild boar, boar, Sus scrofa [野猪]',
|
345 |
+
343: 'warthog [疣猪]',
|
346 |
+
344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius [河马]',
|
347 |
+
345: 'ox [牛]',
|
348 |
+
346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis [水牛,亚洲水牛]',
|
349 |
+
347: 'bison [野牛]',
|
350 |
+
348: 'ram, tup [公羊]',
|
351 |
+
349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis [大角羊,洛矶山大角羊]',
|
352 |
+
350: 'ibex, Capra ibex [山羊]',
|
353 |
+
351: 'hartebeest [狷羚]',
|
354 |
+
352: 'impala, Aepyceros melampus [黑斑羚]',
|
355 |
+
353: 'gazelle [瞪羚]',
|
356 |
+
354: 'Arabian camel, dromedary, Camelus dromedarius [阿拉伯单峰骆驼,骆驼]',
|
357 |
+
355: 'llama [骆驼]',
|
358 |
+
356: 'weasel [黄鼠狼]',
|
359 |
+
357: 'mink [水貂]',
|
360 |
+
358: 'polecat, fitch, foulmart, foumart, Mustela putorius [臭猫]',
|
361 |
+
359: 'black-footed ferret, ferret, Mustela nigripes [黑足鼬]',
|
362 |
+
360: 'otter [水獭]',
|
363 |
+
361: 'skunk, polecat, wood pussy [臭鼬,木猫]',
|
364 |
+
362: 'badger [獾]',
|
365 |
+
363: 'armadillo [犰狳]',
|
366 |
+
364: 'three-toed sloth, ai, Bradypus tridactylus [树懒]',
|
367 |
+
365: 'orangutan, orang, orangutang, Pongo pygmaeus [猩猩,婆罗洲猩猩]',
|
368 |
+
366: 'gorilla, Gorilla gorilla [大猩猩]',
|
369 |
+
367: 'chimpanzee, chimp, Pan troglodytes [黑猩猩]',
|
370 |
+
368: 'gibbon, Hylobates lar [长臂猿]',
|
371 |
+
369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus [合趾猿长臂猿,合趾猿]',
|
372 |
+
370: 'guenon, guenon monkey [长尾猴]',
|
373 |
+
371: 'patas, hussar monkey, Erythrocebus patas [赤猴]',
|
374 |
+
372: 'baboon [狒狒]',
|
375 |
+
373: 'macaque [恒河猴,猕猴]',
|
376 |
+
374: 'langur [白头叶猴]',
|
377 |
+
375: 'colobus, colobus monkey [疣猴]',
|
378 |
+
376: 'proboscis monkey, Nasalis larvatus [长鼻猴]',
|
379 |
+
377: 'marmoset [狨(美洲产小型长尾猴)]',
|
380 |
+
378: 'capuchin, ringtail, Cebus capucinus [卷尾猴]',
|
381 |
+
379: 'howler monkey, howler [吼猴]',
|
382 |
+
380: 'titi, titi monkey [伶猴]',
|
383 |
+
381: 'spider monkey, Ateles geoffroyi [蜘蛛猴]',
|
384 |
+
382: 'squirrel monkey, Saimiri sciureus [松鼠猴]',
|
385 |
+
383: 'Madagascar cat, ring-tailed lemur, Lemur catta [马达加斯加环尾狐猴,鼠狐猴]',
|
386 |
+
384: 'indri, indris, Indri indri, Indri brevicaudatus [大狐猴,马达加斯加大狐猴]',
|
387 |
+
385: 'Indian elephant, Elephas maximus [印度大象,亚洲象]',
|
388 |
+
386: 'African elephant, Loxodonta africana [非洲象,非洲象]',
|
389 |
+
387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens [小熊猫]',
|
390 |
+
388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca [大熊猫]',
|
391 |
+
389: 'barracouta, snoek [杖鱼]',
|
392 |
+
390: 'eel [鳗鱼]',
|
393 |
+
391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch [银鲑,银鲑鱼]',
|
394 |
+
392: 'rock beauty, Holocanthus tricolor [三色刺蝶鱼]',
|
395 |
+
393: 'anemone fish [海葵鱼]',
|
396 |
+
394: 'sturgeon [鲟鱼]',
|
397 |
+
395: 'gar, garfish, garpike, billfish, Lepisosteus osseus [雀鳝]',
|
398 |
+
396: 'lionfish [狮子鱼]',
|
399 |
+
397: 'puffer, pufferfish, blowfish, globefish [河豚]',
|
400 |
+
398: 'abacus [算盘]',
|
401 |
+
399: 'abaya [长袍]',
|
402 |
+
400: 'academic gown, academic robe, judge robe [学位袍]',
|
403 |
+
401: 'accordion, piano accordion, squeeze box [手风琴]',
|
404 |
+
402: 'acoustic guitar [原声吉他]',
|
405 |
+
403: 'aircraft carrier, carrier, flattop, attack aircraft carrier [航空母舰]',
|
406 |
+
404: 'airliner [客机]',
|
407 |
+
405: 'airship, dirigible [飞艇]',
|
408 |
+
406: 'altar [祭坛]',
|
409 |
+
407: 'ambulance [救护车]',
|
410 |
+
408: 'amphibian, amphibious vehicle [水陆两用车]',
|
411 |
+
409: 'analog clock [模拟时钟]',
|
412 |
+
410: 'apiary, bee house [蜂房]',
|
413 |
+
411: 'apron [围裙]',
|
414 |
+
412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin [垃圾桶]',
|
415 |
+
413: 'assault rifle, assault gun [攻击步枪,枪]',
|
416 |
+
414: 'backpack, back pack, knapsack, packsack, rucksack, haversack [背包]',
|
417 |
+
415: 'bakery, bakeshop, bakehouse [面包店,面包铺,]',
|
418 |
+
416: 'balance beam, beam [平衡木]',
|
419 |
+
417: 'balloon [热气球]',
|
420 |
+
418: 'ballpoint, ballpoint pen, ballpen, Biro [圆珠笔]',
|
421 |
+
419: 'Band Aid [创可贴]',
|
422 |
+
420: 'banjo [班卓琴]',
|
423 |
+
421: 'bannister, banister, balustrade, balusters, handrail [栏杆,楼梯扶手]',
|
424 |
+
422: 'barbell [杠铃]',
|
425 |
+
423: 'barber chair [理发师的椅子]',
|
426 |
+
424: 'barbershop [理发店]',
|
427 |
+
425: 'barn [牲口棚]',
|
428 |
+
426: 'barometer [晴雨表]',
|
429 |
+
427: 'barrel, cask [圆筒]',
|
430 |
+
428: 'barrow, garden cart, lawn cart, wheelbarrow [园地小车,手推车]',
|
431 |
+
429: 'baseball [棒球]',
|
432 |
+
430: 'basketball [篮球]',
|
433 |
+
431: 'bassinet [婴儿床]',
|
434 |
+
432: 'bassoon [巴松管,低音管]',
|
435 |
+
433: 'bathing cap, swimming cap [游泳帽]',
|
436 |
+
434: 'bath towel [沐浴毛巾]',
|
437 |
+
435: 'bathtub, bathing tub, bath, tub [浴缸,澡盆]',
|
438 |
+
436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon [沙滩车,旅行车]',
|
439 |
+
437: 'beacon, lighthouse, beacon light, pharos [灯塔]',
|
440 |
+
438: 'beaker [高脚杯]',
|
441 |
+
439: 'bearskin, busby, shako [熊皮高帽]',
|
442 |
+
440: 'beer bottle [啤酒瓶]',
|
443 |
+
441: 'beer glass [啤酒杯]',
|
444 |
+
442: 'bell cote, bell cot [钟塔]',
|
445 |
+
443: 'bib [(小儿用的)围嘴]',
|
446 |
+
444: 'bicycle-built-for-two, tandem bicycle, tandem [串联自行车,]',
|
447 |
+
445: 'bikini, two-piece [比基尼]',
|
448 |
+
446: 'binder, ring-binder [装订册]',
|
449 |
+
447: 'binoculars, field glasses, opera glasses [双筒望远镜]',
|
450 |
+
448: 'birdhouse [鸟舍]',
|
451 |
+
449: 'boathouse [船库]',
|
452 |
+
450: 'bobsled, bobsleigh, bob [雪橇]',
|
453 |
+
451: 'bolo tie, bolo, bola tie, bola [饰扣式领带]',
|
454 |
+
452: 'bonnet, poke bonnet [阔边女帽]',
|
455 |
+
453: 'bookcase [书橱]',
|
456 |
+
454: 'bookshop, bookstore, bookstall [书店,书摊]',
|
457 |
+
455: 'bottlecap [瓶盖]',
|
458 |
+
456: 'bow [弓箭]',
|
459 |
+
457: 'bow tie, bow-tie, bowtie [蝴蝶结领结]',
|
460 |
+
458: 'brass, memorial tablet, plaque [铜制牌位]',
|
461 |
+
459: 'brassiere, bra, bandeau [奶罩]',
|
462 |
+
460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty [防波堤,海堤]',
|
463 |
+
461: 'breastplate, aegis, egis [铠甲]',
|
464 |
+
462: 'broom [扫帚]',
|
465 |
+
463: 'bucket, pail [桶]',
|
466 |
+
464: 'buckle [扣环]',
|
467 |
+
465: 'bulletproof vest [防弹背心]',
|
468 |
+
466: 'bullet train, bullet [动车,子弹头列车]',
|
469 |
+
467: 'butcher shop, meat market [肉铺,肉菜市场]',
|
470 |
+
468: 'cab, hack, taxi, taxicab [出租车]',
|
471 |
+
469: 'caldron, cauldron [大锅]',
|
472 |
+
470: 'candle, taper, wax light [蜡烛]',
|
473 |
+
471: 'cannon [大炮]',
|
474 |
+
472: 'canoe [独木舟]',
|
475 |
+
473: 'can opener, tin opener [开瓶器,开罐器]',
|
476 |
+
474: 'cardigan [开衫]',
|
477 |
+
475: 'car mirror [车镜]',
|
478 |
+
476: 'carousel, carrousel, merry-go-round, roundabout, whirligig [旋转木马]',
|
479 |
+
477: 'carpenters kit, tool kit [木匠的工具包,工具包]',
|
480 |
+
478: 'carton [纸箱]',
|
481 |
+
479: 'car wheel [车轮]',
|
482 |
+
480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM [取款机,自动取款机]',
|
483 |
+
481: 'cassette [盒式录音带]',
|
484 |
+
482: 'cassette player [卡带播放器]',
|
485 |
+
483: 'castle [城堡]',
|
486 |
+
484: 'catamaran [双体船]',
|
487 |
+
485: 'CD player [CD播放器]',
|
488 |
+
486: 'cello, violoncello [大提琴]',
|
489 |
+
487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone [移动电话,手机]',
|
490 |
+
488: 'chain [铁链]',
|
491 |
+
489: 'chainlink fence [围栏]',
|
492 |
+
490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour [链甲]',
|
493 |
+
491: 'chain saw, chainsaw [电锯,油锯]',
|
494 |
+
492: 'chest [箱子]',
|
495 |
+
493: 'chiffonier, commode [衣柜,洗脸台]',
|
496 |
+
494: 'chime, bell, gong [编钟,钟,锣]',
|
497 |
+
495: 'china cabinet, china closet [中国橱柜]',
|
498 |
+
496: 'Christmas stocking [圣诞袜]',
|
499 |
+
497: 'church, church building [教堂,教堂建筑]',
|
500 |
+
498: 'cinema, movie theater, movie theatre, movie house, picture palace [电影院,剧场]',
|
501 |
+
499: 'cleaver, meat cleaver, chopper [切肉刀,菜刀]',
|
502 |
+
500: 'cliff dwelling [悬崖屋]',
|
503 |
+
501: 'cloak [斗篷]',
|
504 |
+
502: 'clog, geta, patten, sabot [木屐,木鞋]',
|
505 |
+
503: 'cocktail shaker [鸡尾酒调酒器]',
|
506 |
+
504: 'coffee mug [咖啡杯]',
|
507 |
+
505: 'coffeepot [咖啡壶]',
|
508 |
+
506: 'coil, spiral, volute, whorl, helix [螺旋结构(楼梯)]',
|
509 |
+
507: 'combination lock [组合锁]',
|
510 |
+
508: 'computer keyboard, keypad [电脑键盘,键盘]',
|
511 |
+
509: 'confectionery, confectionary, candy store [糖果,糖果店]',
|
512 |
+
510: 'container ship, containership, container vessel [集装箱船]',
|
513 |
+
511: 'convertible [敞篷车]',
|
514 |
+
512: 'corkscrew, bottle screw [开瓶器,瓶螺杆]',
|
515 |
+
513: 'cornet, horn, trumpet, trump [短号,喇叭]',
|
516 |
+
514: 'cowboy boot [牛仔靴]',
|
517 |
+
515: 'cowboy hat, ten-gallon hat [牛仔帽]',
|
518 |
+
516: 'cradle [摇篮]',
|
519 |
+
517: 'crane [起重机]',
|
520 |
+
518: 'crash helmet [头盔]',
|
521 |
+
519: 'crate [板条箱]',
|
522 |
+
520: 'crib, cot [小儿床]',
|
523 |
+
521: 'Crock Pot [砂锅]',
|
524 |
+
522: 'croquet ball [槌球]',
|
525 |
+
523: 'crutch [拐杖]',
|
526 |
+
524: 'cuirass [胸甲]',
|
527 |
+
525: 'dam, dike, dyke [大坝,堤防]',
|
528 |
+
526: 'desk [书桌]',
|
529 |
+
527: 'desktop computer [台式电脑]',
|
530 |
+
528: 'dial telephone, dial phone [有线电话]',
|
531 |
+
529: 'diaper, nappy, napkin [尿布湿]',
|
532 |
+
530: 'digital clock [数字时钟]',
|
533 |
+
531: 'digital watch [数字手表]',
|
534 |
+
532: 'dining table, board [餐桌板]',
|
535 |
+
533: 'dishrag, dishcloth [抹布]',
|
536 |
+
534: 'dishwasher, dish washer, dishwashing machine [洗碗机,洗碟机]',
|
537 |
+
535: 'disk brake, disc brake [盘式制动器]',
|
538 |
+
536: 'dock, dockage, docking facility [码头,船坞,码头设施]',
|
539 |
+
537: 'dogsled, dog sled, dog sleigh [狗拉雪橇]',
|
540 |
+
538: 'dome [圆顶]',
|
541 |
+
539: 'doormat, welcome mat [门垫,垫子]',
|
542 |
+
540: 'drilling platform, offshore rig [钻井平台,海上钻井]',
|
543 |
+
541: 'drum, membranophone, tympan [鼓,乐器,鼓膜]',
|
544 |
+
542: 'drumstick [鼓槌]',
|
545 |
+
543: 'dumbbell [哑铃]',
|
546 |
+
544: 'Dutch oven [荷兰烤箱]',
|
547 |
+
545: 'electric fan, blower [电风扇,鼓风机]',
|
548 |
+
546: 'electric guitar [电吉他]',
|
549 |
+
547: 'electric locomotive [电力机车]',
|
550 |
+
548: 'entertainment center [电视,电视柜]',
|
551 |
+
549: 'envelope [信封]',
|
552 |
+
550: 'espresso maker [浓缩咖啡机]',
|
553 |
+
551: 'face powder [扑面粉]',
|
554 |
+
552: 'feather boa, boa [女用长围巾]',
|
555 |
+
553: 'file, file cabinet, filing cabinet [文件,文件柜,档案柜]',
|
556 |
+
554: 'fireboat [消防船]',
|
557 |
+
555: 'fire engine, fire truck [消防车]',
|
558 |
+
556: 'fire screen, fireguard [火炉栏]',
|
559 |
+
557: 'flagpole, flagstaff [旗杆]',
|
560 |
+
558: 'flute, transverse flute [长笛]',
|
561 |
+
559: 'folding chair [折叠椅]',
|
562 |
+
560: 'football helmet [橄榄球头盔]',
|
563 |
+
561: 'forklift [叉车]',
|
564 |
+
562: 'fountain [喷泉]',
|
565 |
+
563: 'fountain pen [钢笔]',
|
566 |
+
564: 'four-poster [有四根帷柱的床]',
|
567 |
+
565: 'freight car [运货车厢]',
|
568 |
+
566: 'French horn, horn [圆号,喇叭]',
|
569 |
+
567: 'frying pan, frypan, skillet [煎锅]',
|
570 |
+
568: 'fur coat [裘皮大衣]',
|
571 |
+
569: 'garbage truck, dustcart [垃圾车]',
|
572 |
+
570: 'gasmask, respirator, gas helmet [防毒面具,呼吸器]',
|
573 |
+
571: 'gas pump, gasoline pump, petrol pump, island dispenser [汽油泵]',
|
574 |
+
572: 'goblet [高脚杯]',
|
575 |
+
573: 'go-kart [卡丁车]',
|
576 |
+
574: 'golf ball [高尔夫球]',
|
577 |
+
575: 'golfcart, golf cart [高尔夫球车]',
|
578 |
+
576: 'gondola [狭长小船]',
|
579 |
+
577: 'gong, tam-tam [锣]',
|
580 |
+
578: 'gown [礼服]',
|
581 |
+
579: 'grand piano, grand [钢琴]',
|
582 |
+
580: 'greenhouse, nursery, glasshouse [温室,苗圃]',
|
583 |
+
581: 'grille, radiator grille [散热器格栅]',
|
584 |
+
582: 'grocery store, grocery, food market, market [杂货店,食品市场]',
|
585 |
+
583: 'guillotine [断头台]',
|
586 |
+
584: 'hair slide [小发夹]',
|
587 |
+
585: 'hair spray [头发喷雾]',
|
588 |
+
586: 'half track [半履带装甲车]',
|
589 |
+
587: 'hammer [锤子]',
|
590 |
+
588: 'hamper [大篮子]',
|
591 |
+
589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier [手摇鼓风机,吹风机]',
|
592 |
+
590: 'hand-held computer, hand-held microcomputer [手提电脑]',
|
593 |
+
591: 'handkerchief, hankie, hanky, hankey [手帕]',
|
594 |
+
592: 'hard disc, hard disk, fixed disk [硬盘]',
|
595 |
+
593: 'harmonica, mouth organ, harp, mouth harp [口琴,口风琴]',
|
596 |
+
594: 'harp [竖琴]',
|
597 |
+
595: 'harvester, reaper [收割机]',
|
598 |
+
596: 'hatchet [斧头]',
|
599 |
+
597: 'holster [手枪皮套]',
|
600 |
+
598: 'home theater, home theatre [家庭影院]',
|
601 |
+
599: 'honeycomb [蜂窝]',
|
602 |
+
600: 'hook, claw [钩爪]',
|
603 |
+
601: 'hoopskirt, crinoline [衬裙]',
|
604 |
+
602: 'horizontal bar, high bar [单杠]',
|
605 |
+
603: 'horse cart, horse-cart [马车]',
|
606 |
+
604: 'hourglass [沙漏]',
|
607 |
+
605: 'iPod [手机,iPad]',
|
608 |
+
606: 'iron, smoothing iron [熨斗]',
|
609 |
+
607: 'jack-o-lantern [南瓜灯笼]',
|
610 |
+
608: 'jean, blue jean, denim [牛仔裤,蓝色牛仔裤]',
|
611 |
+
609: 'jeep, landrover [吉普车]',
|
612 |
+
610: 'jersey, T-shirt, tee shirt [运动衫,T恤]',
|
613 |
+
611: 'jigsaw puzzle [拼图]',
|
614 |
+
612: 'jinrikisha, ricksha, rickshaw [人力车]',
|
615 |
+
613: 'joystick [操纵杆]',
|
616 |
+
614: 'kimono [和服]',
|
617 |
+
615: 'knee pad [护膝]',
|
618 |
+
616: 'knot [蝴蝶结]',
|
619 |
+
617: 'lab coat, laboratory coat [大褂,实验室外套]',
|
620 |
+
618: 'ladle [长柄勺]',
|
621 |
+
619: 'lampshade, lamp shade [灯罩]',
|
622 |
+
620: 'laptop, laptop computer [笔记本电脑]',
|
623 |
+
621: 'lawn mower, mower [割草机]',
|
624 |
+
622: 'lens cap, lens cover [镜头盖]',
|
625 |
+
623: 'letter opener, paper knife, paperknife [开信刀,裁纸刀]',
|
626 |
+
624: 'library [图书馆]',
|
627 |
+
625: 'lifeboat [救生艇]',
|
628 |
+
626: 'lighter, light, igniter, ignitor [点火器,打火机]',
|
629 |
+
627: 'limousine, limo [豪华轿车]',
|
630 |
+
628: 'liner, ocean liner [远洋班轮]',
|
631 |
+
629: 'lipstick, lip rouge [唇膏,口红]',
|
632 |
+
630: 'Loafer [平底便鞋]',
|
633 |
+
631: 'lotion [洗剂]',
|
634 |
+
632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system [扬声器]',
|
635 |
+
633: 'loupe, jewelers loupe [放大镜]',
|
636 |
+
634: 'lumbermill, sawmill [锯木厂]',
|
637 |
+
635: 'magnetic compass [磁罗盘]',
|
638 |
+
636: 'mailbag, postbag [邮袋]',
|
639 |
+
637: 'mailbox, letter box [信箱]',
|
640 |
+
638: 'maillot [女游泳衣]',
|
641 |
+
639: 'maillot, tank suit [有肩带浴衣]',
|
642 |
+
640: 'manhole cover [窨井盖]',
|
643 |
+
641: 'maraca [沙球(一种打击乐器)]',
|
644 |
+
642: 'marimba, xylophone [马林巴木琴]',
|
645 |
+
643: 'mask [面膜]',
|
646 |
+
644: 'matchstick [火柴]',
|
647 |
+
645: 'maypole [花柱]',
|
648 |
+
646: 'maze, labyrinth [迷宫]',
|
649 |
+
647: 'measuring cup [量杯]',
|
650 |
+
648: 'medicine chest, medicine cabinet [药箱]',
|
651 |
+
649: 'megalith, megalithic structure [巨石,巨石结构]',
|
652 |
+
650: 'microphone, mike [麦克风]',
|
653 |
+
651: 'microwave, microwave oven [微波炉]',
|
654 |
+
652: 'military uniform [军装]',
|
655 |
+
653: 'milk can [奶桶]',
|
656 |
+
654: 'minibus [迷你巴士]',
|
657 |
+
655: 'miniskirt, mini [迷你裙]',
|
658 |
+
656: 'minivan [面包车]',
|
659 |
+
657: 'missile [导弹]',
|
660 |
+
658: 'mitten [连指手套]',
|
661 |
+
659: 'mixing bowl [搅拌钵]',
|
662 |
+
660: 'mobile home, manufactured home [活动房屋(由汽车拖拉的)]',
|
663 |
+
661: 'Model T [T型发动机小汽车]',
|
664 |
+
662: 'modem [调制解调器]',
|
665 |
+
663: 'monastery [修道院]',
|
666 |
+
664: 'monitor [显示器]',
|
667 |
+
665: 'moped [电瓶车]',
|
668 |
+
666: 'mortar [砂浆]',
|
669 |
+
667: 'mortarboard [学士]',
|
670 |
+
668: 'mosque [清真寺]',
|
671 |
+
669: 'mosquito net [蚊帐]',
|
672 |
+
670: 'motor scooter, scooter [摩托车]',
|
673 |
+
671: 'mountain bike, all-terrain bike, off-roader [山地自行车]',
|
674 |
+
672: 'mountain tent [登山帐]',
|
675 |
+
673: 'mouse, computer mouse [鼠标,电脑鼠标]',
|
676 |
+
674: 'mousetrap [捕鼠器]',
|
677 |
+
675: 'moving van [搬家车]',
|
678 |
+
676: 'muzzle [口套]',
|
679 |
+
677: 'nail [钉子]',
|
680 |
+
678: 'neck brace [颈托]',
|
681 |
+
679: 'necklace [项链]',
|
682 |
+
680: 'nipple [乳头(瓶)]',
|
683 |
+
681: 'notebook, notebook computer [笔记本,笔记本电脑]',
|
684 |
+
682: 'obelisk [方尖碑]',
|
685 |
+
683: 'oboe, hautboy, hautbois [双簧管]',
|
686 |
+
684: 'ocarina, sweet potato [陶笛,卵形笛]',
|
687 |
+
685: 'odometer, hodometer, mileometer, milometer [里程表]',
|
688 |
+
686: 'oil filter [滤油器]',
|
689 |
+
687: 'organ, pipe organ [风琴,管风琴]',
|
690 |
+
688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO [示波器]',
|
691 |
+
689: 'overskirt [罩裙]',
|
692 |
+
690: 'oxcart [牛车]',
|
693 |
+
691: 'oxygen mask [氧气面罩]',
|
694 |
+
692: 'packet [包装]',
|
695 |
+
693: 'paddle, boat paddle [船桨]',
|
696 |
+
694: 'paddlewheel, paddle wheel [明轮,桨轮]',
|
697 |
+
695: 'padlock [挂锁,扣锁]',
|
698 |
+
696: 'paintbrush [画笔]',
|
699 |
+
697: 'pajama, pyjama, pjs, jammies [睡衣]',
|
700 |
+
698: 'palace [宫殿]',
|
701 |
+
699: 'panpipe, pandean pipe, syrinx [排箫,鸣管]',
|
702 |
+
700: 'paper towel [纸巾]',
|
703 |
+
701: 'parachute, chute [降落伞]',
|
704 |
+
702: 'parallel bars, bars [双杠]',
|
705 |
+
703: 'park bench [公园长椅]',
|
706 |
+
704: 'parking meter [停车收费表,停车计时器]',
|
707 |
+
705: 'passenger car, coach, carriage [客车,教练车]',
|
708 |
+
706: 'patio, terrace [露台,阳台]',
|
709 |
+
707: 'pay-phone, pay-station [付费电话]',
|
710 |
+
708: 'pedestal, plinth, footstall [基座,基脚]',
|
711 |
+
709: 'pencil box, pencil case [铅笔盒]',
|
712 |
+
710: 'pencil sharpener [卷笔刀]',
|
713 |
+
711: 'perfume, essence [香水(瓶)]',
|
714 |
+
712: 'Petri dish [培养皿]',
|
715 |
+
713: 'photocopier [复印机]',
|
716 |
+
714: 'pick, plectrum, plectron [拨弦片,拨子]',
|
717 |
+
715: 'pickelhaube [尖顶头盔]',
|
718 |
+
716: 'picket fence, paling [栅栏,栅栏]',
|
719 |
+
717: 'pickup, pickup truck [皮卡,皮卡车]',
|
720 |
+
718: 'pier [桥墩]',
|
721 |
+
719: 'piggy bank, penny bank [存钱罐]',
|
722 |
+
720: 'pill bottle [药瓶]',
|
723 |
+
721: 'pillow [枕头]',
|
724 |
+
722: 'ping-pong ball [乒乓球]',
|
725 |
+
723: 'pinwheel [风车]',
|
726 |
+
724: 'pirate, pirate ship [海盗船]',
|
727 |
+
725: 'pitcher, ewer [水罐]',
|
728 |
+
726: 'plane, carpenters plane, woodworking plane [木工刨]',
|
729 |
+
727: 'planetarium [天文馆]',
|
730 |
+
728: 'plastic bag [塑料袋]',
|
731 |
+
729: 'plate rack [板架]',
|
732 |
+
730: 'plow, plough [犁型铲雪机]',
|
733 |
+
731: 'plunger, plumbers helper [手压皮碗泵]',
|
734 |
+
732: 'Polaroid camera, Polaroid Land camera [宝丽来相机]',
|
735 |
+
733: 'pole [电线杆]',
|
736 |
+
734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria [警车,巡逻车]',
|
737 |
+
735: 'poncho [雨披]',
|
738 |
+
736: 'pool table, billiard table, snooker table [台球桌]',
|
739 |
+
737: 'pop bottle, soda bottle [充气饮料瓶]',
|
740 |
+
738: 'pot, flowerpot [花盆]',
|
741 |
+
739: 'potters wheel [陶工旋盘]',
|
742 |
+
740: 'power drill [电钻]',
|
743 |
+
741: 'prayer rug, prayer mat [祈祷垫,地毯]',
|
744 |
+
742: 'printer [打印机]',
|
745 |
+
743: 'prison, prison house [监狱]',
|
746 |
+
744: 'projectile, missile [炮弹,导弹]',
|
747 |
+
745: 'projector [投影仪]',
|
748 |
+
746: 'puck, hockey puck [冰球]',
|
749 |
+
747: 'punching bag, punch bag, punching ball, punchball [沙包,吊球]',
|
750 |
+
748: 'purse [钱包]',
|
751 |
+
749: 'quill, quill pen [羽管笔]',
|
752 |
+
750: 'quilt, comforter, comfort, puff [被子]',
|
753 |
+
751: 'racer, race car, racing car [赛车]',
|
754 |
+
752: 'racket, racquet [球拍]',
|
755 |
+
753: 'radiator [散热器]',
|
756 |
+
754: 'radio, wireless [收音机]',
|
757 |
+
755: 'radio telescope, radio reflector [射电望远镜,无线电反射器]',
|
758 |
+
756: 'rain barrel [雨桶]',
|
759 |
+
757: 'recreational vehicle, RV, R.V. [休闲车,房车]',
|
760 |
+
758: 'reel [卷轴,卷筒]',
|
761 |
+
759: 'reflex camera [反射式照相机]',
|
762 |
+
760: 'refrigerator, icebox [冰箱,冰柜]',
|
763 |
+
761: 'remote control, remote [遥控器]',
|
764 |
+
762: 'restaurant, eating house, eating place, eatery [餐厅,饮食店,食堂]',
|
765 |
+
763: 'revolver, six-gun, six-shooter [左轮手枪]',
|
766 |
+
764: 'rifle [步枪]',
|
767 |
+
765: 'rocking chair, rocker [摇椅]',
|
768 |
+
766: 'rotisserie [电转烤肉架]',
|
769 |
+
767: 'rubber eraser, rubber, pencil eraser [橡皮]',
|
770 |
+
768: 'rugby ball [橄榄球]',
|
771 |
+
769: 'rule, ruler [直尺]',
|
772 |
+
770: 'running shoe [跑步鞋]',
|
773 |
+
771: 'safe [保险柜]',
|
774 |
+
772: 'safety pin [安全别针]',
|
775 |
+
773: 'saltshaker, salt shaker [盐瓶(调味用)]',
|
776 |
+
774: 'sandal [凉鞋]',
|
777 |
+
775: 'sarong [纱笼,围裙]',
|
778 |
+
776: 'sax, saxophone [萨克斯管]',
|
779 |
+
777: 'scabbard [剑鞘]',
|
780 |
+
778: 'scale, weighing machine [秤,称重机]',
|
781 |
+
779: 'school bus [校车]',
|
782 |
+
780: 'schooner [帆船]',
|
783 |
+
781: 'scoreboard [记分牌]',
|
784 |
+
782: 'screen, CRT screen [屏幕]',
|
785 |
+
783: 'screw [螺丝]',
|
786 |
+
784: 'screwdriver [螺丝刀]',
|
787 |
+
785: 'seat belt, seatbelt [安全带]',
|
788 |
+
786: 'sewing machine [缝纫机]',
|
789 |
+
787: 'shield, buckler [盾牌,盾牌]',
|
790 |
+
788: 'shoe shop, shoe-shop, shoe store [皮鞋店,鞋店]',
|
791 |
+
789: 'shoji [障子]',
|
792 |
+
790: 'shopping basket [购物篮]',
|
793 |
+
791: 'shopping cart [购物车]',
|
794 |
+
792: 'shovel [铁锹]',
|
795 |
+
793: 'shower cap [浴帽]',
|
796 |
+
794: 'shower curtain [浴帘]',
|
797 |
+
795: 'ski [滑雪板]',
|
798 |
+
796: 'ski mask [滑雪面罩]',
|
799 |
+
797: 'sleeping bag [睡袋]',
|
800 |
+
798: 'slide rule, slipstick [滑尺]',
|
801 |
+
799: 'sliding door [滑动门]',
|
802 |
+
800: 'slot, one-armed bandit [角子老虎机]',
|
803 |
+
801: 'snorkel [潜水通气管]',
|
804 |
+
802: 'snowmobile [雪橇]',
|
805 |
+
803: 'snowplow, snowplough [扫雪机,扫雪机]',
|
806 |
+
804: 'soap dispenser [皂液器]',
|
807 |
+
805: 'soccer ball [足球]',
|
808 |
+
806: 'sock [袜子]',
|
809 |
+
807: 'solar dish, solar collector, solar furnace [碟式太阳能,太阳能集热器,太阳能炉]',
|
810 |
+
808: 'sombrero [宽边帽]',
|
811 |
+
809: 'soup bowl [汤碗]',
|
812 |
+
810: 'space bar [空格键]',
|
813 |
+
811: 'space heater [空间加热器]',
|
814 |
+
812: 'space shuttle [航天飞机]',
|
815 |
+
813: 'spatula [铲(搅拌或涂敷用的)]',
|
816 |
+
814: 'speedboat [快艇]',
|
817 |
+
815: 'spider web, spiders web [蜘蛛网]',
|
818 |
+
816: 'spindle [纺锤,纱锭]',
|
819 |
+
817: 'sports car, sport car [跑车]',
|
820 |
+
818: 'spotlight, spot [聚光灯]',
|
821 |
+
819: 'stage [舞台]',
|
822 |
+
820: 'steam locomotive [蒸汽机车]',
|
823 |
+
821: 'steel arch bridge [钢拱桥]',
|
824 |
+
822: 'steel drum [钢滚筒]',
|
825 |
+
823: 'stethoscope [听诊器]',
|
826 |
+
824: 'stole [女用披肩]',
|
827 |
+
825: 'stone wall [石头墙]',
|
828 |
+
826: 'stopwatch, stop watch [秒表]',
|
829 |
+
827: 'stove [火炉]',
|
830 |
+
828: 'strainer [过滤器]',
|
831 |
+
829: 'streetcar, tram, tramcar, trolley, trolley car [有轨电车,电车]',
|
832 |
+
830: 'stretcher [担架]',
|
833 |
+
831: 'studio couch, day bed [沙发床]',
|
834 |
+
832: 'stupa, tope [佛塔]',
|
835 |
+
833: 'submarine, pigboat, sub, U-boat [潜艇,潜水艇]',
|
836 |
+
834: 'suit, suit of clothes [套装,衣服]',
|
837 |
+
835: 'sundial [日晷]',
|
838 |
+
836: 'sunglass [太阳镜]',
|
839 |
+
837: 'sunglasses, dark glasses, shades [太阳镜,墨镜]',
|
840 |
+
838: 'sunscreen, sunblock, sun blocker [防晒霜,防晒剂]',
|
841 |
+
839: 'suspension bridge [悬索桥]',
|
842 |
+
840: 'swab, swob, mop [拖把]',
|
843 |
+
841: 'sweatshirt [运动衫]',
|
844 |
+
842: 'swimming trunks, bathing trunks [游泳裤]',
|
845 |
+
843: 'swing [秋千]',
|
846 |
+
844: 'switch, electric switch, electrical switch [开关,电器开关]',
|
847 |
+
845: 'syringe [注射器]',
|
848 |
+
846: 'table lamp [台灯]',
|
849 |
+
847: 'tank, army tank, armored combat vehicle, armoured combat vehicle [坦克,装甲战车,装甲战斗车辆]',
|
850 |
+
848: 'tape player [磁带播放器]',
|
851 |
+
849: 'teapot [茶壶]',
|
852 |
+
850: 'teddy, teddy bear [泰迪,泰迪熊]',
|
853 |
+
851: 'television, television system [电视]',
|
854 |
+
852: 'tennis ball [网球]',
|
855 |
+
853: 'thatch, thatched roof [茅草,茅草屋顶]',
|
856 |
+
854: 'theater curtain, theatre curtain [幕布,剧院的帷幕]',
|
857 |
+
855: 'thimble [顶针]',
|
858 |
+
856: 'thresher, thrasher, threshing machine [脱粒机]',
|
859 |
+
857: 'throne [宝座]',
|
860 |
+
858: 'tile roof [瓦屋顶]',
|
861 |
+
859: 'toaster [烤面包机]',
|
862 |
+
860: 'tobacco shop, tobacconist shop, tobacconist [烟草店,烟草]',
|
863 |
+
861: 'toilet seat [马桶]',
|
864 |
+
862: 'torch [火炬]',
|
865 |
+
863: 'totem pole [图腾柱]',
|
866 |
+
864: 'tow truck, tow car, wrecker [拖车,牵引车,清障车]',
|
867 |
+
865: 'toyshop [玩具店]',
|
868 |
+
866: 'tractor [拖拉机]',
|
869 |
+
867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi [拖车,铰接式卡车]',
|
870 |
+
868: 'tray [托盘]',
|
871 |
+
869: 'trench coat [风衣]',
|
872 |
+
870: 'tricycle, trike, velocipede [三轮车]',
|
873 |
+
871: 'trimaran [三体船]',
|
874 |
+
872: 'tripod [三脚架]',
|
875 |
+
873: 'triumphal arch [凯旋门]',
|
876 |
+
874: 'trolleybus, trolley coach, trackless trolley [无轨电车]',
|
877 |
+
875: 'trombone [长号]',
|
878 |
+
876: 'tub, vat [浴盆,浴缸]',
|
879 |
+
877: 'turnstile [旋转式栅门]',
|
880 |
+
878: 'typewriter keyboard [打字机键盘]',
|
881 |
+
879: 'umbrella [伞]',
|
882 |
+
880: 'unicycle, monocycle [独轮车]',
|
883 |
+
881: 'upright, upright piano [直立式钢琴]',
|
884 |
+
882: 'vacuum, vacuum cleaner [真空吸尘器]',
|
885 |
+
883: 'vase [花瓶]',
|
886 |
+
884: 'vault [拱顶]',
|
887 |
+
885: 'velvet [天鹅绒]',
|
888 |
+
886: 'vending machine [自动售货机]',
|
889 |
+
887: 'vestment [祭服]',
|
890 |
+
888: 'viaduct [高架桥]',
|
891 |
+
889: 'violin, fiddle [小提琴,小提琴]',
|
892 |
+
890: 'volleyball [排球]',
|
893 |
+
891: 'waffle iron [松饼机]',
|
894 |
+
892: 'wall clock [挂钟]',
|
895 |
+
893: 'wallet, billfold, notecase, pocketbook [钱包,皮夹]',
|
896 |
+
894: 'wardrobe, closet, press [衣柜,壁橱]',
|
897 |
+
895: 'warplane, military plane [军用飞机]',
|
898 |
+
896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin [洗脸盆,洗手盆]',
|
899 |
+
897: 'washer, automatic washer, washing machine [洗衣机,自动洗衣机]',
|
900 |
+
898: 'water bottle [水瓶]',
|
901 |
+
899: 'water jug [水壶]',
|
902 |
+
900: 'water tower [水塔]',
|
903 |
+
901: 'whiskey jug [威士忌壶]',
|
904 |
+
902: 'whistle [哨子]',
|
905 |
+
903: 'wig [假发]',
|
906 |
+
904: 'window screen [纱窗]',
|
907 |
+
905: 'window shade [百叶窗]',
|
908 |
+
906: 'Windsor tie [温莎领带]',
|
909 |
+
907: 'wine bottle [葡萄酒瓶]',
|
910 |
+
908: 'wing [飞机翅膀,飞机]',
|
911 |
+
909: 'wok [炒菜锅]',
|
912 |
+
910: 'wooden spoon [木制的勺子]',
|
913 |
+
911: 'wool, woolen, woollen [毛织品,羊绒]',
|
914 |
+
912: 'worm fence, snake fence, snake-rail fence, Virginia fence [栅栏,围栏]',
|
915 |
+
913: 'wreck [沉船]',
|
916 |
+
914: 'yawl [双桅船]',
|
917 |
+
915: 'yurt [蒙古包]',
|
918 |
+
916: 'web site, website, internet site, site [网站,互联网网站]',
|
919 |
+
917: 'comic book [漫画]',
|
920 |
+
918: 'crossword puzzle, crossword [纵横字谜]',
|
921 |
+
919: 'street sign [路标]',
|
922 |
+
920: 'traffic light, traffic signal, stoplight [交通信号灯]',
|
923 |
+
921: 'book jacket, dust cover, dust jacket, dust wrapper [防尘罩,书皮]',
|
924 |
+
922: 'menu [菜单]',
|
925 |
+
923: 'plate [盘子]',
|
926 |
+
924: 'guacamole [鳄梨酱]',
|
927 |
+
925: 'consomme [清汤]',
|
928 |
+
926: 'hot pot, hotpot [罐焖土豆烧肉]',
|
929 |
+
927: 'trifle [蛋糕]',
|
930 |
+
928: 'ice cream, icecream [冰淇淋]',
|
931 |
+
929: 'ice lolly, lolly, lollipop, popsicle [雪糕,冰棍,冰棒]',
|
932 |
+
930: 'French loaf [法式面包]',
|
933 |
+
931: 'bagel, beigel [百吉饼]',
|
934 |
+
932: 'pretzel [椒盐脆饼]',
|
935 |
+
933: 'cheeseburger [芝士汉堡]',
|
936 |
+
934: 'hotdog, hot dog, red hot [热狗]',
|
937 |
+
935: 'mashed potato [土豆泥]',
|
938 |
+
936: 'head cabbage [结球甘蓝]',
|
939 |
+
937: 'broccoli [西兰花]',
|
940 |
+
938: 'cauliflower [菜花]',
|
941 |
+
939: 'zucchini, courgette [绿皮密生西葫芦]',
|
942 |
+
940: 'spaghetti squash [西葫芦]',
|
943 |
+
941: 'acorn squash [小青南瓜]',
|
944 |
+
942: 'butternut squash [南瓜]',
|
945 |
+
943: 'cucumber, cuke [黄瓜]',
|
946 |
+
944: 'artichoke, globe artichoke [朝鲜蓟]',
|
947 |
+
945: 'bell pepper [甜椒]',
|
948 |
+
946: 'cardoon [刺棘蓟]',
|
949 |
+
947: 'mushroom [蘑菇]',
|
950 |
+
948: 'Granny Smith [绿苹果]',
|
951 |
+
949: 'strawberry [草莓]',
|
952 |
+
950: 'orange [橘子]',
|
953 |
+
951: 'lemon [柠檬]',
|
954 |
+
952: 'fig [无花果]',
|
955 |
+
953: 'pineapple, ananas [菠萝]',
|
956 |
+
954: 'banana [香蕉]',
|
957 |
+
955: 'jackfruit, jak, jack [菠萝蜜]',
|
958 |
+
956: 'custard apple [蛋奶冻苹果]',
|
959 |
+
957: 'pomegranate [石榴]',
|
960 |
+
958: 'hay [干草]',
|
961 |
+
959: 'carbonara [烤面条加干酪沙司]',
|
962 |
+
960: 'chocolate sauce, chocolate syrup [巧克力酱,巧克力糖浆]',
|
963 |
+
961: 'dough [面团]',
|
964 |
+
962: 'meat loaf, meatloaf [瑞士肉包,肉饼]',
|
965 |
+
963: 'pizza, pizza pie [披萨,披萨饼]',
|
966 |
+
964: 'potpie [馅饼]',
|
967 |
+
965: 'burrito [卷饼]',
|
968 |
+
966: 'red wine [红葡萄酒]',
|
969 |
+
967: 'espresso [意大利浓咖啡]',
|
970 |
+
968: 'cup [杯子]',
|
971 |
+
969: 'eggnog [蛋酒]',
|
972 |
+
970: 'alp [高山]',
|
973 |
+
971: 'bubble [泡泡]',
|
974 |
+
972: 'cliff, drop, drop-off [悬崖]',
|
975 |
+
973: 'coral reef [珊瑚礁]',
|
976 |
+
974: 'geyser [间歇泉]',
|
977 |
+
975: 'lakeside, lakeshore [湖边,湖岸]',
|
978 |
+
976: 'promontory, headland, head, foreland [海角]',
|
979 |
+
977: 'sandbar, sand bar [沙洲,沙坝]',
|
980 |
+
978: 'seashore, coast, seacoast, sea-coast [海滨,海岸]',
|
981 |
+
979: 'valley, vale [峡谷]',
|
982 |
+
980: 'volcano [火山]',
|
983 |
+
981: 'ballplayer, baseball player [棒球,棒球运动员]',
|
984 |
+
982: 'groom, bridegroom [新郎]',
|
985 |
+
983: 'scuba diver [潜水员]',
|
986 |
+
984: 'rapeseed [油菜]',
|
987 |
+
985: 'daisy [雏菊]',
|
988 |
+
986: 'yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum [杓兰]',
|
989 |
+
987: 'corn [玉米]',
|
990 |
+
988: 'acorn [橡子]',
|
991 |
+
989: 'hip, rose hip, rosehip [玫瑰果]',
|
992 |
+
990: 'buckeye, horse chestnut, conker [七叶树果实]',
|
993 |
+
991: 'coral fungus [珊瑚菌]',
|
994 |
+
992: 'agaric [木耳]',
|
995 |
+
993: 'gyromitra [鹿花菌]',
|
996 |
+
994: 'stinkhorn, carrion fungus [鬼笔菌]',
|
997 |
+
995: 'earthstar [地星(菌类)]',
|
998 |
+
996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa [多叶奇果菌]',
|
999 |
+
997: 'bolete [牛肝菌]',
|
1000 |
+
998: 'ear, spike, capitulum [玉米穗]',
|
1001 |
+
999: 'toilet tissue, toilet paper, bathroom tissue [卫生纸]',
|
1002 |
+
}
|
models/generate.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
|
3 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import torch._dynamo.config
|
8 |
+
import torch._inductor.config
|
9 |
+
import copy
|
10 |
+
# torch._inductor.config.coordinate_descent_tuning = True
|
11 |
+
# torch._inductor.config.triton.unique_kernel_names = True
|
12 |
+
# torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
13 |
+
|
14 |
+
|
15 |
+
### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
|
16 |
+
def top_k_top_p_filtering(
|
17 |
+
logits,
|
18 |
+
top_k: int = 0,
|
19 |
+
top_p: float = 1.0,
|
20 |
+
filter_value: float = -float("Inf"),
|
21 |
+
min_tokens_to_keep: int = 1,
|
22 |
+
):
|
23 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
24 |
+
Args:
|
25 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
26 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
27 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
28 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
29 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
30 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
31 |
+
"""
|
32 |
+
if top_k > 0:
|
33 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
34 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
35 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
36 |
+
logits[indices_to_remove] = filter_value
|
37 |
+
|
38 |
+
if top_p < 1.0:
|
39 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
40 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
41 |
+
|
42 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
43 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
44 |
+
if min_tokens_to_keep > 1:
|
45 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
46 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
47 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
48 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
49 |
+
sorted_indices_to_remove[..., 0] = 0
|
50 |
+
|
51 |
+
# scatter sorted tensors to original indexing
|
52 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
53 |
+
logits[indices_to_remove] = filter_value
|
54 |
+
return logits
|
55 |
+
|
56 |
+
|
57 |
+
def sample(logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True):
|
58 |
+
logits = logits[:, -1, :] / max(temperature, 1e-5)
|
59 |
+
if top_k > 0 or top_p < 1.0:
|
60 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
61 |
+
probs = F.softmax(logits, dim=-1)
|
62 |
+
if sample_logits:
|
63 |
+
idx = torch.multinomial(probs, num_samples=1)
|
64 |
+
else:
|
65 |
+
_, idx = torch.topk(probs, k=1, dim=-1)
|
66 |
+
return idx, probs
|
67 |
+
|
68 |
+
|
69 |
+
def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs):
|
70 |
+
logits = logits / max(temperature, 1e-5)
|
71 |
+
if top_k > 0 or top_p < 1.0:
|
72 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
73 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
74 |
+
return probs
|
75 |
+
|
76 |
+
|
77 |
+
def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, **sampling_kwargs):
|
78 |
+
if cfg_scale > 1.0:
|
79 |
+
logits, _ = model(None, cond_idx, input_pos)
|
80 |
+
logits_combined = logits
|
81 |
+
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
|
82 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
83 |
+
else:
|
84 |
+
logits, _ = model(None, cond_idx, input_pos)
|
85 |
+
|
86 |
+
return sample(logits, **sampling_kwargs)[0]
|
87 |
+
|
88 |
+
|
89 |
+
def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, **sampling_kwargs):
|
90 |
+
assert input_pos.shape[-1] == 1
|
91 |
+
if cfg_scale > 1.0:
|
92 |
+
x_combined = torch.cat([x, x])
|
93 |
+
logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos)
|
94 |
+
logits_combined = logits
|
95 |
+
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
|
96 |
+
if cfg_flag:
|
97 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
98 |
+
else:
|
99 |
+
logits = cond_logits
|
100 |
+
else:
|
101 |
+
logits, _ = model(x, cond_idx=None, input_pos=input_pos)
|
102 |
+
return sample(logits, **sampling_kwargs)
|
103 |
+
|
104 |
+
|
105 |
+
def decode_n_tokens(
|
106 |
+
model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
|
107 |
+
cfg_scale: float, cfg_interval: int,
|
108 |
+
**sampling_kwargs):
|
109 |
+
new_tokens, new_probs = [], []
|
110 |
+
cfg_flag = True
|
111 |
+
for i in range(num_new_tokens):
|
112 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
|
113 |
+
if cfg_interval > -1 and i > cfg_interval:
|
114 |
+
cfg_flag = False
|
115 |
+
next_token, next_prob = decode_one_token(
|
116 |
+
model, cur_token, input_pos, cfg_scale, cfg_flag, **sampling_kwargs
|
117 |
+
)
|
118 |
+
input_pos += 1
|
119 |
+
new_tokens.append(next_token.clone())
|
120 |
+
new_probs.append(next_prob.clone())
|
121 |
+
cur_token = next_token.view(-1, 1)
|
122 |
+
|
123 |
+
return new_tokens, new_probs
|
124 |
+
|
125 |
+
|
126 |
+
@torch.no_grad()
|
127 |
+
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, **sampling_kwargs):
|
128 |
+
if model.model_type == 'c2i':
|
129 |
+
if cfg_scale > 1.0:
|
130 |
+
cond_null = torch.ones_like(cond) * model.num_classes
|
131 |
+
cond_combined = torch.cat([cond, cond_null])
|
132 |
+
else:
|
133 |
+
cond_combined = cond
|
134 |
+
T = 1
|
135 |
+
elif model.model_type == 't2i':
|
136 |
+
if cfg_scale > 1.0:
|
137 |
+
cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding
|
138 |
+
cond_combined = torch.cat([cond, cond_null])
|
139 |
+
else:
|
140 |
+
cond_combined = cond
|
141 |
+
T = cond.shape[1]
|
142 |
+
else:
|
143 |
+
raise Exception("please check model type")
|
144 |
+
|
145 |
+
T_new = T + max_new_tokens
|
146 |
+
max_seq_length = T_new
|
147 |
+
max_batch_size = cond.shape[0]
|
148 |
+
|
149 |
+
device = cond.device
|
150 |
+
with torch.device(device):
|
151 |
+
max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
|
152 |
+
model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype)
|
153 |
+
|
154 |
+
if emb_masks is not None:
|
155 |
+
assert emb_masks.shape[0] == max_batch_size
|
156 |
+
assert emb_masks.shape[-1] == T
|
157 |
+
if cfg_scale > 1.0:
|
158 |
+
model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
|
159 |
+
else:
|
160 |
+
model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
|
161 |
+
|
162 |
+
eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
|
163 |
+
model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
|
164 |
+
|
165 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
166 |
+
seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
|
167 |
+
|
168 |
+
input_pos = torch.arange(0, T, device=device)
|
169 |
+
next_token = prefill(model, cond_combined, input_pos, cfg_scale, **sampling_kwargs)
|
170 |
+
seq[:, T:T+1] = next_token
|
171 |
+
|
172 |
+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
173 |
+
generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, **sampling_kwargs)
|
174 |
+
seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
|
175 |
+
|
176 |
+
return seq[:, T:]
|
models/gpt.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
|
3 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
4 |
+
# nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
|
5 |
+
# llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py
|
6 |
+
# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
|
7 |
+
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Optional, List
|
10 |
+
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
|
17 |
+
def find_multiple(n: int, k: int):
|
18 |
+
if n % k == 0:
|
19 |
+
return n
|
20 |
+
return n + k - (n % k)
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class ModelArgs:
|
24 |
+
dim: int = 4096
|
25 |
+
n_layer: int = 32
|
26 |
+
n_head: int = 32
|
27 |
+
n_kv_head: Optional[int] = None
|
28 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
29 |
+
ffn_dim_multiplier: Optional[float] = None
|
30 |
+
rope_base: float = 10000
|
31 |
+
norm_eps: float = 1e-5
|
32 |
+
initializer_range: float = 0.02
|
33 |
+
|
34 |
+
token_dropout_p: float = 0.1
|
35 |
+
attn_dropout_p: float = 0.0
|
36 |
+
resid_dropout_p: float = 0.1
|
37 |
+
ffn_dropout_p: float = 0.1
|
38 |
+
drop_path_rate: float = 0.0
|
39 |
+
|
40 |
+
num_classes: int = 1000
|
41 |
+
caption_dim: int = 2048
|
42 |
+
class_dropout_prob: float = 0.1
|
43 |
+
model_type: str = 'c2i'
|
44 |
+
|
45 |
+
vocab_size: int = 16384
|
46 |
+
cls_token_num: int = 1
|
47 |
+
block_size: int = 256
|
48 |
+
max_batch_size: int = 32
|
49 |
+
max_seq_len: int = 2048
|
50 |
+
|
51 |
+
|
52 |
+
#################################################################################
|
53 |
+
# Embedding Layers for Class Labels #
|
54 |
+
#################################################################################
|
55 |
+
class LabelEmbedder(nn.Module):
|
56 |
+
"""
|
57 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
58 |
+
"""
|
59 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
60 |
+
super().__init__()
|
61 |
+
use_cfg_embedding = dropout_prob > 0
|
62 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
63 |
+
self.num_classes = num_classes
|
64 |
+
self.dropout_prob = dropout_prob
|
65 |
+
|
66 |
+
def token_drop(self, labels, force_drop_ids=None):
|
67 |
+
"""
|
68 |
+
Drops labels to enable classifier-free guidance.
|
69 |
+
"""
|
70 |
+
if force_drop_ids is None:
|
71 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
72 |
+
else:
|
73 |
+
drop_ids = force_drop_ids == 1
|
74 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
75 |
+
return labels
|
76 |
+
|
77 |
+
def forward(self, labels, train, force_drop_ids=None):
|
78 |
+
use_dropout = self.dropout_prob > 0
|
79 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
80 |
+
labels = self.token_drop(labels, force_drop_ids)
|
81 |
+
embeddings = self.embedding_table(labels).unsqueeze(1)
|
82 |
+
return embeddings
|
83 |
+
|
84 |
+
|
85 |
+
#################################################################################
|
86 |
+
# Embedding Layers for Text Feature #
|
87 |
+
#################################################################################
|
88 |
+
class CaptionEmbedder(nn.Module):
|
89 |
+
"""
|
90 |
+
Embeds text caption into vector representations. Also handles label dropout for classifier-free guidance.
|
91 |
+
"""
|
92 |
+
def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120):
|
93 |
+
super().__init__()
|
94 |
+
self.cap_proj = MLP(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size)
|
95 |
+
self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
|
96 |
+
self.uncond_prob = uncond_prob
|
97 |
+
|
98 |
+
def token_drop(self, caption, force_drop_ids=None):
|
99 |
+
"""
|
100 |
+
Drops labels to enable classifier-free guidance.
|
101 |
+
"""
|
102 |
+
if force_drop_ids is None:
|
103 |
+
drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
|
104 |
+
else:
|
105 |
+
drop_ids = force_drop_ids == 1
|
106 |
+
caption = torch.where(drop_ids[:, None, None], self.uncond_embedding, caption)
|
107 |
+
return caption
|
108 |
+
|
109 |
+
def forward(self, caption, train, force_drop_ids=None):
|
110 |
+
use_dropout = self.uncond_prob > 0
|
111 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
112 |
+
caption = self.token_drop(caption, force_drop_ids)
|
113 |
+
embeddings = self.cap_proj(caption)
|
114 |
+
return embeddings
|
115 |
+
|
116 |
+
|
117 |
+
class MLP(nn.Module):
|
118 |
+
def __init__(self, in_features, hidden_features, out_features):
|
119 |
+
super().__init__()
|
120 |
+
out_features = out_features or in_features
|
121 |
+
hidden_features = hidden_features or in_features
|
122 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
|
123 |
+
self.act = nn.GELU(approximate='tanh')
|
124 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
x = self.fc1(x)
|
128 |
+
x = self.act(x)
|
129 |
+
x = self.fc2(x)
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
#################################################################################
|
134 |
+
# GPT Model #
|
135 |
+
#################################################################################
|
136 |
+
class RMSNorm(torch.nn.Module):
|
137 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
138 |
+
super().__init__()
|
139 |
+
self.eps = eps
|
140 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
141 |
+
|
142 |
+
def _norm(self, x):
|
143 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
output = self._norm(x.float()).type_as(x)
|
147 |
+
return output * self.weight
|
148 |
+
|
149 |
+
|
150 |
+
class FeedForward(nn.Module):
|
151 |
+
def __init__(self, config: ModelArgs):
|
152 |
+
super().__init__()
|
153 |
+
hidden_dim = 4 * config.dim
|
154 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
155 |
+
# custom dim factor multiplier
|
156 |
+
if config.ffn_dim_multiplier is not None:
|
157 |
+
hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
|
158 |
+
hidden_dim = find_multiple(hidden_dim, config.multiple_of)
|
159 |
+
|
160 |
+
self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
|
161 |
+
self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
|
162 |
+
self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
|
163 |
+
self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
167 |
+
|
168 |
+
|
169 |
+
class KVCache(nn.Module):
|
170 |
+
def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
|
171 |
+
super().__init__()
|
172 |
+
cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
|
173 |
+
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
|
174 |
+
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
|
175 |
+
|
176 |
+
def update(self, input_pos, k_val, v_val):
|
177 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
178 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
179 |
+
k_out = self.k_cache
|
180 |
+
v_out = self.v_cache
|
181 |
+
k_out[:, :, input_pos] = k_val
|
182 |
+
v_out[:, :, input_pos] = v_val
|
183 |
+
|
184 |
+
return k_out, v_out
|
185 |
+
|
186 |
+
|
187 |
+
class Attention(nn.Module):
|
188 |
+
def __init__(self, config: ModelArgs):
|
189 |
+
super().__init__()
|
190 |
+
assert config.dim % config.n_head == 0
|
191 |
+
self.dim = config.dim
|
192 |
+
self.head_dim = config.dim // config.n_head
|
193 |
+
self.n_head = config.n_head
|
194 |
+
self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
|
195 |
+
total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
|
196 |
+
|
197 |
+
# key, query, value projections for all heads, but in a batch
|
198 |
+
self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
|
199 |
+
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
200 |
+
self.kv_cache = None
|
201 |
+
|
202 |
+
# regularization
|
203 |
+
self.attn_dropout_p = config.attn_dropout_p
|
204 |
+
self.resid_dropout = nn.Dropout(config.resid_dropout_p)
|
205 |
+
|
206 |
+
def forward(
|
207 |
+
self, x: torch.Tensor, freqs_cis: torch.Tensor = None,
|
208 |
+
input_pos: Optional[torch.Tensor] = None,
|
209 |
+
mask: Optional[torch.Tensor] = None
|
210 |
+
):
|
211 |
+
bsz, seqlen, _ = x.shape
|
212 |
+
kv_size = self.n_kv_head * self.head_dim
|
213 |
+
xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
214 |
+
|
215 |
+
xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
|
216 |
+
xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
|
217 |
+
xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
|
218 |
+
|
219 |
+
xq = apply_rotary_emb(xq, freqs_cis)
|
220 |
+
xk = apply_rotary_emb(xk, freqs_cis)
|
221 |
+
|
222 |
+
xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
|
223 |
+
|
224 |
+
if self.kv_cache is not None:
|
225 |
+
keys, values = self.kv_cache.update(input_pos, xk, xv)
|
226 |
+
else:
|
227 |
+
keys, values = xk, xv
|
228 |
+
keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
|
229 |
+
values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
|
230 |
+
|
231 |
+
output = F.scaled_dot_product_attention(
|
232 |
+
xq, keys, values,
|
233 |
+
attn_mask=mask,
|
234 |
+
is_causal=True if mask is None else False, # is_causal=False is for KV cache
|
235 |
+
dropout_p=self.attn_dropout_p if self.training else 0)
|
236 |
+
|
237 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
238 |
+
|
239 |
+
output = self.resid_dropout(self.wo(output))
|
240 |
+
return output
|
241 |
+
|
242 |
+
|
243 |
+
class TransformerBlock(nn.Module):
|
244 |
+
def __init__(self, config: ModelArgs, drop_path: float):
|
245 |
+
super().__init__()
|
246 |
+
self.attention = Attention(config)
|
247 |
+
self.feed_forward = FeedForward(config)
|
248 |
+
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
249 |
+
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
250 |
+
|
251 |
+
def forward(
|
252 |
+
self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
|
253 |
+
h = x + self.attention(self.attention_norm(x), freqs_cis, start_pos, mask)
|
254 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
255 |
+
return out
|
256 |
+
|
257 |
+
|
258 |
+
class Transformer(nn.Module):
|
259 |
+
def __init__(self, config: ModelArgs):
|
260 |
+
super().__init__()
|
261 |
+
self.config = config
|
262 |
+
self.vocab_size = config.vocab_size
|
263 |
+
self.n_layer = config.n_layer
|
264 |
+
self.block_size = config.block_size
|
265 |
+
self.num_classes = config.num_classes
|
266 |
+
self.model_type = config.model_type
|
267 |
+
self.cls_token_num = config.cls_token_num
|
268 |
+
if self.model_type == 'c2i':
|
269 |
+
self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
|
270 |
+
elif self.model_type == 't2i':
|
271 |
+
self.cls_embedding = CaptionEmbedder(config.caption_dim, config.dim, config.class_dropout_prob)
|
272 |
+
else:
|
273 |
+
raise Exception("please check model type")
|
274 |
+
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
275 |
+
self.tok_dropout = nn.Dropout(config.token_dropout_p)
|
276 |
+
|
277 |
+
# transformer blocks
|
278 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
|
279 |
+
self.layers = torch.nn.ModuleList()
|
280 |
+
for layer_id in range(config.n_layer):
|
281 |
+
self.layers.append(TransformerBlock(config, dpr[layer_id]))
|
282 |
+
|
283 |
+
# output layer
|
284 |
+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
285 |
+
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
286 |
+
|
287 |
+
# 2d rotary pos embedding
|
288 |
+
grid_size = int(self.block_size ** 0.5)
|
289 |
+
assert grid_size * grid_size == self.block_size
|
290 |
+
self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
|
291 |
+
|
292 |
+
# KVCache
|
293 |
+
self.max_batch_size = -1
|
294 |
+
self.max_seq_length = -1
|
295 |
+
|
296 |
+
self.initialize_weights()
|
297 |
+
|
298 |
+
def initialize_weights(self):
|
299 |
+
# Initialize nn.Linear and nn.Embedding
|
300 |
+
self.apply(self._init_weights)
|
301 |
+
|
302 |
+
# Zero-out output layers:
|
303 |
+
nn.init.constant_(self.output.weight, 0)
|
304 |
+
|
305 |
+
def _init_weights(self, module):
|
306 |
+
std = self.config.initializer_range
|
307 |
+
if isinstance(module, nn.Linear):
|
308 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
309 |
+
if module.bias is not None:
|
310 |
+
module.bias.data.zero_()
|
311 |
+
elif isinstance(module, nn.Embedding):
|
312 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
313 |
+
|
314 |
+
def setup_caches(self, max_batch_size, max_seq_length, dtype):
|
315 |
+
# if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
|
316 |
+
# return
|
317 |
+
head_dim = self.config.dim // self.config.n_head
|
318 |
+
max_seq_length = find_multiple(max_seq_length, 8)
|
319 |
+
self.max_seq_length = max_seq_length
|
320 |
+
self.max_batch_size = max_batch_size
|
321 |
+
for b in self.layers:
|
322 |
+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)
|
323 |
+
|
324 |
+
causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
|
325 |
+
self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
|
326 |
+
grid_size = int(self.config.block_size ** 0.5)
|
327 |
+
assert grid_size * grid_size == self.block_size
|
328 |
+
self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
|
329 |
+
|
330 |
+
def forward(
|
331 |
+
self,
|
332 |
+
idx: torch.Tensor,
|
333 |
+
cond_idx: torch.Tensor, # cond_idx_or_embed
|
334 |
+
input_pos: Optional[torch.Tensor] = None,
|
335 |
+
targets: Optional[torch.Tensor] = None,
|
336 |
+
mask: Optional[torch.Tensor] = None,
|
337 |
+
valid: Optional[torch.Tensor] = None,
|
338 |
+
):
|
339 |
+
if idx is not None and cond_idx is not None: # training or naive inference
|
340 |
+
cond_embeddings = self.cls_embedding(cond_idx, train=self.training)[:,:self.cls_token_num]
|
341 |
+
token_embeddings = self.tok_embeddings(idx)
|
342 |
+
token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
|
343 |
+
h = self.tok_dropout(token_embeddings)
|
344 |
+
self.freqs_cis = self.freqs_cis.to(h.device)
|
345 |
+
else:
|
346 |
+
if cond_idx is not None: # prefill in inference
|
347 |
+
token_embeddings = self.cls_embedding(cond_idx, train=self.training)[:,:self.cls_token_num]
|
348 |
+
else: # decode_n_tokens(kv cache) in inference
|
349 |
+
token_embeddings = self.tok_embeddings(idx)
|
350 |
+
|
351 |
+
bs = token_embeddings.shape[0]
|
352 |
+
mask = self.causal_mask[:bs, None, input_pos]
|
353 |
+
h = self.tok_dropout(token_embeddings)
|
354 |
+
self.freqs_cis = self.freqs_cis
|
355 |
+
|
356 |
+
if self.training:
|
357 |
+
freqs_cis = self.freqs_cis[:token_embeddings.shape[1]]
|
358 |
+
else:
|
359 |
+
freqs_cis = self.freqs_cis[input_pos]
|
360 |
+
# transformer blocks
|
361 |
+
for layer in self.layers:
|
362 |
+
h = layer(h, freqs_cis, input_pos, mask)
|
363 |
+
|
364 |
+
# output layers
|
365 |
+
h = self.norm(h)
|
366 |
+
logits = self.output(h).float()
|
367 |
+
|
368 |
+
if self.training:
|
369 |
+
logits = logits[:, self.cls_token_num - 1:].contiguous()
|
370 |
+
|
371 |
+
# if we are given some desired targets also calculate the loss
|
372 |
+
loss = None
|
373 |
+
if valid is not None:
|
374 |
+
loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
|
375 |
+
valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1)
|
376 |
+
loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
|
377 |
+
elif targets is not None:
|
378 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
379 |
+
|
380 |
+
return logits, loss
|
381 |
+
|
382 |
+
|
383 |
+
def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
|
384 |
+
return list(self.layers)
|
385 |
+
|
386 |
+
|
387 |
+
|
388 |
+
#################################################################################
|
389 |
+
# Rotary Positional Embedding Functions #
|
390 |
+
#################################################################################
|
391 |
+
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
|
392 |
+
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
|
393 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
394 |
+
t = torch.arange(seq_len, device=freqs.device)
|
395 |
+
freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
|
396 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
397 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
|
398 |
+
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
|
399 |
+
return cond_cache
|
400 |
+
|
401 |
+
|
402 |
+
def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
|
403 |
+
# split the dimension into half, one for x and one for y
|
404 |
+
half_dim = n_elem // 2
|
405 |
+
freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
|
406 |
+
t = torch.arange(grid_size, device=freqs.device)
|
407 |
+
freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
|
408 |
+
freqs_grid = torch.concat([
|
409 |
+
freqs[:, None, :].expand(-1, grid_size, -1),
|
410 |
+
freqs[None, :, :].expand(grid_size, -1, -1),
|
411 |
+
], dim=-1) # (grid_size, grid_size, head_dim // 2)
|
412 |
+
cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
|
413 |
+
cache = cache_grid.flatten(0, 1)
|
414 |
+
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
|
415 |
+
return cond_cache
|
416 |
+
|
417 |
+
|
418 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
|
419 |
+
# x: (bs, seq_len, n_head, head_dim)
|
420 |
+
# freqs_cis (seq_len, head_dim // 2, 2)
|
421 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
|
422 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
|
423 |
+
x_out2 = torch.stack([
|
424 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
425 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
426 |
+
], dim=-1)
|
427 |
+
x_out2 = x_out2.flatten(3)
|
428 |
+
return x_out2.type_as(x)
|
429 |
+
|
430 |
+
|
431 |
+
|
432 |
+
#################################################################################
|
433 |
+
# GPT Configs #
|
434 |
+
#################################################################################
|
435 |
+
### text-conditional
|
436 |
+
def GPT_7B(**kwargs):
|
437 |
+
return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B
|
438 |
+
|
439 |
+
def GPT_3B(**kwargs):
|
440 |
+
return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B
|
441 |
+
|
442 |
+
def GPT_1B(**kwargs):
|
443 |
+
return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B
|
444 |
+
|
445 |
+
### class-conditional
|
446 |
+
def GPT_XXXL(**kwargs):
|
447 |
+
return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
|
448 |
+
|
449 |
+
def GPT_XXL(**kwargs):
|
450 |
+
return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
|
451 |
+
|
452 |
+
def GPT_XL(**kwargs):
|
453 |
+
return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
|
454 |
+
|
455 |
+
def GPT_L(**kwargs):
|
456 |
+
return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
|
457 |
+
|
458 |
+
def GPT_B(**kwargs):
|
459 |
+
return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
|
460 |
+
|
461 |
+
|
462 |
+
GPT_models = {
|
463 |
+
'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
|
464 |
+
'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
|
465 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
torch
|
tokenizer_image/discriminator.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# taming-transformers: https://github.com/CompVis/taming-transformers
|
3 |
+
# stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
|
4 |
+
# maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py
|
5 |
+
import functools
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
try:
|
10 |
+
from kornia.filters import filter2d
|
11 |
+
except:
|
12 |
+
pass
|
13 |
+
|
14 |
+
#################################################################################
|
15 |
+
# PatchGAN #
|
16 |
+
#################################################################################
|
17 |
+
class PatchGANDiscriminator(nn.Module):
|
18 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
+
"""
|
21 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
22 |
+
"""Construct a PatchGAN discriminator
|
23 |
+
Parameters:
|
24 |
+
input_nc (int) -- the number of channels in input images
|
25 |
+
ndf (int) -- the number of filters in the last conv layer
|
26 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
27 |
+
norm_layer -- normalization layer
|
28 |
+
"""
|
29 |
+
super(PatchGANDiscriminator, self).__init__()
|
30 |
+
if not use_actnorm:
|
31 |
+
norm_layer = nn.BatchNorm2d
|
32 |
+
else:
|
33 |
+
norm_layer = ActNorm
|
34 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
35 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
36 |
+
else:
|
37 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
38 |
+
|
39 |
+
kw = 4
|
40 |
+
padw = 1
|
41 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
42 |
+
nf_mult = 1
|
43 |
+
nf_mult_prev = 1
|
44 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
45 |
+
nf_mult_prev = nf_mult
|
46 |
+
nf_mult = min(2 ** n, 8)
|
47 |
+
sequence += [
|
48 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
49 |
+
norm_layer(ndf * nf_mult),
|
50 |
+
nn.LeakyReLU(0.2, True)
|
51 |
+
]
|
52 |
+
|
53 |
+
nf_mult_prev = nf_mult
|
54 |
+
nf_mult = min(2 ** n_layers, 8)
|
55 |
+
sequence += [
|
56 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
57 |
+
norm_layer(ndf * nf_mult),
|
58 |
+
nn.LeakyReLU(0.2, True)
|
59 |
+
]
|
60 |
+
|
61 |
+
sequence += [
|
62 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
63 |
+
self.main = nn.Sequential(*sequence)
|
64 |
+
|
65 |
+
self.apply(self._init_weights)
|
66 |
+
|
67 |
+
def _init_weights(self, module):
|
68 |
+
if isinstance(module, nn.Conv2d):
|
69 |
+
nn.init.normal_(module.weight.data, 0.0, 0.02)
|
70 |
+
elif isinstance(module, nn.BatchNorm2d):
|
71 |
+
nn.init.normal_(module.weight.data, 1.0, 0.02)
|
72 |
+
nn.init.constant_(module.bias.data, 0)
|
73 |
+
|
74 |
+
def forward(self, input):
|
75 |
+
"""Standard forward."""
|
76 |
+
return self.main(input)
|
77 |
+
|
78 |
+
|
79 |
+
class ActNorm(nn.Module):
|
80 |
+
def __init__(self, num_features, logdet=False, affine=True,
|
81 |
+
allow_reverse_init=False):
|
82 |
+
assert affine
|
83 |
+
super().__init__()
|
84 |
+
self.logdet = logdet
|
85 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
86 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
87 |
+
self.allow_reverse_init = allow_reverse_init
|
88 |
+
|
89 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
90 |
+
|
91 |
+
def initialize(self, input):
|
92 |
+
with torch.no_grad():
|
93 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
94 |
+
mean = (
|
95 |
+
flatten.mean(1)
|
96 |
+
.unsqueeze(1)
|
97 |
+
.unsqueeze(2)
|
98 |
+
.unsqueeze(3)
|
99 |
+
.permute(1, 0, 2, 3)
|
100 |
+
)
|
101 |
+
std = (
|
102 |
+
flatten.std(1)
|
103 |
+
.unsqueeze(1)
|
104 |
+
.unsqueeze(2)
|
105 |
+
.unsqueeze(3)
|
106 |
+
.permute(1, 0, 2, 3)
|
107 |
+
)
|
108 |
+
|
109 |
+
self.loc.data.copy_(-mean)
|
110 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
111 |
+
|
112 |
+
def forward(self, input, reverse=False):
|
113 |
+
if reverse:
|
114 |
+
return self.reverse(input)
|
115 |
+
if len(input.shape) == 2:
|
116 |
+
input = input[:,:,None,None]
|
117 |
+
squeeze = True
|
118 |
+
else:
|
119 |
+
squeeze = False
|
120 |
+
|
121 |
+
_, _, height, width = input.shape
|
122 |
+
|
123 |
+
if self.training and self.initialized.item() == 0:
|
124 |
+
self.initialize(input)
|
125 |
+
self.initialized.fill_(1)
|
126 |
+
|
127 |
+
h = self.scale * (input + self.loc)
|
128 |
+
|
129 |
+
if squeeze:
|
130 |
+
h = h.squeeze(-1).squeeze(-1)
|
131 |
+
|
132 |
+
if self.logdet:
|
133 |
+
log_abs = torch.log(torch.abs(self.scale))
|
134 |
+
logdet = height*width*torch.sum(log_abs)
|
135 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
136 |
+
return h, logdet
|
137 |
+
|
138 |
+
return h
|
139 |
+
|
140 |
+
def reverse(self, output):
|
141 |
+
if self.training and self.initialized.item() == 0:
|
142 |
+
if not self.allow_reverse_init:
|
143 |
+
raise RuntimeError(
|
144 |
+
"Initializing ActNorm in reverse direction is "
|
145 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
self.initialize(output)
|
149 |
+
self.initialized.fill_(1)
|
150 |
+
|
151 |
+
if len(output.shape) == 2:
|
152 |
+
output = output[:,:,None,None]
|
153 |
+
squeeze = True
|
154 |
+
else:
|
155 |
+
squeeze = False
|
156 |
+
|
157 |
+
h = output / self.scale - self.loc
|
158 |
+
|
159 |
+
if squeeze:
|
160 |
+
h = h.squeeze(-1).squeeze(-1)
|
161 |
+
return h
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
#################################################################################
|
166 |
+
# StyleGAN #
|
167 |
+
#################################################################################
|
168 |
+
class StyleGANDiscriminator(nn.Module):
|
169 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256):
|
170 |
+
super().__init__()
|
171 |
+
channels = {
|
172 |
+
4: 512,
|
173 |
+
8: 512,
|
174 |
+
16: 512,
|
175 |
+
32: 512,
|
176 |
+
64: 256 * channel_multiplier,
|
177 |
+
128: 128 * channel_multiplier,
|
178 |
+
256: 64 * channel_multiplier,
|
179 |
+
512: 32 * channel_multiplier,
|
180 |
+
1024: 16 * channel_multiplier,
|
181 |
+
}
|
182 |
+
|
183 |
+
log_size = int(math.log(image_size, 2))
|
184 |
+
in_channel = channels[image_size]
|
185 |
+
|
186 |
+
blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()]
|
187 |
+
for i in range(log_size, 2, -1):
|
188 |
+
out_channel = channels[2 ** (i - 1)]
|
189 |
+
blocks.append(DiscriminatorBlock(in_channel, out_channel))
|
190 |
+
in_channel = out_channel
|
191 |
+
self.blocks = nn.ModuleList(blocks)
|
192 |
+
|
193 |
+
self.final_conv = nn.Sequential(
|
194 |
+
nn.Conv2d(in_channel, channels[4], 3, padding=1),
|
195 |
+
leaky_relu(),
|
196 |
+
)
|
197 |
+
self.final_linear = nn.Sequential(
|
198 |
+
nn.Linear(channels[4] * 4 * 4, channels[4]),
|
199 |
+
leaky_relu(),
|
200 |
+
nn.Linear(channels[4], 1)
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
for block in self.blocks:
|
205 |
+
x = block(x)
|
206 |
+
x = self.final_conv(x)
|
207 |
+
x = x.view(x.shape[0], -1)
|
208 |
+
x = self.final_linear(x)
|
209 |
+
return x
|
210 |
+
|
211 |
+
|
212 |
+
class DiscriminatorBlock(nn.Module):
|
213 |
+
def __init__(self, input_channels, filters, downsample=True):
|
214 |
+
super().__init__()
|
215 |
+
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
|
216 |
+
|
217 |
+
self.net = nn.Sequential(
|
218 |
+
nn.Conv2d(input_channels, filters, 3, padding=1),
|
219 |
+
leaky_relu(),
|
220 |
+
nn.Conv2d(filters, filters, 3, padding=1),
|
221 |
+
leaky_relu()
|
222 |
+
)
|
223 |
+
|
224 |
+
self.downsample = nn.Sequential(
|
225 |
+
Blur(),
|
226 |
+
nn.Conv2d(filters, filters, 3, padding = 1, stride = 2)
|
227 |
+
) if downsample else None
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
res = self.conv_res(x)
|
231 |
+
x = self.net(x)
|
232 |
+
if exists(self.downsample):
|
233 |
+
x = self.downsample(x)
|
234 |
+
x = (x + res) * (1 / math.sqrt(2))
|
235 |
+
return x
|
236 |
+
|
237 |
+
|
238 |
+
class Blur(nn.Module):
|
239 |
+
def __init__(self):
|
240 |
+
super().__init__()
|
241 |
+
f = torch.Tensor([1, 2, 1])
|
242 |
+
self.register_buffer('f', f)
|
243 |
+
|
244 |
+
def forward(self, x):
|
245 |
+
f = self.f
|
246 |
+
f = f[None, None, :] * f [None, :, None]
|
247 |
+
return filter2d(x, f, normalized=True)
|
248 |
+
|
249 |
+
|
250 |
+
def leaky_relu(p=0.2):
|
251 |
+
return nn.LeakyReLU(p, inplace=True)
|
252 |
+
|
253 |
+
|
254 |
+
def exists(val):
|
255 |
+
return val is not None
|
tokenizer_image/discriminator_patchgan.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# taming-transformers: https://github.com/CompVis/taming-transformers
|
3 |
+
import functools
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
class NLayerDiscriminator(nn.Module):
|
9 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
10 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
11 |
+
"""
|
12 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
13 |
+
"""Construct a PatchGAN discriminator
|
14 |
+
Parameters:
|
15 |
+
input_nc (int) -- the number of channels in input images
|
16 |
+
ndf (int) -- the number of filters in the last conv layer
|
17 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
18 |
+
norm_layer -- normalization layer
|
19 |
+
"""
|
20 |
+
super(NLayerDiscriminator, self).__init__()
|
21 |
+
if not use_actnorm:
|
22 |
+
norm_layer = nn.BatchNorm2d
|
23 |
+
else:
|
24 |
+
norm_layer = ActNorm
|
25 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
26 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
27 |
+
else:
|
28 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
29 |
+
|
30 |
+
kw = 4
|
31 |
+
padw = 1
|
32 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
33 |
+
nf_mult = 1
|
34 |
+
nf_mult_prev = 1
|
35 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
36 |
+
nf_mult_prev = nf_mult
|
37 |
+
nf_mult = min(2 ** n, 8)
|
38 |
+
sequence += [
|
39 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
40 |
+
norm_layer(ndf * nf_mult),
|
41 |
+
nn.LeakyReLU(0.2, True)
|
42 |
+
]
|
43 |
+
|
44 |
+
nf_mult_prev = nf_mult
|
45 |
+
nf_mult = min(2 ** n_layers, 8)
|
46 |
+
sequence += [
|
47 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
48 |
+
norm_layer(ndf * nf_mult),
|
49 |
+
nn.LeakyReLU(0.2, True)
|
50 |
+
]
|
51 |
+
|
52 |
+
sequence += [
|
53 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
54 |
+
self.main = nn.Sequential(*sequence)
|
55 |
+
|
56 |
+
self.apply(self._init_weights)
|
57 |
+
|
58 |
+
def _init_weights(self, module):
|
59 |
+
if isinstance(module, nn.Conv2d):
|
60 |
+
nn.init.normal_(module.weight.data, 0.0, 0.02)
|
61 |
+
elif isinstance(module, nn.BatchNorm2d):
|
62 |
+
nn.init.normal_(module.weight.data, 1.0, 0.02)
|
63 |
+
nn.init.constant_(module.bias.data, 0)
|
64 |
+
|
65 |
+
def forward(self, input):
|
66 |
+
"""Standard forward."""
|
67 |
+
return self.main(input)
|
68 |
+
|
69 |
+
|
70 |
+
class ActNorm(nn.Module):
|
71 |
+
def __init__(self, num_features, logdet=False, affine=True,
|
72 |
+
allow_reverse_init=False):
|
73 |
+
assert affine
|
74 |
+
super().__init__()
|
75 |
+
self.logdet = logdet
|
76 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
77 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
78 |
+
self.allow_reverse_init = allow_reverse_init
|
79 |
+
|
80 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
81 |
+
|
82 |
+
def initialize(self, input):
|
83 |
+
with torch.no_grad():
|
84 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
85 |
+
mean = (
|
86 |
+
flatten.mean(1)
|
87 |
+
.unsqueeze(1)
|
88 |
+
.unsqueeze(2)
|
89 |
+
.unsqueeze(3)
|
90 |
+
.permute(1, 0, 2, 3)
|
91 |
+
)
|
92 |
+
std = (
|
93 |
+
flatten.std(1)
|
94 |
+
.unsqueeze(1)
|
95 |
+
.unsqueeze(2)
|
96 |
+
.unsqueeze(3)
|
97 |
+
.permute(1, 0, 2, 3)
|
98 |
+
)
|
99 |
+
|
100 |
+
self.loc.data.copy_(-mean)
|
101 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
102 |
+
|
103 |
+
def forward(self, input, reverse=False):
|
104 |
+
if reverse:
|
105 |
+
return self.reverse(input)
|
106 |
+
if len(input.shape) == 2:
|
107 |
+
input = input[:,:,None,None]
|
108 |
+
squeeze = True
|
109 |
+
else:
|
110 |
+
squeeze = False
|
111 |
+
|
112 |
+
_, _, height, width = input.shape
|
113 |
+
|
114 |
+
if self.training and self.initialized.item() == 0:
|
115 |
+
self.initialize(input)
|
116 |
+
self.initialized.fill_(1)
|
117 |
+
|
118 |
+
h = self.scale * (input + self.loc)
|
119 |
+
|
120 |
+
if squeeze:
|
121 |
+
h = h.squeeze(-1).squeeze(-1)
|
122 |
+
|
123 |
+
if self.logdet:
|
124 |
+
log_abs = torch.log(torch.abs(self.scale))
|
125 |
+
logdet = height*width*torch.sum(log_abs)
|
126 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
127 |
+
return h, logdet
|
128 |
+
|
129 |
+
return h
|
130 |
+
|
131 |
+
def reverse(self, output):
|
132 |
+
if self.training and self.initialized.item() == 0:
|
133 |
+
if not self.allow_reverse_init:
|
134 |
+
raise RuntimeError(
|
135 |
+
"Initializing ActNorm in reverse direction is "
|
136 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
137 |
+
)
|
138 |
+
else:
|
139 |
+
self.initialize(output)
|
140 |
+
self.initialized.fill_(1)
|
141 |
+
|
142 |
+
if len(output.shape) == 2:
|
143 |
+
output = output[:,:,None,None]
|
144 |
+
squeeze = True
|
145 |
+
else:
|
146 |
+
squeeze = False
|
147 |
+
|
148 |
+
h = output / self.scale - self.loc
|
149 |
+
|
150 |
+
if squeeze:
|
151 |
+
h = h.squeeze(-1).squeeze(-1)
|
152 |
+
return h
|
tokenizer_image/discriminator_stylegan.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# stylegan2-pytorch: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/stylegan2_pytorch.py
|
3 |
+
# stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
|
4 |
+
# maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
try:
|
9 |
+
from kornia.filters import filter2d
|
10 |
+
except:
|
11 |
+
pass
|
12 |
+
|
13 |
+
class Discriminator(nn.Module):
|
14 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256):
|
15 |
+
super().__init__()
|
16 |
+
channels = {
|
17 |
+
4: 512,
|
18 |
+
8: 512,
|
19 |
+
16: 512,
|
20 |
+
32: 512,
|
21 |
+
64: 256 * channel_multiplier,
|
22 |
+
128: 128 * channel_multiplier,
|
23 |
+
256: 64 * channel_multiplier,
|
24 |
+
512: 32 * channel_multiplier,
|
25 |
+
1024: 16 * channel_multiplier,
|
26 |
+
}
|
27 |
+
|
28 |
+
log_size = int(math.log(image_size, 2))
|
29 |
+
in_channel = channels[image_size]
|
30 |
+
|
31 |
+
blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()]
|
32 |
+
for i in range(log_size, 2, -1):
|
33 |
+
out_channel = channels[2 ** (i - 1)]
|
34 |
+
blocks.append(DiscriminatorBlock(in_channel, out_channel))
|
35 |
+
in_channel = out_channel
|
36 |
+
self.blocks = nn.ModuleList(blocks)
|
37 |
+
|
38 |
+
self.final_conv = nn.Sequential(
|
39 |
+
nn.Conv2d(in_channel, channels[4], 3, padding=1),
|
40 |
+
leaky_relu(),
|
41 |
+
)
|
42 |
+
self.final_linear = nn.Sequential(
|
43 |
+
nn.Linear(channels[4] * 4 * 4, channels[4]),
|
44 |
+
leaky_relu(),
|
45 |
+
nn.Linear(channels[4], 1)
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
for block in self.blocks:
|
50 |
+
x = block(x)
|
51 |
+
x = self.final_conv(x)
|
52 |
+
x = x.view(x.shape[0], -1)
|
53 |
+
x = self.final_linear(x)
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
class DiscriminatorBlock(nn.Module):
|
58 |
+
def __init__(self, input_channels, filters, downsample=True):
|
59 |
+
super().__init__()
|
60 |
+
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
|
61 |
+
|
62 |
+
self.net = nn.Sequential(
|
63 |
+
nn.Conv2d(input_channels, filters, 3, padding=1),
|
64 |
+
leaky_relu(),
|
65 |
+
nn.Conv2d(filters, filters, 3, padding=1),
|
66 |
+
leaky_relu()
|
67 |
+
)
|
68 |
+
|
69 |
+
self.downsample = nn.Sequential(
|
70 |
+
Blur(),
|
71 |
+
nn.Conv2d(filters, filters, 3, padding = 1, stride = 2)
|
72 |
+
) if downsample else None
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
res = self.conv_res(x)
|
76 |
+
x = self.net(x)
|
77 |
+
if exists(self.downsample):
|
78 |
+
x = self.downsample(x)
|
79 |
+
x = (x + res) * (1 / math.sqrt(2))
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
class Blur(nn.Module):
|
85 |
+
def __init__(self):
|
86 |
+
super().__init__()
|
87 |
+
f = torch.Tensor([1, 2, 1])
|
88 |
+
self.register_buffer('f', f)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
f = self.f
|
92 |
+
f = f[None, None, :] * f [None, :, None]
|
93 |
+
return filter2d(x, f, normalized=True)
|
94 |
+
|
95 |
+
|
96 |
+
def leaky_relu(p=0.2):
|
97 |
+
return nn.LeakyReLU(p, inplace=True)
|
98 |
+
|
99 |
+
|
100 |
+
def exists(val):
|
101 |
+
return val is not None
|
tokenizer_image/lpips.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
+
|
3 |
+
import os, hashlib
|
4 |
+
import requests
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torchvision import models
|
10 |
+
from collections import namedtuple
|
11 |
+
|
12 |
+
URL_MAP = {
|
13 |
+
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
14 |
+
}
|
15 |
+
|
16 |
+
CKPT_MAP = {
|
17 |
+
"vgg_lpips": "vgg.pth"
|
18 |
+
}
|
19 |
+
|
20 |
+
MD5_MAP = {
|
21 |
+
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
22 |
+
}
|
23 |
+
|
24 |
+
def download(url, local_path, chunk_size=1024):
|
25 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
26 |
+
with requests.get(url, stream=True) as r:
|
27 |
+
total_size = int(r.headers.get("content-length", 0))
|
28 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
29 |
+
with open(local_path, "wb") as f:
|
30 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
31 |
+
if data:
|
32 |
+
f.write(data)
|
33 |
+
pbar.update(chunk_size)
|
34 |
+
|
35 |
+
|
36 |
+
def md5_hash(path):
|
37 |
+
with open(path, "rb") as f:
|
38 |
+
content = f.read()
|
39 |
+
return hashlib.md5(content).hexdigest()
|
40 |
+
|
41 |
+
|
42 |
+
def get_ckpt_path(name, root, check=False):
|
43 |
+
assert name in URL_MAP
|
44 |
+
path = os.path.join(root, CKPT_MAP[name])
|
45 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
46 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
47 |
+
download(URL_MAP[name], path)
|
48 |
+
md5 = md5_hash(path)
|
49 |
+
assert md5 == MD5_MAP[name], md5
|
50 |
+
return path
|
51 |
+
|
52 |
+
|
53 |
+
class LPIPS(nn.Module):
|
54 |
+
# Learned perceptual metric
|
55 |
+
def __init__(self, use_dropout=True):
|
56 |
+
super().__init__()
|
57 |
+
self.scaling_layer = ScalingLayer()
|
58 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
59 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
60 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
61 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
62 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
63 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
64 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
65 |
+
self.load_from_pretrained()
|
66 |
+
for param in self.parameters():
|
67 |
+
param.requires_grad = False
|
68 |
+
|
69 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
70 |
+
ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
|
71 |
+
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
72 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
73 |
+
|
74 |
+
@classmethod
|
75 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
76 |
+
if name != "vgg_lpips":
|
77 |
+
raise NotImplementedError
|
78 |
+
model = cls()
|
79 |
+
ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
|
80 |
+
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
81 |
+
return model
|
82 |
+
|
83 |
+
def forward(self, input, target):
|
84 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
85 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
86 |
+
feats0, feats1, diffs = {}, {}, {}
|
87 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
88 |
+
for kk in range(len(self.chns)):
|
89 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
90 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
91 |
+
|
92 |
+
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
93 |
+
val = res[0]
|
94 |
+
for l in range(1, len(self.chns)):
|
95 |
+
val += res[l]
|
96 |
+
return val
|
97 |
+
|
98 |
+
|
99 |
+
class ScalingLayer(nn.Module):
|
100 |
+
def __init__(self):
|
101 |
+
super(ScalingLayer, self).__init__()
|
102 |
+
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
103 |
+
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
104 |
+
|
105 |
+
def forward(self, inp):
|
106 |
+
return (inp - self.shift) / self.scale
|
107 |
+
|
108 |
+
|
109 |
+
class NetLinLayer(nn.Module):
|
110 |
+
""" A single linear layer which does a 1x1 conv """
|
111 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
112 |
+
super(NetLinLayer, self).__init__()
|
113 |
+
layers = [nn.Dropout(), ] if (use_dropout) else []
|
114 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
115 |
+
self.model = nn.Sequential(*layers)
|
116 |
+
|
117 |
+
|
118 |
+
class vgg16(torch.nn.Module):
|
119 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
120 |
+
super(vgg16, self).__init__()
|
121 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
122 |
+
self.slice1 = torch.nn.Sequential()
|
123 |
+
self.slice2 = torch.nn.Sequential()
|
124 |
+
self.slice3 = torch.nn.Sequential()
|
125 |
+
self.slice4 = torch.nn.Sequential()
|
126 |
+
self.slice5 = torch.nn.Sequential()
|
127 |
+
self.N_slices = 5
|
128 |
+
for x in range(4):
|
129 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
130 |
+
for x in range(4, 9):
|
131 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
132 |
+
for x in range(9, 16):
|
133 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
134 |
+
for x in range(16, 23):
|
135 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
136 |
+
for x in range(23, 30):
|
137 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
138 |
+
if not requires_grad:
|
139 |
+
for param in self.parameters():
|
140 |
+
param.requires_grad = False
|
141 |
+
|
142 |
+
def forward(self, X):
|
143 |
+
h = self.slice1(X)
|
144 |
+
h_relu1_2 = h
|
145 |
+
h = self.slice2(h)
|
146 |
+
h_relu2_2 = h
|
147 |
+
h = self.slice3(h)
|
148 |
+
h_relu3_3 = h
|
149 |
+
h = self.slice4(h)
|
150 |
+
h_relu4_3 = h
|
151 |
+
h = self.slice5(h)
|
152 |
+
h_relu5_3 = h
|
153 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
154 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
155 |
+
return out
|
156 |
+
|
157 |
+
|
158 |
+
def normalize_tensor(x,eps=1e-10):
|
159 |
+
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
|
160 |
+
return x/(norm_factor+eps)
|
161 |
+
|
162 |
+
|
163 |
+
def spatial_average(x, keepdim=True):
|
164 |
+
return x.mean([2,3],keepdim=keepdim)
|
tokenizer_image/reconstruction_vq_ddp.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
3 |
+
torch.backends.cudnn.allow_tf32 = True
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.distributed as dist
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from torch.utils.data.distributed import DistributedSampler
|
8 |
+
from torchvision import transforms
|
9 |
+
from tqdm import tqdm
|
10 |
+
import os
|
11 |
+
from PIL import Image
|
12 |
+
import numpy as np
|
13 |
+
import argparse
|
14 |
+
import itertools
|
15 |
+
|
16 |
+
from skimage.metrics import peak_signal_noise_ratio as psnr_loss
|
17 |
+
from skimage.metrics import structural_similarity as ssim_loss
|
18 |
+
|
19 |
+
from dataset.augmentation import center_crop_arr
|
20 |
+
from dataset.build import build_dataset
|
21 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def create_npz_from_sample_folder(sample_dir, num=50000):
|
26 |
+
"""
|
27 |
+
Builds a single .npz file from a folder of .png samples.
|
28 |
+
"""
|
29 |
+
samples = []
|
30 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
31 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
32 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
33 |
+
samples.append(sample_np)
|
34 |
+
samples = np.stack(samples)
|
35 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
36 |
+
npz_path = f"{sample_dir}.npz"
|
37 |
+
np.savez(npz_path, arr_0=samples)
|
38 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
39 |
+
return npz_path
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def main(args):
|
44 |
+
# Setup PyTorch:
|
45 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
46 |
+
torch.set_grad_enabled(False)
|
47 |
+
|
48 |
+
# Setup DDP:
|
49 |
+
dist.init_process_group("nccl")
|
50 |
+
rank = dist.get_rank()
|
51 |
+
device = rank % torch.cuda.device_count()
|
52 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
53 |
+
torch.manual_seed(seed)
|
54 |
+
torch.cuda.set_device(device)
|
55 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
56 |
+
|
57 |
+
# create and load model
|
58 |
+
vq_model = VQ_models[args.vq_model](
|
59 |
+
codebook_size=args.codebook_size,
|
60 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
61 |
+
vq_model.to(device)
|
62 |
+
vq_model.eval()
|
63 |
+
checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
|
64 |
+
if "ema" in checkpoint: # ema
|
65 |
+
model_weight = checkpoint["ema"]
|
66 |
+
elif "model" in checkpoint: # ddp
|
67 |
+
model_weight = checkpoint["model"]
|
68 |
+
elif "state_dict" in checkpoint:
|
69 |
+
model_weight = checkpoint["state_dict"]
|
70 |
+
else:
|
71 |
+
raise Exception("please check model weight")
|
72 |
+
vq_model.load_state_dict(model_weight)
|
73 |
+
del checkpoint
|
74 |
+
|
75 |
+
# Create folder to save samples:
|
76 |
+
folder_name = (f"{args.vq_model}-{args.dataset}-size-{args.image_size}-size-{args.image_size_eval}"
|
77 |
+
f"-codebook-size-{args.codebook_size}-dim-{args.codebook_embed_dim}-seed-{args.global_seed}")
|
78 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
79 |
+
if rank == 0:
|
80 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
81 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
82 |
+
dist.barrier()
|
83 |
+
|
84 |
+
# Setup data:
|
85 |
+
transform = transforms.Compose([
|
86 |
+
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
|
87 |
+
transforms.ToTensor(),
|
88 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
89 |
+
])
|
90 |
+
|
91 |
+
if args.dataset == 'imagenet':
|
92 |
+
dataset = build_dataset(args, transform=transform)
|
93 |
+
num_fid_samples = 50000
|
94 |
+
elif args.dataset == 'coco':
|
95 |
+
dataset = build_dataset(args, transform=transform)
|
96 |
+
num_fid_samples = 5000
|
97 |
+
else:
|
98 |
+
raise Exception("please check dataset")
|
99 |
+
|
100 |
+
sampler = DistributedSampler(
|
101 |
+
dataset,
|
102 |
+
num_replicas=dist.get_world_size(),
|
103 |
+
rank=rank,
|
104 |
+
shuffle=False,
|
105 |
+
seed=args.global_seed
|
106 |
+
)
|
107 |
+
loader = DataLoader(
|
108 |
+
dataset,
|
109 |
+
batch_size=args.per_proc_batch_size,
|
110 |
+
shuffle=False,
|
111 |
+
sampler=sampler,
|
112 |
+
num_workers=args.num_workers,
|
113 |
+
pin_memory=True,
|
114 |
+
drop_last=False
|
115 |
+
)
|
116 |
+
|
117 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
118 |
+
n = args.per_proc_batch_size
|
119 |
+
global_batch_size = n * dist.get_world_size()
|
120 |
+
|
121 |
+
psnr_val_rgb = []
|
122 |
+
ssim_val_rgb = []
|
123 |
+
loader = tqdm(loader) if rank == 0 else loader
|
124 |
+
total = 0
|
125 |
+
for x, _ in loader:
|
126 |
+
if args.image_size_eval != args.image_size:
|
127 |
+
rgb_gts = F.interpolate(x, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
|
128 |
+
else:
|
129 |
+
rgb_gts = x
|
130 |
+
rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1]
|
131 |
+
x = x.to(device, non_blocking=True)
|
132 |
+
with torch.no_grad():
|
133 |
+
latent, _, [_, _, indices] = vq_model.encode(x)
|
134 |
+
samples = vq_model.decode_code(indices, latent.shape) # output value is between [-1, 1]
|
135 |
+
if args.image_size_eval != args.image_size:
|
136 |
+
samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
|
137 |
+
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
138 |
+
|
139 |
+
# Save samples to disk as individual .png files
|
140 |
+
for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)):
|
141 |
+
index = i * dist.get_world_size() + rank + total
|
142 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
143 |
+
# metric
|
144 |
+
rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1]
|
145 |
+
psnr = psnr_loss(rgb_restored, rgb_gt)
|
146 |
+
ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1)
|
147 |
+
psnr_val_rgb.append(psnr)
|
148 |
+
ssim_val_rgb.append(ssim)
|
149 |
+
|
150 |
+
total += global_batch_size
|
151 |
+
|
152 |
+
# ------------------------------------
|
153 |
+
# Summary
|
154 |
+
# ------------------------------------
|
155 |
+
# Make sure all processes have finished saving their samples
|
156 |
+
dist.barrier()
|
157 |
+
world_size = dist.get_world_size()
|
158 |
+
gather_psnr_val = [None for _ in range(world_size)]
|
159 |
+
gather_ssim_val = [None for _ in range(world_size)]
|
160 |
+
dist.all_gather_object(gather_psnr_val, psnr_val_rgb)
|
161 |
+
dist.all_gather_object(gather_ssim_val, ssim_val_rgb)
|
162 |
+
|
163 |
+
if rank == 0:
|
164 |
+
gather_psnr_val = list(itertools.chain(*gather_psnr_val))
|
165 |
+
gather_ssim_val = list(itertools.chain(*gather_ssim_val))
|
166 |
+
psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val)
|
167 |
+
ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val)
|
168 |
+
print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb))
|
169 |
+
|
170 |
+
result_file = f"{sample_folder_dir}_results.txt"
|
171 |
+
print("writing results to {}".format(result_file))
|
172 |
+
with open(result_file, 'w') as f:
|
173 |
+
print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f)
|
174 |
+
|
175 |
+
create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
|
176 |
+
print("Done.")
|
177 |
+
|
178 |
+
dist.barrier()
|
179 |
+
dist.destroy_process_group()
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
parser = argparse.ArgumentParser()
|
184 |
+
parser.add_argument("--data-path", type=str, required=True)
|
185 |
+
parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet')
|
186 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
187 |
+
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
|
188 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
189 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
190 |
+
parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256)
|
191 |
+
parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
|
192 |
+
parser.add_argument("--sample-dir", type=str, default="reconstructions")
|
193 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=32)
|
194 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
195 |
+
parser.add_argument("--num-workers", type=int, default=4)
|
196 |
+
args = parser.parse_args()
|
197 |
+
main(args)
|
tokenizer_image/vq_demo.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
10 |
+
from dataset.augmentation import center_crop_arr
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
# Setup PyTorch:
|
15 |
+
torch.manual_seed(args.seed)
|
16 |
+
torch.set_grad_enabled(False)
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
|
19 |
+
# create and load model
|
20 |
+
model = VQ_models[args.vq_model](
|
21 |
+
codebook_size=args.codebook_size,
|
22 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
23 |
+
model.to(device)
|
24 |
+
model.eval()
|
25 |
+
checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
|
26 |
+
if "ema" in checkpoint: # ema
|
27 |
+
model_weight = checkpoint["ema"]
|
28 |
+
elif "model" in checkpoint: # ddp
|
29 |
+
model_weight = checkpoint["model"]
|
30 |
+
elif "state_dict" in checkpoint:
|
31 |
+
model_weight = checkpoint["state_dict"]
|
32 |
+
else:
|
33 |
+
raise Exception("please check model weight")
|
34 |
+
model.load_state_dict(model_weight)
|
35 |
+
del checkpoint
|
36 |
+
|
37 |
+
# output dir
|
38 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
39 |
+
out_path = args.image_path.replace('.jpg', '_{}.jpg'.format(args.suffix))
|
40 |
+
out_path = out_path.replace('.jpeg', '_{}.jpeg'.format(args.suffix))
|
41 |
+
out_path = out_path.replace('.png', '_{}.png'.format(args.suffix))
|
42 |
+
out_filename = out_path.split('/')[-1]
|
43 |
+
out_path = os.path.join(args.output_dir, out_filename)
|
44 |
+
|
45 |
+
# load image
|
46 |
+
pil_image = Image.open(args.image_path).convert("RGB")
|
47 |
+
img = center_crop_arr(pil_image, args.image_size)
|
48 |
+
# # preprocess
|
49 |
+
# size_org = img.size
|
50 |
+
# img = img.resize((input_size, input_size))
|
51 |
+
img = np.array(img) / 255.
|
52 |
+
x = 2.0 * img - 1.0 # x value is between [-1, 1]
|
53 |
+
x = torch.tensor(x)
|
54 |
+
x = x.unsqueeze(dim=0)
|
55 |
+
x = torch.einsum('nhwc->nchw', x)
|
56 |
+
x_input = x.float().to("cuda")
|
57 |
+
|
58 |
+
# inference
|
59 |
+
with torch.no_grad():
|
60 |
+
latent, _, [_, _, indices] = model.encode(x_input)
|
61 |
+
output = model.decode_code(indices, latent.shape) # output value is between [-1, 1]
|
62 |
+
|
63 |
+
# postprocess
|
64 |
+
output = F.interpolate(output, size=[args.image_size, args.image_size], mode='bicubic').permute(0, 2, 3, 1)[0]
|
65 |
+
sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
|
66 |
+
|
67 |
+
# save
|
68 |
+
Image.fromarray(sample).save(out_path)
|
69 |
+
print("Reconstructed image is saved to {}".format(out_path))
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
parser = argparse.ArgumentParser()
|
74 |
+
parser.add_argument("--image-path", type=str, default="assets/example.jpg")
|
75 |
+
parser.add_argument("--output-dir", type=str, default="output_vq_demo")
|
76 |
+
parser.add_argument("--suffix", type=str, default="tokenizer_image")
|
77 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
78 |
+
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
|
79 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
80 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
81 |
+
parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512, 1024], default=512)
|
82 |
+
parser.add_argument("--seed", type=int, default=0)
|
83 |
+
args = parser.parse_args()
|
84 |
+
main(args)
|
tokenizer_image/vq_loss.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# taming-transformers: https://github.com/CompVis/taming-transformers
|
3 |
+
# muse-maskgit-pytorch: https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/vqgan_vae.py
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from tokenizer.tokenizer_image.lpips import LPIPS
|
9 |
+
from tokenizer.tokenizer_image.discriminator_patchgan import NLayerDiscriminator as PatchGANDiscriminator
|
10 |
+
from tokenizer.tokenizer_image.discriminator_stylegan import Discriminator as StyleGANDiscriminator
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def hinge_d_loss(logits_real, logits_fake):
|
15 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
16 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
17 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
18 |
+
return d_loss
|
19 |
+
|
20 |
+
|
21 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
22 |
+
loss_real = torch.mean(F.softplus(-logits_real))
|
23 |
+
loss_fake = torch.mean(F.softplus(logits_fake))
|
24 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
25 |
+
return d_loss
|
26 |
+
|
27 |
+
|
28 |
+
def non_saturating_d_loss(logits_real, logits_fake):
|
29 |
+
loss_real = torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real))
|
30 |
+
loss_fake = torch.mean(F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake))
|
31 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
32 |
+
return d_loss
|
33 |
+
|
34 |
+
|
35 |
+
def hinge_gen_loss(logit_fake):
|
36 |
+
return -torch.mean(logit_fake)
|
37 |
+
|
38 |
+
|
39 |
+
def non_saturating_gen_loss(logit_fake):
|
40 |
+
return torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake))
|
41 |
+
|
42 |
+
|
43 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
44 |
+
if global_step < threshold:
|
45 |
+
weight = value
|
46 |
+
return weight
|
47 |
+
|
48 |
+
|
49 |
+
class VQLoss(nn.Module):
|
50 |
+
def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_type='patchgan', image_size=256,
|
51 |
+
disc_num_layers=3, disc_in_channels=3, disc_weight=1.0, disc_adaptive_weight = False,
|
52 |
+
gen_adv_loss='hinge', reconstruction_loss='l2', reconstruction_weight=1.0,
|
53 |
+
codebook_weight=1.0, perceptual_weight=1.0,
|
54 |
+
):
|
55 |
+
super().__init__()
|
56 |
+
# discriminator loss
|
57 |
+
assert disc_type in ["patchgan", "stylegan"]
|
58 |
+
assert disc_loss in ["hinge", "vanilla", "non-saturating"]
|
59 |
+
if disc_type == "patchgan":
|
60 |
+
self.discriminator = PatchGANDiscriminator(
|
61 |
+
input_nc=disc_in_channels,
|
62 |
+
n_layers=disc_num_layers,
|
63 |
+
ndf=disc_dim,
|
64 |
+
)
|
65 |
+
elif disc_type == "stylegan":
|
66 |
+
self.discriminator = StyleGANDiscriminator(
|
67 |
+
input_nc=disc_in_channels,
|
68 |
+
image_size=image_size,
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.")
|
72 |
+
if disc_loss == "hinge":
|
73 |
+
self.disc_loss = hinge_d_loss
|
74 |
+
elif disc_loss == "vanilla":
|
75 |
+
self.disc_loss = vanilla_d_loss
|
76 |
+
elif disc_loss == "non-saturating":
|
77 |
+
self.disc_loss = non_saturating_d_loss
|
78 |
+
else:
|
79 |
+
raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.")
|
80 |
+
self.discriminator_iter_start = disc_start
|
81 |
+
self.disc_weight = disc_weight
|
82 |
+
self.disc_adaptive_weight = disc_adaptive_weight
|
83 |
+
|
84 |
+
assert gen_adv_loss in ["hinge", "non-saturating"]
|
85 |
+
# gen_adv_loss
|
86 |
+
if gen_adv_loss == "hinge":
|
87 |
+
self.gen_adv_loss = hinge_gen_loss
|
88 |
+
elif gen_adv_loss == "non-saturating":
|
89 |
+
self.gen_adv_loss = non_saturating_gen_loss
|
90 |
+
else:
|
91 |
+
raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")
|
92 |
+
|
93 |
+
# perceptual loss
|
94 |
+
self.perceptual_loss = LPIPS().eval()
|
95 |
+
self.perceptual_weight = perceptual_weight
|
96 |
+
|
97 |
+
# reconstruction loss
|
98 |
+
if reconstruction_loss == "l1":
|
99 |
+
self.rec_loss = F.l1_loss
|
100 |
+
elif reconstruction_loss == "l2":
|
101 |
+
self.rec_loss = F.mse_loss
|
102 |
+
else:
|
103 |
+
raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.")
|
104 |
+
self.rec_weight = reconstruction_weight
|
105 |
+
|
106 |
+
# codebook loss
|
107 |
+
self.codebook_weight = codebook_weight
|
108 |
+
|
109 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
|
110 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
111 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
112 |
+
|
113 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
114 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
115 |
+
return d_weight.detach()
|
116 |
+
|
117 |
+
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, global_step, last_layer=None,
|
118 |
+
logger=None, log_every=100):
|
119 |
+
# generator update
|
120 |
+
if optimizer_idx == 0:
|
121 |
+
# reconstruction loss
|
122 |
+
rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous())
|
123 |
+
|
124 |
+
# perceptual loss
|
125 |
+
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
126 |
+
p_loss = torch.mean(p_loss)
|
127 |
+
|
128 |
+
# discriminator loss
|
129 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
130 |
+
generator_adv_loss = self.gen_adv_loss(logits_fake)
|
131 |
+
|
132 |
+
if self.disc_adaptive_weight:
|
133 |
+
null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss
|
134 |
+
disc_adaptive_weight = self.calculate_adaptive_weight(null_loss, generator_adv_loss, last_layer=last_layer)
|
135 |
+
else:
|
136 |
+
disc_adaptive_weight = 1
|
137 |
+
disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
|
138 |
+
|
139 |
+
loss = self.rec_weight * rec_loss + \
|
140 |
+
self.perceptual_weight * p_loss + \
|
141 |
+
disc_adaptive_weight * disc_weight * generator_adv_loss + \
|
142 |
+
codebook_loss[0] + codebook_loss[1] + codebook_loss[2]
|
143 |
+
|
144 |
+
if global_step % log_every == 0:
|
145 |
+
rec_loss = self.rec_weight * rec_loss
|
146 |
+
p_loss = self.perceptual_weight * p_loss
|
147 |
+
generator_adv_loss = disc_adaptive_weight * disc_weight * generator_adv_loss
|
148 |
+
logger.info(f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, "
|
149 |
+
f"vq_loss: {codebook_loss[0]:.4f}, commit_loss: {codebook_loss[1]:.4f}, entropy_loss: {codebook_loss[2]:.4f}, "
|
150 |
+
f"codebook_usage: {codebook_loss[3]:.4f}, generator_adv_loss: {generator_adv_loss:.4f}, "
|
151 |
+
f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}")
|
152 |
+
return loss
|
153 |
+
|
154 |
+
# discriminator update
|
155 |
+
if optimizer_idx == 1:
|
156 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
157 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
158 |
+
|
159 |
+
disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
|
160 |
+
d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake)
|
161 |
+
|
162 |
+
if global_step % log_every == 0:
|
163 |
+
logits_real = logits_real.detach().mean()
|
164 |
+
logits_fake = logits_fake.detach().mean()
|
165 |
+
logger.info(f"(Discriminator) "
|
166 |
+
f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, "
|
167 |
+
f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}")
|
168 |
+
return d_adversarial_loss
|
tokenizer_image/vq_model.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# taming-transformers: https://github.com/CompVis/taming-transformers
|
3 |
+
# maskgit: https://github.com/google-research/maskgit
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class ModelArgs:
|
14 |
+
codebook_size: int = 16384
|
15 |
+
codebook_embed_dim: int = 8
|
16 |
+
codebook_l2_norm: bool = True
|
17 |
+
codebook_show_usage: bool = True
|
18 |
+
commit_loss_beta: float = 0.25
|
19 |
+
entropy_loss_ratio: float = 0.0
|
20 |
+
|
21 |
+
encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
22 |
+
decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
23 |
+
z_channels: int = 256
|
24 |
+
dropout_p: float = 0.0
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
class VQModel(nn.Module):
|
29 |
+
def __init__(self, config: ModelArgs):
|
30 |
+
super().__init__()
|
31 |
+
self.config = config
|
32 |
+
self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
|
33 |
+
self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
|
34 |
+
|
35 |
+
self.quantize = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
|
36 |
+
config.commit_loss_beta, config.entropy_loss_ratio,
|
37 |
+
config.codebook_l2_norm, config.codebook_show_usage)
|
38 |
+
self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
|
39 |
+
self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1)
|
40 |
+
|
41 |
+
def encode(self, x):
|
42 |
+
h = self.encoder(x)
|
43 |
+
h = self.quant_conv(h)
|
44 |
+
quant, emb_loss, info = self.quantize(h)
|
45 |
+
return quant, emb_loss, info
|
46 |
+
|
47 |
+
def decode(self, quant):
|
48 |
+
quant = self.post_quant_conv(quant)
|
49 |
+
dec = self.decoder(quant)
|
50 |
+
return dec
|
51 |
+
|
52 |
+
def decode_code(self, code_b, shape=None, channel_first=True):
|
53 |
+
quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
|
54 |
+
dec = self.decode(quant_b)
|
55 |
+
return dec
|
56 |
+
|
57 |
+
def forward(self, input):
|
58 |
+
quant, diff, _ = self.encode(input)
|
59 |
+
dec = self.decode(quant)
|
60 |
+
return dec, diff
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
class Encoder(nn.Module):
|
65 |
+
def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2,
|
66 |
+
norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256):
|
67 |
+
super().__init__()
|
68 |
+
self.num_resolutions = len(ch_mult)
|
69 |
+
self.num_res_blocks = num_res_blocks
|
70 |
+
self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
|
71 |
+
|
72 |
+
# downsampling
|
73 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
74 |
+
self.conv_blocks = nn.ModuleList()
|
75 |
+
for i_level in range(self.num_resolutions):
|
76 |
+
conv_block = nn.Module()
|
77 |
+
# res & attn
|
78 |
+
res_block = nn.ModuleList()
|
79 |
+
attn_block = nn.ModuleList()
|
80 |
+
block_in = ch*in_ch_mult[i_level]
|
81 |
+
block_out = ch*ch_mult[i_level]
|
82 |
+
for _ in range(self.num_res_blocks):
|
83 |
+
res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
|
84 |
+
block_in = block_out
|
85 |
+
if i_level == self.num_resolutions - 1:
|
86 |
+
attn_block.append(AttnBlock(block_in, norm_type))
|
87 |
+
conv_block.res = res_block
|
88 |
+
conv_block.attn = attn_block
|
89 |
+
# downsample
|
90 |
+
if i_level != self.num_resolutions-1:
|
91 |
+
conv_block.downsample = Downsample(block_in, resamp_with_conv)
|
92 |
+
self.conv_blocks.append(conv_block)
|
93 |
+
|
94 |
+
# middle
|
95 |
+
self.mid = nn.ModuleList()
|
96 |
+
self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
|
97 |
+
self.mid.append(AttnBlock(block_in, norm_type=norm_type))
|
98 |
+
self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
|
99 |
+
|
100 |
+
# end
|
101 |
+
self.norm_out = Normalize(block_in, norm_type)
|
102 |
+
self.conv_out = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
|
103 |
+
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
h = self.conv_in(x)
|
107 |
+
# downsampling
|
108 |
+
for i_level, block in enumerate(self.conv_blocks):
|
109 |
+
for i_block in range(self.num_res_blocks):
|
110 |
+
h = block.res[i_block](h)
|
111 |
+
if len(block.attn) > 0:
|
112 |
+
h = block.attn[i_block](h)
|
113 |
+
if i_level != self.num_resolutions - 1:
|
114 |
+
h = block.downsample(h)
|
115 |
+
|
116 |
+
# middle
|
117 |
+
for mid_block in self.mid:
|
118 |
+
h = mid_block(h)
|
119 |
+
|
120 |
+
# end
|
121 |
+
h = self.norm_out(h)
|
122 |
+
h = nonlinearity(h)
|
123 |
+
h = self.conv_out(h)
|
124 |
+
return h
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
class Decoder(nn.Module):
|
129 |
+
def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group",
|
130 |
+
dropout=0.0, resamp_with_conv=True, out_channels=3):
|
131 |
+
super().__init__()
|
132 |
+
self.num_resolutions = len(ch_mult)
|
133 |
+
self.num_res_blocks = num_res_blocks
|
134 |
+
|
135 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
136 |
+
# z to block_in
|
137 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
138 |
+
|
139 |
+
# middle
|
140 |
+
self.mid = nn.ModuleList()
|
141 |
+
self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
|
142 |
+
self.mid.append(AttnBlock(block_in, norm_type=norm_type))
|
143 |
+
self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
|
144 |
+
|
145 |
+
# upsampling
|
146 |
+
self.conv_blocks = nn.ModuleList()
|
147 |
+
for i_level in reversed(range(self.num_resolutions)):
|
148 |
+
conv_block = nn.Module()
|
149 |
+
# res & attn
|
150 |
+
res_block = nn.ModuleList()
|
151 |
+
attn_block = nn.ModuleList()
|
152 |
+
block_out = ch*ch_mult[i_level]
|
153 |
+
for _ in range(self.num_res_blocks + 1):
|
154 |
+
res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
|
155 |
+
block_in = block_out
|
156 |
+
if i_level == self.num_resolutions - 1:
|
157 |
+
attn_block.append(AttnBlock(block_in, norm_type))
|
158 |
+
conv_block.res = res_block
|
159 |
+
conv_block.attn = attn_block
|
160 |
+
# downsample
|
161 |
+
if i_level != 0:
|
162 |
+
conv_block.upsample = Upsample(block_in, resamp_with_conv)
|
163 |
+
self.conv_blocks.append(conv_block)
|
164 |
+
|
165 |
+
# end
|
166 |
+
self.norm_out = Normalize(block_in, norm_type)
|
167 |
+
self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
168 |
+
|
169 |
+
@property
|
170 |
+
def last_layer(self):
|
171 |
+
return self.conv_out.weight
|
172 |
+
|
173 |
+
def forward(self, z):
|
174 |
+
# z to block_in
|
175 |
+
h = self.conv_in(z)
|
176 |
+
|
177 |
+
# middle
|
178 |
+
for mid_block in self.mid:
|
179 |
+
h = mid_block(h)
|
180 |
+
|
181 |
+
# upsampling
|
182 |
+
for i_level, block in enumerate(self.conv_blocks):
|
183 |
+
for i_block in range(self.num_res_blocks + 1):
|
184 |
+
h = block.res[i_block](h)
|
185 |
+
if len(block.attn) > 0:
|
186 |
+
h = block.attn[i_block](h)
|
187 |
+
if i_level != self.num_resolutions - 1:
|
188 |
+
h = block.upsample(h)
|
189 |
+
|
190 |
+
# end
|
191 |
+
h = self.norm_out(h)
|
192 |
+
h = nonlinearity(h)
|
193 |
+
h = self.conv_out(h)
|
194 |
+
return h
|
195 |
+
|
196 |
+
|
197 |
+
class VectorQuantizer(nn.Module):
|
198 |
+
def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
|
199 |
+
super().__init__()
|
200 |
+
self.n_e = n_e
|
201 |
+
self.e_dim = e_dim
|
202 |
+
self.beta = beta
|
203 |
+
self.entropy_loss_ratio = entropy_loss_ratio
|
204 |
+
self.l2_norm = l2_norm
|
205 |
+
self.show_usage = show_usage
|
206 |
+
|
207 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
208 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
209 |
+
if self.l2_norm:
|
210 |
+
self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
|
211 |
+
if self.show_usage:
|
212 |
+
self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
|
213 |
+
|
214 |
+
|
215 |
+
def forward(self, z):
|
216 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
217 |
+
z = torch.einsum('b c h w -> b h w c', z).contiguous()
|
218 |
+
z_flattened = z.view(-1, self.e_dim)
|
219 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
220 |
+
|
221 |
+
if self.l2_norm:
|
222 |
+
z = F.normalize(z, p=2, dim=-1)
|
223 |
+
z_flattened = F.normalize(z_flattened, p=2, dim=-1)
|
224 |
+
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
|
225 |
+
else:
|
226 |
+
embedding = self.embedding.weight
|
227 |
+
|
228 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
229 |
+
torch.sum(embedding**2, dim=1) - 2 * \
|
230 |
+
torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))
|
231 |
+
|
232 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
233 |
+
z_q = embedding[min_encoding_indices].view(z.shape)
|
234 |
+
perplexity = None
|
235 |
+
min_encodings = None
|
236 |
+
vq_loss = None
|
237 |
+
commit_loss = None
|
238 |
+
entropy_loss = None
|
239 |
+
codebook_usage = 0
|
240 |
+
|
241 |
+
if self.show_usage and self.training:
|
242 |
+
cur_len = min_encoding_indices.shape[0]
|
243 |
+
self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
|
244 |
+
self.codebook_used[-cur_len:] = min_encoding_indices
|
245 |
+
codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e
|
246 |
+
|
247 |
+
# compute loss for embedding
|
248 |
+
if self.training:
|
249 |
+
vq_loss = torch.mean((z_q - z.detach()) ** 2)
|
250 |
+
commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
|
251 |
+
entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
|
252 |
+
|
253 |
+
# preserve gradients
|
254 |
+
z_q = z + (z_q - z).detach()
|
255 |
+
|
256 |
+
# reshape back to match original input shape
|
257 |
+
z_q = torch.einsum('b h w c -> b c h w', z_q)
|
258 |
+
|
259 |
+
return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)
|
260 |
+
|
261 |
+
def get_codebook_entry(self, indices, shape=None, channel_first=True):
|
262 |
+
# shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
|
263 |
+
if self.l2_norm:
|
264 |
+
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
|
265 |
+
else:
|
266 |
+
embedding = self.embedding.weight
|
267 |
+
z_q = embedding[indices] # (b*h*w, c)
|
268 |
+
|
269 |
+
if shape is not None:
|
270 |
+
if channel_first:
|
271 |
+
z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
|
272 |
+
# reshape back to match original input shape
|
273 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
274 |
+
else:
|
275 |
+
z_q = z_q.view(shape)
|
276 |
+
return z_q
|
277 |
+
|
278 |
+
|
279 |
+
class ResnetBlock(nn.Module):
|
280 |
+
def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group'):
|
281 |
+
super().__init__()
|
282 |
+
self.in_channels = in_channels
|
283 |
+
out_channels = in_channels if out_channels is None else out_channels
|
284 |
+
self.out_channels = out_channels
|
285 |
+
self.use_conv_shortcut = conv_shortcut
|
286 |
+
|
287 |
+
self.norm1 = Normalize(in_channels, norm_type)
|
288 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
289 |
+
self.norm2 = Normalize(out_channels, norm_type)
|
290 |
+
self.dropout = nn.Dropout(dropout)
|
291 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
292 |
+
|
293 |
+
if self.in_channels != self.out_channels:
|
294 |
+
if self.use_conv_shortcut:
|
295 |
+
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
296 |
+
else:
|
297 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
298 |
+
|
299 |
+
def forward(self, x):
|
300 |
+
h = x
|
301 |
+
h = self.norm1(h)
|
302 |
+
h = nonlinearity(h)
|
303 |
+
h = self.conv1(h)
|
304 |
+
h = self.norm2(h)
|
305 |
+
h = nonlinearity(h)
|
306 |
+
h = self.dropout(h)
|
307 |
+
h = self.conv2(h)
|
308 |
+
|
309 |
+
if self.in_channels != self.out_channels:
|
310 |
+
if self.use_conv_shortcut:
|
311 |
+
x = self.conv_shortcut(x)
|
312 |
+
else:
|
313 |
+
x = self.nin_shortcut(x)
|
314 |
+
return x+h
|
315 |
+
|
316 |
+
|
317 |
+
class AttnBlock(nn.Module):
|
318 |
+
def __init__(self, in_channels, norm_type='group'):
|
319 |
+
super().__init__()
|
320 |
+
self.norm = Normalize(in_channels, norm_type)
|
321 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
322 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
323 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
324 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
325 |
+
|
326 |
+
|
327 |
+
def forward(self, x):
|
328 |
+
h_ = x
|
329 |
+
h_ = self.norm(h_)
|
330 |
+
q = self.q(h_)
|
331 |
+
k = self.k(h_)
|
332 |
+
v = self.v(h_)
|
333 |
+
|
334 |
+
# compute attention
|
335 |
+
b,c,h,w = q.shape
|
336 |
+
q = q.reshape(b,c,h*w)
|
337 |
+
q = q.permute(0,2,1) # b,hw,c
|
338 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
339 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
340 |
+
w_ = w_ * (int(c)**(-0.5))
|
341 |
+
w_ = F.softmax(w_, dim=2)
|
342 |
+
|
343 |
+
# attend to values
|
344 |
+
v = v.reshape(b,c,h*w)
|
345 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
346 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
347 |
+
h_ = h_.reshape(b,c,h,w)
|
348 |
+
|
349 |
+
h_ = self.proj_out(h_)
|
350 |
+
|
351 |
+
return x+h_
|
352 |
+
|
353 |
+
|
354 |
+
def nonlinearity(x):
|
355 |
+
# swish
|
356 |
+
return x*torch.sigmoid(x)
|
357 |
+
|
358 |
+
|
359 |
+
def Normalize(in_channels, norm_type='group'):
|
360 |
+
assert norm_type in ['group', 'batch']
|
361 |
+
if norm_type == 'group':
|
362 |
+
return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
363 |
+
elif norm_type == 'batch':
|
364 |
+
return nn.SyncBatchNorm(in_channels)
|
365 |
+
|
366 |
+
|
367 |
+
class Upsample(nn.Module):
|
368 |
+
def __init__(self, in_channels, with_conv):
|
369 |
+
super().__init__()
|
370 |
+
self.with_conv = with_conv
|
371 |
+
if self.with_conv:
|
372 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
373 |
+
|
374 |
+
def forward(self, x):
|
375 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
376 |
+
if self.with_conv:
|
377 |
+
x = self.conv(x)
|
378 |
+
return x
|
379 |
+
|
380 |
+
|
381 |
+
class Downsample(nn.Module):
|
382 |
+
def __init__(self, in_channels, with_conv):
|
383 |
+
super().__init__()
|
384 |
+
self.with_conv = with_conv
|
385 |
+
if self.with_conv:
|
386 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
387 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
388 |
+
|
389 |
+
def forward(self, x):
|
390 |
+
if self.with_conv:
|
391 |
+
pad = (0,1,0,1)
|
392 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
393 |
+
x = self.conv(x)
|
394 |
+
else:
|
395 |
+
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
396 |
+
return x
|
397 |
+
|
398 |
+
|
399 |
+
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
|
400 |
+
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
|
401 |
+
flat_affinity /= temperature
|
402 |
+
probs = F.softmax(flat_affinity, dim=-1)
|
403 |
+
log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
|
404 |
+
if loss_type == "softmax":
|
405 |
+
target_probs = probs
|
406 |
+
else:
|
407 |
+
raise ValueError("Entropy loss {} not supported".format(loss_type))
|
408 |
+
avg_probs = torch.mean(target_probs, dim=0)
|
409 |
+
avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
|
410 |
+
sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
|
411 |
+
loss = sample_entropy - avg_entropy
|
412 |
+
return loss
|
413 |
+
|
414 |
+
|
415 |
+
#################################################################################
|
416 |
+
# VQ Model Configs #
|
417 |
+
#################################################################################
|
418 |
+
def VQ_8(**kwargs):
|
419 |
+
return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs))
|
420 |
+
|
421 |
+
def VQ_16(**kwargs):
|
422 |
+
return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))
|
423 |
+
|
424 |
+
VQ_models = {'VQ-16': VQ_16, 'VQ-8': VQ_8}
|
tokenizer_image/vq_train.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/train.py
|
3 |
+
# nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
|
4 |
+
import torch
|
5 |
+
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
|
6 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
7 |
+
torch.backends.cudnn.allow_tf32 = True
|
8 |
+
import torch.distributed as dist
|
9 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
10 |
+
from torch.utils.data import Dataset, DataLoader
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
from torchvision.datasets import ImageFolder
|
13 |
+
from torchvision import transforms
|
14 |
+
|
15 |
+
import os
|
16 |
+
import time
|
17 |
+
import argparse
|
18 |
+
from glob import glob
|
19 |
+
from copy import deepcopy
|
20 |
+
|
21 |
+
from utils.logger import create_logger
|
22 |
+
from utils.distributed import init_distributed_mode
|
23 |
+
from utils.ema import update_ema, requires_grad
|
24 |
+
from dataset.augmentation import random_crop_arr
|
25 |
+
from dataset.build import build_dataset
|
26 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
27 |
+
from tokenizer.tokenizer_image.vq_loss import VQLoss
|
28 |
+
|
29 |
+
import warnings
|
30 |
+
warnings.filterwarnings('ignore')
|
31 |
+
|
32 |
+
#################################################################################
|
33 |
+
# Training Loop #
|
34 |
+
#################################################################################
|
35 |
+
|
36 |
+
def main(args):
|
37 |
+
"""
|
38 |
+
Trains a new model.
|
39 |
+
"""
|
40 |
+
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
41 |
+
|
42 |
+
# Setup DDP:
|
43 |
+
init_distributed_mode(args)
|
44 |
+
assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
|
45 |
+
rank = dist.get_rank()
|
46 |
+
device = rank % torch.cuda.device_count()
|
47 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
48 |
+
torch.manual_seed(seed)
|
49 |
+
torch.cuda.set_device(device)
|
50 |
+
|
51 |
+
# Setup an experiment folder:
|
52 |
+
if rank == 0:
|
53 |
+
os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
|
54 |
+
experiment_index = len(glob(f"{args.results_dir}/*"))
|
55 |
+
model_string_name = args.vq_model.replace("/", "-")
|
56 |
+
experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder
|
57 |
+
checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
|
58 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
59 |
+
logger = create_logger(experiment_dir)
|
60 |
+
logger.info(f"Experiment directory created at {experiment_dir}")
|
61 |
+
|
62 |
+
time_record = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
63 |
+
cloud_results_dir = f"{args.cloud_save_path}/{time_record}"
|
64 |
+
cloud_checkpoint_dir = f"{cloud_results_dir}/{experiment_index:03d}-{model_string_name}/checkpoints"
|
65 |
+
os.makedirs(cloud_checkpoint_dir, exist_ok=True)
|
66 |
+
logger.info(f"Experiment directory created in cloud at {cloud_checkpoint_dir}")
|
67 |
+
|
68 |
+
else:
|
69 |
+
logger = create_logger(None)
|
70 |
+
|
71 |
+
# training args
|
72 |
+
logger.info(f"{args}")
|
73 |
+
|
74 |
+
# training env
|
75 |
+
logger.info(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
76 |
+
|
77 |
+
# create and load model
|
78 |
+
vq_model = VQ_models[args.vq_model](
|
79 |
+
codebook_size=args.codebook_size,
|
80 |
+
codebook_embed_dim=args.codebook_embed_dim,
|
81 |
+
commit_loss_beta=args.commit_loss_beta,
|
82 |
+
entropy_loss_ratio=args.entropy_loss_ratio,
|
83 |
+
dropout_p=args.dropout_p,
|
84 |
+
)
|
85 |
+
logger.info(f"VQ Model Parameters: {sum(p.numel() for p in vq_model.parameters()):,}")
|
86 |
+
if args.ema:
|
87 |
+
ema = deepcopy(vq_model).to(device) # Create an EMA of the model for use after training
|
88 |
+
requires_grad(ema, False)
|
89 |
+
logger.info(f"VQ Model EMA Parameters: {sum(p.numel() for p in ema.parameters()):,}")
|
90 |
+
vq_model = vq_model.to(device)
|
91 |
+
|
92 |
+
vq_loss = VQLoss(
|
93 |
+
disc_start=args.disc_start,
|
94 |
+
disc_weight=args.disc_weight,
|
95 |
+
disc_type=args.disc_type,
|
96 |
+
disc_loss=args.disc_loss,
|
97 |
+
gen_adv_loss=args.gen_loss,
|
98 |
+
image_size=args.image_size,
|
99 |
+
perceptual_weight=args.perceptual_weight,
|
100 |
+
reconstruction_weight=args.reconstruction_weight,
|
101 |
+
reconstruction_loss=args.reconstruction_loss,
|
102 |
+
codebook_weight=args.codebook_weight,
|
103 |
+
).to(device)
|
104 |
+
logger.info(f"Discriminator Parameters: {sum(p.numel() for p in vq_loss.discriminator.parameters()):,}")
|
105 |
+
|
106 |
+
# initialize a GradScaler. If enabled=False scaler is a no-op
|
107 |
+
scaler = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision =='fp16'))
|
108 |
+
scaler_disc = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision =='fp16'))
|
109 |
+
# Setup optimizer
|
110 |
+
optimizer = torch.optim.Adam(vq_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
|
111 |
+
optimizer_disc = torch.optim.Adam(vq_loss.discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
|
112 |
+
|
113 |
+
# Setup data:
|
114 |
+
transform = transforms.Compose([
|
115 |
+
transforms.Lambda(lambda pil_image: random_crop_arr(pil_image, args.image_size)),
|
116 |
+
transforms.RandomHorizontalFlip(),
|
117 |
+
transforms.ToTensor(),
|
118 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
119 |
+
])
|
120 |
+
dataset = build_dataset(args, transform=transform)
|
121 |
+
sampler = DistributedSampler(
|
122 |
+
dataset,
|
123 |
+
num_replicas=dist.get_world_size(),
|
124 |
+
rank=rank,
|
125 |
+
shuffle=True,
|
126 |
+
seed=args.global_seed
|
127 |
+
)
|
128 |
+
loader = DataLoader(
|
129 |
+
dataset,
|
130 |
+
batch_size=int(args.global_batch_size // dist.get_world_size()),
|
131 |
+
shuffle=False,
|
132 |
+
sampler=sampler,
|
133 |
+
num_workers=args.num_workers,
|
134 |
+
pin_memory=True,
|
135 |
+
drop_last=True
|
136 |
+
)
|
137 |
+
logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")
|
138 |
+
|
139 |
+
|
140 |
+
# Prepare models for training:
|
141 |
+
if args.vq_ckpt:
|
142 |
+
checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
|
143 |
+
vq_model.load_state_dict(checkpoint["model"])
|
144 |
+
if args.ema:
|
145 |
+
ema.load_state_dict(checkpoint["ema"])
|
146 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
147 |
+
vq_loss.discriminator.load_state_dict(checkpoint["discriminator"])
|
148 |
+
optimizer_disc.load_state_dict(checkpoint["optimizer_disc"])
|
149 |
+
if not args.finetune:
|
150 |
+
train_steps = checkpoint["steps"] if "steps" in checkpoint else int(args.vq_ckpt.split('/')[-1].split('.')[0])
|
151 |
+
start_epoch = int(train_steps / int(len(dataset) / args.global_batch_size))
|
152 |
+
train_steps = int(start_epoch * int(len(dataset) / args.global_batch_size))
|
153 |
+
else:
|
154 |
+
train_steps = 0
|
155 |
+
start_epoch = 0
|
156 |
+
del checkpoint
|
157 |
+
logger.info(f"Resume training from checkpoint: {args.vq_ckpt}")
|
158 |
+
logger.info(f"Initial state: steps={train_steps}, epochs={start_epoch}")
|
159 |
+
else:
|
160 |
+
train_steps = 0
|
161 |
+
start_epoch = 0
|
162 |
+
if args.ema:
|
163 |
+
update_ema(ema, vq_model, decay=0) # Ensure EMA is initialized with synced weights
|
164 |
+
|
165 |
+
if args.compile:
|
166 |
+
logger.info("compiling the model... (may take several minutes)")
|
167 |
+
vq_model = torch.compile(vq_model) # requires PyTorch 2.0
|
168 |
+
|
169 |
+
vq_model = DDP(vq_model.to(device), device_ids=[args.gpu])
|
170 |
+
vq_model.train()
|
171 |
+
if args.ema:
|
172 |
+
ema.eval() # EMA model should always be in eval mode
|
173 |
+
vq_loss = DDP(vq_loss.to(device), device_ids=[args.gpu])
|
174 |
+
vq_loss.train()
|
175 |
+
|
176 |
+
ptdtype = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.mixed_precision]
|
177 |
+
|
178 |
+
# Variables for monitoring/logging purposes:
|
179 |
+
log_steps = 0
|
180 |
+
running_loss = 0
|
181 |
+
start_time = time.time()
|
182 |
+
|
183 |
+
logger.info(f"Training for {args.epochs} epochs...")
|
184 |
+
for epoch in range(start_epoch, args.epochs):
|
185 |
+
sampler.set_epoch(epoch)
|
186 |
+
logger.info(f"Beginning epoch {epoch}...")
|
187 |
+
for x, y in loader:
|
188 |
+
imgs = x.to(device, non_blocking=True)
|
189 |
+
|
190 |
+
# generator training
|
191 |
+
optimizer.zero_grad()
|
192 |
+
with torch.cuda.amp.autocast(dtype=ptdtype):
|
193 |
+
recons_imgs, codebook_loss = vq_model(imgs)
|
194 |
+
loss_gen = vq_loss(codebook_loss, imgs, recons_imgs, optimizer_idx=0, global_step=train_steps+1,
|
195 |
+
last_layer=vq_model.module.decoder.last_layer,
|
196 |
+
logger=logger, log_every=args.log_every)
|
197 |
+
scaler.scale(loss_gen).backward()
|
198 |
+
if args.max_grad_norm != 0.0:
|
199 |
+
scaler.unscale_(optimizer)
|
200 |
+
torch.nn.utils.clip_grad_norm_(vq_model.parameters(), args.max_grad_norm)
|
201 |
+
scaler.step(optimizer)
|
202 |
+
scaler.update()
|
203 |
+
if args.ema:
|
204 |
+
update_ema(ema, vq_model.module._orig_mod if args.compile else vq_model.module)
|
205 |
+
|
206 |
+
# discriminator training
|
207 |
+
optimizer_disc.zero_grad()
|
208 |
+
with torch.cuda.amp.autocast(dtype=ptdtype):
|
209 |
+
loss_disc = vq_loss(codebook_loss, imgs, recons_imgs, optimizer_idx=1, global_step=train_steps+1,
|
210 |
+
logger=logger, log_every=args.log_every)
|
211 |
+
scaler_disc.scale(loss_disc).backward()
|
212 |
+
if args.max_grad_norm != 0.0:
|
213 |
+
scaler_disc.unscale_(optimizer_disc)
|
214 |
+
torch.nn.utils.clip_grad_norm_(vq_loss.module.discriminator.parameters(), args.max_grad_norm)
|
215 |
+
scaler_disc.step(optimizer_disc)
|
216 |
+
scaler_disc.update()
|
217 |
+
|
218 |
+
# # Log loss values:
|
219 |
+
running_loss += loss_gen.item() + loss_disc.item()
|
220 |
+
|
221 |
+
log_steps += 1
|
222 |
+
train_steps += 1
|
223 |
+
if train_steps % args.log_every == 0:
|
224 |
+
# Measure training speed:
|
225 |
+
torch.cuda.synchronize()
|
226 |
+
end_time = time.time()
|
227 |
+
steps_per_sec = log_steps / (end_time - start_time)
|
228 |
+
# Reduce loss history over all processes:
|
229 |
+
avg_loss = torch.tensor(running_loss / log_steps, device=device)
|
230 |
+
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
|
231 |
+
avg_loss = avg_loss.item() / dist.get_world_size()
|
232 |
+
logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
|
233 |
+
# Reset monitoring variables:
|
234 |
+
running_loss = 0
|
235 |
+
log_steps = 0
|
236 |
+
start_time = time.time()
|
237 |
+
|
238 |
+
# Save checkpoint:
|
239 |
+
if train_steps % args.ckpt_every == 0 and train_steps > 0:
|
240 |
+
if rank == 0:
|
241 |
+
if args.compile:
|
242 |
+
model_weight = vq_model.module._orig_mod.state_dict()
|
243 |
+
else:
|
244 |
+
model_weight = vq_model.module.state_dict()
|
245 |
+
checkpoint = {
|
246 |
+
"model": model_weight,
|
247 |
+
"optimizer": optimizer.state_dict(),
|
248 |
+
"discriminator": vq_loss.module.discriminator.state_dict(),
|
249 |
+
"optimizer_disc": optimizer_disc.state_dict(),
|
250 |
+
"steps": train_steps,
|
251 |
+
"args": args
|
252 |
+
}
|
253 |
+
if args.ema:
|
254 |
+
checkpoint["ema"] = ema.state_dict()
|
255 |
+
if not args.no_local_save:
|
256 |
+
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
|
257 |
+
torch.save(checkpoint, checkpoint_path)
|
258 |
+
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
259 |
+
|
260 |
+
cloud_checkpoint_path = f"{cloud_checkpoint_dir}/{train_steps:07d}.pt"
|
261 |
+
torch.save(checkpoint, cloud_checkpoint_path)
|
262 |
+
logger.info(f"Saved checkpoint in cloud to {cloud_checkpoint_path}")
|
263 |
+
dist.barrier()
|
264 |
+
|
265 |
+
vq_model.eval() # important! This disables randomized embedding dropout
|
266 |
+
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
|
267 |
+
|
268 |
+
logger.info("Done!")
|
269 |
+
dist.destroy_process_group()
|
270 |
+
|
271 |
+
|
272 |
+
|
273 |
+
if __name__ == "__main__":
|
274 |
+
parser = argparse.ArgumentParser()
|
275 |
+
parser.add_argument("--data-path", type=str, required=True)
|
276 |
+
parser.add_argument("--data-face-path", type=str, default=None, help="face datasets to improve vq model")
|
277 |
+
parser.add_argument("--cloud-save-path", type=str, required=True, help='please specify a cloud disk path, if not, local path')
|
278 |
+
parser.add_argument("--no-local-save", action='store_true', help='no save checkpoints to local path for limited disk volume')
|
279 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
280 |
+
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for resume training")
|
281 |
+
parser.add_argument("--finetune", action='store_true', help="finetune a pre-trained vq model")
|
282 |
+
parser.add_argument("--ema", action='store_true', help="whether using ema training")
|
283 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
284 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
285 |
+
parser.add_argument("--codebook-l2-norm", action='store_true', default=True, help="l2 norm codebook")
|
286 |
+
parser.add_argument("--codebook-weight", type=float, default=1.0, help="codebook loss weight for vector quantization")
|
287 |
+
parser.add_argument("--entropy-loss-ratio", type=float, default=0.0, help="entropy loss ratio in codebook loss")
|
288 |
+
parser.add_argument("--commit-loss-beta", type=float, default=0.25, help="commit loss beta in codebook loss")
|
289 |
+
parser.add_argument("--reconstruction-weight", type=float, default=1.0, help="reconstruction loss weight of image pixel")
|
290 |
+
parser.add_argument("--reconstruction-loss", type=str, default='l2', help="reconstruction loss type of image pixel")
|
291 |
+
parser.add_argument("--perceptual-weight", type=float, default=1.0, help="perceptual loss weight of LPIPS")
|
292 |
+
parser.add_argument("--disc-weight", type=float, default=0.5, help="discriminator loss weight for gan training")
|
293 |
+
parser.add_argument("--disc-start", type=int, default=20000, help="iteration to start discriminator training and loss")
|
294 |
+
parser.add_argument("--disc-type", type=str, choices=['patchgan', 'stylegan'], default='patchgan', help="discriminator type")
|
295 |
+
parser.add_argument("--disc-loss", type=str, choices=['hinge', 'vanilla', 'non-saturating'], default='hinge', help="discriminator loss")
|
296 |
+
parser.add_argument("--gen-loss", type=str, choices=['hinge', 'non-saturating'], default='hinge', help="generator loss for gan training")
|
297 |
+
parser.add_argument("--compile", action='store_true', default=False)
|
298 |
+
parser.add_argument("--dropout-p", type=float, default=0.0, help="dropout_p")
|
299 |
+
parser.add_argument("--results-dir", type=str, default="results_tokenizer_image")
|
300 |
+
parser.add_argument("--dataset", type=str, default='imagenet')
|
301 |
+
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
|
302 |
+
parser.add_argument("--epochs", type=int, default=50)
|
303 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
304 |
+
parser.add_argument("--weight-decay", type=float, default=5e-2, help="Weight decay to use.")
|
305 |
+
parser.add_argument("--beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
306 |
+
parser.add_argument("--beta2", type=float, default=0.95, help="The beta2 parameter for the Adam optimizer.")
|
307 |
+
parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
|
308 |
+
parser.add_argument("--global-batch-size", type=int, default=128)
|
309 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
310 |
+
parser.add_argument("--num-workers", type=int, default=16)
|
311 |
+
parser.add_argument("--log-every", type=int, default=100)
|
312 |
+
parser.add_argument("--ckpt-every", type=int, default=5000)
|
313 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
|
314 |
+
parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
315 |
+
args = parser.parse_args()
|
316 |
+
main(args)
|