File size: 4,989 Bytes
03f3fb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import time
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import torch
from threading import Thread

MODEL_PATH = 'THUDM/glm-4-9b-chat'


def stress_test(token_len, n, num_gpu):
    device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATH,
        trust_remote_code=True,
        padding_side="left"
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16
    ).to(device).eval()
    
    # Use INT4 weight infer
    # model = AutoModelForCausalLM.from_pretrained(
    #     MODEL_PATH,
    #     trust_remote_code=True,
    #     quantization_config=BitsAndBytesConfig(load_in_4bit=True),
    #     low_cpu_mem_usage=True,
    # ).eval()

    times = []
    decode_times = []

    print("Warming up...")
    vocab_size = tokenizer.vocab_size
    warmup_token_len = 20
    random_token_ids = torch.randint(3, vocab_size - 200, (warmup_token_len - 5,), dtype=torch.long)
    start_tokens = [151331, 151333, 151336, 198]
    end_tokens = [151337]
    input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(0).to(
        device)
    attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
    position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
    warmup_inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'position_ids': position_ids
    }
    with torch.no_grad():
        _ = model.generate(
            input_ids=warmup_inputs['input_ids'],
            attention_mask=warmup_inputs['attention_mask'],
            max_new_tokens=2048,
            do_sample=False,
            repetition_penalty=1.0,
            eos_token_id=[151329, 151336, 151338]
        )
    print("Warming up complete. Starting stress test...")

    for i in range(n):
        random_token_ids = torch.randint(3, vocab_size - 200, (token_len - 5,), dtype=torch.long)
        input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(
            0).to(device)
        attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
        position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
        test_inputs = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'position_ids': position_ids
        }

        streamer = TextIteratorStreamer(
            tokenizer=tokenizer,
            timeout=36000,
            skip_prompt=True,
            skip_special_tokens=True
        )

        generate_kwargs = {
            "input_ids": test_inputs['input_ids'],
            "attention_mask": test_inputs['attention_mask'],
            "max_new_tokens": 512,
            "do_sample": False,
            "repetition_penalty": 1.0,
            "eos_token_id": [151329, 151336, 151338],
            "streamer": streamer
        }

        start_time = time.time()
        t = Thread(target=model.generate, kwargs=generate_kwargs)
        t.start()

        first_token_time = None
        all_token_times = []

        for token in streamer:
            current_time = time.time()
            if first_token_time is None:
                first_token_time = current_time
                times.append(first_token_time - start_time)
            all_token_times.append(current_time)

        t.join()
        end_time = time.time()

        avg_decode_time_per_token = len(all_token_times) / (end_time - first_token_time) if all_token_times else 0
        decode_times.append(avg_decode_time_per_token)
        print(
            f"Iteration {i + 1}/{n} - Prefilling Time: {times[-1]:.4f} seconds - Average Decode Time: {avg_decode_time_per_token:.4f} tokens/second")

        torch.cuda.empty_cache()

    avg_first_token_time = sum(times) / n
    avg_decode_time = sum(decode_times) / n
    print(f"\nAverage First Token Time over {n} iterations: {avg_first_token_time:.4f} seconds")
    print(f"Average Decode Time per Token over {n} iterations: {avg_decode_time:.4f} tokens/second")
    return times, avg_first_token_time, decode_times, avg_decode_time


def main():
    parser = argparse.ArgumentParser(description="Stress test for model inference")
    parser.add_argument('--token_len', type=int, default=1000, help='Number of tokens for each test')
    parser.add_argument('--n', type=int, default=3, help='Number of iterations for the stress test')
    parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPUs to use for inference')
    args = parser.parse_args()

    token_len = args.token_len
    n = args.n
    num_gpu = args.num_gpu

    stress_test(token_len, n, num_gpu)


if __name__ == "__main__":
    main()