Skip to content

Commit bb3ab32

Browse files
authored
Merge branch 'main' into sdk_dependency
Signed-off-by: Mamta Singh <[email protected]>
2 parents f8515d9 + d020b88 commit bb3ab32

35 files changed

+1708
-964
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#
66
# ----------------------------------------------------------------------------
77

8-
import hashlib
98
import inspect
109
import logging
1110
import shutil
@@ -22,8 +21,16 @@
2221
from QEfficient.base.pytorch_transforms import PytorchTransform
2322
from QEfficient.compile.qnn_compiler import compile as qnn_compile
2423
from QEfficient.generation.cloud_infer import QAICInferenceSession
25-
from QEfficient.utils import constants, create_json, dump_qconfig, generate_mdp_partition_config, load_json
26-
from QEfficient.utils.cache import QEFF_HOME, to_hashable
24+
from QEfficient.utils import (
25+
constants,
26+
create_json,
27+
create_model_params,
28+
dump_qconfig,
29+
export_wrapper,
30+
generate_mdp_partition_config,
31+
hash_dict_params,
32+
load_json,
33+
)
2734

2835
logger = logging.getLogger(__name__)
2936

@@ -45,12 +52,16 @@ class QEFFBaseModel(ABC):
4552
def _transform_names(cls) -> List[str]:
4653
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
4754

48-
def __init__(self, model: torch.nn.Module) -> None:
55+
def __init__(self, model: torch.nn.Module, **kwargs) -> None:
4956
super().__init__()
5057
self.model = model
58+
self.hash_params = create_model_params(self, **kwargs)
5159
self.onnx_path: Optional[str] = None
5260
self.qpc_path: Optional[str] = None
5361
self.qpc_session: Optional[QAICInferenceSession] = None
62+
self.model_architecture = (
63+
(arch := getattr(self.model.config, "architectures", None)) and len(arch) > 0 and arch[0]
64+
) or None
5465

5566
# Apply the transformations
5667
any_transformed = False
@@ -67,10 +78,6 @@ def __init__(self, model: torch.nn.Module) -> None:
6778
@abstractmethod
6879
def model_name(self) -> str: ...
6980

70-
@property
71-
@abstractmethod
72-
def model_hash(self) -> str: ...
73-
7481
@abstractmethod
7582
def export(self, export_dir: Optional[str] = None) -> Path:
7683
"""
@@ -114,6 +121,7 @@ def compile(self, *args, **kwargs) -> Path:
114121
:str: Path of the compiled ``qpc`` package.
115122
"""
116123

