@@ -149,11 +149,18 @@ def _load_eager_pretrained(
149149
150150 if qembedding_config :
151151 logging .info ("Quantizing embedding layers." )
152+ embedding_config = {
153+ "4w" : IntxWeightOnlyConfig (
154+ weight_dtype = torch .int4 ,
155+ granularity = PerGroup (32 ),
156+ ),
157+ "8w" : IntxWeightOnlyConfig (
158+ weight_dtype = torch .int8 ,
159+ granularity = PerAxis (0 ),
160+ ),
161+ }[qembedding_config ]
162+
152163 # TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available.
153- embedding_config = IntxWeightOnlyConfig (
154- weight_dtype = torch .int8 ,
155- granularity = PerAxis (0 ),
156- )
157164 quantize_ (
158165 eager_model ,
159166 embedding_config ,
@@ -162,10 +169,20 @@ def _load_eager_pretrained(
162169
163170 if qlinear_config :
164171 logging .info ("Quantizing linear layers." )
165- linear_config = Int8DynamicActivationIntxWeightConfig (
166- weight_dtype = torch .int4 ,
167- weight_granularity = PerGroup (32 ),
168- )
172+ linear_config = {
173+ "8da4w" : Int8DynamicActivationIntxWeightConfig (
174+ weight_dtype = torch .int4 ,
175+ weight_granularity = PerGroup (32 ),
176+ ),
177+ "4w" : IntxWeightOnlyConfig (
178+ weight_dtype = torch .int4 ,
179+ granularity = PerGroup (32 ),
180+ ),
181+ "8w" : IntxWeightOnlyConfig (
182+ weight_dtype = torch .int8 ,
183+ granularity = PerAxis (0 ),
184+ ),
185+ }[qlinear_config ]
169186 quantize_ (
170187 eager_model ,
171188 linear_config ,
0 commit comments