Skip to content
Open
20 changes: 15 additions & 5 deletions keras/src/layers/preprocessing/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,6 @@ def __init__(
dtype=None,
name=None,
):
if dtype is None:
dtype = "int64" if output_mode == "int" else backend.floatx()

super().__init__(name=name, dtype=dtype)

if sparse and not backend.SUPPORTS_SPARSE_TENSORS:
Expand Down Expand Up @@ -155,6 +152,13 @@ def __init__(
def input_dtype(self):
return backend.floatx()

@property
def compute_dtype(self):
if self.output_mode == "int":
return "int64"
else:
return backend.floatx()

def adapt(self, data, steps=None):
"""Computes bin boundaries from quantiles in a input dataset.

Expand Down Expand Up @@ -213,7 +217,10 @@ def reset_state(self):
self.summary = np.array([[], []], dtype="float32")

def compute_output_spec(self, inputs):
return backend.KerasTensor(shape=inputs.shape, dtype=self.compute_dtype)
output_dtype = (
"int64" if self.output_mode == "int" else self.compute_dtype
)
return backend.KerasTensor(shape=inputs.shape, dtype=output_dtype)

def load_own_variables(self, store):
if len(store) == 1:
Expand All @@ -230,11 +237,14 @@ def call(self, inputs):
)

indices = self.backend.numpy.digitize(inputs, self.bin_boundaries)
output_dtype = (
"int64" if self.output_mode == "int" else self.compute_dtype
)
return numerical_utils.encode_categorical_inputs(
indices,
output_mode=self.output_mode,
depth=len(self.bin_boundaries) + 1,
dtype=self.compute_dtype,
dtype=output_dtype,
sparse=self.sparse,
backend_module=self.backend,
)
Expand Down
21 changes: 21 additions & 0 deletions keras/src/layers/preprocessing/discretization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,24 @@ def test_call_before_adapt_raises(self):
layer = layers.Discretization(num_bins=3)
with self.assertRaisesRegex(ValueError, "You need .* call .*adapt"):
layer([[0.1, 0.8, 0.9]])

def test_model_call_vs_predict_consistency(self):
"""Test that model(input) and model.predict(input) produce consistent outputs.""" # noqa: E501
# Test with int output mode
layer = layers.Discretization(
bin_boundaries=[-0.5, 0, 0.1, 0.2, 3],
output_mode="int",
)
x = np.array([[0.0, 0.15, 0.21, 0.3], [0.0, 0.17, 0.451, 7.8]])

# Create model
inputs = layers.Input(shape=(4,), dtype="float32")
outputs = layer(inputs)
model = models.Model(inputs=inputs, outputs=outputs)

# Test both execution modes
model_call_output = model(x)
predict_output = model.predict(x)

# Check consistency
self.assertAllClose(model_call_output, predict_output)
Loading