Skip to content

Commit 256ed0d

Browse files
Update llama_cpp_utils.py
1 parent 2c064e8 commit 256ed0d

File tree

1 file changed

+55
-83
lines changed

1 file changed

+55
-83
lines changed

quantllm/quant/llama_cpp_utils.py

Lines changed: 55 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def get_progress_info(self) -> Dict[str, Any]:
3535
elapsed = time.time() - self.start_time
3636
progress_pct = min((self.current_step / self.total_steps) * 100, 100)
3737

38-
# Estimate remaining time
3938
if len(self.step_times) > 1:
4039
avg_step_time = elapsed / len(self.step_times)
4140
remaining_steps = self.total_steps - self.current_step
@@ -60,16 +59,16 @@ class LlamaCppConverter:
6059
"stablelm", "phi", "gemma", "qwen", "baichuan", "yi"
6160
]
6261

63-
# Enhanced quantization type mapping with better defaults
62+
# Enhanced quantization type mapping with 2-bit support
6463
QUANT_TYPE_MAP = {
65-
2: "q2_k", # 2-bit with K-quant for better quality
66-
3: "q3_k_m", # 3-bit medium quality
67-
4: "q4_k_m", # 4-bit medium quality (balanced)
68-
5: "q5_k_m", # 5-bit medium quality
69-
6: "q6_k", # 6-bit high quality
70-
8: "q8_0", # 8-bit highest quality
71-
16: "f16", # 16-bit float
72-
32: "f32" # 32-bit float
64+
2: ["q2_k", "iq2_xxs", "iq2_xs"],
65+
3: ["q3_k_m", "q3_k_s", "q3_k_l"],
66+
4: ["q4_k_m", "q4_k_s", "q4_0", "q4_1"],
67+
5: ["q5_k_m", "q5_k_s", "q5_0", "q5_1"],
68+
6: ["q6_k"],
69+
8: ["q8_0"],
70+
16: ["f16"],
71+
32: ["f32"]
7372
}
7473

7574
# Performance optimization flags
@@ -83,11 +82,11 @@ def __init__(self, verbose: bool = True):
8382
"""Initialize converter with performance optimizations."""
8483
self.verbose = verbose
8584
self.progress_tracker = None
85+
self.convert_script = None
86+
self.quantize_bin = None
8687

