Skip to content

Commit 15f1ea7

Browse files
authored
WOQ qconfig API compatibility in scripts (#2354)
* Update run_llama_quantization.py * Update run_codegen_quantization.py * Update run_falcon_quantization.py * Update run_gpt-j_quantization.py * Update run_gpt-neox_quantization.py * Update run_opt_quantization.py
1 parent 6ea47af commit 15f1ea7

File tree

6 files changed

+113
-72
lines changed

6 files changed

+113
-72
lines changed

examples/cpu/inference/python/llm/single_instance/run_codegen_quantization.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,25 @@
190190
else:
191191
lowp_mode = ipex.quantization.WoqLowpMode.BF16
192192

193-
act_quant_mode_dict = {
194-
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
195-
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
196-
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
197-
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
198-
}
199-
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
200-
weight_dtype=weight_dtype,
201-
lowp_mode=lowp_mode,
202-
act_quant_mode=act_quant_mode_dict[args.act_quant_mode],
203-
group_size=args.group_size
204-
)
193+
try:
194+
act_quant_mode_dict = {
195+
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
196+
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
197+
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
198+
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
199+
}
200+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
201+
weight_dtype=weight_dtype,
202+
lowp_mode=lowp_mode,
203+
act_quant_mode=act_quant_mode_dict[args.act_quant_mode],
204+
group_size=args.group_size
205+
)
206+
except:
207+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
208+
weight_dtype=weight_dtype,
209+
lowp_mode=lowp_mode,
210+
)
211+
205212
if args.low_precision_checkpoint != "":
206213
low_precision_checkpoint = torch.load(args.low_precision_checkpoint)
207214
else:

examples/cpu/inference/python/llm/single_instance/run_falcon_quantization.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -205,18 +205,25 @@
205205
else:
206206
lowp_mode = ipex.quantization.WoqLowpMode.BF16
207207

208-
act_quant_mode_dict = {
209-
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
210-
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
211-
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
212-
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
213-
}
214-
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
215-
weight_dtype=weight_dtype,
216-
lowp_mode=lowp_mode,
217-
act_quant_mode=act_quant_mode_dict[args.act_quant_mode],
218-
group_size=args.group_size
219-
)
208+
try:
209+
act_quant_mode_dict = {
210+
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
211+
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
212+
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
213+
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
214+
}
215+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
216+
weight_dtype=weight_dtype,
217+
lowp_mode=lowp_mode,
218+
act_quant_mode=act_quant_mode_dict[args.act_quant_mode],
219+
group_size=args.group_size
220+
)
221+
except:
222+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
223+
weight_dtype=weight_dtype,
224+
lowp_mode=lowp_mode,
225+
)
226+
220227
if args.low_precision_checkpoint != "":
221228
low_precision_checkpoint = torch.load(args.low_precision_checkpoint)
222229
else:

examples/cpu/inference/python/llm/single_instance/run_gpt-j_quantization.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,18 +196,25 @@
196196
else:
197197
lowp_mode = ipex.quantization.WoqLowpMode.BF16
198198

199-
act_quant_mode_dict = {
200-
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
201-
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
202-
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
203-
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
204-
}
205-
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
206-
weight_dtype=weight_dtype,
207-
lowp_mode=lowp_mode,
208-
act_quant_mode=act_quant_mode_dict[args.act_quant_mode],
209-
group_size=args.group_size
210-
)
199+
try:
200+
act_quant_mode_dict = {
201+
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
202+
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
203+
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
204+
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
205+
}
206+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
207+
weight_dtype=weight_dtype,
208+
lowp_mode=lowp_mode,
209+
act_quant_mode=act_quant_mode_dict[args.act_quant_mode],
210+
group_size=args.group_size
211+
)
212+
except:
213+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
214+
weight_dtype=weight_dtype,
215+
lowp_mode=lowp_mode,
216+
)
217+
211218
if args.low_precision_checkpoint != "":
212219
low_precision_checkpoint = torch.load(args.low_precision_checkpoint)
213220
else:

examples/cpu/inference/python/llm/single_instance/run_gpt-neox_quantization.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,25 @@
193193
else:
194194
lowp_mode = ipex.quantization.WoqLowpMode.BF16
195195

196-
act_quant_mode_dict = {
197-
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
198-
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
199-
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
200-
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
201-
}
202-
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
203-
weight_dtype=weight_dtype,
204-
lowp_mode=lowp_mode,
205-
act_quant_mode=act_quant_mode_dict[args.act_quant_mode],
206-
group_size=args.group_size
207-
)
196+
try:
197+
act_quant_mode_dict = {
198+
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
199+
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
200+
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
201+
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
202+
}
203+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
204+
weight_dtype=weight_dtype,
205+
lowp_mode=lowp_mode,
206+
act_quant_mode=act_quant_mode_dict[args.act_quant_mode],
207+
group_size=args.group_size
208+
)
209+
except:
210+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
211+
weight_dtype=weight_dtype,
212+
lowp_mode=lowp_mode,
213+
)
214+
208215
if args.low_precision_checkpoint != "":
209216
low_precision_checkpoint = torch.load(args.low_precision_checkpoint)
210217
else:

examples/cpu/inference/python/llm/single_instance/run_llama_quantization.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -329,19 +329,25 @@ def calib_func(prepared_model):
329329
lowp_mode = ipex.quantization.WoqLowpMode.INT8
330330
else:
331331
lowp_mode = ipex.quantization.WoqLowpMode.BF16
332+
try:
333+
act_quant_mode_dict = {
334+
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
335+
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
336+
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
337+
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
338+
}
339+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
340+
weight_dtype=weight_dtype,
341+
lowp_mode=lowp_mode,
342+
act_quant_mode=act_quant_mode_dict[args.act_quant_mode],
343+
group_size=args.group_size
344+
)
345+
except:
346+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
347+
weight_dtype=weight_dtype,
348+
lowp_mode=lowp_mode,
349+
)
332350

333-
act_quant_mode_dict = {
334-
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
335-
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
336-
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
337-
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
338-
}
339-
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
340-
weight_dtype=weight_dtype,
341-
lowp_mode=lowp_mode,
342-
act_quant_mode=act_quant_mode_dict[args.act_quant_mode],
343-
group_size=args.group_size
344-
)
345351
if args.low_precision_checkpoint != "":
346352
low_precision_checkpoint = torch.load(args.low_precision_checkpoint)
347353
else:

examples/cpu/inference/python/llm/single_instance/run_opt_quantization.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,25 @@
190190
else:
191191
lowp_mode = ipex.quantization.WoqLowpMode.BF16
192192

193-
act_quant_mode_dict = {
194-
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
195-
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
196-
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
197-
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
198-
}
199-
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
200-
weight_dtype=weight_dtype,
201-
lowp_mode=lowp_mode,
202-
act_quant_mode=act_quant_mode_dict[args.act_quant_mode],
203-
group_size=args.group_size
204-
)
193+
try:
194+
act_quant_mode_dict = {
195+
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
196+
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
197+
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
198+
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
199+
}
200+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
201+
weight_dtype=weight_dtype,
202+
lowp_mode=lowp_mode,
203+
act_quant_mode=act_quant_mode_dict[args.act_quant_mode],
204+
group_size=args.group_size
205+
)
206+
except:
207+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
208+
weight_dtype=weight_dtype,
209+
lowp_mode=lowp_mode,
210+
)
211+
205212
if args.low_precision_checkpoint != "":
206213
low_precision_checkpoint = torch.load(args.low_precision_checkpoint)
207214
else:

0 commit comments

Comments
 (0)