Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions neural_compressor/torch/algorithms/weight_only/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
from torch.nn import functional as F

from neural_compressor.torch.utils import (
Version,
accelerator,
can_pack_with_numba,
get_hpex_version,
is_hpex_support_g_idx,
logger,
)

Expand Down Expand Up @@ -731,7 +730,7 @@ def __init__(
)
else:
self.g_idx = None
self.support_g_idx = True if get_hpex_version() >= Version("1.23.0") else False
self.support_g_idx = is_hpex_support_g_idx()

self.half_indim = self.in_features // 2

Expand Down
16 changes: 9 additions & 7 deletions neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import importlib
import os
import sys
from functools import lru_cache

import torch
from packaging.version import Version
Expand Down Expand Up @@ -79,19 +80,20 @@ def is_hpu_available():
return get_accelerator().name() == "hpu"


def get_hpex_version():
@lru_cache(None)
def is_hpex_support_g_idx():
"""Return ipex version if ipex exists."""
if is_hpex_available():
try:
import habana_frameworks.torch
import torch

hpex_version = habana_frameworks.torch.__version__
except ValueError as e: # pragma: no cover
assert False, "Got an unknown version of habana_frameworks.torch: {}".format(e)
version = Version(hpex_version)
return version
schema = torch._C._get_schema("hpu::convert_from_int4", "")
return "group_index" in str(schema)
except: # pragma: no cover
return False
else:
return None
return False


## check optimum
Expand Down