@@ -29,11 +29,11 @@ def _get_gpu_info(self) -> Dict[str, Any]:
2929 """Get comprehensive GPU information."""
3030 gpu_info = {"available" : False , "devices" : []}
3131
32- if torch .cuda .is_available ():
32+ if torch .cuda .is_available ():
3333 gpu_info ["available" ] = True
3434 gpu_info ["device_count" ] = torch .cuda .device_count ()
3535
36- for i in range (torch .cuda .device_count ()):
36+ for i in range (torch .cuda .device_count ()):
3737 props = torch .cuda .get_device_properties (i )
3838 total_mem = props .total_memory / (1024 ** 3 )
3939 allocated_mem = torch .cuda .memory_allocated (i ) / (1024 ** 3 )
@@ -155,14 +155,14 @@ def estimate_model_size(model_name: Union[str, PreTrainedModel]) -> Dict[str, fl
155155
156156def _estimate_from_config (config , model_name : str ) -> Dict [str , float ]:
157157 """Estimate model size from configuration."""
158- if hasattr (config , 'num_parameters' ):
158+ if hasattr (config , 'num_parameters' ):
159159 params = config .num_parameters
160- elif hasattr (config , 'n_params' ):
160+ elif hasattr (config , 'n_params' ):
161161 params = config .n_params
162- elif hasattr (config , 'hidden_size' ) and hasattr (config , 'num_hidden_layers' ):
162+ elif hasattr (config , 'hidden_size' ) and hasattr (config , 'num_hidden_layers' ):
163163 # Enhanced estimation for various architectures
164- hidden_size = config .hidden_size
165- num_layers = config .num_hidden_layers
164+ hidden_size = config .hidden_size
165+ num_layers = config .num_hidden_layers
166166 vocab_size = getattr (config , 'vocab_size' , 32000 )
167167
168168 # Architecture-specific calculations
@@ -187,9 +187,9 @@ def _estimate_from_config(config, model_name: str) -> Dict[str, float]:
187187 "total_params" : params ,
188188 "size_fp16_gb" : size_fp16 ,
189189 "size_fp32_gb" : size_fp16 * 2 ,
190- "embedding_params" : 0 , # Not available from config
191- "attention_params" : 0 , # Not available from config
192- "other_params" : params
190+ "embedding_params" : embedding_params if 'embedding_params' in locals () else 0 ,
191+ "attention_params" : attention_params if 'attention_params' in locals () else 0 ,
192+ "other_params" : params - ( embedding_params + attention_params ) if 'embedding_params' in locals () else params
193193 }
194194
195195def _fallback_size_estimation (model_name : str ) -> Dict [str , float ]:
@@ -324,6 +324,71 @@ def get_system_memory():
324324class QuantLLM :
325325 """Enhanced high-level API for GGUF model quantization."""
326326
327+ def __init__ (self ):
328+ """Initialize QuantLLM with system monitoring."""
329+ self .system_monitor = SystemResourceMonitor ()
330+ self .progress_tracker = None
331+
332+ def get_system_info (self ) -> Dict [str , Any ]:
333+ """Get current system information."""
334+ return {
335+ "gpu" : self .system_monitor .gpu_info ,
336+ "cpu" : self .system_monitor .cpu_info ,
337+ "memory" : self .system_monitor .memory_info
338+ }
339+
340+ def get_optimal_config (self , model_size_gb : float ) -> Dict [str , Any ]:
341+ """Get optimal configuration based on system resources."""
342+ return self .system_monitor .get_optimal_config (model_size_gb )
343+
344+ def estimate_model_size (self , model_name : Union [str , PreTrainedModel ]) -> Dict [str , float ]:
345+ """Estimate model size and get detailed breakdown."""
346+ return estimate_model_size (model_name )
347+
348+ def get_recommended_bits (
349+ self ,
350+ model_size_gb : float ,
351+ target_size_gb : Optional [float ] = None ,
352+ priority : str = "balanced"
353+ ) -> Tuple [int , str ]:
354+ """Get recommended quantization bits and type."""
355+ return self .get_recommended_quant_type (
356+ model_size_gb = model_size_gb ,
357+ target_size_gb = target_size_gb ,
358+ priority = priority
359+ )
360+
361+ def start_progress_tracking (self , total_steps : int = 100 ):
362+ """Initialize progress tracking."""
363+ self .progress_tracker = ProgressTracker ()
364+ self .progress_tracker .start (total_steps )
365+
366+ def update_progress (self , step : int , message : Optional [str ] = None ):
367+ """Update progress tracking."""
368+ if self .progress_tracker :
369+ self .progress_tracker .update (step , message )
370+
371+ def end_progress_tracking (self ):
372+ """End progress tracking."""
373+ if self .progress_tracker :
374+ self .progress_tracker .finish ()
375+ self .progress_tracker = None
376+
377+ def cleanup (self ):
378+ """Clean up resources."""
379+ if torch .cuda .is_available ():
380+ torch .cuda .empty_cache ()
381+
382+ def __enter__ (self ):
383+ """Context manager entry."""
384+ return self
385+
386+ def __exit__ (self , exc_type , exc_val , exc_tb ):
387+ """Context manager exit with cleanup."""
388+ self .cleanup ()
389+ if self .progress_tracker :
390+ self .end_progress_tracking ()
391+
327392 @staticmethod
328393 def list_quant_types (bits : Optional [int ] = None ) -> Dict [str , str ]:
329394 """
@@ -507,44 +572,14 @@ def quantize_from_pretrained(
507572 verbose : bool = True ,
508573 progress_callback : Optional [Callable ] = None
509574 ) -> PreTrainedModel :
510- """
511- Quantize a model using GGUF format with optimized resource handling.
512-
513- Args:
514- model_name: Model identifier or instance
515- bits: Number of bits for GGUF quantization
516- group_size: Size of quantization groups
517- quant_type: GGUF quantization type
518- use_packed: Whether to use packed format
519- device: Target device for quantization
520- load_in_8bit: Whether to load model in 8-bit precision
521- load_in_4bit: Whether to load model in 4-bit precision
522- bnb_4bit_quant_type: BitsAndBytes 4-bit quantization type
523- bnb_4bit_compute_dtype: Compute dtype for 4-bit quantization
524- bnb_4bit_use_double_quant: Whether to use double quantization
525- use_gradient_checkpointing: Whether to use gradient checkpointing
526- device_map: Device mapping strategy
527- max_memory: Maximum memory configuration
528- offload_folder: Folder for offloading
529- offload_state_dict: Whether to offload state dict
530- torch_dtype: Default torch dtype
531- auto_device: Automatically determine optimal device
532- optimize_for: Optimization priority ("speed", "quality", or "balanced")
533- cpu_offload: Whether to use CPU offloading
534- verbose: Whether to show detailed progress
535- progress_callback: Optional callback for progress updates
536-
537- Returns:
538- Quantized model
539- """
540575 try :
541576 # Initialize progress tracking
542577 progress = ProgressTracker ()
543578 if verbose :
544579 progress .start (100 )
545580 progress .start_phase ("Initialization" )
546581
547- logger .log_info (f"Starting GGUF quantization with { bits } bits" )
582+ logger .log_info (f"Starting quantization with { bits } bits" )
548583
549584 if bits not in SUPPORTED_GGUF_BITS :
550585 raise ValueError (f"Unsupported bits: { bits } . Supported values: { SUPPORTED_GGUF_BITS } " )
@@ -579,7 +614,7 @@ def quantize_from_pretrained(
579614
580615 if device is None :
581616 device = optimal_config ["device" ]
582- if device_map == "auto" :
617+ if device_map == "auto" :
583618 device_map = optimal_config ["device_map" ]
584619 if max_memory is None :
585620 max_memory = optimal_config .get ("max_memory" )
@@ -592,12 +627,24 @@ def quantize_from_pretrained(
592627 logger .log_info (f" • Device map: { device_map } " )
593628 logger .log_info (f" • CPU offload: { cpu_offload } " )
594629 logger .log_info (f" • Optimization level: { optimal_config ['optimization_level' ]} " )
595-
596- # Configure BitsAndBytes for 4-bit quantization
597- if load_in_4bit :
630+
631+ # Configure quantization based on bits
632+ if bits <= 4 :
633+ load_in_4bit = True
634+ load_in_8bit = False
635+ elif bits <= 8 :
636+ load_in_8bit = True
637+ load_in_4bit = False
638+ else :
639+ load_in_4bit = False
640+ load_in_8bit = False
641+
642+ # Configure BitsAndBytes for quantization
643+ if load_in_4bit or load_in_8bit :
598644 compute_dtype = bnb_4bit_compute_dtype or torch .float16
599645 bnb_config = BitsAndBytesConfig (
600- load_in_4bit = True ,
646+ load_in_4bit = load_in_4bit ,
647+ load_in_8bit = load_in_8bit ,
601648 bnb_4bit_quant_type = bnb_4bit_quant_type ,
602649 bnb_4bit_compute_dtype = compute_dtype ,
603650 bnb_4bit_use_double_quant = bnb_4bit_use_double_quant ,
@@ -608,7 +655,7 @@ def quantize_from_pretrained(
608655
609656 # If no quant_type specified, use recommended type
610657 if not quant_type :
611- bits , quant_type = QuantLLM .get_recommended_quant_type (
658+ _ , quant_type = QuantLLM .get_recommended_quant_type (
612659 model_size_gb = model_size_gb ,
613660 priority = optimize_for
614661 )
@@ -644,8 +691,13 @@ def quantize_from_pretrained(
644691 if verbose :
645692 progress .update (40 , "Quantizer created, starting quantization..." )
646693
647- # Store quantizer instance in model for later use
694+ # Store quantizer instance and config in model for later use
648695 quantizer .model ._quantizer = quantizer
696+ quantizer .model .config .quantization_config = {
697+ "bits" : bits ,
698+ "quant_type" : quant_type ,
699+ "group_size" : group_size
700+ }
649701
650702 if verbose :
651703 progress .update (30 , "Quantization completed" )
@@ -703,10 +755,10 @@ def save_quantized_model(
703755 # Get original model path from cache if available
704756 original_path = None
705757 if hasattr (model , 'config' ) and hasattr (model .config , '_name_or_path' ):
706- from transformers .utils import HUGGINGFACE_HUB_CACHE
758+ from huggingface_hub import HfFolder
759+ cache_dir = os .getenv ('HF_HOME' , HfFolder .default_cache_path )
707760 model_id = model .config ._name_or_path
708761 if '/' in model_id : # It's a hub model
709- cache_dir = os .getenv ('TRANSFORMERS_CACHE' , HUGGINGFACE_HUB_CACHE )
710762 org , model_name = model_id .split ('/' )
711763 potential_paths = glob .glob (os .path .join (cache_dir , 'models--' + org + '--' + model_name , '*' , 'snapshots' , '*' ))
712764 if potential_paths :
0 commit comments