Skip to content

Commit ed9f84e

Browse files
committed
Fix condition for broadcast
1 parent 2ebe86a commit ed9f84e

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2786,8 +2786,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
27862786
}
27872787

27882788
//if rms norm is the B operand, then we don't handle broadcast
2789-
if (rms_norm == mul->src[1] &&
2790-
mul->src[0]->ne[1] != rms_norm->src[1]->ne[1]) {
2789+
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
27912790
return false;
27922791
}
27932792

0 commit comments

Comments
 (0)