8788
try:
88-
# Try multiple ways to find llama-cpp-python
8989
self._find_llama_cpp_installation()
90-
9190
except ImportError as e:
9291
raise ImportError(
9392
"llama-cpp-python is required for GGUF conversion.\n"
@@ -98,49 +97,32 @@ def __init__(self, verbose: bool = True):
9897
def _find_llama_cpp_installation(self):
9998
"""Find llama-cpp-python installation and conversion scripts."""
10099
try:
101-
# First try: Direct import
102100
import llama_cpp
103101
self.llama_cpp_path = os.path.dirname(llama_cpp.__file__)
104102

105-
# Look for conversion script in package
103+
# Look for conversion script
106104
script_path = os.path.join(self.llama_cpp_path, "convert.py")
107105
if os.path.exists(script_path):
108106
self.convert_script = script_path
109-
return
110-
111-
# Look in package scripts directory
112-
scripts_dir = os.path.join(os.path.dirname(self.llama_cpp_path), "scripts")
113-
if os.path.exists(scripts_dir):
107+
else:
108+
scripts_dir = os.path.join(os.path.dirname(self.llama_cpp_path), "scripts")
114109
for script in ["convert.py", "convert_hf_to_gguf.py"]:
115110
script_path = os.path.join(scripts_dir, script)
116111
if os.path.exists(script_path):
117112
self.convert_script = script_path
118-
return
113+
break
119114

120-
# Try pip installation path
121-
try:
122-
import site
123-
for site_dir in site.getsitepackages():
124-
for script in ["convert.py", "convert_hf_to_gguf.py"]:
125-
script_path = os.path.join(site_dir, "llama_cpp", script)
126-
if os.path.exists(script_path):
127-
self.convert_script = script_path
128-
return
129-
script_path = os.path.join(site_dir, "llama_cpp", "scripts", script)
130-
if os.path.exists(script_path):
131-
self.convert_script = script_path
132-
return
133-
except Exception:
134-
pass
115+
# Look for quantize binary
116+
quantize_path = os.path.join(self.llama_cpp_path, "quantize")
117+
if os.path.exists(quantize_path):
118+
self.quantize_bin = quantize_path
135119

136-
# Try PATH
137-
for script in ["convert.py", "convert_hf_to_gguf.py"]:
138-
script_path = shutil.which(script)
139-
if script_path:
140-
self.convert_script = script_path
141-
return
142-
143-
raise ImportError("GGUF conversion script not found")
120+
if not self.convert_script or not self.quantize_bin:
121+
raise FileNotFoundError("Required llama.cpp tools (convert.py or quantize) not found")
122+
123+
if self.verbose:
124+
logger.log_info(f"Found convert.py: {self.convert_script}")
125+
logger.log_info(f"Found quantize: {self.quantize_bin}")
144126

145127
except ImportError:
146128
raise ImportError(
@@ -155,7 +137,6 @@ def _detect_model_type(self, model: PreTrainedModel) -> str:
155137

156138
model_type = getattr(model.config, 'model_type', 'unknown').lower()
157139

158-
# Enhanced type mapping
159140
type_mapping = {
160141
"llama": "llama",
161142
"llama2": "llama",
@@ -189,15 +170,12 @@ def _optimize_for_conversion(self, model: PreTrainedModel) -> PreTrainedModel:
189170
if self.verbose:
190171
logger.log_info("⚡ Applying conversion optimizations...")
191172

192-
# Move to CPU and optimize memory
193173
if hasattr(model, 'to'):
194174
model = model.to('cpu')
195175

196-
# Clear CUDA cache
197176
if torch.cuda.is_available():
198177
torch.cuda.empty_cache()
199178

200-
# Enable eval mode for stability
201179
model.eval()
202180

203181
return model
@@ -219,7 +197,6 @@ def read_output():
219197
if not line:
220198
continue
221199

222-
# Parse progress indicators
223200
if "%" in line or "processing" in line.lower():
224201
current_step += 1
225202
if self.progress_tracker:
@@ -228,7 +205,6 @@ def read_output():
228205
if progress_callback:
229206
progress_callback(current_step, line)
230207

231-
# Log important messages
232208
if any(keyword in line.lower() for keyword in ['error', 'warning', 'failed']):
233209
logger.log_warning(f"⚠️ {line}")
234210
elif any(keyword in line.lower() for keyword in ['completed', 'success', 'done']):
@@ -249,7 +225,8 @@ def convert_to_gguf(
249225
optimization_level: str = "balanced",
250226
save_tokenizer: bool = True,
251227
progress_callback: Optional[Callable] = None,
252-
custom_name: Optional[str] = None
228+
custom_name: Optional[str] = None,
229+
quant_type: Optional[str] = None
253230
) -> str:
254231
"""
255232
Convert model to GGUF format with enhanced performance and logging.
@@ -264,18 +241,17 @@ def convert_to_gguf(
264241
save_tokenizer: Whether to save tokenizer
265242
progress_callback: Optional progress callback function
266243
custom_name: Custom output filename (default: model.gguf)
244+
quant_type: Specific quantization type (e.g., Q2_K, IQ2_XXS)
267245
268246
Returns:
269247
Path to generated GGUF file
270248
"""
271249
start_time = time.time()
272250

273251
try:
274-
# Setup output directory
275252
output_dir = os.path.abspath(output_dir)
276253
os.makedirs(output_dir, exist_ok=True)
277254

278-
# Generate output filename
279255
if custom_name:
280256
filename = custom_name if custom_name.endswith('.gguf') else f"{custom_name}.gguf"
281257
else:
@@ -295,45 +271,50 @@ def convert_to_gguf(
295271
logger.log_info(f"💾 Output: {gguf_path}")
296272
logger.log_info("="*80 + "\n")
297273

298-
# Use a temporary directory for intermediate files
299274
with tempfile.TemporaryDirectory(prefix="gguf_convert_", dir=output_dir) as temp_dir:
300275
if self.verbose:
301276
logger.log_info("📁 Setting up workspace...")
302277

303-
# Initialize progress tracking
304278
self.progress_tracker = ProgressTracker(total_steps=10)
305279

306-
# Step 1: Optimize model
307280
self.progress_tracker.update(1, "Optimizing model...")
308281
model = self._optimize_for_conversion(model)
309282

310-
# Step 2: Save model in optimal format with sharding
311283
self.progress_tracker.update(2, "Saving model...")
312284
model.save_pretrained(
313285
temp_dir,
314286
safe_serialization=True,
315287
max_shard_size="2GB"
316288
)
317289

318-
# Step 3: Save tokenizer if requested
319290
if save_tokenizer:
320291
self.progress_tracker.update(3, "Saving tokenizer...")
321292
self._save_tokenizer(model, temp_dir)
322293

323-
# Step 4: Prepare and run conversion
324-
self.progress_tracker.update(4, "Converting to GGUF...")
325-
cmd = self._build_conversion_command(
326-
temp_dir, gguf_path,
294+
self.progress_tracker.update(4, "Converting to FP16 GGUF...")
295+
temp_gguf = os.path.join(temp_dir, "temp_f16.gguf")
296+
cmd_convert = self._build_conversion_command(
297+
temp_dir, temp_gguf,
327298
model_type or self._detect_model_type(model),
328299
bits, optimization_level
329300
)
330301

331-
success = self._run_conversion(cmd, progress_callback)
302+
success = self._run_conversion(cmd_convert, progress_callback)
303+
if not success or not os.path.exists(temp_gguf):
304+
raise RuntimeError("FP16 GGUF conversion failed")
332305

306+
self.progress_tracker.update(7, f"Quantizing to {quant_type or self.QUANT_TYPE_MAP.get(bits)[0]}...")
307+
cmd_quantize = [
308+
self.quantize_bin,
309+
temp_gguf,
310+
gguf_path,
311+
(quant_type or self.QUANT_TYPE_MAP.get(bits)[0]).lower()
312+
]
313+
314+
success = self._run_conversion(cmd_quantize, progress_callback)
333315
if not success or not os.path.exists(gguf_path):
334-
raise RuntimeError("GGUF conversion failed")
316+
raise RuntimeError(f"GGUF quantization to {quant_type} failed")
335317

336-
# Log completion
337318
file_size = os.path.getsize(gguf_path) / (1024**3)
338319
elapsed_time = time.time() - start_time
339320

@@ -360,15 +341,13 @@ def _prepare_model_config(self, model: PreTrainedModel, model_type: str) -> Dict
360341
"""Prepare optimized model configuration."""
361342
config = model.config.to_dict() if hasattr(model.config, 'to_dict') else {}
362343

363-
# Essential fields for GGUF conversion
364344
essential_config = {
365345
"model_type": model_type,
366346
"architectures": getattr(model.config, 'architectures', [model_type]),
367347
"torch_dtype": str(getattr(model, 'dtype', torch.float32)).replace('torch.', ''),
368-
"transformers_version": "4.36.0", # Compatibility version
348+
"transformers_version": "4.36.0",
369349
}
370350

371-
# Preserve important model-specific fields
372351
important_fields = [
373352
'hidden_size', 'num_hidden_layers', 'num_attention_heads',
374353
'vocab_size', 'max_position_embeddings', 'intermediate_size',
@@ -388,7 +367,7 @@ def _save_tokenizer(self, model: PreTrainedModel, temp_dir: str):
388367
tokenizer = AutoTokenizer.from_pretrained(
389368
model.config._name_or_path,
390369
trust_remote_code=True,
391-
use_fast=False # Use slow tokenizer for better compatibility
370+
use_fast=False
392371
)
393372
tokenizer.save_pretrained(temp_dir)
394373
if self.verbose:
@@ -408,35 +387,30 @@ def _build_conversion_command(
408387
bits: int,
409388
optimization_level: str
410389
) -> List[str]:
411-
"""Build optimized conversion command."""
412-
quant_type = self.QUANT_TYPE_MAP.get(bits, "q4_k_m")
413-
390+
"""Build optimized conversion command for FP16."""
414391
cmd = [
415392
sys.executable,
416393
self.convert_script,
417394
temp_dir,
418395
"--outfile", gguf_path,
419-
"--outtype", quant_type,
396+
"--outtype", "f16",
420397
]
421398

422-
# Add optimization flags
423399
if optimization_level in self.OPTIMIZATION_FLAGS:
424400
cmd.extend(self.OPTIMIZATION_FLAGS[optimization_level])
425401

426-
# Add model-specific flags
427402
if model_type != "llama":
428403
cmd.extend(["--model-type", model_type])
429404

430-
# Add memory optimization for large models
431405
cmd.extend([
432-
"--no-lazy", # Disable lazy loading for speed
433-
"--big-endian" # Better compatibility
406+
"--no-lazy",
407+
"--big-endian"
434408
])
435409

436410
return cmd
437411

438412
def _run_conversion(self, cmd: List[str], progress_callback: Optional[Callable] = None) -> bool:
439-
"""Run conversion process with enhanced monitoring."""
413+
"""Run conversion or quantization process with enhanced monitoring."""
440414
try:
441415
process = subprocess.Popen(
442416
cmd,
@@ -446,10 +420,8 @@ def _run_conversion(self, cmd: List[str], progress_callback: Optional[Callable]
446420
bufsize=1
447421
)
448422

449-
# Monitor progress
450423
monitor_thread = self._log_progress(process, progress_callback)
451424

452-
# Wait for completion
453425
return_code = process.wait()
454426

455427
if monitor_thread:
@@ -458,11 +430,11 @@ def _run_conversion(self, cmd: List[str], progress_callback: Optional[Callable]
458430
return return_code == 0
459431

460432
except Exception as e:
461-
logger.log_error(f"Conversion process error: {e}")
433+
logger.log_error(f"Process error: {e}")
462434
return False
463435

464436
def get_supported_quantization_types(self, bits: Optional[int] = None) -> Dict[int, List[str]]:
465437
"""Get supported quantization types."""
466438
if bits:
467-
return {bits: [self.QUANT_TYPE_MAP.get(bits, "q4_k_m")]}
468-
return {k: [v] for k, v in self.QUANT_TYPE_MAP.items()}
439+
return {bits: self.QUANT_TYPE_MAP.get(bits, ["q4_k_m"])}
440+
return self.QUANT_TYPE_MAP.copy()

0 commit comments

Comments
 (0)