ShoufaChen commited on
Commit
4d20c2f
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__
3
+ *.pyc
4
+ *.egg-info
5
+ dist
6
+ .ipynb_checkpoints
7
+ *.ipynb
8
+
9
+ # Log
10
+ *.log
11
+ *.log.*
12
+ *.json
13
+ *.jsonl
14
+
15
+ # Data
16
+ datasets
17
+ *.zip
18
+ *.png
19
+ *.jpg
20
+ *.jpeg
21
+
22
+ # Model
23
+ checkpoints
24
+ ckpts*
25
+ *.ckpt
26
+ *.pth
27
+ *.pt
28
+ pretrained_models
29
+
30
+ # Other
31
+ .DS_Store
32
+ wandb
33
+ output
34
+ results
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LlamaGen
3
+ emoji: 🏆
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.36.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import gradio as gr
3
+ from imagenet_en_cn import IMAGENET_1K_CLASSES
4
+ from huggingface_hub import hf_hub_download
5
+ import torch
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cudnn.allow_tf32 = True
8
+ torch.set_float32_matmul_precision('high')
9
+ setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
10
+ setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
11
+
12
+ import time
13
+ import argparse
14
+ from tokenizer_image.vq_model import VQ_models
15
+ from models.gpt import GPT_models
16
+ from models.generate import generate
17
+
18
+ device = "cuda"
19
+
20
+ model2ckpt = {
21
+ "GPT-XL": ("vq_ds16_c2i.pt", "c2i_XL_384.pt", 384),
22
+ "GPT-B": ("vq_ds16_c2i.pt", "c2i_B_256.pt", 256),
23
+ }
24
+
25
+ def load_model(args):
26
+ ckpt_folder = "./"
27
+ vq_ckpt, gpt_ckpt, image_size = model2ckpt[args.gpt_model]
28
+ hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=vq_ckpt, local_dir=ckpt_folder)
29
+ hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=gpt_ckpt, local_dir=ckpt_folder)
30
+ # create and load model
31
+ vq_model = VQ_models[args.vq_model](
32
+ codebook_size=args.codebook_size,
33
+ codebook_embed_dim=args.codebook_embed_dim)
34
+ vq_model.to(device)
35
+ vq_model.eval()
36
+ checkpoint = torch.load(f"{ckpt_folder}{vq_ckpt}", map_location="cpu")
37
+ vq_model.load_state_dict(checkpoint["model"])
38
+ del checkpoint
39
+ print(f"image tokenizer is loaded")
40
+
41
+ # create and load gpt model
42
+ precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
43
+ latent_size = image_size // args.downsample_size
44
+ gpt_model = GPT_models[args.gpt_model](
45
+ vocab_size=args.codebook_size,
46
+ block_size=latent_size ** 2,
47
+ num_classes=args.num_classes,
48
+ cls_token_num=args.cls_token_num,
49
+ model_type=args.gpt_type,
50
+ ).to(device=device, dtype=precision)
51
+
52
+ checkpoint = torch.load(f"{ckpt_folder}{gpt_ckpt}", map_location="cpu")
53
+ if args.from_fsdp: # fspd
54
+ model_weight = checkpoint
55
+ elif "model" in checkpoint: # ddp
56
+ model_weight = checkpoint["model"]
57
+ elif "module" in checkpoint: # deepspeed
58
+ model_weight = checkpoint["module"]
59
+ elif "state_dict" in checkpoint:
60
+ model_weight = checkpoint["state_dict"]
61
+ else:
62
+ raise Exception("please check model weight")
63
+ # if 'freqs_cis' in model_weight:
64
+ # model_weight.pop('freqs_cis')
65
+ gpt_model.load_state_dict(model_weight, strict=False)
66
+ gpt_model.eval()
67
+ del checkpoint
68
+ print(f"gpt model is loaded")
69
+
70
+ if args.compile:
71
+ print(f"compiling the model...")
72
+ gpt_model = torch.compile(
73
+ gpt_model,
74
+ mode="reduce-overhead",
75
+ fullgraph=True
76
+ ) # requires PyTorch 2.0 (optional)
77
+ else:
78
+ print(f"no need to compile model in demo")
79
+
80
+ return vq_model, gpt_model, image_size
81
+
82
+
83
+ def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
84
+ n = 4
85
+ latent_size = image_size // args.downsample_size
86
+ # Labels to condition the model with (feel free to change):
87
+ class_labels = [class_label for _ in range(n)]
88
+ c_indices = torch.tensor(class_labels, device=device)
89
+ qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
90
+
91
+ t1 = time.time()
92
+ torch.manual_seed(seed)
93
+ index_sample = generate(
94
+ gpt_model, c_indices, latent_size ** 2,
95
+ cfg_scale=cfg_scale, cfg_interval=args.cfg_interval,
96
+ temperature=temperature, top_k=top_k,
97
+ top_p=top_p, sample_logits=True,
98
+ )
99
+ sampling_time = time.time() - t1
100
+ print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
101
+
102
+ t2 = time.time()
103
+ samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
104
+ decoder_time = time.time() - t2
105
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
106
+ # Convert to PIL.Image format:
107
+ samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy()
108
+ samples = [Image.fromarray(sample) for sample in samples]
109
+ return samples
110
+
111
+
112
+ parser = argparse.ArgumentParser()
113
+ parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
114
+ parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
115
+ parser.add_argument("--from-fsdp", action='store_true')
116
+ parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
117
+ parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
118
+ parser.add_argument("--compile", action='store_true', default=False)
119
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
120
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
121
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
122
+ parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
123
+ parser.add_argument("--num-classes", type=int, default=1000)
124
+ parser.add_argument("--cfg-scale", type=float, default=4.0)
125
+ parser.add_argument("--cfg-interval", type=float, default=-1)
126
+ parser.add_argument("--seed", type=int, default=0)
127
+ parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
128
+ parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
129
+ parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
130
+ args = parser.parse_args()
131
+
132
+ vq_model, gpt_model, image_size = load_model(args)
133
+
134
+ with gr.Blocks() as demo:
135
+ gr.Markdown("<h1 style='text-align: center'>Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation</h1>")
136
+
137
+ with gr.Tabs():
138
+ with gr.TabItem('Generate'):
139
+ with gr.Row():
140
+ with gr.Column():
141
+ # with gr.Row():
142
+ # image_size = gr.Radio(choices=[384], value=384, label='Peize Model Resolution')
143
+ with gr.Row():
144
+ i1k_class = gr.Dropdown(
145
+ list(IMAGENET_1K_CLASSES.values()),
146
+ value='Eskimo dog, husky [爱斯基摩犬,哈士奇]',
147
+ type="index", label='ImageNet-1K Class'
148
+ )
149
+ cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
150
+ top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=4000, label='Top-K')
151
+ top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P")
152
+ temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature')
153
+ seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
154
+ # seed = gr.Number(value=0, label='Seed')
155
+ button = gr.Button("Generate", variant="primary")
156
+ with gr.Column():
157
+ output = gr.Gallery(label='Generated Images', height=700)
158
+ button.click(infer, inputs=[cfg_scale, top_k, top_p, temperature, i1k_class, seed], outputs=[output])
159
+ demo.queue()
160
+ demo.launch(debug=True)
imagenet_en_cn.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ IMAGENET_1K_CLASSES = {
2
+ 0: 'tench, Tinca tinca [丁鲷]',
3
+ 1: 'goldfish, Carassius auratus [金鱼]',
4
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias [大白鲨]',
5
+ 3: 'tiger shark, Galeocerdo cuvieri [虎鲨]',
6
+ 4: 'hammerhead, hammerhead shark [锤头鲨]',
7
+ 5: 'electric ray, crampfish, numbfish, torpedo [电鳐]',
8
+ 6: 'stingray [黄貂鱼]',
9
+ 7: 'cock [公鸡]',
10
+ 8: 'hen [母鸡]',
11
+ 9: 'ostrich, Struthio camelus [鸵鸟]',
12
+ 10: 'brambling, Fringilla montifringilla [燕雀]',
13
+ 11: 'goldfinch, Carduelis carduelis [金翅雀]',
14
+ 12: 'house finch, linnet, Carpodacus mexicanus [家朱雀]',
15
+ 13: 'junco, snowbird [灯芯草雀]',
16
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea [靛蓝雀,靛蓝鸟]',
17
+ 15: 'robin, American robin, Turdus migratorius [蓝鹀]',
18
+ 16: 'bulbul [夜莺]',
19
+ 17: 'jay [松鸦]',
20
+ 18: 'magpie [喜鹊]',
21
+ 19: 'chickadee [山雀]',
22
+ 20: 'water ouzel, dipper [河鸟]',
23
+ 21: 'kite [鸢(猛禽)]',
24
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus [秃头鹰]',
25
+ 23: 'vulture [秃鹫]',
26
+ 24: 'great grey owl, great gray owl, Strix nebulosa [大灰猫头鹰]',
27
+ 25: 'European fire salamander, Salamandra salamandra [欧洲火蝾螈]',
28
+ 26: 'common newt, Triturus vulgaris [普通蝾螈]',
29
+ 27: 'eft [水蜥]',
30
+ 28: 'spotted salamander, Ambystoma maculatum [斑点蝾螈]',
31
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum [蝾螈,泥狗]',
32
+ 30: 'bullfrog, Rana catesbeiana [牛蛙]',
33
+ 31: 'tree frog, tree-frog [树蛙]',
34
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui [尾蛙,铃蟾蜍,肋蟾蜍,尾蟾蜍]',
35
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta [红海龟]',
36
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea [皮革龟]',
37
+ 35: 'mud turtle [泥龟]',
38
+ 36: 'terrapin [淡水龟]',
39
+ 37: 'box turtle, box tortoise [箱龟]',
40
+ 38: 'banded gecko [带状壁虎]',
41
+ 39: 'common iguana, iguana, Iguana iguana [普通鬣蜥]',
42
+ 40: 'American chameleon, anole, Anolis carolinensis [美国变色龙]',
43
+ 41: 'whiptail, whiptail lizard [鞭尾蜥蜴]',
44
+ 42: 'agama [飞龙科蜥蜴]',
45
+ 43: 'frilled lizard, Chlamydosaurus kingi [褶边蜥蜴]',
46
+ 44: 'alligator lizard [鳄鱼蜥蜴]',
47
+ 45: 'Gila monster, Heloderma suspectum [毒蜥]',
48
+ 46: 'green lizard, Lacerta viridis [绿蜥蜴]',
49
+ 47: 'African chameleon, Chamaeleo chamaeleon [非洲变色龙]',
50
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis [科莫多蜥蜴]',
51
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus [非洲鳄,尼罗河鳄鱼]',
52
+ 50: 'American alligator, Alligator mississipiensis [美国鳄鱼,鳄鱼]',
53
+ 51: 'triceratops [三角龙]',
54
+ 52: 'thunder snake, worm snake, Carphophis amoenus [雷蛇,蠕虫蛇]',
55
+ 53: 'ringneck snake, ring-necked snake, ring snake [环蛇,环颈蛇]',
56
+ 54: 'hognose snake, puff adder, sand viper [希腊蛇]',
57
+ 55: 'green snake, grass snake [绿蛇,草蛇]',
58
+ 56: 'king snake, kingsnake [国王蛇]',
59
+ 57: 'garter snake, grass snake [袜带蛇,草蛇]',
60
+ 58: 'water snake [水蛇]',
61
+ 59: 'vine snake [藤蛇]',
62
+ 60: 'night snake, Hypsiglena torquata [夜蛇]',
63
+ 61: 'boa constrictor, Constrictor constrictor [大蟒蛇]',
64
+ 62: 'rock python, rock snake, Python sebae [岩石蟒蛇,岩蛇,蟒蛇]',
65
+ 63: 'Indian cobra, Naja naja [印度眼镜蛇]',
66
+ 64: 'green mamba [绿曼巴]',
67
+ 65: 'sea snake [海蛇]',
68
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus [角腹蛇]',
69
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus [菱纹响尾蛇]',
70
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes [角响尾蛇]',
71
+ 69: 'trilobite [三叶虫]',
72
+ 70: 'harvestman, daddy longlegs, Phalangium opilio [盲蜘蛛]',
73
+ 71: 'scorpion [蝎子]',
74
+ 72: 'black and gold garden spider, Argiope aurantia [黑金花园蜘蛛]',
75
+ 73: 'barn spider, Araneus cavaticus [谷仓蜘蛛]',
76
+ 74: 'garden spider, Aranea diademata [花园蜘蛛]',
77
+ 75: 'black widow, Latrodectus mactans [黑寡妇蜘蛛]',
78
+ 76: 'tarantula [狼蛛]',
79
+ 77: 'wolf spider, hunting spider [狼蜘蛛,狩猎蜘蛛]',
80
+ 78: 'tick [壁虱]',
81
+ 79: 'centipede [蜈蚣]',
82
+ 80: 'black grouse [黑松鸡]',
83
+ 81: 'ptarmigan [松鸡,雷鸟]',
84
+ 82: 'ruffed grouse, partridge, Bonasa umbellus [披肩鸡,披肩榛鸡]',
85
+ 83: 'prairie chicken, prairie grouse, prairie fowl [草原鸡,草原松鸡]',
86
+ 84: 'peacock [孔雀]',
87
+ 85: 'quail [鹌鹑]',
88
+ 86: 'partridge [鹧鸪]',
89
+ 87: 'African grey, African gray, Psittacus erithacus [非洲灰鹦鹉]',
90
+ 88: 'macaw [金刚鹦鹉]',
91
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita [硫冠鹦鹉]',
92
+ 90: 'lorikeet [短尾鹦鹉]',
93
+ 91: 'coucal [褐翅鸦鹃]',
94
+ 92: 'bee eater [蜜蜂]',
95
+ 93: 'hornbill [犀鸟]',
96
+ 94: 'hummingbird [蜂鸟]',
97
+ 95: 'jacamar [鹟䴕]',
98
+ 96: 'toucan [犀鸟]',
99
+ 97: 'drake [野鸭]',
100
+ 98: 'red-breasted merganser, Mergus serrator [���胸秋沙鸭]',
101
+ 99: 'goose [鹅]',
102
+ 100: 'black swan, Cygnus atratus [黑天鹅]',
103
+ 101: 'tusker [大象]',
104
+ 102: 'echidna, spiny anteater, anteater [针鼹鼠]',
105
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus [鸭嘴兽]',
106
+ 104: 'wallaby, brush kangaroo [沙袋鼠]',
107
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus [考拉,考拉熊]',
108
+ 106: 'wombat [袋熊]',
109
+ 107: 'jellyfish [水母]',
110
+ 108: 'sea anemone, anemone [海葵]',
111
+ 109: 'brain coral [脑珊瑚]',
112
+ 110: 'flatworm, platyhelminth [扁形虫扁虫]',
113
+ 111: 'nematode, nematode worm, roundworm [线虫,蛔虫]',
114
+ 112: 'conch [海螺]',
115
+ 113: 'snail [蜗牛]',
116
+ 114: 'slug [鼻涕虫]',
117
+ 115: 'sea slug, nudibranch [海参]',
118
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore [石鳖]',
119
+ 117: 'chambered nautilus, pearly nautilus, nautilus [鹦鹉螺]',
120
+ 118: 'Dungeness crab, Cancer magister [珍宝蟹]',
121
+ 119: 'rock crab, Cancer irroratus [石蟹]',
122
+ 120: 'fiddler crab [招潮蟹]',
123
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica [帝王蟹,阿拉斯加蟹,阿拉斯加帝王蟹]',
124
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus [美国龙虾,缅因州龙虾]',
125
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish [大螯虾]',
126
+ 124: 'crayfish, crawfish, crawdad, crawdaddy [小龙虾]',
127
+ 125: 'hermit crab [寄居蟹]',
128
+ 126: 'isopod [等足目动物(明虾和螃蟹近亲)]',
129
+ 127: 'white stork, Ciconia ciconia [白鹳]',
130
+ 128: 'black stork, Ciconia nigra [黑鹳]',
131
+ 129: 'spoonbill [鹭]',
132
+ 130: 'flamingo [火烈鸟]',
133
+ 131: 'little blue heron, Egretta caerulea [小蓝鹭]',
134
+ 132: 'American egret, great white heron, Egretta albus [美国鹭,大白鹭]',
135
+ 133: 'bittern [麻鸦]',
136
+ 134: 'crane [鹤]',
137
+ 135: 'limpkin, Aramus pictus [秧鹤]',
138
+ 136: 'European gallinule, Porphyrio porphyrio [欧洲水鸡,紫水鸡]',
139
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana [沼泽泥母鸡,水母鸡]',
140
+ 138: 'bustard [鸨]',
141
+ 139: 'ruddy turnstone, Arenaria interpres [红翻石鹬]',
142
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina [红背鹬,黑腹滨鹬]',
143
+ 141: 'redshank, Tringa totanus [红脚鹬]',
144
+ 142: 'dowitcher [半蹼鹬]',
145
+ 143: 'oystercatcher, oyster catcher [蛎鹬]',
146
+ 144: 'pelican [鹈鹕]',
147
+ 145: 'king penguin, Aptenodytes patagonica [国王企鹅]',
148
+ 146: 'albatross, mollymawk [信天翁,大海鸟]',
149
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus [灰鲸]',
150
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca [杀人鲸,逆戟鲸,虎鲸]',
151
+ 149: 'dugong, Dugong dugon [海牛]',
152
+ 150: 'sea lion [海狮]',
153
+ 151: 'Chihuahua [奇瓦瓦]',
154
+ 152: 'Japanese spaniel [日本猎犬]',
155
+ 153: 'Maltese dog, Maltese terrier, Maltese [马尔济斯犬]',
156
+ 154: 'Pekinese, Pekingese, Peke [狮子狗]',
157
+ 155: 'Shih-Tzu [西施犬]',
158
+ 156: 'Blenheim spaniel [布莱尼姆猎犬]',
159
+ 157: 'papillon [巴比狗]',
160
+ 158: 'toy terrier [玩具犬]',
161
+ 159: 'Rhodesian ridgeback [罗得西亚长背猎狗]',
162
+ 160: 'Afghan hound, Afghan [阿富汗猎犬]',
163
+ 161: 'basset, basset hound [猎犬]',
164
+ 162: 'beagle [比格犬,猎兔犬]',
165
+ 163: 'bloodhound, sleuthhound [侦探犬]',
166
+ 164: 'bluetick [蓝色快狗]',
167
+ 165: 'black-and-tan coonhound [黑褐猎浣熊犬]',
168
+ 166: 'Walker hound, Walker foxhound [沃克猎犬]',
169
+ 167: 'English foxhound [英国猎狐犬]',
170
+ 168: 'redbone [美洲赤狗]',
171
+ 169: 'borzoi, Russian wolfhound [俄罗斯猎狼犬]',
172
+ 170: 'Irish wolfhound [爱尔兰猎狼犬]',
173
+ 171: 'Italian greyhound [意大利灰狗]',
174
+ 172: 'whippet [惠比特犬]',
175
+ 173: 'Ibizan hound, Ibizan Podenco [依比沙猎犬]',
176
+ 174: 'Norwegian elkhound, elkhound [挪威猎犬]',
177
+ 175: 'otterhound, otter hound [奥达猎犬,水獭猎犬]',
178
+ 176: 'Saluki, gazelle hound [沙克犬,瞪羚猎犬]',
179
+ 177: 'Scottish deerhound, deerhound [苏格兰猎鹿犬,猎鹿犬]',
180
+ 178: 'Weimaraner [威玛猎犬]',
181
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier [斯塔福德郡牛头梗,斯塔福德郡斗牛梗]',
182
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier [美国斯塔福德郡梗,美国比特斗牛梗,斗牛梗]',
183
+ 181: 'Bedlington terrier [贝德灵顿梗]',
184
+ 182: 'Border terrier [边境梗]',
185
+ 183: 'Kerry blue terrier [凯丽蓝梗]',
186
+ 184: 'Irish terrier [爱尔兰梗]',
187
+ 185: 'Norfolk terrier [诺福克梗]',
188
+ 186: 'Norwich terrier [诺维奇梗]',
189
+ 187: 'Yorkshire terrier [约克郡梗]',
190
+ 188: 'wire-haired fox terrier [刚毛猎狐梗]',
191
+ 189: 'Lakeland terrier [莱克兰梗]',
192
+ 190: 'Sealyham terrier, Sealyham [锡利哈姆梗]',
193
+ 191: 'Airedale, Airedale terrier [艾尔谷犬]',
194
+ 192: 'cairn, cairn terrier [凯恩梗]',
195
+ 193: 'Australian terrier [澳大利亚梗]',
196
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier [丹迪丁蒙梗]',
197
+ 195: 'Boston bull, Boston terrier [波士顿梗]',
198
+ 196: 'miniature schnauzer [迷你雪纳瑞犬]',
199
+ 197: 'giant schnauzer [巨型雪纳瑞犬]',
200
+ 198: 'standard schnauzer [标准雪纳瑞犬]',
201
+ 199: 'Scotch terrier, Scottish terrier, Scottie [苏格兰梗]',
202
+ 200: 'Tibetan terrier, chrysanthemum dog [西藏梗,菊花狗]',
203
+ 201: 'silky terrier, Sydney silky [丝毛梗]',
204
+ 202: 'soft-coated wheaten terrier [软毛麦色梗]',
205
+ 203: 'West Highland white terrier [西高地白梗]',
206
+ 204: 'Lhasa, Lhasa apso [拉萨阿普索犬]',
207
+ 205: 'flat-coated retriever [平毛寻回犬]',
208
+ 206: 'curly-coated retriever [卷毛寻回犬]',
209
+ 207: 'golden retriever [金毛猎犬]',
210
+ 208: 'Labrador retriever [拉布拉多猎犬]',
211
+ 209: 'Chesapeake Bay retriever [乞沙比克猎犬]',
212
+ 210: 'German short-haired pointer [德国短毛猎犬]',
213
+ 211: 'vizsla, Hungarian pointer [维兹拉犬]',
214
+ 212: 'English setter [英国谍犬]',
215
+ 213: 'Irish setter, red setter [爱尔兰雪达犬,红色猎犬]',
216
+ 214: 'Gordon setter [戈登雪达犬]',
217
+ 215: 'Brittany spaniel [布列塔尼犬猎犬]',
218
+ 216: 'clumber, clumber spaniel [黄毛,黄毛猎犬]',
219
+ 217: 'English springer, English springer spaniel [英国史宾格犬]',
220
+ 218: 'Welsh springer spaniel [威尔士史宾格犬]',
221
+ 219: 'cocker spaniel, English cocker spaniel, cocker [可卡犬,英国可卡犬]',
222
+ 220: 'Sussex spaniel [萨塞克斯猎犬]',
223
+ 221: 'Irish water spaniel [爱尔兰水猎犬]',
224
+ 222: 'kuvasz [哥威斯犬]',
225
+ 223: 'schipperke [舒柏奇犬]',
226
+ 224: 'groenendael [比利时牧羊犬]',
227
+ 225: 'malinois [马里努阿犬]',
228
+ 226: 'briard [伯瑞犬]',
229
+ 227: 'kelpie [凯尔皮犬]',
230
+ 228: 'komondor [匈牙利牧羊犬]',
231
+ 229: 'Old English sheepdog, bobtail [老英国牧羊犬]',
232
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland [喜乐蒂牧羊犬]',
233
+ 231: 'collie [牧羊犬]',
234
+ 232: 'Border collie [边境牧羊犬]',
235
+ 233: 'Bouvier des Flandres, Bouviers des Flandres [法兰德斯牧牛狗]',
236
+ 234: 'Rottweiler [罗特韦尔犬]',
237
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian [德国牧羊犬,德国警犬,阿尔萨斯]',
238
+ 236: 'Doberman, Doberman pinscher [多伯曼犬,杜宾犬]',
239
+ 237: 'miniature pinscher [迷你杜宾犬]',
240
+ 238: 'Greater Swiss Mountain dog [大瑞士山地犬]',
241
+ 239: 'Bernese mountain dog [伯恩山犬]',
242
+ 240: 'Appenzeller [Appenzeller狗]',
243
+ 241: 'EntleBucher [EntleBucher狗]',
244
+ 242: 'boxer [拳师狗]',
245
+ 243: 'bull mastiff [斗牛獒]',
246
+ 244: 'Tibetan mastiff [藏獒]',
247
+ 245: 'French bulldog [法国斗牛犬]',
248
+ 246: 'Great Dane [大丹犬]',
249
+ 247: 'Saint Bernard, St Bernard [圣伯纳德狗]',
250
+ 248: 'Eskimo dog, husky [爱斯基摩犬,哈士奇]',
251
+ 249: 'malamute, malemute, Alaskan malamute [雪橇犬,阿拉斯加爱斯基摩狗]',
252
+ 250: 'Siberian husky [哈士奇]',
253
+ 251: 'dalmatian, coach dog, carriage dog [达尔马提亚,教练车狗]',
254
+ 252: 'affenpinscher, monkey pinscher, monkey dog [狮毛狗]',
255
+ 253: 'basenji [巴辛吉狗]',
256
+ 254: 'pug, pug-dog [哈巴狗,狮子狗]',
257
+ 255: 'Leonberg [莱昂贝格狗]',
258
+ 256: 'Newfoundland, Newfoundland dog [纽芬兰岛狗]',
259
+ 257: 'Great Pyrenees [大白熊犬]',
260
+ 258: 'Samoyed, Samoyede [萨摩耶犬]',
261
+ 259: 'Pomeranian [博美犬]',
262
+ 260: 'chow, chow chow [松狮,松狮]',
263
+ 261: 'keeshond [荷兰卷尾狮毛狗]',
264
+ 262: 'Brabancon griffon [布鲁塞尔格林芬犬]',
265
+ 263: 'Pembroke, Pembroke Welsh corgi [彭布洛克威尔士科基犬]',
266
+ 264: 'Cardigan, Cardigan Welsh corgi [威尔士柯基犬]',
267
+ 265: 'toy poodle [玩具贵宾犬]',
268
+ 266: 'miniature poodle [迷你贵宾犬]',
269
+ 267: 'standard poodle [标准贵宾犬]',
270
+ 268: 'Mexican hairless [墨西哥无毛犬]',
271
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus [灰狼]',
272
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum [白狼,北极狼]',
273
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger [红太狼,鬃狼,犬犬鲁弗斯]',
274
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans [狼,草原狼,刷狼,郊狼]',
275
+ 273: 'dingo, warrigal, warragal, Canis dingo [澳洲野狗,澳大利亚野犬]',
276
+ 274: 'dhole, Cuon alpinus [豺]',
277
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus [非洲猎犬,土狼犬]',
278
+ 276: 'hyena, hyaena [鬣狗]',
279
+ 277: 'red fox, Vulpes vulpes [红狐狸]',
280
+ 278: 'kit fox, Vulpes macrotis [沙狐]',
281
+ 279: 'Arctic fox, white fox, Alopex lagopus [北极狐狸,白狐狸]',
282
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus [灰狐狸]',
283
+ 281: 'tabby, tabby cat [虎斑猫]',
284
+ 282: 'tiger cat [山猫,虎猫]',
285
+ 283: 'Persian cat [波斯猫]',
286
+ 284: 'Siamese cat, Siamese [暹罗暹罗猫,]',
287
+ 285: 'Egyptian cat [埃及猫]',
288
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor [美洲狮,美洲豹]',
289
+ 287: 'lynx, catamount [猞猁,山猫]',
290
+ 288: 'leopard, Panthera pardus [豹子]',
291
+ 289: 'snow leopard, ounce, Panthera uncia [雪豹]',
292
+ 290: 'jaguar, panther, Panthera onca, Felis onca [美洲虎]',
293
+ 291: 'lion, king of beasts, Panthera leo [狮子]',
294
+ 292: 'tiger, Panthera tigris [老虎]',
295
+ 293: 'cheetah, chetah, Acinonyx jubatus [猎豹]',
296
+ 294: 'brown bear, bruin, Ursus arctos [棕熊]',
297
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus [美洲黑熊]',
298
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus [冰熊,北极熊]',
299
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus [懒熊]',
300
+ 298: 'mongoose [猫鼬]',
301
+ 299: 'meerkat, mierkat [猫鼬,海猫]',
302
+ 300: 'tiger beetle [虎甲虫]',
303
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle [瓢虫]',
304
+ 302: 'ground beetle, carabid beetle [土鳖虫]',
305
+ 303: 'long-horned beetle, longicorn, longicorn beetle [天牛]',
306
+ 304: 'leaf beetle, chrysomelid [龟甲虫]',
307
+ 305: 'dung beetle [粪甲虫]',
308
+ 306: 'rhinoceros beetle [犀牛甲虫]',
309
+ 307: 'weevil [象甲]',
310
+ 308: 'fly [苍蝇]',
311
+ 309: 'bee [蜜蜂]',
312
+ 310: 'ant, emmet, pismire [蚂蚁]',
313
+ 311: 'grasshopper, hopper [蚱蜢]',
314
+ 312: 'cricket [蟋蟀]',
315
+ 313: 'walking stick, walkingstick, stick insect [竹节虫]',
316
+ 314: 'cockroach, roach [蟑螂]',
317
+ 315: 'mantis, mantid [螳螂]',
318
+ 316: 'cicada, cicala [蝉]',
319
+ 317: 'leafhopper [叶蝉]',
320
+ 318: 'lacewing, lacewing fly [草蜻蛉]',
321
+ 319: 'dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk [蜻蜓]',
322
+ 320: 'damselfly [豆娘,蜻蛉]',
323
+ 321: 'admiral [优红蛱蝶]',
324
+ 322: 'ringlet, ringlet butterfly [小环蝴蝶]',
325
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus [君主蝴蝶,大斑蝶]',
326
+ 324: 'cabbage butterfly [菜粉蝶]',
327
+ 325: 'sulphur butterfly, sulfur butterfly [白蝴蝶]',
328
+ 326: 'lycaenid, lycaenid butterfly [灰蝶]',
329
+ 327: 'starfish, sea star [海星]',
330
+ 328: 'sea urchin [海胆]',
331
+ 329: 'sea cucumber, holothurian [海参,海黄瓜]',
332
+ 330: 'wood rabbit, cottontail, cottontail rabbit [野兔]',
333
+ 331: 'hare [兔]',
334
+ 332: 'Angora, Angora rabbit [安哥拉兔]',
335
+ 333: 'hamster [仓鼠]',
336
+ 334: 'porcupine, hedgehog [刺猬,豪猪,]',
337
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger [黑松鼠]',
338
+ 336: 'marmot [土拨鼠]',
339
+ 337: 'beaver [海狸]',
340
+ 338: 'guinea pig, Cavia cobaya [豚鼠,豚鼠]',
341
+ 339: 'sorrel [栗色马]',
342
+ 340: 'zebra [斑马]',
343
+ 341: 'hog, pig, grunter, squealer, Sus scrofa [猪]',
344
+ 342: 'wild boar, boar, Sus scrofa [野猪]',
345
+ 343: 'warthog [疣猪]',
346
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius [河马]',
347
+ 345: 'ox [牛]',
348
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis [水牛,亚洲水牛]',
349
+ 347: 'bison [野牛]',
350
+ 348: 'ram, tup [公羊]',
351
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis [大角羊,洛矶山大角羊]',
352
+ 350: 'ibex, Capra ibex [山羊]',
353
+ 351: 'hartebeest [狷羚]',
354
+ 352: 'impala, Aepyceros melampus [黑斑羚]',
355
+ 353: 'gazelle [瞪羚]',
356
+ 354: 'Arabian camel, dromedary, Camelus dromedarius [阿拉伯单峰骆驼,骆驼]',
357
+ 355: 'llama [骆驼]',
358
+ 356: 'weasel [黄鼠狼]',
359
+ 357: 'mink [水貂]',
360
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius [臭猫]',
361
+ 359: 'black-footed ferret, ferret, Mustela nigripes [黑足鼬]',
362
+ 360: 'otter [水獭]',
363
+ 361: 'skunk, polecat, wood pussy [臭鼬,木猫]',
364
+ 362: 'badger [獾]',
365
+ 363: 'armadillo [犰狳]',
366
+ 364: 'three-toed sloth, ai, Bradypus tridactylus [树懒]',
367
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus [猩猩,婆罗洲猩猩]',
368
+ 366: 'gorilla, Gorilla gorilla [大猩猩]',
369
+ 367: 'chimpanzee, chimp, Pan troglodytes [黑猩猩]',
370
+ 368: 'gibbon, Hylobates lar [长臂猿]',
371
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus [合趾猿长臂猿,合趾猿]',
372
+ 370: 'guenon, guenon monkey [长尾猴]',
373
+ 371: 'patas, hussar monkey, Erythrocebus patas [赤猴]',
374
+ 372: 'baboon [狒狒]',
375
+ 373: 'macaque [恒河猴,猕猴]',
376
+ 374: 'langur [白头叶猴]',
377
+ 375: 'colobus, colobus monkey [疣猴]',
378
+ 376: 'proboscis monkey, Nasalis larvatus [长鼻猴]',
379
+ 377: 'marmoset [狨(美洲产小型长尾猴)]',
380
+ 378: 'capuchin, ringtail, Cebus capucinus [卷尾猴]',
381
+ 379: 'howler monkey, howler [吼猴]',
382
+ 380: 'titi, titi monkey [伶猴]',
383
+ 381: 'spider monkey, Ateles geoffroyi [蜘蛛猴]',
384
+ 382: 'squirrel monkey, Saimiri sciureus [松鼠猴]',
385
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta [马达加斯加环尾狐猴,鼠狐猴]',
386
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus [大狐猴,马达加斯加大狐猴]',
387
+ 385: 'Indian elephant, Elephas maximus [印度大象,亚洲象]',
388
+ 386: 'African elephant, Loxodonta africana [非洲象,非洲象]',
389
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens [小熊猫]',
390
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca [大熊猫]',
391
+ 389: 'barracouta, snoek [杖鱼]',
392
+ 390: 'eel [鳗鱼]',
393
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch [银鲑,银鲑鱼]',
394
+ 392: 'rock beauty, Holocanthus tricolor [三色刺蝶鱼]',
395
+ 393: 'anemone fish [海葵鱼]',
396
+ 394: 'sturgeon [鲟鱼]',
397
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus [雀鳝]',
398
+ 396: 'lionfish [狮子鱼]',
399
+ 397: 'puffer, pufferfish, blowfish, globefish [河豚]',
400
+ 398: 'abacus [算盘]',
401
+ 399: 'abaya [长袍]',
402
+ 400: 'academic gown, academic robe, judge robe [学位袍]',
403
+ 401: 'accordion, piano accordion, squeeze box [手风琴]',
404
+ 402: 'acoustic guitar [原声吉他]',
405
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier [航空母舰]',
406
+ 404: 'airliner [客机]',
407
+ 405: 'airship, dirigible [飞艇]',
408
+ 406: 'altar [祭坛]',
409
+ 407: 'ambulance [救护车]',
410
+ 408: 'amphibian, amphibious vehicle [水陆两用车]',
411
+ 409: 'analog clock [模拟时钟]',
412
+ 410: 'apiary, bee house [蜂房]',
413
+ 411: 'apron [围裙]',
414
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin [垃圾桶]',
415
+ 413: 'assault rifle, assault gun [攻击步枪,枪]',
416
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack [背包]',
417
+ 415: 'bakery, bakeshop, bakehouse [面包店,面包铺,]',
418
+ 416: 'balance beam, beam [平衡木]',
419
+ 417: 'balloon [热气球]',
420
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro [圆珠笔]',
421
+ 419: 'Band Aid [创可贴]',
422
+ 420: 'banjo [班卓琴]',
423
+ 421: 'bannister, banister, balustrade, balusters, handrail [栏杆,楼梯扶手]',
424
+ 422: 'barbell [杠铃]',
425
+ 423: 'barber chair [理发师的椅子]',
426
+ 424: 'barbershop [理发店]',
427
+ 425: 'barn [牲口棚]',
428
+ 426: 'barometer [晴雨表]',
429
+ 427: 'barrel, cask [圆筒]',
430
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow [园地小车,手推车]',
431
+ 429: 'baseball [棒球]',
432
+ 430: 'basketball [篮球]',
433
+ 431: 'bassinet [婴儿床]',
434
+ 432: 'bassoon [巴松管,低音管]',
435
+ 433: 'bathing cap, swimming cap [游泳帽]',
436
+ 434: 'bath towel [沐浴毛巾]',
437
+ 435: 'bathtub, bathing tub, bath, tub [浴缸,澡盆]',
438
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon [沙滩车,旅行车]',
439
+ 437: 'beacon, lighthouse, beacon light, pharos [灯塔]',
440
+ 438: 'beaker [高脚杯]',
441
+ 439: 'bearskin, busby, shako [熊皮高帽]',
442
+ 440: 'beer bottle [啤酒瓶]',
443
+ 441: 'beer glass [啤酒杯]',
444
+ 442: 'bell cote, bell cot [钟塔]',
445
+ 443: 'bib [(小儿用的)围嘴]',
446
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem [串联自行车,]',
447
+ 445: 'bikini, two-piece [比基尼]',
448
+ 446: 'binder, ring-binder [装订册]',
449
+ 447: 'binoculars, field glasses, opera glasses [双筒望远镜]',
450
+ 448: 'birdhouse [鸟舍]',
451
+ 449: 'boathouse [船库]',
452
+ 450: 'bobsled, bobsleigh, bob [雪橇]',
453
+ 451: 'bolo tie, bolo, bola tie, bola [饰扣式领带]',
454
+ 452: 'bonnet, poke bonnet [阔边女帽]',
455
+ 453: 'bookcase [书橱]',
456
+ 454: 'bookshop, bookstore, bookstall [书店,书摊]',
457
+ 455: 'bottlecap [瓶盖]',
458
+ 456: 'bow [弓箭]',
459
+ 457: 'bow tie, bow-tie, bowtie [蝴蝶结领结]',
460
+ 458: 'brass, memorial tablet, plaque [铜制牌位]',
461
+ 459: 'brassiere, bra, bandeau [奶罩]',
462
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty [防波堤,海堤]',
463
+ 461: 'breastplate, aegis, egis [铠甲]',
464
+ 462: 'broom [扫帚]',
465
+ 463: 'bucket, pail [桶]',
466
+ 464: 'buckle [扣环]',
467
+ 465: 'bulletproof vest [防弹背心]',
468
+ 466: 'bullet train, bullet [动车,子弹头列车]',
469
+ 467: 'butcher shop, meat market [肉铺,肉菜市场]',
470
+ 468: 'cab, hack, taxi, taxicab [出租车]',
471
+ 469: 'caldron, cauldron [大锅]',
472
+ 470: 'candle, taper, wax light [蜡烛]',
473
+ 471: 'cannon [大炮]',
474
+ 472: 'canoe [独木舟]',
475
+ 473: 'can opener, tin opener [开瓶器,开罐器]',
476
+ 474: 'cardigan [开衫]',
477
+ 475: 'car mirror [车镜]',
478
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig [旋转木马]',
479
+ 477: 'carpenters kit, tool kit [木匠的工具包,工具包]',
480
+ 478: 'carton [纸箱]',
481
+ 479: 'car wheel [车轮]',
482
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM [取款机,自动取款机]',
483
+ 481: 'cassette [盒式录音带]',
484
+ 482: 'cassette player [卡带播放器]',
485
+ 483: 'castle [城堡]',
486
+ 484: 'catamaran [双体船]',
487
+ 485: 'CD player [CD播放器]',
488
+ 486: 'cello, violoncello [大提琴]',
489
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone [移动电话,手机]',
490
+ 488: 'chain [铁链]',
491
+ 489: 'chainlink fence [围栏]',
492
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour [链甲]',
493
+ 491: 'chain saw, chainsaw [电锯,油锯]',
494
+ 492: 'chest [箱子]',
495
+ 493: 'chiffonier, commode [衣柜,洗脸台]',
496
+ 494: 'chime, bell, gong [编钟,钟,锣]',
497
+ 495: 'china cabinet, china closet [中国橱柜]',
498
+ 496: 'Christmas stocking [圣诞袜]',
499
+ 497: 'church, church building [教堂,教堂建筑]',
500
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace [电影院,剧场]',
501
+ 499: 'cleaver, meat cleaver, chopper [切肉刀,菜刀]',
502
+ 500: 'cliff dwelling [悬崖屋]',
503
+ 501: 'cloak [斗篷]',
504
+ 502: 'clog, geta, patten, sabot [木屐,木鞋]',
505
+ 503: 'cocktail shaker [鸡尾酒调酒器]',
506
+ 504: 'coffee mug [咖啡杯]',
507
+ 505: 'coffeepot [咖啡壶]',
508
+ 506: 'coil, spiral, volute, whorl, helix [螺旋结构(楼梯)]',
509
+ 507: 'combination lock [组合锁]',
510
+ 508: 'computer keyboard, keypad [电脑键盘,键盘]',
511
+ 509: 'confectionery, confectionary, candy store [糖果,糖果店]',
512
+ 510: 'container ship, containership, container vessel [集装箱船]',
513
+ 511: 'convertible [敞篷车]',
514
+ 512: 'corkscrew, bottle screw [开瓶器,瓶螺杆]',
515
+ 513: 'cornet, horn, trumpet, trump [短号,喇叭]',
516
+ 514: 'cowboy boot [牛仔靴]',
517
+ 515: 'cowboy hat, ten-gallon hat [牛仔帽]',
518
+ 516: 'cradle [摇篮]',
519
+ 517: 'crane [起重机]',
520
+ 518: 'crash helmet [头盔]',
521
+ 519: 'crate [板条箱]',
522
+ 520: 'crib, cot [小儿床]',
523
+ 521: 'Crock Pot [砂锅]',
524
+ 522: 'croquet ball [槌球]',
525
+ 523: 'crutch [拐杖]',
526
+ 524: 'cuirass [胸甲]',
527
+ 525: 'dam, dike, dyke [大坝,堤防]',
528
+ 526: 'desk [书桌]',
529
+ 527: 'desktop computer [台式电脑]',
530
+ 528: 'dial telephone, dial phone [有线电话]',
531
+ 529: 'diaper, nappy, napkin [尿布湿]',
532
+ 530: 'digital clock [数字时钟]',
533
+ 531: 'digital watch [数字手表]',
534
+ 532: 'dining table, board [餐桌板]',
535
+ 533: 'dishrag, dishcloth [抹布]',
536
+ 534: 'dishwasher, dish washer, dishwashing machine [洗碗机,洗碟机]',
537
+ 535: 'disk brake, disc brake [盘式制动器]',
538
+ 536: 'dock, dockage, docking facility [码头,船坞,码头设施]',
539
+ 537: 'dogsled, dog sled, dog sleigh [狗拉雪橇]',
540
+ 538: 'dome [圆顶]',
541
+ 539: 'doormat, welcome mat [门垫,垫子]',
542
+ 540: 'drilling platform, offshore rig [钻井平台,海上钻井]',
543
+ 541: 'drum, membranophone, tympan [鼓,乐器,鼓膜]',
544
+ 542: 'drumstick [鼓槌]',
545
+ 543: 'dumbbell [哑铃]',
546
+ 544: 'Dutch oven [荷兰烤箱]',
547
+ 545: 'electric fan, blower [电风扇,鼓风机]',
548
+ 546: 'electric guitar [电吉他]',
549
+ 547: 'electric locomotive [电力机车]',
550
+ 548: 'entertainment center [电视,电视柜]',
551
+ 549: 'envelope [信封]',
552
+ 550: 'espresso maker [浓缩咖啡机]',
553
+ 551: 'face powder [扑面粉]',
554
+ 552: 'feather boa, boa [女用长围巾]',
555
+ 553: 'file, file cabinet, filing cabinet [文件,文件柜,档案柜]',
556
+ 554: 'fireboat [消防船]',
557
+ 555: 'fire engine, fire truck [消防车]',
558
+ 556: 'fire screen, fireguard [火炉栏]',
559
+ 557: 'flagpole, flagstaff [旗杆]',
560
+ 558: 'flute, transverse flute [长笛]',
561
+ 559: 'folding chair [折叠椅]',
562
+ 560: 'football helmet [橄榄球头盔]',
563
+ 561: 'forklift [叉车]',
564
+ 562: 'fountain [喷泉]',
565
+ 563: 'fountain pen [钢笔]',
566
+ 564: 'four-poster [有四根帷柱的床]',
567
+ 565: 'freight car [运货车厢]',
568
+ 566: 'French horn, horn [圆号,喇叭]',
569
+ 567: 'frying pan, frypan, skillet [煎锅]',
570
+ 568: 'fur coat [裘皮大衣]',
571
+ 569: 'garbage truck, dustcart [垃圾车]',
572
+ 570: 'gasmask, respirator, gas helmet [防毒面具,呼吸器]',
573
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser [汽油泵]',
574
+ 572: 'goblet [高脚杯]',
575
+ 573: 'go-kart [卡丁车]',
576
+ 574: 'golf ball [高尔夫球]',
577
+ 575: 'golfcart, golf cart [高尔夫球车]',
578
+ 576: 'gondola [狭长小船]',
579
+ 577: 'gong, tam-tam [锣]',
580
+ 578: 'gown [礼服]',
581
+ 579: 'grand piano, grand [钢琴]',
582
+ 580: 'greenhouse, nursery, glasshouse [温室,苗圃]',
583
+ 581: 'grille, radiator grille [散热器格栅]',
584
+ 582: 'grocery store, grocery, food market, market [杂货店,食品市场]',
585
+ 583: 'guillotine [断头台]',
586
+ 584: 'hair slide [小发夹]',
587
+ 585: 'hair spray [头发喷雾]',
588
+ 586: 'half track [半履带装甲车]',
589
+ 587: 'hammer [锤子]',
590
+ 588: 'hamper [大篮子]',
591
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier [手摇鼓风机,吹风机]',
592
+ 590: 'hand-held computer, hand-held microcomputer [手提电脑]',
593
+ 591: 'handkerchief, hankie, hanky, hankey [手帕]',
594
+ 592: 'hard disc, hard disk, fixed disk [硬盘]',
595
+ 593: 'harmonica, mouth organ, harp, mouth harp [口琴,口风琴]',
596
+ 594: 'harp [竖琴]',
597
+ 595: 'harvester, reaper [收割机]',
598
+ 596: 'hatchet [斧头]',
599
+ 597: 'holster [手枪皮套]',
600
+ 598: 'home theater, home theatre [家庭影院]',
601
+ 599: 'honeycomb [蜂窝]',
602
+ 600: 'hook, claw [钩爪]',
603
+ 601: 'hoopskirt, crinoline [衬裙]',
604
+ 602: 'horizontal bar, high bar [单杠]',
605
+ 603: 'horse cart, horse-cart [马车]',
606
+ 604: 'hourglass [沙漏]',
607
+ 605: 'iPod [手机,iPad]',
608
+ 606: 'iron, smoothing iron [熨斗]',
609
+ 607: 'jack-o-lantern [南瓜灯笼]',
610
+ 608: 'jean, blue jean, denim [牛仔裤,蓝色牛仔裤]',
611
+ 609: 'jeep, landrover [吉普车]',
612
+ 610: 'jersey, T-shirt, tee shirt [运动衫,T恤]',
613
+ 611: 'jigsaw puzzle [拼图]',
614
+ 612: 'jinrikisha, ricksha, rickshaw [人力车]',
615
+ 613: 'joystick [操纵杆]',
616
+ 614: 'kimono [和服]',
617
+ 615: 'knee pad [护膝]',
618
+ 616: 'knot [蝴蝶结]',
619
+ 617: 'lab coat, laboratory coat [大褂,实验室外套]',
620
+ 618: 'ladle [长柄勺]',
621
+ 619: 'lampshade, lamp shade [灯罩]',
622
+ 620: 'laptop, laptop computer [笔记本电脑]',
623
+ 621: 'lawn mower, mower [割草机]',
624
+ 622: 'lens cap, lens cover [镜头盖]',
625
+ 623: 'letter opener, paper knife, paperknife [开信刀,裁纸刀]',
626
+ 624: 'library [图书馆]',
627
+ 625: 'lifeboat [救生艇]',
628
+ 626: 'lighter, light, igniter, ignitor [点火器,打火机]',
629
+ 627: 'limousine, limo [豪华轿车]',
630
+ 628: 'liner, ocean liner [远洋班轮]',
631
+ 629: 'lipstick, lip rouge [唇膏,口红]',
632
+ 630: 'Loafer [平底便鞋]',
633
+ 631: 'lotion [洗剂]',
634
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system [扬声器]',
635
+ 633: 'loupe, jewelers loupe [放大镜]',
636
+ 634: 'lumbermill, sawmill [锯木厂]',
637
+ 635: 'magnetic compass [磁罗盘]',
638
+ 636: 'mailbag, postbag [邮袋]',
639
+ 637: 'mailbox, letter box [信箱]',
640
+ 638: 'maillot [女游泳衣]',
641
+ 639: 'maillot, tank suit [有肩带浴衣]',
642
+ 640: 'manhole cover [窨井盖]',
643
+ 641: 'maraca [沙球(一种打击乐器)]',
644
+ 642: 'marimba, xylophone [马林巴木琴]',
645
+ 643: 'mask [面膜]',
646
+ 644: 'matchstick [火柴]',
647
+ 645: 'maypole [花柱]',
648
+ 646: 'maze, labyrinth [迷宫]',
649
+ 647: 'measuring cup [量杯]',
650
+ 648: 'medicine chest, medicine cabinet [药箱]',
651
+ 649: 'megalith, megalithic structure [巨石,巨石结构]',
652
+ 650: 'microphone, mike [麦克风]',
653
+ 651: 'microwave, microwave oven [微波炉]',
654
+ 652: 'military uniform [军装]',
655
+ 653: 'milk can [奶桶]',
656
+ 654: 'minibus [迷你巴士]',
657
+ 655: 'miniskirt, mini [迷你裙]',
658
+ 656: 'minivan [面包车]',
659
+ 657: 'missile [导弹]',
660
+ 658: 'mitten [连指手套]',
661
+ 659: 'mixing bowl [搅拌钵]',
662
+ 660: 'mobile home, manufactured home [活动房屋(由汽车拖拉的)]',
663
+ 661: 'Model T [T型发动机小汽车]',
664
+ 662: 'modem [调制解调器]',
665
+ 663: 'monastery [修道院]',
666
+ 664: 'monitor [显示器]',
667
+ 665: 'moped [电瓶车]',
668
+ 666: 'mortar [砂浆]',
669
+ 667: 'mortarboard [学士]',
670
+ 668: 'mosque [清真寺]',
671
+ 669: 'mosquito net [蚊帐]',
672
+ 670: 'motor scooter, scooter [摩托车]',
673
+ 671: 'mountain bike, all-terrain bike, off-roader [山地自行车]',
674
+ 672: 'mountain tent [登山帐]',
675
+ 673: 'mouse, computer mouse [鼠标,电脑鼠标]',
676
+ 674: 'mousetrap [捕鼠器]',
677
+ 675: 'moving van [搬家车]',
678
+ 676: 'muzzle [口套]',
679
+ 677: 'nail [钉子]',
680
+ 678: 'neck brace [颈托]',
681
+ 679: 'necklace [项链]',
682
+ 680: 'nipple [乳头(瓶)]',
683
+ 681: 'notebook, notebook computer [笔记本,笔记本电脑]',
684
+ 682: 'obelisk [方尖碑]',
685
+ 683: 'oboe, hautboy, hautbois [双簧管]',
686
+ 684: 'ocarina, sweet potato [陶笛,卵形笛]',
687
+ 685: 'odometer, hodometer, mileometer, milometer [里程表]',
688
+ 686: 'oil filter [滤油器]',
689
+ 687: 'organ, pipe organ [风琴,管风琴]',
690
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO [示波器]',
691
+ 689: 'overskirt [罩裙]',
692
+ 690: 'oxcart [牛车]',
693
+ 691: 'oxygen mask [氧气面罩]',
694
+ 692: 'packet [包装]',
695
+ 693: 'paddle, boat paddle [船桨]',
696
+ 694: 'paddlewheel, paddle wheel [明轮,桨轮]',
697
+ 695: 'padlock [挂锁,扣锁]',
698
+ 696: 'paintbrush [画笔]',
699
+ 697: 'pajama, pyjama, pjs, jammies [睡衣]',
700
+ 698: 'palace [宫殿]',
701
+ 699: 'panpipe, pandean pipe, syrinx [排箫,鸣管]',
702
+ 700: 'paper towel [纸巾]',
703
+ 701: 'parachute, chute [降落伞]',
704
+ 702: 'parallel bars, bars [双杠]',
705
+ 703: 'park bench [公园长椅]',
706
+ 704: 'parking meter [停车收费表,停车计时器]',
707
+ 705: 'passenger car, coach, carriage [客车,教练车]',
708
+ 706: 'patio, terrace [露台,阳台]',
709
+ 707: 'pay-phone, pay-station [付费电话]',
710
+ 708: 'pedestal, plinth, footstall [基座,基脚]',
711
+ 709: 'pencil box, pencil case [铅笔盒]',
712
+ 710: 'pencil sharpener [卷笔刀]',
713
+ 711: 'perfume, essence [香水(瓶)]',
714
+ 712: 'Petri dish [培养皿]',
715
+ 713: 'photocopier [复印机]',
716
+ 714: 'pick, plectrum, plectron [拨弦片,拨子]',
717
+ 715: 'pickelhaube [尖顶头盔]',
718
+ 716: 'picket fence, paling [栅栏,栅栏]',
719
+ 717: 'pickup, pickup truck [皮卡,皮卡车]',
720
+ 718: 'pier [桥墩]',
721
+ 719: 'piggy bank, penny bank [存钱罐]',
722
+ 720: 'pill bottle [药瓶]',
723
+ 721: 'pillow [枕头]',
724
+ 722: 'ping-pong ball [乒乓球]',
725
+ 723: 'pinwheel [风车]',
726
+ 724: 'pirate, pirate ship [海盗船]',
727
+ 725: 'pitcher, ewer [水罐]',
728
+ 726: 'plane, carpenters plane, woodworking plane [木工刨]',
729
+ 727: 'planetarium [天文馆]',
730
+ 728: 'plastic bag [塑料袋]',
731
+ 729: 'plate rack [板架]',
732
+ 730: 'plow, plough [犁型铲雪机]',
733
+ 731: 'plunger, plumbers helper [手压皮碗泵]',
734
+ 732: 'Polaroid camera, Polaroid Land camera [宝丽来相机]',
735
+ 733: 'pole [电线杆]',
736
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria [警车,巡逻车]',
737
+ 735: 'poncho [雨披]',
738
+ 736: 'pool table, billiard table, snooker table [台球桌]',
739
+ 737: 'pop bottle, soda bottle [充气饮料瓶]',
740
+ 738: 'pot, flowerpot [花盆]',
741
+ 739: 'potters wheel [陶工旋盘]',
742
+ 740: 'power drill [电钻]',
743
+ 741: 'prayer rug, prayer mat [祈祷垫,地毯]',
744
+ 742: 'printer [打印机]',
745
+ 743: 'prison, prison house [监狱]',
746
+ 744: 'projectile, missile [炮弹,导弹]',
747
+ 745: 'projector [投影仪]',
748
+ 746: 'puck, hockey puck [冰球]',
749
+ 747: 'punching bag, punch bag, punching ball, punchball [沙包,吊球]',
750
+ 748: 'purse [钱包]',
751
+ 749: 'quill, quill pen [羽管笔]',
752
+ 750: 'quilt, comforter, comfort, puff [被子]',
753
+ 751: 'racer, race car, racing car [赛车]',
754
+ 752: 'racket, racquet [球拍]',
755
+ 753: 'radiator [散热器]',
756
+ 754: 'radio, wireless [收音机]',
757
+ 755: 'radio telescope, radio reflector [射电望远镜,无线电反射器]',
758
+ 756: 'rain barrel [雨桶]',
759
+ 757: 'recreational vehicle, RV, R.V. [休闲车,房车]',
760
+ 758: 'reel [卷轴,卷筒]',
761
+ 759: 'reflex camera [反射式照相机]',
762
+ 760: 'refrigerator, icebox [冰箱,冰柜]',
763
+ 761: 'remote control, remote [遥控器]',
764
+ 762: 'restaurant, eating house, eating place, eatery [餐厅,饮食店,食堂]',
765
+ 763: 'revolver, six-gun, six-shooter [左轮手枪]',
766
+ 764: 'rifle [步枪]',
767
+ 765: 'rocking chair, rocker [摇椅]',
768
+ 766: 'rotisserie [电转烤肉架]',
769
+ 767: 'rubber eraser, rubber, pencil eraser [橡皮]',
770
+ 768: 'rugby ball [橄榄球]',
771
+ 769: 'rule, ruler [直尺]',
772
+ 770: 'running shoe [跑步鞋]',
773
+ 771: 'safe [保险柜]',
774
+ 772: 'safety pin [安全别针]',
775
+ 773: 'saltshaker, salt shaker [盐瓶(调味用)]',
776
+ 774: 'sandal [凉鞋]',
777
+ 775: 'sarong [纱笼,围裙]',
778
+ 776: 'sax, saxophone [萨克斯管]',
779
+ 777: 'scabbard [剑鞘]',
780
+ 778: 'scale, weighing machine [秤,称重机]',
781
+ 779: 'school bus [校车]',
782
+ 780: 'schooner [帆船]',
783
+ 781: 'scoreboard [记分牌]',
784
+ 782: 'screen, CRT screen [屏幕]',
785
+ 783: 'screw [螺丝]',
786
+ 784: 'screwdriver [螺丝刀]',
787
+ 785: 'seat belt, seatbelt [安全带]',
788
+ 786: 'sewing machine [缝纫机]',
789
+ 787: 'shield, buckler [盾牌,盾牌]',
790
+ 788: 'shoe shop, shoe-shop, shoe store [皮鞋店,鞋店]',
791
+ 789: 'shoji [障子]',
792
+ 790: 'shopping basket [购物篮]',
793
+ 791: 'shopping cart [购物车]',
794
+ 792: 'shovel [铁锹]',
795
+ 793: 'shower cap [浴帽]',
796
+ 794: 'shower curtain [浴帘]',
797
+ 795: 'ski [滑雪板]',
798
+ 796: 'ski mask [滑雪面罩]',
799
+ 797: 'sleeping bag [睡袋]',
800
+ 798: 'slide rule, slipstick [滑尺]',
801
+ 799: 'sliding door [滑动门]',
802
+ 800: 'slot, one-armed bandit [角子老虎机]',
803
+ 801: 'snorkel [潜水通气管]',
804
+ 802: 'snowmobile [雪橇]',
805
+ 803: 'snowplow, snowplough [扫雪机,扫雪机]',
806
+ 804: 'soap dispenser [皂液器]',
807
+ 805: 'soccer ball [足球]',
808
+ 806: 'sock [袜子]',
809
+ 807: 'solar dish, solar collector, solar furnace [碟式太阳能,太阳能集热器,太阳能炉]',
810
+ 808: 'sombrero [宽边帽]',
811
+ 809: 'soup bowl [汤碗]',
812
+ 810: 'space bar [空格键]',
813
+ 811: 'space heater [空间加热器]',
814
+ 812: 'space shuttle [航天飞机]',
815
+ 813: 'spatula [铲(搅拌或涂敷用的)]',
816
+ 814: 'speedboat [快艇]',
817
+ 815: 'spider web, spiders web [蜘蛛网]',
818
+ 816: 'spindle [纺锤,纱锭]',
819
+ 817: 'sports car, sport car [跑车]',
820
+ 818: 'spotlight, spot [聚光灯]',
821
+ 819: 'stage [舞台]',
822
+ 820: 'steam locomotive [蒸汽机车]',
823
+ 821: 'steel arch bridge [钢拱桥]',
824
+ 822: 'steel drum [钢滚筒]',
825
+ 823: 'stethoscope [听诊器]',
826
+ 824: 'stole [女用披肩]',
827
+ 825: 'stone wall [石头墙]',
828
+ 826: 'stopwatch, stop watch [秒表]',
829
+ 827: 'stove [火炉]',
830
+ 828: 'strainer [过滤器]',
831
+ 829: 'streetcar, tram, tramcar, trolley, trolley car [有轨电车,电车]',
832
+ 830: 'stretcher [担架]',
833
+ 831: 'studio couch, day bed [沙发床]',
834
+ 832: 'stupa, tope [佛塔]',
835
+ 833: 'submarine, pigboat, sub, U-boat [潜艇,潜水艇]',
836
+ 834: 'suit, suit of clothes [套装,衣服]',
837
+ 835: 'sundial [日晷]',
838
+ 836: 'sunglass [太阳镜]',
839
+ 837: 'sunglasses, dark glasses, shades [太阳镜,墨镜]',
840
+ 838: 'sunscreen, sunblock, sun blocker [防晒霜,防晒剂]',
841
+ 839: 'suspension bridge [悬索桥]',
842
+ 840: 'swab, swob, mop [拖把]',
843
+ 841: 'sweatshirt [运动衫]',
844
+ 842: 'swimming trunks, bathing trunks [游泳裤]',
845
+ 843: 'swing [秋千]',
846
+ 844: 'switch, electric switch, electrical switch [开关,电器开关]',
847
+ 845: 'syringe [注射器]',
848
+ 846: 'table lamp [台灯]',
849
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle [坦克,装甲战车,装甲战斗车辆]',
850
+ 848: 'tape player [磁带播放器]',
851
+ 849: 'teapot [茶壶]',
852
+ 850: 'teddy, teddy bear [泰迪,泰迪熊]',
853
+ 851: 'television, television system [电视]',
854
+ 852: 'tennis ball [网球]',
855
+ 853: 'thatch, thatched roof [茅草,茅草屋顶]',
856
+ 854: 'theater curtain, theatre curtain [幕布,剧院的帷幕]',
857
+ 855: 'thimble [顶针]',
858
+ 856: 'thresher, thrasher, threshing machine [脱粒机]',
859
+ 857: 'throne [宝座]',
860
+ 858: 'tile roof [瓦屋顶]',
861
+ 859: 'toaster [烤面包机]',
862
+ 860: 'tobacco shop, tobacconist shop, tobacconist [烟草店,烟草]',
863
+ 861: 'toilet seat [马桶]',
864
+ 862: 'torch [火炬]',
865
+ 863: 'totem pole [图腾柱]',
866
+ 864: 'tow truck, tow car, wrecker [拖车,牵引车,清障车]',
867
+ 865: 'toyshop [玩具店]',
868
+ 866: 'tractor [拖拉机]',
869
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi [拖车,铰接式卡车]',
870
+ 868: 'tray [托盘]',
871
+ 869: 'trench coat [风衣]',
872
+ 870: 'tricycle, trike, velocipede [三轮车]',
873
+ 871: 'trimaran [三体船]',
874
+ 872: 'tripod [三脚架]',
875
+ 873: 'triumphal arch [凯旋门]',
876
+ 874: 'trolleybus, trolley coach, trackless trolley [无轨电车]',
877
+ 875: 'trombone [长号]',
878
+ 876: 'tub, vat [浴盆,浴缸]',
879
+ 877: 'turnstile [旋转式栅门]',
880
+ 878: 'typewriter keyboard [打字机键盘]',
881
+ 879: 'umbrella [伞]',
882
+ 880: 'unicycle, monocycle [独轮车]',
883
+ 881: 'upright, upright piano [直立式钢琴]',
884
+ 882: 'vacuum, vacuum cleaner [真空吸尘器]',
885
+ 883: 'vase [花瓶]',
886
+ 884: 'vault [拱顶]',
887
+ 885: 'velvet [天鹅绒]',
888
+ 886: 'vending machine [自动售货机]',
889
+ 887: 'vestment [祭服]',
890
+ 888: 'viaduct [高架桥]',
891
+ 889: 'violin, fiddle [小提琴,小提琴]',
892
+ 890: 'volleyball [排球]',
893
+ 891: 'waffle iron [松饼机]',
894
+ 892: 'wall clock [挂钟]',
895
+ 893: 'wallet, billfold, notecase, pocketbook [钱包,皮夹]',
896
+ 894: 'wardrobe, closet, press [衣柜,壁橱]',
897
+ 895: 'warplane, military plane [军用飞机]',
898
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin [洗脸盆,洗手盆]',
899
+ 897: 'washer, automatic washer, washing machine [洗衣机,自动洗衣机]',
900
+ 898: 'water bottle [水瓶]',
901
+ 899: 'water jug [水壶]',
902
+ 900: 'water tower [水塔]',
903
+ 901: 'whiskey jug [威士忌壶]',
904
+ 902: 'whistle [哨子]',
905
+ 903: 'wig [假发]',
906
+ 904: 'window screen [纱窗]',
907
+ 905: 'window shade [百叶窗]',
908
+ 906: 'Windsor tie [温莎领带]',
909
+ 907: 'wine bottle [葡萄酒瓶]',
910
+ 908: 'wing [飞机翅膀,飞机]',
911
+ 909: 'wok [炒菜锅]',
912
+ 910: 'wooden spoon [木制的勺子]',
913
+ 911: 'wool, woolen, woollen [毛织品,羊绒]',
914
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence [栅栏,围栏]',
915
+ 913: 'wreck [沉船]',
916
+ 914: 'yawl [双桅船]',
917
+ 915: 'yurt [蒙古包]',
918
+ 916: 'web site, website, internet site, site [网站,互联网网站]',
919
+ 917: 'comic book [漫画]',
920
+ 918: 'crossword puzzle, crossword [纵横字谜]',
921
+ 919: 'street sign [路标]',
922
+ 920: 'traffic light, traffic signal, stoplight [交通信号灯]',
923
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper [防尘罩,书皮]',
924
+ 922: 'menu [菜单]',
925
+ 923: 'plate [盘子]',
926
+ 924: 'guacamole [鳄梨酱]',
927
+ 925: 'consomme [清汤]',
928
+ 926: 'hot pot, hotpot [罐焖土豆烧肉]',
929
+ 927: 'trifle [蛋糕]',
930
+ 928: 'ice cream, icecream [冰淇淋]',
931
+ 929: 'ice lolly, lolly, lollipop, popsicle [雪糕,冰棍,冰棒]',
932
+ 930: 'French loaf [法式面包]',
933
+ 931: 'bagel, beigel [百吉饼]',
934
+ 932: 'pretzel [椒盐脆饼]',
935
+ 933: 'cheeseburger [芝士汉堡]',
936
+ 934: 'hotdog, hot dog, red hot [热狗]',
937
+ 935: 'mashed potato [土豆泥]',
938
+ 936: 'head cabbage [结球甘蓝]',
939
+ 937: 'broccoli [西兰花]',
940
+ 938: 'cauliflower [菜花]',
941
+ 939: 'zucchini, courgette [绿皮密生西葫芦]',
942
+ 940: 'spaghetti squash [西葫芦]',
943
+ 941: 'acorn squash [小青南瓜]',
944
+ 942: 'butternut squash [南瓜]',
945
+ 943: 'cucumber, cuke [黄瓜]',
946
+ 944: 'artichoke, globe artichoke [朝鲜蓟]',
947
+ 945: 'bell pepper [甜椒]',
948
+ 946: 'cardoon [刺棘蓟]',
949
+ 947: 'mushroom [蘑菇]',
950
+ 948: 'Granny Smith [绿苹果]',
951
+ 949: 'strawberry [草莓]',
952
+ 950: 'orange [橘子]',
953
+ 951: 'lemon [柠檬]',
954
+ 952: 'fig [无花果]',
955
+ 953: 'pineapple, ananas [菠萝]',
956
+ 954: 'banana [香蕉]',
957
+ 955: 'jackfruit, jak, jack [菠萝蜜]',
958
+ 956: 'custard apple [蛋奶冻苹果]',
959
+ 957: 'pomegranate [石榴]',
960
+ 958: 'hay [干草]',
961
+ 959: 'carbonara [烤面条加干酪沙司]',
962
+ 960: 'chocolate sauce, chocolate syrup [巧克力酱,巧克力糖浆]',
963
+ 961: 'dough [面团]',
964
+ 962: 'meat loaf, meatloaf [瑞士肉包,肉饼]',
965
+ 963: 'pizza, pizza pie [披萨,披萨饼]',
966
+ 964: 'potpie [馅饼]',
967
+ 965: 'burrito [卷饼]',
968
+ 966: 'red wine [红葡萄酒]',
969
+ 967: 'espresso [意大利浓咖啡]',
970
+ 968: 'cup [杯子]',
971
+ 969: 'eggnog [蛋酒]',
972
+ 970: 'alp [高山]',
973
+ 971: 'bubble [泡泡]',
974
+ 972: 'cliff, drop, drop-off [悬崖]',
975
+ 973: 'coral reef [珊瑚礁]',
976
+ 974: 'geyser [间歇泉]',
977
+ 975: 'lakeside, lakeshore [湖边,湖岸]',
978
+ 976: 'promontory, headland, head, foreland [海角]',
979
+ 977: 'sandbar, sand bar [沙洲,沙坝]',
980
+ 978: 'seashore, coast, seacoast, sea-coast [海滨,海岸]',
981
+ 979: 'valley, vale [峡谷]',
982
+ 980: 'volcano [火山]',
983
+ 981: 'ballplayer, baseball player [棒球,棒球运动员]',
984
+ 982: 'groom, bridegroom [新郎]',
985
+ 983: 'scuba diver [潜水员]',
986
+ 984: 'rapeseed [油菜]',
987
+ 985: 'daisy [雏菊]',
988
+ 986: 'yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum [杓兰]',
989
+ 987: 'corn [玉米]',
990
+ 988: 'acorn [橡子]',
991
+ 989: 'hip, rose hip, rosehip [玫瑰果]',
992
+ 990: 'buckeye, horse chestnut, conker [七叶树果实]',
993
+ 991: 'coral fungus [珊瑚菌]',
994
+ 992: 'agaric [木耳]',
995
+ 993: 'gyromitra [鹿花菌]',
996
+ 994: 'stinkhorn, carrion fungus [鬼笔菌]',
997
+ 995: 'earthstar [地星(菌类)]',
998
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa [多叶奇果菌]',
999
+ 997: 'bolete [牛肝菌]',
1000
+ 998: 'ear, spike, capitulum [玉米穗]',
1001
+ 999: 'toilet tissue, toilet paper, bathroom tissue [卫生纸]',
1002
+ }
models/generate.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
3
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ import torch._dynamo.config
8
+ import torch._inductor.config
9
+ import copy
10
+ # torch._inductor.config.coordinate_descent_tuning = True
11
+ # torch._inductor.config.triton.unique_kernel_names = True
12
+ # torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
13
+
14
+
15
+ ### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
16
+ def top_k_top_p_filtering(
17
+ logits,
18
+ top_k: int = 0,
19
+ top_p: float = 1.0,
20
+ filter_value: float = -float("Inf"),
21
+ min_tokens_to_keep: int = 1,
22
+ ):
23
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
24
+ Args:
25
+ logits: logits distribution shape (batch size, vocabulary size)
26
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
27
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
28
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
29
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
30
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
31
+ """
32
+ if top_k > 0:
33
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
34
+ # Remove all tokens with a probability less than the last token of the top-k
35
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
36
+ logits[indices_to_remove] = filter_value
37
+
38
+ if top_p < 1.0:
39
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
40
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
41
+
42
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
43
+ sorted_indices_to_remove = cumulative_probs > top_p
44
+ if min_tokens_to_keep > 1:
45
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
46
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
47
+ # Shift the indices to the right to keep also the first token above the threshold
48
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
49
+ sorted_indices_to_remove[..., 0] = 0
50
+
51
+ # scatter sorted tensors to original indexing
52
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
53
+ logits[indices_to_remove] = filter_value
54
+ return logits
55
+
56
+
57
+ def sample(logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True):
58
+ logits = logits[:, -1, :] / max(temperature, 1e-5)
59
+ if top_k > 0 or top_p < 1.0:
60
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
61
+ probs = F.softmax(logits, dim=-1)
62
+ if sample_logits:
63
+ idx = torch.multinomial(probs, num_samples=1)
64
+ else:
65
+ _, idx = torch.topk(probs, k=1, dim=-1)
66
+ return idx, probs
67
+
68
+
69
+ def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs):
70
+ logits = logits / max(temperature, 1e-5)
71
+ if top_k > 0 or top_p < 1.0:
72
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
73
+ probs = torch.nn.functional.softmax(logits, dim=-1)
74
+ return probs
75
+
76
+
77
+ def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, **sampling_kwargs):
78
+ if cfg_scale > 1.0:
79
+ logits, _ = model(None, cond_idx, input_pos)
80
+ logits_combined = logits
81
+ cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
82
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
83
+ else:
84
+ logits, _ = model(None, cond_idx, input_pos)
85
+
86
+ return sample(logits, **sampling_kwargs)[0]
87
+
88
+
89
+ def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, **sampling_kwargs):
90
+ assert input_pos.shape[-1] == 1
91
+ if cfg_scale > 1.0:
92
+ x_combined = torch.cat([x, x])
93
+ logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos)
94
+ logits_combined = logits
95
+ cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
96
+ if cfg_flag:
97
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
98
+ else:
99
+ logits = cond_logits
100
+ else:
101
+ logits, _ = model(x, cond_idx=None, input_pos=input_pos)
102
+ return sample(logits, **sampling_kwargs)
103
+
104
+
105
+ def decode_n_tokens(
106
+ model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
107
+ cfg_scale: float, cfg_interval: int,
108
+ **sampling_kwargs):
109
+ new_tokens, new_probs = [], []
110
+ cfg_flag = True
111
+ for i in range(num_new_tokens):
112
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
113
+ if cfg_interval > -1 and i > cfg_interval:
114
+ cfg_flag = False
115
+ next_token, next_prob = decode_one_token(
116
+ model, cur_token, input_pos, cfg_scale, cfg_flag, **sampling_kwargs
117
+ )
118
+ input_pos += 1
119
+ new_tokens.append(next_token.clone())
120
+ new_probs.append(next_prob.clone())
121
+ cur_token = next_token.view(-1, 1)
122
+
123
+ return new_tokens, new_probs
124
+
125
+
126
+ @torch.no_grad()
127
+ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, **sampling_kwargs):
128
+ if model.model_type == 'c2i':
129
+ if cfg_scale > 1.0:
130
+ cond_null = torch.ones_like(cond) * model.num_classes
131
+ cond_combined = torch.cat([cond, cond_null])
132
+ else:
133
+ cond_combined = cond
134
+ T = 1
135
+ elif model.model_type == 't2i':
136
+ if cfg_scale > 1.0:
137
+ cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding
138
+ cond_combined = torch.cat([cond, cond_null])
139
+ else:
140
+ cond_combined = cond
141
+ T = cond.shape[1]
142
+ else:
143
+ raise Exception("please check model type")
144
+
145
+ T_new = T + max_new_tokens
146
+ max_seq_length = T_new
147
+ max_batch_size = cond.shape[0]
148
+
149
+ device = cond.device
150
+ with torch.device(device):
151
+ max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
152
+ model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype)
153
+
154
+ if emb_masks is not None:
155
+ assert emb_masks.shape[0] == max_batch_size
156
+ assert emb_masks.shape[-1] == T
157
+ if cfg_scale > 1.0:
158
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
159
+ else:
160
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
161
+
162
+ eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
163
+ model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
164
+
165
+ # create an empty tensor of the expected final shape and fill in the current tokens
166
+ seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
167
+
168
+ input_pos = torch.arange(0, T, device=device)
169
+ next_token = prefill(model, cond_combined, input_pos, cfg_scale, **sampling_kwargs)
170
+ seq[:, T:T+1] = next_token
171
+
172
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
173
+ generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, **sampling_kwargs)
174
+ seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
175
+
176
+ return seq[:, T:]
models/gpt.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
3
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
+ # nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
5
+ # llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py
6
+ # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
7
+ # PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
8
+ from dataclasses import dataclass
9
+ from typing import Optional, List
10
+
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import functional as F
15
+
16
+
17
+ def find_multiple(n: int, k: int):
18
+ if n % k == 0:
19
+ return n
20
+ return n + k - (n % k)
21
+
22
+ @dataclass
23
+ class ModelArgs:
24
+ dim: int = 4096
25
+ n_layer: int = 32
26
+ n_head: int = 32
27
+ n_kv_head: Optional[int] = None
28
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
29
+ ffn_dim_multiplier: Optional[float] = None
30
+ rope_base: float = 10000
31
+ norm_eps: float = 1e-5
32
+ initializer_range: float = 0.02
33
+
34
+ token_dropout_p: float = 0.1
35
+ attn_dropout_p: float = 0.0
36
+ resid_dropout_p: float = 0.1
37
+ ffn_dropout_p: float = 0.1
38
+ drop_path_rate: float = 0.0
39
+
40
+ num_classes: int = 1000
41
+ caption_dim: int = 2048
42
+ class_dropout_prob: float = 0.1
43
+ model_type: str = 'c2i'
44
+
45
+ vocab_size: int = 16384
46
+ cls_token_num: int = 1
47
+ block_size: int = 256
48
+ max_batch_size: int = 32
49
+ max_seq_len: int = 2048
50
+
51
+
52
+ #################################################################################
53
+ # Embedding Layers for Class Labels #
54
+ #################################################################################
55
+ class LabelEmbedder(nn.Module):
56
+ """
57
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
58
+ """
59
+ def __init__(self, num_classes, hidden_size, dropout_prob):
60
+ super().__init__()
61
+ use_cfg_embedding = dropout_prob > 0
62
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
63
+ self.num_classes = num_classes
64
+ self.dropout_prob = dropout_prob
65
+
66
+ def token_drop(self, labels, force_drop_ids=None):
67
+ """
68
+ Drops labels to enable classifier-free guidance.
69
+ """
70
+ if force_drop_ids is None:
71
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
72
+ else:
73
+ drop_ids = force_drop_ids == 1
74
+ labels = torch.where(drop_ids, self.num_classes, labels)
75
+ return labels
76
+
77
+ def forward(self, labels, train, force_drop_ids=None):
78
+ use_dropout = self.dropout_prob > 0
79
+ if (train and use_dropout) or (force_drop_ids is not None):
80
+ labels = self.token_drop(labels, force_drop_ids)
81
+ embeddings = self.embedding_table(labels).unsqueeze(1)
82
+ return embeddings
83
+
84
+
85
+ #################################################################################
86
+ # Embedding Layers for Text Feature #
87
+ #################################################################################
88
+ class CaptionEmbedder(nn.Module):
89
+ """
90
+ Embeds text caption into vector representations. Also handles label dropout for classifier-free guidance.
91
+ """
92
+ def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120):
93
+ super().__init__()
94
+ self.cap_proj = MLP(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size)
95
+ self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
96
+ self.uncond_prob = uncond_prob
97
+
98
+ def token_drop(self, caption, force_drop_ids=None):
99
+ """
100
+ Drops labels to enable classifier-free guidance.
101
+ """
102
+ if force_drop_ids is None:
103
+ drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
104
+ else:
105
+ drop_ids = force_drop_ids == 1
106
+ caption = torch.where(drop_ids[:, None, None], self.uncond_embedding, caption)
107
+ return caption
108
+
109
+ def forward(self, caption, train, force_drop_ids=None):
110
+ use_dropout = self.uncond_prob > 0
111
+ if (train and use_dropout) or (force_drop_ids is not None):
112
+ caption = self.token_drop(caption, force_drop_ids)
113
+ embeddings = self.cap_proj(caption)
114
+ return embeddings
115
+
116
+
117
+ class MLP(nn.Module):
118
+ def __init__(self, in_features, hidden_features, out_features):
119
+ super().__init__()
120
+ out_features = out_features or in_features
121
+ hidden_features = hidden_features or in_features
122
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
123
+ self.act = nn.GELU(approximate='tanh')
124
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
125
+
126
+ def forward(self, x):
127
+ x = self.fc1(x)
128
+ x = self.act(x)
129
+ x = self.fc2(x)
130
+ return x
131
+
132
+
133
+ #################################################################################
134
+ # GPT Model #
135
+ #################################################################################
136
+ class RMSNorm(torch.nn.Module):
137
+ def __init__(self, dim: int, eps: float = 1e-5):
138
+ super().__init__()
139
+ self.eps = eps
140
+ self.weight = nn.Parameter(torch.ones(dim))
141
+
142
+ def _norm(self, x):
143
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
144
+
145
+ def forward(self, x):
146
+ output = self._norm(x.float()).type_as(x)
147
+ return output * self.weight
148
+
149
+
150
+ class FeedForward(nn.Module):
151
+ def __init__(self, config: ModelArgs):
152
+ super().__init__()
153
+ hidden_dim = 4 * config.dim
154
+ hidden_dim = int(2 * hidden_dim / 3)
155
+ # custom dim factor multiplier
156
+ if config.ffn_dim_multiplier is not None:
157
+ hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
158
+ hidden_dim = find_multiple(hidden_dim, config.multiple_of)
159
+
160
+ self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
161
+ self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
162
+ self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
163
+ self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
164
+
165
+ def forward(self, x):
166
+ return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
167
+
168
+
169
+ class KVCache(nn.Module):
170
+ def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
171
+ super().__init__()
172
+ cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
173
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
174
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
175
+
176
+ def update(self, input_pos, k_val, v_val):
177
+ # input_pos: [S], k_val: [B, H, S, D]
178
+ assert input_pos.shape[0] == k_val.shape[2]
179
+ k_out = self.k_cache
180
+ v_out = self.v_cache
181
+ k_out[:, :, input_pos] = k_val
182
+ v_out[:, :, input_pos] = v_val
183
+
184
+ return k_out, v_out
185
+
186
+
187
+ class Attention(nn.Module):
188
+ def __init__(self, config: ModelArgs):
189
+ super().__init__()
190
+ assert config.dim % config.n_head == 0
191
+ self.dim = config.dim
192
+ self.head_dim = config.dim // config.n_head
193
+ self.n_head = config.n_head
194
+ self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
195
+ total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
196
+
197
+ # key, query, value projections for all heads, but in a batch
198
+ self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
199
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
200
+ self.kv_cache = None
201
+
202
+ # regularization
203
+ self.attn_dropout_p = config.attn_dropout_p
204
+ self.resid_dropout = nn.Dropout(config.resid_dropout_p)
205
+
206
+ def forward(
207
+ self, x: torch.Tensor, freqs_cis: torch.Tensor = None,
208
+ input_pos: Optional[torch.Tensor] = None,
209
+ mask: Optional[torch.Tensor] = None
210
+ ):
211
+ bsz, seqlen, _ = x.shape
212
+ kv_size = self.n_kv_head * self.head_dim
213
+ xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
214
+
215
+ xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
216
+ xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
217
+ xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
218
+
219
+ xq = apply_rotary_emb(xq, freqs_cis)
220
+ xk = apply_rotary_emb(xk, freqs_cis)
221
+
222
+ xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
223
+
224
+ if self.kv_cache is not None:
225
+ keys, values = self.kv_cache.update(input_pos, xk, xv)
226
+ else:
227
+ keys, values = xk, xv
228
+ keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
229
+ values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
230
+
231
+ output = F.scaled_dot_product_attention(
232
+ xq, keys, values,
233
+ attn_mask=mask,
234
+ is_causal=True if mask is None else False, # is_causal=False is for KV cache
235
+ dropout_p=self.attn_dropout_p if self.training else 0)
236
+
237
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
238
+
239
+ output = self.resid_dropout(self.wo(output))
240
+ return output
241
+
242
+
243
+ class TransformerBlock(nn.Module):
244
+ def __init__(self, config: ModelArgs, drop_path: float):
245
+ super().__init__()
246
+ self.attention = Attention(config)
247
+ self.feed_forward = FeedForward(config)
248
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
249
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
250
+
251
+ def forward(
252
+ self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
253
+ h = x + self.attention(self.attention_norm(x), freqs_cis, start_pos, mask)
254
+ out = h + self.feed_forward(self.ffn_norm(h))
255
+ return out
256
+
257
+
258
+ class Transformer(nn.Module):
259
+ def __init__(self, config: ModelArgs):
260
+ super().__init__()
261
+ self.config = config
262
+ self.vocab_size = config.vocab_size
263
+ self.n_layer = config.n_layer
264
+ self.block_size = config.block_size
265
+ self.num_classes = config.num_classes
266
+ self.model_type = config.model_type
267
+ self.cls_token_num = config.cls_token_num
268
+ if self.model_type == 'c2i':
269
+ self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
270
+ elif self.model_type == 't2i':
271
+ self.cls_embedding = CaptionEmbedder(config.caption_dim, config.dim, config.class_dropout_prob)
272
+ else:
273
+ raise Exception("please check model type")
274
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
275
+ self.tok_dropout = nn.Dropout(config.token_dropout_p)
276
+
277
+ # transformer blocks
278
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
279
+ self.layers = torch.nn.ModuleList()
280
+ for layer_id in range(config.n_layer):
281
+ self.layers.append(TransformerBlock(config, dpr[layer_id]))
282
+
283
+ # output layer
284
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
285
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
286
+
287
+ # 2d rotary pos embedding
288
+ grid_size = int(self.block_size ** 0.5)
289
+ assert grid_size * grid_size == self.block_size
290
+ self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
291
+
292
+ # KVCache
293
+ self.max_batch_size = -1
294
+ self.max_seq_length = -1
295
+
296
+ self.initialize_weights()
297
+
298
+ def initialize_weights(self):
299
+ # Initialize nn.Linear and nn.Embedding
300
+ self.apply(self._init_weights)
301
+
302
+ # Zero-out output layers:
303
+ nn.init.constant_(self.output.weight, 0)
304
+
305
+ def _init_weights(self, module):
306
+ std = self.config.initializer_range
307
+ if isinstance(module, nn.Linear):
308
+ module.weight.data.normal_(mean=0.0, std=std)
309
+ if module.bias is not None:
310
+ module.bias.data.zero_()
311
+ elif isinstance(module, nn.Embedding):
312
+ module.weight.data.normal_(mean=0.0, std=std)
313
+
314
+ def setup_caches(self, max_batch_size, max_seq_length, dtype):
315
+ # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
316
+ # return
317
+ head_dim = self.config.dim // self.config.n_head
318
+ max_seq_length = find_multiple(max_seq_length, 8)
319
+ self.max_seq_length = max_seq_length
320
+ self.max_batch_size = max_batch_size
321
+ for b in self.layers:
322
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)
323
+
324
+ causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
325
+ self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
326
+ grid_size = int(self.config.block_size ** 0.5)
327
+ assert grid_size * grid_size == self.block_size
328
+ self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
329
+
330
+ def forward(
331
+ self,
332
+ idx: torch.Tensor,
333
+ cond_idx: torch.Tensor, # cond_idx_or_embed
334
+ input_pos: Optional[torch.Tensor] = None,
335
+ targets: Optional[torch.Tensor] = None,
336
+ mask: Optional[torch.Tensor] = None,
337
+ valid: Optional[torch.Tensor] = None,
338
+ ):
339
+ if idx is not None and cond_idx is not None: # training or naive inference
340
+ cond_embeddings = self.cls_embedding(cond_idx, train=self.training)[:,:self.cls_token_num]
341
+ token_embeddings = self.tok_embeddings(idx)
342
+ token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
343
+ h = self.tok_dropout(token_embeddings)
344
+ self.freqs_cis = self.freqs_cis.to(h.device)
345
+ else:
346
+ if cond_idx is not None: # prefill in inference
347
+ token_embeddings = self.cls_embedding(cond_idx, train=self.training)[:,:self.cls_token_num]
348
+ else: # decode_n_tokens(kv cache) in inference
349
+ token_embeddings = self.tok_embeddings(idx)
350
+
351
+ bs = token_embeddings.shape[0]
352
+ mask = self.causal_mask[:bs, None, input_pos]
353
+ h = self.tok_dropout(token_embeddings)
354
+ self.freqs_cis = self.freqs_cis
355
+
356
+ if self.training:
357
+ freqs_cis = self.freqs_cis[:token_embeddings.shape[1]]
358
+ else:
359
+ freqs_cis = self.freqs_cis[input_pos]
360
+ # transformer blocks
361
+ for layer in self.layers:
362
+ h = layer(h, freqs_cis, input_pos, mask)
363
+
364
+ # output layers
365
+ h = self.norm(h)
366
+ logits = self.output(h).float()
367
+
368
+ if self.training:
369
+ logits = logits[:, self.cls_token_num - 1:].contiguous()
370
+
371
+ # if we are given some desired targets also calculate the loss
372
+ loss = None
373
+ if valid is not None:
374
+ loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
375
+ valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1)
376
+ loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
377
+ elif targets is not None:
378
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
379
+
380
+ return logits, loss
381
+
382
+
383
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
384
+ return list(self.layers)
385
+
386
+
387
+
388
+ #################################################################################
389
+ # Rotary Positional Embedding Functions #
390
+ #################################################################################
391
+ # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
392
+ def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
393
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
394
+ t = torch.arange(seq_len, device=freqs.device)
395
+ freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
396
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
397
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
398
+ cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
399
+ return cond_cache
400
+
401
+
402
+ def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
403
+ # split the dimension into half, one for x and one for y
404
+ half_dim = n_elem // 2
405
+ freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
406
+ t = torch.arange(grid_size, device=freqs.device)
407
+ freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
408
+ freqs_grid = torch.concat([
409
+ freqs[:, None, :].expand(-1, grid_size, -1),
410
+ freqs[None, :, :].expand(grid_size, -1, -1),
411
+ ], dim=-1) # (grid_size, grid_size, head_dim // 2)
412
+ cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
413
+ cache = cache_grid.flatten(0, 1)
414
+ cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
415
+ return cond_cache
416
+
417
+
418
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
419
+ # x: (bs, seq_len, n_head, head_dim)
420
+ # freqs_cis (seq_len, head_dim // 2, 2)
421
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
422
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
423
+ x_out2 = torch.stack([
424
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
425
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
426
+ ], dim=-1)
427
+ x_out2 = x_out2.flatten(3)
428
+ return x_out2.type_as(x)
429
+
430
+
431
+
432
+ #################################################################################
433
+ # GPT Configs #
434
+ #################################################################################
435
+ ### text-conditional
436
+ def GPT_7B(**kwargs):
437
+ return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B
438
+
439
+ def GPT_3B(**kwargs):
440
+ return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B
441
+
442
+ def GPT_1B(**kwargs):
443
+ return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B
444
+
445
+ ### class-conditional
446
+ def GPT_XXXL(**kwargs):
447
+ return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
448
+
449
+ def GPT_XXL(**kwargs):
450
+ return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
451
+
452
+ def GPT_XL(**kwargs):
453
+ return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
454
+
455
+ def GPT_L(**kwargs):
456
+ return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
457
+
458
+ def GPT_B(**kwargs):
459
+ return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
460
+
461
+
462
+ GPT_models = {
463
+ 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
464
+ 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
465
+ }
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torch
tokenizer_image/discriminator.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # taming-transformers: https://github.com/CompVis/taming-transformers
3
+ # stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
4
+ # maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py
5
+ import functools
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ try:
10
+ from kornia.filters import filter2d
11
+ except:
12
+ pass
13
+
14
+ #################################################################################
15
+ # PatchGAN #
16
+ #################################################################################
17
+ class PatchGANDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22
+ """Construct a PatchGAN discriminator
23
+ Parameters:
24
+ input_nc (int) -- the number of channels in input images
25
+ ndf (int) -- the number of filters in the last conv layer
26
+ n_layers (int) -- the number of conv layers in the discriminator
27
+ norm_layer -- normalization layer
28
+ """
29
+ super(PatchGANDiscriminator, self).__init__()
30
+ if not use_actnorm:
31
+ norm_layer = nn.BatchNorm2d
32
+ else:
33
+ norm_layer = ActNorm
34
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35
+ use_bias = norm_layer.func != nn.BatchNorm2d
36
+ else:
37
+ use_bias = norm_layer != nn.BatchNorm2d
38
+
39
+ kw = 4
40
+ padw = 1
41
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42
+ nf_mult = 1
43
+ nf_mult_prev = 1
44
+ for n in range(1, n_layers): # gradually increase the number of filters
45
+ nf_mult_prev = nf_mult
46
+ nf_mult = min(2 ** n, 8)
47
+ sequence += [
48
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49
+ norm_layer(ndf * nf_mult),
50
+ nn.LeakyReLU(0.2, True)
51
+ ]
52
+
53
+ nf_mult_prev = nf_mult
54
+ nf_mult = min(2 ** n_layers, 8)
55
+ sequence += [
56
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57
+ norm_layer(ndf * nf_mult),
58
+ nn.LeakyReLU(0.2, True)
59
+ ]
60
+
61
+ sequence += [
62
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63
+ self.main = nn.Sequential(*sequence)
64
+
65
+ self.apply(self._init_weights)
66
+
67
+ def _init_weights(self, module):
68
+ if isinstance(module, nn.Conv2d):
69
+ nn.init.normal_(module.weight.data, 0.0, 0.02)
70
+ elif isinstance(module, nn.BatchNorm2d):
71
+ nn.init.normal_(module.weight.data, 1.0, 0.02)
72
+ nn.init.constant_(module.bias.data, 0)
73
+
74
+ def forward(self, input):
75
+ """Standard forward."""
76
+ return self.main(input)
77
+
78
+
79
+ class ActNorm(nn.Module):
80
+ def __init__(self, num_features, logdet=False, affine=True,
81
+ allow_reverse_init=False):
82
+ assert affine
83
+ super().__init__()
84
+ self.logdet = logdet
85
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
86
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
87
+ self.allow_reverse_init = allow_reverse_init
88
+
89
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
90
+
91
+ def initialize(self, input):
92
+ with torch.no_grad():
93
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
94
+ mean = (
95
+ flatten.mean(1)
96
+ .unsqueeze(1)
97
+ .unsqueeze(2)
98
+ .unsqueeze(3)
99
+ .permute(1, 0, 2, 3)
100
+ )
101
+ std = (
102
+ flatten.std(1)
103
+ .unsqueeze(1)
104
+ .unsqueeze(2)
105
+ .unsqueeze(3)
106
+ .permute(1, 0, 2, 3)
107
+ )
108
+
109
+ self.loc.data.copy_(-mean)
110
+ self.scale.data.copy_(1 / (std + 1e-6))
111
+
112
+ def forward(self, input, reverse=False):
113
+ if reverse:
114
+ return self.reverse(input)
115
+ if len(input.shape) == 2:
116
+ input = input[:,:,None,None]
117
+ squeeze = True
118
+ else:
119
+ squeeze = False
120
+
121
+ _, _, height, width = input.shape
122
+
123
+ if self.training and self.initialized.item() == 0:
124
+ self.initialize(input)
125
+ self.initialized.fill_(1)
126
+
127
+ h = self.scale * (input + self.loc)
128
+
129
+ if squeeze:
130
+ h = h.squeeze(-1).squeeze(-1)
131
+
132
+ if self.logdet:
133
+ log_abs = torch.log(torch.abs(self.scale))
134
+ logdet = height*width*torch.sum(log_abs)
135
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
136
+ return h, logdet
137
+
138
+ return h
139
+
140
+ def reverse(self, output):
141
+ if self.training and self.initialized.item() == 0:
142
+ if not self.allow_reverse_init:
143
+ raise RuntimeError(
144
+ "Initializing ActNorm in reverse direction is "
145
+ "disabled by default. Use allow_reverse_init=True to enable."
146
+ )
147
+ else:
148
+ self.initialize(output)
149
+ self.initialized.fill_(1)
150
+
151
+ if len(output.shape) == 2:
152
+ output = output[:,:,None,None]
153
+ squeeze = True
154
+ else:
155
+ squeeze = False
156
+
157
+ h = output / self.scale - self.loc
158
+
159
+ if squeeze:
160
+ h = h.squeeze(-1).squeeze(-1)
161
+ return h
162
+
163
+
164
+
165
+ #################################################################################
166
+ # StyleGAN #
167
+ #################################################################################
168
+ class StyleGANDiscriminator(nn.Module):
169
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256):
170
+ super().__init__()
171
+ channels = {
172
+ 4: 512,
173
+ 8: 512,
174
+ 16: 512,
175
+ 32: 512,
176
+ 64: 256 * channel_multiplier,
177
+ 128: 128 * channel_multiplier,
178
+ 256: 64 * channel_multiplier,
179
+ 512: 32 * channel_multiplier,
180
+ 1024: 16 * channel_multiplier,
181
+ }
182
+
183
+ log_size = int(math.log(image_size, 2))
184
+ in_channel = channels[image_size]
185
+
186
+ blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()]
187
+ for i in range(log_size, 2, -1):
188
+ out_channel = channels[2 ** (i - 1)]
189
+ blocks.append(DiscriminatorBlock(in_channel, out_channel))
190
+ in_channel = out_channel
191
+ self.blocks = nn.ModuleList(blocks)
192
+
193
+ self.final_conv = nn.Sequential(
194
+ nn.Conv2d(in_channel, channels[4], 3, padding=1),
195
+ leaky_relu(),
196
+ )
197
+ self.final_linear = nn.Sequential(
198
+ nn.Linear(channels[4] * 4 * 4, channels[4]),
199
+ leaky_relu(),
200
+ nn.Linear(channels[4], 1)
201
+ )
202
+
203
+ def forward(self, x):
204
+ for block in self.blocks:
205
+ x = block(x)
206
+ x = self.final_conv(x)
207
+ x = x.view(x.shape[0], -1)
208
+ x = self.final_linear(x)
209
+ return x
210
+
211
+
212
+ class DiscriminatorBlock(nn.Module):
213
+ def __init__(self, input_channels, filters, downsample=True):
214
+ super().__init__()
215
+ self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
216
+
217
+ self.net = nn.Sequential(
218
+ nn.Conv2d(input_channels, filters, 3, padding=1),
219
+ leaky_relu(),
220
+ nn.Conv2d(filters, filters, 3, padding=1),
221
+ leaky_relu()
222
+ )
223
+
224
+ self.downsample = nn.Sequential(
225
+ Blur(),
226
+ nn.Conv2d(filters, filters, 3, padding = 1, stride = 2)
227
+ ) if downsample else None
228
+
229
+ def forward(self, x):
230
+ res = self.conv_res(x)
231
+ x = self.net(x)
232
+ if exists(self.downsample):
233
+ x = self.downsample(x)
234
+ x = (x + res) * (1 / math.sqrt(2))
235
+ return x
236
+
237
+
238
+ class Blur(nn.Module):
239
+ def __init__(self):
240
+ super().__init__()
241
+ f = torch.Tensor([1, 2, 1])
242
+ self.register_buffer('f', f)
243
+
244
+ def forward(self, x):
245
+ f = self.f
246
+ f = f[None, None, :] * f [None, :, None]
247
+ return filter2d(x, f, normalized=True)
248
+
249
+
250
+ def leaky_relu(p=0.2):
251
+ return nn.LeakyReLU(p, inplace=True)
252
+
253
+
254
+ def exists(val):
255
+ return val is not None
tokenizer_image/discriminator_patchgan.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # taming-transformers: https://github.com/CompVis/taming-transformers
3
+ import functools
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class NLayerDiscriminator(nn.Module):
9
+ """Defines a PatchGAN discriminator as in Pix2Pix
10
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
11
+ """
12
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
13
+ """Construct a PatchGAN discriminator
14
+ Parameters:
15
+ input_nc (int) -- the number of channels in input images
16
+ ndf (int) -- the number of filters in the last conv layer
17
+ n_layers (int) -- the number of conv layers in the discriminator
18
+ norm_layer -- normalization layer
19
+ """
20
+ super(NLayerDiscriminator, self).__init__()
21
+ if not use_actnorm:
22
+ norm_layer = nn.BatchNorm2d
23
+ else:
24
+ norm_layer = ActNorm
25
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
26
+ use_bias = norm_layer.func != nn.BatchNorm2d
27
+ else:
28
+ use_bias = norm_layer != nn.BatchNorm2d
29
+
30
+ kw = 4
31
+ padw = 1
32
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
33
+ nf_mult = 1
34
+ nf_mult_prev = 1
35
+ for n in range(1, n_layers): # gradually increase the number of filters
36
+ nf_mult_prev = nf_mult
37
+ nf_mult = min(2 ** n, 8)
38
+ sequence += [
39
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
40
+ norm_layer(ndf * nf_mult),
41
+ nn.LeakyReLU(0.2, True)
42
+ ]
43
+
44
+ nf_mult_prev = nf_mult
45
+ nf_mult = min(2 ** n_layers, 8)
46
+ sequence += [
47
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
48
+ norm_layer(ndf * nf_mult),
49
+ nn.LeakyReLU(0.2, True)
50
+ ]
51
+
52
+ sequence += [
53
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
54
+ self.main = nn.Sequential(*sequence)
55
+
56
+ self.apply(self._init_weights)
57
+
58
+ def _init_weights(self, module):
59
+ if isinstance(module, nn.Conv2d):
60
+ nn.init.normal_(module.weight.data, 0.0, 0.02)
61
+ elif isinstance(module, nn.BatchNorm2d):
62
+ nn.init.normal_(module.weight.data, 1.0, 0.02)
63
+ nn.init.constant_(module.bias.data, 0)
64
+
65
+ def forward(self, input):
66
+ """Standard forward."""
67
+ return self.main(input)
68
+
69
+
70
+ class ActNorm(nn.Module):
71
+ def __init__(self, num_features, logdet=False, affine=True,
72
+ allow_reverse_init=False):
73
+ assert affine
74
+ super().__init__()
75
+ self.logdet = logdet
76
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
77
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
78
+ self.allow_reverse_init = allow_reverse_init
79
+
80
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
81
+
82
+ def initialize(self, input):
83
+ with torch.no_grad():
84
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
85
+ mean = (
86
+ flatten.mean(1)
87
+ .unsqueeze(1)
88
+ .unsqueeze(2)
89
+ .unsqueeze(3)
90
+ .permute(1, 0, 2, 3)
91
+ )
92
+ std = (
93
+ flatten.std(1)
94
+ .unsqueeze(1)
95
+ .unsqueeze(2)
96
+ .unsqueeze(3)
97
+ .permute(1, 0, 2, 3)
98
+ )
99
+
100
+ self.loc.data.copy_(-mean)
101
+ self.scale.data.copy_(1 / (std + 1e-6))
102
+
103
+ def forward(self, input, reverse=False):
104
+ if reverse:
105
+ return self.reverse(input)
106
+ if len(input.shape) == 2:
107
+ input = input[:,:,None,None]
108
+ squeeze = True
109
+ else:
110
+ squeeze = False
111
+
112
+ _, _, height, width = input.shape
113
+
114
+ if self.training and self.initialized.item() == 0:
115
+ self.initialize(input)
116
+ self.initialized.fill_(1)
117
+
118
+ h = self.scale * (input + self.loc)
119
+
120
+ if squeeze:
121
+ h = h.squeeze(-1).squeeze(-1)
122
+
123
+ if self.logdet:
124
+ log_abs = torch.log(torch.abs(self.scale))
125
+ logdet = height*width*torch.sum(log_abs)
126
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
127
+ return h, logdet
128
+
129
+ return h
130
+
131
+ def reverse(self, output):
132
+ if self.training and self.initialized.item() == 0:
133
+ if not self.allow_reverse_init:
134
+ raise RuntimeError(
135
+ "Initializing ActNorm in reverse direction is "
136
+ "disabled by default. Use allow_reverse_init=True to enable."
137
+ )
138
+ else:
139
+ self.initialize(output)
140
+ self.initialized.fill_(1)
141
+
142
+ if len(output.shape) == 2:
143
+ output = output[:,:,None,None]
144
+ squeeze = True
145
+ else:
146
+ squeeze = False
147
+
148
+ h = output / self.scale - self.loc
149
+
150
+ if squeeze:
151
+ h = h.squeeze(-1).squeeze(-1)
152
+ return h
tokenizer_image/discriminator_stylegan.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # stylegan2-pytorch: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/stylegan2_pytorch.py
3
+ # stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
4
+ # maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ try:
9
+ from kornia.filters import filter2d
10
+ except:
11
+ pass
12
+
13
+ class Discriminator(nn.Module):
14
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256):
15
+ super().__init__()
16
+ channels = {
17
+ 4: 512,
18
+ 8: 512,
19
+ 16: 512,
20
+ 32: 512,
21
+ 64: 256 * channel_multiplier,
22
+ 128: 128 * channel_multiplier,
23
+ 256: 64 * channel_multiplier,
24
+ 512: 32 * channel_multiplier,
25
+ 1024: 16 * channel_multiplier,
26
+ }
27
+
28
+ log_size = int(math.log(image_size, 2))
29
+ in_channel = channels[image_size]
30
+
31
+ blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()]
32
+ for i in range(log_size, 2, -1):
33
+ out_channel = channels[2 ** (i - 1)]
34
+ blocks.append(DiscriminatorBlock(in_channel, out_channel))
35
+ in_channel = out_channel
36
+ self.blocks = nn.ModuleList(blocks)
37
+
38
+ self.final_conv = nn.Sequential(
39
+ nn.Conv2d(in_channel, channels[4], 3, padding=1),
40
+ leaky_relu(),
41
+ )
42
+ self.final_linear = nn.Sequential(
43
+ nn.Linear(channels[4] * 4 * 4, channels[4]),
44
+ leaky_relu(),
45
+ nn.Linear(channels[4], 1)
46
+ )
47
+
48
+ def forward(self, x):
49
+ for block in self.blocks:
50
+ x = block(x)
51
+ x = self.final_conv(x)
52
+ x = x.view(x.shape[0], -1)
53
+ x = self.final_linear(x)
54
+ return x
55
+
56
+
57
+ class DiscriminatorBlock(nn.Module):
58
+ def __init__(self, input_channels, filters, downsample=True):
59
+ super().__init__()
60
+ self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
61
+
62
+ self.net = nn.Sequential(
63
+ nn.Conv2d(input_channels, filters, 3, padding=1),
64
+ leaky_relu(),
65
+ nn.Conv2d(filters, filters, 3, padding=1),
66
+ leaky_relu()
67
+ )
68
+
69
+ self.downsample = nn.Sequential(
70
+ Blur(),
71
+ nn.Conv2d(filters, filters, 3, padding = 1, stride = 2)
72
+ ) if downsample else None
73
+
74
+ def forward(self, x):
75
+ res = self.conv_res(x)
76
+ x = self.net(x)
77
+ if exists(self.downsample):
78
+ x = self.downsample(x)
79
+ x = (x + res) * (1 / math.sqrt(2))
80
+ return x
81
+
82
+
83
+
84
+ class Blur(nn.Module):
85
+ def __init__(self):
86
+ super().__init__()
87
+ f = torch.Tensor([1, 2, 1])
88
+ self.register_buffer('f', f)
89
+
90
+ def forward(self, x):
91
+ f = self.f
92
+ f = f[None, None, :] * f [None, :, None]
93
+ return filter2d(x, f, normalized=True)
94
+
95
+
96
+ def leaky_relu(p=0.2):
97
+ return nn.LeakyReLU(p, inplace=True)
98
+
99
+
100
+ def exists(val):
101
+ return val is not None
tokenizer_image/lpips.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ import os, hashlib
4
+ import requests
5
+ from tqdm import tqdm
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchvision import models
10
+ from collections import namedtuple
11
+
12
+ URL_MAP = {
13
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
14
+ }
15
+
16
+ CKPT_MAP = {
17
+ "vgg_lpips": "vgg.pth"
18
+ }
19
+
20
+ MD5_MAP = {
21
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
22
+ }
23
+
24
+ def download(url, local_path, chunk_size=1024):
25
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
26
+ with requests.get(url, stream=True) as r:
27
+ total_size = int(r.headers.get("content-length", 0))
28
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
29
+ with open(local_path, "wb") as f:
30
+ for data in r.iter_content(chunk_size=chunk_size):
31
+ if data:
32
+ f.write(data)
33
+ pbar.update(chunk_size)
34
+
35
+
36
+ def md5_hash(path):
37
+ with open(path, "rb") as f:
38
+ content = f.read()
39
+ return hashlib.md5(content).hexdigest()
40
+
41
+
42
+ def get_ckpt_path(name, root, check=False):
43
+ assert name in URL_MAP
44
+ path = os.path.join(root, CKPT_MAP[name])
45
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
46
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
47
+ download(URL_MAP[name], path)
48
+ md5 = md5_hash(path)
49
+ assert md5 == MD5_MAP[name], md5
50
+ return path
51
+
52
+
53
+ class LPIPS(nn.Module):
54
+ # Learned perceptual metric
55
+ def __init__(self, use_dropout=True):
56
+ super().__init__()
57
+ self.scaling_layer = ScalingLayer()
58
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
59
+ self.net = vgg16(pretrained=True, requires_grad=False)
60
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
61
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
62
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
63
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
64
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
65
+ self.load_from_pretrained()
66
+ for param in self.parameters():
67
+ param.requires_grad = False
68
+
69
+ def load_from_pretrained(self, name="vgg_lpips"):
70
+ ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
71
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
72
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
73
+
74
+ @classmethod
75
+ def from_pretrained(cls, name="vgg_lpips"):
76
+ if name != "vgg_lpips":
77
+ raise NotImplementedError
78
+ model = cls()
79
+ ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
80
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
81
+ return model
82
+
83
+ def forward(self, input, target):
84
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
85
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
86
+ feats0, feats1, diffs = {}, {}, {}
87
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
88
+ for kk in range(len(self.chns)):
89
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
90
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
91
+
92
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
93
+ val = res[0]
94
+ for l in range(1, len(self.chns)):
95
+ val += res[l]
96
+ return val
97
+
98
+
99
+ class ScalingLayer(nn.Module):
100
+ def __init__(self):
101
+ super(ScalingLayer, self).__init__()
102
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
103
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
104
+
105
+ def forward(self, inp):
106
+ return (inp - self.shift) / self.scale
107
+
108
+
109
+ class NetLinLayer(nn.Module):
110
+ """ A single linear layer which does a 1x1 conv """
111
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
112
+ super(NetLinLayer, self).__init__()
113
+ layers = [nn.Dropout(), ] if (use_dropout) else []
114
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
115
+ self.model = nn.Sequential(*layers)
116
+
117
+
118
+ class vgg16(torch.nn.Module):
119
+ def __init__(self, requires_grad=False, pretrained=True):
120
+ super(vgg16, self).__init__()
121
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
122
+ self.slice1 = torch.nn.Sequential()
123
+ self.slice2 = torch.nn.Sequential()
124
+ self.slice3 = torch.nn.Sequential()
125
+ self.slice4 = torch.nn.Sequential()
126
+ self.slice5 = torch.nn.Sequential()
127
+ self.N_slices = 5
128
+ for x in range(4):
129
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
130
+ for x in range(4, 9):
131
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
132
+ for x in range(9, 16):
133
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
134
+ for x in range(16, 23):
135
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
136
+ for x in range(23, 30):
137
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
138
+ if not requires_grad:
139
+ for param in self.parameters():
140
+ param.requires_grad = False
141
+
142
+ def forward(self, X):
143
+ h = self.slice1(X)
144
+ h_relu1_2 = h
145
+ h = self.slice2(h)
146
+ h_relu2_2 = h
147
+ h = self.slice3(h)
148
+ h_relu3_3 = h
149
+ h = self.slice4(h)
150
+ h_relu4_3 = h
151
+ h = self.slice5(h)
152
+ h_relu5_3 = h
153
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
154
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
155
+ return out
156
+
157
+
158
+ def normalize_tensor(x,eps=1e-10):
159
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
160
+ return x/(norm_factor+eps)
161
+
162
+
163
+ def spatial_average(x, keepdim=True):
164
+ return x.mean([2,3],keepdim=keepdim)
tokenizer_image/reconstruction_vq_ddp.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.backends.cuda.matmul.allow_tf32 = True
3
+ torch.backends.cudnn.allow_tf32 = True
4
+ import torch.nn.functional as F
5
+ import torch.distributed as dist
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from torchvision import transforms
9
+ from tqdm import tqdm
10
+ import os
11
+ from PIL import Image
12
+ import numpy as np
13
+ import argparse
14
+ import itertools
15
+
16
+ from skimage.metrics import peak_signal_noise_ratio as psnr_loss
17
+ from skimage.metrics import structural_similarity as ssim_loss
18
+
19
+ from dataset.augmentation import center_crop_arr
20
+ from dataset.build import build_dataset
21
+ from tokenizer.tokenizer_image.vq_model import VQ_models
22
+
23
+
24
+
25
+ def create_npz_from_sample_folder(sample_dir, num=50000):
26
+ """
27
+ Builds a single .npz file from a folder of .png samples.
28
+ """
29
+ samples = []
30
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
31
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
32
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
33
+ samples.append(sample_np)
34
+ samples = np.stack(samples)
35
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
36
+ npz_path = f"{sample_dir}.npz"
37
+ np.savez(npz_path, arr_0=samples)
38
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
39
+ return npz_path
40
+
41
+
42
+
43
+ def main(args):
44
+ # Setup PyTorch:
45
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
46
+ torch.set_grad_enabled(False)
47
+
48
+ # Setup DDP:
49
+ dist.init_process_group("nccl")
50
+ rank = dist.get_rank()
51
+ device = rank % torch.cuda.device_count()
52
+ seed = args.global_seed * dist.get_world_size() + rank
53
+ torch.manual_seed(seed)
54
+ torch.cuda.set_device(device)
55
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
56
+
57
+ # create and load model
58
+ vq_model = VQ_models[args.vq_model](
59
+ codebook_size=args.codebook_size,
60
+ codebook_embed_dim=args.codebook_embed_dim)
61
+ vq_model.to(device)
62
+ vq_model.eval()
63
+ checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
64
+ if "ema" in checkpoint: # ema
65
+ model_weight = checkpoint["ema"]
66
+ elif "model" in checkpoint: # ddp
67
+ model_weight = checkpoint["model"]
68
+ elif "state_dict" in checkpoint:
69
+ model_weight = checkpoint["state_dict"]
70
+ else:
71
+ raise Exception("please check model weight")
72
+ vq_model.load_state_dict(model_weight)
73
+ del checkpoint
74
+
75
+ # Create folder to save samples:
76
+ folder_name = (f"{args.vq_model}-{args.dataset}-size-{args.image_size}-size-{args.image_size_eval}"
77
+ f"-codebook-size-{args.codebook_size}-dim-{args.codebook_embed_dim}-seed-{args.global_seed}")
78
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
79
+ if rank == 0:
80
+ os.makedirs(sample_folder_dir, exist_ok=True)
81
+ print(f"Saving .png samples at {sample_folder_dir}")
82
+ dist.barrier()
83
+
84
+ # Setup data:
85
+ transform = transforms.Compose([
86
+ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
87
+ transforms.ToTensor(),
88
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
89
+ ])
90
+
91
+ if args.dataset == 'imagenet':
92
+ dataset = build_dataset(args, transform=transform)
93
+ num_fid_samples = 50000
94
+ elif args.dataset == 'coco':
95
+ dataset = build_dataset(args, transform=transform)
96
+ num_fid_samples = 5000
97
+ else:
98
+ raise Exception("please check dataset")
99
+
100
+ sampler = DistributedSampler(
101
+ dataset,
102
+ num_replicas=dist.get_world_size(),
103
+ rank=rank,
104
+ shuffle=False,
105
+ seed=args.global_seed
106
+ )
107
+ loader = DataLoader(
108
+ dataset,
109
+ batch_size=args.per_proc_batch_size,
110
+ shuffle=False,
111
+ sampler=sampler,
112
+ num_workers=args.num_workers,
113
+ pin_memory=True,
114
+ drop_last=False
115
+ )
116
+
117
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
118
+ n = args.per_proc_batch_size
119
+ global_batch_size = n * dist.get_world_size()
120
+
121
+ psnr_val_rgb = []
122
+ ssim_val_rgb = []
123
+ loader = tqdm(loader) if rank == 0 else loader
124
+ total = 0
125
+ for x, _ in loader:
126
+ if args.image_size_eval != args.image_size:
127
+ rgb_gts = F.interpolate(x, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
128
+ else:
129
+ rgb_gts = x
130
+ rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1]
131
+ x = x.to(device, non_blocking=True)
132
+ with torch.no_grad():
133
+ latent, _, [_, _, indices] = vq_model.encode(x)
134
+ samples = vq_model.decode_code(indices, latent.shape) # output value is between [-1, 1]
135
+ if args.image_size_eval != args.image_size:
136
+ samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
137
+ samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
138
+
139
+ # Save samples to disk as individual .png files
140
+ for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)):
141
+ index = i * dist.get_world_size() + rank + total
142
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
143
+ # metric
144
+ rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1]
145
+ psnr = psnr_loss(rgb_restored, rgb_gt)
146
+ ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1)
147
+ psnr_val_rgb.append(psnr)
148
+ ssim_val_rgb.append(ssim)
149
+
150
+ total += global_batch_size
151
+
152
+ # ------------------------------------
153
+ # Summary
154
+ # ------------------------------------
155
+ # Make sure all processes have finished saving their samples
156
+ dist.barrier()
157
+ world_size = dist.get_world_size()
158
+ gather_psnr_val = [None for _ in range(world_size)]
159
+ gather_ssim_val = [None for _ in range(world_size)]
160
+ dist.all_gather_object(gather_psnr_val, psnr_val_rgb)
161
+ dist.all_gather_object(gather_ssim_val, ssim_val_rgb)
162
+
163
+ if rank == 0:
164
+ gather_psnr_val = list(itertools.chain(*gather_psnr_val))
165
+ gather_ssim_val = list(itertools.chain(*gather_ssim_val))
166
+ psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val)
167
+ ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val)
168
+ print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb))
169
+
170
+ result_file = f"{sample_folder_dir}_results.txt"
171
+ print("writing results to {}".format(result_file))
172
+ with open(result_file, 'w') as f:
173
+ print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f)
174
+
175
+ create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
176
+ print("Done.")
177
+
178
+ dist.barrier()
179
+ dist.destroy_process_group()
180
+
181
+
182
+ if __name__ == "__main__":
183
+ parser = argparse.ArgumentParser()
184
+ parser.add_argument("--data-path", type=str, required=True)
185
+ parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet')
186
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
187
+ parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
188
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
189
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
190
+ parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256)
191
+ parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
192
+ parser.add_argument("--sample-dir", type=str, default="reconstructions")
193
+ parser.add_argument("--per-proc-batch-size", type=int, default=32)
194
+ parser.add_argument("--global-seed", type=int, default=0)
195
+ parser.add_argument("--num-workers", type=int, default=4)
196
+ args = parser.parse_args()
197
+ main(args)
tokenizer_image/vq_demo.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import os
5
+ import argparse
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from tokenizer.tokenizer_image.vq_model import VQ_models
10
+ from dataset.augmentation import center_crop_arr
11
+
12
+
13
+ def main(args):
14
+ # Setup PyTorch:
15
+ torch.manual_seed(args.seed)
16
+ torch.set_grad_enabled(False)
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # create and load model
20
+ model = VQ_models[args.vq_model](
21
+ codebook_size=args.codebook_size,
22
+ codebook_embed_dim=args.codebook_embed_dim)
23
+ model.to(device)
24
+ model.eval()
25
+ checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
26
+ if "ema" in checkpoint: # ema
27
+ model_weight = checkpoint["ema"]
28
+ elif "model" in checkpoint: # ddp
29
+ model_weight = checkpoint["model"]
30
+ elif "state_dict" in checkpoint:
31
+ model_weight = checkpoint["state_dict"]
32
+ else:
33
+ raise Exception("please check model weight")
34
+ model.load_state_dict(model_weight)
35
+ del checkpoint
36
+
37
+ # output dir
38
+ os.makedirs(args.output_dir, exist_ok=True)
39
+ out_path = args.image_path.replace('.jpg', '_{}.jpg'.format(args.suffix))
40
+ out_path = out_path.replace('.jpeg', '_{}.jpeg'.format(args.suffix))
41
+ out_path = out_path.replace('.png', '_{}.png'.format(args.suffix))
42
+ out_filename = out_path.split('/')[-1]
43
+ out_path = os.path.join(args.output_dir, out_filename)
44
+
45
+ # load image
46
+ pil_image = Image.open(args.image_path).convert("RGB")
47
+ img = center_crop_arr(pil_image, args.image_size)
48
+ # # preprocess
49
+ # size_org = img.size
50
+ # img = img.resize((input_size, input_size))
51
+ img = np.array(img) / 255.
52
+ x = 2.0 * img - 1.0 # x value is between [-1, 1]
53
+ x = torch.tensor(x)
54
+ x = x.unsqueeze(dim=0)
55
+ x = torch.einsum('nhwc->nchw', x)
56
+ x_input = x.float().to("cuda")
57
+
58
+ # inference
59
+ with torch.no_grad():
60
+ latent, _, [_, _, indices] = model.encode(x_input)
61
+ output = model.decode_code(indices, latent.shape) # output value is between [-1, 1]
62
+
63
+ # postprocess
64
+ output = F.interpolate(output, size=[args.image_size, args.image_size], mode='bicubic').permute(0, 2, 3, 1)[0]
65
+ sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
66
+
67
+ # save
68
+ Image.fromarray(sample).save(out_path)
69
+ print("Reconstructed image is saved to {}".format(out_path))
70
+
71
+
72
+ if __name__ == "__main__":
73
+ parser = argparse.ArgumentParser()
74
+ parser.add_argument("--image-path", type=str, default="assets/example.jpg")
75
+ parser.add_argument("--output-dir", type=str, default="output_vq_demo")
76
+ parser.add_argument("--suffix", type=str, default="tokenizer_image")
77
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
78
+ parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
79
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
80
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
81
+ parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512, 1024], default=512)
82
+ parser.add_argument("--seed", type=int, default=0)
83
+ args = parser.parse_args()
84
+ main(args)
tokenizer_image/vq_loss.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # taming-transformers: https://github.com/CompVis/taming-transformers
3
+ # muse-maskgit-pytorch: https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/vqgan_vae.py
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from tokenizer.tokenizer_image.lpips import LPIPS
9
+ from tokenizer.tokenizer_image.discriminator_patchgan import NLayerDiscriminator as PatchGANDiscriminator
10
+ from tokenizer.tokenizer_image.discriminator_stylegan import Discriminator as StyleGANDiscriminator
11
+
12
+
13
+
14
+ def hinge_d_loss(logits_real, logits_fake):
15
+ loss_real = torch.mean(F.relu(1. - logits_real))
16
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
17
+ d_loss = 0.5 * (loss_real + loss_fake)
18
+ return d_loss
19
+
20
+
21
+ def vanilla_d_loss(logits_real, logits_fake):
22
+ loss_real = torch.mean(F.softplus(-logits_real))
23
+ loss_fake = torch.mean(F.softplus(logits_fake))
24
+ d_loss = 0.5 * (loss_real + loss_fake)
25
+ return d_loss
26
+
27
+
28
+ def non_saturating_d_loss(logits_real, logits_fake):
29
+ loss_real = torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real))
30
+ loss_fake = torch.mean(F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake))
31
+ d_loss = 0.5 * (loss_real + loss_fake)
32
+ return d_loss
33
+
34
+
35
+ def hinge_gen_loss(logit_fake):
36
+ return -torch.mean(logit_fake)
37
+
38
+
39
+ def non_saturating_gen_loss(logit_fake):
40
+ return torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake))
41
+
42
+
43
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
44
+ if global_step < threshold:
45
+ weight = value
46
+ return weight
47
+
48
+
49
+ class VQLoss(nn.Module):
50
+ def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_type='patchgan', image_size=256,
51
+ disc_num_layers=3, disc_in_channels=3, disc_weight=1.0, disc_adaptive_weight = False,
52
+ gen_adv_loss='hinge', reconstruction_loss='l2', reconstruction_weight=1.0,
53
+ codebook_weight=1.0, perceptual_weight=1.0,
54
+ ):
55
+ super().__init__()
56
+ # discriminator loss
57
+ assert disc_type in ["patchgan", "stylegan"]
58
+ assert disc_loss in ["hinge", "vanilla", "non-saturating"]
59
+ if disc_type == "patchgan":
60
+ self.discriminator = PatchGANDiscriminator(
61
+ input_nc=disc_in_channels,
62
+ n_layers=disc_num_layers,
63
+ ndf=disc_dim,
64
+ )
65
+ elif disc_type == "stylegan":
66
+ self.discriminator = StyleGANDiscriminator(
67
+ input_nc=disc_in_channels,
68
+ image_size=image_size,
69
+ )
70
+ else:
71
+ raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.")
72
+ if disc_loss == "hinge":
73
+ self.disc_loss = hinge_d_loss
74
+ elif disc_loss == "vanilla":
75
+ self.disc_loss = vanilla_d_loss
76
+ elif disc_loss == "non-saturating":
77
+ self.disc_loss = non_saturating_d_loss
78
+ else:
79
+ raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.")
80
+ self.discriminator_iter_start = disc_start
81
+ self.disc_weight = disc_weight
82
+ self.disc_adaptive_weight = disc_adaptive_weight
83
+
84
+ assert gen_adv_loss in ["hinge", "non-saturating"]
85
+ # gen_adv_loss
86
+ if gen_adv_loss == "hinge":
87
+ self.gen_adv_loss = hinge_gen_loss
88
+ elif gen_adv_loss == "non-saturating":
89
+ self.gen_adv_loss = non_saturating_gen_loss
90
+ else:
91
+ raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")
92
+
93
+ # perceptual loss
94
+ self.perceptual_loss = LPIPS().eval()
95
+ self.perceptual_weight = perceptual_weight
96
+
97
+ # reconstruction loss
98
+ if reconstruction_loss == "l1":
99
+ self.rec_loss = F.l1_loss
100
+ elif reconstruction_loss == "l2":
101
+ self.rec_loss = F.mse_loss
102
+ else:
103
+ raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.")
104
+ self.rec_weight = reconstruction_weight
105
+
106
+ # codebook loss
107
+ self.codebook_weight = codebook_weight
108
+
109
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
110
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
111
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
112
+
113
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
114
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
115
+ return d_weight.detach()
116
+
117
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, global_step, last_layer=None,
118
+ logger=None, log_every=100):
119
+ # generator update
120
+ if optimizer_idx == 0:
121
+ # reconstruction loss
122
+ rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous())
123
+
124
+ # perceptual loss
125
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
126
+ p_loss = torch.mean(p_loss)
127
+
128
+ # discriminator loss
129
+ logits_fake = self.discriminator(reconstructions.contiguous())
130
+ generator_adv_loss = self.gen_adv_loss(logits_fake)
131
+
132
+ if self.disc_adaptive_weight:
133
+ null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss
134
+ disc_adaptive_weight = self.calculate_adaptive_weight(null_loss, generator_adv_loss, last_layer=last_layer)
135
+ else:
136
+ disc_adaptive_weight = 1
137
+ disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
138
+
139
+ loss = self.rec_weight * rec_loss + \
140
+ self.perceptual_weight * p_loss + \
141
+ disc_adaptive_weight * disc_weight * generator_adv_loss + \
142
+ codebook_loss[0] + codebook_loss[1] + codebook_loss[2]
143
+
144
+ if global_step % log_every == 0:
145
+ rec_loss = self.rec_weight * rec_loss
146
+ p_loss = self.perceptual_weight * p_loss
147
+ generator_adv_loss = disc_adaptive_weight * disc_weight * generator_adv_loss
148
+ logger.info(f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, "
149
+ f"vq_loss: {codebook_loss[0]:.4f}, commit_loss: {codebook_loss[1]:.4f}, entropy_loss: {codebook_loss[2]:.4f}, "
150
+ f"codebook_usage: {codebook_loss[3]:.4f}, generator_adv_loss: {generator_adv_loss:.4f}, "
151
+ f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}")
152
+ return loss
153
+
154
+ # discriminator update
155
+ if optimizer_idx == 1:
156
+ logits_real = self.discriminator(inputs.contiguous().detach())
157
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
158
+
159
+ disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
160
+ d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake)
161
+
162
+ if global_step % log_every == 0:
163
+ logits_real = logits_real.detach().mean()
164
+ logits_fake = logits_fake.detach().mean()
165
+ logger.info(f"(Discriminator) "
166
+ f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, "
167
+ f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}")
168
+ return d_adversarial_loss
tokenizer_image/vq_model.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # taming-transformers: https://github.com/CompVis/taming-transformers
3
+ # maskgit: https://github.com/google-research/maskgit
4
+ from dataclasses import dataclass, field
5
+ from typing import List
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ @dataclass
13
+ class ModelArgs:
14
+ codebook_size: int = 16384
15
+ codebook_embed_dim: int = 8
16
+ codebook_l2_norm: bool = True
17
+ codebook_show_usage: bool = True
18
+ commit_loss_beta: float = 0.25
19
+ entropy_loss_ratio: float = 0.0
20
+
21
+ encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
22
+ decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
23
+ z_channels: int = 256
24
+ dropout_p: float = 0.0
25
+
26
+
27
+
28
+ class VQModel(nn.Module):
29
+ def __init__(self, config: ModelArgs):
30
+ super().__init__()
31
+ self.config = config
32
+ self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
33
+ self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
34
+
35
+ self.quantize = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
36
+ config.commit_loss_beta, config.entropy_loss_ratio,
37
+ config.codebook_l2_norm, config.codebook_show_usage)
38
+ self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
39
+ self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1)
40
+
41
+ def encode(self, x):
42
+ h = self.encoder(x)
43
+ h = self.quant_conv(h)
44
+ quant, emb_loss, info = self.quantize(h)
45
+ return quant, emb_loss, info
46
+
47
+ def decode(self, quant):
48
+ quant = self.post_quant_conv(quant)
49
+ dec = self.decoder(quant)
50
+ return dec
51
+
52
+ def decode_code(self, code_b, shape=None, channel_first=True):
53
+ quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
54
+ dec = self.decode(quant_b)
55
+ return dec
56
+
57
+ def forward(self, input):
58
+ quant, diff, _ = self.encode(input)
59
+ dec = self.decode(quant)
60
+ return dec, diff
61
+
62
+
63
+
64
+ class Encoder(nn.Module):
65
+ def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2,
66
+ norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256):
67
+ super().__init__()
68
+ self.num_resolutions = len(ch_mult)
69
+ self.num_res_blocks = num_res_blocks
70
+ self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
71
+
72
+ # downsampling
73
+ in_ch_mult = (1,) + tuple(ch_mult)
74
+ self.conv_blocks = nn.ModuleList()
75
+ for i_level in range(self.num_resolutions):
76
+ conv_block = nn.Module()
77
+ # res & attn
78
+ res_block = nn.ModuleList()
79
+ attn_block = nn.ModuleList()
80
+ block_in = ch*in_ch_mult[i_level]
81
+ block_out = ch*ch_mult[i_level]
82
+ for _ in range(self.num_res_blocks):
83
+ res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
84
+ block_in = block_out
85
+ if i_level == self.num_resolutions - 1:
86
+ attn_block.append(AttnBlock(block_in, norm_type))
87
+ conv_block.res = res_block
88
+ conv_block.attn = attn_block
89
+ # downsample
90
+ if i_level != self.num_resolutions-1:
91
+ conv_block.downsample = Downsample(block_in, resamp_with_conv)
92
+ self.conv_blocks.append(conv_block)
93
+
94
+ # middle
95
+ self.mid = nn.ModuleList()
96
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
97
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
98
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
99
+
100
+ # end
101
+ self.norm_out = Normalize(block_in, norm_type)
102
+ self.conv_out = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
103
+
104
+
105
+ def forward(self, x):
106
+ h = self.conv_in(x)
107
+ # downsampling
108
+ for i_level, block in enumerate(self.conv_blocks):
109
+ for i_block in range(self.num_res_blocks):
110
+ h = block.res[i_block](h)
111
+ if len(block.attn) > 0:
112
+ h = block.attn[i_block](h)
113
+ if i_level != self.num_resolutions - 1:
114
+ h = block.downsample(h)
115
+
116
+ # middle
117
+ for mid_block in self.mid:
118
+ h = mid_block(h)
119
+
120
+ # end
121
+ h = self.norm_out(h)
122
+ h = nonlinearity(h)
123
+ h = self.conv_out(h)
124
+ return h
125
+
126
+
127
+
128
+ class Decoder(nn.Module):
129
+ def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group",
130
+ dropout=0.0, resamp_with_conv=True, out_channels=3):
131
+ super().__init__()
132
+ self.num_resolutions = len(ch_mult)
133
+ self.num_res_blocks = num_res_blocks
134
+
135
+ block_in = ch*ch_mult[self.num_resolutions-1]
136
+ # z to block_in
137
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
138
+
139
+ # middle
140
+ self.mid = nn.ModuleList()
141
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
142
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
143
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
144
+
145
+ # upsampling
146
+ self.conv_blocks = nn.ModuleList()
147
+ for i_level in reversed(range(self.num_resolutions)):
148
+ conv_block = nn.Module()
149
+ # res & attn
150
+ res_block = nn.ModuleList()
151
+ attn_block = nn.ModuleList()
152
+ block_out = ch*ch_mult[i_level]
153
+ for _ in range(self.num_res_blocks + 1):
154
+ res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
155
+ block_in = block_out
156
+ if i_level == self.num_resolutions - 1:
157
+ attn_block.append(AttnBlock(block_in, norm_type))
158
+ conv_block.res = res_block
159
+ conv_block.attn = attn_block
160
+ # downsample
161
+ if i_level != 0:
162
+ conv_block.upsample = Upsample(block_in, resamp_with_conv)
163
+ self.conv_blocks.append(conv_block)
164
+
165
+ # end
166
+ self.norm_out = Normalize(block_in, norm_type)
167
+ self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
168
+
169
+ @property
170
+ def last_layer(self):
171
+ return self.conv_out.weight
172
+
173
+ def forward(self, z):
174
+ # z to block_in
175
+ h = self.conv_in(z)
176
+
177
+ # middle
178
+ for mid_block in self.mid:
179
+ h = mid_block(h)
180
+
181
+ # upsampling
182
+ for i_level, block in enumerate(self.conv_blocks):
183
+ for i_block in range(self.num_res_blocks + 1):
184
+ h = block.res[i_block](h)
185
+ if len(block.attn) > 0:
186
+ h = block.attn[i_block](h)
187
+ if i_level != self.num_resolutions - 1:
188
+ h = block.upsample(h)
189
+
190
+ # end
191
+ h = self.norm_out(h)
192
+ h = nonlinearity(h)
193
+ h = self.conv_out(h)
194
+ return h
195
+
196
+
197
+ class VectorQuantizer(nn.Module):
198
+ def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
199
+ super().__init__()
200
+ self.n_e = n_e
201
+ self.e_dim = e_dim
202
+ self.beta = beta
203
+ self.entropy_loss_ratio = entropy_loss_ratio
204
+ self.l2_norm = l2_norm
205
+ self.show_usage = show_usage
206
+
207
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
208
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
209
+ if self.l2_norm:
210
+ self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
211
+ if self.show_usage:
212
+ self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
213
+
214
+
215
+ def forward(self, z):
216
+ # reshape z -> (batch, height, width, channel) and flatten
217
+ z = torch.einsum('b c h w -> b h w c', z).contiguous()
218
+ z_flattened = z.view(-1, self.e_dim)
219
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
220
+
221
+ if self.l2_norm:
222
+ z = F.normalize(z, p=2, dim=-1)
223
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
224
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
225
+ else:
226
+ embedding = self.embedding.weight
227
+
228
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
229
+ torch.sum(embedding**2, dim=1) - 2 * \
230
+ torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))
231
+
232
+ min_encoding_indices = torch.argmin(d, dim=1)
233
+ z_q = embedding[min_encoding_indices].view(z.shape)
234
+ perplexity = None
235
+ min_encodings = None
236
+ vq_loss = None
237
+ commit_loss = None
238
+ entropy_loss = None
239
+ codebook_usage = 0
240
+
241
+ if self.show_usage and self.training:
242
+ cur_len = min_encoding_indices.shape[0]
243
+ self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
244
+ self.codebook_used[-cur_len:] = min_encoding_indices
245
+ codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e
246
+
247
+ # compute loss for embedding
248
+ if self.training:
249
+ vq_loss = torch.mean((z_q - z.detach()) ** 2)
250
+ commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
251
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
252
+
253
+ # preserve gradients
254
+ z_q = z + (z_q - z).detach()
255
+
256
+ # reshape back to match original input shape
257
+ z_q = torch.einsum('b h w c -> b c h w', z_q)
258
+
259
+ return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)
260
+
261
+ def get_codebook_entry(self, indices, shape=None, channel_first=True):
262
+ # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
263
+ if self.l2_norm:
264
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
265
+ else:
266
+ embedding = self.embedding.weight
267
+ z_q = embedding[indices] # (b*h*w, c)
268
+
269
+ if shape is not None:
270
+ if channel_first:
271
+ z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
272
+ # reshape back to match original input shape
273
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
274
+ else:
275
+ z_q = z_q.view(shape)
276
+ return z_q
277
+
278
+
279
+ class ResnetBlock(nn.Module):
280
+ def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group'):
281
+ super().__init__()
282
+ self.in_channels = in_channels
283
+ out_channels = in_channels if out_channels is None else out_channels
284
+ self.out_channels = out_channels
285
+ self.use_conv_shortcut = conv_shortcut
286
+
287
+ self.norm1 = Normalize(in_channels, norm_type)
288
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
289
+ self.norm2 = Normalize(out_channels, norm_type)
290
+ self.dropout = nn.Dropout(dropout)
291
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
292
+
293
+ if self.in_channels != self.out_channels:
294
+ if self.use_conv_shortcut:
295
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
296
+ else:
297
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
298
+
299
+ def forward(self, x):
300
+ h = x
301
+ h = self.norm1(h)
302
+ h = nonlinearity(h)
303
+ h = self.conv1(h)
304
+ h = self.norm2(h)
305
+ h = nonlinearity(h)
306
+ h = self.dropout(h)
307
+ h = self.conv2(h)
308
+
309
+ if self.in_channels != self.out_channels:
310
+ if self.use_conv_shortcut:
311
+ x = self.conv_shortcut(x)
312
+ else:
313
+ x = self.nin_shortcut(x)
314
+ return x+h
315
+
316
+
317
+ class AttnBlock(nn.Module):
318
+ def __init__(self, in_channels, norm_type='group'):
319
+ super().__init__()
320
+ self.norm = Normalize(in_channels, norm_type)
321
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
322
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
323
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
324
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
325
+
326
+
327
+ def forward(self, x):
328
+ h_ = x
329
+ h_ = self.norm(h_)
330
+ q = self.q(h_)
331
+ k = self.k(h_)
332
+ v = self.v(h_)
333
+
334
+ # compute attention
335
+ b,c,h,w = q.shape
336
+ q = q.reshape(b,c,h*w)
337
+ q = q.permute(0,2,1) # b,hw,c
338
+ k = k.reshape(b,c,h*w) # b,c,hw
339
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
340
+ w_ = w_ * (int(c)**(-0.5))
341
+ w_ = F.softmax(w_, dim=2)
342
+
343
+ # attend to values
344
+ v = v.reshape(b,c,h*w)
345
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
346
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
347
+ h_ = h_.reshape(b,c,h,w)
348
+
349
+ h_ = self.proj_out(h_)
350
+
351
+ return x+h_
352
+
353
+
354
+ def nonlinearity(x):
355
+ # swish
356
+ return x*torch.sigmoid(x)
357
+
358
+
359
+ def Normalize(in_channels, norm_type='group'):
360
+ assert norm_type in ['group', 'batch']
361
+ if norm_type == 'group':
362
+ return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
363
+ elif norm_type == 'batch':
364
+ return nn.SyncBatchNorm(in_channels)
365
+
366
+
367
+ class Upsample(nn.Module):
368
+ def __init__(self, in_channels, with_conv):
369
+ super().__init__()
370
+ self.with_conv = with_conv
371
+ if self.with_conv:
372
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
373
+
374
+ def forward(self, x):
375
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
376
+ if self.with_conv:
377
+ x = self.conv(x)
378
+ return x
379
+
380
+
381
+ class Downsample(nn.Module):
382
+ def __init__(self, in_channels, with_conv):
383
+ super().__init__()
384
+ self.with_conv = with_conv
385
+ if self.with_conv:
386
+ # no asymmetric padding in torch conv, must do it ourselves
387
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
388
+
389
+ def forward(self, x):
390
+ if self.with_conv:
391
+ pad = (0,1,0,1)
392
+ x = F.pad(x, pad, mode="constant", value=0)
393
+ x = self.conv(x)
394
+ else:
395
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
396
+ return x
397
+
398
+
399
+ def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
400
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
401
+ flat_affinity /= temperature
402
+ probs = F.softmax(flat_affinity, dim=-1)
403
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
404
+ if loss_type == "softmax":
405
+ target_probs = probs
406
+ else:
407
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
408
+ avg_probs = torch.mean(target_probs, dim=0)
409
+ avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
410
+ sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
411
+ loss = sample_entropy - avg_entropy
412
+ return loss
413
+
414
+
415
+ #################################################################################
416
+ # VQ Model Configs #
417
+ #################################################################################
418
+ def VQ_8(**kwargs):
419
+ return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs))
420
+
421
+ def VQ_16(**kwargs):
422
+ return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))
423
+
424
+ VQ_models = {'VQ-16': VQ_16, 'VQ-8': VQ_8}
tokenizer_image/vq_train.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/train.py
3
+ # nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
4
+ import torch
5
+ # the first flag below was False when we tested this script but True makes A100 training a lot faster:
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cudnn.allow_tf32 = True
8
+ import torch.distributed as dist
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torch.utils.data.distributed import DistributedSampler
12
+ from torchvision.datasets import ImageFolder
13
+ from torchvision import transforms
14
+
15
+ import os
16
+ import time
17
+ import argparse
18
+ from glob import glob
19
+ from copy import deepcopy
20
+
21
+ from utils.logger import create_logger
22
+ from utils.distributed import init_distributed_mode
23
+ from utils.ema import update_ema, requires_grad
24
+ from dataset.augmentation import random_crop_arr
25
+ from dataset.build import build_dataset
26
+ from tokenizer.tokenizer_image.vq_model import VQ_models
27
+ from tokenizer.tokenizer_image.vq_loss import VQLoss
28
+
29
+ import warnings
30
+ warnings.filterwarnings('ignore')
31
+
32
+ #################################################################################
33
+ # Training Loop #
34
+ #################################################################################
35
+
36
+ def main(args):
37
+ """
38
+ Trains a new model.
39
+ """
40
+ assert torch.cuda.is_available(), "Training currently requires at least one GPU."
41
+
42
+ # Setup DDP:
43
+ init_distributed_mode(args)
44
+ assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
45
+ rank = dist.get_rank()
46
+ device = rank % torch.cuda.device_count()
47
+ seed = args.global_seed * dist.get_world_size() + rank
48
+ torch.manual_seed(seed)
49
+ torch.cuda.set_device(device)
50
+
51
+ # Setup an experiment folder:
52
+ if rank == 0:
53
+ os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
54
+ experiment_index = len(glob(f"{args.results_dir}/*"))
55
+ model_string_name = args.vq_model.replace("/", "-")
56
+ experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder
57
+ checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
58
+ os.makedirs(checkpoint_dir, exist_ok=True)
59
+ logger = create_logger(experiment_dir)
60
+ logger.info(f"Experiment directory created at {experiment_dir}")
61
+
62
+ time_record = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
63
+ cloud_results_dir = f"{args.cloud_save_path}/{time_record}"
64
+ cloud_checkpoint_dir = f"{cloud_results_dir}/{experiment_index:03d}-{model_string_name}/checkpoints"
65
+ os.makedirs(cloud_checkpoint_dir, exist_ok=True)
66
+ logger.info(f"Experiment directory created in cloud at {cloud_checkpoint_dir}")
67
+
68
+ else:
69
+ logger = create_logger(None)
70
+
71
+ # training args
72
+ logger.info(f"{args}")
73
+
74
+ # training env
75
+ logger.info(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
76
+
77
+ # create and load model
78
+ vq_model = VQ_models[args.vq_model](
79
+ codebook_size=args.codebook_size,
80
+ codebook_embed_dim=args.codebook_embed_dim,
81
+ commit_loss_beta=args.commit_loss_beta,
82
+ entropy_loss_ratio=args.entropy_loss_ratio,
83
+ dropout_p=args.dropout_p,
84
+ )
85
+ logger.info(f"VQ Model Parameters: {sum(p.numel() for p in vq_model.parameters()):,}")
86
+ if args.ema:
87
+ ema = deepcopy(vq_model).to(device) # Create an EMA of the model for use after training
88
+ requires_grad(ema, False)
89
+ logger.info(f"VQ Model EMA Parameters: {sum(p.numel() for p in ema.parameters()):,}")
90
+ vq_model = vq_model.to(device)
91
+
92
+ vq_loss = VQLoss(
93
+ disc_start=args.disc_start,
94
+ disc_weight=args.disc_weight,
95
+ disc_type=args.disc_type,
96
+ disc_loss=args.disc_loss,
97
+ gen_adv_loss=args.gen_loss,
98
+ image_size=args.image_size,
99
+ perceptual_weight=args.perceptual_weight,
100
+ reconstruction_weight=args.reconstruction_weight,
101
+ reconstruction_loss=args.reconstruction_loss,
102
+ codebook_weight=args.codebook_weight,
103
+ ).to(device)
104
+ logger.info(f"Discriminator Parameters: {sum(p.numel() for p in vq_loss.discriminator.parameters()):,}")
105
+
106
+ # initialize a GradScaler. If enabled=False scaler is a no-op
107
+ scaler = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision =='fp16'))
108
+ scaler_disc = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision =='fp16'))
109
+ # Setup optimizer
110
+ optimizer = torch.optim.Adam(vq_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
111
+ optimizer_disc = torch.optim.Adam(vq_loss.discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
112
+
113
+ # Setup data:
114
+ transform = transforms.Compose([
115
+ transforms.Lambda(lambda pil_image: random_crop_arr(pil_image, args.image_size)),
116
+ transforms.RandomHorizontalFlip(),
117
+ transforms.ToTensor(),
118
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
119
+ ])
120
+ dataset = build_dataset(args, transform=transform)
121
+ sampler = DistributedSampler(
122
+ dataset,
123
+ num_replicas=dist.get_world_size(),
124
+ rank=rank,
125
+ shuffle=True,
126
+ seed=args.global_seed
127
+ )
128
+ loader = DataLoader(
129
+ dataset,
130
+ batch_size=int(args.global_batch_size // dist.get_world_size()),
131
+ shuffle=False,
132
+ sampler=sampler,
133
+ num_workers=args.num_workers,
134
+ pin_memory=True,
135
+ drop_last=True
136
+ )
137
+ logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")
138
+
139
+
140
+ # Prepare models for training:
141
+ if args.vq_ckpt:
142
+ checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
143
+ vq_model.load_state_dict(checkpoint["model"])
144
+ if args.ema:
145
+ ema.load_state_dict(checkpoint["ema"])
146
+ optimizer.load_state_dict(checkpoint["optimizer"])
147
+ vq_loss.discriminator.load_state_dict(checkpoint["discriminator"])
148
+ optimizer_disc.load_state_dict(checkpoint["optimizer_disc"])
149
+ if not args.finetune:
150
+ train_steps = checkpoint["steps"] if "steps" in checkpoint else int(args.vq_ckpt.split('/')[-1].split('.')[0])
151
+ start_epoch = int(train_steps / int(len(dataset) / args.global_batch_size))
152
+ train_steps = int(start_epoch * int(len(dataset) / args.global_batch_size))
153
+ else:
154
+ train_steps = 0
155
+ start_epoch = 0
156
+ del checkpoint
157
+ logger.info(f"Resume training from checkpoint: {args.vq_ckpt}")
158
+ logger.info(f"Initial state: steps={train_steps}, epochs={start_epoch}")
159
+ else:
160
+ train_steps = 0
161
+ start_epoch = 0
162
+ if args.ema:
163
+ update_ema(ema, vq_model, decay=0) # Ensure EMA is initialized with synced weights
164
+
165
+ if args.compile:
166
+ logger.info("compiling the model... (may take several minutes)")
167
+ vq_model = torch.compile(vq_model) # requires PyTorch 2.0
168
+
169
+ vq_model = DDP(vq_model.to(device), device_ids=[args.gpu])
170
+ vq_model.train()
171
+ if args.ema:
172
+ ema.eval() # EMA model should always be in eval mode
173
+ vq_loss = DDP(vq_loss.to(device), device_ids=[args.gpu])
174
+ vq_loss.train()
175
+
176
+ ptdtype = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.mixed_precision]
177
+
178
+ # Variables for monitoring/logging purposes:
179
+ log_steps = 0
180
+ running_loss = 0
181
+ start_time = time.time()
182
+
183
+ logger.info(f"Training for {args.epochs} epochs...")
184
+ for epoch in range(start_epoch, args.epochs):
185
+ sampler.set_epoch(epoch)
186
+ logger.info(f"Beginning epoch {epoch}...")
187
+ for x, y in loader:
188
+ imgs = x.to(device, non_blocking=True)
189
+
190
+ # generator training
191
+ optimizer.zero_grad()
192
+ with torch.cuda.amp.autocast(dtype=ptdtype):
193
+ recons_imgs, codebook_loss = vq_model(imgs)
194
+ loss_gen = vq_loss(codebook_loss, imgs, recons_imgs, optimizer_idx=0, global_step=train_steps+1,
195
+ last_layer=vq_model.module.decoder.last_layer,
196
+ logger=logger, log_every=args.log_every)
197
+ scaler.scale(loss_gen).backward()
198
+ if args.max_grad_norm != 0.0:
199
+ scaler.unscale_(optimizer)
200
+ torch.nn.utils.clip_grad_norm_(vq_model.parameters(), args.max_grad_norm)
201
+ scaler.step(optimizer)
202
+ scaler.update()
203
+ if args.ema:
204
+ update_ema(ema, vq_model.module._orig_mod if args.compile else vq_model.module)
205
+
206
+ # discriminator training
207
+ optimizer_disc.zero_grad()
208
+ with torch.cuda.amp.autocast(dtype=ptdtype):
209
+ loss_disc = vq_loss(codebook_loss, imgs, recons_imgs, optimizer_idx=1, global_step=train_steps+1,
210
+ logger=logger, log_every=args.log_every)
211
+ scaler_disc.scale(loss_disc).backward()
212
+ if args.max_grad_norm != 0.0:
213
+ scaler_disc.unscale_(optimizer_disc)
214
+ torch.nn.utils.clip_grad_norm_(vq_loss.module.discriminator.parameters(), args.max_grad_norm)
215
+ scaler_disc.step(optimizer_disc)
216
+ scaler_disc.update()
217
+
218
+ # # Log loss values:
219
+ running_loss += loss_gen.item() + loss_disc.item()
220
+
221
+ log_steps += 1
222
+ train_steps += 1
223
+ if train_steps % args.log_every == 0:
224
+ # Measure training speed:
225
+ torch.cuda.synchronize()
226
+ end_time = time.time()
227
+ steps_per_sec = log_steps / (end_time - start_time)
228
+ # Reduce loss history over all processes:
229
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
230
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
231
+ avg_loss = avg_loss.item() / dist.get_world_size()
232
+ logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
233
+ # Reset monitoring variables:
234
+ running_loss = 0
235
+ log_steps = 0
236
+ start_time = time.time()
237
+
238
+ # Save checkpoint:
239
+ if train_steps % args.ckpt_every == 0 and train_steps > 0:
240
+ if rank == 0:
241
+ if args.compile:
242
+ model_weight = vq_model.module._orig_mod.state_dict()
243
+ else:
244
+ model_weight = vq_model.module.state_dict()
245
+ checkpoint = {
246
+ "model": model_weight,
247
+ "optimizer": optimizer.state_dict(),
248
+ "discriminator": vq_loss.module.discriminator.state_dict(),
249
+ "optimizer_disc": optimizer_disc.state_dict(),
250
+ "steps": train_steps,
251
+ "args": args
252
+ }
253
+ if args.ema:
254
+ checkpoint["ema"] = ema.state_dict()
255
+ if not args.no_local_save:
256
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
257
+ torch.save(checkpoint, checkpoint_path)
258
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
259
+
260
+ cloud_checkpoint_path = f"{cloud_checkpoint_dir}/{train_steps:07d}.pt"
261
+ torch.save(checkpoint, cloud_checkpoint_path)
262
+ logger.info(f"Saved checkpoint in cloud to {cloud_checkpoint_path}")
263
+ dist.barrier()
264
+
265
+ vq_model.eval() # important! This disables randomized embedding dropout
266
+ # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
267
+
268
+ logger.info("Done!")
269
+ dist.destroy_process_group()
270
+
271
+
272
+
273
+ if __name__ == "__main__":
274
+ parser = argparse.ArgumentParser()
275
+ parser.add_argument("--data-path", type=str, required=True)
276
+ parser.add_argument("--data-face-path", type=str, default=None, help="face datasets to improve vq model")
277
+ parser.add_argument("--cloud-save-path", type=str, required=True, help='please specify a cloud disk path, if not, local path')
278
+ parser.add_argument("--no-local-save", action='store_true', help='no save checkpoints to local path for limited disk volume')
279
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
280
+ parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for resume training")
281
+ parser.add_argument("--finetune", action='store_true', help="finetune a pre-trained vq model")
282
+ parser.add_argument("--ema", action='store_true', help="whether using ema training")
283
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
284
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
285
+ parser.add_argument("--codebook-l2-norm", action='store_true', default=True, help="l2 norm codebook")
286
+ parser.add_argument("--codebook-weight", type=float, default=1.0, help="codebook loss weight for vector quantization")
287
+ parser.add_argument("--entropy-loss-ratio", type=float, default=0.0, help="entropy loss ratio in codebook loss")
288
+ parser.add_argument("--commit-loss-beta", type=float, default=0.25, help="commit loss beta in codebook loss")
289
+ parser.add_argument("--reconstruction-weight", type=float, default=1.0, help="reconstruction loss weight of image pixel")
290
+ parser.add_argument("--reconstruction-loss", type=str, default='l2', help="reconstruction loss type of image pixel")
291
+ parser.add_argument("--perceptual-weight", type=float, default=1.0, help="perceptual loss weight of LPIPS")
292
+ parser.add_argument("--disc-weight", type=float, default=0.5, help="discriminator loss weight for gan training")
293
+ parser.add_argument("--disc-start", type=int, default=20000, help="iteration to start discriminator training and loss")
294
+ parser.add_argument("--disc-type", type=str, choices=['patchgan', 'stylegan'], default='patchgan', help="discriminator type")
295
+ parser.add_argument("--disc-loss", type=str, choices=['hinge', 'vanilla', 'non-saturating'], default='hinge', help="discriminator loss")
296
+ parser.add_argument("--gen-loss", type=str, choices=['hinge', 'non-saturating'], default='hinge', help="generator loss for gan training")
297
+ parser.add_argument("--compile", action='store_true', default=False)
298
+ parser.add_argument("--dropout-p", type=float, default=0.0, help="dropout_p")
299
+ parser.add_argument("--results-dir", type=str, default="results_tokenizer_image")
300
+ parser.add_argument("--dataset", type=str, default='imagenet')
301
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
302
+ parser.add_argument("--epochs", type=int, default=50)
303
+ parser.add_argument("--lr", type=float, default=1e-4)
304
+ parser.add_argument("--weight-decay", type=float, default=5e-2, help="Weight decay to use.")
305
+ parser.add_argument("--beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
306
+ parser.add_argument("--beta2", type=float, default=0.95, help="The beta2 parameter for the Adam optimizer.")
307
+ parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
308
+ parser.add_argument("--global-batch-size", type=int, default=128)
309
+ parser.add_argument("--global-seed", type=int, default=0)
310
+ parser.add_argument("--num-workers", type=int, default=16)
311
+ parser.add_argument("--log-every", type=int, default=100)
312
+ parser.add_argument("--ckpt-every", type=int, default=5000)
313
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
314
+ parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
315
+ args = parser.parse_args()
316
+ main(args)