124+
@export_wrapper
117125
def _export(
118126
self,
119127
example_inputs: Dict[str, torch.Tensor],
@@ -134,8 +142,6 @@ def _export(
134142
:onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
135143
:export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
136144
"""
137-
export_dir = Path(export_dir or (QEFF_HOME / self.model_name))
138-
export_dir = export_dir.with_name(export_dir.name + "-" + self.model_hash)
139145
onnx_path = export_dir / f"{self.model_name}.onnx"
140146
if onnx_path.is_file():
141147
self.onnx_path = onnx_path
@@ -299,23 +305,16 @@ def _compile(
299305
else:
300306
mdp_ts_json = None
301307

302-
compile_hash = hashlib.sha256(to_hashable(command))
303-
304-
if specializations is not None:
305-
compile_hash.update(to_hashable(specializations))
306-
307-
if custom_io is not None:
308-
compile_hash.update(to_hashable(custom_io))
309-
310-
if num_speculative_tokens:
311-
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))
312-
313-
# Hash the MDP partition config and the number of devices.
314-
compile_hash.update(to_hashable(mdp_ts_json))
315-
compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices}))
308+
compile_hash_params = {
309+
"command": command,
310+
"specializations": specializations,
311+
"custom_io": custom_io,
312+
"mdp_ts_num_devices": mdp_ts_num_devices,
313+
"mdp_ts_json": mdp_ts_json,
314+
"num_speculative_tokens": num_speculative_tokens,
315+
}
316+
compile_hash = hash_dict_params(compile_hash_params)
316317

317-
# Check if already compiled
318-
compile_hash = compile_hash.hexdigest()[:16]
319318
compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash)
320319
qpc_path = compile_dir / "qpc"
321320
qpc_path.mkdir(parents=True, exist_ok=True)
@@ -366,6 +365,10 @@ def _compile(
366365
]
367366
)
368367
)
368+
# Dump JSON file with hashed parameters
369+
hashed_compile_params_path = compile_dir / "hashed_compile_params.json"
370+
create_json(hashed_compile_params_path, compile_hash_params)
371+
logger.info("Hashed parameters exported successfully.")
369372

370373
self.qpc_path = qpc_path
371374

QEfficient/compile/qnn_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from typing import Dict, List, Optional
1313

1414
from QEfficient.utils._utils import create_json, execute_command, load_json
15-
from QEfficient.utils.cache import to_hashable
1615
from QEfficient.utils.constants import QnnConstants
1716
from QEfficient.utils.generate_qnn_network_specialization_config import (
1817
generate_data_format_config,
1918
generate_qnn_specialization,
2019
)
20+
from QEfficient.utils.hash_utils import to_hashable
2121
from QEfficient.utils.logging_utils import logger
2222

2323

QEfficient/finetune/utils/train_utils.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,9 @@ def train(
124124

125125
if train_config.use_peft and train_config.from_peft_checkpoint:
126126
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
127+
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
127128
if epoch < intermediate_epoch:
128129
logger.log_rank_zero(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
129-
# to bring the count of train_step in sync with where it left off
130-
total_train_steps += len(train_dataloader)
131130
continue
132131

133132
logger.log_rank_zero(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
@@ -149,20 +148,18 @@ def train(
149148

150149
num_dummy_samples = 0
151150
for step, batch in enumerate(train_dataloader):
151+
# total_train_steps indicates the cumulative number of training steps completed across all epochs.
152+
# When resuming fine-tuning from previously saved checkpoints, total_train_steps indicates the total number of steps trained across the earlier session and the ongoing one.
153+
total_train_steps = (epoch) * len(train_dataloader) + step
152154
# resume training from a particular checkpoint, assuming the dataset is not shuffled
153155
if train_config.use_peft and train_config.from_peft_checkpoint:
154-
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
155-
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
156156
# to bring the count of train_step in sync with where it left off
157157
if epoch == intermediate_epoch and step == 0:
158-
total_train_steps += intermediate_step
159158
logger.log_rank_zero(
160159
f"Skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for it."
161160
)
162161
if epoch == intermediate_epoch and step < intermediate_step:
163-
total_train_steps += 1
164162
continue
165-
total_train_steps += 1
166163

167164
if train_config.max_train_step > 0 and total_train_steps >= train_config.max_train_step:
168165
max_steps_reached = True
@@ -235,12 +232,12 @@ def train(
235232
else:
236233
num_samples_in_cur_update = len(train_dataloader) % train_config.gradient_accumulation_steps
237234

238-
loss = loss / num_samples_in_cur_update
235+
normalized_loss = loss / num_samples_in_cur_update
239236

240237
if train_config.grad_scaler:
241-
scaler.scale(loss).backward() # backward pass
238+
scaler.scale(normalized_loss).backward() # backward pass
242239
else:
243-
loss.backward() # backward pass
240+
normalized_loss.backward() # backward pass
244241

245242
if is_optimizer_step:
246243
if train_config.grad_scaler:
@@ -358,7 +355,6 @@ def train(
358355
logger.log_rank_zero(
359356
f"Epoch {epoch + 1}: Train epoch loss: {train_epoch_loss:.4f}, Train metric: {train_epoch_metric:.4f}, Epoch time {epoch_end_time:.2f} sec"
360357
)
361-
362358
# Saving the results every epoch to plot later
363359
if train_config.save_metrics:
364360
save_to_json(
@@ -377,9 +373,14 @@ def train(
377373

378374
results["last_epoch_train_loss"] = train_epoch_loss.cpu()
379375
results["last_epoch_train_metric"] = train_epoch_metric.cpu()
376+
results["train_step_loss"] = train_step_loss
377+
results["train_step_metric"] = train_step_metric
378+
380379
if train_config.run_validation:
381380
results["last_epoch_eval_loss"] = eval_epoch_loss.cpu()
382381
results["last_epoch_eval_metric"] = eval_epoch_metric.cpu()
382+
results["eval_step_loss"] = eval_step_loss
383+
results["eval_step_metric"] = eval_step_metric
383384
results["avg_epoch_time"] = avg_epoch_time
384385
results["avg_checkpoint_time"] = avg_checkpoint_time
385386
if train_config.save_metrics:

QEfficient/peft/auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform
2828
from QEfficient.utils import constants
2929
from QEfficient.utils._utils import get_padding_shape_from_config
30-
from QEfficient.utils.cache import to_hashable
30+
from QEfficient.utils.hash_utils import to_hashable
3131

3232
logger = logging.getLogger(__name__)
3333

QEfficient/peft/lora/auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from QEfficient import QEFFAutoModelForCausalLM
1919
from QEfficient.peft.lora.pytorch_transforms import LoraModelInputsTransform, TargetModulesTransform
2020
from QEfficient.utils import constants, get_padding_shape_from_config
21-
from QEfficient.utils.cache import to_hashable
21+
from QEfficient.utils.hash_utils import to_hashable
2222
from QEfficient.utils.logging_utils import logger
2323

2424

QEfficient/transformers/models/llama4/modeling_llama4.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -925,14 +925,6 @@ def get_specializations(
925925
)
926926
vision_size = num_features_per_tile * max_num_tiles
927927

928-
downsample_ratio = int(round(1.0 / (self.config.vision_config.pixel_shuffle_ratio**2)))
929-
num_features_per_tile = int(
930-
(img_size // self.config.vision_config.patch_size)
931-
* (img_size // self.config.vision_config.patch_size)
932-
// downsample_ratio
933-
)
934-
vision_size = num_features_per_tile * max_num_tiles
935-
936928
vision = [
937929
{
938930
"batch_size": batch_size,

0 commit comments

Comments
 (0)