-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdemo.py
More file actions
176 lines (147 loc) · 6.78 KB
/
Copy pathdemo.py
File metadata and controls
176 lines (147 loc) · 6.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import time
import yaml
import argparse
from transformers import AutoTokenizer
def get_model_and_cache_classes(model_name, mode):
is_qwen2 = 'qwen2' in model_name.lower()
if mode == 'sliminfer':
if is_qwen2:
print("Importing SlimInfer for Qwen2...")
from models.modeling_qwen2 import Qwen2ForCausalLM as ModelForCausalLM
else:
print("Importing SlimInfer for Llama...")
from models.modeling_llama import LlamaForCausalLM as ModelForCausalLM
from utils.cache_utils import SlimInferCache, HsCache
return ModelForCausalLM, SlimInferCache, HsCache
elif mode == 'origin':
print("Importing original Hugging Face model...")
from transformers import AutoModelForCausalLM as ModelForCausalLM
return ModelForCausalLM, None, None
else:
raise KeyError(f"Unsupported mode: {mode}")
def print_results_table(results, max_new_tokens_list):
header = ["Input Len", "Prompt Len"]
for n in max_new_tokens_list:
header.append(f"Time ({n} tok) s")
col_widths = [len(h) for h in header]
for row_data in results:
col_widths[0] = max(col_widths[0], len(str(row_data['target_input_length'])))
col_widths[1] = max(col_widths[1], len(str(row_data['actual_prompt_len'])))
for i, n in enumerate(max_new_tokens_list):
time_key = f'time_{n}'
if time_key in row_data:
col_widths[2 + i] = max(col_widths[2 + i], len(f"{row_data[time_key]:.4f}"))
header_str = " | ".join(header[i].center(col_widths[i]) for i in range(len(header)))
separator = "-+-".join("-" * w for w in col_widths)
print("\n" + "=" * len(separator))
print("Benchmark Results Summary".center(len(separator)))
print("=" * len(separator))
print(header_str)
print(separator)
for row_data in results:
row_values = [
str(row_data['target_input_length']),
str(row_data['actual_prompt_len'])
]
for n in max_new_tokens_list:
time_val = row_data.get(f'time_{n}', 'N/A')
row_values.append(f"{time_val:.4f}" if isinstance(time_val, float) else time_val)
row_str = " | ".join(row_values[i].rjust(col_widths[i]) for i in range(len(row_values)))
print(row_str)
print("=" * len(separator))
if __name__ == "__main__":
script_dir = os.path.dirname(os.path.realpath(__file__))
if script_dir:
os.chdir(script_dir)
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default='/path/to/llama3.1-8b-instruct', help="Path to the model")
parser.add_argument("--pruning_config", type=str, default="prune_configs/b64_t09_w4_prune_fx_9_8_19_4_29_2.yaml")
parser.add_argument("--mode", type=str, default="sliminfer", choices=['sliminfer', 'origin'])
args = parser.parse_args()
print(args)
ModelForCausalLM, SlimInferCache, HsCache = get_model_and_cache_classes(args.model, args.mode)
sliminfer_config = None
if "sliminfer" in args.mode:
with open(args.pruning_config, 'r') as f:
sliminfer_config = yaml.safe_load(f)
model = ModelForCausalLM.from_pretrained(
args.model,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2',
)
model = model.to("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
import json
json_path = "data/example_data.json"
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
story = data[0]["prompt"]
print("\n--- Starting Global Warm-up ---")
if "sliminfer" in args.mode:
warmup_input = tokenizer("hello world" * 500, return_tensors="pt").to(model.device)
warmup_kwargs = {
"past_key_values": SlimInferCache(model.device),
"hs_cache": HsCache(),
"sliminfer_config": sliminfer_config,
"all_original_prompt_indices_abs": list(range(warmup_input.input_ids.shape[1])),
}
with torch.no_grad():
_ = model.generate(warmup_input.input_ids, max_new_tokens=2, use_cache=True, **warmup_kwargs)
else:
warmup_input = tokenizer("hello world", return_tensors="pt").to(model.device)
with torch.no_grad():
_ = model.generate(warmup_input.input_ids, max_new_tokens=2, use_cache=True)
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("--- Global Warm-up Finished ---\n")
all_results = []
intput_lengthes = [4 * 1024, 8 * 1024, 16 * 1024, 24 * 1024, 28 * 1024, 32 * 1024]
max_new_tokens_list = [1, 1, 4, 8, 16]
story *= 100
for input_length in intput_lengthes:
torch.cuda.empty_cache()
print(f"--- Testing for target input length: {input_length} ---")
model_inputs = tokenizer(story, return_tensors="pt", truncation=True, max_length=input_length).to(model.device)
real_input_length = model_inputs.input_ids.shape[1]
print(f"Actual tokenized input length: {real_input_length}")
current_run_results = {
'target_input_length': input_length,
'actual_prompt_len': real_input_length,
}
for n_tokens in max_new_tokens_list:
print(f" - Benchmarking with max_new_tokens = {n_tokens}...")
kwargs = {}
if "sliminfer" in args.mode:
kwargs = {
"past_key_values": SlimInferCache(model.device),
"hs_cache": HsCache(),
"sliminfer_config": sliminfer_config,
"all_original_prompt_indices_abs": list(range(real_input_length)),
}
torch.cuda.synchronize()
start_time = time.perf_counter()
with torch.no_grad():
_ = model.generate(
model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
max_new_tokens=n_tokens,
do_sample=False,
use_cache=True,
**kwargs
)
torch.cuda.synchronize()
end_time = time.perf_counter()
duration = end_time - start_time
current_run_results[f'time_{n_tokens}'] = duration
print(f" -> Duration: {duration:.4f} seconds")
all_results.append(current_run_results)
print_results_table(all_results, max_new_tokens_list)