-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Support numpy.prod operation #30212 #21873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -336,9 +336,36 @@ def argmin(x, axis=None, keepdims=False): | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def argsort(x, axis=-1): | ||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||||||||||||||||||||||||||
| "`argsort` is not supported with openvino backend" | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| x = get_ov_output(x) | ||||||||||||||||||||||||||||||||||||||||||
| x_shape = x.get_partial_shape() | ||||||||||||||||||||||||||||||||||||||||||
| rank = x_shape.rank.get_length() | ||||||||||||||||||||||||||||||||||||||||||
| if rank == 0: | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0)) | ||||||||||||||||||||||||||||||||||||||||||
| if axis is None: | ||||||||||||||||||||||||||||||||||||||||||
| flatten_shape = ov_opset.constant([-1], Type.i32).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| x = ov_opset.reshape(x, flatten_shape, False).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| x_shape_tensor = ov_opset.shape_of(x, Type.i32).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| k = ov_opset.reduce_prod( | ||||||||||||||||||||||||||||||||||||||||||
| x_shape_tensor, ov_opset.constant([0], Type.i32), keep_dims=False | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| axis = 0 | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| if axis < 0: | ||||||||||||||||||||||||||||||||||||||||||
| axis = rank + axis | ||||||||||||||||||||||||||||||||||||||||||
| x_shape_tensor = ov_opset.shape_of(x, Type.i32).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| k = ov_opset.gather( | ||||||||||||||||||||||||||||||||||||||||||
| x_shape_tensor, | ||||||||||||||||||||||||||||||||||||||||||
| ov_opset.constant(axis, Type.i32).output(0), | ||||||||||||||||||||||||||||||||||||||||||
| ov_opset.constant(0, Type.i32).output(0), | ||||||||||||||||||||||||||||||||||||||||||
| ).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| sorted_indices = ov_opset.topk( | ||||||||||||||||||||||||||||||||||||||||||
| x, | ||||||||||||||||||||||||||||||||||||||||||
| k=k, | ||||||||||||||||||||||||||||||||||||||||||
| axis=axis, | ||||||||||||||||||||||||||||||||||||||||||
| mode="min", | ||||||||||||||||||||||||||||||||||||||||||
| sort="value", | ||||||||||||||||||||||||||||||||||||||||||
| ).output(1) | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(sorted_indices) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def array(x, dtype=None): | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -380,9 +407,48 @@ def average(x, axis=None, weights=None): | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def bincount(x, weights=None, minlength=0, sparse=False): | ||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||||||||||||||||||||||||||
| "`bincount` is not supported with openvino backend" | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| if x is None: | ||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("input x is None") | ||||||||||||||||||||||||||||||||||||||||||
| if sparse: | ||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Unsupported value `sparse=True`") | ||||||||||||||||||||||||||||||||||||||||||
| x = get_ov_output(x) | ||||||||||||||||||||||||||||||||||||||||||
| x_type = x.get_element_type() | ||||||||||||||||||||||||||||||||||||||||||
| shape_x = ov_opset.shape_of(x, "i64").output(0) | ||||||||||||||||||||||||||||||||||||||||||
| rank_x = ov_opset.shape_of(shape_x, "i64").output(0) | ||||||||||||||||||||||||||||||||||||||||||
| rank_x = ov_opset.convert(rank_x, x_type).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| scalar_shape = ov_opset.constant([], x_type).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| rank_x = ov_opset.reshape(rank_x, scalar_shape, False).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| const_minus_one = ov_opset.constant(-1, x_type).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| rank_minus_one = ov_opset.add(rank_x, const_minus_one).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| minlength = get_ov_output(minlength) | ||||||||||||||||||||||||||||||||||||||||||
| minlength = ov_opset.convert(minlength, x_type).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| const_one = ov_opset.constant(1, x_type).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| const_zero = ov_opset.constant(0, x_type).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| max_element = ov_opset.reduce_max(x, const_zero, keep_dims=False).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| depth = ov_opset.add(max_element, const_one).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| depth = ov_opset.maximum(depth, minlength).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| depth_scalar = ov_opset.reduce_max( | ||||||||||||||||||||||||||||||||||||||||||
| depth, const_zero, keep_dims=False | ||||||||||||||||||||||||||||||||||||||||||
| ).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| one_hot = ov_opset.one_hot( | ||||||||||||||||||||||||||||||||||||||||||
| x, depth_scalar, const_one, const_zero, axis=-1 | ||||||||||||||||||||||||||||||||||||||||||
| ).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| if weights is not None: | ||||||||||||||||||||||||||||||||||||||||||
| weights = get_ov_output(weights) | ||||||||||||||||||||||||||||||||||||||||||
| weights_type = weights.get_element_type() | ||||||||||||||||||||||||||||||||||||||||||
| weights_new = ov_opset.reshape(weights, [-1, 1], False).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| one_hot = ov_opset.convert(one_hot, weights_type).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| final_one_hot = ov_opset.multiply(one_hot, weights_new).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| final_output = ov_opset.reduce_sum( | ||||||||||||||||||||||||||||||||||||||||||
| final_one_hot, rank_minus_one, keep_dims=False | ||||||||||||||||||||||||||||||||||||||||||
| ).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(final_output) | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| final_output = ov_opset.reduce_sum( | ||||||||||||||||||||||||||||||||||||||||||
| one_hot, rank_minus_one, keep_dims=False | ||||||||||||||||||||||||||||||||||||||||||
| ).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| final_output = ov_opset.convert(final_output, Type.i32).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(final_output) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def broadcast_to(x, shape): | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -502,7 +568,76 @@ def diagonal(x, offset=0, axis1=0, axis2=1): | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def diff(a, n=1, axis=-1): | ||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError("`diff` is not supported with openvino backend") | ||||||||||||||||||||||||||||||||||||||||||
| if n == 0: | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(get_ov_output(a)) | ||||||||||||||||||||||||||||||||||||||||||
| if n < 0: | ||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("order must be non-negative but got " + repr(n)) | ||||||||||||||||||||||||||||||||||||||||||
| a = get_ov_output(a) | ||||||||||||||||||||||||||||||||||||||||||
| a_type = a.get_element_type() | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(a, np.ndarray): | ||||||||||||||||||||||||||||||||||||||||||
| rank = a.ndim | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| rank = a.get_partial_shape().rank.get_length() | ||||||||||||||||||||||||||||||||||||||||||
| if axis < 0: | ||||||||||||||||||||||||||||||||||||||||||
| axis = axis + rank | ||||||||||||||||||||||||||||||||||||||||||
| result = a | ||||||||||||||||||||||||||||||||||||||||||
| for _ in range(n): | ||||||||||||||||||||||||||||||||||||||||||
| rank = result.get_partial_shape().rank.get_length() | ||||||||||||||||||||||||||||||||||||||||||
| strides = ov_opset.constant( | ||||||||||||||||||||||||||||||||||||||||||
| np.array([1] * rank, dtype=np.int64), Type.i64 | ||||||||||||||||||||||||||||||||||||||||||
| ).output(0) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| begin_upper_list = [0] * rank | ||||||||||||||||||||||||||||||||||||||||||
| begin_upper_list[axis] = 1 | ||||||||||||||||||||||||||||||||||||||||||
| begin_upper = ov_opset.constant( | ||||||||||||||||||||||||||||||||||||||||||
| np.array(begin_upper_list, dtype=np.int64), Type.i64 | ||||||||||||||||||||||||||||||||||||||||||
| ).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| end_upper = ov_opset.constant( | ||||||||||||||||||||||||||||||||||||||||||
| np.array([0] * rank, dtype=np.int64), Type.i64 | ||||||||||||||||||||||||||||||||||||||||||
| ).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| begin_mask_upper = [1] * rank | ||||||||||||||||||||||||||||||||||||||||||
| begin_mask_upper[axis] = 0 | ||||||||||||||||||||||||||||||||||||||||||
| end_mask_upper = [1] * rank | ||||||||||||||||||||||||||||||||||||||||||
| upper = ov_opset.strided_slice( | ||||||||||||||||||||||||||||||||||||||||||
| data=result, | ||||||||||||||||||||||||||||||||||||||||||
| begin=begin_upper, | ||||||||||||||||||||||||||||||||||||||||||
| end=end_upper, | ||||||||||||||||||||||||||||||||||||||||||
| strides=strides, | ||||||||||||||||||||||||||||||||||||||||||
| begin_mask=begin_mask_upper, | ||||||||||||||||||||||||||||||||||||||||||
| end_mask=end_mask_upper, | ||||||||||||||||||||||||||||||||||||||||||
| new_axis_mask=[], | ||||||||||||||||||||||||||||||||||||||||||
| shrink_axis_mask=[], | ||||||||||||||||||||||||||||||||||||||||||
| ellipsis_mask=[], | ||||||||||||||||||||||||||||||||||||||||||
| ).output(0) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| begin_lower = ov_opset.constant( | ||||||||||||||||||||||||||||||||||||||||||
| np.array([0] * rank, dtype=np.int64), Type.i64 | ||||||||||||||||||||||||||||||||||||||||||
| ).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| end_lower_list = [0] * rank | ||||||||||||||||||||||||||||||||||||||||||
| end_lower_list[axis] = -1 | ||||||||||||||||||||||||||||||||||||||||||
| end_lower = ov_opset.constant( | ||||||||||||||||||||||||||||||||||||||||||
| np.array(end_lower_list, dtype=np.int64), Type.i64 | ||||||||||||||||||||||||||||||||||||||||||
| ).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| begin_mask_lower = [1] * rank | ||||||||||||||||||||||||||||||||||||||||||
| end_mask_lower = [1] * rank | ||||||||||||||||||||||||||||||||||||||||||
| end_mask_lower[axis] = 0 | ||||||||||||||||||||||||||||||||||||||||||
| lower = ov_opset.strided_slice( | ||||||||||||||||||||||||||||||||||||||||||
| data=result, | ||||||||||||||||||||||||||||||||||||||||||
| begin=begin_lower, | ||||||||||||||||||||||||||||||||||||||||||
| end=end_lower, | ||||||||||||||||||||||||||||||||||||||||||
| strides=strides, | ||||||||||||||||||||||||||||||||||||||||||
| begin_mask=begin_mask_lower, | ||||||||||||||||||||||||||||||||||||||||||
| end_mask=end_mask_lower, | ||||||||||||||||||||||||||||||||||||||||||
| new_axis_mask=[], | ||||||||||||||||||||||||||||||||||||||||||
| shrink_axis_mask=[], | ||||||||||||||||||||||||||||||||||||||||||
| ellipsis_mask=[], | ||||||||||||||||||||||||||||||||||||||||||
| ).output(0) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if a_type == Type.boolean: | ||||||||||||||||||||||||||||||||||||||||||
| result = ov_opset.not_equal(upper, lower).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| result = ov_opset.subtract(upper, lower).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(result) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def digitize(x, bins): | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -512,11 +647,30 @@ def digitize(x, bins): | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def dot(x, y): | ||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError("`dot` is not supported with openvino backend") | ||||||||||||||||||||||||||||||||||||||||||
| element_type = None | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(x, OpenVINOKerasTensor): | ||||||||||||||||||||||||||||||||||||||||||
| element_type = x.output.get_element_type() | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(y, OpenVINOKerasTensor): | ||||||||||||||||||||||||||||||||||||||||||
| element_type = y.output.get_element_type() | ||||||||||||||||||||||||||||||||||||||||||
| x = get_ov_output(x, element_type) | ||||||||||||||||||||||||||||||||||||||||||
| y = get_ov_output(y, element_type) | ||||||||||||||||||||||||||||||||||||||||||
| x, y = _align_operand_types(x, y, "dot()") | ||||||||||||||||||||||||||||||||||||||||||
| if x.get_partial_shape().rank == 0 or y.get_partial_shape().rank == 0: | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(ov_opset.multiply(x, y).output(0)) | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(ov_opset.matmul(x, y, False, False).output(0)) | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def empty(shape, dtype=None): | ||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError("`empty` is not supported with openvino backend") | ||||||||||||||||||||||||||||||||||||||||||
| dtype = standardize_dtype(dtype) or config.floatx() | ||||||||||||||||||||||||||||||||||||||||||
| ov_type = OPENVINO_DTYPES[dtype] | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(shape, tuple): | ||||||||||||||||||||||||||||||||||||||||||
| shape = list(shape) | ||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(shape, int): | ||||||||||||||||||||||||||||||||||||||||||
| shape = [shape] | ||||||||||||||||||||||||||||||||||||||||||
| shape_node = ov_opset.constant(shape, Type.i32).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| const_zero = ov_opset.constant(0, dtype=ov_type).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| empty_tensor = ov_opset.broadcast(const_zero, shape_node).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(empty_tensor) | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
663
to
+673
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementation of
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def equal(x1, x2): | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -533,14 +687,17 @@ def equal(x1, x2): | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def exp(x): | ||||||||||||||||||||||||||||||||||||||||||
| x = get_ov_output(x) | ||||||||||||||||||||||||||||||||||||||||||
| x_type = x.get_element_type() | ||||||||||||||||||||||||||||||||||||||||||
| if x_type.is_integral(): | ||||||||||||||||||||||||||||||||||||||||||
| ov_type = OPENVINO_DTYPES[config.floatx()] | ||||||||||||||||||||||||||||||||||||||||||
| x = ov_opset.convert(x, ov_type) | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(ov_opset.exp(x).output(0)) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def expand_dims(x, axis): | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(x, OpenVINOKerasTensor): | ||||||||||||||||||||||||||||||||||||||||||
| x = x.output | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| assert False | ||||||||||||||||||||||||||||||||||||||||||
| x = get_ov_output(x) | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(axis, tuple): | ||||||||||||||||||||||||||||||||||||||||||
| axis = list(axis) | ||||||||||||||||||||||||||||||||||||||||||
| axis = ov_opset.constant(axis, Type.i32).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(ov_opset.unsqueeze(x, axis).output(0)) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -571,9 +728,15 @@ def full(shape, fill_value, dtype=None): | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def full_like(x, fill_value, dtype=None): | ||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||||||||||||||||||||||||||
| "`full_like` is not supported with openvino backend" | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| x = get_ov_output(x) | ||||||||||||||||||||||||||||||||||||||||||
| shape_x = ov_opset.shape_of(x) | ||||||||||||||||||||||||||||||||||||||||||
| if dtype is not None: | ||||||||||||||||||||||||||||||||||||||||||
| ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)] | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| ov_type = x.get_element_type() | ||||||||||||||||||||||||||||||||||||||||||
| const_value = ov_opset.constant(fill_value, ov_type).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| res = ov_opset.broadcast(const_value, shape_x).output(0) | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(res) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def greater(x1, x2): | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -601,7 +764,20 @@ def greater_equal(x1, x2): | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def hstack(xs): | ||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError("`hstack` is not supported with openvino backend") | ||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(xs, (list, tuple)): | ||||||||||||||||||||||||||||||||||||||||||
| raise TypeError("Input to `hstack` must be a list or tuple of tensors.") | ||||||||||||||||||||||||||||||||||||||||||
| if len(xs) == 0: | ||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Input list to `hstack` cannot be empty.") | ||||||||||||||||||||||||||||||||||||||||||
| element_type = None | ||||||||||||||||||||||||||||||||||||||||||
| for x in xs: | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(x, OpenVINOKerasTensor): | ||||||||||||||||||||||||||||||||||||||||||
| element_type = x.output.get_element_type() | ||||||||||||||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||||||||||||||
| xs = [get_ov_output(x, element_type) for x in xs] | ||||||||||||||||||||||||||||||||||||||||||
| xs = _align_operand_types(xs[0], xs[1], "hstack()") | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+771
to
+777
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type alignment for the input tensors in
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
| rank = len(xs[0].get_partial_shape()) | ||||||||||||||||||||||||||||||||||||||||||
| axis = 1 if rank > 1 else 0 | ||||||||||||||||||||||||||||||||||||||||||
| return OpenVINOKerasTensor(ov_opset.concat(xs, axis=axis).output(0)) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def identity(n, dtype=None): | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of
rank_xis a bit complex. You can simplify this by usingov_opset.rank(x)which directly returns the rank of the tensor as a scalar tensor. This would make the code more readable and concise.