ShoufaChen commited on
Commit
4bfb360
·
1 Parent(s): 4d20c2f
app.py CHANGED
@@ -8,12 +8,12 @@ torch.backends.cudnn.allow_tf32 = True
8
  torch.set_float32_matmul_precision('high')
9
  setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
10
  setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
11
-
12
  import time
13
  import argparse
14
  from tokenizer_image.vq_model import VQ_models
15
- from models.gpt import GPT_models
16
- from models.generate import generate
17
 
18
  device = "cuda"
19
 
@@ -38,46 +38,16 @@ def load_model(args):
38
  del checkpoint
39
  print(f"image tokenizer is loaded")
40
 
41
- # create and load gpt model
42
- precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
43
- latent_size = image_size // args.downsample_size
44
- gpt_model = GPT_models[args.gpt_model](
45
- vocab_size=args.codebook_size,
46
- block_size=latent_size ** 2,
47
- num_classes=args.num_classes,
48
- cls_token_num=args.cls_token_num,
49
- model_type=args.gpt_type,
50
- ).to(device=device, dtype=precision)
51
-
52
- checkpoint = torch.load(f"{ckpt_folder}{gpt_ckpt}", map_location="cpu")
53
- if args.from_fsdp: # fspd
54
- model_weight = checkpoint
55
- elif "model" in checkpoint: # ddp
56
- model_weight = checkpoint["model"]
57
- elif "module" in checkpoint: # deepspeed
58
- model_weight = checkpoint["module"]
59
- elif "state_dict" in checkpoint:
60
- model_weight = checkpoint["state_dict"]
61
- else:
62
- raise Exception("please check model weight")
63
- # if 'freqs_cis' in model_weight:
64
- # model_weight.pop('freqs_cis')
65
- gpt_model.load_state_dict(model_weight, strict=False)
66
- gpt_model.eval()
67
- del checkpoint
68
  print(f"gpt model is loaded")
69
-
70
- if args.compile:
71
- print(f"compiling the model...")
72
- gpt_model = torch.compile(
73
- gpt_model,
74
- mode="reduce-overhead",
75
- fullgraph=True
76
- ) # requires PyTorch 2.0 (optional)
77
- else:
78
- print(f"no need to compile model in demo")
79
-
80
- return vq_model, gpt_model, image_size
81
 
82
 
83
  def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
@@ -85,20 +55,29 @@ def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
85
  latent_size = image_size // args.downsample_size
86
  # Labels to condition the model with (feel free to change):
87
  class_labels = [class_label for _ in range(n)]
88
- c_indices = torch.tensor(class_labels, device=device)
89
  qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
90
 
 
 
 
 
 
 
 
 
 
91
  t1 = time.time()
92
  torch.manual_seed(seed)
93
- index_sample = generate(
94
- gpt_model, c_indices, latent_size ** 2,
95
- cfg_scale=cfg_scale, cfg_interval=args.cfg_interval,
96
- temperature=temperature, top_k=top_k,
97
- top_p=top_p, sample_logits=True,
98
- )
99
  sampling_time = time.time() - t1
100
  print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
101
 
 
 
 
102
  t2 = time.time()
103
  samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
104
  decoder_time = time.time() - t2
@@ -110,7 +89,7 @@ def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
110
 
111
 
112
  parser = argparse.ArgumentParser()
113
- parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
114
  parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
115
  parser.add_argument("--from-fsdp", action='store_true')
116
  parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
