Skip to content

better solution for checking g_idx support #2251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions neural_compressor/torch/algorithms/weight_only/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
from torch.nn import functional as F

from neural_compressor.torch.utils import (
Version,
accelerator,
can_pack_with_numba,
get_hpex_version,
is_hpex_support_g_idx,
logger,
)

Expand Down Expand Up @@ -731,7 +730,7 @@ def __init__(
)
else:
self.g_idx = None
self.support_g_idx = True if get_hpex_version() >= Version("1.23.0") else False
self.support_g_idx = is_hpex_support_g_idx()

self.half_indim = self.in_features // 2

Expand Down
18 changes: 10 additions & 8 deletions neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import importlib
import os
import sys
from functools import lru_cache

import torch
from packaging.version import Version
Expand Down Expand Up @@ -79,19 +80,20 @@ def is_hpu_available():
return get_accelerator().name() == "hpu"


def get_hpex_version():
"""Return ipex version if ipex exists."""
@lru_cache(None)
def is_hpex_support_g_idx():
"""Check if HPEX supports group_index in the schema of hpu::convert_from_int4."""
if is_hpex_available():
try:
import habana_frameworks.torch
import torch

hpex_version = habana_frameworks.torch.__version__
except ValueError as e: # pragma: no cover
assert False, "Got an unknown version of habana_frameworks.torch: {}".format(e)
version = Version(hpex_version)
return version
schema = torch._C._get_schema("hpu::convert_from_int4", "")
return "group_index" in str(schema)
except: # pragma: no cover
return False
else:
return None
return False


## check optimum
Expand Down