Spaces:
Runtime error
Runtime error
ShoufaChen
commited on
Commit
·
4bfb360
1
Parent(s):
4d20c2f
vllm
Browse files- app.py +30 -51
- app_naive.py +160 -0
- requirements.txt +2 -1
- serve/README.md +63 -0
- serve/gpt_model.py +369 -0
- serve/gpu_executor.py +201 -0
- serve/llm.py +267 -0
- serve/llm_engine.py +671 -0
- serve/model_runner.py +1223 -0
- serve/sample_c2i.py +97 -0
- serve/sampler.py +868 -0
- serve/worker.py +349 -0
app.py
CHANGED
@@ -8,12 +8,12 @@ torch.backends.cudnn.allow_tf32 = True
|
|
8 |
torch.set_float32_matmul_precision('high')
|
9 |
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
|
10 |
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
|
11 |
-
|
12 |
import time
|
13 |
import argparse
|
14 |
from tokenizer_image.vq_model import VQ_models
|
15 |
-
from models.
|
16 |
-
from
|
17 |
|
18 |
device = "cuda"
|
19 |
|
@@ -38,46 +38,16 @@ def load_model(args):
|
|
38 |
del checkpoint
|
39 |
print(f"image tokenizer is loaded")
|
40 |
|
41 |
-
#
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
model_type=args.gpt_type,
|
50 |
-
).to(device=device, dtype=precision)
|
51 |
-
|
52 |
-
checkpoint = torch.load(f"{ckpt_folder}{gpt_ckpt}", map_location="cpu")
|
53 |
-
if args.from_fsdp: # fspd
|
54 |
-
model_weight = checkpoint
|
55 |
-
elif "model" in checkpoint: # ddp
|
56 |
-
model_weight = checkpoint["model"]
|
57 |
-
elif "module" in checkpoint: # deepspeed
|
58 |
-
model_weight = checkpoint["module"]
|
59 |
-
elif "state_dict" in checkpoint:
|
60 |
-
model_weight = checkpoint["state_dict"]
|
61 |
-
else:
|
62 |
-
raise Exception("please check model weight")
|
63 |
-
# if 'freqs_cis' in model_weight:
|
64 |
-
# model_weight.pop('freqs_cis')
|
65 |
-
gpt_model.load_state_dict(model_weight, strict=False)
|
66 |
-
gpt_model.eval()
|
67 |
-
del checkpoint
|
68 |
print(f"gpt model is loaded")
|
69 |
-
|
70 |
-
if args.compile:
|
71 |
-
print(f"compiling the model...")
|
72 |
-
gpt_model = torch.compile(
|
73 |
-
gpt_model,
|
74 |
-
mode="reduce-overhead",
|
75 |
-
fullgraph=True
|
76 |
-
) # requires PyTorch 2.0 (optional)
|
77 |
-
else:
|
78 |
-
print(f"no need to compile model in demo")
|
79 |
-
|
80 |
-
return vq_model, gpt_model, image_size
|
81 |
|
82 |
|
83 |
def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
@@ -85,20 +55,29 @@ def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
|
85 |
latent_size = image_size // args.downsample_size
|
86 |
# Labels to condition the model with (feel free to change):
|
87 |
class_labels = [class_label for _ in range(n)]
|
88 |
-
c_indices = torch.tensor(class_labels, device=device)
|
89 |
qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
t1 = time.time()
|
92 |
torch.manual_seed(seed)
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
top_p=top_p, sample_logits=True,
|
98 |
-
)
|
99 |
sampling_time = time.time() - t1
|
100 |
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
101 |
|
|
|
|
|
|
|
102 |
t2 = time.time()
|
103 |
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
104 |
decoder_time = time.time() - t2
|
@@ -110,7 +89,7 @@ def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
|
110 |
|
111 |
|
112 |
parser = argparse.ArgumentParser()
|
113 |
-
parser.add_argument("--gpt-model", type=str,
|
114 |
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
|
115 |
parser.add_argument("--from-fsdp", action='store_true')
|
116 |
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
|
@@ -129,7 +108,7 @@ parser.add_argument("--temperature", type=float, default=1.0, help="temperature
|
|
129 |
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
130 |
args = parser.parse_args()
|
131 |
|
132 |
-
vq_model,
|
133 |
|
134 |
with gr.Blocks() as demo:
|
135 |
gr.Markdown("<h1 style='text-align: center'>Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation</h1>")
|
|
|
8 |
torch.set_float32_matmul_precision('high')
|
9 |
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
|
10 |
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
|
11 |
+
from vllm import SamplingParams
|
12 |
import time
|
13 |
import argparse
|
14 |
from tokenizer_image.vq_model import VQ_models
|
15 |
+
# from models.generate import generate
|
16 |
+
from serve.llm import LLM
|
17 |
|
18 |
device = "cuda"
|
19 |
|
|
|
38 |
del checkpoint
|
39 |
print(f"image tokenizer is loaded")
|
40 |
|
41 |
+
# Create an LLM.
|
42 |
+
args.image_size = image_size
|
43 |
+
args.gpt_ckpt = f"{ckpt_folder}{gpt_ckpt}"
|
44 |
+
llm = LLM(
|
45 |
+
args=args,
|
46 |
+
model='serve/fake_json/{}.json'.format(args.gpt_model),
|
47 |
+
gpu_memory_utilization=0.6,
|
48 |
+
skip_tokenizer_init=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
print(f"gpt model is loaded")
|
50 |
+
return vq_model, llm, image_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
|
53 |
def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
|
|
55 |
latent_size = image_size // args.downsample_size
|
56 |
# Labels to condition the model with (feel free to change):
|
57 |
class_labels = [class_label for _ in range(n)]
|
|
|
58 |
qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
|
59 |
|
60 |
+
prompt_token_ids = [[cind] for cind in class_labels]
|
61 |
+
if cfg_scale > 1.0:
|
62 |
+
prompt_token_ids.extend([[args.num_classes] for _ in range(len(prompt_token_ids))])
|
63 |
+
|
64 |
+
# Create a sampling params object.
|
65 |
+
sampling_params = SamplingParams(
|
66 |
+
temperature=temperature, top_p=top_p, top_k=top_k,
|
67 |
+
max_tokens=latent_size ** 2)
|
68 |
+
|
69 |
t1 = time.time()
|
70 |
torch.manual_seed(seed)
|
71 |
+
outputs = llm.generate(
|
72 |
+
prompt_token_ids=prompt_token_ids,
|
73 |
+
sampling_params=sampling_params,
|
74 |
+
use_tqdm=False)
|
|
|
|
|
75 |
sampling_time = time.time() - t1
|
76 |
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
77 |
|
78 |
+
index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
|
79 |
+
if args.cfg_scale > 1.0:
|
80 |
+
index_sample = index_sample[:len(class_labels)]
|
81 |
t2 = time.time()
|
82 |
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
83 |
decoder_time = time.time() - t2
|
|
|
89 |
|
90 |
|
91 |
parser = argparse.ArgumentParser()
|
92 |
+
parser.add_argument("--gpt-model", type=str, default="GPT-XL")
|
93 |
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
|
94 |
parser.add_argument("--from-fsdp", action='store_true')
|
95 |
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
|
|
|
108 |
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
109 |
args = parser.parse_args()
|
110 |
|
111 |
+
vq_model, llm, image_size = load_model(args)
|
112 |
|
113 |
with gr.Blocks() as demo:
|
114 |
gr.Markdown("<h1 style='text-align: center'>Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation</h1>")
|
app_naive.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import gradio as gr
|
3 |
+
from imagenet_en_cn import IMAGENET_1K_CLASSES
|
4 |
+
from huggingface_hub import hf_hub_download
|
5 |
+
import torch
|
6 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
7 |
+
torch.backends.cudnn.allow_tf32 = True
|
8 |
+
torch.set_float32_matmul_precision('high')
|
9 |
+
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
|
10 |
+
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
|
11 |
+
|
12 |
+
import time
|
13 |
+
import argparse
|
14 |
+
from tokenizer_image.vq_model import VQ_models
|
15 |
+
from models.gpt import GPT_models
|
16 |
+
from models.generate import generate
|
17 |
+
|
18 |
+
device = "cuda"
|
19 |
+
|
20 |
+
model2ckpt = {
|
21 |
+
"GPT-XL": ("vq_ds16_c2i.pt", "c2i_XL_384.pt", 384),
|
22 |
+
"GPT-B": ("vq_ds16_c2i.pt", "c2i_B_256.pt", 256),
|
23 |
+
}
|
24 |
+
|
25 |
+
def load_model(args):
|
26 |
+
ckpt_folder = "./"
|
27 |
+
vq_ckpt, gpt_ckpt, image_size = model2ckpt[args.gpt_model]
|
28 |
+
hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=vq_ckpt, local_dir=ckpt_folder)
|
29 |
+
hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=gpt_ckpt, local_dir=ckpt_folder)
|
30 |
+
# create and load model
|
31 |
+
vq_model = VQ_models[args.vq_model](
|
32 |
+
codebook_size=args.codebook_size,
|
33 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
34 |
+
vq_model.to(device)
|
35 |
+
vq_model.eval()
|
36 |
+
checkpoint = torch.load(f"{ckpt_folder}{vq_ckpt}", map_location="cpu")
|
37 |
+
vq_model.load_state_dict(checkpoint["model"])
|
38 |
+
del checkpoint
|
39 |
+
print(f"image tokenizer is loaded")
|
40 |
+
|
41 |
+
# create and load gpt model
|
42 |
+
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
|
43 |
+
latent_size = image_size // args.downsample_size
|
44 |
+
gpt_model = GPT_models[args.gpt_model](
|
45 |
+
vocab_size=args.codebook_size,
|
46 |
+
block_size=latent_size ** 2,
|
47 |
+
num_classes=args.num_classes,
|
48 |
+
cls_token_num=args.cls_token_num,
|
49 |
+
model_type=args.gpt_type,
|
50 |
+
).to(device=device, dtype=precision)
|
51 |
+
|
52 |
+
checkpoint = torch.load(f"{ckpt_folder}{gpt_ckpt}", map_location="cpu")
|
53 |
+
if args.from_fsdp: # fspd
|
54 |
+
model_weight = checkpoint
|
55 |
+
elif "model" in checkpoint: # ddp
|
56 |
+
model_weight = checkpoint["model"]
|
57 |
+
elif "module" in checkpoint: # deepspeed
|
58 |
+
model_weight = checkpoint["module"]
|
59 |
+
elif "state_dict" in checkpoint:
|
60 |
+
model_weight = checkpoint["state_dict"]
|
61 |
+
else:
|
62 |
+
raise Exception("please check model weight")
|
63 |
+
# if 'freqs_cis' in model_weight:
|
64 |
+
# model_weight.pop('freqs_cis')
|
65 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
66 |
+
gpt_model.eval()
|
67 |
+
del checkpoint
|
68 |
+
print(f"gpt model is loaded")
|
69 |
+
|
70 |
+
if args.compile:
|
71 |
+
print(f"compiling the model...")
|
72 |
+
gpt_model = torch.compile(
|
73 |
+
gpt_model,
|
74 |
+
mode="reduce-overhead",
|
75 |
+
fullgraph=True
|
76 |
+
) # requires PyTorch 2.0 (optional)
|
77 |
+
else:
|
78 |
+
print(f"no need to compile model in demo")
|
79 |
+
|
80 |
+
return vq_model, gpt_model, image_size
|
81 |
+
|
82 |
+
|
83 |
+
def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
84 |
+
n = 4
|
85 |
+
latent_size = image_size // args.downsample_size
|
86 |
+
# Labels to condition the model with (feel free to change):
|
87 |
+
class_labels = [class_label for _ in range(n)]
|
88 |
+
c_indices = torch.tensor(class_labels, device=device)
|
89 |
+
qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
|
90 |
+
|
91 |
+
t1 = time.time()
|
92 |
+
torch.manual_seed(seed)
|
93 |
+
index_sample = generate(
|
94 |
+
gpt_model, c_indices, latent_size ** 2,
|
95 |
+
cfg_scale=cfg_scale, cfg_interval=args.cfg_interval,
|
96 |
+
temperature=temperature, top_k=top_k,
|
97 |
+
top_p=top_p, sample_logits=True,
|
98 |
+
)
|
99 |
+
sampling_time = time.time() - t1
|
100 |
+
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
101 |
+
|
102 |
+
t2 = time.time()
|
103 |
+
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
104 |
+
decoder_time = time.time() - t2
|
105 |
+
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
106 |
+
# Convert to PIL.Image format:
|
107 |
+
samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy()
|
108 |
+
samples = [Image.fromarray(sample) for sample in samples]
|
109 |
+
return samples
|
110 |
+
|
111 |
+
|
112 |
+
parser = argparse.ArgumentParser()
|
113 |
+
parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
|
114 |
+
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
|
115 |
+
parser.add_argument("--from-fsdp", action='store_true')
|
116 |
+
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
|
117 |
+
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
118 |
+
parser.add_argument("--compile", action='store_true', default=False)
|
119 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
120 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
121 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
122 |
+
parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
|
123 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
124 |
+
parser.add_argument("--cfg-scale", type=float, default=4.0)
|
125 |
+
parser.add_argument("--cfg-interval", type=float, default=-1)
|
126 |
+
parser.add_argument("--seed", type=int, default=0)
|
127 |
+
parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
|
128 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
|
129 |
+
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
130 |
+
args = parser.parse_args()
|
131 |
+
|
132 |
+
vq_model, gpt_model, image_size = load_model(args)
|
133 |
+
|
134 |
+
with gr.Blocks() as demo:
|
135 |
+
gr.Markdown("<h1 style='text-align: center'>Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation</h1>")
|
136 |
+
|
137 |
+
with gr.Tabs():
|
138 |
+
with gr.TabItem('Generate'):
|
139 |
+
with gr.Row():
|
140 |
+
with gr.Column():
|
141 |
+
# with gr.Row():
|
142 |
+
# image_size = gr.Radio(choices=[384], value=384, label='Peize Model Resolution')
|
143 |
+
with gr.Row():
|
144 |
+
i1k_class = gr.Dropdown(
|
145 |
+
list(IMAGENET_1K_CLASSES.values()),
|
146 |
+
value='Eskimo dog, husky [爱斯基摩犬,哈士奇]',
|
147 |
+
type="index", label='ImageNet-1K Class'
|
148 |
+
)
|
149 |
+
cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
|
150 |
+
top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=4000, label='Top-K')
|
151 |
+
top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P")
|
152 |
+
temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature')
|
153 |
+
seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
|
154 |
+
# seed = gr.Number(value=0, label='Seed')
|
155 |
+
button = gr.Button("Generate", variant="primary")
|
156 |
+
with gr.Column():
|
157 |
+
output = gr.Gallery(label='Generated Images', height=700)
|
158 |
+
button.click(infer, inputs=[cfg_scale, top_k, top_p, temperature, i1k_class, seed], outputs=[output])
|
159 |
+
demo.queue()
|
160 |
+
demo.launch(debug=True)
|
requirements.txt
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
|
|
|
|
1 |
+
vllm==0.4.1
|
2 |
+
torchvision==0.17.1
|
serve/README.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## serving by vLLM
|
2 |
+
|
3 |
+
### Install
|
4 |
+
```
|
5 |
+
pip install vllm==0.4.1
|
6 |
+
pip install torchvision==0.17.1
|
7 |
+
```
|
8 |
+
|
9 |
+
### Demo
|
10 |
+
```
|
11 |
+
cd ${THIS_REPO_ROOT}
|
12 |
+
python3 autoregressive/serve/sample_c2i.py --vq-ckpt /path/to/vq_ds16size16384dim8.pt --gpt-ckpt /path/to/GPT-B/checkpoints/1500000.pt --gpt-model GPT-B
|
13 |
+
|
14 |
+
```
|
15 |
+
|
16 |
+
|
17 |
+
### Comparison (A100)
|
18 |
+
|
19 |
+
Method | params | baseline(s) | vllm(s) | speed-up ratio
|
20 |
+
--- |:---:|:---:|:---:|:---:
|
21 |
+
GPT-B | 100M | 7.80 | 2.39 | 326 %
|
22 |
+
GPT-L | 300M | 13.72 | 3.48 | 380 %
|
23 |
+
GPT-XL | 700M | 19.76 | 4.84 | 408 %
|
24 |
+
GPT-XXL | 1.4B | 26.38 | 6.36 | 414 %
|
25 |
+
GPT-3B | 3.1B | - | - | -
|
26 |
+
|
27 |
+
|
28 |
+
```
|
29 |
+
### GPT-B
|
30 |
+
# 7.80 seconds
|
31 |
+
python3 autoregressive/sample/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-04-24-20-56-19/002-GPT-B/checkpoints/1500000.pt
|
32 |
+
|
33 |
+
# 2.39 seconds
|
34 |
+
python3 autoregressive/serve/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-04-24-20-56-19/002-GPT-B/checkpoints/1500000.pt
|
35 |
+
|
36 |
+
|
37 |
+
### GPT-L
|
38 |
+
# 13.72 seconds
|
39 |
+
python3 autoregressive/sample/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-04-27-14-27-57/011-GPT-L/checkpoints/1500000.pt --gpt-model GPT-L
|
40 |
+
|
41 |
+
# 3.48 seconds
|
42 |
+
python3 autoregressive/serve/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-04-27-14-27-57/011-GPT-L/checkpoints/1500000.pt --gpt-model GPT-L
|
43 |
+
|
44 |
+
|
45 |
+
### GPT-XL
|
46 |
+
# 19.76 seconds
|
47 |
+
python3 autoregressive/sample/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-05-05-13-15-40/000-GPT-XL/checkpoints/1500000.pt --gpt-model GPT-XL
|
48 |
+
|
49 |
+
# 4.84 seconds
|
50 |
+
python3 autoregressive/serve/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-05-05-13-15-40/000-GPT-XL/checkpoints/1500000.pt --gpt-model GPT-XL
|
51 |
+
|
52 |
+
|
53 |
+
### GPT-XXL
|
54 |
+
# 26.38 seconds
|
55 |
+
python3 autoregressive/sample/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/20240506150815-GPT-XXXL/0125000/consolidated.pth --from-fsdp --gpt-model GPT-XXXL
|
56 |
+
|
57 |
+
# 6.36 seconds
|
58 |
+
python3 autoregressive/serve/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/20240506150815-GPT-XXXL/0125000/consolidated.pth --from-fsdp --gpt-model GPT-XXXL
|
59 |
+
|
60 |
+
|
61 |
+
```
|
62 |
+
|
63 |
+
In 3B model, head size 100 is not supported by PagedAttention, supported head sizes are: [64, 80, 96, 112, 128, 256]
|
serve/gpt_model.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
8 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
9 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
10 |
+
from vllm.sequence import SamplerOutput
|
11 |
+
|
12 |
+
from vllm.attention import AttentionMetadata
|
13 |
+
from vllm.attention import Attention as pagedAttention
|
14 |
+
|
15 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
16 |
+
from serve.sampler import Sampler
|
17 |
+
|
18 |
+
def find_multiple(n: int, k: int):
|
19 |
+
if n % k == 0:
|
20 |
+
return n
|
21 |
+
return n + k - (n % k)
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class ModelArgs:
|
25 |
+
dim: int = 4096
|
26 |
+
n_layer: int = 32
|
27 |
+
n_head: int = 32
|
28 |
+
n_kv_head: Optional[int] = None
|
29 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
30 |
+
ffn_dim_multiplier: Optional[float] = None
|
31 |
+
rope_base: float = 10000
|
32 |
+
norm_eps: float = 1e-5
|
33 |
+
initializer_range: float = 0.02
|
34 |
+
|
35 |
+
num_classes: int = 1000
|
36 |
+
class_dropout_prob: float = 0.1
|
37 |
+
model_type: str = 'c2i'
|
38 |
+
cfg_scale: float = 4.0
|
39 |
+
|
40 |
+
vocab_size: int = 16384
|
41 |
+
cls_token_num: int = 1
|
42 |
+
block_size: int = 256
|
43 |
+
max_batch_size: int = 32
|
44 |
+
max_seq_len: int = 2048
|
45 |
+
|
46 |
+
|
47 |
+
#################################################################################
|
48 |
+
# Embedding Layers for Class Labels #
|
49 |
+
#################################################################################
|
50 |
+
class LabelEmbedder(nn.Module):
|
51 |
+
"""
|
52 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
53 |
+
"""
|
54 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
55 |
+
super().__init__()
|
56 |
+
use_cfg_embedding = dropout_prob > 0
|
57 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
58 |
+
self.num_classes = num_classes
|
59 |
+
self.dropout_prob = dropout_prob
|
60 |
+
|
61 |
+
# def token_drop(self, labels, force_drop_ids=None):
|
62 |
+
# """
|
63 |
+
# Drops labels to enable classifier-free guidance.
|
64 |
+
# """
|
65 |
+
# if force_drop_ids is None:
|
66 |
+
# drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
67 |
+
# else:
|
68 |
+
# drop_ids = force_drop_ids == 1
|
69 |
+
# labels = torch.where(drop_ids, self.num_classes, labels)
|
70 |
+
# return labels
|
71 |
+
|
72 |
+
# def forward(self, labels, train, force_drop_ids=None):
|
73 |
+
def forward(self, labels):
|
74 |
+
# use_dropout = self.dropout_prob > 0
|
75 |
+
# if (train and use_dropout) or (force_drop_ids is not None):
|
76 |
+
# labels = self.token_drop(labels, force_drop_ids)
|
77 |
+
embeddings = self.embedding_table(labels)
|
78 |
+
return embeddings
|
79 |
+
|
80 |
+
|
81 |
+
#################################################################################
|
82 |
+
# GPT Model #
|
83 |
+
#################################################################################
|
84 |
+
# class RMSNorm(torch.nn.Module):
|
85 |
+
# def __init__(self, dim: int, eps: float = 1e-5):
|
86 |
+
# super().__init__()
|
87 |
+
# self.eps = eps
|
88 |
+
# self.weight = nn.Parameter(torch.ones(dim))
|
89 |
+
|
90 |
+
# def _norm(self, x):
|
91 |
+
# return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
92 |
+
|
93 |
+
# def forward(self, x):
|
94 |
+
# output = self._norm(x.float()).type_as(x)
|
95 |
+
# return output * self.weight
|
96 |
+
|
97 |
+
|
98 |
+
class FeedForward(nn.Module):
|
99 |
+
def __init__(self, config: ModelArgs):
|
100 |
+
super().__init__()
|
101 |
+
hidden_dim = 4 * config.dim
|
102 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
103 |
+
# custom dim factor multiplier
|
104 |
+
if config.ffn_dim_multiplier is not None:
|
105 |
+
hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
|
106 |
+
hidden_dim = find_multiple(hidden_dim, config.multiple_of)
|
107 |
+
|
108 |
+
# self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
|
109 |
+
# self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
|
110 |
+
self.w_merged = nn.Linear(config.dim, hidden_dim * 2, bias=False)
|
111 |
+
self.act_fn = SiluAndMul()
|
112 |
+
|
113 |
+
self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
|
114 |
+
# self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
|
115 |
+
|
116 |
+
# def forward(self, x):
|
117 |
+
# return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = self.w_merged(x)
|
121 |
+
x = self.act_fn(x)
|
122 |
+
x = self.w2(x)
|
123 |
+
# return self.ffn_dropout(x)
|
124 |
+
return x
|
125 |
+
|
126 |
+
|
127 |
+
class Attention(nn.Module):
|
128 |
+
def __init__(self, config: ModelArgs):
|
129 |
+
super().__init__()
|
130 |
+
assert config.dim % config.n_head == 0
|
131 |
+
self.dim = config.dim
|
132 |
+
self.head_dim = config.dim // config.n_head
|
133 |
+
self.n_head = config.n_head
|
134 |
+
self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
|
135 |
+
total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
|
136 |
+
|
137 |
+
# key, query, value projections for all heads, but in a batch
|
138 |
+
self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
|
139 |
+
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
140 |
+
|
141 |
+
# pagedAttention
|
142 |
+
self.attn = pagedAttention(self.n_head,
|
143 |
+
self.head_dim,
|
144 |
+
self.head_dim**-0.5,
|
145 |
+
num_kv_heads=self.n_kv_head,
|
146 |
+
)
|
147 |
+
|
148 |
+
# 2d rotary pos embedding
|
149 |
+
grid_size = int(config.block_size ** 0.5)
|
150 |
+
assert grid_size * grid_size == config.block_size
|
151 |
+
freqs_cis = precompute_freqs_cis_2d(grid_size, config.dim // config.n_head, config.rope_base, config.cls_token_num)
|
152 |
+
self.register_buffer('freqs_cis', freqs_cis)
|
153 |
+
|
154 |
+
|
155 |
+
def forward(
|
156 |
+
self,
|
157 |
+
x: torch.Tensor,
|
158 |
+
positions: torch.Tensor,
|
159 |
+
kv_cache: torch.Tensor,
|
160 |
+
attn_metadata: AttentionMetadata,
|
161 |
+
):
|
162 |
+
kv_size = self.n_kv_head * self.head_dim
|
163 |
+
xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
164 |
+
|
165 |
+
xq = xq.view(*xq.shape[:-1], 1, self.n_head, self.head_dim)
|
166 |
+
xk = xk.view(*xk.shape[:-1], 1, self.n_kv_head, self.head_dim)
|
167 |
+
freqs_cis = self.freqs_cis[positions].unsqueeze(1)
|
168 |
+
xq = apply_rotary_emb_bs(xq, freqs_cis)
|
169 |
+
xk = apply_rotary_emb_bs(xk, freqs_cis)
|
170 |
+
xq = xq.flatten(1)
|
171 |
+
xk = xk.flatten(1)
|
172 |
+
|
173 |
+
output = self.attn(xq, xk, xv, kv_cache, attn_metadata)
|
174 |
+
output = self.wo(output)
|
175 |
+
|
176 |
+
return output
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
class TransformerBlock(nn.Module):
|
181 |
+
def __init__(self, config: ModelArgs):
|
182 |
+
super().__init__()
|
183 |
+
self.attention = Attention(config)
|
184 |
+
self.feed_forward = FeedForward(config)
|
185 |
+
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
186 |
+
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
187 |
+
|
188 |
+
def forward(self, x: torch.Tensor, positions: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata):
|
189 |
+
h = x + self.attention(self.attention_norm(x), positions, kv_cache, attn_metadata)
|
190 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
191 |
+
return out
|
192 |
+
|
193 |
+
|
194 |
+
class Transformer(nn.Module):
|
195 |
+
def __init__(self, config: ModelArgs):
|
196 |
+
super().__init__()
|
197 |
+
self.config = config
|
198 |
+
self.vocab_size = config.vocab_size
|
199 |
+
self.n_layer = config.n_layer
|
200 |
+
self.block_size = config.block_size
|
201 |
+
self.num_classes = config.num_classes
|
202 |
+
self.model_type = config.model_type
|
203 |
+
self.cls_token_num = config.cls_token_num
|
204 |
+
self.cfg_scale = config.cfg_scale
|
205 |
+
if self.model_type == 'c2i':
|
206 |
+
self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
|
207 |
+
else:
|
208 |
+
raise Exception("vllm only supports c2i now, please check model type")
|
209 |
+
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
210 |
+
|
211 |
+
self.layers = torch.nn.ModuleList()
|
212 |
+
for layer_id in range(config.n_layer):
|
213 |
+
self.layers.append(TransformerBlock(config))
|
214 |
+
|
215 |
+
# output layer
|
216 |
+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
217 |
+
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
218 |
+
|
219 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
220 |
+
|
221 |
+
self.sampler = Sampler(config.cfg_scale)
|
222 |
+
|
223 |
+
def forward(
|
224 |
+
self,
|
225 |
+
input_ids: torch.Tensor=None,
|
226 |
+
positions: torch.Tensor=None,
|
227 |
+
kv_caches: List[torch.Tensor]=None,
|
228 |
+
attn_metadata: AttentionMetadata=None,
|
229 |
+
):
|
230 |
+
# if positions.max() == 0: # prefill in inference
|
231 |
+
# token_embeddings = self.cls_embedding(input_ids)
|
232 |
+
# else: # decode_n_tokens(kv cache) in inference
|
233 |
+
# token_embeddings = self.tok_embeddings(input_ids)
|
234 |
+
cond_ids = torch.clamp(input_ids, max=self.num_classes)
|
235 |
+
token_embeddings = self.cls_embedding(cond_ids) * (positions.max() == 0) + \
|
236 |
+
self.tok_embeddings(input_ids) * (positions.max() != 0)
|
237 |
+
|
238 |
+
hh = token_embeddings
|
239 |
+
# transformer blocks
|
240 |
+
for layer_id, layer in enumerate(self.layers):
|
241 |
+
hh = layer(hh, positions, kv_caches[layer_id], attn_metadata)
|
242 |
+
|
243 |
+
# output layers
|
244 |
+
hh = self.norm(hh)
|
245 |
+
return hh
|
246 |
+
|
247 |
+
def compute_logits(self, hidden_states: torch.Tensor,
|
248 |
+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
249 |
+
logits = self.logits_processor(self.output.weight, hidden_states, sampling_metadata)
|
250 |
+
return logits
|
251 |
+
|
252 |
+
def sample(
|
253 |
+
self,
|
254 |
+
logits: torch.Tensor,
|
255 |
+
sampling_metadata: SamplingMetadata,
|
256 |
+
) -> Optional[SamplerOutput]:
|
257 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
258 |
+
return next_tokens
|
259 |
+
|
260 |
+
|
261 |
+
def custom_load_state_dict(self, model_weights):
|
262 |
+
model_weights = model_weights.copy()
|
263 |
+
for layer_id in range(len(self.layers)):
|
264 |
+
branch1 = f'layers.{layer_id}.feed_forward.w1.weight'
|
265 |
+
branch3 = f'layers.{layer_id}.feed_forward.w3.weight'
|
266 |
+
branch_merged = f'layers.{layer_id}.feed_forward.w_merged.weight'
|
267 |
+
model_weights[branch_merged] = torch.cat(
|
268 |
+
[model_weights[branch1], model_weights[branch3]], dim=0
|
269 |
+
)
|
270 |
+
model_weights.pop(branch1)
|
271 |
+
model_weights.pop(branch3)
|
272 |
+
|
273 |
+
if 'freqs_cis' in model_weights:
|
274 |
+
model_weights.pop('freqs_cis')
|
275 |
+
|
276 |
+
self.load_state_dict(model_weights, strict=False)
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
#################################################################################
|
281 |
+
# Rotary Positional Embedding Functions #
|
282 |
+
#################################################################################
|
283 |
+
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
|
284 |
+
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
|
285 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
286 |
+
t = torch.arange(seq_len, device=freqs.device)
|
287 |
+
freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
|
288 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
289 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
|
290 |
+
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
|
291 |
+
return cond_cache
|
292 |
+
|
293 |
+
|
294 |
+
def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
|
295 |
+
# split the dimension into half, one for x and one for y
|
296 |
+
half_dim = n_elem // 2
|
297 |
+
freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
|
298 |
+
t = torch.arange(grid_size, device=freqs.device)
|
299 |
+
freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
|
300 |
+
freqs_grid = torch.concat([
|
301 |
+
freqs[:, None, :].expand(-1, grid_size, -1),
|
302 |
+
freqs[None, :, :].expand(grid_size, -1, -1),
|
303 |
+
], dim=-1) # (grid_size, grid_size, head_dim // 2)
|
304 |
+
cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
|
305 |
+
cache = cache_grid.flatten(0, 1)
|
306 |
+
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
|
307 |
+
return cond_cache
|
308 |
+
|
309 |
+
|
310 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
|
311 |
+
# x: (bs, seq_len, n_head, head_dim)
|
312 |
+
# freqs_cis (seq_len, head_dim // 2, 2)
|
313 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
|
314 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
|
315 |
+
x_out2 = torch.stack([
|
316 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
317 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
318 |
+
], dim=-1)
|
319 |
+
x_out2 = x_out2.flatten(3)
|
320 |
+
return x_out2.type_as(x)
|
321 |
+
|
322 |
+
|
323 |
+
def apply_rotary_emb_bs(x: torch.Tensor, freqs_cis: torch.Tensor):
|
324 |
+
# x: (bs, seq_len, n_head, head_dim)
|
325 |
+
# freqs_cis (seq_len, head_dim // 2, 2)
|
326 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
|
327 |
+
freqs_cis = freqs_cis.view(xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2) # (bs, seq_len, 1, head_dim//2, 2)
|
328 |
+
x_out2 = torch.stack([
|
329 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
330 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
331 |
+
], dim=-1)
|
332 |
+
x_out2 = x_out2.flatten(3)
|
333 |
+
return x_out2.type_as(x)
|
334 |
+
|
335 |
+
|
336 |
+
#################################################################################
|
337 |
+
# GPT Configs #
|
338 |
+
#################################################################################
|
339 |
+
### text-conditional
|
340 |
+
def GPT_7B(**kwargs):
|
341 |
+
return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B
|
342 |
+
|
343 |
+
def GPT_3B(**kwargs):
|
344 |
+
return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B
|
345 |
+
|
346 |
+
def GPT_1B(**kwargs):
|
347 |
+
return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B
|
348 |
+
|
349 |
+
### class-conditional
|
350 |
+
def GPT_XXXL(**kwargs):
|
351 |
+
return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
|
352 |
+
|
353 |
+
def GPT_XXL(**kwargs):
|
354 |
+
return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
|
355 |
+
|
356 |
+
def GPT_XL(**kwargs):
|
357 |
+
return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
|
358 |
+
|
359 |
+
def GPT_L(**kwargs):
|
360 |
+
return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
|
361 |
+
|
362 |
+
def GPT_B(**kwargs):
|
363 |
+
return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
|
364 |
+
|
365 |
+
|
366 |
+
GPT_models = {
|
367 |
+
'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
|
368 |
+
'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
|
369 |
+
}
|
serve/gpu_executor.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Set, Tuple, Optional, Set
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
5 |
+
ModelConfig, ParallelConfig, SchedulerConfig,
|
6 |
+
SpeculativeConfig, VisionLanguageConfig)
|
7 |
+
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
8 |
+
from vllm.logger import init_logger
|
9 |
+
from vllm.lora.request import LoRARequest
|
10 |
+
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
11 |
+
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
12 |
+
make_async)
|
13 |
+
|
14 |
+
logger = init_logger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class GPUExecutor(ExecutorBase):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
args: argparse.ArgumentParser,
|
21 |
+
model_config: ModelConfig,
|
22 |
+
cache_config: CacheConfig,
|
23 |
+
parallel_config: ParallelConfig,
|
24 |
+
scheduler_config: SchedulerConfig,
|
25 |
+
device_config: DeviceConfig,
|
26 |
+
load_config: LoadConfig,
|
27 |
+
lora_config: Optional[LoRAConfig],
|
28 |
+
vision_language_config: Optional[VisionLanguageConfig],
|
29 |
+
speculative_config: Optional[SpeculativeConfig],
|
30 |
+
) -> None:
|
31 |
+
self.args = args
|
32 |
+
self.model_config = model_config
|
33 |
+
self.cache_config = cache_config
|
34 |
+
self.lora_config = lora_config
|
35 |
+
self.load_config = load_config
|
36 |
+
self.parallel_config = parallel_config
|
37 |
+
self.scheduler_config = scheduler_config
|
38 |
+
self.device_config = device_config
|
39 |
+
self.vision_language_config = vision_language_config
|
40 |
+
self.speculative_config = speculative_config
|
41 |
+
|
42 |
+
self._init_executor()
|
43 |
+
|
44 |
+
def _init_executor(self) -> None:
|
45 |
+
"""Initialize the worker and load the model.
|
46 |
+
|
47 |
+
If speculative decoding is enabled, we instead create the speculative
|
48 |
+
worker.
|
49 |
+
"""
|
50 |
+
if self.speculative_config is None:
|
51 |
+
self._init_non_spec_worker()
|
52 |
+
else:
|
53 |
+
self._init_spec_worker()
|
54 |
+
|
55 |
+
def _init_non_spec_worker(self):
|
56 |
+
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
57 |
+
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
58 |
+
# from vllm.worker.worker import Worker
|
59 |
+
from serve.worker import Worker
|
60 |
+
|
61 |
+
assert self.parallel_config.world_size == 1, (
|
62 |
+
"GPUExecutor only supports single GPU.")
|
63 |
+
|
64 |
+
distributed_init_method = get_distributed_init_method(
|
65 |
+
get_ip(), get_open_port())
|
66 |
+
self.driver_worker = Worker(
|
67 |
+
model_config=self.model_config,
|
68 |
+
parallel_config=self.parallel_config,
|
69 |
+
scheduler_config=self.scheduler_config,
|
70 |
+
device_config=self.device_config,
|
71 |
+
cache_config=self.cache_config,
|
72 |
+
load_config=self.load_config,
|
73 |
+
local_rank=0,
|
74 |
+
rank=0,
|
75 |
+
distributed_init_method=distributed_init_method,
|
76 |
+
lora_config=self.lora_config,
|
77 |
+
vision_language_config=self.vision_language_config,
|
78 |
+
is_driver_worker=True,
|
79 |
+
)
|
80 |
+
self.driver_worker.init_device()
|
81 |
+
self.driver_worker.load_model(self.args)
|
82 |
+
|
83 |
+
def _init_spec_worker(self):
|
84 |
+
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
|
85 |
+
"""
|
86 |
+
assert self.speculative_config is not None
|
87 |
+
|
88 |
+
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
89 |
+
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
90 |
+
from vllm.worker.worker import Worker
|
91 |
+
|
92 |
+
distributed_init_method = get_distributed_init_method(
|
93 |
+
get_ip(), get_open_port())
|
94 |
+
|
95 |
+
target_worker = Worker(
|
96 |
+
model_config=self.model_config,
|
97 |
+
parallel_config=self.parallel_config,
|
98 |
+
scheduler_config=self.scheduler_config,
|
99 |
+
device_config=self.device_config,
|
100 |
+
cache_config=self.cache_config,
|
101 |
+
load_config=self.load_config,
|
102 |
+
local_rank=0,
|
103 |
+
rank=0,
|
104 |
+
distributed_init_method=distributed_init_method,
|
105 |
+
lora_config=self.lora_config,
|
106 |
+
vision_language_config=self.vision_language_config,
|
107 |
+
is_driver_worker=True,
|
108 |
+
)
|
109 |
+
|
110 |
+
draft_worker = MultiStepWorker(
|
111 |
+
model_config=self.speculative_config.draft_model_config,
|
112 |
+
parallel_config=self.speculative_config.draft_parallel_config,
|
113 |
+
scheduler_config=self.scheduler_config,
|
114 |
+
device_config=self.device_config,
|
115 |
+
cache_config=self.cache_config,
|
116 |
+
load_config=self.load_config,
|
117 |
+
local_rank=0,
|
118 |
+
rank=0,
|
119 |
+
distributed_init_method=distributed_init_method,
|
120 |
+
lora_config=self.lora_config,
|
121 |
+
vision_language_config=self.vision_language_config,
|
122 |
+
is_driver_worker=True,
|
123 |
+
)
|
124 |
+
|
125 |
+
spec_decode_worker = SpecDecodeWorker.from_workers(
|
126 |
+
proposer_worker=draft_worker, scorer_worker=target_worker)
|
127 |
+
|
128 |
+
assert self.parallel_config.world_size == 1, (
|
129 |
+
"GPUExecutor only supports single GPU.")
|
130 |
+
|
131 |
+
self.driver_worker = spec_decode_worker
|
132 |
+
|
133 |
+
# Load model handled in spec decode worker.
|
134 |
+
self.driver_worker.init_device()
|
135 |
+
|
136 |
+
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
137 |
+
"""Determine the number of available KV blocks by invoking the
|
138 |
+
underlying worker.
|
139 |
+
"""
|
140 |
+
return self.driver_worker.determine_num_available_blocks()
|
141 |
+
|
142 |
+
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
|
143 |
+
"""Initialize the KV cache by invoking the underlying worker.
|
144 |
+
"""
|
145 |
+
# NOTE: This is logged in the executor because there can be >1 worker
|
146 |
+
# with other executors. We could log in the engine level, but work
|
147 |
+
# remains to abstract away the device for non-GPU configurations.
|
148 |
+
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
149 |
+
f"# CPU blocks: {num_cpu_blocks}")
|
150 |
+
|
151 |
+
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
152 |
+
|
153 |
+
def execute_model(
|
154 |
+
self,
|
155 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
156 |
+
blocks_to_swap_in: Dict[int, int],
|
157 |
+
blocks_to_swap_out: Dict[int, int],
|
158 |
+
blocks_to_copy: Dict[int, List[int]],
|
159 |
+
num_lookahead_slots: int,
|
160 |
+
) -> List[SamplerOutput]:
|
161 |
+
output = self.driver_worker.execute_model(
|
162 |
+
seq_group_metadata_list=seq_group_metadata_list,
|
163 |
+
blocks_to_swap_in=blocks_to_swap_in,
|
164 |
+
blocks_to_swap_out=blocks_to_swap_out,
|
165 |
+
blocks_to_copy=blocks_to_copy,
|
166 |
+
num_lookahead_slots=num_lookahead_slots,
|
167 |
+
)
|
168 |
+
return output
|
169 |
+
|
170 |
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
171 |
+
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
172 |
+
return self.driver_worker.add_lora(lora_request)
|
173 |
+
|
174 |
+
def remove_lora(self, lora_id: int) -> bool:
|
175 |
+
assert lora_id > 0, "lora_id must be greater than 0."
|
176 |
+
return self.driver_worker.remove_lora(lora_id)
|
177 |
+
|
178 |
+
def list_loras(self) -> Set[int]:
|
179 |
+
return self.driver_worker.list_loras()
|
180 |
+
|
181 |
+
def check_health(self) -> None:
|
182 |
+
# GPUExecutor will always be healthy as long as
|
183 |
+
# it's running.
|
184 |
+
return
|
185 |
+
|
186 |
+
|
187 |
+
class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
|
188 |
+
|
189 |
+
async def execute_model_async(
|
190 |
+
self,
|
191 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
192 |
+
blocks_to_swap_in: Dict[int, int],
|
193 |
+
blocks_to_swap_out: Dict[int, int],
|
194 |
+
blocks_to_copy: Dict[int, List[int]],
|
195 |
+
) -> SamplerOutput:
|
196 |
+
output = await make_async(self.driver_worker.execute_model)(
|
197 |
+
seq_group_metadata_list=seq_group_metadata_list,
|
198 |
+
blocks_to_swap_in=blocks_to_swap_in,
|
199 |
+
blocks_to_swap_out=blocks_to_swap_out,
|
200 |
+
blocks_to_copy=blocks_to_copy)
|
201 |
+
return output
|
serve/llm.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
9 |
+
|
10 |
+
from vllm.engine.arg_utils import EngineArgs
|
11 |
+
# from vllm.engine.llm_engine import LLMEngine
|
12 |
+
from vllm.lora.request import LoRARequest
|
13 |
+
from vllm.outputs import RequestOutput
|
14 |
+
from vllm.sampling_params import SamplingParams
|
15 |
+
from vllm.sequence import MultiModalData
|
16 |
+
from vllm.usage.usage_lib import UsageContext
|
17 |
+
from vllm.utils import Counter
|
18 |
+
|
19 |
+
from serve.llm_engine import LLMEngine
|
20 |
+
|
21 |
+
|
22 |
+
class LLM:
|
23 |
+
"""An LLM for generating texts from given prompts and sampling parameters.
|
24 |
+
|
25 |
+
This class includes a tokenizer, a language model (possibly distributed
|
26 |
+
across multiple GPUs), and GPU memory space allocated for intermediate
|
27 |
+
states (aka KV cache). Given a batch of prompts and sampling parameters,
|
28 |
+
this class generates texts from the model, using an intelligent batching
|
29 |
+
mechanism and efficient memory management.
|
30 |
+
|
31 |
+
NOTE: This class is intended to be used for offline inference. For online
|
32 |
+
serving, use the `AsyncLLMEngine` class instead.
|
33 |
+
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
model: The name or path of a HuggingFace Transformers model.
|
37 |
+
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
|
38 |
+
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
39 |
+
if available, and "slow" will always use the slow tokenizer.
|
40 |
+
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
41 |
+
detokenizer. Expect valid prompt_token_ids and None for prompt
|
42 |
+
from the input.
|
43 |
+
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
44 |
+
downloading the model and tokenizer.
|
45 |
+
tensor_parallel_size: The number of GPUs to use for distributed
|
46 |
+
execution with tensor parallelism.
|
47 |
+
dtype: The data type for the model weights and activations. Currently,
|
48 |
+
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
49 |
+
the `torch_dtype` attribute specified in the model config file.
|
50 |
+
However, if the `torch_dtype` in the config is `float32`, we will
|
51 |
+
use `float16` instead.
|
52 |
+
quantization: The method used to quantize the model weights. Currently,
|
53 |
+
we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
|
54 |
+
If None, we first check the `quantization_config` attribute in the
|
55 |
+
model config file. If that is None, we assume the model weights are
|
56 |
+
not quantized and use `dtype` to determine the data type of
|
57 |
+
the weights.
|
58 |
+
revision: The specific model version to use. It can be a branch name,
|
59 |
+
a tag name, or a commit id.
|
60 |
+
tokenizer_revision: The specific tokenizer version to use. It can be a
|
61 |
+
branch name, a tag name, or a commit id.
|
62 |
+
seed: The seed to initialize the random number generator for sampling.
|
63 |
+
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
64 |
+
reserve for the model weights, activations, and KV cache. Higher
|
65 |
+
values will increase the KV cache size and thus improve the model's
|
66 |
+
throughput. However, if the value is too high, it may cause out-of-
|
67 |
+
memory (OOM) errors.
|
68 |
+
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
69 |
+
This can be used for temporarily storing the states of the requests
|
70 |
+
when their `best_of` sampling parameters are larger than 1. If all
|
71 |
+
requests will have `best_of=1`, you can safely set this to 0.
|
72 |
+
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
73 |
+
enforce_eager: Whether to enforce eager execution. If True, we will
|
74 |
+
disable CUDA graph and always execute the model in eager mode.
|
75 |
+
If False, we will use CUDA graph and eager execution in hybrid.
|
76 |
+
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
77 |
+
When a sequence has context length larger than this, we fall back
|
78 |
+
to eager mode.
|
79 |
+
disable_custom_all_reduce: See ParallelConfig
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
args: argparse.ArgumentParser,
|
85 |
+
model: str,
|
86 |
+
tokenizer: Optional[str] = None,
|
87 |
+
tokenizer_mode: str = "auto",
|
88 |
+
skip_tokenizer_init: bool = False,
|
89 |
+
trust_remote_code: bool = False,
|
90 |
+
tensor_parallel_size: int = 1,
|
91 |
+
dtype: str = "auto",
|
92 |
+
quantization: Optional[str] = None,
|
93 |
+
revision: Optional[str] = None,
|
94 |
+
tokenizer_revision: Optional[str] = None,
|
95 |
+
seed: int = 0,
|
96 |
+
gpu_memory_utilization: float = 0.9,
|
97 |
+
swap_space: int = 4,
|
98 |
+
enforce_eager: bool = False,
|
99 |
+
max_context_len_to_capture: int = 8192,
|
100 |
+
disable_custom_all_reduce: bool = False,
|
101 |
+
**kwargs,
|
102 |
+
) -> None:
|
103 |
+
if "disable_log_stats" not in kwargs:
|
104 |
+
kwargs["disable_log_stats"] = True
|
105 |
+
engine_args = EngineArgs(
|
106 |
+
model=model,
|
107 |
+
tokenizer=tokenizer,
|
108 |
+
tokenizer_mode=tokenizer_mode,
|
109 |
+
skip_tokenizer_init=skip_tokenizer_init,
|
110 |
+
trust_remote_code=trust_remote_code,
|
111 |
+
tensor_parallel_size=tensor_parallel_size,
|
112 |
+
dtype=dtype,
|
113 |
+
quantization=quantization,
|
114 |
+
revision=revision,
|
115 |
+
tokenizer_revision=tokenizer_revision,
|
116 |
+
seed=seed,
|
117 |
+
gpu_memory_utilization=gpu_memory_utilization,
|
118 |
+
swap_space=swap_space,
|
119 |
+
enforce_eager=enforce_eager,
|
120 |
+
max_context_len_to_capture=max_context_len_to_capture,
|
121 |
+
disable_custom_all_reduce=disable_custom_all_reduce,
|
122 |
+
**kwargs,
|
123 |
+
)
|
124 |
+
self.llm_engine = LLMEngine.from_engine_args(
|
125 |
+
engine_args, usage_context=UsageContext.LLM_CLASS, args=args)
|
126 |
+
self.request_counter = Counter()
|
127 |
+
|
128 |
+
def get_tokenizer(
|
129 |
+
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
130 |
+
return self.llm_engine.tokenizer.tokenizer
|
131 |
+
|
132 |
+
def set_tokenizer(
|
133 |
+
self,
|
134 |
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
135 |
+
) -> None:
|
136 |
+
self.llm_engine.tokenizer.tokenizer = tokenizer
|
137 |
+
|
138 |
+
def generate(
|
139 |
+
self,
|
140 |
+
prompts: Optional[Union[str, List[str]]] = None,
|
141 |
+
sampling_params: Optional[Union[SamplingParams,
|
142 |
+
List[SamplingParams]]] = None,
|
143 |
+
prompt_token_ids: Optional[List[List[int]]] = None,
|
144 |
+
use_tqdm: bool = True,
|
145 |
+
lora_request: Optional[LoRARequest] = None,
|
146 |
+
multi_modal_data: Optional[MultiModalData] = None,
|
147 |
+
) -> List[RequestOutput]:
|
148 |
+
"""Generates the completions for the input prompts.
|
149 |
+
|
150 |
+
NOTE: This class automatically batches the given prompts, considering
|
151 |
+
the memory constraint. For the best performance, put all of your prompts
|
152 |
+
into a single list and pass it to this method.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
prompts: A list of prompts to generate completions for.
|
156 |
+
sampling_params: The sampling parameters for text generation. If
|
157 |
+
None, we use the default sampling parameters.
|
158 |
+
When it is a single value, it is applied to every prompt.
|
159 |
+
When it is a list, the list must have the same length as the
|
160 |
+
prompts and it is paired one by one with the prompt.
|
161 |
+
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
162 |
+
use the tokenizer to convert the prompts to token IDs.
|
163 |
+
use_tqdm: Whether to use tqdm to display the progress bar.
|
164 |
+
lora_request: LoRA request to use for generation, if any.
|
165 |
+
multi_modal_data: Multi modal data.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
A list of `RequestOutput` objects containing the generated
|
169 |
+
completions in the same order as the input prompts.
|
170 |
+
"""
|
171 |
+
if prompts is None and prompt_token_ids is None:
|
172 |
+
raise ValueError("Either prompts or prompt_token_ids must be "
|
173 |
+
"provided.")
|
174 |
+
if self.llm_engine.model_config.skip_tokenizer_init \
|
175 |
+
and prompts is not None:
|
176 |
+
raise ValueError("prompts must be None if skip_tokenizer_init "
|
177 |
+
"is True")
|
178 |
+
if isinstance(prompts, str):
|
179 |
+
# Convert a single prompt to a list.
|
180 |
+
prompts = [prompts]
|
181 |
+
if (prompts is not None and prompt_token_ids is not None
|
182 |
+
and len(prompts) != len(prompt_token_ids)):
|
183 |
+
raise ValueError("The lengths of prompts and prompt_token_ids "
|
184 |
+
"must be the same.")
|
185 |
+
|
186 |
+
if prompts is not None:
|
187 |
+
num_requests = len(prompts)
|
188 |
+
else:
|
189 |
+
assert prompt_token_ids is not None
|
190 |
+
num_requests = len(prompt_token_ids)
|
191 |
+
|
192 |
+
if sampling_params is None:
|
193 |
+
# Use default sampling params.
|
194 |
+
sampling_params = SamplingParams()
|
195 |
+
|
196 |
+
elif isinstance(sampling_params,
|
197 |
+
list) and len(sampling_params) != num_requests:
|
198 |
+
raise ValueError("The lengths of prompts and sampling_params "
|
199 |
+
"must be the same.")
|
200 |
+
if multi_modal_data:
|
201 |
+
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
|
202 |
+
|
203 |
+
# Add requests to the engine.
|
204 |
+
for i in range(num_requests):
|
205 |
+
prompt = prompts[i] if prompts is not None else None
|
206 |
+
token_ids = None if prompt_token_ids is None else prompt_token_ids[i]
|
207 |
+
self._add_request(
|
208 |
+
prompt,
|
209 |
+
sampling_params[i]
|
210 |
+
if isinstance(sampling_params, list) else sampling_params,
|
211 |
+
token_ids,
|
212 |
+
lora_request=lora_request,
|
213 |
+
# Get ith image while maintaining the batch dim.
|
214 |
+
multi_modal_data=MultiModalData(
|
215 |
+
type=multi_modal_data.type,
|
216 |
+
data=multi_modal_data.data[i].unsqueeze(0))
|
217 |
+
if multi_modal_data else None,
|
218 |
+
)
|
219 |
+
return self._run_engine(use_tqdm)
|
220 |
+
|
221 |
+
def _add_request(
|
222 |
+
self,
|
223 |
+
prompt: Optional[str],
|
224 |
+
sampling_params: SamplingParams,
|
225 |
+
prompt_token_ids: Optional[List[int]],
|
226 |
+
lora_request: Optional[LoRARequest] = None,
|
227 |
+
multi_modal_data: Optional[MultiModalData] = None,
|
228 |
+
) -> None:
|
229 |
+
request_id = str(next(self.request_counter))
|
230 |
+
self.llm_engine.add_request(request_id,
|
231 |
+
prompt,
|
232 |
+
sampling_params,
|
233 |
+
prompt_token_ids,
|
234 |
+
lora_request=lora_request,
|
235 |
+
multi_modal_data=multi_modal_data)
|
236 |
+
|
237 |
+
|
238 |
+
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
239 |
+
# Initialize tqdm.
|
240 |
+
if use_tqdm:
|
241 |
+
num_requests = self.llm_engine.get_num_unfinished_requests()
|
242 |
+
pbar = tqdm(
|
243 |
+
total=num_requests,
|
244 |
+
desc="Processed prompts",
|
245 |
+
dynamic_ncols=True,
|
246 |
+
postfix=f"Generation Speed: {0:.2f} toks/s",
|
247 |
+
)
|
248 |
+
# Run the engine.
|
249 |
+
outputs: List[RequestOutput] = []
|
250 |
+
while self.llm_engine.has_unfinished_requests():
|
251 |
+
step_outputs = self.llm_engine.step()
|
252 |
+
for output in step_outputs:
|
253 |
+
if output.finished:
|
254 |
+
outputs.append(output)
|
255 |
+
if use_tqdm:
|
256 |
+
total_toks += (sum(
|
257 |
+
len(stp.token_ids) for stp in output.outputs))
|
258 |
+
spd = total_toks / pbar.format_dict["elapsed"]
|
259 |
+
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
|
260 |
+
pbar.update(1)
|
261 |
+
if use_tqdm:
|
262 |
+
pbar.close()
|
263 |
+
# Sort the outputs by request ID.
|
264 |
+
# This is necessary because some requests may be finished earlier than
|
265 |
+
# its previous requests.
|
266 |
+
outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
267 |
+
return outputs
|
serve/llm_engine.py
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py
|
3 |
+
import time
|
4 |
+
from typing import Iterable, List, Optional, Type, Union
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
from transformers import GenerationConfig, PreTrainedTokenizer
|
8 |
+
|
9 |
+
import vllm
|
10 |
+
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
11 |
+
LoRAConfig, ModelConfig, ParallelConfig,
|
12 |
+
SchedulerConfig, SpeculativeConfig,
|
13 |
+
VisionLanguageConfig)
|
14 |
+
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
15 |
+
from vllm.engine.arg_utils import EngineArgs
|
16 |
+
from vllm.engine.metrics import StatLogger, Stats
|
17 |
+
from vllm.engine.output_processor.interfaces import (
|
18 |
+
SequenceGroupOutputProcessor)
|
19 |
+
from vllm.engine.output_processor.stop_checker import StopChecker
|
20 |
+
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
21 |
+
from vllm.engine.ray_utils import initialize_ray_cluster
|
22 |
+
from vllm.executor.executor_base import ExecutorBase
|
23 |
+
from vllm.logger import init_logger
|
24 |
+
from vllm.lora.request import LoRARequest
|
25 |
+
from vllm.outputs import RequestOutput
|
26 |
+
from vllm.sampling_params import SamplingParams
|
27 |
+
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
28 |
+
SequenceGroup)
|
29 |
+
from vllm.transformers_utils.detokenizer import Detokenizer
|
30 |
+
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
31 |
+
get_tokenizer_group)
|
32 |
+
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
33 |
+
usage_message)
|
34 |
+
from vllm.utils import Counter
|
35 |
+
|
36 |
+
logger = init_logger(__name__)
|
37 |
+
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
38 |
+
|
39 |
+
|
40 |
+
def _load_generation_config_dict(model_config: ModelConfig):
|
41 |
+
try:
|
42 |
+
return GenerationConfig.from_pretrained(
|
43 |
+
model_config.model,
|
44 |
+
revision=model_config.revision,
|
45 |
+
).to_diff_dict()
|
46 |
+
except OSError:
|
47 |
+
# Not found.
|
48 |
+
return {}
|
49 |
+
|
50 |
+
|
51 |
+
class LLMEngine:
|
52 |
+
"""An LLM engine that receives requests and generates texts.
|
53 |
+
|
54 |
+
This is the main class for the vLLM engine. It receives requests
|
55 |
+
from clients and generates texts from the LLM. It includes a tokenizer, a
|
56 |
+
language model (possibly distributed across multiple GPUs), and GPU memory
|
57 |
+
space allocated for intermediate states (aka KV cache). This class utilizes
|
58 |
+
iteration-level scheduling and efficient memory management to maximize the
|
59 |
+
serving throughput.
|
60 |
+
|
61 |
+
The `LLM` class wraps this class for offline batched inference and the
|
62 |
+
`AsyncLLMEngine` class wraps this class for online serving.
|
63 |
+
|
64 |
+
NOTE: The config arguments are derived from the `EngineArgs` class. For the
|
65 |
+
comprehensive list of arguments, see `EngineArgs`.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
model_config: The configuration related to the LLM model.
|
69 |
+
cache_config: The configuration related to the KV cache memory
|
70 |
+
management.
|
71 |
+
parallel_config: The configuration related to distributed execution.
|
72 |
+
scheduler_config: The configuration related to the request scheduler.
|
73 |
+
device_config: The configuration related to the device.
|
74 |
+
lora_config (Optional): The configuration related to serving multi-LoRA.
|
75 |
+
vision_language_config (Optional): The configuration related to vision
|
76 |
+
language models.
|
77 |
+
speculative_config (Optional): The configuration related to speculative
|
78 |
+
decoding.
|
79 |
+
executor_class: The model executor class for managing distributed
|
80 |
+
execution.
|
81 |
+
log_stats: Whether to log statistics.
|
82 |
+
usage_context: Specified entry point, used for usage info collection
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
args: argparse.ArgumentParser,
|
88 |
+
model_config: ModelConfig,
|
89 |
+
cache_config: CacheConfig,
|
90 |
+
parallel_config: ParallelConfig,
|
91 |
+
scheduler_config: SchedulerConfig,
|
92 |
+
device_config: DeviceConfig,
|
93 |
+
load_config: LoadConfig,
|
94 |
+
lora_config: Optional[LoRAConfig],
|
95 |
+
vision_language_config: Optional[VisionLanguageConfig],
|
96 |
+
speculative_config: Optional[SpeculativeConfig],
|
97 |
+
decoding_config: Optional[DecodingConfig],
|
98 |
+
executor_class: Type[ExecutorBase],
|
99 |
+
log_stats: bool,
|
100 |
+
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
101 |
+
) -> None:
|
102 |
+
logger.info(
|
103 |
+
f"Initializing an LLM engine (v{vllm.__version__}) with config: "
|
104 |
+
f"model={model_config.model!r}, "
|
105 |
+
f"speculative_config={speculative_config!r}, "
|
106 |
+
f"tokenizer={model_config.tokenizer!r}, "
|
107 |
+
f"skip_tokenizer_init={model_config.skip_tokenizer_init}, "
|
108 |
+
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
109 |
+
f"revision={model_config.revision}, "
|
110 |
+
f"tokenizer_revision={model_config.tokenizer_revision}, "
|
111 |
+
f"trust_remote_code={model_config.trust_remote_code}, "
|
112 |
+
f"dtype={model_config.dtype}, "
|
113 |
+
f"max_seq_len={model_config.max_model_len}, "
|
114 |
+
f"download_dir={load_config.download_dir!r}, "
|
115 |
+
f"load_format={load_config.load_format}, "
|
116 |
+
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
117 |
+
f"disable_custom_all_reduce="
|
118 |
+
f"{parallel_config.disable_custom_all_reduce}, "
|
119 |
+
f"quantization={model_config.quantization}, "
|
120 |
+
f"enforce_eager={model_config.enforce_eager}, "
|
121 |
+
f"kv_cache_dtype={cache_config.cache_dtype}, "
|
122 |
+
f"quantization_param_path={model_config.quantization_param_path}, "
|
123 |
+
f"device_config={device_config.device}, "
|
124 |
+
f"decoding_config={decoding_config!r}, "
|
125 |
+
f"seed={model_config.seed})")
|
126 |
+
# TODO(woosuk): Print more configs in debug mode.
|
127 |
+
|
128 |
+
self.model_config = model_config
|
129 |
+
self.cache_config = cache_config
|
130 |
+
self.lora_config = lora_config
|
131 |
+
self.vision_language_config = vision_language_config
|
132 |
+
self.parallel_config = parallel_config
|
133 |
+
self.scheduler_config = scheduler_config
|
134 |
+
self.device_config = device_config
|
135 |
+
self.speculative_config = speculative_config
|
136 |
+
self.load_config = load_config
|
137 |
+
self.decoding_config = decoding_config or DecodingConfig()
|
138 |
+
self.log_stats = log_stats
|
139 |
+
|
140 |
+
if not self.model_config.skip_tokenizer_init:
|
141 |
+
self.tokenizer: BaseTokenizerGroup
|
142 |
+
self._init_tokenizer()
|
143 |
+
self.detokenizer = Detokenizer(self.tokenizer)
|
144 |
+
else:
|
145 |
+
self.detokenizer = None
|
146 |
+
self.tokenizer = None
|
147 |
+
|
148 |
+
self.seq_counter = Counter()
|
149 |
+
self.generation_config_fields = _load_generation_config_dict(
|
150 |
+
model_config)
|
151 |
+
|
152 |
+
self.model_executor = executor_class(
|
153 |
+
args=args,
|
154 |
+
model_config=model_config,
|
155 |
+
cache_config=cache_config,
|
156 |
+
parallel_config=parallel_config,
|
157 |
+
scheduler_config=scheduler_config,
|
158 |
+
device_config=device_config,
|
159 |
+
lora_config=lora_config,
|
160 |
+
vision_language_config=vision_language_config,
|
161 |
+
speculative_config=speculative_config,
|
162 |
+
load_config=load_config,
|
163 |
+
)
|
164 |
+
|
165 |
+
self._initialize_kv_caches()
|
166 |
+
|
167 |
+
# If usage stat is enabled, collect relevant info.
|
168 |
+
if is_usage_stats_enabled():
|
169 |
+
from vllm.model_executor.model_loader import (
|
170 |
+
get_architecture_class_name)
|
171 |
+
usage_message.report_usage(
|
172 |
+
get_architecture_class_name(model_config),
|
173 |
+
usage_context,
|
174 |
+
extra_kvs={
|
175 |
+
# Common configuration
|
176 |
+
"dtype":
|
177 |
+
str(model_config.dtype),
|
178 |
+
"tensor_parallel_size":
|
179 |
+
parallel_config.tensor_parallel_size,
|
180 |
+
"block_size":
|
181 |
+
cache_config.block_size,
|
182 |
+
"gpu_memory_utilization":
|
183 |
+
cache_config.gpu_memory_utilization,
|
184 |
+
|
185 |
+
# Quantization
|
186 |
+
"quantization":
|
187 |
+
model_config.quantization,
|
188 |
+
"kv_cache_dtype":
|
189 |
+
cache_config.cache_dtype,
|
190 |
+
|
191 |
+
# Feature flags
|
192 |
+
"enable_lora":
|
193 |
+
bool(lora_config),
|
194 |
+
"enable_prefix_caching":
|
195 |
+
cache_config.enable_prefix_caching,
|
196 |
+
"enforce_eager":
|
197 |
+
model_config.enforce_eager,
|
198 |
+
"disable_custom_all_reduce":
|
199 |
+
parallel_config.disable_custom_all_reduce,
|
200 |
+
})
|
201 |
+
|
202 |
+
if self.tokenizer:
|
203 |
+
# Ping the tokenizer to ensure liveness if it runs in a
|
204 |
+
# different process.
|
205 |
+
self.tokenizer.ping()
|
206 |
+
|
207 |
+
# Create the scheduler.
|
208 |
+
# NOTE: the cache_config here have been updated with the numbers of
|
209 |
+
# GPU and CPU blocks, which are profiled in the distributed executor.
|
210 |
+
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
211 |
+
|
212 |
+
# Metric Logging.
|
213 |
+
if self.log_stats:
|
214 |
+
self.stat_logger = StatLogger(
|
215 |
+
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
216 |
+
labels=dict(model_name=model_config.model))
|
217 |
+
self.stat_logger.info("cache_config", self.cache_config)
|
218 |
+
|
219 |
+
# Create sequence output processor, e.g. for beam search or
|
220 |
+
# speculative decoding.
|
221 |
+
self.output_processor = (
|
222 |
+
SequenceGroupOutputProcessor.create_output_processor(
|
223 |
+
self.scheduler_config,
|
224 |
+
self.detokenizer,
|
225 |
+
self.scheduler,
|
226 |
+
self.seq_counter,
|
227 |
+
self.get_tokenizer_for_seq,
|
228 |
+
stop_checker=StopChecker(
|
229 |
+
self.scheduler_config.max_model_len,
|
230 |
+
self.get_tokenizer_for_seq,
|
231 |
+
),
|
232 |
+
))
|
233 |
+
|
234 |
+
def _initialize_kv_caches(self) -> None:
|
235 |
+
"""Initialize the KV cache in the worker(s).
|
236 |
+
|
237 |
+
The workers will determine the number of blocks in both the GPU cache
|
238 |
+
and the swap CPU cache.
|
239 |
+
"""
|
240 |
+
num_gpu_blocks, num_cpu_blocks = (
|
241 |
+
self.model_executor.determine_num_available_blocks())
|
242 |
+
|
243 |
+
if self.cache_config.num_gpu_blocks_override is not None:
|
244 |
+
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
|
245 |
+
logger.info(f"Overriding {num_gpu_blocks=} with "
|
246 |
+
f"{num_gpu_blocks_override=}")
|
247 |
+
num_gpu_blocks = num_gpu_blocks_override
|
248 |
+
|
249 |
+
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
250 |
+
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
251 |
+
|
252 |
+
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
253 |
+
|
254 |
+
@classmethod
|
255 |
+
def from_engine_args(
|
256 |
+
cls,
|
257 |
+
engine_args: EngineArgs,
|
258 |
+
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
259 |
+
args: argparse.ArgumentParser = None,
|
260 |
+
) -> "LLMEngine":
|
261 |
+
"""Creates an LLM engine from the engine arguments."""
|
262 |
+
# Create the engine configs.
|
263 |
+
engine_config = engine_args.create_engine_config()
|
264 |
+
|
265 |
+
# Initialize the cluster and specify the executor class.
|
266 |
+
if engine_config.device_config.device_type == "neuron":
|
267 |
+
from vllm.executor.neuron_executor import NeuronExecutor
|
268 |
+
executor_class = NeuronExecutor
|
269 |
+
elif engine_config.device_config.device_type == "cpu":
|
270 |
+
from vllm.executor.cpu_executor import CPUExecutor
|
271 |
+
executor_class = CPUExecutor
|
272 |
+
elif engine_config.parallel_config.worker_use_ray:
|
273 |
+
initialize_ray_cluster(engine_config.parallel_config)
|
274 |
+
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
275 |
+
executor_class = RayGPUExecutor
|
276 |
+
else:
|
277 |
+
assert engine_config.parallel_config.world_size == 1, (
|
278 |
+
"Ray is required if parallel_config.world_size > 1.")
|
279 |
+
# from vllm.executor.gpu_executor import GPUExecutor
|
280 |
+
from serve.gpu_executor import GPUExecutor
|
281 |
+
executor_class = GPUExecutor
|
282 |
+
|
283 |
+
# Create the LLM engine.
|
284 |
+
engine = cls(
|
285 |
+
**engine_config.to_dict(),
|
286 |
+
executor_class=executor_class,
|
287 |
+
log_stats=not engine_args.disable_log_stats,
|
288 |
+
usage_context=usage_context,
|
289 |
+
args=args,
|
290 |
+
)
|
291 |
+
return engine
|
292 |
+
|
293 |
+
def __reduce__(self):
|
294 |
+
# This is to ensure that the LLMEngine is not referenced in
|
295 |
+
# the closure used to initialize Ray worker actors
|
296 |
+
raise RuntimeError("LLMEngine should not be pickled!")
|
297 |
+
|
298 |
+
def get_tokenizer(self) -> "PreTrainedTokenizer":
|
299 |
+
return self.tokenizer.get_lora_tokenizer(None)
|
300 |
+
|
301 |
+
def get_tokenizer_for_seq(self,
|
302 |
+
sequence: Sequence) -> "PreTrainedTokenizer":
|
303 |
+
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
304 |
+
|
305 |
+
def _init_tokenizer(self, **tokenizer_init_kwargs):
|
306 |
+
init_kwargs = dict(
|
307 |
+
tokenizer_id=self.model_config.tokenizer,
|
308 |
+
enable_lora=bool(self.lora_config),
|
309 |
+
max_num_seqs=self.scheduler_config.max_num_seqs,
|
310 |
+
max_input_length=None,
|
311 |
+
tokenizer_mode=self.model_config.tokenizer_mode,
|
312 |
+
trust_remote_code=self.model_config.trust_remote_code,
|
313 |
+
revision=self.model_config.tokenizer_revision)
|
314 |
+
init_kwargs.update(tokenizer_init_kwargs)
|
315 |
+
self.tokenizer = get_tokenizer_group(
|
316 |
+
self.parallel_config.tokenizer_pool_config, **init_kwargs)
|
317 |
+
|
318 |
+
def _verify_args(self) -> None:
|
319 |
+
self.model_config.verify_with_parallel_config(self.parallel_config)
|
320 |
+
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
321 |
+
if self.lora_config:
|
322 |
+
self.lora_config.verify_with_model_config(self.model_config)
|
323 |
+
self.lora_config.verify_with_scheduler_config(
|
324 |
+
self.scheduler_config)
|
325 |
+
|
326 |
+
def encode_request(
|
327 |
+
self,
|
328 |
+
request_id: str, # pylint: disable=unused-argument
|
329 |
+
prompt: Optional[str],
|
330 |
+
prompt_token_ids: Optional[List[int]] = None,
|
331 |
+
lora_request: Optional[LoRARequest] = None,
|
332 |
+
):
|
333 |
+
if prompt_token_ids is None:
|
334 |
+
assert prompt is not None
|
335 |
+
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
|
336 |
+
prompt=prompt,
|
337 |
+
lora_request=lora_request)
|
338 |
+
return prompt_token_ids
|
339 |
+
|
340 |
+
def add_request(
|
341 |
+
self,
|
342 |
+
request_id: str,
|
343 |
+
prompt: Optional[str],
|
344 |
+
sampling_params: SamplingParams,
|
345 |
+
prompt_token_ids: Optional[List[int]] = None,
|
346 |
+
arrival_time: Optional[float] = None,
|
347 |
+
lora_request: Optional[LoRARequest] = None,
|
348 |
+
multi_modal_data: Optional[MultiModalData] = None,
|
349 |
+
) -> None:
|
350 |
+
"""Add a request to the engine's request pool.
|
351 |
+
|
352 |
+
The request is added to the request pool and will be processed by the
|
353 |
+
scheduler as `engine.step()` is called. The exact scheduling policy is
|
354 |
+
determined by the scheduler.
|
355 |
+
|
356 |
+
Args:
|
357 |
+
request_id: The unique ID of the request.
|
358 |
+
prompt: The prompt string. Can be None if prompt_token_ids is
|
359 |
+
provided.
|
360 |
+
sampling_params: The sampling parameters for text generation.
|
361 |
+
prompt_token_ids: The token IDs of the prompt. If None, we
|
362 |
+
use the tokenizer to convert the prompts to token IDs.
|
363 |
+
arrival_time: The arrival time of the request. If None, we use
|
364 |
+
the current monotonic time.
|
365 |
+
multi_modal_data: Multi modal data per request.
|
366 |
+
|
367 |
+
Details:
|
368 |
+
- Set arrival_time to the current time if it is None.
|
369 |
+
- Set prompt_token_ids to the encoded prompt if it is None.
|
370 |
+
- Create `best_of` number of :class:`~vllm.Sequence` objects.
|
371 |
+
- Create a :class:`~vllm.SequenceGroup` object
|
372 |
+
from the list of :class:`~vllm.Sequence`.
|
373 |
+
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
|
374 |
+
|
375 |
+
Example:
|
376 |
+
>>> # initialize engine
|
377 |
+
>>> engine = LLMEngine.from_engine_args(engine_args)
|
378 |
+
>>> # set request arguments
|
379 |
+
>>> example_prompt = "Who is the president of the United States?"
|
380 |
+
>>> sampling_params = SamplingParams(temperature=0.0)
|
381 |
+
>>> request_id = 0
|
382 |
+
>>>
|
383 |
+
>>> # add the request to the engine
|
384 |
+
>>> engine.add_request(
|
385 |
+
>>> str(request_id),
|
386 |
+
>>> example_prompt,
|
387 |
+
>>> SamplingParams(temperature=0.0))
|
388 |
+
>>> # continue the request processing
|
389 |
+
>>> ...
|
390 |
+
"""
|
391 |
+
if lora_request is not None and not self.lora_config:
|
392 |
+
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
393 |
+
"not enabled!")
|
394 |
+
max_logprobs = self.get_model_config().max_logprobs
|
395 |
+
if (sampling_params.logprobs
|
396 |
+
and sampling_params.logprobs > max_logprobs) or (
|
397 |
+
sampling_params.prompt_logprobs
|
398 |
+
and sampling_params.prompt_logprobs > max_logprobs):
|
399 |
+
raise ValueError(f"Cannot request more than "
|
400 |
+
f"{max_logprobs} logprobs.")
|
401 |
+
if arrival_time is None:
|
402 |
+
arrival_time = time.time()
|
403 |
+
prompt_token_ids = self.encode_request(
|
404 |
+
request_id=request_id,
|
405 |
+
prompt=prompt,
|
406 |
+
prompt_token_ids=prompt_token_ids,
|
407 |
+
lora_request=lora_request)
|
408 |
+
|
409 |
+
# Create the sequences.
|
410 |
+
block_size = self.cache_config.block_size
|
411 |
+
seq_id = next(self.seq_counter)
|
412 |
+
eos_token_id = None
|
413 |
+
if self.tokenizer:
|
414 |
+
eos_token_id = self.tokenizer.get_lora_tokenizer(
|
415 |
+
lora_request).eos_token_id
|
416 |
+
else:
|
417 |
+
logger.warning("Use None for EOS token id because tokenizer is "
|
418 |
+
"not initialized")
|
419 |
+
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
420 |
+
eos_token_id, lora_request)
|
421 |
+
|
422 |
+
# Defensive copy of SamplingParams, which are used by the sampler,
|
423 |
+
# this doesn't deep-copy LogitsProcessor objects
|
424 |
+
sampling_params = sampling_params.clone()
|
425 |
+
# Add the eos token id into the sampling_params to support min_tokens
|
426 |
+
# processing
|
427 |
+
if seq.eos_token_id is not None:
|
428 |
+
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
|
429 |
+
sampling_params.update_from_generation_config(
|
430 |
+
self.generation_config_fields)
|
431 |
+
|
432 |
+
# Create the sequence group.
|
433 |
+
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
434 |
+
arrival_time, lora_request, multi_modal_data)
|
435 |
+
|
436 |
+
# Add the sequence group to the scheduler.
|
437 |
+
self.scheduler.add_seq_group(seq_group)
|
438 |
+
|
439 |
+
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
440 |
+
"""Aborts a request(s) with the given ID.
|
441 |
+
|
442 |
+
Args:
|
443 |
+
request_id: The ID(s) of the request to abort.
|
444 |
+
|
445 |
+
Details:
|
446 |
+
- Refer to the
|
447 |
+
:meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
|
448 |
+
from class :class:`~vllm.core.scheduler.Scheduler`.
|
449 |
+
|
450 |
+
Example:
|
451 |
+
>>> # initialize engine and add a request with request_id
|
452 |
+
>>> request_id = str(0)
|
453 |
+
>>> # abort the request
|
454 |
+
>>> engine.abort_request(request_id)
|
455 |
+
"""
|
456 |
+
self.scheduler.abort_seq_group(request_id)
|
457 |
+
|
458 |
+
def get_model_config(self) -> ModelConfig:
|
459 |
+
"""Gets the model configuration."""
|
460 |
+
return self.model_config
|
461 |
+
|
462 |
+
def get_num_unfinished_requests(self) -> int:
|
463 |
+
"""Gets the number of unfinished requests."""
|
464 |
+
return self.scheduler.get_num_unfinished_seq_groups()
|
465 |
+
|
466 |
+
def has_unfinished_requests(self) -> bool:
|
467 |
+
"""Returns True if there are unfinished requests."""
|
468 |
+
return self.scheduler.has_unfinished_seqs()
|
469 |
+
|
470 |
+
def _process_model_outputs(
|
471 |
+
self, output: List[SamplerOutput],
|
472 |
+
scheduled_seq_groups: List[SequenceGroup],
|
473 |
+
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
|
474 |
+
"""Apply the model output to the sequences in the scheduled seq groups.
|
475 |
+
|
476 |
+
Returns RequestOutputs that can be returned to the client.
|
477 |
+
"""
|
478 |
+
now = time.time()
|
479 |
+
|
480 |
+
# Organize outputs by [sequence group][step] instead of
|
481 |
+
# [step][sequence group].
|
482 |
+
output_by_sequence_group = create_output_by_sequence_group(
|
483 |
+
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
|
484 |
+
|
485 |
+
# Update the scheduled sequence groups with the model outputs.
|
486 |
+
for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
|
487 |
+
output_by_sequence_group):
|
488 |
+
seq_group = scheduled_seq_group.seq_group
|
489 |
+
seq_group.update_num_computed_tokens(
|
490 |
+
scheduled_seq_group.token_chunk_size)
|
491 |
+
# If uncomputed tokens > 0, it means prefill is chunked.
|
492 |
+
# We don't need to process outputs in that case.
|
493 |
+
if seq_group.get_num_uncomputed_tokens() == 0:
|
494 |
+
self.output_processor.process_outputs(seq_group, outputs)
|
495 |
+
|
496 |
+
# Free the finished sequence groups.
|
497 |
+
self.scheduler.free_finished_seq_groups()
|
498 |
+
|
499 |
+
# Create the outputs.
|
500 |
+
request_outputs: List[RequestOutput] = []
|
501 |
+
for scheduled_seq_group in scheduled_seq_groups:
|
502 |
+
seq_group = scheduled_seq_group.seq_group
|
503 |
+
seq_group.maybe_set_first_token_time(now)
|
504 |
+
request_output = RequestOutput.from_seq_group(seq_group)
|
505 |
+
request_outputs.append(request_output)
|
506 |
+
for seq_group in ignored_seq_groups:
|
507 |
+
request_output = RequestOutput.from_seq_group(seq_group)
|
508 |
+
request_outputs.append(request_output)
|
509 |
+
return request_outputs
|
510 |
+
|
511 |
+
def step(self) -> List[RequestOutput]:
|
512 |
+
"""Performs one decoding iteration and returns newly generated results.
|
513 |
+
|
514 |
+
.. figure:: https://i.imgur.com/sv2HssD.png
|
515 |
+
:alt: Overview of the step function
|
516 |
+
:align: center
|
517 |
+
|
518 |
+
Overview of the step function.
|
519 |
+
|
520 |
+
Details:
|
521 |
+
- Step 1: Schedules the sequences to be executed in the next
|
522 |
+
iteration and the token blocks to be swapped in/out/copy.
|
523 |
+
|
524 |
+
- Depending on the scheduling policy,
|
525 |
+
sequences may be `preempted/reordered`.
|
526 |
+
- A Sequence Group (SG) refer to a group of sequences
|
527 |
+
that are generated from the same prompt.
|
528 |
+
|
529 |
+
- Step 2: Calls the distributed executor to execute the model.
|
530 |
+
- Step 3: Processes the model output. This mainly includes:
|
531 |
+
|
532 |
+
- Decodes the relevant outputs.
|
533 |
+
- Updates the scheduled sequence groups with model outputs
|
534 |
+
based on its `sampling parameters` (`use_beam_search` or not).
|
535 |
+
- Frees the finished sequence groups.
|
536 |
+
|
537 |
+
- Finally, it creates and returns the newly generated results.
|
538 |
+
|
539 |
+
Example:
|
540 |
+
>>> # Please see the example/ folder for more detailed examples.
|
541 |
+
>>>
|
542 |
+
>>> # initialize engine and request arguments
|
543 |
+
>>> engine = LLMEngine.from_engine_args(engine_args)
|
544 |
+
>>> example_inputs = [(0, "What is LLM?",
|
545 |
+
>>> SamplingParams(temperature=0.0))]
|
546 |
+
>>>
|
547 |
+
>>> # Start the engine with an event loop
|
548 |
+
>>> while True:
|
549 |
+
>>> if example_inputs:
|
550 |
+
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
|
551 |
+
>>> engine.add_request(str(req_id), prompt, sampling_params)
|
552 |
+
>>>
|
553 |
+
>>> # continue the request processing
|
554 |
+
>>> request_outputs = engine.step()
|
555 |
+
>>> for request_output in request_outputs:
|
556 |
+
>>> if request_output.finished:
|
557 |
+
>>> # return or show the request output
|
558 |
+
>>>
|
559 |
+
>>> if not (engine.has_unfinished_requests() or example_inputs):
|
560 |
+
>>> break
|
561 |
+
"""
|
562 |
+
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
563 |
+
if not scheduler_outputs.is_empty():
|
564 |
+
output = self.model_executor.execute_model(
|
565 |
+
seq_group_metadata_list=seq_group_metadata_list,
|
566 |
+
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
567 |
+
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
568 |
+
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
569 |
+
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
|
570 |
+
else:
|
571 |
+
output = []
|
572 |
+
|
573 |
+
request_outputs = self._process_model_outputs(
|
574 |
+
output, scheduler_outputs.scheduled_seq_groups,
|
575 |
+
scheduler_outputs.ignored_seq_groups)
|
576 |
+
|
577 |
+
# Log stats.
|
578 |
+
if self.log_stats:
|
579 |
+
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
580 |
+
|
581 |
+
return request_outputs
|
582 |
+
|
583 |
+
def do_log_stats(self) -> None:
|
584 |
+
"""Forced log when no requests active."""
|
585 |
+
if self.log_stats:
|
586 |
+
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
|
587 |
+
|
588 |
+
def _get_stats(self,
|
589 |
+
scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
|
590 |
+
"""Get Stats to be Logged to Prometheus."""
|
591 |
+
now = time.time()
|
592 |
+
|
593 |
+
# KV Cache Usage in %.
|
594 |
+
num_total_gpu = self.cache_config.num_gpu_blocks
|
595 |
+
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
|
596 |
+
gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
|
597 |
+
|
598 |
+
num_total_cpu = self.cache_config.num_cpu_blocks
|
599 |
+
cpu_cache_usage = 0.
|
600 |
+
if num_total_cpu > 0:
|
601 |
+
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
|
602 |
+
)
|
603 |
+
cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
|
604 |
+
|
605 |
+
# Scheduler State
|
606 |
+
num_running = len(self.scheduler.running)
|
607 |
+
num_swapped = len(self.scheduler.swapped)
|
608 |
+
num_waiting = len(self.scheduler.waiting)
|
609 |
+
|
610 |
+
# Iteration stats if we have scheduler output.
|
611 |
+
num_prompt_tokens = 0
|
612 |
+
num_generation_tokens = 0
|
613 |
+
time_to_first_tokens = []
|
614 |
+
time_per_output_tokens = []
|
615 |
+
time_e2e_requests = []
|
616 |
+
if scheduler_outputs is not None:
|
617 |
+
prompt_run = scheduler_outputs.num_prefill_groups > 0
|
618 |
+
|
619 |
+
# Number of Tokens.
|
620 |
+
if prompt_run:
|
621 |
+
num_prompt_tokens = sum(
|
622 |
+
len(scheduled_seq_group.seq_group.prompt_token_ids)
|
623 |
+
for scheduled_seq_group in
|
624 |
+
scheduler_outputs.scheduled_seq_groups)
|
625 |
+
num_generation_tokens = sum(
|
626 |
+
scheduled_seq_group.seq_group.num_seqs()
|
627 |
+
for scheduled_seq_group in
|
628 |
+
scheduler_outputs.scheduled_seq_groups)
|
629 |
+
else:
|
630 |
+
num_generation_tokens = scheduler_outputs.num_batched_tokens
|
631 |
+
|
632 |
+
# Latency Timings.
|
633 |
+
time_last_iters = []
|
634 |
+
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
|
635 |
+
seq_group = scheduled_seq_group.seq_group
|
636 |
+
# Time since last token.
|
637 |
+
# (n.b. updates seq_group.metrics.last_token_time)
|
638 |
+
time_last_iters.append(seq_group.get_last_latency(now))
|
639 |
+
# Time since arrival for all finished requests.
|
640 |
+
if seq_group.is_finished():
|
641 |
+
time_e2e_requests.append(now -
|
642 |
+
seq_group.metrics.arrival_time)
|
643 |
+
|
644 |
+
time_to_first_tokens = time_last_iters if prompt_run else []
|
645 |
+
time_per_output_tokens = [] if prompt_run else time_last_iters
|
646 |
+
|
647 |
+
return Stats(
|
648 |
+
now=now,
|
649 |
+
num_running=num_running,
|
650 |
+
num_swapped=num_swapped,
|
651 |
+
num_waiting=num_waiting,
|
652 |
+
gpu_cache_usage=gpu_cache_usage,
|
653 |
+
cpu_cache_usage=cpu_cache_usage,
|
654 |
+
num_prompt_tokens=num_prompt_tokens,
|
655 |
+
num_generation_tokens=num_generation_tokens,
|
656 |
+
time_to_first_tokens=time_to_first_tokens,
|
657 |
+
time_per_output_tokens=time_per_output_tokens,
|
658 |
+
time_e2e_requests=time_e2e_requests,
|
659 |
+
)
|
660 |
+
|
661 |
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
662 |
+
return self.model_executor.add_lora(lora_request)
|
663 |
+
|
664 |
+
def remove_lora(self, lora_id: int) -> bool:
|
665 |
+
return self.model_executor.remove_lora(lora_id)
|
666 |
+
|
667 |
+
def list_loras(self) -> List[int]:
|
668 |
+
return self.model_executor.list_loras()
|
669 |
+
|
670 |
+
def check_health(self) -> None:
|
671 |
+
self.model_executor.check_health()
|
serve/model_runner.py
ADDED
@@ -0,0 +1,1223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import time
|
3 |
+
from enum import IntEnum
|
4 |
+
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
11 |
+
get_attn_backend)
|
12 |
+
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
13 |
+
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
14 |
+
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
15 |
+
from vllm.distributed.device_communicators import (custom_all_reduce,
|
16 |
+
pynccl_utils)
|
17 |
+
from vllm.logger import init_logger
|
18 |
+
from vllm.lora.layers import LoRAMapping
|
19 |
+
from vllm.lora.request import LoRARequest
|
20 |
+
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
21 |
+
from vllm.model_executor import SamplingMetadata
|
22 |
+
from vllm.model_executor.model_loader import get_model
|
23 |
+
from vllm.sampling_params import SamplingParams, SamplingType
|
24 |
+
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
|
25 |
+
SequenceGroupMetadata)
|
26 |
+
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
|
27 |
+
is_pin_memory_available, make_tensor_with_pad,
|
28 |
+
maybe_expand_dim)
|
29 |
+
from serve.gpt_model import GPT_models
|
30 |
+
|
31 |
+
logger = init_logger(__name__)
|
32 |
+
|
33 |
+
_PAD_SLOT_ID = -1
|
34 |
+
LORA_WARMUP_RANK = 8
|
35 |
+
_BATCH_SIZE_ALIGNMENT = 8
|
36 |
+
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
37 |
+
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
|
38 |
+
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
39 |
+
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
|
40 |
+
]
|
41 |
+
|
42 |
+
|
43 |
+
class PreparePromptMetadata(NamedTuple):
|
44 |
+
input_tokens: List[int]
|
45 |
+
input_positions: List[int]
|
46 |
+
attn_metadata: Optional[AttentionMetadataPerStage]
|
47 |
+
prompt_lens: List[int]
|
48 |
+
subquery_lens: List[int]
|
49 |
+
lora_index_mapping: List[int]
|
50 |
+
lora_prompt_mapping: List[int]
|
51 |
+
lora_requests: Set[LoRARequest]
|
52 |
+
multi_modal_input: Optional[torch.Tensor]
|
53 |
+
slot_mapping: List[int]
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
def empty(cls):
|
57 |
+
return PreparePromptMetadata(
|
58 |
+
input_tokens=[],
|
59 |
+
input_positions=[],
|
60 |
+
attn_metadata=None,
|
61 |
+
prompt_lens=[],
|
62 |
+
subquery_lens=[],
|
63 |
+
lora_index_mapping=[],
|
64 |
+
lora_prompt_mapping=[],
|
65 |
+
lora_requests=set(),
|
66 |
+
multi_modal_input=None,
|
67 |
+
slot_mapping=[],
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
class PrepareDecodeMetadata(NamedTuple):
|
72 |
+
input_tokens: List[int]
|
73 |
+
input_positions: List[int]
|
74 |
+
attn_metadata: Optional[AttentionMetadata]
|
75 |
+
lora_index_mapping: List[int]
|
76 |
+
lora_prompt_mapping: List[int]
|
77 |
+
lora_requests: Set[LoRARequest]
|
78 |
+
slot_mapping: List[int]
|
79 |
+
|
80 |
+
@classmethod
|
81 |
+
def empty(cls):
|
82 |
+
return PrepareDecodeMetadata(
|
83 |
+
input_tokens=[],
|
84 |
+
input_positions=[],
|
85 |
+
attn_metadata=None,
|
86 |
+
lora_index_mapping=[],
|
87 |
+
lora_prompt_mapping=[],
|
88 |
+
lora_requests=set(),
|
89 |
+
slot_mapping=[],
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
# How batches are constructed.
|
94 |
+
class BatchType(IntEnum):
|
95 |
+
# Every batch is prefill.
|
96 |
+
PREFILL = 0
|
97 |
+
# Every batch is decode.
|
98 |
+
DECODE = 1
|
99 |
+
# Batch is a mixture of prefill and decode.
|
100 |
+
MIXED = 2
|
101 |
+
|
102 |
+
|
103 |
+
class ModelRunner:
|
104 |
+
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
model_config: ModelConfig,
|
108 |
+
parallel_config: ParallelConfig,
|
109 |
+
scheduler_config: SchedulerConfig,
|
110 |
+
device_config: DeviceConfig,
|
111 |
+
load_config: LoadConfig,
|
112 |
+
lora_config: Optional[LoRAConfig],
|
113 |
+
kv_cache_dtype: Optional[str] = "auto",
|
114 |
+
is_driver_worker: bool = False,
|
115 |
+
vision_language_config: Optional[VisionLanguageConfig] = None,
|
116 |
+
):
|
117 |
+
self.model_config = model_config
|
118 |
+
self.parallel_config = parallel_config
|
119 |
+
self.scheduler_config = scheduler_config
|
120 |
+
self.lora_config = lora_config
|
121 |
+
self.load_config = load_config
|
122 |
+
self.is_driver_worker = is_driver_worker
|
123 |
+
|
124 |
+
# model_config can be None in tests/samplers/test_sampler.py.
|
125 |
+
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
|
126 |
+
self.sliding_window = (model_config.get_sliding_window()
|
127 |
+
if model_config is not None else None)
|
128 |
+
self.device_config = (device_config
|
129 |
+
if device_config is not None else DeviceConfig())
|
130 |
+
self.device = self.device_config.device
|
131 |
+
|
132 |
+
# Set after load_model.
|
133 |
+
self.lora_manager: LRUCacheWorkerLoRAManager = None
|
134 |
+
|
135 |
+
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
136 |
+
self.graph_memory_pool: Optional[Tuple[
|
137 |
+
int, int]] = None # Set during graph capture.
|
138 |
+
|
139 |
+
self.max_context_len_to_capture = (
|
140 |
+
self.model_config.max_context_len_to_capture
|
141 |
+
if self.model_config is not None else 0)
|
142 |
+
|
143 |
+
self.pin_memory = is_pin_memory_available()
|
144 |
+
self.kv_cache_dtype = kv_cache_dtype
|
145 |
+
self.vision_language_config = vision_language_config
|
146 |
+
|
147 |
+
self.attn_backend = get_attn_backend(
|
148 |
+
self.model_config.dtype if model_config is not None else None)
|
149 |
+
|
150 |
+
# Lazy initialization
|
151 |
+
self.model: torch.nn.Module # Set after load_model
|
152 |
+
self.block_size: int # Set after initial profiling.
|
153 |
+
# When using CUDA graph, the input block tables must be padded to
|
154 |
+
# max_context_len_to_capture. However, creating the block table in
|
155 |
+
# Python can be expensive. To optimize this, we cache the block table
|
156 |
+
# in numpy and only copy the actual input content at every iteration.
|
157 |
+
# The shape of the cached block table will be
|
158 |
+
# (max batch size to capture, max context len to capture / block size).
|
159 |
+
self.graph_block_tables: torch.Tensor # Set after initial profiling.
|
160 |
+
|
161 |
+
def load_model(self, args) -> None:
|
162 |
+
with CudaMemoryProfiler() as m:
|
163 |
+
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
|
164 |
+
latent_size = args.image_size // args.downsample_size
|
165 |
+
gpt_model = GPT_models[args.gpt_model](
|
166 |
+
vocab_size=args.codebook_size,
|
167 |
+
block_size=latent_size ** 2,
|
168 |
+
num_classes=args.num_classes,
|
169 |
+
cls_token_num=args.cls_token_num,
|
170 |
+
model_type=args.gpt_type,
|
171 |
+
cfg_scale=args.cfg_scale,
|
172 |
+
).to(device='cuda', dtype=precision) # TODO: make device configurable
|
173 |
+
|
174 |
+
checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
|
175 |
+
if args.from_fsdp: # fspd
|
176 |
+
model_weight = checkpoint
|
177 |
+
elif "model" in checkpoint: # ddp
|
178 |
+
model_weight = checkpoint["model"]
|
179 |
+
elif "state_dict" in checkpoint:
|
180 |
+
model_weight = checkpoint["state_dict"]
|
181 |
+
else:
|
182 |
+
raise Exception("please check model weight")
|
183 |
+
gpt_model.custom_load_state_dict(model_weight)
|
184 |
+
gpt_model.eval()
|
185 |
+
del checkpoint
|
186 |
+
self.model = gpt_model
|
187 |
+
|
188 |
+
self.model_memory_usage = m.consumed_memory
|
189 |
+
logger.info(f"Loading model weights took "
|
190 |
+
f"{self.model_memory_usage / float(2**30):.4f} GB")
|
191 |
+
|
192 |
+
if self.lora_config:
|
193 |
+
assert hasattr(self.model, "supported_lora_modules"
|
194 |
+
) and self.model.supported_lora_modules, (
|
195 |
+
"Model does not support LoRA")
|
196 |
+
assert hasattr(
|
197 |
+
self.model,
|
198 |
+
"embedding_modules"), "Model does not have embedding_modules"
|
199 |
+
assert hasattr(self.model, "embedding_padding_modules"
|
200 |
+
), "Model does not have embedding_padding_modules"
|
201 |
+
self.lora_manager = LRUCacheWorkerLoRAManager(
|
202 |
+
self.scheduler_config.max_num_seqs,
|
203 |
+
self.scheduler_config.max_num_batched_tokens, self.vocab_size,
|
204 |
+
self.lora_config, self.device, self.model.embedding_modules,
|
205 |
+
self.model.embedding_padding_modules)
|
206 |
+
self.model = self.lora_manager.create_lora_manager(self.model)
|
207 |
+
|
208 |
+
if self.kv_cache_dtype == "fp8" and is_hip():
|
209 |
+
# Currently scaled KV cache is only enabled on ROCm
|
210 |
+
if self.model_config.quantization_param_path is not None:
|
211 |
+
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
212 |
+
self.model.load_kv_cache_scales(
|
213 |
+
self.model_config.quantization_param_path)
|
214 |
+
else:
|
215 |
+
raise RuntimeError("Using FP8 KV cache and scaling "
|
216 |
+
"factors provided but model "
|
217 |
+
f"{self.model.__class__} does not "
|
218 |
+
"support loading scaling factors.")
|
219 |
+
else:
|
220 |
+
logger.warn("Using FP8 KV cache but no scaling factors "
|
221 |
+
"provided. Defaulting to scaling factors of 1.0. "
|
222 |
+
"This may lead to less accurate results!")
|
223 |
+
elif self.model_config.quantization_param_path is not None:
|
224 |
+
logger.warn("KV cache scaling factors provided, "
|
225 |
+
"but the KV cache data type is not FP8. "
|
226 |
+
"KV cache scaling factors will not be used.")
|
227 |
+
|
228 |
+
def set_block_size(self, block_size: int) -> None:
|
229 |
+
self.block_size = block_size
|
230 |
+
|
231 |
+
self.graph_block_tables = np.zeros(
|
232 |
+
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
|
233 |
+
dtype=np.int32)
|
234 |
+
|
235 |
+
def get_max_block_per_batch(self) -> int:
|
236 |
+
block_size = self.block_size
|
237 |
+
return (self.max_context_len_to_capture + block_size - 1) // block_size
|
238 |
+
|
239 |
+
def _prepare_prompt(
|
240 |
+
self,
|
241 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
242 |
+
) -> PreparePromptMetadata:
|
243 |
+
input_tokens: List[int] = []
|
244 |
+
input_positions: List[int] = []
|
245 |
+
slot_mapping: List[int] = []
|
246 |
+
lora_index_mapping: List[int] = []
|
247 |
+
lora_prompt_mapping: List[int] = []
|
248 |
+
lora_requests: Set[LoRARequest] = set()
|
249 |
+
|
250 |
+
prompt_lens: List[int] = []
|
251 |
+
context_lens: List[int] = []
|
252 |
+
subquery_lens: List[int] = []
|
253 |
+
prefix_block_tables: List[List[int]] = []
|
254 |
+
multi_modal_input_list: List[torch.Tensor] = []
|
255 |
+
|
256 |
+
if len(seq_group_metadata_list) == 0:
|
257 |
+
return PreparePromptMetadata.empty()
|
258 |
+
|
259 |
+
for seq_group_metadata in seq_group_metadata_list:
|
260 |
+
assert seq_group_metadata.is_prompt
|
261 |
+
seq_ids = list(seq_group_metadata.seq_data.keys())
|
262 |
+
assert len(seq_ids) == 1
|
263 |
+
seq_id = seq_ids[0]
|
264 |
+
|
265 |
+
computed_block_nums = seq_group_metadata.computed_block_nums
|
266 |
+
if (self.scheduler_config is not None
|
267 |
+
and self.scheduler_config.chunked_prefill_enabled
|
268 |
+
and not (computed_block_nums is None
|
269 |
+
or computed_block_nums == [])):
|
270 |
+
raise RuntimeError(
|
271 |
+
"chunked prefill cannot be used with prefix caching "
|
272 |
+
"now.")
|
273 |
+
|
274 |
+
token_chunk_size = seq_group_metadata.token_chunk_size
|
275 |
+
seq_data = seq_group_metadata.seq_data[seq_id]
|
276 |
+
computed_len = seq_data.get_num_computed_tokens()
|
277 |
+
# We should use get_len here because in case of preemption
|
278 |
+
# it contains output tokens.
|
279 |
+
prefill_end = min(seq_data.get_len(),
|
280 |
+
computed_len + token_chunk_size)
|
281 |
+
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
|
282 |
+
prompt_len = prefill_end
|
283 |
+
prompt_lens.append(prompt_len)
|
284 |
+
|
285 |
+
# NOTE: This only works for oooooooxxx style attention.
|
286 |
+
if computed_block_nums is not None and len(
|
287 |
+
computed_block_nums) > 0 and self.sliding_window is None:
|
288 |
+
# Prefix is not supported with sliding_window
|
289 |
+
computed_len = len(computed_block_nums) * self.block_size
|
290 |
+
prompt_tokens = prompt_tokens[computed_len:]
|
291 |
+
prefix_block_tables.append(computed_block_nums)
|
292 |
+
elif self.scheduler_config.chunked_prefill_enabled:
|
293 |
+
if seq_group_metadata.block_tables is not None:
|
294 |
+
# Prefill has chunked before.
|
295 |
+
block_table = seq_group_metadata.block_tables[seq_id]
|
296 |
+
prefix_block_tables.append(block_table)
|
297 |
+
else:
|
298 |
+
# The first prefill.
|
299 |
+
prefix_block_tables.append([])
|
300 |
+
else:
|
301 |
+
prefix_block_tables.append([])
|
302 |
+
# Right now, prefill start is always 0. However, this
|
303 |
+
# assumption can be changed once chunked prefill is introduced.
|
304 |
+
assert computed_len == 0
|
305 |
+
|
306 |
+
# actual prompt lens
|
307 |
+
context_lens.append(computed_len)
|
308 |
+
subquery_lens.append(prompt_len - computed_len)
|
309 |
+
|
310 |
+
input_tokens.extend(prompt_tokens)
|
311 |
+
# NOTE(woosuk): Here we assume that the first token in the prompt
|
312 |
+
# is always the first token in the sequence.
|
313 |
+
input_positions.extend(list(range(computed_len, prefill_end)))
|
314 |
+
lora_id = seq_group_metadata.lora_int_id
|
315 |
+
|
316 |
+
if lora_id > 0:
|
317 |
+
lora_requests.add(seq_group_metadata.lora_request)
|
318 |
+
|
319 |
+
lora_index_mapping += [lora_id] * (prompt_len - computed_len)
|
320 |
+
lora_prompt_mapping.extend(
|
321 |
+
[lora_id] *
|
322 |
+
(prompt_len - computed_len
|
323 |
+
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
324 |
+
|
325 |
+
if seq_group_metadata.multi_modal_data:
|
326 |
+
multi_modal_input_list.append(
|
327 |
+
seq_group_metadata.multi_modal_data.data)
|
328 |
+
|
329 |
+
if seq_group_metadata.block_tables is None:
|
330 |
+
# During memory profiling, the block tables are not initialized
|
331 |
+
# yet. In this case, we just use a dummy slot mapping.
|
332 |
+
slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
|
333 |
+
continue
|
334 |
+
|
335 |
+
# Compute the slot mapping.
|
336 |
+
block_table = seq_group_metadata.block_tables[seq_id]
|
337 |
+
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
338 |
+
# where start_idx is max(0, prompt_len - sliding_window).
|
339 |
+
# For example, if the prompt len is 10, sliding window is 8, and
|
340 |
+
# block size is 4, the first two tokens are masked and the slot
|
341 |
+
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
342 |
+
start_idx = 0
|
343 |
+
if self.sliding_window is not None:
|
344 |
+
assert computed_len == 0, (
|
345 |
+
"Prefix caching is currently not supported with "
|
346 |
+
"sliding window attention")
|
347 |
+
start_idx = max(0, prompt_len - self.sliding_window)
|
348 |
+
|
349 |
+
for i in range(computed_len, prefill_end):
|
350 |
+
if i < start_idx:
|
351 |
+
slot_mapping.append(_PAD_SLOT_ID)
|
352 |
+
continue
|
353 |
+
|
354 |
+
block_number = block_table[i // self.block_size]
|
355 |
+
block_offset = i % self.block_size
|
356 |
+
slot = block_number * self.block_size + block_offset
|
357 |
+
slot_mapping.append(slot)
|
358 |
+
|
359 |
+
max_subquery_len = max(subquery_lens)
|
360 |
+
max_prompt_len = max(prompt_lens)
|
361 |
+
assert max_subquery_len > 0
|
362 |
+
|
363 |
+
context_lens_tensor = torch.tensor(context_lens,
|
364 |
+
dtype=torch.int,
|
365 |
+
device=self.device)
|
366 |
+
|
367 |
+
if multi_modal_input_list:
|
368 |
+
assert self.vision_language_config, (
|
369 |
+
"Multi-modal inputs are only supported by "
|
370 |
+
"vision language models.")
|
371 |
+
multi_modal_input = torch.cat(multi_modal_input_list,
|
372 |
+
dim=0).to(self.device)
|
373 |
+
else:
|
374 |
+
multi_modal_input = None
|
375 |
+
|
376 |
+
# Prepare prefix block tables
|
377 |
+
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
|
378 |
+
block_tables = make_tensor_with_pad(
|
379 |
+
prefix_block_tables,
|
380 |
+
max_len=max_prompt_block_table_len,
|
381 |
+
pad=0,
|
382 |
+
dtype=torch.int,
|
383 |
+
device=self.device,
|
384 |
+
)
|
385 |
+
|
386 |
+
# Query length can be shorter than key (i.e., prompt) when prefill
|
387 |
+
# is chunked or prefix cached.
|
388 |
+
subquery_lens_tensor = torch.tensor(subquery_lens,
|
389 |
+
dtype=torch.long,
|
390 |
+
device=self.device)
|
391 |
+
subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
|
392 |
+
dtype=torch.int32,
|
393 |
+
device=self.device)
|
394 |
+
|
395 |
+
prompt_lens_tensor = torch.tensor(prompt_lens,
|
396 |
+
dtype=torch.long,
|
397 |
+
device=self.device)
|
398 |
+
seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
|
399 |
+
dtype=torch.int32,
|
400 |
+
device=self.device)
|
401 |
+
|
402 |
+
torch.cumsum(subquery_lens_tensor,
|
403 |
+
dim=0,
|
404 |
+
dtype=subquery_start_loc.dtype,
|
405 |
+
out=subquery_start_loc[1:])
|
406 |
+
|
407 |
+
torch.cumsum(prompt_lens_tensor,
|
408 |
+
dim=0,
|
409 |
+
dtype=seq_start_loc.dtype,
|
410 |
+
out=seq_start_loc[1:])
|
411 |
+
|
412 |
+
attn_metadata = self.attn_backend.make_metadata(
|
413 |
+
is_prompt=True,
|
414 |
+
prompt_lens=prompt_lens,
|
415 |
+
prompt_lens_tensor=prompt_lens_tensor,
|
416 |
+
max_subquery_len=max_subquery_len,
|
417 |
+
max_context_len=None,
|
418 |
+
max_prompt_len=max_prompt_len,
|
419 |
+
subquery_start_loc=subquery_start_loc,
|
420 |
+
seq_start_loc=seq_start_loc,
|
421 |
+
context_lens=context_lens_tensor,
|
422 |
+
block_tables=block_tables,
|
423 |
+
use_cuda_graph=False,
|
424 |
+
)
|
425 |
+
|
426 |
+
return PreparePromptMetadata(
|
427 |
+
input_tokens=input_tokens,
|
428 |
+
input_positions=input_positions,
|
429 |
+
attn_metadata=attn_metadata,
|
430 |
+
prompt_lens=prompt_lens,
|
431 |
+
subquery_lens=subquery_lens,
|
432 |
+
lora_index_mapping=lora_index_mapping,
|
433 |
+
lora_prompt_mapping=lora_prompt_mapping,
|
434 |
+
lora_requests=lora_requests,
|
435 |
+
multi_modal_input=multi_modal_input,
|
436 |
+
slot_mapping=slot_mapping,
|
437 |
+
)
|
438 |
+
|
439 |
+
def _prepare_decode(
|
440 |
+
self,
|
441 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
442 |
+
) -> PrepareDecodeMetadata:
|
443 |
+
input_tokens: List[int] = []
|
444 |
+
input_positions: List[int] = []
|
445 |
+
slot_mapping: List[int] = []
|
446 |
+
context_lens: List[int] = []
|
447 |
+
block_tables: List[List[int]] = []
|
448 |
+
lora_index_mapping: List[int] = []
|
449 |
+
lora_prompt_mapping: List[int] = []
|
450 |
+
lora_requests: Set[LoRARequest] = set()
|
451 |
+
|
452 |
+
if len(seq_group_metadata_list) == 0:
|
453 |
+
return PrepareDecodeMetadata.empty()
|
454 |
+
|
455 |
+
for seq_group_metadata in seq_group_metadata_list:
|
456 |
+
assert not seq_group_metadata.is_prompt
|
457 |
+
assert seq_group_metadata.token_chunk_size == 1
|
458 |
+
|
459 |
+
seq_ids = list(seq_group_metadata.seq_data.keys())
|
460 |
+
lora_id = seq_group_metadata.lora_int_id
|
461 |
+
|
462 |
+
if lora_id > 0:
|
463 |
+
lora_requests.add(seq_group_metadata.lora_request)
|
464 |
+
|
465 |
+
for seq_id in seq_ids:
|
466 |
+
seq_data = seq_group_metadata.seq_data[seq_id]
|
467 |
+
generation_token = seq_data.get_last_token_id()
|
468 |
+
input_tokens.append(generation_token)
|
469 |
+
|
470 |
+
seq_len = seq_data.get_len()
|
471 |
+
position = seq_len - 1
|
472 |
+
input_positions.append(position)
|
473 |
+
|
474 |
+
context_len = seq_len if self.sliding_window is None else min(
|
475 |
+
seq_len, self.sliding_window)
|
476 |
+
context_lens.append(context_len)
|
477 |
+
|
478 |
+
block_table = seq_group_metadata.block_tables[seq_id]
|
479 |
+
block_number = block_table[position // self.block_size]
|
480 |
+
block_offset = position % self.block_size
|
481 |
+
slot = block_number * self.block_size + block_offset
|
482 |
+
slot_mapping.append(slot)
|
483 |
+
lora_index_mapping.append(lora_id)
|
484 |
+
lora_prompt_mapping.append(lora_id)
|
485 |
+
|
486 |
+
if self.sliding_window is not None:
|
487 |
+
sliding_window_blocks = (self.sliding_window //
|
488 |
+
self.block_size)
|
489 |
+
block_table = block_table[-sliding_window_blocks:]
|
490 |
+
block_tables.append(block_table)
|
491 |
+
|
492 |
+
# vLLM uses cuda graph only for decoding requests.
|
493 |
+
# See `capture_model` API for more details.
|
494 |
+
# For decoding requests, batch_size == input_tokens.
|
495 |
+
batch_size = len(input_tokens)
|
496 |
+
max_context_len = max(context_lens)
|
497 |
+
use_captured_graph = (
|
498 |
+
not self.model_config.enforce_eager
|
499 |
+
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
500 |
+
and max_context_len <= self.max_context_len_to_capture)
|
501 |
+
if use_captured_graph:
|
502 |
+
graph_batch_size = _get_graph_batch_size(batch_size)
|
503 |
+
assert graph_batch_size >= batch_size
|
504 |
+
for _ in range(graph_batch_size - batch_size):
|
505 |
+
input_tokens.append(0)
|
506 |
+
input_positions.append(0)
|
507 |
+
slot_mapping.append(_PAD_SLOT_ID)
|
508 |
+
context_lens.append(1)
|
509 |
+
block_tables.append([])
|
510 |
+
lora_index_mapping.append(0)
|
511 |
+
batch_size = graph_batch_size
|
512 |
+
|
513 |
+
context_lens_tensor = torch.tensor(context_lens,
|
514 |
+
dtype=torch.int,
|
515 |
+
device=self.device)
|
516 |
+
|
517 |
+
if use_captured_graph:
|
518 |
+
# When using cuda-graph all these tensors should be
|
519 |
+
# padded.
|
520 |
+
assert context_lens_tensor.shape[0] == len(input_tokens)
|
521 |
+
assert context_lens_tensor.shape[0] == len(input_positions)
|
522 |
+
assert context_lens_tensor.shape[0] == len(slot_mapping)
|
523 |
+
|
524 |
+
# The shape of graph_block_tables is
|
525 |
+
# [max batch size, max context len // block size].
|
526 |
+
input_block_tables = self.graph_block_tables[:batch_size]
|
527 |
+
for i, block_table in enumerate(block_tables):
|
528 |
+
if block_table:
|
529 |
+
input_block_tables[i, :len(block_table)] = block_table
|
530 |
+
block_tables = torch.tensor(input_block_tables, device=self.device)
|
531 |
+
else:
|
532 |
+
max_block_table_len = max(
|
533 |
+
len(block_table) for block_table in block_tables)
|
534 |
+
block_tables = make_tensor_with_pad(
|
535 |
+
block_tables,
|
536 |
+
max_len=max_block_table_len,
|
537 |
+
pad=0,
|
538 |
+
dtype=torch.int,
|
539 |
+
device=self.device,
|
540 |
+
)
|
541 |
+
|
542 |
+
attn_metadata = self.attn_backend.make_metadata(
|
543 |
+
is_prompt=False,
|
544 |
+
prompt_lens=None,
|
545 |
+
prompt_lens_tensor=None,
|
546 |
+
max_subquery_len=None,
|
547 |
+
max_context_len=max_context_len,
|
548 |
+
max_prompt_len=None,
|
549 |
+
subquery_start_loc=None,
|
550 |
+
seq_start_loc=None,
|
551 |
+
context_lens=context_lens_tensor,
|
552 |
+
block_tables=block_tables,
|
553 |
+
use_cuda_graph=use_captured_graph,
|
554 |
+
)
|
555 |
+
return PrepareDecodeMetadata(
|
556 |
+
input_tokens=input_tokens,
|
557 |
+
input_positions=input_positions,
|
558 |
+
attn_metadata=attn_metadata,
|
559 |
+
lora_index_mapping=lora_index_mapping,
|
560 |
+
lora_prompt_mapping=lora_prompt_mapping,
|
561 |
+
lora_requests=lora_requests,
|
562 |
+
slot_mapping=slot_mapping,
|
563 |
+
)
|
564 |
+
|
565 |
+
def _prepare_sample(
|
566 |
+
self,
|
567 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
568 |
+
prompt_lens: List[int],
|
569 |
+
subquery_lens: Optional[List[int]],
|
570 |
+
) -> SamplingMetadata:
|
571 |
+
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
572 |
+
selected_token_indices: List[int] = []
|
573 |
+
generators: List[torch.Generator] = []
|
574 |
+
selected_token_start_idx = 0
|
575 |
+
categorized_sample_indices: Dict[SamplingType,
|
576 |
+
List[Tuple[int, int]]] = {
|
577 |
+
t: []
|
578 |
+
for t in SamplingType
|
579 |
+
}
|
580 |
+
categorized_sample_indices_start_idx = 0
|
581 |
+
categorized_sampled_token_indices_start_idx = 0
|
582 |
+
|
583 |
+
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
584 |
+
seq_ids = list(seq_group_metadata.seq_data.keys())
|
585 |
+
sampling_params = seq_group_metadata.sampling_params
|
586 |
+
seq_groups.append((seq_ids, sampling_params))
|
587 |
+
|
588 |
+
if seq_group_metadata.is_prompt:
|
589 |
+
assert len(seq_ids) == 1
|
590 |
+
assert subquery_lens is not None
|
591 |
+
subquery_len = subquery_lens[i]
|
592 |
+
if sampling_params.prompt_logprobs is not None:
|
593 |
+
# NOTE: prompt token positions do not need sample, skip
|
594 |
+
categorized_sample_indices_start_idx += subquery_len - 1
|
595 |
+
|
596 |
+
categorized_sample_indices[
|
597 |
+
sampling_params.sampling_type].append(
|
598 |
+
(categorized_sample_indices_start_idx,
|
599 |
+
categorized_sampled_token_indices_start_idx))
|
600 |
+
categorized_sample_indices_start_idx += 1
|
601 |
+
categorized_sampled_token_indices_start_idx += 1
|
602 |
+
|
603 |
+
if sampling_params.prompt_logprobs is not None:
|
604 |
+
selected_token_indices.extend(
|
605 |
+
range(selected_token_start_idx,
|
606 |
+
selected_token_start_idx + subquery_len - 1))
|
607 |
+
selected_token_indices.append(selected_token_start_idx +
|
608 |
+
subquery_len - 1)
|
609 |
+
selected_token_start_idx += subquery_len
|
610 |
+
|
611 |
+
if sampling_params.seed is not None:
|
612 |
+
seq_group_metadata.state.generator = torch.Generator(
|
613 |
+
device=self.device).manual_seed(sampling_params.seed)
|
614 |
+
else:
|
615 |
+
num_seqs = len(seq_ids)
|
616 |
+
selected_token_indices.extend(
|
617 |
+
range(selected_token_start_idx,
|
618 |
+
selected_token_start_idx + num_seqs))
|
619 |
+
selected_token_start_idx += num_seqs
|
620 |
+
|
621 |
+
categorized_sample_indices[
|
622 |
+
sampling_params.sampling_type].extend(
|
623 |
+
list(
|
624 |
+
zip(
|
625 |
+
range(
|
626 |
+
categorized_sample_indices_start_idx,
|
627 |
+
categorized_sample_indices_start_idx +
|
628 |
+
num_seqs),
|
629 |
+
range(
|
630 |
+
categorized_sampled_token_indices_start_idx,
|
631 |
+
categorized_sampled_token_indices_start_idx
|
632 |
+
+ num_seqs))))
|
633 |
+
categorized_sample_indices_start_idx += num_seqs
|
634 |
+
categorized_sampled_token_indices_start_idx += num_seqs
|
635 |
+
|
636 |
+
if sampling_params.seed is not None:
|
637 |
+
generators.append(seq_group_metadata.state.generator)
|
638 |
+
|
639 |
+
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
640 |
+
dtype=torch.long,
|
641 |
+
target_device=self.device,
|
642 |
+
pin_memory=self.pin_memory)
|
643 |
+
|
644 |
+
categorized_sample_indices = {
|
645 |
+
t: maybe_expand_dim(
|
646 |
+
async_tensor_h2d(seq_ids,
|
647 |
+
dtype=torch.int,
|
648 |
+
target_device=self.device,
|
649 |
+
pin_memory=self.pin_memory), 2, 2)
|
650 |
+
for t, seq_ids in categorized_sample_indices.items()
|
651 |
+
}
|
652 |
+
|
653 |
+
seq_data: Dict[int, SequenceData] = {}
|
654 |
+
for seq_group_metadata in seq_group_metadata_list:
|
655 |
+
seq_data.update(seq_group_metadata.seq_data)
|
656 |
+
|
657 |
+
sampling_metadata = SamplingMetadata(
|
658 |
+
seq_groups=seq_groups,
|
659 |
+
seq_data=seq_data,
|
660 |
+
prompt_lens=prompt_lens,
|
661 |
+
selected_token_indices=selected_token_indices,
|
662 |
+
categorized_sample_indices=categorized_sample_indices,
|
663 |
+
generators=generators,
|
664 |
+
)
|
665 |
+
return sampling_metadata
|
666 |
+
|
667 |
+
def prepare_input_tensors(
|
668 |
+
self,
|
669 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
670 |
+
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
671 |
+
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
672 |
+
if self.is_driver_worker:
|
673 |
+
prefill_reqs = []
|
674 |
+
decode_reqs = []
|
675 |
+
for seq_group_meta in seq_group_metadata_list:
|
676 |
+
if seq_group_meta.is_prompt:
|
677 |
+
prefill_reqs.append(seq_group_meta)
|
678 |
+
else:
|
679 |
+
decode_reqs.append(seq_group_meta)
|
680 |
+
|
681 |
+
# Prepare input tensors.
|
682 |
+
(
|
683 |
+
input_tokens,
|
684 |
+
input_positions,
|
685 |
+
prefill_attn_metadata,
|
686 |
+
prompt_lens,
|
687 |
+
subquery_lens,
|
688 |
+
lora_index_mapping,
|
689 |
+
lora_prompt_mapping,
|
690 |
+
lora_requests,
|
691 |
+
multi_modal_input,
|
692 |
+
slot_mapping,
|
693 |
+
) = self._prepare_prompt(prefill_reqs)
|
694 |
+
(
|
695 |
+
decode_input_tokens,
|
696 |
+
decode_input_positions,
|
697 |
+
decode_attn_metadata,
|
698 |
+
decode_lora_index_mapping,
|
699 |
+
decode_lora_prompt_mapping,
|
700 |
+
decode_lora_requests,
|
701 |
+
decode_slot_mapping,
|
702 |
+
) = self._prepare_decode(decode_reqs)
|
703 |
+
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
704 |
+
prompt_lens,
|
705 |
+
subquery_lens)
|
706 |
+
|
707 |
+
if not self.scheduler_config.chunked_prefill_enabled:
|
708 |
+
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
709 |
+
|
710 |
+
num_prefills = len(prompt_lens)
|
711 |
+
num_prefill_tokens = len(input_tokens)
|
712 |
+
num_decode_tokens = len(decode_input_tokens)
|
713 |
+
|
714 |
+
# Coalesce tensors. Note that attn_metadata is currently not
|
715 |
+
# coalesced for simplicity.
|
716 |
+
input_tokens.extend(decode_input_tokens)
|
717 |
+
input_positions.extend(decode_input_positions)
|
718 |
+
slot_mapping.extend(decode_slot_mapping)
|
719 |
+
lora_index_mapping.extend(decode_lora_index_mapping)
|
720 |
+
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
|
721 |
+
lora_requests.update(decode_lora_requests)
|
722 |
+
|
723 |
+
input_tokens = torch.tensor(input_tokens,
|
724 |
+
dtype=torch.long,
|
725 |
+
device=self.device)
|
726 |
+
input_positions = torch.tensor(input_positions,
|
727 |
+
dtype=torch.long,
|
728 |
+
device=self.device)
|
729 |
+
slot_mapping = torch.tensor(slot_mapping,
|
730 |
+
dtype=torch.long,
|
731 |
+
device=self.device)
|
732 |
+
|
733 |
+
if self.lora_config:
|
734 |
+
lora_mapping = LoRAMapping(
|
735 |
+
lora_index_mapping,
|
736 |
+
lora_prompt_mapping,
|
737 |
+
)
|
738 |
+
else:
|
739 |
+
lora_mapping = None
|
740 |
+
|
741 |
+
# Broadcast the metadata.
|
742 |
+
# If batch contains both prefill and decode, it sends 2 broadcasts.
|
743 |
+
# If it only contains 1 type, it triggers a single broadcast.
|
744 |
+
if (prefill_attn_metadata is not None
|
745 |
+
and decode_attn_metadata is not None):
|
746 |
+
batch_type = BatchType.MIXED
|
747 |
+
elif prefill_attn_metadata is not None:
|
748 |
+
batch_type = BatchType.PREFILL
|
749 |
+
else:
|
750 |
+
batch_type = BatchType.DECODE
|
751 |
+
|
752 |
+
metadata_dict = {
|
753 |
+
"input_tokens": input_tokens,
|
754 |
+
"input_positions": input_positions,
|
755 |
+
"selected_token_indices":
|
756 |
+
sampling_metadata.selected_token_indices,
|
757 |
+
"lora_requests": lora_requests,
|
758 |
+
"lora_mapping": lora_mapping,
|
759 |
+
"multi_modal_input": multi_modal_input,
|
760 |
+
"num_prefill_tokens": num_prefill_tokens,
|
761 |
+
"num_decode_tokens": num_decode_tokens,
|
762 |
+
"slot_mapping": slot_mapping,
|
763 |
+
"num_prefills": num_prefills,
|
764 |
+
"batch_type": batch_type,
|
765 |
+
}
|
766 |
+
if prefill_attn_metadata is not None:
|
767 |
+
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
768 |
+
else:
|
769 |
+
assert decode_attn_metadata is not None
|
770 |
+
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
771 |
+
broadcast_tensor_dict(metadata_dict, src=0)
|
772 |
+
|
773 |
+
# Broadcast decode attn metadata for mixed batch type.
|
774 |
+
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
|
775 |
+
# We can potentially reduce the overhead by coelescing tensors.
|
776 |
+
if batch_type == BatchType.MIXED:
|
777 |
+
assert decode_attn_metadata is not None
|
778 |
+
metadata_dict = decode_attn_metadata.asdict_zerocopy()
|
779 |
+
broadcast_tensor_dict(metadata_dict, src=0)
|
780 |
+
else:
|
781 |
+
metadata_dict = broadcast_tensor_dict(src=0)
|
782 |
+
input_tokens = metadata_dict.pop("input_tokens")
|
783 |
+
input_positions = metadata_dict.pop("input_positions")
|
784 |
+
slot_mapping = metadata_dict.pop("slot_mapping")
|
785 |
+
num_prefills = metadata_dict.pop("num_prefills")
|
786 |
+
selected_token_indices = metadata_dict.pop(
|
787 |
+
"selected_token_indices")
|
788 |
+
lora_mapping = metadata_dict.pop("lora_mapping")
|
789 |
+
lora_requests = metadata_dict.pop("lora_requests")
|
790 |
+
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
791 |
+
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
|
792 |
+
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
|
793 |
+
batch_type = metadata_dict.pop("batch_type")
|
794 |
+
|
795 |
+
# Create an attention metadata.
|
796 |
+
prefill_attn_metadata = None
|
797 |
+
decode_attn_metadata = None
|
798 |
+
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
|
799 |
+
prefill_attn_metadata = self.attn_backend.make_metadata(
|
800 |
+
**metadata_dict)
|
801 |
+
else:
|
802 |
+
decode_attn_metadata = self.attn_backend.make_metadata(
|
803 |
+
**metadata_dict)
|
804 |
+
sampling_metadata = SamplingMetadata(
|
805 |
+
seq_groups=None,
|
806 |
+
seq_data=None,
|
807 |
+
prompt_lens=None,
|
808 |
+
selected_token_indices=selected_token_indices,
|
809 |
+
categorized_sample_indices=None,
|
810 |
+
generators=None,
|
811 |
+
perform_sampling=False,
|
812 |
+
)
|
813 |
+
|
814 |
+
# if it is a mixed batch, decode attn_metadata is broadcasted
|
815 |
+
# separately.
|
816 |
+
if batch_type == BatchType.MIXED:
|
817 |
+
metadata_dict = broadcast_tensor_dict(src=0)
|
818 |
+
decode_attn_metadata = self.attn_backend.make_metadata(
|
819 |
+
**metadata_dict)
|
820 |
+
|
821 |
+
attn_metadata = AttentionMetadata(
|
822 |
+
num_prefills=num_prefills,
|
823 |
+
slot_mapping=slot_mapping,
|
824 |
+
num_prefill_tokens=num_prefill_tokens,
|
825 |
+
num_decode_tokens=num_decode_tokens,
|
826 |
+
prefill_metadata=prefill_attn_metadata,
|
827 |
+
decode_metadata=decode_attn_metadata,
|
828 |
+
kv_cache_dtype=self.kv_cache_dtype,
|
829 |
+
)
|
830 |
+
|
831 |
+
return (input_tokens, input_positions, attn_metadata,
|
832 |
+
sampling_metadata, lora_requests, lora_mapping,
|
833 |
+
multi_modal_input)
|
834 |
+
|
835 |
+
@torch.inference_mode()
|
836 |
+
def execute_model(
|
837 |
+
self,
|
838 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
839 |
+
kv_caches: List[torch.Tensor],
|
840 |
+
) -> Optional[SamplerOutput]:
|
841 |
+
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
842 |
+
lora_requests, lora_mapping, multi_modal_input
|
843 |
+
) = self.prepare_input_tensors(seq_group_metadata_list)
|
844 |
+
if self.lora_config:
|
845 |
+
self.set_active_loras(lora_requests, lora_mapping)
|
846 |
+
|
847 |
+
# Currently cuda graph is only supported by the decode phase.
|
848 |
+
prefill_meta = attn_metadata.prefill_metadata
|
849 |
+
decode_meta = attn_metadata.decode_metadata
|
850 |
+
if prefill_meta is None and decode_meta.use_cuda_graph:
|
851 |
+
graph_batch_size = input_tokens.shape[0]
|
852 |
+
model_executable = self.graph_runners[graph_batch_size]
|
853 |
+
else:
|
854 |
+
model_executable = self.model
|
855 |
+
execute_model_kwargs = {
|
856 |
+
"input_ids": input_tokens,
|
857 |
+
"positions": input_positions,
|
858 |
+
"kv_caches": kv_caches,
|
859 |
+
"attn_metadata": attn_metadata,
|
860 |
+
}
|
861 |
+
if self.vision_language_config:
|
862 |
+
execute_model_kwargs.update({"image_input": multi_modal_input})
|
863 |
+
hidden_states = model_executable(**execute_model_kwargs)
|
864 |
+
|
865 |
+
# Compute the logits.
|
866 |
+
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
867 |
+
|
868 |
+
# Only perform sampling in the driver worker.
|
869 |
+
if not sampling_metadata.perform_sampling:
|
870 |
+
return None
|
871 |
+
|
872 |
+
# Sample the next token.
|
873 |
+
output = self.model.sample(
|
874 |
+
logits=logits,
|
875 |
+
sampling_metadata=sampling_metadata,
|
876 |
+
)
|
877 |
+
return output
|
878 |
+
|
879 |
+
@torch.inference_mode()
|
880 |
+
def profile_run(self) -> None:
|
881 |
+
# Enable top-k sampling to reflect the accurate memory usage.
|
882 |
+
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
883 |
+
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
884 |
+
max_num_seqs = self.scheduler_config.max_num_seqs
|
885 |
+
|
886 |
+
# This represents the maximum number of different requests
|
887 |
+
# that will have unique loras, an therefore the max amount of memory
|
888 |
+
# consumption create dummy lora request copies from the lora request
|
889 |
+
# passed in, which contains a lora from the lora warmup path.
|
890 |
+
dummy_lora_requests = []
|
891 |
+
dummy_lora_requests_per_seq = []
|
892 |
+
if self.lora_config:
|
893 |
+
for idx in range(self.lora_config.max_loras):
|
894 |
+
lora_id = idx + 1
|
895 |
+
dummy_lora_request = LoRARequest(
|
896 |
+
lora_name=f"warmup_{lora_id}",
|
897 |
+
lora_int_id=lora_id,
|
898 |
+
lora_local_path="/not/a/real/path",
|
899 |
+
)
|
900 |
+
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
901 |
+
rank=LORA_WARMUP_RANK)
|
902 |
+
dummy_lora_requests.append(dummy_lora_request)
|
903 |
+
dummy_lora_requests_per_seq = [
|
904 |
+
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
905 |
+
for idx in range(max_num_seqs)
|
906 |
+
]
|
907 |
+
|
908 |
+
# Profile memory usage with max_num_sequences sequences and the total
|
909 |
+
# number of tokens equal to max_num_batched_tokens.
|
910 |
+
seqs: List[SequenceGroupMetadata] = []
|
911 |
+
# Additional GPU memory may be needed for vision encoding, which needs
|
912 |
+
# to be accounted for when calculating the GPU blocks for
|
913 |
+
# vLLM blocker manager.
|
914 |
+
# To exercise the worst scenario for GPU memory consumption,
|
915 |
+
# the number of seqs (batch_size) is chosen to maximize the number
|
916 |
+
# of images processed.
|
917 |
+
if self.vision_language_config:
|
918 |
+
max_num_seqs = min(
|
919 |
+
max_num_seqs,
|
920 |
+
int(max_num_batched_tokens /
|
921 |
+
self.vision_language_config.image_feature_size))
|
922 |
+
for group_id in range(max_num_seqs):
|
923 |
+
seq_len = (max_num_batched_tokens // max_num_seqs +
|
924 |
+
(group_id < max_num_batched_tokens % max_num_seqs))
|
925 |
+
seq_data, fake_multi_modal_input = _prepare_fake_inputs(
|
926 |
+
seq_len, self.vision_language_config)
|
927 |
+
seq = SequenceGroupMetadata(
|
928 |
+
request_id=str(group_id),
|
929 |
+
is_prompt=True,
|
930 |
+
seq_data={group_id: seq_data},
|
931 |
+
sampling_params=sampling_params,
|
932 |
+
block_tables=None,
|
933 |
+
lora_request=dummy_lora_requests_per_seq[group_id]
|
934 |
+
if dummy_lora_requests_per_seq else None,
|
935 |
+
multi_modal_data=fake_multi_modal_input,
|
936 |
+
)
|
937 |
+
seqs.append(seq)
|
938 |
+
|
939 |
+
# Run the model with the dummy inputs.
|
940 |
+
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
941 |
+
kv_caches = [None] * num_layers
|
942 |
+
self.execute_model(seqs, kv_caches)
|
943 |
+
torch.cuda.synchronize()
|
944 |
+
return
|
945 |
+
|
946 |
+
def remove_all_loras(self) -> bool:
|
947 |
+
if not self.lora_manager:
|
948 |
+
raise RuntimeError("LoRA is not enabled.")
|
949 |
+
return self.lora_manager.remove_all_loras()
|
950 |
+
|
951 |
+
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
952 |
+
lora_mapping: LoRAMapping) -> None:
|
953 |
+
if not self.lora_manager:
|
954 |
+
raise RuntimeError("LoRA is not enabled.")
|
955 |
+
self.lora_manager.set_active_loras(lora_requests, lora_mapping)
|
956 |
+
|
957 |
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
958 |
+
if not self.lora_manager:
|
959 |
+
raise RuntimeError("LoRA is not enabled.")
|
960 |
+
return self.lora_manager.add_lora(lora_request)
|
961 |
+
|
962 |
+
def remove_lora(self, lora_id: int) -> bool:
|
963 |
+
if not self.lora_manager:
|
964 |
+
raise RuntimeError("LoRA is not enabled.")
|
965 |
+
return self.lora_manager.remove_lora(lora_id)
|
966 |
+
|
967 |
+
def list_loras(self) -> Set[int]:
|
968 |
+
if not self.lora_manager:
|
969 |
+
raise RuntimeError("LoRA is not enabled.")
|
970 |
+
return self.lora_manager.list_loras()
|
971 |
+
|
972 |
+
@torch.inference_mode()
|
973 |
+
def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
|
974 |
+
"""Cuda graph capture a model.
|
975 |
+
|
976 |
+
Note that CUDA graph's performance gain is negligible if number
|
977 |
+
of batched tokens are larger than 200. And since CUDA graph
|
978 |
+
requires fixed sized tensors, supporting large/variable batch
|
979 |
+
size requires high GPU memory overhead. Thus, vLLM only captures
|
980 |
+
decoding requests. Mixed batch (chunked prefill + decoding) or
|
981 |
+
prefill requests are not captured.
|
982 |
+
|
983 |
+
Since it is used for decoding-only, it assumes there's only 1 token
|
984 |
+
per sequence in the batch.
|
985 |
+
"""
|
986 |
+
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
|
987 |
+
# deleted before the CUDA graphs.
|
988 |
+
self.pynccl_backend = pynccl_utils.get_nccl_backend()
|
989 |
+
|
990 |
+
assert not self.model_config.enforce_eager
|
991 |
+
logger.info("Capturing the model for CUDA graphs. This may lead to "
|
992 |
+
"unexpected consequences if the model is not static. To "
|
993 |
+
"run the model in eager mode, set 'enforce_eager=True' or "
|
994 |
+
"use '--enforce-eager' in the CLI.")
|
995 |
+
logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
|
996 |
+
"If you are running out of memory, consider decreasing "
|
997 |
+
"`gpu_memory_utilization` or enforcing eager mode. "
|
998 |
+
"You can also reduce the `max_num_seqs` as needed "
|
999 |
+
"to decrease memory usage.")
|
1000 |
+
start_time = time.perf_counter()
|
1001 |
+
|
1002 |
+
# Prepare dummy inputs. These will be reused for all batch sizes.
|
1003 |
+
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
1004 |
+
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
1005 |
+
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
1006 |
+
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
|
1007 |
+
slot_mapping.fill_(_PAD_SLOT_ID)
|
1008 |
+
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
1009 |
+
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
1010 |
+
|
1011 |
+
graph_batch_size = _get_graph_batch_size(
|
1012 |
+
self.scheduler_config.max_num_seqs)
|
1013 |
+
batch_size_capture_list = [
|
1014 |
+
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
1015 |
+
]
|
1016 |
+
|
1017 |
+
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
|
1018 |
+
# kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
|
1019 |
+
# either custom all-reduce kernel or pynccl. When not using CUDA
|
1020 |
+
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
1021 |
+
# We always prioritize using custom all-reduce kernel but fall back
|
1022 |
+
# to PyTorch or pynccl if it is disabled or not supported.
|
1023 |
+
with custom_all_reduce.capture():
|
1024 |
+
# NOTE: Capturing the largest batch size first may help reduce the
|
1025 |
+
# memory usage of CUDA graph.
|
1026 |
+
for batch_size in reversed(batch_size_capture_list):
|
1027 |
+
# Create dummy attn_metadata.
|
1028 |
+
decode_metadata = self.attn_backend.make_metadata(
|
1029 |
+
is_prompt=False,
|
1030 |
+
prompt_lens=None,
|
1031 |
+
prompt_lens_tensor=None,
|
1032 |
+
max_subquery_len=None,
|
1033 |
+
max_context_len=self.max_context_len_to_capture,
|
1034 |
+
max_prompt_len=None,
|
1035 |
+
subquery_start_loc=None,
|
1036 |
+
seq_start_loc=None,
|
1037 |
+
context_lens=context_lens[:batch_size],
|
1038 |
+
block_tables=block_tables[:batch_size],
|
1039 |
+
use_cuda_graph=True,
|
1040 |
+
)
|
1041 |
+
attn_metadata = AttentionMetadata(
|
1042 |
+
num_prefills=0,
|
1043 |
+
num_prefill_tokens=0,
|
1044 |
+
num_decode_tokens=batch_size,
|
1045 |
+
slot_mapping=slot_mapping[:batch_size],
|
1046 |
+
prefill_metadata=None,
|
1047 |
+
decode_metadata=decode_metadata,
|
1048 |
+
kv_cache_dtype=self.kv_cache_dtype,
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
if self.lora_config:
|
1052 |
+
lora_mapping = LoRAMapping(
|
1053 |
+
[0] * batch_size,
|
1054 |
+
[0] * batch_size,
|
1055 |
+
)
|
1056 |
+
self.set_active_loras(set(), lora_mapping)
|
1057 |
+
|
1058 |
+
graph_runner = CUDAGraphRunner(self.model)
|
1059 |
+
graph_runner.capture(
|
1060 |
+
input_tokens[:batch_size],
|
1061 |
+
input_positions[:batch_size],
|
1062 |
+
kv_caches,
|
1063 |
+
attn_metadata,
|
1064 |
+
memory_pool=self.graph_memory_pool,
|
1065 |
+
)
|
1066 |
+
self.graph_memory_pool = graph_runner.graph.pool()
|
1067 |
+
self.graph_runners[batch_size] = graph_runner
|
1068 |
+
|
1069 |
+
end_time = time.perf_counter()
|
1070 |
+
elapsed_time = end_time - start_time
|
1071 |
+
# This usually takes < 10 seconds.
|
1072 |
+
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
|
1073 |
+
|
1074 |
+
def __del__(self) -> None:
|
1075 |
+
# Delete the CUDA graphs before deleting the pynccl communicator.
|
1076 |
+
# NOTE(woosuk): This is necessary because otherwise deadlocks can
|
1077 |
+
# happen.
|
1078 |
+
# FIXME(woosuk): This is a bit hacky. Find a more robust solution.
|
1079 |
+
# TODO(youkaichao): when we get enough user feedback that pynccl is
|
1080 |
+
# more stable than cupy, we can remove this, e.g. in v0.4.1.
|
1081 |
+
self.graph_runners.clear()
|
1082 |
+
self.pynccl_backend = None
|
1083 |
+
|
1084 |
+
@property
|
1085 |
+
def vocab_size(self) -> int:
|
1086 |
+
return self.model_config.get_vocab_size()
|
1087 |
+
|
1088 |
+
|
1089 |
+
class CUDAGraphRunner:
|
1090 |
+
|
1091 |
+
def __init__(self, model: nn.Module):
|
1092 |
+
self.model = model
|
1093 |
+
self.input_buffers: Dict[str, torch.Tensor] = {}
|
1094 |
+
self.output_buffers: Dict[str, torch.Tensor] = {}
|
1095 |
+
|
1096 |
+
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
1097 |
+
|
1098 |
+
@property
|
1099 |
+
def graph(self):
|
1100 |
+
assert self._graph is not None
|
1101 |
+
return self._graph
|
1102 |
+
|
1103 |
+
def capture(
|
1104 |
+
self,
|
1105 |
+
input_ids: torch.Tensor,
|
1106 |
+
positions: torch.Tensor,
|
1107 |
+
kv_caches: List[torch.Tensor],
|
1108 |
+
attn_metadata: AttentionMetadata,
|
1109 |
+
memory_pool,
|
1110 |
+
**kwargs,
|
1111 |
+
) -> None:
|
1112 |
+
assert self._graph is None
|
1113 |
+
# Run the model once without capturing the graph.
|
1114 |
+
# This is to make sure that the captured graph does not include the
|
1115 |
+
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
1116 |
+
with _maybe_pynccl():
|
1117 |
+
self.model(
|
1118 |
+
input_ids,
|
1119 |
+
positions,
|
1120 |
+
kv_caches,
|
1121 |
+
attn_metadata,
|
1122 |
+
**kwargs,
|
1123 |
+
)
|
1124 |
+
torch.cuda.synchronize()
|
1125 |
+
|
1126 |
+
# Capture the graph.
|
1127 |
+
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
|
1128 |
+
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
1129 |
+
self._graph = torch.cuda.CUDAGraph()
|
1130 |
+
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
|
1131 |
+
with _maybe_pynccl():
|
1132 |
+
hidden_states = self.model(
|
1133 |
+
input_ids,
|
1134 |
+
positions,
|
1135 |
+
kv_caches,
|
1136 |
+
attn_metadata,
|
1137 |
+
**kwargs,
|
1138 |
+
)
|
1139 |
+
torch.cuda.synchronize()
|
1140 |
+
|
1141 |
+
# Save the input and output buffers.
|
1142 |
+
self.input_buffers = {
|
1143 |
+
"input_ids": input_ids,
|
1144 |
+
"positions": positions,
|
1145 |
+
"kv_caches": kv_caches,
|
1146 |
+
"slot_mapping": attn_metadata.slot_mapping,
|
1147 |
+
"context_lens": attn_metadata.decode_metadata.context_lens,
|
1148 |
+
"block_tables": attn_metadata.decode_metadata.block_tables,
|
1149 |
+
}
|
1150 |
+
self.output_buffers = {"hidden_states": hidden_states}
|
1151 |
+
return
|
1152 |
+
|
1153 |
+
def forward(
|
1154 |
+
self,
|
1155 |
+
input_ids: torch.Tensor,
|
1156 |
+
positions: torch.Tensor,
|
1157 |
+
kv_caches: List[torch.Tensor],
|
1158 |
+
attn_metadata: AttentionMetadata,
|
1159 |
+
**kwargs,
|
1160 |
+
) -> torch.Tensor:
|
1161 |
+
# KV caches are fixed tensors, so we don't need to copy them.
|
1162 |
+
del kv_caches
|
1163 |
+
|
1164 |
+
# Copy the input tensors to the input buffers.
|
1165 |
+
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
1166 |
+
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
1167 |
+
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
|
1168 |
+
non_blocking=True)
|
1169 |
+
self.input_buffers["context_lens"].copy_(
|
1170 |
+
attn_metadata.decode_metadata.context_lens, non_blocking=True)
|
1171 |
+
self.input_buffers["block_tables"].copy_(
|
1172 |
+
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
1173 |
+
# Run the graph.
|
1174 |
+
self.graph.replay()
|
1175 |
+
|
1176 |
+
# Return the output tensor.
|
1177 |
+
return self.output_buffers["hidden_states"]
|
1178 |
+
|
1179 |
+
def __call__(self, *args, **kwargs):
|
1180 |
+
return self.forward(*args, **kwargs)
|
1181 |
+
|
1182 |
+
|
1183 |
+
@contextlib.contextmanager
|
1184 |
+
def _maybe_pynccl():
|
1185 |
+
if pynccl_utils.is_initialized(
|
1186 |
+
) and not custom_all_reduce.is_initialized():
|
1187 |
+
with with_pynccl_for_all_reduce():
|
1188 |
+
yield
|
1189 |
+
else:
|
1190 |
+
yield
|
1191 |
+
|
1192 |
+
|
1193 |
+
def _get_graph_batch_size(batch_size: int) -> int:
|
1194 |
+
"""Returns the padded batch size given actual batch size.
|
1195 |
+
|
1196 |
+
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
|
1197 |
+
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
|
1198 |
+
"""
|
1199 |
+
if batch_size <= 2:
|
1200 |
+
return batch_size
|
1201 |
+
elif batch_size <= 4:
|
1202 |
+
return 4
|
1203 |
+
else:
|
1204 |
+
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
1205 |
+
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
1206 |
+
|
1207 |
+
|
1208 |
+
def _prepare_fake_inputs(
|
1209 |
+
seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
|
1210 |
+
"""Prepare fake inputs for profile run."""
|
1211 |
+
if vision_language_config:
|
1212 |
+
prompt_tokens = [
|
1213 |
+
vision_language_config.image_token_id
|
1214 |
+
] * vision_language_config.image_feature_size + [0] * (
|
1215 |
+
seq_len - vision_language_config.image_feature_size)
|
1216 |
+
fake_image_input = MultiModalData(
|
1217 |
+
type=MultiModalData.Type.IMAGE,
|
1218 |
+
data=torch.zeros(vision_language_config.image_input_shape,
|
1219 |
+
dtype=torch.float16))
|
1220 |
+
else:
|
1221 |
+
prompt_tokens = [0] * seq_len
|
1222 |
+
fake_image_input = None
|
1223 |
+
return SequenceData(prompt_tokens), fake_image_input
|
serve/sample_c2i.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
from torchvision.utils import save_image
|
5 |
+
|
6 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
7 |
+
from serve.gpt_model import GPT_models
|
8 |
+
from serve.llm import LLM
|
9 |
+
from vllm import SamplingParams
|
10 |
+
|
11 |
+
|
12 |
+
def main(args):
|
13 |
+
# Setup PyTorch:
|
14 |
+
torch.manual_seed(args.seed)
|
15 |
+
torch.backends.cudnn.deterministic = True
|
16 |
+
torch.backends.cudnn.benchmark = False
|
17 |
+
torch.set_grad_enabled(False)
|
18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
|
20 |
+
# create and load model
|
21 |
+
vq_model = VQ_models[args.vq_model](
|
22 |
+
codebook_size=args.codebook_size,
|
23 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
24 |
+
vq_model.to(device)
|
25 |
+
vq_model.eval()
|
26 |
+
checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
|
27 |
+
vq_model.load_state_dict(checkpoint["model"])
|
28 |
+
del checkpoint
|
29 |
+
print(f"image tokenizer is loaded")
|
30 |
+
|
31 |
+
# Labels to condition the model with (feel free to change):
|
32 |
+
class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
|
33 |
+
latent_size = args.image_size // args.downsample_size
|
34 |
+
qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
|
35 |
+
prompt_token_ids = [[cind] for cind in class_labels]
|
36 |
+
if args.cfg_scale > 1.0:
|
37 |
+
prompt_token_ids.extend([[args.num_classes] for _ in range(len(prompt_token_ids))])
|
38 |
+
# Create an LLM.
|
39 |
+
llm = LLM(
|
40 |
+
args=args,
|
41 |
+
model='autoregressive/serve/fake_json/{}.json'.format(args.gpt_model),
|
42 |
+
gpu_memory_utilization=0.9,
|
43 |
+
skip_tokenizer_init=True)
|
44 |
+
print(f"gpt model is loaded")
|
45 |
+
|
46 |
+
# Create a sampling params object.
|
47 |
+
sampling_params = SamplingParams(
|
48 |
+
temperature=args.temperature, top_p=args.top_p, top_k=args.top_k,
|
49 |
+
max_tokens=latent_size ** 2)
|
50 |
+
|
51 |
+
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
52 |
+
# that contain the prompt, generated text, and other information.
|
53 |
+
t1 = time.time()
|
54 |
+
outputs = llm.generate(
|
55 |
+
prompt_token_ids=prompt_token_ids,
|
56 |
+
sampling_params=sampling_params,
|
57 |
+
use_tqdm=False)
|
58 |
+
sampling_time = time.time() - t1
|
59 |
+
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
60 |
+
|
61 |
+
# decode to image
|
62 |
+
index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
|
63 |
+
if args.cfg_scale > 1.0:
|
64 |
+
index_sample = index_sample[:len(class_labels)]
|
65 |
+
t2 = time.time()
|
66 |
+
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
67 |
+
decoder_time = time.time() - t2
|
68 |
+
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
69 |
+
|
70 |
+
# Save and display images:
|
71 |
+
save_image(samples, "sample_{}.png".format(args.gpt_type), nrow=4, normalize=True, value_range=(-1, 1))
|
72 |
+
print(f"image is saved to sample_{args.gpt_type}.png")
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
parser = argparse.ArgumentParser()
|
77 |
+
parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
|
78 |
+
parser.add_argument("--gpt-ckpt", type=str, required=True, help="ckpt path for gpt model")
|
79 |
+
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
|
80 |
+
parser.add_argument("--from-fsdp", action='store_true')
|
81 |
+
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
|
82 |
+
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
83 |
+
parser.add_argument("--compile", action='store_true', default=False)
|
84 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
85 |
+
parser.add_argument("--vq-ckpt", type=str, required=True, help="ckpt path for vq model")
|
86 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
87 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
88 |
+
parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
|
89 |
+
parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
|
90 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
91 |
+
parser.add_argument("--cfg-scale", type=float, default=4.0)
|
92 |
+
parser.add_argument("--seed", type=int, default=0)
|
93 |
+
parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
|
94 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
|
95 |
+
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
96 |
+
args = parser.parse_args()
|
97 |
+
main(args)
|
serve/sampler.py
ADDED
@@ -0,0 +1,868 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A layer that samples the next tokens from the model's outputs."""
|
2 |
+
import itertools
|
3 |
+
from typing import Dict, List, Optional, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
9 |
+
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
10 |
+
SamplingTensors)
|
11 |
+
from vllm.sampling_params import SamplingParams, SamplingType
|
12 |
+
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
13 |
+
SamplerOutput, SequenceData, SequenceGroupOutput,
|
14 |
+
SequenceOutput)
|
15 |
+
|
16 |
+
|
17 |
+
class Sampler(nn.Module):
|
18 |
+
"""Samples the next tokens from the model's outputs.
|
19 |
+
|
20 |
+
This layer does the following:
|
21 |
+
1. Discard the hidden states that are not used for sampling (i.e., all
|
22 |
+
tokens except the final one in each prompt).
|
23 |
+
2. Compute the logits for the next tokens.
|
24 |
+
3. Apply presence, frequency and repetition penalties.
|
25 |
+
4. Apply temperature scaling.
|
26 |
+
5. Apply top-p and top-k truncation.
|
27 |
+
6. Sample the next tokens.
|
28 |
+
Here, each sequence group within the batch can have different sampling
|
29 |
+
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
30 |
+
|
31 |
+
The structure of the logits tensor is coupled with the seq_groups in
|
32 |
+
sampling_metadata. Typically, each sequence in each seq_group has one row in
|
33 |
+
logits for the next token to be sampled; however, for a seq_group with a
|
34 |
+
prompt request with the prompt_logprobs sampling parameter, there are rows
|
35 |
+
in logits for each token in the input prompt.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, cfg_scale=1.0):
|
39 |
+
super().__init__()
|
40 |
+
self.cfg_scale = cfg_scale
|
41 |
+
# Whether or not the SamplerOutput should have on-device tensors
|
42 |
+
# containing the sampled token ids and probabilities. This is used by
|
43 |
+
# speculative decoding.
|
44 |
+
self.include_gpu_probs_tensor = False
|
45 |
+
|
46 |
+
def forward(
|
47 |
+
self,
|
48 |
+
logits: torch.Tensor,
|
49 |
+
sampling_metadata: SamplingMetadata,
|
50 |
+
) -> Optional[SamplerOutput]:
|
51 |
+
assert logits is not None
|
52 |
+
_, vocab_size = logits.shape
|
53 |
+
|
54 |
+
if self.cfg_scale > 1.0:
|
55 |
+
logits_combined = logits
|
56 |
+
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
|
57 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_scale
|
58 |
+
logits = torch.cat([logits, logits], dim=0)
|
59 |
+
|
60 |
+
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
|
61 |
+
# have not been generated yet
|
62 |
+
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
63 |
+
|
64 |
+
# Prepare sampling tensors with pinned memory to avoid blocking.
|
65 |
+
(sampling_tensors, do_penalties, do_top_p_top_k,
|
66 |
+
do_min_p) = SamplingTensors.from_sampling_metadata(
|
67 |
+
sampling_metadata, vocab_size, logits.device, logits.dtype)
|
68 |
+
|
69 |
+
# Apply presence and frequency penalties.
|
70 |
+
if do_penalties:
|
71 |
+
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
|
72 |
+
sampling_tensors.output_tokens,
|
73 |
+
sampling_tensors.presence_penalties,
|
74 |
+
sampling_tensors.frequency_penalties,
|
75 |
+
sampling_tensors.repetition_penalties)
|
76 |
+
|
77 |
+
# Apply temperature scaling.
|
78 |
+
# Use in-place division to avoid creating a new tensor.
|
79 |
+
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
|
80 |
+
|
81 |
+
if do_top_p_top_k:
|
82 |
+
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
83 |
+
sampling_tensors.top_ks)
|
84 |
+
|
85 |
+
if do_min_p:
|
86 |
+
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
87 |
+
|
88 |
+
# We use float32 for probabilities and log probabilities.
|
89 |
+
# Compute the probabilities.
|
90 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
91 |
+
# Compute the log probabilities.
|
92 |
+
# Use log_softmax to ensure numerical stability.
|
93 |
+
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
94 |
+
|
95 |
+
# Sample the next tokens.
|
96 |
+
sample_results, maybe_sampled_tokens_tensor = _sample(
|
97 |
+
probs,
|
98 |
+
logprobs,
|
99 |
+
sampling_metadata,
|
100 |
+
sampling_tensors,
|
101 |
+
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
102 |
+
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
if self.cfg_scale > 1.0:
|
107 |
+
cond_result = sample_results[:len(sample_results) // 2]
|
108 |
+
sample_results = cond_result + cond_result
|
109 |
+
|
110 |
+
|
111 |
+
if self.include_gpu_probs_tensor:
|
112 |
+
assert maybe_sampled_tokens_tensor is not None
|
113 |
+
sampled_tokens_tensor = maybe_sampled_tokens_tensor
|
114 |
+
on_device_tensors = (probs, sampled_tokens_tensor)
|
115 |
+
else:
|
116 |
+
on_device_tensors = None
|
117 |
+
|
118 |
+
# Get the logprobs query results.
|
119 |
+
prompt_logprobs, sample_logprobs = _get_logprobs(
|
120 |
+
logprobs, sampling_metadata, sample_results)
|
121 |
+
return _build_sampler_output(sample_results,
|
122 |
+
sampling_metadata,
|
123 |
+
prompt_logprobs,
|
124 |
+
sample_logprobs,
|
125 |
+
on_device_tensors=on_device_tensors)
|
126 |
+
|
127 |
+
@property
|
128 |
+
def _should_modify_greedy_probs_inplace(self) -> bool:
|
129 |
+
"""Whether or not the sampler should modify the probability distribution
|
130 |
+
of greedily-sampled tokens such that multinomial sampling would sample
|
131 |
+
the greedily-sampled token.
|
132 |
+
|
133 |
+
In other words, if True then we set the probability of the greedily-
|
134 |
+
sampled token to 1.
|
135 |
+
|
136 |
+
This is used by speculative decoding, which requires that the sampling
|
137 |
+
method be encoded into the probability distribution.
|
138 |
+
"""
|
139 |
+
# Modify greedy probs if include_gpu_probs_tensor is set.
|
140 |
+
return self.include_gpu_probs_tensor
|
141 |
+
|
142 |
+
|
143 |
+
def _get_bin_counts_and_mask(
|
144 |
+
tokens: torch.Tensor,
|
145 |
+
vocab_size: int,
|
146 |
+
num_seqs: int,
|
147 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
148 |
+
# Compute the bin counts for the tokens.
|
149 |
+
# vocab_size + 1 for padding.
|
150 |
+
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
151 |
+
dtype=torch.long,
|
152 |
+
device=tokens.device)
|
153 |
+
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
154 |
+
bin_counts = bin_counts[:, :vocab_size]
|
155 |
+
mask = bin_counts > 0
|
156 |
+
|
157 |
+
return bin_counts, mask
|
158 |
+
|
159 |
+
|
160 |
+
def _apply_min_tokens_penalty(
|
161 |
+
logits: torch.Tensor,
|
162 |
+
sampling_metadata: SamplingMetadata,
|
163 |
+
) -> torch.Tensor:
|
164 |
+
# list of indices in logits that will be set to -inf
|
165 |
+
logits_to_penalize = []
|
166 |
+
start_idx = 0
|
167 |
+
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
168 |
+
seq_ids, sampling_params = seq_group
|
169 |
+
|
170 |
+
# handle prompt_logprobs by skipping rows in logits added for the prompt
|
171 |
+
# tokens (prompt logprobs are not penalized)
|
172 |
+
if (i < sampling_metadata.num_prompts
|
173 |
+
and sampling_params.prompt_logprobs is not None):
|
174 |
+
assert len(seq_ids) == 1
|
175 |
+
start_idx += sampling_metadata.prompt_lens[i] - 1
|
176 |
+
|
177 |
+
min_tokens = sampling_params.min_tokens
|
178 |
+
if min_tokens > 0:
|
179 |
+
seqs_to_penalize = []
|
180 |
+
for i, seq_id in enumerate(seq_ids):
|
181 |
+
seq_data = sampling_metadata.seq_data[seq_id]
|
182 |
+
if len(seq_data.output_token_ids) < min_tokens:
|
183 |
+
seqs_to_penalize.append(i)
|
184 |
+
|
185 |
+
if seqs_to_penalize:
|
186 |
+
# convert to the index into logits
|
187 |
+
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
|
188 |
+
# use set() to remove any duplicates
|
189 |
+
token_ids_to_penalize = set(sampling_params.stop_token_ids +
|
190 |
+
[sampling_params.eos_token_id])
|
191 |
+
# itertools.product pairs each seq index with every token id
|
192 |
+
logits_to_penalize.extend(
|
193 |
+
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
194 |
+
|
195 |
+
start_idx += len(seq_ids)
|
196 |
+
|
197 |
+
if logits_to_penalize:
|
198 |
+
# use zip and * to group indices along each dimension
|
199 |
+
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
200 |
+
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
201 |
+
|
202 |
+
# verifies that no rows in logits were missed unexpectedly
|
203 |
+
assert start_idx == logits.shape[0]
|
204 |
+
return logits
|
205 |
+
|
206 |
+
|
207 |
+
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
208 |
+
output_tokens_tensor: torch.Tensor,
|
209 |
+
presence_penalties: torch.Tensor,
|
210 |
+
frequency_penalties: torch.Tensor,
|
211 |
+
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
212 |
+
num_seqs, vocab_size = logits.shape
|
213 |
+
_, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
|
214 |
+
num_seqs)
|
215 |
+
output_bin_counts, output_mask = _get_bin_counts_and_mask(
|
216 |
+
output_tokens_tensor, vocab_size, num_seqs)
|
217 |
+
|
218 |
+
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
219 |
+
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
|
220 |
+
logits = torch.where(logits > 0, logits / repetition_penalties,
|
221 |
+
logits * repetition_penalties)
|
222 |
+
|
223 |
+
# We follow the definition in OpenAI API.
|
224 |
+
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
225 |
+
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
226 |
+
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
227 |
+
return logits
|
228 |
+
|
229 |
+
|
230 |
+
def _apply_top_k_top_p(
|
231 |
+
logits: torch.Tensor,
|
232 |
+
p: torch.Tensor,
|
233 |
+
k: torch.Tensor,
|
234 |
+
) -> torch.Tensor:
|
235 |
+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
236 |
+
|
237 |
+
# Apply top-k.
|
238 |
+
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
239 |
+
# Get all the top_k values.
|
240 |
+
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
241 |
+
top_k_mask = logits_sort < top_k_mask
|
242 |
+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
243 |
+
|
244 |
+
# Apply top-p.
|
245 |
+
probs_sort = logits_sort.softmax(dim=-1)
|
246 |
+
probs_sum = probs_sort.cumsum(dim=-1)
|
247 |
+
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
248 |
+
# at least one
|
249 |
+
top_p_mask[:, -1] = False
|
250 |
+
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
251 |
+
|
252 |
+
# Re-sort the probabilities.
|
253 |
+
src = torch.arange(logits_idx.shape[-1],
|
254 |
+
device=logits_idx.device).expand_as(logits_idx)
|
255 |
+
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
|
256 |
+
index=logits_idx,
|
257 |
+
src=src)
|
258 |
+
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
|
259 |
+
return logits
|
260 |
+
|
261 |
+
|
262 |
+
def _apply_min_p(
|
263 |
+
logits: torch.Tensor,
|
264 |
+
min_p: torch.Tensor,
|
265 |
+
) -> torch.Tensor:
|
266 |
+
"""
|
267 |
+
Adapted from
|
268 |
+
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
|
269 |
+
"""
|
270 |
+
probs = torch.softmax(logits, dim=-1)
|
271 |
+
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
272 |
+
scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
|
273 |
+
tokens_to_remove = probs < scaled_min_p
|
274 |
+
logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
|
275 |
+
|
276 |
+
return logits
|
277 |
+
|
278 |
+
|
279 |
+
def _greedy_sample(
|
280 |
+
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
281 |
+
samples: torch.Tensor,
|
282 |
+
) -> List[Tuple[List[int], List[int]]]:
|
283 |
+
samples = samples.tolist()
|
284 |
+
sample_idx = 0
|
285 |
+
results = []
|
286 |
+
for seq_group in selected_seq_groups:
|
287 |
+
seq_ids, _ = seq_group
|
288 |
+
num_parent_seqs = len(seq_ids)
|
289 |
+
assert num_parent_seqs == 1, (
|
290 |
+
"Greedy sampling should have only one seq.")
|
291 |
+
parent_ids = list(range(num_parent_seqs))
|
292 |
+
next_token_ids = [samples[sample_idx]]
|
293 |
+
results.append((next_token_ids, parent_ids))
|
294 |
+
sample_idx += num_parent_seqs
|
295 |
+
return results
|
296 |
+
|
297 |
+
|
298 |
+
def _random_sample(
|
299 |
+
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
300 |
+
is_prompts: List[bool],
|
301 |
+
random_samples: torch.Tensor,
|
302 |
+
) -> List[Tuple[List[int], List[int]]]:
|
303 |
+
# Find the maximum best_of value of the prompt phase requests.
|
304 |
+
random_samples = random_samples.cpu()
|
305 |
+
sample_idx = 0
|
306 |
+
results = []
|
307 |
+
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
308 |
+
seq_ids, sampling_params = seq_group
|
309 |
+
num_parent_seqs = len(seq_ids)
|
310 |
+
if is_prompt:
|
311 |
+
# Prompt phase.
|
312 |
+
parent_ids = [0] * sampling_params.best_of
|
313 |
+
next_token_ids = random_samples[
|
314 |
+
sample_idx, :sampling_params.best_of].tolist()
|
315 |
+
else:
|
316 |
+
# Generation phase.
|
317 |
+
parent_ids = list(range(num_parent_seqs))
|
318 |
+
next_token_ids = random_samples[sample_idx:sample_idx +
|
319 |
+
num_parent_seqs, 0].tolist()
|
320 |
+
results.append((next_token_ids, parent_ids))
|
321 |
+
sample_idx += num_parent_seqs
|
322 |
+
return results
|
323 |
+
|
324 |
+
|
325 |
+
def _beam_search_sample(
|
326 |
+
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
327 |
+
is_prompts: List[bool],
|
328 |
+
seq_data: Dict[int, SequenceData],
|
329 |
+
logprobs: torch.Tensor,
|
330 |
+
) -> List[Tuple[List[int], List[int]]]:
|
331 |
+
# We sample 2 * beam_width candidates to make sure that with high
|
332 |
+
# probability we can get `beam_width` candidates in addition to
|
333 |
+
# the finished sequences for the next iteration. See
|
334 |
+
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
335 |
+
# for details. See also HF reference:
|
336 |
+
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
337 |
+
#
|
338 |
+
# NOTE: Beam search is not vectorized, so its speed can be slower than
|
339 |
+
# other sampling methods.
|
340 |
+
sample_idx = 0
|
341 |
+
results = []
|
342 |
+
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
343 |
+
seq_ids, sampling_params = seq_group
|
344 |
+
num_parent_seqs = len(seq_ids)
|
345 |
+
beam_width = sampling_params.best_of
|
346 |
+
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
347 |
+
if is_prompt:
|
348 |
+
# Prompt phase.
|
349 |
+
assert num_parent_seqs == 1, (
|
350 |
+
"Prompt input should have only one seq.")
|
351 |
+
parent_ids = [0] * (2 * beam_width)
|
352 |
+
_, next_token_ids = torch.topk(seq_group_logprobs[0],
|
353 |
+
2 * beam_width)
|
354 |
+
next_token_ids = next_token_ids.tolist()
|
355 |
+
else:
|
356 |
+
# Generation phase.
|
357 |
+
cumulative_logprobs = [
|
358 |
+
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
|
359 |
+
]
|
360 |
+
cumulative_logprobs = torch.tensor(
|
361 |
+
cumulative_logprobs,
|
362 |
+
dtype=torch.float,
|
363 |
+
device=seq_group_logprobs.device)
|
364 |
+
seq_group_logprobs = (seq_group_logprobs +
|
365 |
+
cumulative_logprobs.unsqueeze(dim=1))
|
366 |
+
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
367 |
+
2 * beam_width)
|
368 |
+
topk_ids = topk_ids.tolist()
|
369 |
+
vocab_size = seq_group_logprobs.size(-1)
|
370 |
+
parent_ids = [i // vocab_size for i in topk_ids]
|
371 |
+
next_token_ids = [i % vocab_size for i in topk_ids]
|
372 |
+
results.append((next_token_ids, parent_ids))
|
373 |
+
sample_idx += num_parent_seqs
|
374 |
+
assert sample_idx == logprobs.size(0)
|
375 |
+
return results
|
376 |
+
|
377 |
+
|
378 |
+
# torch.multinomial forces a GPU<->CPU sync.
|
379 |
+
# Therefore, we use an optimized implementation instead.
|
380 |
+
# Note that we always sample with replacement.
|
381 |
+
# probs will be modified in place, but this is fine, as we pass
|
382 |
+
# in a copy already.
|
383 |
+
def _multinomial(
|
384 |
+
probs: torch.Tensor,
|
385 |
+
num_samples: int,
|
386 |
+
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
|
387 |
+
generators: Optional[List[torch.Generator]] = None,
|
388 |
+
) -> torch.Tensor:
|
389 |
+
if num_samples > 1:
|
390 |
+
# This is equivalent to torch.repeat_interleaved (which also
|
391 |
+
# forces a GPU<->CPU sync).
|
392 |
+
# This allows us to do sampling with replacement by creating
|
393 |
+
# num_samples copies of each row in the tensor, and then
|
394 |
+
# batch sampling the resulting tensor.
|
395 |
+
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
396 |
+
probs.shape[1]).contiguous().view(
|
397 |
+
-1, probs.shape[1])
|
398 |
+
q = torch.empty_like(probs)
|
399 |
+
if seq_groups is None:
|
400 |
+
q.exponential_()
|
401 |
+
else:
|
402 |
+
sample_idx = 0
|
403 |
+
for (seq_ids, _), generator in zip(seq_groups, generators):
|
404 |
+
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
405 |
+
q[sample_idx:next_sample_idx].exponential_(generator=generator)
|
406 |
+
sample_idx = next_sample_idx
|
407 |
+
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
408 |
+
|
409 |
+
|
410 |
+
def _sample_with_torch(
|
411 |
+
probs: torch.Tensor,
|
412 |
+
logprobs: torch.Tensor,
|
413 |
+
sampling_metadata: SamplingMetadata,
|
414 |
+
include_gpu_probs_tensor: bool,
|
415 |
+
modify_greedy_probs: bool,
|
416 |
+
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
417 |
+
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
418 |
+
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
419 |
+
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
420 |
+
_, sampling_params = seq_group
|
421 |
+
sampling_type = sampling_params.sampling_type
|
422 |
+
categorized_seq_group_ids[sampling_type].append(i)
|
423 |
+
|
424 |
+
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
425 |
+
sample_metadata = {}
|
426 |
+
multinomial_samples = {}
|
427 |
+
|
428 |
+
# Create output tensor for sampled token ids.
|
429 |
+
if include_gpu_probs_tensor:
|
430 |
+
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
|
431 |
+
1,
|
432 |
+
dtype=torch.long,
|
433 |
+
device=logprobs.device)
|
434 |
+
else:
|
435 |
+
sampled_token_ids_tensor = None
|
436 |
+
|
437 |
+
# Counterintiutively, having two loops here is actually faster.
|
438 |
+
# The first loop can run without waiting on GPU<->CPU sync.
|
439 |
+
for sampling_type in SamplingType:
|
440 |
+
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
441 |
+
num_tokens = len(sample_indices)
|
442 |
+
if num_tokens == 0:
|
443 |
+
continue
|
444 |
+
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
445 |
+
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
446 |
+
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
447 |
+
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
448 |
+
is_prompts, sample_indices)
|
449 |
+
long_sample_indices = sample_indices.long()
|
450 |
+
|
451 |
+
if sampling_type == SamplingType.GREEDY:
|
452 |
+
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
453 |
+
dim=-1)
|
454 |
+
|
455 |
+
if include_gpu_probs_tensor:
|
456 |
+
# Store sampled tokens in output tensor.
|
457 |
+
sampled_token_ids_tensor[
|
458 |
+
long_sample_indices] = greedy_samples.unsqueeze(-1)
|
459 |
+
|
460 |
+
if modify_greedy_probs:
|
461 |
+
# If required, modify the probabilities such that sampling from
|
462 |
+
# the modified distribution would always sample the argmax
|
463 |
+
# token id.
|
464 |
+
_modify_greedy_probs_inplace(logprobs, probs,
|
465 |
+
long_sample_indices,
|
466 |
+
greedy_samples)
|
467 |
+
|
468 |
+
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
469 |
+
max_best_of_in_batch = 1
|
470 |
+
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
471 |
+
if is_prompt:
|
472 |
+
_, sampling_params = seq_group
|
473 |
+
max_best_of_in_batch = max(max_best_of_in_batch,
|
474 |
+
sampling_params.best_of)
|
475 |
+
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
476 |
+
"seq_groups": seq_groups,
|
477 |
+
"generators": sampling_metadata.generators,
|
478 |
+
}
|
479 |
+
|
480 |
+
multinomial_samples[sampling_type] = _multinomial(
|
481 |
+
probs[long_sample_indices], max_best_of_in_batch,
|
482 |
+
**seeded_args)
|
483 |
+
|
484 |
+
if include_gpu_probs_tensor:
|
485 |
+
# Store sampled tokens in output tensor.
|
486 |
+
sampled_token_ids_tensor[
|
487 |
+
long_sample_indices] = multinomial_samples[sampling_type]
|
488 |
+
|
489 |
+
elif sampling_type == SamplingType.BEAM:
|
490 |
+
beam_search_logprobs = logprobs[sample_indices]
|
491 |
+
else:
|
492 |
+
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
493 |
+
|
494 |
+
# GPU<->CPU sync happens in the loop below.
|
495 |
+
# This also converts the sample output to Python objects.
|
496 |
+
|
497 |
+
for sampling_type in SamplingType:
|
498 |
+
if sampling_type not in sample_metadata:
|
499 |
+
continue
|
500 |
+
seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
|
501 |
+
sampling_type]
|
502 |
+
if sampling_type == SamplingType.GREEDY:
|
503 |
+
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
504 |
+
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
505 |
+
sample_results = _random_sample(seq_groups, is_prompts,
|
506 |
+
multinomial_samples[sampling_type])
|
507 |
+
elif sampling_type == SamplingType.BEAM:
|
508 |
+
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
509 |
+
sampling_metadata.seq_data,
|
510 |
+
beam_search_logprobs)
|
511 |
+
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
512 |
+
|
513 |
+
sample_results = [
|
514 |
+
sample_results_dict[i]
|
515 |
+
for i in range(len(sampling_metadata.seq_groups))
|
516 |
+
]
|
517 |
+
return sample_results, sampled_token_ids_tensor
|
518 |
+
|
519 |
+
|
520 |
+
def _sample_with_triton_kernel(
|
521 |
+
probs: torch.Tensor,
|
522 |
+
logprobs: torch.Tensor,
|
523 |
+
sampling_metadata: SamplingMetadata,
|
524 |
+
sampling_tensors: SamplingTensors,
|
525 |
+
) -> List[Tuple[List[int], List[int]]]:
|
526 |
+
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
527 |
+
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
528 |
+
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
529 |
+
_, sampling_params = seq_group
|
530 |
+
sampling_type = sampling_params.sampling_type
|
531 |
+
categorized_seq_group_ids[sampling_type].append(i)
|
532 |
+
|
533 |
+
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
534 |
+
sample_metadata = {}
|
535 |
+
max_best_of_in_batch = 1
|
536 |
+
|
537 |
+
# Counterintiutively, having two loops here is actually faster.
|
538 |
+
# The first loop can run without waiting on GPU<->CPU sync.
|
539 |
+
for sampling_type in SamplingType:
|
540 |
+
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
541 |
+
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
|
542 |
+
num_tokens = len(sample_indices)
|
543 |
+
if num_tokens == 0:
|
544 |
+
continue
|
545 |
+
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
546 |
+
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
547 |
+
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
548 |
+
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
549 |
+
is_prompts, sample_indices,
|
550 |
+
sampled_token_indices)
|
551 |
+
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
|
552 |
+
SamplingType.RANDOM_SEED):
|
553 |
+
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
554 |
+
if is_prompt:
|
555 |
+
_, sampling_params = seq_group
|
556 |
+
max_best_of_in_batch = max(max_best_of_in_batch,
|
557 |
+
sampling_params.best_of)
|
558 |
+
elif sampling_type == SamplingType.BEAM:
|
559 |
+
beam_search_logprobs = logprobs[sample_indices]
|
560 |
+
else:
|
561 |
+
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
562 |
+
|
563 |
+
sampled_tokens, _, _ = sample_triton(
|
564 |
+
probs=probs,
|
565 |
+
seeds=sampling_tensors.sampling_seeds,
|
566 |
+
max_best_of=max_best_of_in_batch,
|
567 |
+
sample_indices=sampling_tensors.sample_indices,
|
568 |
+
logprobs=logprobs,
|
569 |
+
# don't save logprobs because we have logic for that below
|
570 |
+
# TODO: use this instead of the CPU-based logic below
|
571 |
+
save_logprobs=False,
|
572 |
+
)
|
573 |
+
|
574 |
+
# GPU<->CPU sync happens in the loop below.
|
575 |
+
|
576 |
+
for sampling_type in SamplingType:
|
577 |
+
if sampling_type not in sample_metadata:
|
578 |
+
continue
|
579 |
+
(seq_group_ids, seq_groups, is_prompts, sample_indices,
|
580 |
+
sampled_token_indices) = sample_metadata[sampling_type]
|
581 |
+
if sampling_type == SamplingType.GREEDY:
|
582 |
+
sample_results = _greedy_sample(
|
583 |
+
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
|
584 |
+
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
585 |
+
sample_results = _random_sample(
|
586 |
+
seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
|
587 |
+
elif sampling_type == SamplingType.BEAM:
|
588 |
+
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
589 |
+
sampling_metadata.seq_data,
|
590 |
+
beam_search_logprobs)
|
591 |
+
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
592 |
+
|
593 |
+
sample_results = [
|
594 |
+
sample_results_dict[i]
|
595 |
+
for i in range(len(sampling_metadata.seq_groups))
|
596 |
+
]
|
597 |
+
return sample_results
|
598 |
+
|
599 |
+
|
600 |
+
def _sample(
|
601 |
+
probs: torch.Tensor, logprobs: torch.Tensor,
|
602 |
+
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
603 |
+
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
604 |
+
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
605 |
+
return _sample_with_torch(
|
606 |
+
probs,
|
607 |
+
logprobs,
|
608 |
+
sampling_metadata,
|
609 |
+
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
610 |
+
modify_greedy_probs=modify_greedy_probs,
|
611 |
+
)
|
612 |
+
|
613 |
+
# TODO: Enable once Triton kernel & associated code is faster.
|
614 |
+
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
|
615 |
+
# sampling_tensors)
|
616 |
+
|
617 |
+
|
618 |
+
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
619 |
+
"""
|
620 |
+
This function calculates the ranks of the chosen tokens in a logprob tensor.
|
621 |
+
|
622 |
+
Args:
|
623 |
+
x (torch.Tensor): 2D logprob tensor of shape (N, M)
|
624 |
+
where N is the no. of tokens and M is the vocab dim.
|
625 |
+
indices (torch.Tensor): List of chosen token indices.
|
626 |
+
|
627 |
+
Returns:
|
628 |
+
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
|
629 |
+
Each element in the returned tensor represents the rank
|
630 |
+
of the chosen token in the input logprob tensor.
|
631 |
+
"""
|
632 |
+
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
|
633 |
+
indices]
|
634 |
+
return (x > vals[:, None]).long().sum(1).add_(1)
|
635 |
+
|
636 |
+
|
637 |
+
def _get_logprobs(
|
638 |
+
logprobs: torch.Tensor,
|
639 |
+
sampling_metadata: SamplingMetadata,
|
640 |
+
sample_results: List[Tuple[List[int], List[int]]],
|
641 |
+
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
|
642 |
+
int, float]]]]:
|
643 |
+
# Prepare query indices
|
644 |
+
batched_logprobs_query_seq_indices: List[int] = []
|
645 |
+
batched_logprobs_query_token_indices: List[int] = []
|
646 |
+
# at least get one logprob for each token
|
647 |
+
largest_num_logprobs = 1
|
648 |
+
sample_idx = 0
|
649 |
+
for i, (seq_group, sample_result) in enumerate(
|
650 |
+
zip(sampling_metadata.seq_groups, sample_results)):
|
651 |
+
seq_ids, sampling_params = seq_group
|
652 |
+
next_token_ids, parent_ids = sample_result
|
653 |
+
num_parent_seqs = len(seq_ids)
|
654 |
+
if (i < sampling_metadata.num_prompts
|
655 |
+
and sampling_params.prompt_logprobs is not None):
|
656 |
+
largest_num_logprobs = max(largest_num_logprobs,
|
657 |
+
sampling_params.prompt_logprobs)
|
658 |
+
prompt_len = sampling_metadata.prompt_lens[i]
|
659 |
+
prompt_tokens = sampling_metadata.seq_data[
|
660 |
+
seq_ids[0]].prompt_token_ids
|
661 |
+
batched_logprobs_query_seq_indices.extend(
|
662 |
+
sample_idx + j for j in range(prompt_len - 1))
|
663 |
+
batched_logprobs_query_token_indices.extend(
|
664 |
+
token_id for token_id in prompt_tokens[1:])
|
665 |
+
sample_idx += prompt_len - 1
|
666 |
+
batched_logprobs_query_seq_indices.extend(
|
667 |
+
[sample_idx + parent_id for parent_id in parent_ids])
|
668 |
+
batched_logprobs_query_token_indices.extend(next_token_ids)
|
669 |
+
if sampling_params.logprobs is not None:
|
670 |
+
largest_num_logprobs = max(largest_num_logprobs,
|
671 |
+
sampling_params.logprobs)
|
672 |
+
sample_idx += num_parent_seqs
|
673 |
+
assert sample_idx == logprobs.size(0)
|
674 |
+
|
675 |
+
batched_logprobs_query_seq_indices_gpu = torch.tensor(
|
676 |
+
batched_logprobs_query_seq_indices, device=logprobs.device)
|
677 |
+
batched_logprobs_query_token_indices_gpu = torch.tensor(
|
678 |
+
batched_logprobs_query_token_indices, device=logprobs.device)
|
679 |
+
|
680 |
+
# Batched query for logprobs of selected token
|
681 |
+
batched_logprobs_query_result = logprobs[[
|
682 |
+
batched_logprobs_query_seq_indices_gpu,
|
683 |
+
batched_logprobs_query_token_indices_gpu
|
684 |
+
]]
|
685 |
+
|
686 |
+
batched_ranks_query_result = _get_ranks(
|
687 |
+
logprobs[batched_logprobs_query_seq_indices_gpu],
|
688 |
+
batched_logprobs_query_token_indices_gpu)
|
689 |
+
|
690 |
+
# Batched query for logprobs of topk tokens
|
691 |
+
if largest_num_logprobs > 0:
|
692 |
+
top_logprobs, top_token_ids = torch.topk(logprobs,
|
693 |
+
largest_num_logprobs,
|
694 |
+
dim=-1)
|
695 |
+
top_logprobs = top_logprobs.cpu()
|
696 |
+
top_token_ids = top_token_ids.cpu()
|
697 |
+
else:
|
698 |
+
top_logprobs, top_token_ids = None, None
|
699 |
+
|
700 |
+
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
|
701 |
+
batched_ranks_query_result = batched_ranks_query_result.cpu()
|
702 |
+
|
703 |
+
# Gather results
|
704 |
+
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
705 |
+
result_sample_logprobs: List[SampleLogprobs] = []
|
706 |
+
sample_idx = 0
|
707 |
+
query_result_idx = 0
|
708 |
+
for i, (seq_group, sample_result) in enumerate(
|
709 |
+
zip(sampling_metadata.seq_groups, sample_results)):
|
710 |
+
seq_ids, sampling_params = seq_group
|
711 |
+
next_token_ids, parent_ids = sample_result
|
712 |
+
|
713 |
+
# Prompt logprobs
|
714 |
+
if (i < sampling_metadata.num_prompts
|
715 |
+
and sampling_params.prompt_logprobs is not None):
|
716 |
+
num_logprobs = sampling_params.prompt_logprobs
|
717 |
+
prompt_tokens = sampling_metadata.seq_data[
|
718 |
+
seq_ids[0]].prompt_token_ids
|
719 |
+
group_prompt_logprobs: PromptLogprobs = [None]
|
720 |
+
for token_id in prompt_tokens[1:]:
|
721 |
+
prompt_logprobs_dict = {
|
722 |
+
token_id:
|
723 |
+
(batched_logprobs_query_result[query_result_idx].item(),
|
724 |
+
batched_ranks_query_result[query_result_idx].item())
|
725 |
+
}
|
726 |
+
if num_logprobs > 0:
|
727 |
+
prompt_logprobs_dict.update(
|
728 |
+
zip(
|
729 |
+
top_token_ids[sample_idx, :num_logprobs].tolist(),
|
730 |
+
zip(
|
731 |
+
top_logprobs[
|
732 |
+
sample_idx, :num_logprobs].tolist(),
|
733 |
+
range(1, num_logprobs + 1))))
|
734 |
+
group_prompt_logprobs.append({
|
735 |
+
token_id: Logprob(*logprob_rank)
|
736 |
+
for token_id, logprob_rank in prompt_logprobs_dict.items()
|
737 |
+
})
|
738 |
+
sample_idx += 1
|
739 |
+
query_result_idx += 1
|
740 |
+
result_prompt_logprobs.append(group_prompt_logprobs)
|
741 |
+
else:
|
742 |
+
result_prompt_logprobs.append(None)
|
743 |
+
|
744 |
+
# Sample logprobs
|
745 |
+
num_logprobs = sampling_params.logprobs
|
746 |
+
if num_logprobs is None:
|
747 |
+
num_logprobs = 0
|
748 |
+
group_sample_logprobs: SampleLogprobs = []
|
749 |
+
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
|
750 |
+
sample_logprobs_dict = {
|
751 |
+
next_token_id:
|
752 |
+
(batched_logprobs_query_result[query_result_idx].item(),
|
753 |
+
batched_ranks_query_result[query_result_idx].item())
|
754 |
+
}
|
755 |
+
query_result_idx += 1
|
756 |
+
if num_logprobs >= 0:
|
757 |
+
sample_logprobs_dict.update(
|
758 |
+
zip(
|
759 |
+
top_token_ids[sample_idx +
|
760 |
+
parent_id, :num_logprobs].tolist(),
|
761 |
+
zip(
|
762 |
+
top_logprobs[sample_idx +
|
763 |
+
parent_id, :num_logprobs].tolist(),
|
764 |
+
range(1, num_logprobs + 1))))
|
765 |
+
group_sample_logprobs.append({
|
766 |
+
token_id: Logprob(*logprob_rank)
|
767 |
+
for token_id, logprob_rank in sample_logprobs_dict.items()
|
768 |
+
})
|
769 |
+
result_sample_logprobs.append(group_sample_logprobs)
|
770 |
+
sample_idx += len(seq_ids)
|
771 |
+
|
772 |
+
return result_prompt_logprobs, result_sample_logprobs
|
773 |
+
|
774 |
+
|
775 |
+
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
776 |
+
sample_indices: torch.Tensor,
|
777 |
+
greedy_samples: torch.Tensor) -> None:
|
778 |
+
"""Modify the probability distributions of the greedily-sampled tokens such
|
779 |
+
that each sampled token has a "probability" of 1.0. This is required by
|
780 |
+
speculative decoding, which depends on the sampling method being encoded
|
781 |
+
within the probability distribution for correctness.
|
782 |
+
|
783 |
+
# Why do we only need to do this for greedy sampling?
|
784 |
+
|
785 |
+
vLLM's sampler performs the following steps for greedy or multinomial
|
786 |
+
(random) sampling:
|
787 |
+
1. Get logits from model.
|
788 |
+
2. Modify logits according to per-sequence sampling parameters.
|
789 |
+
- Multiply by temperature, top-k and top-p masking, penalize tokens
|
790 |
+
according to their frequency, etc.
|
791 |
+
3. Sample a token.
|
792 |
+
- Random sampling simply samples from the modified probability
|
793 |
+
distribution.
|
794 |
+
- Greedy sampling performs `argmax` to obtain the token with the
|
795 |
+
highest likelihood.
|
796 |
+
|
797 |
+
Ignoring greedy sampling for a moment, we find that the computed probability
|
798 |
+
distribution has the following property: we can sample from it independently
|
799 |
+
and find that the token sampled by the Sampler has a frequency corresponding
|
800 |
+
to how often we see it in our sampling. In other words, for tokens sampled
|
801 |
+
with vLLM's random SamplingType, the computed probability distribution
|
802 |
+
encodes the sampling methodology completely.
|
803 |
+
|
804 |
+
Greedy sampling does not normally have this property. vLLM modifies logits
|
805 |
+
according to sampling params, then performs `argmax`, then returns the
|
806 |
+
sampled token and the computed probability distribution. If we sample from
|
807 |
+
the distribution, we'll find the likelihood of the greedily-sampled token
|
808 |
+
is not always 1.0.
|
809 |
+
|
810 |
+
Since lossless speculative decoding requires that the sampling methodology
|
811 |
+
be encoded within the probability distribution, we are motivated to modify
|
812 |
+
the probability distribution such that the sampled token has probability 1
|
813 |
+
when speculative decoding is used.
|
814 |
+
|
815 |
+
NOTE: Alternatively, we could use an extremely low temperature to achieve
|
816 |
+
greedy sampling using multinomial computation and unite the codepaths. This
|
817 |
+
has implications on the overall design of the sampler, e.g. how to record
|
818 |
+
accurate logprobs for the user, so this improvement is deferred to later.
|
819 |
+
"""
|
820 |
+
logprobs[sample_indices, :] = -float('inf')
|
821 |
+
logprobs[sample_indices, greedy_samples] = 0.0
|
822 |
+
probs[sample_indices, :] = 0
|
823 |
+
probs[sample_indices, greedy_samples] = 1.0
|
824 |
+
|
825 |
+
|
826 |
+
def _build_sampler_output(
|
827 |
+
sample_results: List[Tuple[List[int], List[int]]],
|
828 |
+
sampling_metadata: SamplingMetadata,
|
829 |
+
prompt_logprobs: List[Optional[PromptLogprobs]],
|
830 |
+
sample_logprobs: List[SampleLogprobs],
|
831 |
+
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
832 |
+
) -> SamplerOutput:
|
833 |
+
"""Construct Python objects with the output of sampling.
|
834 |
+
|
835 |
+
Args:
|
836 |
+
on_device_tensors: Tuple containing on-device tensors with the
|
837 |
+
probabilities used in sampling and the sampled token ids. This
|
838 |
+
allows post-processing without copies to CPU/serialization, e.g. in
|
839 |
+
speculative decoding rejection sampling.
|
840 |
+
"""
|
841 |
+
|
842 |
+
sampler_output = []
|
843 |
+
for (seq_group, sample_result, group_prompt_logprobs,
|
844 |
+
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
845 |
+
sample_results, prompt_logprobs,
|
846 |
+
sample_logprobs):
|
847 |
+
seq_ids, _ = seq_group
|
848 |
+
next_token_ids, parent_ids = sample_result
|
849 |
+
seq_outputs = []
|
850 |
+
for parent_id, next_token_id, logprobs in zip(parent_ids,
|
851 |
+
next_token_ids,
|
852 |
+
group_sample_logprobs):
|
853 |
+
seq_outputs.append(
|
854 |
+
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
855 |
+
sampler_output.append(
|
856 |
+
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
857 |
+
|
858 |
+
# If not specified, store None values in SamplerOutput.
|
859 |
+
if on_device_tensors is not None:
|
860 |
+
sampled_token_probs, sampled_token_ids = on_device_tensors
|
861 |
+
else:
|
862 |
+
sampled_token_probs, sampled_token_ids = (None, None)
|
863 |
+
|
864 |
+
return SamplerOutput(
|
865 |
+
outputs=sampler_output,
|
866 |
+
sampled_token_probs=sampled_token_probs,
|
867 |
+
sampled_token_ids=sampled_token_ids,
|
868 |
+
)
|
serve/worker.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A GPU worker class."""
|
2 |
+
import gc
|
3 |
+
import os
|
4 |
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed
|
8 |
+
|
9 |
+
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
10 |
+
ModelConfig, ParallelConfig, SchedulerConfig,
|
11 |
+
VisionLanguageConfig)
|
12 |
+
from vllm.distributed import (broadcast_tensor_dict,
|
13 |
+
ensure_model_parallel_initialized,
|
14 |
+
init_distributed_environment)
|
15 |
+
from vllm.distributed.device_communicators import pynccl_utils
|
16 |
+
from vllm.distributed.device_communicators.custom_all_reduce import (
|
17 |
+
init_custom_ar)
|
18 |
+
from vllm.lora.request import LoRARequest
|
19 |
+
from vllm.model_executor import set_random_seed
|
20 |
+
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
21 |
+
from vllm.worker.cache_engine import CacheEngine
|
22 |
+
# from vllm.worker.model_runner import ModelRunner
|
23 |
+
from vllm.worker.worker_base import WorkerBase
|
24 |
+
from serve.model_runner import ModelRunner
|
25 |
+
|
26 |
+
|
27 |
+
class Worker(WorkerBase):
|
28 |
+
"""A worker class that executes (a partition of) the model on a GPU.
|
29 |
+
|
30 |
+
Each worker is associated with a single GPU. The worker is responsible for
|
31 |
+
maintaining the KV cache and executing the model on the GPU. In case of
|
32 |
+
distributed inference, each worker is assigned a partition of the model.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
model_config: ModelConfig,
|
38 |
+
parallel_config: ParallelConfig,
|
39 |
+
scheduler_config: SchedulerConfig,
|
40 |
+
device_config: DeviceConfig,
|
41 |
+
cache_config: CacheConfig,
|
42 |
+
load_config: LoadConfig,
|
43 |
+
local_rank: int,
|
44 |
+
rank: int,
|
45 |
+
distributed_init_method: str,
|
46 |
+
lora_config: Optional[LoRAConfig] = None,
|
47 |
+
vision_language_config: Optional[VisionLanguageConfig] = None,
|
48 |
+
is_driver_worker: bool = False,
|
49 |
+
) -> None:
|
50 |
+
self.model_config = model_config
|
51 |
+
self.parallel_config = parallel_config
|
52 |
+
self.scheduler_config = scheduler_config
|
53 |
+
self.device_config = device_config
|
54 |
+
self.cache_config = cache_config
|
55 |
+
self.local_rank = local_rank
|
56 |
+
self.rank = rank
|
57 |
+
self.distributed_init_method = distributed_init_method
|
58 |
+
self.lora_config = lora_config
|
59 |
+
self.load_config = load_config
|
60 |
+
self.is_driver_worker = is_driver_worker
|
61 |
+
if self.is_driver_worker:
|
62 |
+
assert self.rank == 0, "The driver worker must have rank 0."
|
63 |
+
|
64 |
+
if self.model_config.trust_remote_code:
|
65 |
+
# note: lazy import to avoid importing torch before initializing
|
66 |
+
from vllm.utils import init_cached_hf_modules
|
67 |
+
init_cached_hf_modules()
|
68 |
+
self.vision_language_config = vision_language_config
|
69 |
+
if self.vision_language_config:
|
70 |
+
assert not self.lora_config, (
|
71 |
+
"To be tested: vision language model with LoRA settings.")
|
72 |
+
|
73 |
+
self.model_runner = ModelRunner(
|
74 |
+
model_config,
|
75 |
+
parallel_config,
|
76 |
+
scheduler_config,
|
77 |
+
device_config,
|
78 |
+
load_config=load_config,
|
79 |
+
lora_config=self.lora_config,
|
80 |
+
kv_cache_dtype=self.cache_config.cache_dtype,
|
81 |
+
is_driver_worker=is_driver_worker,
|
82 |
+
vision_language_config=vision_language_config,
|
83 |
+
)
|
84 |
+
# Uninitialized cache engine. Will be initialized by
|
85 |
+
# initialize_cache.
|
86 |
+
self.cache_engine: CacheEngine
|
87 |
+
self.gpu_cache: List[torch.Tensor]
|
88 |
+
|
89 |
+
def init_device(self) -> None:
|
90 |
+
if self.device_config.device.type == "cuda":
|
91 |
+
# torch.distributed.all_reduce does not free the input tensor until
|
92 |
+
# the synchronization point. This causes the memory usage to grow
|
93 |
+
# as the number of all_reduce calls increases. This env var disables
|
94 |
+
# this behavior.
|
95 |
+
# Related issue:
|
96 |
+
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
97 |
+
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
98 |
+
|
99 |
+
# This env var set by Ray causes exceptions with graph building.
|
100 |
+
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
101 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
102 |
+
torch.cuda.set_device(self.device)
|
103 |
+
|
104 |
+
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
105 |
+
torch.cuda.empty_cache()
|
106 |
+
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
|
107 |
+
else:
|
108 |
+
raise RuntimeError(
|
109 |
+
f"Not support device type: {self.device_config.device}")
|
110 |
+
# Initialize the distributed environment.
|
111 |
+
init_worker_distributed_environment(self.parallel_config, self.rank,
|
112 |
+
self.distributed_init_method,
|
113 |
+
self.local_rank)
|
114 |
+
# Set random seed.
|
115 |
+
set_random_seed(self.model_config.seed)
|
116 |
+
|
117 |
+
def load_model(self, args):
|
118 |
+
self.model_runner.load_model(args)
|
119 |
+
|
120 |
+
@torch.inference_mode()
|
121 |
+
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
122 |
+
"""Profiles the peak memory usage of the model to determine how many
|
123 |
+
KV blocks may be allocated without OOMs.
|
124 |
+
|
125 |
+
The engine will first conduct a profiling of the existing memory usage.
|
126 |
+
Then, it calculate the maximum possible number of GPU and CPU blocks
|
127 |
+
that can be allocated with the remaining free memory.
|
128 |
+
|
129 |
+
.. tip::
|
130 |
+
You may limit the usage of GPU memory
|
131 |
+
by adjusting the `gpu_memory_utilization` parameter.
|
132 |
+
"""
|
133 |
+
# Profile the memory usage of the model and get the maximum number of
|
134 |
+
# cache blocks that can be allocated with the remaining free memory.
|
135 |
+
torch.cuda.empty_cache()
|
136 |
+
|
137 |
+
# Execute a forward pass with dummy inputs to profile the memory usage
|
138 |
+
# of the model.
|
139 |
+
self.model_runner.profile_run()
|
140 |
+
|
141 |
+
# Calculate the number of blocks that can be allocated with the
|
142 |
+
# profiled peak memory.
|
143 |
+
torch.cuda.synchronize()
|
144 |
+
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
145 |
+
# NOTE(woosuk): Here we assume that the other processes using the same
|
146 |
+
# GPU did not change their memory usage during the profiling.
|
147 |
+
peak_memory = self.init_gpu_memory - free_gpu_memory
|
148 |
+
assert peak_memory > 0, (
|
149 |
+
"Error in memory profiling. This happens when the GPU memory was "
|
150 |
+
"not properly cleaned up before initializing the vLLM instance.")
|
151 |
+
|
152 |
+
cache_block_size = self.get_cache_block_size_bytes()
|
153 |
+
num_gpu_blocks = int(
|
154 |
+
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
155 |
+
peak_memory) // cache_block_size)
|
156 |
+
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
157 |
+
cache_block_size)
|
158 |
+
num_gpu_blocks = max(num_gpu_blocks, 0)
|
159 |
+
num_cpu_blocks = max(num_cpu_blocks, 0)
|
160 |
+
if self.model_runner.lora_manager:
|
161 |
+
self.model_runner.remove_all_loras()
|
162 |
+
gc.collect()
|
163 |
+
torch.cuda.empty_cache()
|
164 |
+
return num_gpu_blocks, num_cpu_blocks
|
165 |
+
|
166 |
+
def initialize_cache(self, num_gpu_blocks: int,
|
167 |
+
num_cpu_blocks: int) -> None:
|
168 |
+
"""Allocate GPU and CPU KV cache with the specified number of blocks.
|
169 |
+
|
170 |
+
This also warms up the model, which may record CUDA graphs.
|
171 |
+
"""
|
172 |
+
raise_if_cache_size_invalid(num_gpu_blocks,
|
173 |
+
self.cache_config.block_size,
|
174 |
+
self.model_config.max_model_len)
|
175 |
+
|
176 |
+
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
177 |
+
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
178 |
+
|
179 |
+
self._init_cache_engine()
|
180 |
+
self._warm_up_model()
|
181 |
+
|
182 |
+
def _init_cache_engine(self):
|
183 |
+
assert self.cache_config.num_gpu_blocks is not None
|
184 |
+
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
185 |
+
self.parallel_config)
|
186 |
+
self.gpu_cache = self.cache_engine.gpu_cache
|
187 |
+
self.model_runner.set_block_size(self.cache_engine.block_size)
|
188 |
+
|
189 |
+
def _warm_up_model(self) -> None:
|
190 |
+
if not self.model_config.enforce_eager:
|
191 |
+
self.model_runner.capture_model(self.gpu_cache)
|
192 |
+
# Reset the seed to ensure that the random state is not affected by
|
193 |
+
# the model initialization and profiling.
|
194 |
+
set_random_seed(self.model_config.seed)
|
195 |
+
|
196 |
+
def cache_swap(
|
197 |
+
self,
|
198 |
+
blocks_to_swap_in: Dict[int, int],
|
199 |
+
blocks_to_swap_out: Dict[int, int],
|
200 |
+
blocks_to_copy: Dict[int, List[int]],
|
201 |
+
) -> None:
|
202 |
+
# Issue cache operations.
|
203 |
+
# TODO(woosuk): Profile swapping overhead and optimize if needed.
|
204 |
+
if blocks_to_swap_in:
|
205 |
+
self.cache_engine.swap_in(blocks_to_swap_in)
|
206 |
+
if blocks_to_swap_out:
|
207 |
+
self.cache_engine.swap_out(blocks_to_swap_out)
|
208 |
+
if blocks_to_copy:
|
209 |
+
self.cache_engine.copy(blocks_to_copy)
|
210 |
+
|
211 |
+
@torch.inference_mode()
|
212 |
+
def execute_model(
|
213 |
+
self,
|
214 |
+
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
215 |
+
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
216 |
+
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
217 |
+
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
218 |
+
num_lookahead_slots: int = 0,
|
219 |
+
) -> List[SamplerOutput]:
|
220 |
+
|
221 |
+
if self.is_driver_worker:
|
222 |
+
assert seq_group_metadata_list is not None
|
223 |
+
num_seq_groups = len(seq_group_metadata_list)
|
224 |
+
assert blocks_to_swap_in is not None
|
225 |
+
assert blocks_to_swap_out is not None
|
226 |
+
assert blocks_to_copy is not None
|
227 |
+
data: Dict[str, Any] = {
|
228 |
+
"num_seq_groups": num_seq_groups,
|
229 |
+
"blocks_to_swap_in": blocks_to_swap_in,
|
230 |
+
"blocks_to_swap_out": blocks_to_swap_out,
|
231 |
+
"blocks_to_copy": blocks_to_copy,
|
232 |
+
}
|
233 |
+
broadcast_tensor_dict(data, src=0)
|
234 |
+
else:
|
235 |
+
data = broadcast_tensor_dict(src=0)
|
236 |
+
num_seq_groups = data["num_seq_groups"]
|
237 |
+
blocks_to_swap_in = data["blocks_to_swap_in"]
|
238 |
+
blocks_to_swap_out = data["blocks_to_swap_out"]
|
239 |
+
blocks_to_copy = data["blocks_to_copy"]
|
240 |
+
|
241 |
+
assert blocks_to_swap_in is not None
|
242 |
+
assert blocks_to_swap_out is not None
|
243 |
+
assert blocks_to_copy is not None
|
244 |
+
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
245 |
+
|
246 |
+
# If there is no input, we don't need to execute the model.
|
247 |
+
if num_seq_groups == 0:
|
248 |
+
return []
|
249 |
+
|
250 |
+
output = self.model_runner.execute_model(seq_group_metadata_list,
|
251 |
+
self.gpu_cache)
|
252 |
+
|
253 |
+
# Worker only supports single-step execution. Wrap the output in a list
|
254 |
+
# to conform to interface.
|
255 |
+
return [output]
|
256 |
+
|
257 |
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
258 |
+
return self.model_runner.add_lora(lora_request)
|
259 |
+
|
260 |
+
def remove_lora(self, lora_id: int) -> bool:
|
261 |
+
return self.model_runner.remove_lora(lora_id)
|
262 |
+
|
263 |
+
def list_loras(self) -> Set[int]:
|
264 |
+
return self.model_runner.list_loras()
|
265 |
+
|
266 |
+
@property
|
267 |
+
def max_model_len(self) -> int:
|
268 |
+
return self.model_config.max_model_len
|
269 |
+
|
270 |
+
@property
|
271 |
+
def vocab_size(self) -> int:
|
272 |
+
return self.model_runner.vocab_size
|
273 |
+
|
274 |
+
def get_cache_block_size_bytes(self) -> int:
|
275 |
+
"""Get the size of the KV cache block size in bytes.
|
276 |
+
"""
|
277 |
+
return CacheEngine.get_cache_block_size(self.cache_config,
|
278 |
+
self.model_config,
|
279 |
+
self.parallel_config)
|
280 |
+
|
281 |
+
|
282 |
+
def init_worker_distributed_environment(
|
283 |
+
parallel_config: ParallelConfig,
|
284 |
+
rank: int,
|
285 |
+
distributed_init_method: Optional[str] = None,
|
286 |
+
local_rank: int = -1,
|
287 |
+
) -> None:
|
288 |
+
"""Initialize the distributed environment."""
|
289 |
+
init_distributed_environment(parallel_config.world_size, rank,
|
290 |
+
distributed_init_method, local_rank)
|
291 |
+
|
292 |
+
if pynccl_utils.is_initialized():
|
293 |
+
pynccl_world_size = pynccl_utils.get_world_size()
|
294 |
+
if pynccl_world_size != parallel_config.world_size:
|
295 |
+
raise RuntimeError(
|
296 |
+
"pynccl is already initialized but the pynccl world "
|
297 |
+
"size does not match parallel_config.world_size "
|
298 |
+
f"({pynccl_world_size} vs. {parallel_config.world_size}).")
|
299 |
+
elif parallel_config.world_size > 1:
|
300 |
+
# NOTE(woosuk): We don't initialize pynccl process group when world size
|
301 |
+
# is 1.
|
302 |
+
pynccl_utils.init_process_group(
|
303 |
+
world_size=parallel_config.world_size,
|
304 |
+
local_rank=local_rank,
|
305 |
+
rank=rank,
|
306 |
+
init_method=distributed_init_method,
|
307 |
+
)
|
308 |
+
|
309 |
+
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
310 |
+
parallel_config.pipeline_parallel_size)
|
311 |
+
|
312 |
+
# Initialize a custom fast all-reduce implementation.
|
313 |
+
if not parallel_config.disable_custom_all_reduce:
|
314 |
+
init_custom_ar()
|
315 |
+
|
316 |
+
# A small all_reduce for warmup.
|
317 |
+
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
318 |
+
if pynccl_utils.is_initialized():
|
319 |
+
pynccl_utils.all_reduce(torch.zeros(1).cuda())
|
320 |
+
|
321 |
+
|
322 |
+
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
323 |
+
# Check if the GPU supports the dtype.
|
324 |
+
if torch_dtype == torch.bfloat16:
|
325 |
+
compute_capability = torch.cuda.get_device_capability()
|
326 |
+
if compute_capability[0] < 8:
|
327 |
+
gpu_name = torch.cuda.get_device_name()
|
328 |
+
raise ValueError(
|
329 |
+
"Bfloat16 is only supported on GPUs with compute capability "
|
330 |
+
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
331 |
+
f"{compute_capability[0]}.{compute_capability[1]}. "
|
332 |
+
"You can use float16 instead by explicitly setting the"
|
333 |
+
"`dtype` flag in CLI, for example: --dtype=half.")
|
334 |
+
|
335 |
+
|
336 |
+
def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
|
337 |
+
max_model_len) -> None:
|
338 |
+
if num_gpu_blocks <= 0:
|
339 |
+
raise ValueError("No available memory for the cache blocks. "
|
340 |
+
"Try increasing `gpu_memory_utilization` when "
|
341 |
+
"initializing the engine.")
|
342 |
+
max_seq_len = block_size * num_gpu_blocks
|
343 |
+
if max_model_len > max_seq_len:
|
344 |
+
raise ValueError(
|
345 |
+
f"The model's max seq len ({max_model_len}) "
|
346 |
+
"is larger than the maximum number of tokens that can be "
|
347 |
+
f"stored in KV cache ({max_seq_len}). Try increasing "
|
348 |
+
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
349 |
+
"initializing the engine.")
|