Skip to content

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Oct 10, 2025

fixes #1733
fixes #1295

@avik-pal avik-pal force-pushed the ap/stacked_batchdup branch 2 times, most recently from 0f51436 to 36da208 Compare October 10, 2025 20:02
@avik-pal
Copy link
Collaborator Author

We should merge and release a JLL with EnzymeAD/Enzyme-JAX#1466 before merging this

@avik-pal avik-pal requested a review from wsmoses October 10, 2025 21:49
@avik-pal avik-pal marked this pull request as ready for review October 10, 2025 21:49
@avik-pal avik-pal force-pushed the ap/stacked_batchdup branch 2 times, most recently from f437aae to d71b23f Compare October 11, 2025 17:19
@avik-pal
Copy link
Collaborator Author

avik-pal commented Oct 11, 2025

SliceSimplify seems to be incorrect:

module {
  func.func @main(%arg0: tensor<2x3xf32> {enzymexla.memory_effects = []}) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) attributes {enzymexla.memory_effects = []} {
    %cst = stablehlo.constant dense<[[1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x6xf32>
    %0 = stablehlo.slice %cst [0:6, 0:1] : (tensor<6x6xf32>) -> tensor<6x1xf32>
    %1 = stablehlo.reshape %0 : (tensor<6x1xf32>) -> tensor<2x3xf32>
    %2 = stablehlo.slice %cst [0:6, 1:2] : (tensor<6x6xf32>) -> tensor<6x1xf32>
    %3 = stablehlo.reshape %2 : (tensor<6x1xf32>) -> tensor<2x3xf32>
    %4 = stablehlo.slice %cst [0:6, 2:3] : (tensor<6x6xf32>) -> tensor<6x1xf32>
    %5 = stablehlo.reshape %4 : (tensor<6x1xf32>) -> tensor<2x3xf32>
    %6 = stablehlo.slice %cst [0:6, 3:4] : (tensor<6x6xf32>) -> tensor<6x1xf32>
    %7 = stablehlo.reshape %6 : (tensor<6x1xf32>) -> tensor<2x3xf32>
    %8 = stablehlo.slice %cst [0:6, 4:5] : (tensor<6x6xf32>) -> tensor<6x1xf32>
    %9 = stablehlo.reshape %8 : (tensor<6x1xf32>) -> tensor<2x3xf32>
    %10 = stablehlo.slice %cst [0:6, 5:6] : (tensor<6x6xf32>) -> tensor<6x1xf32>
    %11 = stablehlo.reshape %10 : (tensor<6x1xf32>) -> tensor<2x3xf32>
    return %1, %3, %5, %7, %9, %11 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>
  }
}

After slice simplify

module {
  func.func @main(%arg0: tensor<2x3xf32> {enzymexla.memory_effects = []}) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) attributes {enzymexla.memory_effects = []} {
    %cst = stablehlo.constant dense<[[0.000000e+00], [0.000000e+00], [1.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00]]> : tensor<6x1xf32>
    %cst_0 = stablehlo.constant dense<[[0.000000e+00], [0.000000e+00], [0.000000e+00], [1.000000e+00], [0.000000e+00], [0.000000e+00]]> : tensor<6x1xf32>
    %cst_1 = stablehlo.constant dense<[[0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00], [1.000000e+00], [0.000000e+00]]> : tensor<6x1xf32>
    %cst_2 = stablehlo.constant dense<[[0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00], [1.000000e+00]]> : tensor<6x1xf32>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<6x1xf32>
    %cst_4 = stablehlo.constant dense<[[1.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00], [0.000000e+00]]> : tensor<6x1xf32>
    %0 = stablehlo.reshape %cst_4 : (tensor<6x1xf32>) -> tensor<2x3xf32>
    %1 = stablehlo.reshape %cst_3 : (tensor<6x1xf32>) -> tensor<2x3xf32>
    %2 = stablehlo.reshape %cst_2 : (tensor<6x1xf32>) -> tensor<2x3xf32>
    %3 = stablehlo.reshape %cst_1 : (tensor<6x1xf32>) -> tensor<2x3xf32>
    %4 = stablehlo.reshape %cst_0 : (tensor<6x1xf32>) -> tensor<2x3xf32>
    %5 = stablehlo.reshape %cst : (tensor<6x1xf32>) -> tensor<2x3xf32>
    return %0, %1, %2, %3, %4, %5 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>
  }
}

EDIT: fixed with the latest push to EnzymeJAX

@avik-pal avik-pal force-pushed the ap/stacked_batchdup branch from 7c12684 to bbadac4 Compare October 12, 2025 00:05
@avik-pal
Copy link
Collaborator Author

Need JuliaPackaging/Yggdrasil#12270 before merging

@avik-pal avik-pal force-pushed the ap/stacked_batchdup branch 2 times, most recently from bb431db to 49b55b9 Compare October 12, 2025 14:49
@avik-pal
Copy link
Collaborator Author

This is good to go from my end

@avik-pal avik-pal force-pushed the ap/stacked_batchdup branch from 3fda458 to 10d888e Compare October 13, 2025 04:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Incorrect results for nested AD Alternate implementation of BatchDuplicated

1 participant