@@ -129,7 +108,7 @@ parser.add_argument("--temperature", type=float, default=1.0, help="temperature
129
  parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
130
  args = parser.parse_args()
131
 
132
- vq_model, gpt_model, image_size = load_model(args)
133
 
134
  with gr.Blocks() as demo:
135
  gr.Markdown("<h1 style='text-align: center'>Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation</h1>")
 
8
  torch.set_float32_matmul_precision('high')
9
  setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
10
  setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
11
+ from vllm import SamplingParams
12
  import time
13
  import argparse
14
  from tokenizer_image.vq_model import VQ_models
15
+ # from models.generate import generate
16
+ from serve.llm import LLM
17
 
18
  device = "cuda"
19
 
 
38
  del checkpoint
39
  print(f"image tokenizer is loaded")
40
 
41
+ # Create an LLM.
42
+ args.image_size = image_size
43
+ args.gpt_ckpt = f"{ckpt_folder}{gpt_ckpt}"
44
+ llm = LLM(
45
+ args=args,
46
+ model='serve/fake_json/{}.json'.format(args.gpt_model),
47
+ gpu_memory_utilization=0.6,
48
+ skip_tokenizer_init=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  print(f"gpt model is loaded")
50
+ return vq_model, llm, image_size
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
 
55
  latent_size = image_size // args.downsample_size
56
  # Labels to condition the model with (feel free to change):
57
  class_labels = [class_label for _ in range(n)]
 
58
  qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
59
 
60
+ prompt_token_ids = [[cind] for cind in class_labels]
61
+ if cfg_scale > 1.0:
62
+ prompt_token_ids.extend([[args.num_classes] for _ in range(len(prompt_token_ids))])
63
+
64
+ # Create a sampling params object.
65
+ sampling_params = SamplingParams(
66
+ temperature=temperature, top_p=top_p, top_k=top_k,
67
+ max_tokens=latent_size ** 2)
68
+
69
  t1 = time.time()
70
  torch.manual_seed(seed)
71
+ outputs = llm.generate(
72
+ prompt_token_ids=prompt_token_ids,
73
+ sampling_params=sampling_params,
74
+ use_tqdm=False)
 
 
75
  sampling_time = time.time() - t1
76
  print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
77
 
78
+ index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
79
+ if args.cfg_scale > 1.0:
80
+ index_sample = index_sample[:len(class_labels)]
81
  t2 = time.time()
82
  samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
83
  decoder_time = time.time() - t2
 
89
 
90
 
91
  parser = argparse.ArgumentParser()
92
+ parser.add_argument("--gpt-model", type=str, default="GPT-XL")
93
  parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
94
  parser.add_argument("--from-fsdp", action='store_true')
95
  parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
 
108
  parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
109
  args = parser.parse_args()
110
 
111
+ vq_model, llm, image_size = load_model(args)
112
 
113
  with gr.Blocks() as demo:
114
  gr.Markdown("<h1 style='text-align: center'>Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation</h1>")
app_naive.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import gradio as gr
3
+ from imagenet_en_cn import IMAGENET_1K_CLASSES
4
+ from huggingface_hub import hf_hub_download
5
+ import torch
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cudnn.allow_tf32 = True
8
+ torch.set_float32_matmul_precision('high')
9
+ setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
10
+ setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
11
+
12
+ import time
13
+ import argparse
14
+ from tokenizer_image.vq_model import VQ_models
15
+ from models.gpt import GPT_models
16
+ from models.generate import generate
17
+
18
+ device = "cuda"
19
+
20
+ model2ckpt = {
21
+ "GPT-XL": ("vq_ds16_c2i.pt", "c2i_XL_384.pt", 384),
22
+ "GPT-B": ("vq_ds16_c2i.pt", "c2i_B_256.pt", 256),
23
+ }
24
+
25
+ def load_model(args):
26
+ ckpt_folder = "./"
27
+ vq_ckpt, gpt_ckpt, image_size = model2ckpt[args.gpt_model]
28
+ hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=vq_ckpt, local_dir=ckpt_folder)
29
+ hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=gpt_ckpt, local_dir=ckpt_folder)
30
+ # create and load model
31
+ vq_model = VQ_models[args.vq_model](
32
+ codebook_size=args.codebook_size,
33
+ codebook_embed_dim=args.codebook_embed_dim)
34
+ vq_model.to(device)
35
+ vq_model.eval()
36
+ checkpoint = torch.load(f"{ckpt_folder}{vq_ckpt}", map_location="cpu")
37
+ vq_model.load_state_dict(checkpoint["model"])
38
+ del checkpoint
39
+ print(f"image tokenizer is loaded")
40
+
41
+ # create and load gpt model
42
+ precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
43
+ latent_size = image_size // args.downsample_size
44
+ gpt_model = GPT_models[args.gpt_model](
45
+ vocab_size=args.codebook_size,
46
+ block_size=latent_size ** 2,
47
+ num_classes=args.num_classes,
48
+ cls_token_num=args.cls_token_num,
49
+ model_type=args.gpt_type,
50
+ ).to(device=device, dtype=precision)
51
+
52
+ checkpoint = torch.load(f"{ckpt_folder}{gpt_ckpt}", map_location="cpu")
53
+ if args.from_fsdp: # fspd
54
+ model_weight = checkpoint
55
+ elif "model" in checkpoint: # ddp
56
+ model_weight = checkpoint["model"]
57
+ elif "module" in checkpoint: # deepspeed
58
+ model_weight = checkpoint["module"]
59
+ elif "state_dict" in checkpoint:
60
+ model_weight = checkpoint["state_dict"]
61
+ else:
62
+ raise Exception("please check model weight")
63
+ # if 'freqs_cis' in model_weight:
64
+ # model_weight.pop('freqs_cis')
65
+ gpt_model.load_state_dict(model_weight, strict=False)
66
+ gpt_model.eval()
67
+ del checkpoint
68
+ print(f"gpt model is loaded")
69
+
70
+ if args.compile:
71
+ print(f"compiling the model...")
72
+ gpt_model = torch.compile(
73
+ gpt_model,
74
+ mode="reduce-overhead",
75
+ fullgraph=True
76
+ ) # requires PyTorch 2.0 (optional)
77
+ else:
78
+ print(f"no need to compile model in demo")
79
+
80
+ return vq_model, gpt_model, image_size
81
+
82
+
83
+ def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
84
+ n = 4
85
+ latent_size = image_size // args.downsample_size
86
+ # Labels to condition the model with (feel free to change):
87
+ class_labels = [class_label for _ in range(n)]
88
+ c_indices = torch.tensor(class_labels, device=device)
89
+ qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
90
+
91
+ t1 = time.time()
92
+ torch.manual_seed(seed)
93
+ index_sample = generate(
94
+ gpt_model, c_indices, latent_size ** 2,
95
+ cfg_scale=cfg_scale, cfg_interval=args.cfg_interval,
96
+ temperature=temperature, top_k=top_k,
97
+ top_p=top_p, sample_logits=True,
98
+ )
99
+ sampling_time = time.time() - t1
100
+ print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
101
+
102
+ t2 = time.time()
103
+ samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
104
+ decoder_time = time.time() - t2
105
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
106
+ # Convert to PIL.Image format:
107
+ samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy()
108
+ samples = [Image.fromarray(sample) for sample in samples]
109
+ return samples
110
+
111
+
112
+ parser = argparse.ArgumentParser()
113
+ parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
114
+ parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
115
+ parser.add_argument("--from-fsdp", action='store_true')
116
+ parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
117
+ parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
118
+ parser.add_argument("--compile", action='store_true', default=False)
119
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
120
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
121
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
122
+ parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
123
+ parser.add_argument("--num-classes", type=int, default=1000)
124
+ parser.add_argument("--cfg-scale", type=float, default=4.0)
125
+ parser.add_argument("--cfg-interval", type=float, default=-1)
126
+ parser.add_argument("--seed", type=int, default=0)
127
+ parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
128
+ parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
129
+ parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
130
+ args = parser.parse_args()
131
+
132
+ vq_model, gpt_model, image_size = load_model(args)
133
+
134
+ with gr.Blocks() as demo:
135
+ gr.Markdown("<h1 style='text-align: center'>Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation</h1>")
136
+
137
+ with gr.Tabs():
138
+ with gr.TabItem('Generate'):
139
+ with gr.Row():
140
+ with gr.Column():
141
+ # with gr.Row():
142
+ # image_size = gr.Radio(choices=[384], value=384, label='Peize Model Resolution')
143
+ with gr.Row():
144
+ i1k_class = gr.Dropdown(
145
+ list(IMAGENET_1K_CLASSES.values()),
146
+ value='Eskimo dog, husky [爱斯基摩犬,哈士奇]',
147
+ type="index", label='ImageNet-1K Class'
148
+ )
149
+ cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
150
+ top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=4000, label='Top-K')
151
+ top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P")
152
+ temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature')
153
+ seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
154
+ # seed = gr.Number(value=0, label='Seed')
155
+ button = gr.Button("Generate", variant="primary")
156
+ with gr.Column():
157
+ output = gr.Gallery(label='Generated Images', height=700)
158
+ button.click(infer, inputs=[cfg_scale, top_k, top_p, temperature, i1k_class, seed], outputs=[output])
159
+ demo.queue()
160
+ demo.launch(debug=True)
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- torch
 
 
1
+ vllm==0.4.1
2
+ torchvision==0.17.1
serve/README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## serving by vLLM
2
+
3
+ ### Install
4
+ ```
5
+ pip install vllm==0.4.1
6
+ pip install torchvision==0.17.1
7
+ ```
8
+
9
+ ### Demo
10
+ ```
11
+ cd ${THIS_REPO_ROOT}
12
+ python3 autoregressive/serve/sample_c2i.py --vq-ckpt /path/to/vq_ds16size16384dim8.pt --gpt-ckpt /path/to/GPT-B/checkpoints/1500000.pt --gpt-model GPT-B
13
+
14
+ ```
15
+
16
+
17
+ ### Comparison (A100)
18
+
19
+ Method | params | baseline(s) | vllm(s) | speed-up ratio
20
+ --- |:---:|:---:|:---:|:---:
21
+ GPT-B | 100M | 7.80 | 2.39 | 326 %
22
+ GPT-L | 300M | 13.72 | 3.48 | 380 %
23
+ GPT-XL | 700M | 19.76 | 4.84 | 408 %
24
+ GPT-XXL | 1.4B | 26.38 | 6.36 | 414 %
25
+ GPT-3B | 3.1B | - | - | -
26
+
27
+
28
+ ```
29
+ ### GPT-B
30
+ # 7.80 seconds
31
+ python3 autoregressive/sample/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-04-24-20-56-19/002-GPT-B/checkpoints/1500000.pt
32
+
33
+ # 2.39 seconds
34
+ python3 autoregressive/serve/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-04-24-20-56-19/002-GPT-B/checkpoints/1500000.pt
35
+
36
+
37
+ ### GPT-L
38
+ # 13.72 seconds
39
+ python3 autoregressive/sample/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-04-27-14-27-57/011-GPT-L/checkpoints/1500000.pt --gpt-model GPT-L
40
+
41
+ # 3.48 seconds
42
+ python3 autoregressive/serve/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-04-27-14-27-57/011-GPT-L/checkpoints/1500000.pt --gpt-model GPT-L
43
+
44
+
45
+ ### GPT-XL
46
+ # 19.76 seconds
47
+ python3 autoregressive/sample/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-05-05-13-15-40/000-GPT-XL/checkpoints/1500000.pt --gpt-model GPT-XL
48
+
49
+ # 4.84 seconds
50
+ python3 autoregressive/serve/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-05-05-13-15-40/000-GPT-XL/checkpoints/1500000.pt --gpt-model GPT-XL
51
+
52
+
53
+ ### GPT-XXL
54
+ # 26.38 seconds
55
+ python3 autoregressive/sample/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/20240506150815-GPT-XXXL/0125000/consolidated.pth --from-fsdp --gpt-model GPT-XXXL
56
+
57
+ # 6.36 seconds
58
+ python3 autoregressive/serve/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/20240506150815-GPT-XXXL/0125000/consolidated.pth --from-fsdp --gpt-model GPT-XXXL
59
+
60
+
61
+ ```
62
+
63
+ In 3B model, head size 100 is not supported by PagedAttention, supported head sizes are: [64, 80, 96, 112, 128, 256]
serve/gpt_model.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from vllm.model_executor.layers.layernorm import RMSNorm
8
+ from vllm.model_executor.layers.activation import SiluAndMul
9
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
10
+ from vllm.sequence import SamplerOutput
11
+
12
+ from vllm.attention import AttentionMetadata
13
+ from vllm.attention import Attention as pagedAttention
14
+
15
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
16
+ from serve.sampler import Sampler
17
+
18
+ def find_multiple(n: int, k: int):
19
+ if n % k == 0:
20
+ return n
21
+ return n + k - (n % k)
22
+
23
+ @dataclass
24
+ class ModelArgs:
25
+ dim: int = 4096
26
+ n_layer: int = 32
27
+ n_head: int = 32
28
+ n_kv_head: Optional[int] = None
29
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
30
+ ffn_dim_multiplier: Optional[float] = None
31
+ rope_base: float = 10000
32
+ norm_eps: float = 1e-5
33
+ initializer_range: float = 0.02
34
+
35
+ num_classes: int = 1000
36
+ class_dropout_prob: float = 0.1
37
+ model_type: str = 'c2i'
38
+ cfg_scale: float = 4.0
39
+
40
+ vocab_size: int = 16384
41
+ cls_token_num: int = 1
42
+ block_size: int = 256
43
+ max_batch_size: int = 32
44
+ max_seq_len: int = 2048
45
+
46
+
47
+ #################################################################################
48
+ # Embedding Layers for Class Labels #
49
+ #################################################################################
50
+ class LabelEmbedder(nn.Module):
51
+ """
52
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
53
+ """
54
+ def __init__(self, num_classes, hidden_size, dropout_prob):
55
+ super().__init__()
56
+ use_cfg_embedding = dropout_prob > 0
57
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
58
+ self.num_classes = num_classes
59
+ self.dropout_prob = dropout_prob
60
+
61
+ # def token_drop(self, labels, force_drop_ids=None):
62
+ # """
63
+ # Drops labels to enable classifier-free guidance.
64
+ # """
65
+ # if force_drop_ids is None:
66
+ # drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
67
+ # else:
68
+ # drop_ids = force_drop_ids == 1
69
+ # labels = torch.where(drop_ids, self.num_classes, labels)
70
+ # return labels
71
+
72
+ # def forward(self, labels, train, force_drop_ids=None):
73
+ def forward(self, labels):
74
+ # use_dropout = self.dropout_prob > 0
75
+ # if (train and use_dropout) or (force_drop_ids is not None):
76
+ # labels = self.token_drop(labels, force_drop_ids)
77
+ embeddings = self.embedding_table(labels)
78
+ return embeddings
79
+
80
+
81
+ #################################################################################
82
+ # GPT Model #
83
+ #################################################################################
84
+ # class RMSNorm(torch.nn.Module):
85
+ # def __init__(self, dim: int, eps: float = 1e-5):
86
+ # super().__init__()
87
+ # self.eps = eps
88
+ # self.weight = nn.Parameter(torch.ones(dim))
89
+
90
+ # def _norm(self, x):
91
+ # return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
92
+
93
+ # def forward(self, x):
94
+ # output = self._norm(x.float()).type_as(x)
95
+ # return output * self.weight
96
+
97
+
98
+ class FeedForward(nn.Module):
99
+ def __init__(self, config: ModelArgs):
100
+ super().__init__()
101
+ hidden_dim = 4 * config.dim
102
+ hidden_dim = int(2 * hidden_dim / 3)
103
+ # custom dim factor multiplier
104
+ if config.ffn_dim_multiplier is not None:
105
+ hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
106
+ hidden_dim = find_multiple(hidden_dim, config.multiple_of)
107
+
108
+ # self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
109
+ # self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
110
+ self.w_merged = nn.Linear(config.dim, hidden_dim * 2, bias=False)
111
+ self.act_fn = SiluAndMul()
112
+
113
+ self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
114
+ # self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
115
+
116
+ # def forward(self, x):
117
+ # return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
118
+
119
+ def forward(self, x):
120
+ x = self.w_merged(x)
121
+ x = self.act_fn(x)
122
+ x = self.w2(x)
123
+ # return self.ffn_dropout(x)
124
+ return x
125
+
126
+
127
+ class Attention(nn.Module):
128
+ def __init__(self, config: ModelArgs):
129
+ super().__init__()
130
+ assert config.dim % config.n_head == 0
131
+ self.dim = config.dim
132
+ self.head_dim = config.dim // config.n_head
133
+ self.n_head = config.n_head
134
+ self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
135
+ total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
136
+
137
+ # key, query, value projections for all heads, but in a batch
138
+ self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
139
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
140
+
141
+ # pagedAttention
142
+ self.attn = pagedAttention(self.n_head,
143
+ self.head_dim,
144
+ self.head_dim**-0.5,
145
+ num_kv_heads=self.n_kv_head,
146
+ )
147
+
148
+ # 2d rotary pos embedding
149
+ grid_size = int(config.block_size ** 0.5)
150
+ assert grid_size * grid_size == config.block_size
151
+ freqs_cis = precompute_freqs_cis_2d(grid_size, config.dim // config.n_head, config.rope_base, config.cls_token_num)
152
+ self.register_buffer('freqs_cis', freqs_cis)
153
+
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ positions: torch.Tensor,
159
+ kv_cache: torch.Tensor,
160
+ attn_metadata: AttentionMetadata,
161
+ ):
162
+ kv_size = self.n_kv_head * self.head_dim
163
+ xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
164
+
165
+ xq = xq.view(*xq.shape[:-1], 1, self.n_head, self.head_dim)
166
+ xk = xk.view(*xk.shape[:-1], 1, self.n_kv_head, self.head_dim)
167
+ freqs_cis = self.freqs_cis[positions].unsqueeze(1)
168
+ xq = apply_rotary_emb_bs(xq, freqs_cis)
169
+ xk = apply_rotary_emb_bs(xk, freqs_cis)
170
+ xq = xq.flatten(1)
171
+ xk = xk.flatten(1)
172
+
173
+ output = self.attn(xq, xk, xv, kv_cache, attn_metadata)
174
+ output = self.wo(output)
175
+
176
+ return output
177
+
178
+
179
+
180
+ class TransformerBlock(nn.Module):
181
+ def __init__(self, config: ModelArgs):
182
+ super().__init__()
183
+ self.attention = Attention(config)
184
+ self.feed_forward = FeedForward(config)
185
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
186
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
187
+
188
+ def forward(self, x: torch.Tensor, positions: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata):
189
+ h = x + self.attention(self.attention_norm(x), positions, kv_cache, attn_metadata)
190
+ out = h + self.feed_forward(self.ffn_norm(h))
191
+ return out
192
+
193
+
194
+ class Transformer(nn.Module):
195
+ def __init__(self, config: ModelArgs):
196
+ super().__init__()
197
+ self.config = config
198
+ self.vocab_size = config.vocab_size
199
+ self.n_layer = config.n_layer
200
+ self.block_size = config.block_size
201
+ self.num_classes = config.num_classes
202
+ self.model_type = config.model_type
203
+ self.cls_token_num = config.cls_token_num
204
+ self.cfg_scale = config.cfg_scale
205
+ if self.model_type == 'c2i':
206
+ self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
207
+ else:
208
+ raise Exception("vllm only supports c2i now, please check model type")
209
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
210
+
211
+ self.layers = torch.nn.ModuleList()
212
+ for layer_id in range(config.n_layer):
213
+ self.layers.append(TransformerBlock(config))
214
+
215
+ # output layer
216
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
217
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
218
+
219
+ self.logits_processor = LogitsProcessor(config.vocab_size)
220
+
221
+ self.sampler = Sampler(config.cfg_scale)
222
+
223
+ def forward(
224
+ self,
225
+ input_ids: torch.Tensor=None,
226
+ positions: torch.Tensor=None,
227
+ kv_caches: List[torch.Tensor]=None,
228
+ attn_metadata: AttentionMetadata=None,
229
+ ):
230
+ # if positions.max() == 0: # prefill in inference
231
+ # token_embeddings = self.cls_embedding(input_ids)
232
+ # else: # decode_n_tokens(kv cache) in inference
233
+ # token_embeddings = self.tok_embeddings(input_ids)
234
+ cond_ids = torch.clamp(input_ids, max=self.num_classes)
235
+ token_embeddings = self.cls_embedding(cond_ids) * (positions.max() == 0) + \
236
+ self.tok_embeddings(input_ids) * (positions.max() != 0)
237
+
238
+ hh = token_embeddings
239
+ # transformer blocks
240
+ for layer_id, layer in enumerate(self.layers):
241
+ hh = layer(hh, positions, kv_caches[layer_id], attn_metadata)
242
+
243
+ # output layers
244
+ hh = self.norm(hh)
245
+ return hh
246
+
247
+ def compute_logits(self, hidden_states: torch.Tensor,
248
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
249
+ logits = self.logits_processor(self.output.weight, hidden_states, sampling_metadata)
250
+ return logits
251
+
252
+ def sample(
253
+ self,
254
+ logits: torch.Tensor,
255
+ sampling_metadata: SamplingMetadata,
256
+ ) -> Optional[SamplerOutput]:
257
+ next_tokens = self.sampler(logits, sampling_metadata)
258
+ return next_tokens
259
+
260
+
261
+ def custom_load_state_dict(self, model_weights):
262
+ model_weights = model_weights.copy()
263
+ for layer_id in range(len(self.layers)):
264
+ branch1 = f'layers.{layer_id}.feed_forward.w1.weight'
265
+ branch3 = f'layers.{layer_id}.feed_forward.w3.weight'
266
+ branch_merged = f'layers.{layer_id}.feed_forward.w_merged.weight'
267
+ model_weights[branch_merged] = torch.cat(
268
+ [model_weights[branch1], model_weights[branch3]], dim=0
269
+ )
270
+ model_weights.pop(branch1)
271
+ model_weights.pop(branch3)
272
+
273
+ if 'freqs_cis' in model_weights:
274
+ model_weights.pop('freqs_cis')
275
+
276
+ self.load_state_dict(model_weights, strict=False)
277
+
278
+
279
+
280
+ #################################################################################
281
+ # Rotary Positional Embedding Functions #
282
+ #################################################################################
283
+ # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
284
+ def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
285
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
286
+ t = torch.arange(seq_len, device=freqs.device)
287
+ freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
288
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
289
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
290
+ cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
291
+ return cond_cache
292
+
293
+
294
+ def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
295
+ # split the dimension into half, one for x and one for y
296
+ half_dim = n_elem // 2
297
+ freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
298
+ t = torch.arange(grid_size, device=freqs.device)
299
+ freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
300
+ freqs_grid = torch.concat([
301
+ freqs[:, None, :].expand(-1, grid_size, -1),
302
+ freqs[None, :, :].expand(grid_size, -1, -1),
303
+ ], dim=-1) # (grid_size, grid_size, head_dim // 2)
304
+ cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
305
+ cache = cache_grid.flatten(0, 1)
306
+ cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
307
+ return cond_cache
308
+
309
+
310
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
311
+ # x: (bs, seq_len, n_head, head_dim)
312
+ # freqs_cis (seq_len, head_dim // 2, 2)
313
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
314
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
315
+ x_out2 = torch.stack([
316
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
317
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
318
+ ], dim=-1)
319
+ x_out2 = x_out2.flatten(3)
320
+ return x_out2.type_as(x)
321
+
322
+
323
+ def apply_rotary_emb_bs(x: torch.Tensor, freqs_cis: torch.Tensor):
324
+ # x: (bs, seq_len, n_head, head_dim)
325
+ # freqs_cis (seq_len, head_dim // 2, 2)
326
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
327
+ freqs_cis = freqs_cis.view(xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2) # (bs, seq_len, 1, head_dim//2, 2)
328
+ x_out2 = torch.stack([
329
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
330
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
331
+ ], dim=-1)
332
+ x_out2 = x_out2.flatten(3)
333
+ return x_out2.type_as(x)
334
+
335
+
336
+ #################################################################################
337
+ # GPT Configs #
338
+ #################################################################################
339
+ ### text-conditional
340
+ def GPT_7B(**kwargs):
341
+ return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B
342
+
343
+ def GPT_3B(**kwargs):
344
+ return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B
345
+
346
+ def GPT_1B(**kwargs):
347
+ return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B
348
+
349
+ ### class-conditional
350
+ def GPT_XXXL(**kwargs):
351
+ return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
352
+
353
+ def GPT_XXL(**kwargs):
354
+ return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
355
+
356
+ def GPT_XL(**kwargs):
357
+ return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
358
+
359
+ def GPT_L(**kwargs):
360
+ return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
361
+
362
+ def GPT_B(**kwargs):
363
+ return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
364
+
365
+
366
+ GPT_models = {
367
+ 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
368
+ 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
369
+ }
serve/gpu_executor.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Set, Tuple, Optional, Set
2
+ import argparse
3
+
4
+ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
5
+ ModelConfig, ParallelConfig, SchedulerConfig,
6
+ SpeculativeConfig, VisionLanguageConfig)
7
+ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
8
+ from vllm.logger import init_logger
9
+ from vllm.lora.request import LoRARequest
10
+ from vllm.sequence import SamplerOutput, SequenceGroupMetadata
11
+ from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
12
+ make_async)
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ class GPUExecutor(ExecutorBase):
18
+ def __init__(
19
+ self,
20
+ args: argparse.ArgumentParser,
21
+ model_config: ModelConfig,
22
+ cache_config: CacheConfig,
23
+ parallel_config: ParallelConfig,
24
+ scheduler_config: SchedulerConfig,
25
+ device_config: DeviceConfig,
26
+ load_config: LoadConfig,
27
+ lora_config: Optional[LoRAConfig],
28
+ vision_language_config: Optional[VisionLanguageConfig],
29
+ speculative_config: Optional[SpeculativeConfig],
30
+ ) -> None:
31
+ self.args = args
32
+ self.model_config = model_config
33
+ self.cache_config = cache_config
34
+ self.lora_config = lora_config
35
+ self.load_config = load_config
36
+ self.parallel_config = parallel_config
37
+ self.scheduler_config = scheduler_config
38
+ self.device_config = device_config
39
+ self.vision_language_config = vision_language_config
40
+ self.speculative_config = speculative_config
41
+
42
+ self._init_executor()
43
+
44
+ def _init_executor(self) -> None:
45
+ """Initialize the worker and load the model.
46
+
47
+ If speculative decoding is enabled, we instead create the speculative
48
+ worker.
49
+ """
50
+ if self.speculative_config is None:
51
+ self._init_non_spec_worker()
52
+ else:
53
+ self._init_spec_worker()
54
+
55
+ def _init_non_spec_worker(self):
56
+ # Lazy import the Worker to avoid importing torch.cuda/xformers
57
+ # before CUDA_VISIBLE_DEVICES is set in the Worker
58
+ # from vllm.worker.worker import Worker
59
+ from serve.worker import Worker
60
+
61
+ assert self.parallel_config.world_size == 1, (
62
+ "GPUExecutor only supports single GPU.")
63
+
64
+ distributed_init_method = get_distributed_init_method(
65
+ get_ip(), get_open_port())
66
+ self.driver_worker = Worker(
67
+ model_config=self.model_config,
68
+ parallel_config=self.parallel_config,
69
+ scheduler_config=self.scheduler_config,
70
+ device_config=self.device_config,
71
+ cache_config=self.cache_config,
72
+ load_config=self.load_config,
73
+ local_rank=0,
74
+ rank=0,
75
+ distributed_init_method=distributed_init_method,
76
+ lora_config=self.lora_config,
77
+ vision_language_config=self.vision_language_config,
78
+ is_driver_worker=True,
79
+ )
80
+ self.driver_worker.init_device()
81
+ self.driver_worker.load_model(self.args)
82
+
83
+ def _init_spec_worker(self):
84
+ """Initialize a SpecDecodeWorker, using a draft model for proposals.
85
+ """
86
+ assert self.speculative_config is not None
87
+
88
+ from vllm.spec_decode.multi_step_worker import MultiStepWorker
89
+ from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
90
+ from vllm.worker.worker import Worker
91
+
92
+ distributed_init_method = get_distributed_init_method(
93
+ get_ip(), get_open_port())
94
+
95
+ target_worker = Worker(
96
+ model_config=self.model_config,
97
+ parallel_config=self.parallel_config,
98
+ scheduler_config=self.scheduler_config,
99
+ device_config=self.device_config,
100
+ cache_config=self.cache_config,
101
+ load_config=self.load_config,
102
+ local_rank=0,
103
+ rank=0,
104
+ distributed_init_method=distributed_init_method,
105
+ lora_config=self.lora_config,
106
+ vision_language_config=self.vision_language_config,
107
+ is_driver_worker=True,
108
+ )
109
+
110
+ draft_worker = MultiStepWorker(
111
+ model_config=self.speculative_config.draft_model_config,
112
+ parallel_config=self.speculative_config.draft_parallel_config,
113
+ scheduler_config=self.scheduler_config,
114
+ device_config=self.device_config,
115
+ cache_config=self.cache_config,
116
+ load_config=self.load_config,
117
+ local_rank=0,
118
+ rank=0,
119
+ distributed_init_method=distributed_init_method,
120
+ lora_config=self.lora_config,
121
+ vision_language_config=self.vision_language_config,
122
+ is_driver_worker=True,
123
+ )
124
+
125
+ spec_decode_worker = SpecDecodeWorker.from_workers(
126
+ proposer_worker=draft_worker, scorer_worker=target_worker)
127
+
128
+ assert self.parallel_config.world_size == 1, (
129
+ "GPUExecutor only supports single GPU.")
130
+
131
+ self.driver_worker = spec_decode_worker
132
+
133
+ # Load model handled in spec decode worker.
134
+ self.driver_worker.init_device()
135
+
136
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
137
+ """Determine the number of available KV blocks by invoking the
138
+ underlying worker.
139
+ """
140
+ return self.driver_worker.determine_num_available_blocks()
141
+
142
+ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
143
+ """Initialize the KV cache by invoking the underlying worker.
144
+ """
145
+ # NOTE: This is logged in the executor because there can be >1 worker
146
+ # with other executors. We could log in the engine level, but work
147
+ # remains to abstract away the device for non-GPU configurations.
148
+ logger.info(f"# GPU blocks: {num_gpu_blocks}, "
149
+ f"# CPU blocks: {num_cpu_blocks}")
150
+
151
+ self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
152
+
153
+ def execute_model(
154
+ self,
155
+ seq_group_metadata_list: List[SequenceGroupMetadata],
156
+ blocks_to_swap_in: Dict[int, int],
157
+ blocks_to_swap_out: Dict[int, int],
158
+ blocks_to_copy: Dict[int, List[int]],
159
+ num_lookahead_slots: int,
160
+ ) -> List[SamplerOutput]:
161
+ output = self.driver_worker.execute_model(
162
+ seq_group_metadata_list=seq_group_metadata_list,
163
+ blocks_to_swap_in=blocks_to_swap_in,
164
+ blocks_to_swap_out=blocks_to_swap_out,
165
+ blocks_to_copy=blocks_to_copy,
166
+ num_lookahead_slots=num_lookahead_slots,
167
+ )
168
+ return output
169
+
170
+ def add_lora(self, lora_request: LoRARequest) -> bool:
171
+ assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
172
+ return self.driver_worker.add_lora(lora_request)
173
+
174
+ def remove_lora(self, lora_id: int) -> bool:
175
+ assert lora_id > 0, "lora_id must be greater than 0."
176
+ return self.driver_worker.remove_lora(lora_id)
177
+
178
+ def list_loras(self) -> Set[int]:
179
+ return self.driver_worker.list_loras()
180
+
181
+ def check_health(self) -> None:
182
+ # GPUExecutor will always be healthy as long as
183
+ # it's running.
184
+ return
185
+
186
+
187
+ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
188
+
189
+ async def execute_model_async(
190
+ self,
191
+ seq_group_metadata_list: List[SequenceGroupMetadata],
192
+ blocks_to_swap_in: Dict[int, int],
193
+ blocks_to_swap_out: Dict[int, int],
194
+ blocks_to_copy: Dict[int, List[int]],
195
+ ) -> SamplerOutput:
196
+ output = await make_async(self.driver_worker.execute_model)(
197
+ seq_group_metadata_list=seq_group_metadata_list,
198
+ blocks_to_swap_in=blocks_to_swap_in,
199
+ blocks_to_swap_out=blocks_to_swap_out,
200
+ blocks_to_copy=blocks_to_copy)
201
+ return output
serve/llm.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
3
+ from typing import List, Optional, Union
4
+ import argparse
5
+
6
+ import torch
7
+ from tqdm import tqdm
8
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
9
+
10
+ from vllm.engine.arg_utils import EngineArgs
11
+ # from vllm.engine.llm_engine import LLMEngine
12
+ from vllm.lora.request import LoRARequest
13
+ from vllm.outputs import RequestOutput
14
+ from vllm.sampling_params import SamplingParams
15
+ from vllm.sequence import MultiModalData
16
+ from vllm.usage.usage_lib import UsageContext
17
+ from vllm.utils import Counter
18
+
19
+ from serve.llm_engine import LLMEngine
20
+
21
+
22
+ class LLM:
23
+ """An LLM for generating texts from given prompts and sampling parameters.
24
+
25
+ This class includes a tokenizer, a language model (possibly distributed
26
+ across multiple GPUs), and GPU memory space allocated for intermediate
27
+ states (aka KV cache). Given a batch of prompts and sampling parameters,
28
+ this class generates texts from the model, using an intelligent batching
29
+ mechanism and efficient memory management.
30
+
31
+ NOTE: This class is intended to be used for offline inference. For online
32
+ serving, use the `AsyncLLMEngine` class instead.
33
+ NOTE: For the comprehensive list of arguments, see `EngineArgs`.
34
+
35
+ Args:
36
+ model: The name or path of a HuggingFace Transformers model.
37
+ tokenizer: The name or path of a HuggingFace Transformers tokenizer.
38
+ tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
39
+ if available, and "slow" will always use the slow tokenizer.
40
+ skip_tokenizer_init: If true, skip initialization of tokenizer and
41
+ detokenizer. Expect valid prompt_token_ids and None for prompt
42
+ from the input.
43
+ trust_remote_code: Trust remote code (e.g., from HuggingFace) when
44
+ downloading the model and tokenizer.
45
+ tensor_parallel_size: The number of GPUs to use for distributed
46
+ execution with tensor parallelism.
47
+ dtype: The data type for the model weights and activations. Currently,
48
+ we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
49
+ the `torch_dtype` attribute specified in the model config file.
50
+ However, if the `torch_dtype` in the config is `float32`, we will
51
+ use `float16` instead.
52
+ quantization: The method used to quantize the model weights. Currently,
53
+ we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
54
+ If None, we first check the `quantization_config` attribute in the
55
+ model config file. If that is None, we assume the model weights are
56
+ not quantized and use `dtype` to determine the data type of
57
+ the weights.
58
+ revision: The specific model version to use. It can be a branch name,
59
+ a tag name, or a commit id.
60
+ tokenizer_revision: The specific tokenizer version to use. It can be a
61
+ branch name, a tag name, or a commit id.
62
+ seed: The seed to initialize the random number generator for sampling.
63
+ gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
64
+ reserve for the model weights, activations, and KV cache. Higher
65
+ values will increase the KV cache size and thus improve the model's
66
+ throughput. However, if the value is too high, it may cause out-of-
67
+ memory (OOM) errors.
68
+ swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
69
+ This can be used for temporarily storing the states of the requests
70
+ when their `best_of` sampling parameters are larger than 1. If all
71
+ requests will have `best_of=1`, you can safely set this to 0.
72
+ Otherwise, too small values may cause out-of-memory (OOM) errors.
73
+ enforce_eager: Whether to enforce eager execution. If True, we will
74
+ disable CUDA graph and always execute the model in eager mode.
75
+ If False, we will use CUDA graph and eager execution in hybrid.
76
+ max_context_len_to_capture: Maximum context len covered by CUDA graphs.
77
+ When a sequence has context length larger than this, we fall back
78
+ to eager mode.
79
+ disable_custom_all_reduce: See ParallelConfig
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ args: argparse.ArgumentParser,
85
+ model: str,
86
+ tokenizer: Optional[str] = None,
87
+ tokenizer_mode: str = "auto",
88
+ skip_tokenizer_init: bool = False,
89
+ trust_remote_code: bool = False,
90
+ tensor_parallel_size: int = 1,
91
+ dtype: str = "auto",
92
+ quantization: Optional[str] = None,
93
+ revision: Optional[str] = None,
94
+ tokenizer_revision: Optional[str] = None,
95
+ seed: int = 0,
96
+ gpu_memory_utilization: float = 0.9,
97
+ swap_space: int = 4,
98
+ enforce_eager: bool = False,
99
+ max_context_len_to_capture: int = 8192,
100
+ disable_custom_all_reduce: bool = False,
101
+ **kwargs,
102
+ ) -> None:
103
+ if "disable_log_stats" not in kwargs:
104
+ kwargs["disable_log_stats"] = True
105
+ engine_args = EngineArgs(
106
+ model=model,
107
+ tokenizer=tokenizer,
108
+ tokenizer_mode=tokenizer_mode,
109
+ skip_tokenizer_init=skip_tokenizer_init,
110
+ trust_remote_code=trust_remote_code,
111
+ tensor_parallel_size=tensor_parallel_size,
112
+ dtype=dtype,
113
+ quantization=quantization,
114
+ revision=revision,
115
+ tokenizer_revision=tokenizer_revision,
116
+ seed=seed,
117
+ gpu_memory_utilization=gpu_memory_utilization,
118
+ swap_space=swap_space,
119
+ enforce_eager=enforce_eager,
120
+ max_context_len_to_capture=max_context_len_to_capture,
121
+ disable_custom_all_reduce=disable_custom_all_reduce,
122
+ **kwargs,
123
+ )
124
+ self.llm_engine = LLMEngine.from_engine_args(
125
+ engine_args, usage_context=UsageContext.LLM_CLASS, args=args)
126
+ self.request_counter = Counter()
127
+
128
+ def get_tokenizer(
129
+ self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
130
+ return self.llm_engine.tokenizer.tokenizer
131
+
132
+ def set_tokenizer(
133
+ self,
134
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
135
+ ) -> None:
136
+ self.llm_engine.tokenizer.tokenizer = tokenizer
137
+
138
+ def generate(
139
+ self,
140
+ prompts: Optional[Union[str, List[str]]] = None,
141
+ sampling_params: Optional[Union[SamplingParams,
142
+ List[SamplingParams]]] = None,
143
+ prompt_token_ids: Optional[List[List[int]]] = None,
144
+ use_tqdm: bool = True,
145
+ lora_request: Optional[LoRARequest] = None,
146
+ multi_modal_data: Optional[MultiModalData] = None,
147
+ ) -> List[RequestOutput]:
148
+ """Generates the completions for the input prompts.
149
+
150
+ NOTE: This class automatically batches the given prompts, considering
151
+ the memory constraint. For the best performance, put all of your prompts
152
+ into a single list and pass it to this method.
153
+
154
+ Args:
155
+ prompts: A list of prompts to generate completions for.
156
+ sampling_params: The sampling parameters for text generation. If
157
+ None, we use the default sampling parameters.
158
+ When it is a single value, it is applied to every prompt.
159
+ When it is a list, the list must have the same length as the
160
+ prompts and it is paired one by one with the prompt.
161
+ prompt_token_ids: A list of token IDs for the prompts. If None, we
162
+ use the tokenizer to convert the prompts to token IDs.
163
+ use_tqdm: Whether to use tqdm to display the progress bar.
164
+ lora_request: LoRA request to use for generation, if any.
165
+ multi_modal_data: Multi modal data.
166
+
167
+ Returns:
168
+ A list of `RequestOutput` objects containing the generated
169
+ completions in the same order as the input prompts.
170
+ """
171
+ if prompts is None and prompt_token_ids is None:
172
+ raise ValueError("Either prompts or prompt_token_ids must be "
173
+ "provided.")
174
+ if self.llm_engine.model_config.skip_tokenizer_init \
175
+ and prompts is not None:
176
+ raise ValueError("prompts must be None if skip_tokenizer_init "
177
+ "is True")
178
+ if isinstance(prompts, str):
179
+ # Convert a single prompt to a list.
180
+ prompts = [prompts]
181
+ if (prompts is not None and prompt_token_ids is not None
182
+ and len(prompts) != len(prompt_token_ids)):
183
+ raise ValueError("The lengths of prompts and prompt_token_ids "
184
+ "must be the same.")
185
+
186
+ if prompts is not None:
187
+ num_requests = len(prompts)
188
+ else:
189
+ assert prompt_token_ids is not None
190
+ num_requests = len(prompt_token_ids)
191
+
192
+ if sampling_params is None:
193
+ # Use default sampling params.
194
+ sampling_params = SamplingParams()
195
+
196
+ elif isinstance(sampling_params,
197
+ list) and len(sampling_params) != num_requests:
198
+ raise ValueError("The lengths of prompts and sampling_params "
199
+ "must be the same.")
200
+ if multi_modal_data:
201
+ multi_modal_data.data = multi_modal_data.data.to(torch.float16)
202
+
203
+ # Add requests to the engine.
204
+ for i in range(num_requests):
205
+ prompt = prompts[i] if prompts is not None else None
206
+ token_ids = None if prompt_token_ids is None else prompt_token_ids[i]
207
+ self._add_request(
208
+ prompt,
209
+ sampling_params[i]
210
+ if isinstance(sampling_params, list) else sampling_params,
211
+ token_ids,
212
+ lora_request=lora_request,
213
+ # Get ith image while maintaining the batch dim.
214
+ multi_modal_data=MultiModalData(
215
+ type=multi_modal_data.type,
216
+ data=multi_modal_data.data[i].unsqueeze(0))
217
+ if multi_modal_data else None,
218
+ )
219
+ return self._run_engine(use_tqdm)
220
+
221
+ def _add_request(
222
+ self,
223
+ prompt: Optional[str],
224
+ sampling_params: SamplingParams,
225
+ prompt_token_ids: Optional[List[int]],
226
+ lora_request: Optional[LoRARequest] = None,
227
+ multi_modal_data: Optional[MultiModalData] = None,
228
+ ) -> None:
229
+ request_id = str(next(self.request_counter))
230
+ self.llm_engine.add_request(request_id,
231
+ prompt,
232
+ sampling_params,
233
+ prompt_token_ids,
234
+ lora_request=lora_request,
235
+ multi_modal_data=multi_modal_data)
236
+
237
+
238
+ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
239
+ # Initialize tqdm.
240
+ if use_tqdm:
241
+ num_requests = self.llm_engine.get_num_unfinished_requests()
242
+ pbar = tqdm(
243
+ total=num_requests,
244
+ desc="Processed prompts",
245
+ dynamic_ncols=True,
246
+ postfix=f"Generation Speed: {0:.2f} toks/s",
247
+ )
248
+ # Run the engine.
249
+ outputs: List[RequestOutput] = []
250
+ while self.llm_engine.has_unfinished_requests():
251
+ step_outputs = self.llm_engine.step()
252
+ for output in step_outputs:
253
+ if output.finished:
254
+ outputs.append(output)
255
+ if use_tqdm:
256
+ total_toks += (sum(
257
+ len(stp.token_ids) for stp in output.outputs))
258
+ spd = total_toks / pbar.format_dict["elapsed"]
259
+ pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
260
+ pbar.update(1)
261
+ if use_tqdm:
262
+ pbar.close()
263
+ # Sort the outputs by request ID.
264
+ # This is necessary because some requests may be finished earlier than
265
+ # its previous requests.
266
+ outputs = sorted(outputs, key=lambda x: int(x.request_id))
267
+ return outputs
serve/llm_engine.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py
3
+ import time
4
+ from typing import Iterable, List, Optional, Type, Union
5
+ import argparse
6
+
7
+ from transformers import GenerationConfig, PreTrainedTokenizer
8
+
9
+ import vllm
10
+ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
11
+ LoRAConfig, ModelConfig, ParallelConfig,
12
+ SchedulerConfig, SpeculativeConfig,
13
+ VisionLanguageConfig)
14
+ from vllm.core.scheduler import Scheduler, SchedulerOutputs
15
+ from vllm.engine.arg_utils import EngineArgs
16
+ from vllm.engine.metrics import StatLogger, Stats
17
+ from vllm.engine.output_processor.interfaces import (
18
+ SequenceGroupOutputProcessor)
19
+ from vllm.engine.output_processor.stop_checker import StopChecker
20
+ from vllm.engine.output_processor.util import create_output_by_sequence_group
21
+ from vllm.engine.ray_utils import initialize_ray_cluster
22
+ from vllm.executor.executor_base import ExecutorBase
23
+ from vllm.logger import init_logger
24
+ from vllm.lora.request import LoRARequest
25
+ from vllm.outputs import RequestOutput
26
+ from vllm.sampling_params import SamplingParams
27
+ from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
28
+ SequenceGroup)
29
+ from vllm.transformers_utils.detokenizer import Detokenizer
30
+ from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
31
+ get_tokenizer_group)
32
+ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
33
+ usage_message)
34
+ from vllm.utils import Counter
35
+
36
+ logger = init_logger(__name__)
37
+ _LOCAL_LOGGING_INTERVAL_SEC = 5
38
+
39
+
40
+ def _load_generation_config_dict(model_config: ModelConfig):
41
+ try:
42
+ return GenerationConfig.from_pretrained(
43
+ model_config.model,
44
+ revision=model_config.revision,
45
+ ).to_diff_dict()
46
+ except OSError:
47
+ # Not found.
48
+ return {}
49
+
50
+
51
+ class LLMEngine:
52
+ """An LLM engine that receives requests and generates texts.
53
+
54
+ This is the main class for the vLLM engine. It receives requests
55
+ from clients and generates texts from the LLM. It includes a tokenizer, a
56
+ language model (possibly distributed across multiple GPUs), and GPU memory
57
+ space allocated for intermediate states (aka KV cache). This class utilizes
58
+ iteration-level scheduling and efficient memory management to maximize the
59
+ serving throughput.
60
+
61
+ The `LLM` class wraps this class for offline batched inference and the
62
+ `AsyncLLMEngine` class wraps this class for online serving.
63
+
64
+ NOTE: The config arguments are derived from the `EngineArgs` class. For the
65
+ comprehensive list of arguments, see `EngineArgs`.
66
+
67
+ Args:
68
+ model_config: The configuration related to the LLM model.
69
+ cache_config: The configuration related to the KV cache memory
70
+ management.
71
+ parallel_config: The configuration related to distributed execution.
72
+ scheduler_config: The configuration related to the request scheduler.
73
+ device_config: The configuration related to the device.
74
+ lora_config (Optional): The configuration related to serving multi-LoRA.
75
+ vision_language_config (Optional): The configuration related to vision
76
+ language models.
77
+ speculative_config (Optional): The configuration related to speculative
78
+ decoding.
79
+ executor_class: The model executor class for managing distributed
80
+ execution.
81
+ log_stats: Whether to log statistics.
82
+ usage_context: Specified entry point, used for usage info collection
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ args: argparse.ArgumentParser,
88
+ model_config: ModelConfig,
89
+ cache_config: CacheConfig,
90
+ parallel_config: ParallelConfig,
91
+ scheduler_config: SchedulerConfig,
92
+ device_config: DeviceConfig,
93
+ load_config: LoadConfig,
94
+ lora_config: Optional[LoRAConfig],
95
+ vision_language_config: Optional[VisionLanguageConfig],
96
+ speculative_config: Optional[SpeculativeConfig],
97
+ decoding_config: Optional[DecodingConfig],
98
+ executor_class: Type[ExecutorBase],
99
+ log_stats: bool,
100
+ usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
101
+ ) -> None:
102
+ logger.info(
103
+ f"Initializing an LLM engine (v{vllm.__version__}) with config: "
104
+ f"model={model_config.model!r}, "
105
+ f"speculative_config={speculative_config!r}, "
106
+ f"tokenizer={model_config.tokenizer!r}, "
107
+ f"skip_tokenizer_init={model_config.skip_tokenizer_init}, "
108
+ f"tokenizer_mode={model_config.tokenizer_mode}, "
109
+ f"revision={model_config.revision}, "
110
+ f"tokenizer_revision={model_config.tokenizer_revision}, "
111
+ f"trust_remote_code={model_config.trust_remote_code}, "
112
+ f"dtype={model_config.dtype}, "
113
+ f"max_seq_len={model_config.max_model_len}, "
114
+ f"download_dir={load_config.download_dir!r}, "
115
+ f"load_format={load_config.load_format}, "
116
+ f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
117
+ f"disable_custom_all_reduce="
118
+ f"{parallel_config.disable_custom_all_reduce}, "
119
+ f"quantization={model_config.quantization}, "
120
+ f"enforce_eager={model_config.enforce_eager}, "
121
+ f"kv_cache_dtype={cache_config.cache_dtype}, "
122
+ f"quantization_param_path={model_config.quantization_param_path}, "
123
+ f"device_config={device_config.device}, "
124
+ f"decoding_config={decoding_config!r}, "
125
+ f"seed={model_config.seed})")
126
+ # TODO(woosuk): Print more configs in debug mode.
127
+
128
+ self.model_config = model_config
129
+ self.cache_config = cache_config
130
+ self.lora_config = lora_config
131
+ self.vision_language_config = vision_language_config
132
+ self.parallel_config = parallel_config
133
+ self.scheduler_config = scheduler_config
134
+ self.device_config = device_config
135
+ self.speculative_config = speculative_config
136
+ self.load_config = load_config
137
+ self.decoding_config = decoding_config or DecodingConfig()
138
+ self.log_stats = log_stats
139
+
140
+ if not self.model_config.skip_tokenizer_init:
141
+ self.tokenizer: BaseTokenizerGroup
142
+ self._init_tokenizer()
143
+ self.detokenizer = Detokenizer(self.tokenizer)
144
+ else:
145
+ self.detokenizer = None
146
+ self.tokenizer = None
147
+
148
+ self.seq_counter = Counter()
149
+ self.generation_config_fields = _load_generation_config_dict(
150
+ model_config)
151
+
152
+ self.model_executor = executor_class(
153
+ args=args,
154
+ model_config=model_config,
155
+ cache_config=cache_config,
156
+ parallel_config=parallel_config,
157
+ scheduler_config=scheduler_config,
158
+ device_config=device_config,
159
+ lora_config=lora_config,
160
+ vision_language_config=vision_language_config,
161
+ speculative_config=speculative_config,
162
+ load_config=load_config,
163
+ )
164
+
165
+ self._initialize_kv_caches()
166
+
167
+ # If usage stat is enabled, collect relevant info.
168
+ if is_usage_stats_enabled():
169
+ from vllm.model_executor.model_loader import (
170
+ get_architecture_class_name)
171
+ usage_message.report_usage(
172
+ get_architecture_class_name(model_config),
173
+ usage_context,
174
+ extra_kvs={
175
+ # Common configuration
176
+ "dtype":
177
+ str(model_config.dtype),
178
+ "tensor_parallel_size":
179
+ parallel_config.tensor_parallel_size,
180
+ "block_size":
181
+ cache_config.block_size,
182
+ "gpu_memory_utilization":
183
+ cache_config.gpu_memory_utilization,
184
+
185
+ # Quantization
186
+ "quantization":
187
+ model_config.quantization,
188
+ "kv_cache_dtype":
189
+ cache_config.cache_dtype,
190
+
191
+ # Feature flags
192
+ "enable_lora":
193
+ bool(lora_config),
194
+ "enable_prefix_caching":
195
+ cache_config.enable_prefix_caching,
196
+ "enforce_eager":
197
+ model_config.enforce_eager,
198
+ "disable_custom_all_reduce":
199
+ parallel_config.disable_custom_all_reduce,
200
+ })
201
+
202
+ if self.tokenizer:
203
+ # Ping the tokenizer to ensure liveness if it runs in a
204
+ # different process.
205
+ self.tokenizer.ping()
206
+
207
+ # Create the scheduler.
208
+ # NOTE: the cache_config here have been updated with the numbers of
209
+ # GPU and CPU blocks, which are profiled in the distributed executor.
210
+ self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
211
+
212
+ # Metric Logging.
213
+ if self.log_stats:
214
+ self.stat_logger = StatLogger(
215
+ local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
216
+ labels=dict(model_name=model_config.model))
217
+ self.stat_logger.info("cache_config", self.cache_config)
218
+
219
+ # Create sequence output processor, e.g. for beam search or
220
+ # speculative decoding.
221
+ self.output_processor = (
222
+ SequenceGroupOutputProcessor.create_output_processor(
223
+ self.scheduler_config,
224
+ self.detokenizer,
225
+ self.scheduler,
226
+ self.seq_counter,
227
+ self.get_tokenizer_for_seq,
228
+ stop_checker=StopChecker(
229
+ self.scheduler_config.max_model_len,
230
+ self.get_tokenizer_for_seq,
231
+ ),
232
+ ))
233
+
234
+ def _initialize_kv_caches(self) -> None:
235
+ """Initialize the KV cache in the worker(s).
236
+
237
+ The workers will determine the number of blocks in both the GPU cache
238
+ and the swap CPU cache.
239
+ """
240
+ num_gpu_blocks, num_cpu_blocks = (
241
+ self.model_executor.determine_num_available_blocks())
242
+
243
+ if self.cache_config.num_gpu_blocks_override is not None:
244
+ num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
245
+ logger.info(f"Overriding {num_gpu_blocks=} with "
246
+ f"{num_gpu_blocks_override=}")
247
+ num_gpu_blocks = num_gpu_blocks_override
248
+
249
+ self.cache_config.num_gpu_blocks = num_gpu_blocks
250
+ self.cache_config.num_cpu_blocks = num_cpu_blocks
251
+
252
+ self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
253
+
254
+ @classmethod
255
+ def from_engine_args(
256
+ cls,
257
+ engine_args: EngineArgs,
258
+ usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
259
+ args: argparse.ArgumentParser = None,
260
+ ) -> "LLMEngine":
261
+ """Creates an LLM engine from the engine arguments."""
262
+ # Create the engine configs.
263
+ engine_config = engine_args.create_engine_config()
264
+
265
+ # Initialize the cluster and specify the executor class.
266
+ if engine_config.device_config.device_type == "neuron":
267
+ from vllm.executor.neuron_executor import NeuronExecutor
268
+ executor_class = NeuronExecutor
269
+ elif engine_config.device_config.device_type == "cpu":
270
+ from vllm.executor.cpu_executor import CPUExecutor
271
+ executor_class = CPUExecutor
272
+ elif engine_config.parallel_config.worker_use_ray:
273
+ initialize_ray_cluster(engine_config.parallel_config)
274
+ from vllm.executor.ray_gpu_executor import RayGPUExecutor
275
+ executor_class = RayGPUExecutor
276
+ else:
277
+ assert engine_config.parallel_config.world_size == 1, (
278
+ "Ray is required if parallel_config.world_size > 1.")
279
+ # from vllm.executor.gpu_executor import GPUExecutor
280
+ from serve.gpu_executor import GPUExecutor
281
+ executor_class = GPUExecutor
282
+
283
+ # Create the LLM engine.
284
+ engine = cls(
285
+ **engine_config.to_dict(),
286
+ executor_class=executor_class,
287
+ log_stats=not engine_args.disable_log_stats,
288
+ usage_context=usage_context,
289
+ args=args,
290
+ )
291
+ return engine
292
+
293
+ def __reduce__(self):
294
+ # This is to ensure that the LLMEngine is not referenced in
295
+ # the closure used to initialize Ray worker actors
296
+ raise RuntimeError("LLMEngine should not be pickled!")
297
+
298
+ def get_tokenizer(self) -> "PreTrainedTokenizer":
299
+ return self.tokenizer.get_lora_tokenizer(None)
300
+
301
+ def get_tokenizer_for_seq(self,
302
+ sequence: Sequence) -> "PreTrainedTokenizer":
303
+ return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
304
+
305
+ def _init_tokenizer(self, **tokenizer_init_kwargs):
306
+ init_kwargs = dict(
307
+ tokenizer_id=self.model_config.tokenizer,
308
+ enable_lora=bool(self.lora_config),
309
+ max_num_seqs=self.scheduler_config.max_num_seqs,
310
+ max_input_length=None,
311
+ tokenizer_mode=self.model_config.tokenizer_mode,
312
+ trust_remote_code=self.model_config.trust_remote_code,
313
+ revision=self.model_config.tokenizer_revision)
314
+ init_kwargs.update(tokenizer_init_kwargs)
315
+ self.tokenizer = get_tokenizer_group(
316
+ self.parallel_config.tokenizer_pool_config, **init_kwargs)
317
+
318
+ def _verify_args(self) -> None:
319
+ self.model_config.verify_with_parallel_config(self.parallel_config)
320
+ self.cache_config.verify_with_parallel_config(self.parallel_config)
321
+ if self.lora_config:
322
+ self.lora_config.verify_with_model_config(self.model_config)
323
+ self.lora_config.verify_with_scheduler_config(
324
+ self.scheduler_config)
325
+
326
+ def encode_request(
327
+ self,
328
+ request_id: str, # pylint: disable=unused-argument
329
+ prompt: Optional[str],
330
+ prompt_token_ids: Optional[List[int]] = None,
331
+ lora_request: Optional[LoRARequest] = None,
332
+ ):
333
+ if prompt_token_ids is None:
334
+ assert prompt is not None
335
+ prompt_token_ids = self.tokenizer.encode(request_id=request_id,
336
+ prompt=prompt,
337
+ lora_request=lora_request)
338
+ return prompt_token_ids
339
+
340
+ def add_request(
341
+ self,
342
+ request_id: str,
343
+ prompt: Optional[str],
344
+ sampling_params: SamplingParams,
345
+ prompt_token_ids: Optional[List[int]] = None,
346
+ arrival_time: Optional[float] = None,
347
+ lora_request: Optional[LoRARequest] = None,
348
+ multi_modal_data: Optional[MultiModalData] = None,
349
+ ) -> None:
350
+ """Add a request to the engine's request pool.
351
+
352
+ The request is added to the request pool and will be processed by the
353
+ scheduler as `engine.step()` is called. The exact scheduling policy is
354
+ determined by the scheduler.
355
+
356
+ Args:
357
+ request_id: The unique ID of the request.
358
+ prompt: The prompt string. Can be None if prompt_token_ids is
359
+ provided.
360
+ sampling_params: The sampling parameters for text generation.
361
+ prompt_token_ids: The token IDs of the prompt. If None, we
362
+ use the tokenizer to convert the prompts to token IDs.
363
+ arrival_time: The arrival time of the request. If None, we use
364
+ the current monotonic time.
365
+ multi_modal_data: Multi modal data per request.
366
+
367
+ Details:
368
+ - Set arrival_time to the current time if it is None.
369
+ - Set prompt_token_ids to the encoded prompt if it is None.
370
+ - Create `best_of` number of :class:`~vllm.Sequence` objects.
371
+ - Create a :class:`~vllm.SequenceGroup` object
372
+ from the list of :class:`~vllm.Sequence`.
373
+ - Add the :class:`~vllm.SequenceGroup` object to the scheduler.
374
+
375
+ Example:
376
+ >>> # initialize engine
377
+ >>> engine = LLMEngine.from_engine_args(engine_args)
378
+ >>> # set request arguments
379
+ >>> example_prompt = "Who is the president of the United States?"
380
+ >>> sampling_params = SamplingParams(temperature=0.0)
381
+ >>> request_id = 0
382
+ >>>
383
+ >>> # add the request to the engine
384
+ >>> engine.add_request(
385
+ >>> str(request_id),
386
+ >>> example_prompt,
387
+ >>> SamplingParams(temperature=0.0))
388
+ >>> # continue the request processing
389
+ >>> ...
390
+ """
391
+ if lora_request is not None and not self.lora_config:
392
+ raise ValueError(f"Got lora_request {lora_request} but LoRA is "
393
+ "not enabled!")
394
+ max_logprobs = self.get_model_config().max_logprobs
395
+ if (sampling_params.logprobs
396
+ and sampling_params.logprobs > max_logprobs) or (
397
+ sampling_params.prompt_logprobs
398
+ and sampling_params.prompt_logprobs > max_logprobs):
399
+ raise ValueError(f"Cannot request more than "
400
+ f"{max_logprobs} logprobs.")
401
+ if arrival_time is None:
402
+ arrival_time = time.time()
403
+ prompt_token_ids = self.encode_request(
404
+ request_id=request_id,
405
+ prompt=prompt,
406
+ prompt_token_ids=prompt_token_ids,
407
+ lora_request=lora_request)
408
+
409
+ # Create the sequences.
410
+ block_size = self.cache_config.block_size
411
+ seq_id = next(self.seq_counter)
412
+ eos_token_id = None
413
+ if self.tokenizer:
414
+ eos_token_id = self.tokenizer.get_lora_tokenizer(
415
+ lora_request).eos_token_id
416
+ else:
417
+ logger.warning("Use None for EOS token id because tokenizer is "
418
+ "not initialized")
419
+ seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
420
+ eos_token_id, lora_request)
421
+
422
+ # Defensive copy of SamplingParams, which are used by the sampler,
423
+ # this doesn't deep-copy LogitsProcessor objects
424
+ sampling_params = sampling_params.clone()
425
+ # Add the eos token id into the sampling_params to support min_tokens
426
+ # processing
427
+ if seq.eos_token_id is not None:
428
+ sampling_params.all_stop_token_ids.add(seq.eos_token_id)
429
+ sampling_params.update_from_generation_config(
430
+ self.generation_config_fields)
431
+
432
+ # Create the sequence group.
433
+ seq_group = SequenceGroup(request_id, [seq], sampling_params,
434
+ arrival_time, lora_request, multi_modal_data)
435
+
436
+ # Add the sequence group to the scheduler.
437
+ self.scheduler.add_seq_group(seq_group)
438
+
439
+ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
440
+ """Aborts a request(s) with the given ID.
441
+
442
+ Args:
443
+ request_id: The ID(s) of the request to abort.
444
+
445
+ Details:
446
+ - Refer to the
447
+ :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
448
+ from class :class:`~vllm.core.scheduler.Scheduler`.
449
+
450
+ Example:
451
+ >>> # initialize engine and add a request with request_id
452
+ >>> request_id = str(0)
453
+ >>> # abort the request
454
+ >>> engine.abort_request(request_id)
455
+ """
456
+ self.scheduler.abort_seq_group(request_id)
457
+
458
+ def get_model_config(self) -> ModelConfig:
459
+ """Gets the model configuration."""
460
+ return self.model_config
461
+
462
+ def get_num_unfinished_requests(self) -> int:
463
+ """Gets the number of unfinished requests."""
464
+ return self.scheduler.get_num_unfinished_seq_groups()
465
+
466
+ def has_unfinished_requests(self) -> bool:
467
+ """Returns True if there are unfinished requests."""
468
+ return self.scheduler.has_unfinished_seqs()
469
+
470
+ def _process_model_outputs(
471
+ self, output: List[SamplerOutput],
472
+ scheduled_seq_groups: List[SequenceGroup],
473
+ ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
474
+ """Apply the model output to the sequences in the scheduled seq groups.
475
+
476
+ Returns RequestOutputs that can be returned to the client.
477
+ """
478
+ now = time.time()
479
+
480
+ # Organize outputs by [sequence group][step] instead of
481
+ # [step][sequence group].
482
+ output_by_sequence_group = create_output_by_sequence_group(
483
+ sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
484
+
485
+ # Update the scheduled sequence groups with the model outputs.
486
+ for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
487
+ output_by_sequence_group):
488
+ seq_group = scheduled_seq_group.seq_group
489
+ seq_group.update_num_computed_tokens(
490
+ scheduled_seq_group.token_chunk_size)
491
+ # If uncomputed tokens > 0, it means prefill is chunked.
492
+ # We don't need to process outputs in that case.
493
+ if seq_group.get_num_uncomputed_tokens() == 0:
494
+ self.output_processor.process_outputs(seq_group, outputs)
495
+
496
+ # Free the finished sequence groups.
497
+ self.scheduler.free_finished_seq_groups()
498
+
499
+ # Create the outputs.
500
+ request_outputs: List[RequestOutput] = []
501
+ for scheduled_seq_group in scheduled_seq_groups:
502
+ seq_group = scheduled_seq_group.seq_group
503
+ seq_group.maybe_set_first_token_time(now)
504
+ request_output = RequestOutput.from_seq_group(seq_group)
505
+ request_outputs.append(request_output)
506
+ for seq_group in ignored_seq_groups:
507
+ request_output = RequestOutput.from_seq_group(seq_group)
508
+ request_outputs.append(request_output)
509
+ return request_outputs
510
+
511
+ def step(self) -> List[RequestOutput]:
512
+ """Performs one decoding iteration and returns newly generated results.
513
+
514
+ .. figure:: https://i.imgur.com/sv2HssD.png
515
+ :alt: Overview of the step function
516
+ :align: center
517
+
518
+ Overview of the step function.
519
+
520
+ Details:
521
+ - Step 1: Schedules the sequences to be executed in the next
522
+ iteration and the token blocks to be swapped in/out/copy.
523
+
524
+ - Depending on the scheduling policy,
525
+ sequences may be `preempted/reordered`.
526
+ - A Sequence Group (SG) refer to a group of sequences
527
+ that are generated from the same prompt.
528
+
529
+ - Step 2: Calls the distributed executor to execute the model.
530
+ - Step 3: Processes the model output. This mainly includes:
531
+
532
+ - Decodes the relevant outputs.
533
+ - Updates the scheduled sequence groups with model outputs
534
+ based on its `sampling parameters` (`use_beam_search` or not).
535
+ - Frees the finished sequence groups.
536
+
537
+ - Finally, it creates and returns the newly generated results.
538
+
539
+ Example:
540
+ >>> # Please see the example/ folder for more detailed examples.
541
+ >>>
542
+ >>> # initialize engine and request arguments
543
+ >>> engine = LLMEngine.from_engine_args(engine_args)
544
+ >>> example_inputs = [(0, "What is LLM?",
545
+ >>> SamplingParams(temperature=0.0))]
546
+ >>>
547
+ >>> # Start the engine with an event loop
548
+ >>> while True:
549
+ >>> if example_inputs:
550
+ >>> req_id, prompt, sampling_params = example_inputs.pop(0)
551
+ >>> engine.add_request(str(req_id), prompt, sampling_params)
552
+ >>>
553
+ >>> # continue the request processing
554
+ >>> request_outputs = engine.step()
555
+ >>> for request_output in request_outputs:
556
+ >>> if request_output.finished:
557
+ >>> # return or show the request output
558
+ >>>
559
+ >>> if not (engine.has_unfinished_requests() or example_inputs):
560
+ >>> break
561
+ """
562
+ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
563
+ if not scheduler_outputs.is_empty():
564
+ output = self.model_executor.execute_model(
565
+ seq_group_metadata_list=seq_group_metadata_list,
566
+ blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
567
+ blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
568
+ blocks_to_copy=scheduler_outputs.blocks_to_copy,
569
+ num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
570
+ else:
571
+ output = []
572
+
573
+ request_outputs = self._process_model_outputs(
574
+ output, scheduler_outputs.scheduled_seq_groups,
575
+ scheduler_outputs.ignored_seq_groups)
576
+
577
+ # Log stats.
578
+ if self.log_stats:
579
+ self.stat_logger.log(self._get_stats(scheduler_outputs))
580
+
581
+ return request_outputs
582
+
583
+ def do_log_stats(self) -> None:
584
+ """Forced log when no requests active."""
585
+ if self.log_stats:
586
+ self.stat_logger.log(self._get_stats(scheduler_outputs=None))
587
+
588
+ def _get_stats(self,
589
+ scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
590
+ """Get Stats to be Logged to Prometheus."""
591
+ now = time.time()
592
+
593
+ # KV Cache Usage in %.
594
+ num_total_gpu = self.cache_config.num_gpu_blocks
595
+ num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
596
+ gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
597
+
598
+ num_total_cpu = self.cache_config.num_cpu_blocks
599
+ cpu_cache_usage = 0.
600
+ if num_total_cpu > 0:
601
+ num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
602
+ )
603
+ cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
604
+
605
+ # Scheduler State
606
+ num_running = len(self.scheduler.running)
607
+ num_swapped = len(self.scheduler.swapped)
608
+ num_waiting = len(self.scheduler.waiting)
609
+
610
+ # Iteration stats if we have scheduler output.
611
+ num_prompt_tokens = 0
612
+ num_generation_tokens = 0
613
+ time_to_first_tokens = []
614
+ time_per_output_tokens = []
615
+ time_e2e_requests = []
616
+ if scheduler_outputs is not None:
617
+ prompt_run = scheduler_outputs.num_prefill_groups > 0
618
+
619
+ # Number of Tokens.
620
+ if prompt_run:
621
+ num_prompt_tokens = sum(
622
+ len(scheduled_seq_group.seq_group.prompt_token_ids)
623
+ for scheduled_seq_group in
624
+ scheduler_outputs.scheduled_seq_groups)
625
+ num_generation_tokens = sum(
626
+ scheduled_seq_group.seq_group.num_seqs()
627
+ for scheduled_seq_group in
628
+ scheduler_outputs.scheduled_seq_groups)
629
+ else:
630
+ num_generation_tokens = scheduler_outputs.num_batched_tokens
631
+
632
+ # Latency Timings.
633
+ time_last_iters = []
634
+ for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
635
+ seq_group = scheduled_seq_group.seq_group
636
+ # Time since last token.
637
+ # (n.b. updates seq_group.metrics.last_token_time)
638
+ time_last_iters.append(seq_group.get_last_latency(now))
639
+ # Time since arrival for all finished requests.
640
+ if seq_group.is_finished():
641
+ time_e2e_requests.append(now -
642
+ seq_group.metrics.arrival_time)
643
+
644
+ time_to_first_tokens = time_last_iters if prompt_run else []
645
+ time_per_output_tokens = [] if prompt_run else time_last_iters
646
+
647
+ return Stats(
648
+ now=now,
649
+ num_running=num_running,
650
+ num_swapped=num_swapped,
651
+ num_waiting=num_waiting,
652
+ gpu_cache_usage=gpu_cache_usage,
653
+ cpu_cache_usage=cpu_cache_usage,
654
+ num_prompt_tokens=num_prompt_tokens,
655
+ num_generation_tokens=num_generation_tokens,
656
+ time_to_first_tokens=time_to_first_tokens,
657
+ time_per_output_tokens=time_per_output_tokens,
658
+ time_e2e_requests=time_e2e_requests,
659
+ )
660
+
661
+ def add_lora(self, lora_request: LoRARequest) -> bool:
662
+ return self.model_executor.add_lora(lora_request)
663
+
664
+ def remove_lora(self, lora_id: int) -> bool:
665
+ return self.model_executor.remove_lora(lora_id)
666
+
667
+ def list_loras(self) -> List[int]:
668
+ return self.model_executor.list_loras()
669
+
670
+ def check_health(self) -> None:
671
+ self.model_executor.check_health()
serve/model_runner.py ADDED
@@ -0,0 +1,1223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import time
3
+ from enum import IntEnum
4
+ from typing import Dict, List, NamedTuple, Optional, Set, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
11
+ get_attn_backend)
12
+ from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
13
+ ParallelConfig, SchedulerConfig, VisionLanguageConfig)
14
+ from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
15
+ from vllm.distributed.device_communicators import (custom_all_reduce,
16
+ pynccl_utils)
17
+ from vllm.logger import init_logger
18
+ from vllm.lora.layers import LoRAMapping
19
+ from vllm.lora.request import LoRARequest
20
+ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
21
+ from vllm.model_executor import SamplingMetadata
22
+ from vllm.model_executor.model_loader import get_model
23
+ from vllm.sampling_params import SamplingParams, SamplingType
24
+ from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
25
+ SequenceGroupMetadata)
26
+ from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
27
+ is_pin_memory_available, make_tensor_with_pad,
28
+ maybe_expand_dim)
29
+ from serve.gpt_model import GPT_models
30
+
31
+ logger = init_logger(__name__)
32
+
33
+ _PAD_SLOT_ID = -1
34
+ LORA_WARMUP_RANK = 8
35
+ _BATCH_SIZE_ALIGNMENT = 8
36
+ # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
37
+ # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
38
+ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
39
+ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
40
+ ]
41
+
42
+
43
+ class PreparePromptMetadata(NamedTuple):
44
+ input_tokens: List[int]
45
+ input_positions: List[int]
46
+ attn_metadata: Optional[AttentionMetadataPerStage]
47
+ prompt_lens: List[int]
48
+ subquery_lens: List[int]
49
+ lora_index_mapping: List[int]
50
+ lora_prompt_mapping: List[int]
51
+ lora_requests: Set[LoRARequest]
52
+ multi_modal_input: Optional[torch.Tensor]
53
+ slot_mapping: List[int]
54
+
55
+ @classmethod
56
+ def empty(cls):
57
+ return PreparePromptMetadata(
58
+ input_tokens=[],
59
+ input_positions=[],
60
+ attn_metadata=None,
61
+ prompt_lens=[],
62
+ subquery_lens=[],
63
+ lora_index_mapping=[],
64
+ lora_prompt_mapping=[],
65
+ lora_requests=set(),
66
+ multi_modal_input=None,
67
+ slot_mapping=[],
68
+ )
69
+
70
+
71
+ class PrepareDecodeMetadata(NamedTuple):
72
+ input_tokens: List[int]
73
+ input_positions: List[int]
74
+ attn_metadata: Optional[AttentionMetadata]
75
+ lora_index_mapping: List[int]
76
+ lora_prompt_mapping: List[int]
77
+ lora_requests: Set[LoRARequest]
78
+ slot_mapping: List[int]
79
+
80
+ @classmethod
81
+ def empty(cls):
82
+ return PrepareDecodeMetadata(
83
+ input_tokens=[],
84
+ input_positions=[],
85
+ attn_metadata=None,
86
+ lora_index_mapping=[],
87
+ lora_prompt_mapping=[],
88
+ lora_requests=set(),
89
+ slot_mapping=[],
90
+ )
91
+
92
+
93
+ # How batches are constructed.
94
+ class BatchType(IntEnum):
95
+ # Every batch is prefill.
96
+ PREFILL = 0
97
+ # Every batch is decode.
98
+ DECODE = 1
99
+ # Batch is a mixture of prefill and decode.
100
+ MIXED = 2
101
+
102
+
103
+ class ModelRunner:
104
+
105
+ def __init__(
106
+ self,
107
+ model_config: ModelConfig,
108
+ parallel_config: ParallelConfig,
109
+ scheduler_config: SchedulerConfig,
110
+ device_config: DeviceConfig,
111
+ load_config: LoadConfig,
112
+ lora_config: Optional[LoRAConfig],
113
+ kv_cache_dtype: Optional[str] = "auto",
114
+ is_driver_worker: bool = False,
115
+ vision_language_config: Optional[VisionLanguageConfig] = None,
116
+ ):
117
+ self.model_config = model_config
118
+ self.parallel_config = parallel_config
119
+ self.scheduler_config = scheduler_config
120
+ self.lora_config = lora_config
121
+ self.load_config = load_config
122
+ self.is_driver_worker = is_driver_worker
123
+
124
+ # model_config can be None in tests/samplers/test_sampler.py.
125
+ # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
126
+ self.sliding_window = (model_config.get_sliding_window()
127
+ if model_config is not None else None)
128
+ self.device_config = (device_config
129
+ if device_config is not None else DeviceConfig())
130
+ self.device = self.device_config.device
131
+
132
+ # Set after load_model.
133
+ self.lora_manager: LRUCacheWorkerLoRAManager = None
134
+
135
+ self.graph_runners: Dict[int, CUDAGraphRunner] = {}
136
+ self.graph_memory_pool: Optional[Tuple[
137
+ int, int]] = None # Set during graph capture.
138
+
139
+ self.max_context_len_to_capture = (
140
+ self.model_config.max_context_len_to_capture
141
+ if self.model_config is not None else 0)
142
+
143
+ self.pin_memory = is_pin_memory_available()
144
+ self.kv_cache_dtype = kv_cache_dtype
145
+ self.vision_language_config = vision_language_config
146
+
147
+ self.attn_backend = get_attn_backend(
148
+ self.model_config.dtype if model_config is not None else None)
149
+
150
+ # Lazy initialization
151
+ self.model: torch.nn.Module # Set after load_model
152
+ self.block_size: int # Set after initial profiling.
153
+ # When using CUDA graph, the input block tables must be padded to
154
+ # max_context_len_to_capture. However, creating the block table in
155
+ # Python can be expensive. To optimize this, we cache the block table
156
+ # in numpy and only copy the actual input content at every iteration.
157
+ # The shape of the cached block table will be
158
+ # (max batch size to capture, max context len to capture / block size).
159
+ self.graph_block_tables: torch.Tensor # Set after initial profiling.
160
+
161
+ def load_model(self, args) -> None:
162
+ with CudaMemoryProfiler() as m:
163
+ precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
164
+ latent_size = args.image_size // args.downsample_size
165
+ gpt_model = GPT_models[args.gpt_model](
166
+ vocab_size=args.codebook_size,
167
+ block_size=latent_size ** 2,
168
+ num_classes=args.num_classes,
169
+ cls_token_num=args.cls_token_num,
170
+ model_type=args.gpt_type,
171
+ cfg_scale=args.cfg_scale,
172
+ ).to(device='cuda', dtype=precision) # TODO: make device configurable
173
+
174
+ checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
175
+ if args.from_fsdp: # fspd
176
+ model_weight = checkpoint
177
+ elif "model" in checkpoint: # ddp
178
+ model_weight = checkpoint["model"]
179
+ elif "state_dict" in checkpoint:
180
+ model_weight = checkpoint["state_dict"]
181
+ else:
182
+ raise Exception("please check model weight")
183
+ gpt_model.custom_load_state_dict(model_weight)
184
+ gpt_model.eval()
185
+ del checkpoint
186
+ self.model = gpt_model
187
+
188
+ self.model_memory_usage = m.consumed_memory
189
+ logger.info(f"Loading model weights took "
190
+ f"{self.model_memory_usage / float(2**30):.4f} GB")
191
+
192
+ if self.lora_config:
193
+ assert hasattr(self.model, "supported_lora_modules"
194
+ ) and self.model.supported_lora_modules, (
195
+ "Model does not support LoRA")
196
+ assert hasattr(
197
+ self.model,
198
+ "embedding_modules"), "Model does not have embedding_modules"
199
+ assert hasattr(self.model, "embedding_padding_modules"
200
+ ), "Model does not have embedding_padding_modules"
201
+ self.lora_manager = LRUCacheWorkerLoRAManager(
202
+ self.scheduler_config.max_num_seqs,
203
+ self.scheduler_config.max_num_batched_tokens, self.vocab_size,
204
+ self.lora_config, self.device, self.model.embedding_modules,
205
+ self.model.embedding_padding_modules)
206
+ self.model = self.lora_manager.create_lora_manager(self.model)
207
+
208
+ if self.kv_cache_dtype == "fp8" and is_hip():
209
+ # Currently scaled KV cache is only enabled on ROCm
210
+ if self.model_config.quantization_param_path is not None:
211
+ if callable(getattr(self.model, "load_kv_cache_scales", None)):
212
+ self.model.load_kv_cache_scales(
213
+ self.model_config.quantization_param_path)
214
+ else:
215
+ raise RuntimeError("Using FP8 KV cache and scaling "
216
+ "factors provided but model "
217
+ f"{self.model.__class__} does not "
218
+ "support loading scaling factors.")
219
+ else:
220
+ logger.warn("Using FP8 KV cache but no scaling factors "
221
+ "provided. Defaulting to scaling factors of 1.0. "
222
+ "This may lead to less accurate results!")
223
+ elif self.model_config.quantization_param_path is not None:
224
+ logger.warn("KV cache scaling factors provided, "
225
+ "but the KV cache data type is not FP8. "
226
+ "KV cache scaling factors will not be used.")
227
+
228
+ def set_block_size(self, block_size: int) -> None:
229
+ self.block_size = block_size
230
+
231
+ self.graph_block_tables = np.zeros(
232
+ (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
233
+ dtype=np.int32)
234
+
235
+ def get_max_block_per_batch(self) -> int:
236
+ block_size = self.block_size
237
+ return (self.max_context_len_to_capture + block_size - 1) // block_size
238
+
239
+ def _prepare_prompt(
240
+ self,
241
+ seq_group_metadata_list: List[SequenceGroupMetadata],
242
+ ) -> PreparePromptMetadata:
243
+ input_tokens: List[int] = []
244
+ input_positions: List[int] = []
245
+ slot_mapping: List[int] = []
246
+ lora_index_mapping: List[int] = []
247
+ lora_prompt_mapping: List[int] = []
248
+ lora_requests: Set[LoRARequest] = set()
249
+
250
+ prompt_lens: List[int] = []
251
+ context_lens: List[int] = []
252
+ subquery_lens: List[int] = []
253
+ prefix_block_tables: List[List[int]] = []
254
+ multi_modal_input_list: List[torch.Tensor] = []
255
+
256
+ if len(seq_group_metadata_list) == 0:
257
+ return PreparePromptMetadata.empty()
258
+
259
+ for seq_group_metadata in seq_group_metadata_list:
260
+ assert seq_group_metadata.is_prompt
261
+ seq_ids = list(seq_group_metadata.seq_data.keys())
262
+ assert len(seq_ids) == 1
263
+ seq_id = seq_ids[0]
264
+
265
+ computed_block_nums = seq_group_metadata.computed_block_nums
266
+ if (self.scheduler_config is not None
267
+ and self.scheduler_config.chunked_prefill_enabled
268
+ and not (computed_block_nums is None
269
+ or computed_block_nums == [])):
270
+ raise RuntimeError(
271
+ "chunked prefill cannot be used with prefix caching "
272
+ "now.")
273
+
274
+ token_chunk_size = seq_group_metadata.token_chunk_size
275
+ seq_data = seq_group_metadata.seq_data[seq_id]
276
+ computed_len = seq_data.get_num_computed_tokens()
277
+ # We should use get_len here because in case of preemption
278
+ # it contains output tokens.
279
+ prefill_end = min(seq_data.get_len(),
280
+ computed_len + token_chunk_size)
281
+ prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
282
+ prompt_len = prefill_end
283
+ prompt_lens.append(prompt_len)
284
+
285
+ # NOTE: This only works for oooooooxxx style attention.
286
+ if computed_block_nums is not None and len(
287
+ computed_block_nums) > 0 and self.sliding_window is None:
288
+ # Prefix is not supported with sliding_window
289
+ computed_len = len(computed_block_nums) * self.block_size
290
+ prompt_tokens = prompt_tokens[computed_len:]
291
+ prefix_block_tables.append(computed_block_nums)
292
+ elif self.scheduler_config.chunked_prefill_enabled:
293
+ if seq_group_metadata.block_tables is not None:
294
+ # Prefill has chunked before.
295
+ block_table = seq_group_metadata.block_tables[seq_id]
296
+ prefix_block_tables.append(block_table)
297
+ else:
298
+ # The first prefill.
299
+ prefix_block_tables.append([])
300
+ else:
301
+ prefix_block_tables.append([])
302
+ # Right now, prefill start is always 0. However, this
303
+ # assumption can be changed once chunked prefill is introduced.
304
+ assert computed_len == 0
305
+
306
+ # actual prompt lens
307
+ context_lens.append(computed_len)
308
+ subquery_lens.append(prompt_len - computed_len)
309
+
310
+ input_tokens.extend(prompt_tokens)
311
+ # NOTE(woosuk): Here we assume that the first token in the prompt
312
+ # is always the first token in the sequence.
313
+ input_positions.extend(list(range(computed_len, prefill_end)))
314
+ lora_id = seq_group_metadata.lora_int_id
315
+
316
+ if lora_id > 0:
317
+ lora_requests.add(seq_group_metadata.lora_request)
318
+
319
+ lora_index_mapping += [lora_id] * (prompt_len - computed_len)
320
+ lora_prompt_mapping.extend(
321
+ [lora_id] *
322
+ (prompt_len - computed_len
323
+ if seq_group_metadata.sampling_params.prompt_logprobs else 1))
324
+
325
+ if seq_group_metadata.multi_modal_data:
326
+ multi_modal_input_list.append(
327
+ seq_group_metadata.multi_modal_data.data)
328
+
329
+ if seq_group_metadata.block_tables is None:
330
+ # During memory profiling, the block tables are not initialized
331
+ # yet. In this case, we just use a dummy slot mapping.
332
+ slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
333
+ continue
334
+
335
+ # Compute the slot mapping.
336
+ block_table = seq_group_metadata.block_tables[seq_id]
337
+ # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
338
+ # where start_idx is max(0, prompt_len - sliding_window).
339
+ # For example, if the prompt len is 10, sliding window is 8, and
340
+ # block size is 4, the first two tokens are masked and the slot
341
+ # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
342
+ start_idx = 0
343
+ if self.sliding_window is not None:
344
+ assert computed_len == 0, (
345
+ "Prefix caching is currently not supported with "
346
+ "sliding window attention")
347
+ start_idx = max(0, prompt_len - self.sliding_window)
348
+
349
+ for i in range(computed_len, prefill_end):
350
+ if i < start_idx:
351
+ slot_mapping.append(_PAD_SLOT_ID)
352
+ continue
353
+
354
+ block_number = block_table[i // self.block_size]
355
+ block_offset = i % self.block_size
356
+ slot = block_number * self.block_size + block_offset
357
+ slot_mapping.append(slot)
358
+
359
+ max_subquery_len = max(subquery_lens)
360
+ max_prompt_len = max(prompt_lens)
361
+ assert max_subquery_len > 0
362
+
363
+ context_lens_tensor = torch.tensor(context_lens,
364
+ dtype=torch.int,
365
+ device=self.device)
366
+
367
+ if multi_modal_input_list:
368
+ assert self.vision_language_config, (
369
+ "Multi-modal inputs are only supported by "
370
+ "vision language models.")
371
+ multi_modal_input = torch.cat(multi_modal_input_list,
372
+ dim=0).to(self.device)
373
+ else:
374
+ multi_modal_input = None
375
+
376
+ # Prepare prefix block tables
377
+ max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
378
+ block_tables = make_tensor_with_pad(
379
+ prefix_block_tables,
380
+ max_len=max_prompt_block_table_len,
381
+ pad=0,
382
+ dtype=torch.int,
383
+ device=self.device,
384
+ )
385
+
386
+ # Query length can be shorter than key (i.e., prompt) when prefill
387
+ # is chunked or prefix cached.
388
+ subquery_lens_tensor = torch.tensor(subquery_lens,
389
+ dtype=torch.long,
390
+ device=self.device)
391
+ subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
392
+ dtype=torch.int32,
393
+ device=self.device)
394
+
395
+ prompt_lens_tensor = torch.tensor(prompt_lens,
396
+ dtype=torch.long,
397
+ device=self.device)
398
+ seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
399
+ dtype=torch.int32,
400
+ device=self.device)
401
+
402
+ torch.cumsum(subquery_lens_tensor,
403
+ dim=0,
404
+ dtype=subquery_start_loc.dtype,
405
+ out=subquery_start_loc[1:])
406
+
407
+ torch.cumsum(prompt_lens_tensor,
408
+ dim=0,
409
+ dtype=seq_start_loc.dtype,
410
+ out=seq_start_loc[1:])
411
+
412
+ attn_metadata = self.attn_backend.make_metadata(
413
+ is_prompt=True,
414
+ prompt_lens=prompt_lens,
415
+ prompt_lens_tensor=prompt_lens_tensor,
416
+ max_subquery_len=max_subquery_len,
417
+ max_context_len=None,
418
+ max_prompt_len=max_prompt_len,
419
+ subquery_start_loc=subquery_start_loc,
420
+ seq_start_loc=seq_start_loc,
421
+ context_lens=context_lens_tensor,
422
+ block_tables=block_tables,
423
+ use_cuda_graph=False,
424
+ )
425
+
426
+ return PreparePromptMetadata(
427
+ input_tokens=input_tokens,
428
+ input_positions=input_positions,
429
+ attn_metadata=attn_metadata,
430
+ prompt_lens=prompt_lens,
431
+ subquery_lens=subquery_lens,
432
+ lora_index_mapping=lora_index_mapping,
433
+ lora_prompt_mapping=lora_prompt_mapping,
434
+ lora_requests=lora_requests,
435
+ multi_modal_input=multi_modal_input,
436
+ slot_mapping=slot_mapping,
437
+ )
438
+
439
+ def _prepare_decode(
440
+ self,
441
+ seq_group_metadata_list: List[SequenceGroupMetadata],
442
+ ) -> PrepareDecodeMetadata:
443
+ input_tokens: List[int] = []
444
+ input_positions: List[int] = []
445
+ slot_mapping: List[int] = []
446
+ context_lens: List[int] = []
447
+ block_tables: List[List[int]] = []
448
+ lora_index_mapping: List[int] = []
449
+ lora_prompt_mapping: List[int] = []
450
+ lora_requests: Set[LoRARequest] = set()
451
+
452
+ if len(seq_group_metadata_list) == 0:
453
+ return PrepareDecodeMetadata.empty()
454
+
455
+ for seq_group_metadata in seq_group_metadata_list:
456
+ assert not seq_group_metadata.is_prompt
457
+ assert seq_group_metadata.token_chunk_size == 1
458
+
459
+ seq_ids = list(seq_group_metadata.seq_data.keys())
460
+ lora_id = seq_group_metadata.lora_int_id
461
+
462
+ if lora_id > 0:
463
+ lora_requests.add(seq_group_metadata.lora_request)
464
+
465
+ for seq_id in seq_ids:
466
+ seq_data = seq_group_metadata.seq_data[seq_id]
467
+ generation_token = seq_data.get_last_token_id()
468
+ input_tokens.append(generation_token)
469
+
470
+ seq_len = seq_data.get_len()
471
+ position = seq_len - 1
472
+ input_positions.append(position)
473
+
474
+ context_len = seq_len if self.sliding_window is None else min(
475
+ seq_len, self.sliding_window)
476
+ context_lens.append(context_len)
477
+
478
+ block_table = seq_group_metadata.block_tables[seq_id]
479
+ block_number = block_table[position // self.block_size]
480
+ block_offset = position % self.block_size
481
+ slot = block_number * self.block_size + block_offset
482
+ slot_mapping.append(slot)
483
+ lora_index_mapping.append(lora_id)
484
+ lora_prompt_mapping.append(lora_id)
485
+
486
+ if self.sliding_window is not None:
487
+ sliding_window_blocks = (self.sliding_window //
488
+ self.block_size)
489
+ block_table = block_table[-sliding_window_blocks:]
490
+ block_tables.append(block_table)
491
+
492
+ # vLLM uses cuda graph only for decoding requests.
493
+ # See `capture_model` API for more details.
494
+ # For decoding requests, batch_size == input_tokens.
495
+ batch_size = len(input_tokens)
496
+ max_context_len = max(context_lens)
497
+ use_captured_graph = (
498
+ not self.model_config.enforce_eager
499
+ and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
500
+ and max_context_len <= self.max_context_len_to_capture)
501
+ if use_captured_graph:
502
+ graph_batch_size = _get_graph_batch_size(batch_size)
503
+ assert graph_batch_size >= batch_size
504
+ for _ in range(graph_batch_size - batch_size):
505
+ input_tokens.append(0)
506
+ input_positions.append(0)
507
+ slot_mapping.append(_PAD_SLOT_ID)
508
+ context_lens.append(1)
509
+ block_tables.append([])
510
+ lora_index_mapping.append(0)
511
+ batch_size = graph_batch_size
512
+
513
+ context_lens_tensor = torch.tensor(context_lens,
514
+ dtype=torch.int,
515
+ device=self.device)
516
+
517
+ if use_captured_graph:
518
+ # When using cuda-graph all these tensors should be
519
+ # padded.
520
+ assert context_lens_tensor.shape[0] == len(input_tokens)
521
+ assert context_lens_tensor.shape[0] == len(input_positions)
522
+ assert context_lens_tensor.shape[0] == len(slot_mapping)
523
+
524
+ # The shape of graph_block_tables is
525
+ # [max batch size, max context len // block size].
526
+ input_block_tables = self.graph_block_tables[:batch_size]
527
+ for i, block_table in enumerate(block_tables):
528
+ if block_table:
529
+ input_block_tables[i, :len(block_table)] = block_table
530
+ block_tables = torch.tensor(input_block_tables, device=self.device)
531
+ else:
532
+ max_block_table_len = max(
533
+ len(block_table) for block_table in block_tables)
534
+ block_tables = make_tensor_with_pad(
535
+ block_tables,
536
+ max_len=max_block_table_len,
537
+ pad=0,
538
+ dtype=torch.int,
539
+ device=self.device,
540
+ )
541
+
542
+ attn_metadata = self.attn_backend.make_metadata(
543
+ is_prompt=False,
544
+ prompt_lens=None,
545
+ prompt_lens_tensor=None,
546
+ max_subquery_len=None,
547
+ max_context_len=max_context_len,
548
+ max_prompt_len=None,
549
+ subquery_start_loc=None,
550
+ seq_start_loc=None,
551
+ context_lens=context_lens_tensor,
552
+ block_tables=block_tables,
553
+ use_cuda_graph=use_captured_graph,
554
+ )
555
+ return PrepareDecodeMetadata(
556
+ input_tokens=input_tokens,
557
+ input_positions=input_positions,
558
+ attn_metadata=attn_metadata,
559
+ lora_index_mapping=lora_index_mapping,
560
+ lora_prompt_mapping=lora_prompt_mapping,
561
+ lora_requests=lora_requests,
562
+ slot_mapping=slot_mapping,
563
+ )
564
+
565
+ def _prepare_sample(
566
+ self,
567
+ seq_group_metadata_list: List[SequenceGroupMetadata],
568
+ prompt_lens: List[int],
569
+ subquery_lens: Optional[List[int]],
570
+ ) -> SamplingMetadata:
571
+ seq_groups: List[Tuple[List[int], SamplingParams]] = []
572
+ selected_token_indices: List[int] = []
573
+ generators: List[torch.Generator] = []
574
+ selected_token_start_idx = 0
575
+ categorized_sample_indices: Dict[SamplingType,
576
+ List[Tuple[int, int]]] = {
577
+ t: []
578
+ for t in SamplingType
579
+ }
580
+ categorized_sample_indices_start_idx = 0
581
+ categorized_sampled_token_indices_start_idx = 0
582
+
583
+ for i, seq_group_metadata in enumerate(seq_group_metadata_list):
584
+ seq_ids = list(seq_group_metadata.seq_data.keys())
585
+ sampling_params = seq_group_metadata.sampling_params
586
+ seq_groups.append((seq_ids, sampling_params))
587
+
588
+ if seq_group_metadata.is_prompt:
589
+ assert len(seq_ids) == 1
590
+ assert subquery_lens is not None
591
+ subquery_len = subquery_lens[i]
592
+ if sampling_params.prompt_logprobs is not None:
593
+ # NOTE: prompt token positions do not need sample, skip
594
+ categorized_sample_indices_start_idx += subquery_len - 1
595
+
596
+ categorized_sample_indices[
597
+ sampling_params.sampling_type].append(
598
+ (categorized_sample_indices_start_idx,
599
+ categorized_sampled_token_indices_start_idx))
600
+ categorized_sample_indices_start_idx += 1
601
+ categorized_sampled_token_indices_start_idx += 1
602
+
603
+ if sampling_params.prompt_logprobs is not None:
604
+ selected_token_indices.extend(
605
+ range(selected_token_start_idx,
606
+ selected_token_start_idx + subquery_len - 1))
607
+ selected_token_indices.append(selected_token_start_idx +
608
+ subquery_len - 1)
609
+ selected_token_start_idx += subquery_len
610
+
611
+ if sampling_params.seed is not None:
612
+ seq_group_metadata.state.generator = torch.Generator(
613
+ device=self.device).manual_seed(sampling_params.seed)
614
+ else:
615
+ num_seqs = len(seq_ids)
616
+ selected_token_indices.extend(
617
+ range(selected_token_start_idx,
618
+ selected_token_start_idx + num_seqs))
619
+ selected_token_start_idx += num_seqs
620
+
621
+ categorized_sample_indices[
622
+ sampling_params.sampling_type].extend(
623
+ list(
624
+ zip(
625
+ range(
626
+ categorized_sample_indices_start_idx,
627
+ categorized_sample_indices_start_idx +
628
+ num_seqs),
629
+ range(
630
+ categorized_sampled_token_indices_start_idx,
631
+ categorized_sampled_token_indices_start_idx
632
+ + num_seqs))))
633
+ categorized_sample_indices_start_idx += num_seqs
634
+ categorized_sampled_token_indices_start_idx += num_seqs
635
+
636
+ if sampling_params.seed is not None:
637
+ generators.append(seq_group_metadata.state.generator)
638
+
639
+ selected_token_indices = async_tensor_h2d(selected_token_indices,
640
+ dtype=torch.long,
641
+ target_device=self.device,
642
+ pin_memory=self.pin_memory)
643
+
644
+ categorized_sample_indices = {
645
+ t: maybe_expand_dim(
646
+ async_tensor_h2d(seq_ids,
647
+ dtype=torch.int,
648
+ target_device=self.device,
649
+ pin_memory=self.pin_memory), 2, 2)
650
+ for t, seq_ids in categorized_sample_indices.items()
651
+ }
652
+
653
+ seq_data: Dict[int, SequenceData] = {}
654
+ for seq_group_metadata in seq_group_metadata_list:
655
+ seq_data.update(seq_group_metadata.seq_data)
656
+
657
+ sampling_metadata = SamplingMetadata(
658
+ seq_groups=seq_groups,
659
+ seq_data=seq_data,
660
+ prompt_lens=prompt_lens,
661
+ selected_token_indices=selected_token_indices,
662
+ categorized_sample_indices=categorized_sample_indices,
663
+ generators=generators,
664
+ )
665
+ return sampling_metadata
666
+
667
+ def prepare_input_tensors(
668
+ self,
669
+ seq_group_metadata_list: List[SequenceGroupMetadata],
670
+ ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
671
+ Set[LoRARequest], LoRAMapping, torch.Tensor]:
672
+ if self.is_driver_worker:
673
+ prefill_reqs = []
674
+ decode_reqs = []
675
+ for seq_group_meta in seq_group_metadata_list:
676
+ if seq_group_meta.is_prompt:
677
+ prefill_reqs.append(seq_group_meta)
678
+ else:
679
+ decode_reqs.append(seq_group_meta)
680
+
681
+ # Prepare input tensors.
682
+ (
683
+ input_tokens,
684
+ input_positions,
685
+ prefill_attn_metadata,
686
+ prompt_lens,
687
+ subquery_lens,
688
+ lora_index_mapping,
689
+ lora_prompt_mapping,
690
+ lora_requests,
691
+ multi_modal_input,
692
+ slot_mapping,
693
+ ) = self._prepare_prompt(prefill_reqs)
694
+ (
695
+ decode_input_tokens,
696
+ decode_input_positions,
697
+ decode_attn_metadata,
698
+ decode_lora_index_mapping,
699
+ decode_lora_prompt_mapping,
700
+ decode_lora_requests,
701
+ decode_slot_mapping,
702
+ ) = self._prepare_decode(decode_reqs)
703
+ sampling_metadata = self._prepare_sample(seq_group_metadata_list,
704
+ prompt_lens,
705
+ subquery_lens)
706
+
707
+ if not self.scheduler_config.chunked_prefill_enabled:
708
+ assert (len(prefill_reqs) and len(decode_reqs)) == 0
709
+
710
+ num_prefills = len(prompt_lens)
711
+ num_prefill_tokens = len(input_tokens)
712
+ num_decode_tokens = len(decode_input_tokens)
713
+
714
+ # Coalesce tensors. Note that attn_metadata is currently not
715
+ # coalesced for simplicity.
716
+ input_tokens.extend(decode_input_tokens)
717
+ input_positions.extend(decode_input_positions)
718
+ slot_mapping.extend(decode_slot_mapping)
719
+ lora_index_mapping.extend(decode_lora_index_mapping)
720
+ lora_prompt_mapping.extend(decode_lora_prompt_mapping)
721
+ lora_requests.update(decode_lora_requests)
722
+
723
+ input_tokens = torch.tensor(input_tokens,
724
+ dtype=torch.long,
725
+ device=self.device)
726
+ input_positions = torch.tensor(input_positions,
727
+ dtype=torch.long,
728
+ device=self.device)
729
+ slot_mapping = torch.tensor(slot_mapping,
730
+ dtype=torch.long,
731
+ device=self.device)
732
+
733
+ if self.lora_config:
734
+ lora_mapping = LoRAMapping(
735
+ lora_index_mapping,
736
+ lora_prompt_mapping,
737
+ )
738
+ else:
739
+ lora_mapping = None
740
+
741
+ # Broadcast the metadata.
742
+ # If batch contains both prefill and decode, it sends 2 broadcasts.
743
+ # If it only contains 1 type, it triggers a single broadcast.
744
+ if (prefill_attn_metadata is not None
745
+ and decode_attn_metadata is not None):
746
+ batch_type = BatchType.MIXED
747
+ elif prefill_attn_metadata is not None:
748
+ batch_type = BatchType.PREFILL
749
+ else:
750
+ batch_type = BatchType.DECODE
751
+
752
+ metadata_dict = {
753
+ "input_tokens": input_tokens,
754
+ "input_positions": input_positions,
755
+ "selected_token_indices":
756
+ sampling_metadata.selected_token_indices,
757
+ "lora_requests": lora_requests,
758
+ "lora_mapping": lora_mapping,
759
+ "multi_modal_input": multi_modal_input,
760
+ "num_prefill_tokens": num_prefill_tokens,
761
+ "num_decode_tokens": num_decode_tokens,
762
+ "slot_mapping": slot_mapping,
763
+ "num_prefills": num_prefills,
764
+ "batch_type": batch_type,
765
+ }
766
+ if prefill_attn_metadata is not None:
767
+ metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
768
+ else:
769
+ assert decode_attn_metadata is not None
770
+ metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
771
+ broadcast_tensor_dict(metadata_dict, src=0)
772
+
773
+ # Broadcast decode attn metadata for mixed batch type.
774
+ # The additional broadcast costs 300us overhead on 4 A10 GPUs.
775
+ # We can potentially reduce the overhead by coelescing tensors.
776
+ if batch_type == BatchType.MIXED:
777
+ assert decode_attn_metadata is not None
778
+ metadata_dict = decode_attn_metadata.asdict_zerocopy()
779
+ broadcast_tensor_dict(metadata_dict, src=0)
780
+ else:
781
+ metadata_dict = broadcast_tensor_dict(src=0)
782
+ input_tokens = metadata_dict.pop("input_tokens")
783
+ input_positions = metadata_dict.pop("input_positions")
784
+ slot_mapping = metadata_dict.pop("slot_mapping")
785
+ num_prefills = metadata_dict.pop("num_prefills")
786
+ selected_token_indices = metadata_dict.pop(
787
+ "selected_token_indices")
788
+ lora_mapping = metadata_dict.pop("lora_mapping")
789
+ lora_requests = metadata_dict.pop("lora_requests")
790
+ multi_modal_input = metadata_dict.pop("multi_modal_input")
791
+ num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
792
+ num_decode_tokens = metadata_dict.pop("num_decode_tokens")
793
+ batch_type = metadata_dict.pop("batch_type")
794
+
795
+ # Create an attention metadata.
796
+ prefill_attn_metadata = None
797
+ decode_attn_metadata = None
798
+ if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
799
+ prefill_attn_metadata = self.attn_backend.make_metadata(
800
+ **metadata_dict)
801
+ else:
802
+ decode_attn_metadata = self.attn_backend.make_metadata(
803
+ **metadata_dict)
804
+ sampling_metadata = SamplingMetadata(
805
+ seq_groups=None,
806
+ seq_data=None,
807
+ prompt_lens=None,
808
+ selected_token_indices=selected_token_indices,
809
+ categorized_sample_indices=None,
810
+ generators=None,
811
+ perform_sampling=False,
812
+ )
813
+
814
+ # if it is a mixed batch, decode attn_metadata is broadcasted
815
+ # separately.
816
+ if batch_type == BatchType.MIXED:
817
+ metadata_dict = broadcast_tensor_dict(src=0)
818
+ decode_attn_metadata = self.attn_backend.make_metadata(
819
+ **metadata_dict)
820
+
821
+ attn_metadata = AttentionMetadata(
822
+ num_prefills=num_prefills,
823
+ slot_mapping=slot_mapping,
824
+ num_prefill_tokens=num_prefill_tokens,
825
+ num_decode_tokens=num_decode_tokens,
826
+ prefill_metadata=prefill_attn_metadata,
827
+ decode_metadata=decode_attn_metadata,
828
+ kv_cache_dtype=self.kv_cache_dtype,
829
+ )
830
+
831
+ return (input_tokens, input_positions, attn_metadata,
832
+ sampling_metadata, lora_requests, lora_mapping,
833
+ multi_modal_input)
834
+
835
+ @torch.inference_mode()
836
+ def execute_model(
837
+ self,
838
+ seq_group_metadata_list: List[SequenceGroupMetadata],
839
+ kv_caches: List[torch.Tensor],
840
+ ) -> Optional[SamplerOutput]:
841
+ (input_tokens, input_positions, attn_metadata, sampling_metadata,
842
+ lora_requests, lora_mapping, multi_modal_input
843
+ ) = self.prepare_input_tensors(seq_group_metadata_list)
844
+ if self.lora_config:
845
+ self.set_active_loras(lora_requests, lora_mapping)
846
+
847
+ # Currently cuda graph is only supported by the decode phase.
848
+ prefill_meta = attn_metadata.prefill_metadata
849
+ decode_meta = attn_metadata.decode_metadata
850
+ if prefill_meta is None and decode_meta.use_cuda_graph:
851
+ graph_batch_size = input_tokens.shape[0]
852
+ model_executable = self.graph_runners[graph_batch_size]
853
+ else:
854
+ model_executable = self.model
855
+ execute_model_kwargs = {
856
+ "input_ids": input_tokens,
857
+ "positions": input_positions,
858
+ "kv_caches": kv_caches,
859
+ "attn_metadata": attn_metadata,
860
+ }
861
+ if self.vision_language_config:
862
+ execute_model_kwargs.update({"image_input": multi_modal_input})
863
+ hidden_states = model_executable(**execute_model_kwargs)
864
+
865
+ # Compute the logits.
866
+ logits = self.model.compute_logits(hidden_states, sampling_metadata)
867
+
868
+ # Only perform sampling in the driver worker.
869
+ if not sampling_metadata.perform_sampling:
870
+ return None
871
+
872
+ # Sample the next token.
873
+ output = self.model.sample(
874
+ logits=logits,
875
+ sampling_metadata=sampling_metadata,
876
+ )
877
+ return output
878
+
879
+ @torch.inference_mode()
880
+ def profile_run(self) -> None:
881
+ # Enable top-k sampling to reflect the accurate memory usage.
882
+ sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
883
+ max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
884
+ max_num_seqs = self.scheduler_config.max_num_seqs
885
+
886
+ # This represents the maximum number of different requests
887
+ # that will have unique loras, an therefore the max amount of memory
888
+ # consumption create dummy lora request copies from the lora request
889
+ # passed in, which contains a lora from the lora warmup path.
890
+ dummy_lora_requests = []
891
+ dummy_lora_requests_per_seq = []
892
+ if self.lora_config:
893
+ for idx in range(self.lora_config.max_loras):
894
+ lora_id = idx + 1
895
+ dummy_lora_request = LoRARequest(
896
+ lora_name=f"warmup_{lora_id}",
897
+ lora_int_id=lora_id,
898
+ lora_local_path="/not/a/real/path",
899
+ )
900
+ self.lora_manager.add_dummy_lora(dummy_lora_request,
901
+ rank=LORA_WARMUP_RANK)
902
+ dummy_lora_requests.append(dummy_lora_request)
903
+ dummy_lora_requests_per_seq = [
904
+ dummy_lora_requests[idx % len(dummy_lora_requests)]
905
+ for idx in range(max_num_seqs)
906
+ ]
907
+
908
+ # Profile memory usage with max_num_sequences sequences and the total
909
+ # number of tokens equal to max_num_batched_tokens.
910
+ seqs: List[SequenceGroupMetadata] = []
911
+ # Additional GPU memory may be needed for vision encoding, which needs
912
+ # to be accounted for when calculating the GPU blocks for
913
+ # vLLM blocker manager.
914
+ # To exercise the worst scenario for GPU memory consumption,
915
+ # the number of seqs (batch_size) is chosen to maximize the number
916
+ # of images processed.
917
+ if self.vision_language_config:
918
+ max_num_seqs = min(
919
+ max_num_seqs,
920
+ int(max_num_batched_tokens /
921
+ self.vision_language_config.image_feature_size))
922
+ for group_id in range(max_num_seqs):
923
+ seq_len = (max_num_batched_tokens // max_num_seqs +
924
+ (group_id < max_num_batched_tokens % max_num_seqs))
925
+ seq_data, fake_multi_modal_input = _prepare_fake_inputs(
926
+ seq_len, self.vision_language_config)
927
+ seq = SequenceGroupMetadata(
928
+ request_id=str(group_id),
929
+ is_prompt=True,
930
+ seq_data={group_id: seq_data},
931
+ sampling_params=sampling_params,
932
+ block_tables=None,
933
+ lora_request=dummy_lora_requests_per_seq[group_id]
934
+ if dummy_lora_requests_per_seq else None,
935
+ multi_modal_data=fake_multi_modal_input,
936
+ )
937
+ seqs.append(seq)
938
+
939
+ # Run the model with the dummy inputs.
940
+ num_layers = self.model_config.get_num_layers(self.parallel_config)
941
+ kv_caches = [None] * num_layers
942
+ self.execute_model(seqs, kv_caches)
943
+ torch.cuda.synchronize()
944
+ return
945
+
946
+ def remove_all_loras(self) -> bool:
947
+ if not self.lora_manager:
948
+ raise RuntimeError("LoRA is not enabled.")
949
+ return self.lora_manager.remove_all_loras()
950
+
951
+ def set_active_loras(self, lora_requests: Set[LoRARequest],
952
+ lora_mapping: LoRAMapping) -> None:
953
+ if not self.lora_manager:
954
+ raise RuntimeError("LoRA is not enabled.")
955
+ self.lora_manager.set_active_loras(lora_requests, lora_mapping)
956
+
957
+ def add_lora(self, lora_request: LoRARequest) -> bool:
958
+ if not self.lora_manager:
959
+ raise RuntimeError("LoRA is not enabled.")
960
+ return self.lora_manager.add_lora(lora_request)
961
+
962
+ def remove_lora(self, lora_id: int) -> bool:
963
+ if not self.lora_manager:
964
+ raise RuntimeError("LoRA is not enabled.")
965
+ return self.lora_manager.remove_lora(lora_id)
966
+
967
+ def list_loras(self) -> Set[int]:
968
+ if not self.lora_manager:
969
+ raise RuntimeError("LoRA is not enabled.")
970
+ return self.lora_manager.list_loras()
971
+
972
+ @torch.inference_mode()
973
+ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
974
+ """Cuda graph capture a model.
975
+
976
+ Note that CUDA graph's performance gain is negligible if number
977
+ of batched tokens are larger than 200. And since CUDA graph
978
+ requires fixed sized tensors, supporting large/variable batch
979
+ size requires high GPU memory overhead. Thus, vLLM only captures
980
+ decoding requests. Mixed batch (chunked prefill + decoding) or
981
+ prefill requests are not captured.
982
+
983
+ Since it is used for decoding-only, it assumes there's only 1 token
984
+ per sequence in the batch.
985
+ """
986
+ # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
987
+ # deleted before the CUDA graphs.
988
+ self.pynccl_backend = pynccl_utils.get_nccl_backend()
989
+
990
+ assert not self.model_config.enforce_eager
991
+ logger.info("Capturing the model for CUDA graphs. This may lead to "
992
+ "unexpected consequences if the model is not static. To "
993
+ "run the model in eager mode, set 'enforce_eager=True' or "
994
+ "use '--enforce-eager' in the CLI.")
995
+ logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
996
+ "If you are running out of memory, consider decreasing "
997
+ "`gpu_memory_utilization` or enforcing eager mode. "
998
+ "You can also reduce the `max_num_seqs` as needed "
999
+ "to decrease memory usage.")
1000
+ start_time = time.perf_counter()
1001
+
1002
+ # Prepare dummy inputs. These will be reused for all batch sizes.
1003
+ max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
1004
+ input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
1005
+ input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
1006
+ slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
1007
+ slot_mapping.fill_(_PAD_SLOT_ID)
1008
+ context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
1009
+ block_tables = torch.from_numpy(self.graph_block_tables).cuda()
1010
+
1011
+ graph_batch_size = _get_graph_batch_size(
1012
+ self.scheduler_config.max_num_seqs)
1013
+ batch_size_capture_list = [
1014
+ bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
1015
+ ]
1016
+
1017
+ # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
1018
+ # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
1019
+ # either custom all-reduce kernel or pynccl. When not using CUDA
1020
+ # graph, we use either custom all-reduce kernel or PyTorch NCCL.
1021
+ # We always prioritize using custom all-reduce kernel but fall back
1022
+ # to PyTorch or pynccl if it is disabled or not supported.
1023
+ with custom_all_reduce.capture():
1024
+ # NOTE: Capturing the largest batch size first may help reduce the
1025
+ # memory usage of CUDA graph.
1026
+ for batch_size in reversed(batch_size_capture_list):
1027
+ # Create dummy attn_metadata.
1028
+ decode_metadata = self.attn_backend.make_metadata(
1029
+ is_prompt=False,
1030
+ prompt_lens=None,
1031
+ prompt_lens_tensor=None,
1032
+ max_subquery_len=None,
1033
+ max_context_len=self.max_context_len_to_capture,
1034
+ max_prompt_len=None,
1035
+ subquery_start_loc=None,
1036
+ seq_start_loc=None,
1037
+ context_lens=context_lens[:batch_size],
1038
+ block_tables=block_tables[:batch_size],
1039
+ use_cuda_graph=True,
1040
+ )
1041
+ attn_metadata = AttentionMetadata(
1042
+ num_prefills=0,
1043
+ num_prefill_tokens=0,
1044
+ num_decode_tokens=batch_size,
1045
+ slot_mapping=slot_mapping[:batch_size],
1046
+ prefill_metadata=None,
1047
+ decode_metadata=decode_metadata,
1048
+ kv_cache_dtype=self.kv_cache_dtype,
1049
+ )
1050
+
1051
+ if self.lora_config:
1052
+ lora_mapping = LoRAMapping(
1053
+ [0] * batch_size,
1054
+ [0] * batch_size,
1055
+ )
1056
+ self.set_active_loras(set(), lora_mapping)
1057
+
1058
+ graph_runner = CUDAGraphRunner(self.model)
1059
+ graph_runner.capture(
1060
+ input_tokens[:batch_size],
1061
+ input_positions[:batch_size],
1062
+ kv_caches,
1063
+ attn_metadata,
1064
+ memory_pool=self.graph_memory_pool,
1065
+ )
1066
+ self.graph_memory_pool = graph_runner.graph.pool()
1067
+ self.graph_runners[batch_size] = graph_runner
1068
+
1069
+ end_time = time.perf_counter()
1070
+ elapsed_time = end_time - start_time
1071
+ # This usually takes < 10 seconds.
1072
+ logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
1073
+
1074
+ def __del__(self) -> None:
1075
+ # Delete the CUDA graphs before deleting the pynccl communicator.
1076
+ # NOTE(woosuk): This is necessary because otherwise deadlocks can
1077
+ # happen.
1078
+ # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
1079
+ # TODO(youkaichao): when we get enough user feedback that pynccl is
1080
+ # more stable than cupy, we can remove this, e.g. in v0.4.1.
1081
+ self.graph_runners.clear()
1082
+ self.pynccl_backend = None
1083
+
1084
+ @property
1085
+ def vocab_size(self) -> int:
1086
+ return self.model_config.get_vocab_size()
1087
+
1088
+
1089
+ class CUDAGraphRunner:
1090
+
1091
+ def __init__(self, model: nn.Module):
1092
+ self.model = model
1093
+ self.input_buffers: Dict[str, torch.Tensor] = {}
1094
+ self.output_buffers: Dict[str, torch.Tensor] = {}
1095
+
1096
+ self._graph: Optional[torch.cuda.CUDAGraph] = None
1097
+
1098
+ @property
1099
+ def graph(self):
1100
+ assert self._graph is not None
1101
+ return self._graph
1102
+
1103
+ def capture(
1104
+ self,
1105
+ input_ids: torch.Tensor,
1106
+ positions: torch.Tensor,
1107
+ kv_caches: List[torch.Tensor],
1108
+ attn_metadata: AttentionMetadata,
1109
+ memory_pool,
1110
+ **kwargs,
1111
+ ) -> None:
1112
+ assert self._graph is None
1113
+ # Run the model once without capturing the graph.
1114
+ # This is to make sure that the captured graph does not include the
1115
+ # kernel launches for initial benchmarking (e.g., Triton autotune).
1116
+ with _maybe_pynccl():
1117
+ self.model(
1118
+ input_ids,
1119
+ positions,
1120
+ kv_caches,
1121
+ attn_metadata,
1122
+ **kwargs,
1123
+ )
1124
+ torch.cuda.synchronize()
1125
+
1126
+ # Capture the graph.
1127
+ # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
1128
+ # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
1129
+ self._graph = torch.cuda.CUDAGraph()
1130
+ with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
1131
+ with _maybe_pynccl():
1132
+ hidden_states = self.model(
1133
+ input_ids,
1134
+ positions,
1135
+ kv_caches,
1136
+ attn_metadata,
1137
+ **kwargs,
1138
+ )
1139
+ torch.cuda.synchronize()
1140
+
1141
+ # Save the input and output buffers.
1142
+ self.input_buffers = {
1143
+ "input_ids": input_ids,
1144
+ "positions": positions,
1145
+ "kv_caches": kv_caches,
1146
+ "slot_mapping": attn_metadata.slot_mapping,
1147
+ "context_lens": attn_metadata.decode_metadata.context_lens,
1148
+ "block_tables": attn_metadata.decode_metadata.block_tables,
1149
+ }
1150
+ self.output_buffers = {"hidden_states": hidden_states}
1151
+ return
1152
+
1153
+ def forward(
1154
+ self,
1155
+ input_ids: torch.Tensor,
1156
+ positions: torch.Tensor,
1157
+ kv_caches: List[torch.Tensor],
1158
+ attn_metadata: AttentionMetadata,
1159
+ **kwargs,
1160
+ ) -> torch.Tensor:
1161
+ # KV caches are fixed tensors, so we don't need to copy them.
1162
+ del kv_caches
1163
+
1164
+ # Copy the input tensors to the input buffers.
1165
+ self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
1166
+ self.input_buffers["positions"].copy_(positions, non_blocking=True)
1167
+ self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1168
+ non_blocking=True)
1169
+ self.input_buffers["context_lens"].copy_(
1170
+ attn_metadata.decode_metadata.context_lens, non_blocking=True)
1171
+ self.input_buffers["block_tables"].copy_(
1172
+ attn_metadata.decode_metadata.block_tables, non_blocking=True)
1173
+ # Run the graph.
1174
+ self.graph.replay()
1175
+
1176
+ # Return the output tensor.
1177
+ return self.output_buffers["hidden_states"]
1178
+
1179
+ def __call__(self, *args, **kwargs):
1180
+ return self.forward(*args, **kwargs)
1181
+
1182
+
1183
+ @contextlib.contextmanager
1184
+ def _maybe_pynccl():
1185
+ if pynccl_utils.is_initialized(
1186
+ ) and not custom_all_reduce.is_initialized():
1187
+ with with_pynccl_for_all_reduce():
1188
+ yield
1189
+ else:
1190
+ yield
1191
+
1192
+
1193
+ def _get_graph_batch_size(batch_size: int) -> int:
1194
+ """Returns the padded batch size given actual batch size.
1195
+
1196
+ Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
1197
+ 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
1198
+ """
1199
+ if batch_size <= 2:
1200
+ return batch_size
1201
+ elif batch_size <= 4:
1202
+ return 4
1203
+ else:
1204
+ return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
1205
+ _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1206
+
1207
+
1208
+ def _prepare_fake_inputs(
1209
+ seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
1210
+ """Prepare fake inputs for profile run."""
1211
+ if vision_language_config:
1212
+ prompt_tokens = [
1213
+ vision_language_config.image_token_id
1214
+ ] * vision_language_config.image_feature_size + [0] * (
1215
+ seq_len - vision_language_config.image_feature_size)
1216
+ fake_image_input = MultiModalData(
1217
+ type=MultiModalData.Type.IMAGE,
1218
+ data=torch.zeros(vision_language_config.image_input_shape,
1219
+ dtype=torch.float16))
1220
+ else:
1221
+ prompt_tokens = [0] * seq_len
1222
+ fake_image_input = None
1223
+ return SequenceData(prompt_tokens), fake_image_input
serve/sample_c2i.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import argparse
3
+ import torch
4
+ from torchvision.utils import save_image
5
+
6
+ from tokenizer.tokenizer_image.vq_model import VQ_models
7
+ from serve.gpt_model import GPT_models
8
+ from serve.llm import LLM
9
+ from vllm import SamplingParams
10
+
11
+
12
+ def main(args):
13
+ # Setup PyTorch:
14
+ torch.manual_seed(args.seed)
15
+ torch.backends.cudnn.deterministic = True
16
+ torch.backends.cudnn.benchmark = False
17
+ torch.set_grad_enabled(False)
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ # create and load model
21
+ vq_model = VQ_models[args.vq_model](
22
+ codebook_size=args.codebook_size,
23
+ codebook_embed_dim=args.codebook_embed_dim)
24
+ vq_model.to(device)
25
+ vq_model.eval()
26
+ checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
27
+ vq_model.load_state_dict(checkpoint["model"])
28
+ del checkpoint
29
+ print(f"image tokenizer is loaded")
30
+
31
+ # Labels to condition the model with (feel free to change):
32
+ class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
33
+ latent_size = args.image_size // args.downsample_size
34
+ qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
35
+ prompt_token_ids = [[cind] for cind in class_labels]
36
+ if args.cfg_scale > 1.0:
37
+ prompt_token_ids.extend([[args.num_classes] for _ in range(len(prompt_token_ids))])
38
+ # Create an LLM.
39
+ llm = LLM(
40
+ args=args,
41
+ model='autoregressive/serve/fake_json/{}.json'.format(args.gpt_model),
42
+ gpu_memory_utilization=0.9,
43
+ skip_tokenizer_init=True)
44
+ print(f"gpt model is loaded")
45
+
46
+ # Create a sampling params object.
47
+ sampling_params = SamplingParams(
48
+ temperature=args.temperature, top_p=args.top_p, top_k=args.top_k,
49
+ max_tokens=latent_size ** 2)
50
+
51
+ # Generate texts from the prompts. The output is a list of RequestOutput objects
52
+ # that contain the prompt, generated text, and other information.
53
+ t1 = time.time()
54
+ outputs = llm.generate(
55
+ prompt_token_ids=prompt_token_ids,
56
+ sampling_params=sampling_params,
57
+ use_tqdm=False)
58
+ sampling_time = time.time() - t1
59
+ print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
60
+
61
+ # decode to image
62
+ index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
63
+ if args.cfg_scale > 1.0:
64
+ index_sample = index_sample[:len(class_labels)]
65
+ t2 = time.time()
66
+ samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
67
+ decoder_time = time.time() - t2
68
+ print(f"decoder takes about {decoder_time:.2f} seconds.")
69
+
70
+ # Save and display images:
71
+ save_image(samples, "sample_{}.png".format(args.gpt_type), nrow=4, normalize=True, value_range=(-1, 1))
72
+ print(f"image is saved to sample_{args.gpt_type}.png")
73
+
74
+
75
+ if __name__ == '__main__':
76
+ parser = argparse.ArgumentParser()
77
+ parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
78
+ parser.add_argument("--gpt-ckpt", type=str, required=True, help="ckpt path for gpt model")
79
+ parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
80
+ parser.add_argument("--from-fsdp", action='store_true')
81
+ parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
82
+ parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
83
+ parser.add_argument("--compile", action='store_true', default=False)
84
+ parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
85
+ parser.add_argument("--vq-ckpt", type=str, required=True, help="ckpt path for vq model")
86
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
87
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
88
+ parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
89
+ parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
90
+ parser.add_argument("--num-classes", type=int, default=1000)
91
+ parser.add_argument("--cfg-scale", type=float, default=4.0)
92
+ parser.add_argument("--seed", type=int, default=0)
93
+ parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
94
+ parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
95
+ parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
96
+ args = parser.parse_args()
97
+ main(args)
serve/sampler.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A layer that samples the next tokens from the model's outputs."""
2
+ import itertools
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from vllm.model_executor.layers.ops.sample import sample as sample_triton
9
+ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
10
+ SamplingTensors)
11
+ from vllm.sampling_params import SamplingParams, SamplingType
12
+ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
13
+ SamplerOutput, SequenceData, SequenceGroupOutput,
14
+ SequenceOutput)
15
+
16
+
17
+ class Sampler(nn.Module):
18
+ """Samples the next tokens from the model's outputs.
19
+
20
+ This layer does the following:
21
+ 1. Discard the hidden states that are not used for sampling (i.e., all
22
+ tokens except the final one in each prompt).
23
+ 2. Compute the logits for the next tokens.
24
+ 3. Apply presence, frequency and repetition penalties.
25
+ 4. Apply temperature scaling.
26
+ 5. Apply top-p and top-k truncation.
27
+ 6. Sample the next tokens.
28
+ Here, each sequence group within the batch can have different sampling
29
+ parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
30
+
31
+ The structure of the logits tensor is coupled with the seq_groups in
32
+ sampling_metadata. Typically, each sequence in each seq_group has one row in
33
+ logits for the next token to be sampled; however, for a seq_group with a
34
+ prompt request with the prompt_logprobs sampling parameter, there are rows
35
+ in logits for each token in the input prompt.
36
+ """
37
+
38
+ def __init__(self, cfg_scale=1.0):
39
+ super().__init__()
40
+ self.cfg_scale = cfg_scale
41
+ # Whether or not the SamplerOutput should have on-device tensors
42
+ # containing the sampled token ids and probabilities. This is used by
43
+ # speculative decoding.
44
+ self.include_gpu_probs_tensor = False
45
+
46
+ def forward(
47
+ self,
48
+ logits: torch.Tensor,
49
+ sampling_metadata: SamplingMetadata,
50
+ ) -> Optional[SamplerOutput]:
51
+ assert logits is not None
52
+ _, vocab_size = logits.shape
53
+
54
+ if self.cfg_scale > 1.0:
55
+ logits_combined = logits
56
+ cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
57
+ logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_scale
58
+ logits = torch.cat([logits, logits], dim=0)
59
+
60
+ # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
61
+ # have not been generated yet
62
+ logits = _apply_min_tokens_penalty(logits, sampling_metadata)
63
+
64
+ # Prepare sampling tensors with pinned memory to avoid blocking.
65
+ (sampling_tensors, do_penalties, do_top_p_top_k,
66
+ do_min_p) = SamplingTensors.from_sampling_metadata(
67
+ sampling_metadata, vocab_size, logits.device, logits.dtype)
68
+
69
+ # Apply presence and frequency penalties.
70
+ if do_penalties:
71
+ logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
72
+ sampling_tensors.output_tokens,
73
+ sampling_tensors.presence_penalties,
74
+ sampling_tensors.frequency_penalties,
75
+ sampling_tensors.repetition_penalties)
76
+
77
+ # Apply temperature scaling.
78
+ # Use in-place division to avoid creating a new tensor.
79
+ logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
80
+
81
+ if do_top_p_top_k:
82
+ logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
83
+ sampling_tensors.top_ks)
84
+
85
+ if do_min_p:
86
+ logits = _apply_min_p(logits, sampling_tensors.min_ps)
87
+
88
+ # We use float32 for probabilities and log probabilities.
89
+ # Compute the probabilities.
90
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float)
91
+ # Compute the log probabilities.
92
+ # Use log_softmax to ensure numerical stability.
93
+ logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
94
+
95
+ # Sample the next tokens.
96
+ sample_results, maybe_sampled_tokens_tensor = _sample(
97
+ probs,
98
+ logprobs,
99
+ sampling_metadata,
100
+ sampling_tensors,
101
+ include_gpu_probs_tensor=self.include_gpu_probs_tensor,
102
+ modify_greedy_probs=self._should_modify_greedy_probs_inplace,
103
+ )
104
+
105
+
106
+ if self.cfg_scale > 1.0:
107
+ cond_result = sample_results[:len(sample_results) // 2]
108
+ sample_results = cond_result + cond_result
109
+
110
+
111
+ if self.include_gpu_probs_tensor:
112
+ assert maybe_sampled_tokens_tensor is not None
113
+ sampled_tokens_tensor = maybe_sampled_tokens_tensor
114
+ on_device_tensors = (probs, sampled_tokens_tensor)
115
+ else:
116
+ on_device_tensors = None
117
+
118
+ # Get the logprobs query results.
119
+ prompt_logprobs, sample_logprobs = _get_logprobs(
120
+ logprobs, sampling_metadata, sample_results)
121
+ return _build_sampler_output(sample_results,
122
+ sampling_metadata,
123
+ prompt_logprobs,
124
+ sample_logprobs,
125
+ on_device_tensors=on_device_tensors)
126
+
127
+ @property
128
+ def _should_modify_greedy_probs_inplace(self) -> bool:
129
+ """Whether or not the sampler should modify the probability distribution
130
+ of greedily-sampled tokens such that multinomial sampling would sample
131
+ the greedily-sampled token.
132
+
133
+ In other words, if True then we set the probability of the greedily-
134
+ sampled token to 1.
135
+
136
+ This is used by speculative decoding, which requires that the sampling
137
+ method be encoded into the probability distribution.
138
+ """
139
+ # Modify greedy probs if include_gpu_probs_tensor is set.
140
+ return self.include_gpu_probs_tensor
141
+
142
+
143
+ def _get_bin_counts_and_mask(
144
+ tokens: torch.Tensor,
145
+ vocab_size: int,
146
+ num_seqs: int,
147
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
148
+ # Compute the bin counts for the tokens.
149
+ # vocab_size + 1 for padding.
150
+ bin_counts = torch.zeros((num_seqs, vocab_size + 1),
151
+ dtype=torch.long,
152
+ device=tokens.device)
153
+ bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
154
+ bin_counts = bin_counts[:, :vocab_size]
155
+ mask = bin_counts > 0
156
+
157
+ return bin_counts, mask
158
+
159
+
160
+ def _apply_min_tokens_penalty(
161
+ logits: torch.Tensor,
162
+ sampling_metadata: SamplingMetadata,
163
+ ) -> torch.Tensor:
164
+ # list of indices in logits that will be set to -inf
165
+ logits_to_penalize = []
166
+ start_idx = 0
167
+ for i, seq_group in enumerate(sampling_metadata.seq_groups):
168
+ seq_ids, sampling_params = seq_group
169
+
170
+ # handle prompt_logprobs by skipping rows in logits added for the prompt
171
+ # tokens (prompt logprobs are not penalized)
172
+ if (i < sampling_metadata.num_prompts
173
+ and sampling_params.prompt_logprobs is not None):
174
+ assert len(seq_ids) == 1
175
+ start_idx += sampling_metadata.prompt_lens[i] - 1
176
+
177
+ min_tokens = sampling_params.min_tokens
178
+ if min_tokens > 0:
179
+ seqs_to_penalize = []
180
+ for i, seq_id in enumerate(seq_ids):
181
+ seq_data = sampling_metadata.seq_data[seq_id]
182
+ if len(seq_data.output_token_ids) < min_tokens:
183
+ seqs_to_penalize.append(i)
184
+
185
+ if seqs_to_penalize:
186
+ # convert to the index into logits
187
+ seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
188
+ # use set() to remove any duplicates
189
+ token_ids_to_penalize = set(sampling_params.stop_token_ids +
190
+ [sampling_params.eos_token_id])
191
+ # itertools.product pairs each seq index with every token id
192
+ logits_to_penalize.extend(
193
+ itertools.product(seqs_to_penalize, token_ids_to_penalize))
194
+
195
+ start_idx += len(seq_ids)
196
+
197
+ if logits_to_penalize:
198
+ # use zip and * to group indices along each dimension
199
+ # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
200
+ logits[tuple(zip(*logits_to_penalize))] = -float("inf")
201
+
202
+ # verifies that no rows in logits were missed unexpectedly
203
+ assert start_idx == logits.shape[0]
204
+ return logits
205
+
206
+
207
+ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
208
+ output_tokens_tensor: torch.Tensor,
209
+ presence_penalties: torch.Tensor,
210
+ frequency_penalties: torch.Tensor,
211
+ repetition_penalties: torch.Tensor) -> torch.Tensor:
212
+ num_seqs, vocab_size = logits.shape
213
+ _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
214
+ num_seqs)
215
+ output_bin_counts, output_mask = _get_bin_counts_and_mask(
216
+ output_tokens_tensor, vocab_size, num_seqs)
217
+
218
+ repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
219
+ repetition_penalties[~(prompt_mask | output_mask)] = 1.0
220
+ logits = torch.where(logits > 0, logits / repetition_penalties,
221
+ logits * repetition_penalties)
222
+
223
+ # We follow the definition in OpenAI API.
224
+ # Refer to https://platform.openai.com/docs/api-reference/parameter-details
225
+ logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
226
+ logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
227
+ return logits
228
+
229
+
230
+ def _apply_top_k_top_p(
231
+ logits: torch.Tensor,
232
+ p: torch.Tensor,
233
+ k: torch.Tensor,
234
+ ) -> torch.Tensor:
235
+ logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
236
+
237
+ # Apply top-k.
238
+ top_k_mask = logits_sort.size(1) - k.to(torch.long)
239
+ # Get all the top_k values.
240
+ top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
241
+ top_k_mask = logits_sort < top_k_mask
242
+ logits_sort.masked_fill_(top_k_mask, -float("inf"))
243
+
244
+ # Apply top-p.
245
+ probs_sort = logits_sort.softmax(dim=-1)
246
+ probs_sum = probs_sort.cumsum(dim=-1)
247
+ top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
248
+ # at least one
249
+ top_p_mask[:, -1] = False
250
+ logits_sort.masked_fill_(top_p_mask, -float("inf"))
251
+
252
+ # Re-sort the probabilities.
253
+ src = torch.arange(logits_idx.shape[-1],
254
+ device=logits_idx.device).expand_as(logits_idx)
255
+ logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
256
+ index=logits_idx,
257
+ src=src)
258
+ logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
259
+ return logits
260
+
261
+
262
+ def _apply_min_p(
263
+ logits: torch.Tensor,
264
+ min_p: torch.Tensor,
265
+ ) -> torch.Tensor:
266
+ """
267
+ Adapted from
268
+ https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
269
+ """
270
+ probs = torch.softmax(logits, dim=-1)
271
+ top_probs, _ = probs.max(dim=-1, keepdim=True)
272
+ scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
273
+ tokens_to_remove = probs < scaled_min_p
274
+ logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
275
+
276
+ return logits
277
+
278
+
279
+ def _greedy_sample(
280
+ selected_seq_groups: List[Tuple[List[int], SamplingParams]],
281
+ samples: torch.Tensor,
282
+ ) -> List[Tuple[List[int], List[int]]]:
283
+ samples = samples.tolist()
284
+ sample_idx = 0
285
+ results = []
286
+ for seq_group in selected_seq_groups:
287
+ seq_ids, _ = seq_group
288
+ num_parent_seqs = len(seq_ids)
289
+ assert num_parent_seqs == 1, (
290
+ "Greedy sampling should have only one seq.")
291
+ parent_ids = list(range(num_parent_seqs))
292
+ next_token_ids = [samples[sample_idx]]
293
+ results.append((next_token_ids, parent_ids))
294
+ sample_idx += num_parent_seqs
295
+ return results
296
+
297
+
298
+ def _random_sample(
299
+ selected_seq_groups: List[Tuple[List[int], SamplingParams]],
300
+ is_prompts: List[bool],
301
+ random_samples: torch.Tensor,
302
+ ) -> List[Tuple[List[int], List[int]]]:
303
+ # Find the maximum best_of value of the prompt phase requests.
304
+ random_samples = random_samples.cpu()
305
+ sample_idx = 0
306
+ results = []
307
+ for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
308
+ seq_ids, sampling_params = seq_group
309
+ num_parent_seqs = len(seq_ids)
310
+ if is_prompt:
311
+ # Prompt phase.
312
+ parent_ids = [0] * sampling_params.best_of
313
+ next_token_ids = random_samples[
314
+ sample_idx, :sampling_params.best_of].tolist()
315
+ else:
316
+ # Generation phase.
317
+ parent_ids = list(range(num_parent_seqs))
318
+ next_token_ids = random_samples[sample_idx:sample_idx +
319
+ num_parent_seqs, 0].tolist()
320
+ results.append((next_token_ids, parent_ids))
321
+ sample_idx += num_parent_seqs
322
+ return results
323
+
324
+
325
+ def _beam_search_sample(
326
+ selected_seq_groups: List[Tuple[List[int], SamplingParams]],
327
+ is_prompts: List[bool],
328
+ seq_data: Dict[int, SequenceData],
329
+ logprobs: torch.Tensor,
330
+ ) -> List[Tuple[List[int], List[int]]]:
331
+ # We sample 2 * beam_width candidates to make sure that with high
332
+ # probability we can get `beam_width` candidates in addition to
333
+ # the finished sequences for the next iteration. See
334
+ # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
335
+ # for details. See also HF reference:
336
+ # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
337
+ #
338
+ # NOTE: Beam search is not vectorized, so its speed can be slower than
339
+ # other sampling methods.
340
+ sample_idx = 0
341
+ results = []
342
+ for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
343
+ seq_ids, sampling_params = seq_group
344
+ num_parent_seqs = len(seq_ids)
345
+ beam_width = sampling_params.best_of
346
+ seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
347
+ if is_prompt:
348
+ # Prompt phase.
349
+ assert num_parent_seqs == 1, (
350
+ "Prompt input should have only one seq.")
351
+ parent_ids = [0] * (2 * beam_width)
352
+ _, next_token_ids = torch.topk(seq_group_logprobs[0],
353
+ 2 * beam_width)
354
+ next_token_ids = next_token_ids.tolist()
355
+ else:
356
+ # Generation phase.
357
+ cumulative_logprobs = [
358
+ seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
359
+ ]
360
+ cumulative_logprobs = torch.tensor(
361
+ cumulative_logprobs,
362
+ dtype=torch.float,
363
+ device=seq_group_logprobs.device)
364
+ seq_group_logprobs = (seq_group_logprobs +
365
+ cumulative_logprobs.unsqueeze(dim=1))
366
+ _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
367
+ 2 * beam_width)
368
+ topk_ids = topk_ids.tolist()
369
+ vocab_size = seq_group_logprobs.size(-1)
370
+ parent_ids = [i // vocab_size for i in topk_ids]
371
+ next_token_ids = [i % vocab_size for i in topk_ids]
372
+ results.append((next_token_ids, parent_ids))
373
+ sample_idx += num_parent_seqs
374
+ assert sample_idx == logprobs.size(0)
375
+ return results
376
+
377
+
378
+ # torch.multinomial forces a GPU<->CPU sync.
379
+ # Therefore, we use an optimized implementation instead.
380
+ # Note that we always sample with replacement.
381
+ # probs will be modified in place, but this is fine, as we pass
382
+ # in a copy already.
383
+ def _multinomial(
384
+ probs: torch.Tensor,
385
+ num_samples: int,
386
+ seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
387
+ generators: Optional[List[torch.Generator]] = None,
388
+ ) -> torch.Tensor:
389
+ if num_samples > 1:
390
+ # This is equivalent to torch.repeat_interleaved (which also
391
+ # forces a GPU<->CPU sync).
392
+ # This allows us to do sampling with replacement by creating
393
+ # num_samples copies of each row in the tensor, and then
394
+ # batch sampling the resulting tensor.
395
+ probs = probs[:, None, :].expand(probs.shape[0], num_samples,
396
+ probs.shape[1]).contiguous().view(
397
+ -1, probs.shape[1])
398
+ q = torch.empty_like(probs)
399
+ if seq_groups is None:
400
+ q.exponential_()
401
+ else:
402
+ sample_idx = 0
403
+ for (seq_ids, _), generator in zip(seq_groups, generators):
404
+ next_sample_idx = sample_idx + len(seq_ids) * num_samples
405
+ q[sample_idx:next_sample_idx].exponential_(generator=generator)
406
+ sample_idx = next_sample_idx
407
+ return probs.div_(q).argmax(dim=1).view(-1, num_samples)
408
+
409
+
410
+ def _sample_with_torch(
411
+ probs: torch.Tensor,
412
+ logprobs: torch.Tensor,
413
+ sampling_metadata: SamplingMetadata,
414
+ include_gpu_probs_tensor: bool,
415
+ modify_greedy_probs: bool,
416
+ ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
417
+ categorized_seq_group_ids = {t: [] for t in SamplingType}
418
+ categorized_sample_indices = sampling_metadata.categorized_sample_indices
419
+ for i, seq_group in enumerate(sampling_metadata.seq_groups):
420
+ _, sampling_params = seq_group
421
+ sampling_type = sampling_params.sampling_type
422
+ categorized_seq_group_ids[sampling_type].append(i)
423
+
424
+ sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
425
+ sample_metadata = {}
426
+ multinomial_samples = {}
427
+
428
+ # Create output tensor for sampled token ids.
429
+ if include_gpu_probs_tensor:
430
+ sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
431
+ 1,
432
+ dtype=torch.long,
433
+ device=logprobs.device)
434
+ else:
435
+ sampled_token_ids_tensor = None
436
+
437
+ # Counterintiutively, having two loops here is actually faster.
438
+ # The first loop can run without waiting on GPU<->CPU sync.
439
+ for sampling_type in SamplingType:
440
+ sample_indices = categorized_sample_indices[sampling_type][:, 0]
441
+ num_tokens = len(sample_indices)
442
+ if num_tokens == 0:
443
+ continue
444
+ seq_group_ids = categorized_seq_group_ids[sampling_type]
445
+ seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
446
+ is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
447
+ sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
448
+ is_prompts, sample_indices)
449
+ long_sample_indices = sample_indices.long()
450
+
451
+ if sampling_type == SamplingType.GREEDY:
452
+ greedy_samples = torch.argmax(logprobs[long_sample_indices],
453
+ dim=-1)
454
+
455
+ if include_gpu_probs_tensor:
456
+ # Store sampled tokens in output tensor.
457
+ sampled_token_ids_tensor[
458
+ long_sample_indices] = greedy_samples.unsqueeze(-1)
459
+
460
+ if modify_greedy_probs:
461
+ # If required, modify the probabilities such that sampling from
462
+ # the modified distribution would always sample the argmax
463
+ # token id.
464
+ _modify_greedy_probs_inplace(logprobs, probs,
465
+ long_sample_indices,
466
+ greedy_samples)
467
+
468
+ elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
469
+ max_best_of_in_batch = 1
470
+ for seq_group, is_prompt in zip(seq_groups, is_prompts):
471
+ if is_prompt:
472
+ _, sampling_params = seq_group
473
+ max_best_of_in_batch = max(max_best_of_in_batch,
474
+ sampling_params.best_of)
475
+ seeded_args = {} if sampling_type == SamplingType.RANDOM else {
476
+ "seq_groups": seq_groups,
477
+ "generators": sampling_metadata.generators,
478
+ }
479
+
480
+ multinomial_samples[sampling_type] = _multinomial(
481
+ probs[long_sample_indices], max_best_of_in_batch,
482
+ **seeded_args)
483
+
484
+ if include_gpu_probs_tensor:
485
+ # Store sampled tokens in output tensor.
486
+ sampled_token_ids_tensor[
487
+ long_sample_indices] = multinomial_samples[sampling_type]
488
+
489
+ elif sampling_type == SamplingType.BEAM:
490
+ beam_search_logprobs = logprobs[sample_indices]
491
+ else:
492
+ raise ValueError(f"Unsupported sampling type: {sampling_type}")
493
+
494
+ # GPU<->CPU sync happens in the loop below.
495
+ # This also converts the sample output to Python objects.
496
+
497
+ for sampling_type in SamplingType:
498
+ if sampling_type not in sample_metadata:
499
+ continue
500
+ seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
501
+ sampling_type]
502
+ if sampling_type == SamplingType.GREEDY:
503
+ sample_results = _greedy_sample(seq_groups, greedy_samples)
504
+ elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
505
+ sample_results = _random_sample(seq_groups, is_prompts,
506
+ multinomial_samples[sampling_type])
507
+ elif sampling_type == SamplingType.BEAM:
508
+ sample_results = _beam_search_sample(seq_groups, is_prompts,
509
+ sampling_metadata.seq_data,
510
+ beam_search_logprobs)
511
+ sample_results_dict.update(zip(seq_group_ids, sample_results))
512
+
513
+ sample_results = [
514
+ sample_results_dict[i]
515
+ for i in range(len(sampling_metadata.seq_groups))
516
+ ]
517
+ return sample_results, sampled_token_ids_tensor
518
+
519
+
520
+ def _sample_with_triton_kernel(
521
+ probs: torch.Tensor,
522
+ logprobs: torch.Tensor,
523
+ sampling_metadata: SamplingMetadata,
524
+ sampling_tensors: SamplingTensors,
525
+ ) -> List[Tuple[List[int], List[int]]]:
526
+ categorized_seq_group_ids = {t: [] for t in SamplingType}
527
+ categorized_sample_indices = sampling_metadata.categorized_sample_indices
528
+ for i, seq_group in enumerate(sampling_metadata.seq_groups):
529
+ _, sampling_params = seq_group
530
+ sampling_type = sampling_params.sampling_type
531
+ categorized_seq_group_ids[sampling_type].append(i)
532
+
533
+ sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
534
+ sample_metadata = {}
535
+ max_best_of_in_batch = 1
536
+
537
+ # Counterintiutively, having two loops here is actually faster.
538
+ # The first loop can run without waiting on GPU<->CPU sync.
539
+ for sampling_type in SamplingType:
540
+ sample_indices = categorized_sample_indices[sampling_type][:, 0]
541
+ sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
542
+ num_tokens = len(sample_indices)
543
+ if num_tokens == 0:
544
+ continue
545
+ seq_group_ids = categorized_seq_group_ids[sampling_type]
546
+ seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
547
+ is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
548
+ sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
549
+ is_prompts, sample_indices,
550
+ sampled_token_indices)
551
+ if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
552
+ SamplingType.RANDOM_SEED):
553
+ for seq_group, is_prompt in zip(seq_groups, is_prompts):
554
+ if is_prompt:
555
+ _, sampling_params = seq_group
556
+ max_best_of_in_batch = max(max_best_of_in_batch,
557
+ sampling_params.best_of)
558
+ elif sampling_type == SamplingType.BEAM:
559
+ beam_search_logprobs = logprobs[sample_indices]
560
+ else:
561
+ raise ValueError(f"Unsupported sampling type: {sampling_type}")
562
+
563
+ sampled_tokens, _, _ = sample_triton(
564
+ probs=probs,
565
+ seeds=sampling_tensors.sampling_seeds,
566
+ max_best_of=max_best_of_in_batch,
567
+ sample_indices=sampling_tensors.sample_indices,
568
+ logprobs=logprobs,
569
+ # don't save logprobs because we have logic for that below
570
+ # TODO: use this instead of the CPU-based logic below
571
+ save_logprobs=False,
572
+ )
573
+
574
+ # GPU<->CPU sync happens in the loop below.
575
+
576
+ for sampling_type in SamplingType:
577
+ if sampling_type not in sample_metadata:
578
+ continue
579
+ (seq_group_ids, seq_groups, is_prompts, sample_indices,
580
+ sampled_token_indices) = sample_metadata[sampling_type]
581
+ if sampling_type == SamplingType.GREEDY:
582
+ sample_results = _greedy_sample(
583
+ seq_groups, sampled_tokens[sampled_token_indices][:, 0])
584
+ elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
585
+ sample_results = _random_sample(
586
+ seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
587
+ elif sampling_type == SamplingType.BEAM:
588
+ sample_results = _beam_search_sample(seq_groups, is_prompts,
589
+ sampling_metadata.seq_data,
590
+ beam_search_logprobs)
591
+ sample_results_dict.update(zip(seq_group_ids, sample_results))
592
+
593
+ sample_results = [
594
+ sample_results_dict[i]
595
+ for i in range(len(sampling_metadata.seq_groups))
596
+ ]
597
+ return sample_results
598
+
599
+
600
+ def _sample(
601
+ probs: torch.Tensor, logprobs: torch.Tensor,
602
+ sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
603
+ include_gpu_probs_tensor: bool, modify_greedy_probs: bool
604
+ ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
605
+ return _sample_with_torch(
606
+ probs,
607
+ logprobs,
608
+ sampling_metadata,
609
+ include_gpu_probs_tensor=include_gpu_probs_tensor,
610
+ modify_greedy_probs=modify_greedy_probs,
611
+ )
612
+
613
+ # TODO: Enable once Triton kernel & associated code is faster.
614
+ # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
615
+ # sampling_tensors)
616
+
617
+
618
+ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
619
+ """
620
+ This function calculates the ranks of the chosen tokens in a logprob tensor.
621
+
622
+ Args:
623
+ x (torch.Tensor): 2D logprob tensor of shape (N, M)
624
+ where N is the no. of tokens and M is the vocab dim.
625
+ indices (torch.Tensor): List of chosen token indices.
626
+
627
+ Returns:
628
+ torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
629
+ Each element in the returned tensor represents the rank
630
+ of the chosen token in the input logprob tensor.
631
+ """
632
+ vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
633
+ indices]
634
+ return (x > vals[:, None]).long().sum(1).add_(1)
635
+
636
+
637
+ def _get_logprobs(
638
+ logprobs: torch.Tensor,
639
+ sampling_metadata: SamplingMetadata,
640
+ sample_results: List[Tuple[List[int], List[int]]],
641
+ ) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
642
+ int, float]]]]:
643
+ # Prepare query indices
644
+ batched_logprobs_query_seq_indices: List[int] = []
645
+ batched_logprobs_query_token_indices: List[int] = []
646
+ # at least get one logprob for each token
647
+ largest_num_logprobs = 1
648
+ sample_idx = 0
649
+ for i, (seq_group, sample_result) in enumerate(
650
+ zip(sampling_metadata.seq_groups, sample_results)):
651
+ seq_ids, sampling_params = seq_group
652
+ next_token_ids, parent_ids = sample_result
653
+ num_parent_seqs = len(seq_ids)
654
+ if (i < sampling_metadata.num_prompts
655
+ and sampling_params.prompt_logprobs is not None):
656
+ largest_num_logprobs = max(largest_num_logprobs,
657
+ sampling_params.prompt_logprobs)
658
+ prompt_len = sampling_metadata.prompt_lens[i]
659
+ prompt_tokens = sampling_metadata.seq_data[
660
+ seq_ids[0]].prompt_token_ids
661
+ batched_logprobs_query_seq_indices.extend(
662
+ sample_idx + j for j in range(prompt_len - 1))
663
+ batched_logprobs_query_token_indices.extend(
664
+ token_id for token_id in prompt_tokens[1:])
665
+ sample_idx += prompt_len - 1
666
+ batched_logprobs_query_seq_indices.extend(
667
+ [sample_idx + parent_id for parent_id in parent_ids])
668
+ batched_logprobs_query_token_indices.extend(next_token_ids)
669
+ if sampling_params.logprobs is not None:
670
+ largest_num_logprobs = max(largest_num_logprobs,
671
+ sampling_params.logprobs)
672
+ sample_idx += num_parent_seqs
673
+ assert sample_idx == logprobs.size(0)
674
+
675
+ batched_logprobs_query_seq_indices_gpu = torch.tensor(
676
+ batched_logprobs_query_seq_indices, device=logprobs.device)
677
+ batched_logprobs_query_token_indices_gpu = torch.tensor(
678
+ batched_logprobs_query_token_indices, device=logprobs.device)
679
+
680
+ # Batched query for logprobs of selected token
681
+ batched_logprobs_query_result = logprobs[[
682
+ batched_logprobs_query_seq_indices_gpu,
683
+ batched_logprobs_query_token_indices_gpu
684
+ ]]
685
+
686
+ batched_ranks_query_result = _get_ranks(
687
+ logprobs[batched_logprobs_query_seq_indices_gpu],
688
+ batched_logprobs_query_token_indices_gpu)
689
+
690
+ # Batched query for logprobs of topk tokens
691
+ if largest_num_logprobs > 0:
692
+ top_logprobs, top_token_ids = torch.topk(logprobs,
693
+ largest_num_logprobs,
694
+ dim=-1)
695
+ top_logprobs = top_logprobs.cpu()
696
+ top_token_ids = top_token_ids.cpu()
697
+ else:
698
+ top_logprobs, top_token_ids = None, None
699
+
700
+ batched_logprobs_query_result = batched_logprobs_query_result.cpu()
701
+ batched_ranks_query_result = batched_ranks_query_result.cpu()
702
+
703
+ # Gather results
704
+ result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
705
+ result_sample_logprobs: List[SampleLogprobs] = []
706
+ sample_idx = 0
707
+ query_result_idx = 0
708
+ for i, (seq_group, sample_result) in enumerate(
709
+ zip(sampling_metadata.seq_groups, sample_results)):
710
+ seq_ids, sampling_params = seq_group
711
+ next_token_ids, parent_ids = sample_result
712
+
713
+ # Prompt logprobs
714
+ if (i < sampling_metadata.num_prompts
715
+ and sampling_params.prompt_logprobs is not None):
716
+ num_logprobs = sampling_params.prompt_logprobs
717
+ prompt_tokens = sampling_metadata.seq_data[
718
+ seq_ids[0]].prompt_token_ids
719
+ group_prompt_logprobs: PromptLogprobs = [None]
720
+ for token_id in prompt_tokens[1:]:
721
+ prompt_logprobs_dict = {
722
+ token_id:
723
+ (batched_logprobs_query_result[query_result_idx].item(),
724
+ batched_ranks_query_result[query_result_idx].item())
725
+ }
726
+ if num_logprobs > 0:
727
+ prompt_logprobs_dict.update(
728
+ zip(
729
+ top_token_ids[sample_idx, :num_logprobs].tolist(),
730
+ zip(
731
+ top_logprobs[
732
+ sample_idx, :num_logprobs].tolist(),
733
+ range(1, num_logprobs + 1))))
734
+ group_prompt_logprobs.append({
735
+ token_id: Logprob(*logprob_rank)
736
+ for token_id, logprob_rank in prompt_logprobs_dict.items()
737
+ })
738
+ sample_idx += 1
739
+ query_result_idx += 1
740
+ result_prompt_logprobs.append(group_prompt_logprobs)
741
+ else:
742
+ result_prompt_logprobs.append(None)
743
+
744
+ # Sample logprobs
745
+ num_logprobs = sampling_params.logprobs
746
+ if num_logprobs is None:
747
+ num_logprobs = 0
748
+ group_sample_logprobs: SampleLogprobs = []
749
+ for next_token_id, parent_id in zip(next_token_ids, parent_ids):
750
+ sample_logprobs_dict = {
751
+ next_token_id:
752
+ (batched_logprobs_query_result[query_result_idx].item(),
753
+ batched_ranks_query_result[query_result_idx].item())
754
+ }
755
+ query_result_idx += 1
756
+ if num_logprobs >= 0:
757
+ sample_logprobs_dict.update(
758
+ zip(
759
+ top_token_ids[sample_idx +
760
+ parent_id, :num_logprobs].tolist(),
761
+ zip(
762
+ top_logprobs[sample_idx +
763
+ parent_id, :num_logprobs].tolist(),
764
+ range(1, num_logprobs + 1))))
765
+ group_sample_logprobs.append({
766
+ token_id: Logprob(*logprob_rank)
767
+ for token_id, logprob_rank in sample_logprobs_dict.items()
768
+ })
769
+ result_sample_logprobs.append(group_sample_logprobs)
770
+ sample_idx += len(seq_ids)
771
+
772
+ return result_prompt_logprobs, result_sample_logprobs
773
+
774
+
775
+ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
776
+ sample_indices: torch.Tensor,
777
+ greedy_samples: torch.Tensor) -> None:
778
+ """Modify the probability distributions of the greedily-sampled tokens such
779
+ that each sampled token has a "probability" of 1.0. This is required by
780
+ speculative decoding, which depends on the sampling method being encoded
781
+ within the probability distribution for correctness.
782
+
783
+ # Why do we only need to do this for greedy sampling?
784
+
785
+ vLLM's sampler performs the following steps for greedy or multinomial
786
+ (random) sampling:
787
+ 1. Get logits from model.
788
+ 2. Modify logits according to per-sequence sampling parameters.
789
+ - Multiply by temperature, top-k and top-p masking, penalize tokens
790
+ according to their frequency, etc.
791
+ 3. Sample a token.
792
+ - Random sampling simply samples from the modified probability
793
+ distribution.
794
+ - Greedy sampling performs `argmax` to obtain the token with the
795
+ highest likelihood.
796
+
797
+ Ignoring greedy sampling for a moment, we find that the computed probability
798
+ distribution has the following property: we can sample from it independently
799
+ and find that the token sampled by the Sampler has a frequency corresponding
800
+ to how often we see it in our sampling. In other words, for tokens sampled
801
+ with vLLM's random SamplingType, the computed probability distribution
802
+ encodes the sampling methodology completely.
803
+
804
+ Greedy sampling does not normally have this property. vLLM modifies logits
805
+ according to sampling params, then performs `argmax`, then returns the
806
+ sampled token and the computed probability distribution. If we sample from
807
+ the distribution, we'll find the likelihood of the greedily-sampled token
808
+ is not always 1.0.
809
+
810
+ Since lossless speculative decoding requires that the sampling methodology
811
+ be encoded within the probability distribution, we are motivated to modify
812
+ the probability distribution such that the sampled token has probability 1
813
+ when speculative decoding is used.
814
+
815
+ NOTE: Alternatively, we could use an extremely low temperature to achieve
816
+ greedy sampling using multinomial computation and unite the codepaths. This
817
+ has implications on the overall design of the sampler, e.g. how to record
818
+ accurate logprobs for the user, so this improvement is deferred to later.
819
+ """
820
+ logprobs[sample_indices, :] = -float('inf')
821
+ logprobs[sample_indices, greedy_samples] = 0.0
822
+ probs[sample_indices, :] = 0
823
+ probs[sample_indices, greedy_samples] = 1.0
824
+
825
+
826
+ def _build_sampler_output(
827
+ sample_results: List[Tuple[List[int], List[int]]],
828
+ sampling_metadata: SamplingMetadata,
829
+ prompt_logprobs: List[Optional[PromptLogprobs]],
830
+ sample_logprobs: List[SampleLogprobs],
831
+ on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
832
+ ) -> SamplerOutput:
833
+ """Construct Python objects with the output of sampling.
834
+
835
+ Args:
836
+ on_device_tensors: Tuple containing on-device tensors with the
837
+ probabilities used in sampling and the sampled token ids. This
838
+ allows post-processing without copies to CPU/serialization, e.g. in
839
+ speculative decoding rejection sampling.
840
+ """
841
+
842
+ sampler_output = []
843
+ for (seq_group, sample_result, group_prompt_logprobs,
844
+ group_sample_logprobs) in zip(sampling_metadata.seq_groups,
845
+ sample_results, prompt_logprobs,
846
+ sample_logprobs):
847
+ seq_ids, _ = seq_group
848
+ next_token_ids, parent_ids = sample_result
849
+ seq_outputs = []
850
+ for parent_id, next_token_id, logprobs in zip(parent_ids,
851
+ next_token_ids,
852
+ group_sample_logprobs):
853
+ seq_outputs.append(
854
+ SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
855
+ sampler_output.append(
856
+ SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
857
+
858
+ # If not specified, store None values in SamplerOutput.
859
+ if on_device_tensors is not None:
860
+ sampled_token_probs, sampled_token_ids = on_device_tensors
861
+ else:
862
+ sampled_token_probs, sampled_token_ids = (None, None)
863
+
864
+ return SamplerOutput(
865
+ outputs=sampler_output,
866
+ sampled_token_probs=sampled_token_probs,
867
+ sampled_token_ids=sampled_token_ids,
868
+ )
serve/worker.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A GPU worker class."""
2
+ import gc
3
+ import os
4
+ from typing import Any, Dict, List, Optional, Set, Tuple
5
+
6
+ import torch
7
+ import torch.distributed
8
+
9
+ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
10
+ ModelConfig, ParallelConfig, SchedulerConfig,
11
+ VisionLanguageConfig)
12
+ from vllm.distributed import (broadcast_tensor_dict,
13
+ ensure_model_parallel_initialized,
14
+ init_distributed_environment)
15
+ from vllm.distributed.device_communicators import pynccl_utils
16
+ from vllm.distributed.device_communicators.custom_all_reduce import (
17
+ init_custom_ar)
18
+ from vllm.lora.request import LoRARequest
19
+ from vllm.model_executor import set_random_seed
20
+ from vllm.sequence import SamplerOutput, SequenceGroupMetadata
21
+ from vllm.worker.cache_engine import CacheEngine
22
+ # from vllm.worker.model_runner import ModelRunner
23
+ from vllm.worker.worker_base import WorkerBase
24
+ from serve.model_runner import ModelRunner
25
+
26
+
27
+ class Worker(WorkerBase):
28
+ """A worker class that executes (a partition of) the model on a GPU.
29
+
30
+ Each worker is associated with a single GPU. The worker is responsible for
31
+ maintaining the KV cache and executing the model on the GPU. In case of
32
+ distributed inference, each worker is assigned a partition of the model.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ model_config: ModelConfig,
38
+ parallel_config: ParallelConfig,
39
+ scheduler_config: SchedulerConfig,
40
+ device_config: DeviceConfig,
41
+ cache_config: CacheConfig,
42
+ load_config: LoadConfig,
43
+ local_rank: int,
44
+ rank: int,
45
+ distributed_init_method: str,
46
+ lora_config: Optional[LoRAConfig] = None,
47
+ vision_language_config: Optional[VisionLanguageConfig] = None,
48
+ is_driver_worker: bool = False,
49
+ ) -> None:
50
+ self.model_config = model_config
51
+ self.parallel_config = parallel_config
52
+ self.scheduler_config = scheduler_config
53
+ self.device_config = device_config
54
+ self.cache_config = cache_config
55
+ self.local_rank = local_rank
56
+ self.rank = rank
57
+ self.distributed_init_method = distributed_init_method
58
+ self.lora_config = lora_config
59
+ self.load_config = load_config
60
+ self.is_driver_worker = is_driver_worker
61
+ if self.is_driver_worker:
62
+ assert self.rank == 0, "The driver worker must have rank 0."
63
+
64
+ if self.model_config.trust_remote_code:
65
+ # note: lazy import to avoid importing torch before initializing
66
+ from vllm.utils import init_cached_hf_modules
67
+ init_cached_hf_modules()
68
+ self.vision_language_config = vision_language_config
69
+ if self.vision_language_config:
70
+ assert not self.lora_config, (
71
+ "To be tested: vision language model with LoRA settings.")
72
+
73
+ self.model_runner = ModelRunner(
74
+ model_config,
75
+ parallel_config,
76
+ scheduler_config,
77
+ device_config,
78
+ load_config=load_config,
79
+ lora_config=self.lora_config,
80
+ kv_cache_dtype=self.cache_config.cache_dtype,
81
+ is_driver_worker=is_driver_worker,
82
+ vision_language_config=vision_language_config,
83
+ )
84
+ # Uninitialized cache engine. Will be initialized by
85
+ # initialize_cache.
86
+ self.cache_engine: CacheEngine
87
+ self.gpu_cache: List[torch.Tensor]
88
+
89
+ def init_device(self) -> None:
90
+ if self.device_config.device.type == "cuda":
91
+ # torch.distributed.all_reduce does not free the input tensor until
92
+ # the synchronization point. This causes the memory usage to grow
93
+ # as the number of all_reduce calls increases. This env var disables
94
+ # this behavior.
95
+ # Related issue:
96
+ # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
97
+ os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
98
+
99
+ # This env var set by Ray causes exceptions with graph building.
100
+ os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
101
+ self.device = torch.device(f"cuda:{self.local_rank}")
102
+ torch.cuda.set_device(self.device)
103
+
104
+ _check_if_gpu_supports_dtype(self.model_config.dtype)
105
+ torch.cuda.empty_cache()
106
+ self.init_gpu_memory = torch.cuda.mem_get_info()[0]
107
+ else:
108
+ raise RuntimeError(
109
+ f"Not support device type: {self.device_config.device}")
110
+ # Initialize the distributed environment.
111
+ init_worker_distributed_environment(self.parallel_config, self.rank,
112
+ self.distributed_init_method,
113
+ self.local_rank)
114
+ # Set random seed.
115
+ set_random_seed(self.model_config.seed)
116
+
117
+ def load_model(self, args):
118
+ self.model_runner.load_model(args)
119
+
120
+ @torch.inference_mode()
121
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
122
+ """Profiles the peak memory usage of the model to determine how many
123
+ KV blocks may be allocated without OOMs.
124
+
125
+ The engine will first conduct a profiling of the existing memory usage.
126
+ Then, it calculate the maximum possible number of GPU and CPU blocks
127
+ that can be allocated with the remaining free memory.
128
+
129
+ .. tip::
130
+ You may limit the usage of GPU memory
131
+ by adjusting the `gpu_memory_utilization` parameter.
132
+ """
133
+ # Profile the memory usage of the model and get the maximum number of
134
+ # cache blocks that can be allocated with the remaining free memory.
135
+ torch.cuda.empty_cache()
136
+
137
+ # Execute a forward pass with dummy inputs to profile the memory usage
138
+ # of the model.
139
+ self.model_runner.profile_run()
140
+
141
+ # Calculate the number of blocks that can be allocated with the
142
+ # profiled peak memory.
143
+ torch.cuda.synchronize()
144
+ free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
145
+ # NOTE(woosuk): Here we assume that the other processes using the same
146
+ # GPU did not change their memory usage during the profiling.
147
+ peak_memory = self.init_gpu_memory - free_gpu_memory
148
+ assert peak_memory > 0, (
149
+ "Error in memory profiling. This happens when the GPU memory was "
150
+ "not properly cleaned up before initializing the vLLM instance.")
151
+
152
+ cache_block_size = self.get_cache_block_size_bytes()
153
+ num_gpu_blocks = int(
154
+ (total_gpu_memory * self.cache_config.gpu_memory_utilization -
155
+ peak_memory) // cache_block_size)
156
+ num_cpu_blocks = int(self.cache_config.swap_space_bytes //
157
+ cache_block_size)
158
+ num_gpu_blocks = max(num_gpu_blocks, 0)
159
+ num_cpu_blocks = max(num_cpu_blocks, 0)
160
+ if self.model_runner.lora_manager:
161
+ self.model_runner.remove_all_loras()
162
+ gc.collect()
163
+ torch.cuda.empty_cache()
164
+ return num_gpu_blocks, num_cpu_blocks
165
+
166
+ def initialize_cache(self, num_gpu_blocks: int,
167
+ num_cpu_blocks: int) -> None:
168
+ """Allocate GPU and CPU KV cache with the specified number of blocks.
169
+
170
+ This also warms up the model, which may record CUDA graphs.
171
+ """
172
+ raise_if_cache_size_invalid(num_gpu_blocks,
173
+ self.cache_config.block_size,
174
+ self.model_config.max_model_len)
175
+
176
+ self.cache_config.num_gpu_blocks = num_gpu_blocks
177
+ self.cache_config.num_cpu_blocks = num_cpu_blocks
178
+
179
+ self._init_cache_engine()
180
+ self._warm_up_model()
181
+
182
+ def _init_cache_engine(self):
183
+ assert self.cache_config.num_gpu_blocks is not None
184
+ self.cache_engine = CacheEngine(self.cache_config, self.model_config,
185
+ self.parallel_config)
186
+ self.gpu_cache = self.cache_engine.gpu_cache
187
+ self.model_runner.set_block_size(self.cache_engine.block_size)
188
+
189
+ def _warm_up_model(self) -> None:
190
+ if not self.model_config.enforce_eager:
191
+ self.model_runner.capture_model(self.gpu_cache)
192
+ # Reset the seed to ensure that the random state is not affected by
193
+ # the model initialization and profiling.
194
+ set_random_seed(self.model_config.seed)
195
+
196
+ def cache_swap(
197
+ self,
198
+ blocks_to_swap_in: Dict[int, int],
199
+ blocks_to_swap_out: Dict[int, int],
200
+ blocks_to_copy: Dict[int, List[int]],
201
+ ) -> None:
202
+ # Issue cache operations.
203
+ # TODO(woosuk): Profile swapping overhead and optimize if needed.
204
+ if blocks_to_swap_in:
205
+ self.cache_engine.swap_in(blocks_to_swap_in)
206
+ if blocks_to_swap_out:
207
+ self.cache_engine.swap_out(blocks_to_swap_out)
208
+ if blocks_to_copy:
209
+ self.cache_engine.copy(blocks_to_copy)
210
+
211
+ @torch.inference_mode()
212
+ def execute_model(
213
+ self,
214
+ seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
215
+ blocks_to_swap_in: Optional[Dict[int, int]] = None,
216
+ blocks_to_swap_out: Optional[Dict[int, int]] = None,
217
+ blocks_to_copy: Optional[Dict[int, List[int]]] = None,
218
+ num_lookahead_slots: int = 0,
219
+ ) -> List[SamplerOutput]:
220
+
221
+ if self.is_driver_worker:
222
+ assert seq_group_metadata_list is not None
223
+ num_seq_groups = len(seq_group_metadata_list)
224
+ assert blocks_to_swap_in is not None
225
+ assert blocks_to_swap_out is not None
226
+ assert blocks_to_copy is not None
227
+ data: Dict[str, Any] = {
228
+ "num_seq_groups": num_seq_groups,
229
+ "blocks_to_swap_in": blocks_to_swap_in,
230
+ "blocks_to_swap_out": blocks_to_swap_out,
231
+ "blocks_to_copy": blocks_to_copy,
232
+ }
233
+ broadcast_tensor_dict(data, src=0)
234
+ else:
235
+ data = broadcast_tensor_dict(src=0)
236
+ num_seq_groups = data["num_seq_groups"]
237
+ blocks_to_swap_in = data["blocks_to_swap_in"]
238
+ blocks_to_swap_out = data["blocks_to_swap_out"]
239
+ blocks_to_copy = data["blocks_to_copy"]
240
+
241
+ assert blocks_to_swap_in is not None
242
+ assert blocks_to_swap_out is not None
243
+ assert blocks_to_copy is not None
244
+ self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
245
+
246
+ # If there is no input, we don't need to execute the model.
247
+ if num_seq_groups == 0:
248
+ return []
249
+
250
+ output = self.model_runner.execute_model(seq_group_metadata_list,
251
+ self.gpu_cache)
252
+
253
+ # Worker only supports single-step execution. Wrap the output in a list
254
+ # to conform to interface.
255
+ return [output]
256
+
257
+ def add_lora(self, lora_request: LoRARequest) -> bool:
258
+ return self.model_runner.add_lora(lora_request)
259
+
260
+ def remove_lora(self, lora_id: int) -> bool:
261
+ return self.model_runner.remove_lora(lora_id)
262
+
263
+ def list_loras(self) -> Set[int]:
264
+ return self.model_runner.list_loras()
265
+
266
+ @property
267
+ def max_model_len(self) -> int:
268
+ return self.model_config.max_model_len
269
+
270
+ @property
271
+ def vocab_size(self) -> int:
272
+ return self.model_runner.vocab_size
273
+
274
+ def get_cache_block_size_bytes(self) -> int:
275
+ """Get the size of the KV cache block size in bytes.
276
+ """
277
+ return CacheEngine.get_cache_block_size(self.cache_config,
278
+ self.model_config,
279
+ self.parallel_config)
280
+
281
+
282
+ def init_worker_distributed_environment(
283
+ parallel_config: ParallelConfig,
284
+ rank: int,
285
+ distributed_init_method: Optional[str] = None,
286
+ local_rank: int = -1,
287
+ ) -> None:
288
+ """Initialize the distributed environment."""
289
+ init_distributed_environment(parallel_config.world_size, rank,
290
+ distributed_init_method, local_rank)
291
+
292
+ if pynccl_utils.is_initialized():
293
+ pynccl_world_size = pynccl_utils.get_world_size()
294
+ if pynccl_world_size != parallel_config.world_size:
295
+ raise RuntimeError(
296
+ "pynccl is already initialized but the pynccl world "
297
+ "size does not match parallel_config.world_size "
298
+ f"({pynccl_world_size} vs. {parallel_config.world_size}).")
299
+ elif parallel_config.world_size > 1:
300
+ # NOTE(woosuk): We don't initialize pynccl process group when world size
301
+ # is 1.
302
+ pynccl_utils.init_process_group(
303
+ world_size=parallel_config.world_size,
304
+ local_rank=local_rank,
305
+ rank=rank,
306
+ init_method=distributed_init_method,
307
+ )
308
+
309
+ ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
310
+ parallel_config.pipeline_parallel_size)
311
+
312
+ # Initialize a custom fast all-reduce implementation.
313
+ if not parallel_config.disable_custom_all_reduce:
314
+ init_custom_ar()
315
+
316
+ # A small all_reduce for warmup.
317
+ torch.distributed.all_reduce(torch.zeros(1).cuda())
318
+ if pynccl_utils.is_initialized():
319
+ pynccl_utils.all_reduce(torch.zeros(1).cuda())
320
+
321
+
322
+ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
323
+ # Check if the GPU supports the dtype.
324
+ if torch_dtype == torch.bfloat16:
325
+ compute_capability = torch.cuda.get_device_capability()
326
+ if compute_capability[0] < 8:
327
+ gpu_name = torch.cuda.get_device_name()
328
+ raise ValueError(
329
+ "Bfloat16 is only supported on GPUs with compute capability "
330
+ f"of at least 8.0. Your {gpu_name} GPU has compute capability "
331
+ f"{compute_capability[0]}.{compute_capability[1]}. "
332
+ "You can use float16 instead by explicitly setting the"
333
+ "`dtype` flag in CLI, for example: --dtype=half.")
334
+
335
+
336
+ def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
337
+ max_model_len) -> None:
338
+ if num_gpu_blocks <= 0:
339
+ raise ValueError("No available memory for the cache blocks. "
340
+ "Try increasing `gpu_memory_utilization` when "
341
+ "initializing the engine.")
342
+ max_seq_len = block_size * num_gpu_blocks
343
+ if max_model_len > max_seq_len:
344
+ raise ValueError(
345
+ f"The model's max seq len ({max_model_len}) "
346
+ "is larger than the maximum number of tokens that can be "
347
+ f"stored in KV cache ({max_seq_len}). Try increasing "
348
+ "`gpu_memory_utilization` or decreasing `max_model_len` when "
349
+ "initializing the engine.")