diff --git a/.bazelignore b/.bazelignore index dac859a..05aafad 100644 --- a/.bazelignore +++ b/.bazelignore @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -# ignore local_repos of llvm-project, torch-mlir, stablehlo +# ignore local_repos of llvm-project, torch-mlir third_party/llvm-project third_party/torch-mlir -third_party/stablehlo diff --git a/.github/workflows/bazelBuildAndTestStablehlo.yml b/.github/workflows/bazelBuildAndTestStablehlo.yml deleted file mode 100644 index 9079e87..0000000 --- a/.github/workflows/bazelBuildAndTestStablehlo.yml +++ /dev/null @@ -1,64 +0,0 @@ -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -name: Bazel Build and Test (stablehlo) - -# Only run when stablehlo hash changes (deps.bzl) -on: - pull_request: - branches: - - main - paths: - - 'deps.bzl' - - '.github/workflows/bazelBuildAndTestStablehlo.yml' - # TODO: Use self-hosted runners as we hit disk space issues with GitHub hosted runners - # push: - # branches: - # - main - # paths: - # - 'deps.bzl' - # - '.github/workflows/bazelBuildAndTestStablehlo.yml' - workflow_dispatch: - -# Ensure that only a single job or workflow using the same -# concurrency group will run at a time. This would cancel -# any in-progress jobs in the same github workflow and github -# ref (e.g. refs/heads/main or refs/pull//merge). -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - - -jobs: - ubuntu-build: - name: ubuntu-x86_64 / stablehlo - runs-on: ubuntu-latest - - steps: - - name: Checkout mlir-tcp - uses: actions/checkout@v4 - - - name: Setup workspace - uses: ./.github/actions/setup-build - with: - cache-prefix: 'stablehlo' - - - name: Build docker image - run: | - docker build -f docker/Dockerfile \ - -t mlir-tcp:ci \ - --build-arg GROUP=$(id -gn) \ - --build-arg GID=$(id -g) \ - --build-arg USER=$(id -un) \ - --build-arg UID=$(id -u) \ - . - - - name: Bazel build and test stablehlo - run: | - docker run --rm \ - -v "$(pwd)":"/opt/src/mlir-tcp" \ - -v "${HOME}/.cache/bazel":"${HOME}/.cache/bazel" \ - mlir-tcp:ci \ - bazel test --test_output=errors @stablehlo//... diff --git a/.gitignore b/.gitignore index e57e392..dab755f 100644 --- a/.gitignore +++ b/.gitignore @@ -9,10 +9,9 @@ bazel-out bazel-mlir-tcp bazel-testlogs -# ignore local_repos of llvm, torch-mlir, stablehlo +# ignore local_repos of llvm, torch-mlir third_party/llvm-project third_party/torch-mlir -third_party/stablehlo # clangd related .cache diff --git a/BUILD b/BUILD index 67765eb..d0a6782 100644 --- a/BUILD +++ b/BUILD @@ -150,6 +150,7 @@ cc_library( name = "TcpDialectPasses", srcs = [ "lib/Dialect/Transforms/DropSymbolicShapeOpsPass.cpp", + "lib/Dialect/Transforms/EliminateUnusedTorchOpsPass.cpp", "lib/Dialect/Transforms/FuseTcpOpsPass.cpp", "lib/Dialect/Transforms/FusionPatterns.cpp", "lib/Dialect/Transforms/IsolateGroupOpsPass.cpp", @@ -160,6 +161,7 @@ cc_library( ], hdrs = [ "include/mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h", + "include/mlir-tcp/Dialect/Transforms/EliminateUnusedTorchOpsPass.h", "include/mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h", "include/mlir-tcp/Dialect/Transforms/FusionPatterns.h", "include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h", @@ -175,6 +177,7 @@ cc_library( "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:Transforms", + "@torch-mlir//:TorchMLIRTorchDialect", ], ) @@ -198,7 +201,6 @@ cc_library( hdrs = ["include/mlir-tcp/Conversion/Passes.h"], strip_include_prefix = "include", deps = [ - ":StablehloToTcp", ":TcpToArith", ":TcpToLinalg", ":TcpToTensor", @@ -237,25 +239,6 @@ cc_library( ], ) -cc_library( - name = "StablehloToTcp", - srcs = [ - "lib/Conversion/PassDetail.h", - "lib/Conversion/StablehloToTcp/StablehloToTcp.cpp", - ], - hdrs = ["include/mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h"], - strip_include_prefix = "include", - deps = [ - ":TcpConversionPassesIncGen", - ":TcpDialect", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Transforms", - "@stablehlo//:stablehlo_ops", - ], -) - cc_library( name = "TcpToLinalg", srcs = [ @@ -364,6 +347,5 @@ cc_binary( "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:QuantOps", - "@stablehlo//:register", ], ) diff --git a/README.md b/README.md index c733cf2..d76571d 100644 --- a/README.md +++ b/README.md @@ -51,17 +51,15 @@ bazel run //tools/clangd:refresh_compile_commands ``` When run successfully, a `compile_commands.json` is generated at the workspace root (and refreshed upon re-runs). If you're using VSCode, just hit CMD+SHIFT+P and select `clangd: Restart language server` to start clangd. Note that this only works for non-docker builds at the moment. -When bumping upstream dependencies (LLVM, Torch-MLIR, StableHLO), you may validate the set of "green commits" by running the corresponding third-party tests: +When bumping upstream dependencies (LLVM, Torch-MLIR), you may validate the set of "green commits" by running the corresponding third-party tests: ```shell bazel test @llvm-project//mlir/... bazel test @torch-mlir//... -bazel test @stablehlo//... ``` The following CI workflows are automatically triggered anytime upstream dependencies (`deps.bzl`) are updated: - [![Bazel Build and Test (llvm-project)](https://github.com/llvm/mlir-tcp/actions/workflows/bazelBuildAndTestLlvm.yml/badge.svg)](https://github.com/llvm/mlir-tcp/actions/workflows/bazelBuildAndTestLlvm.yml) - [![Bazel Build and Test (torch-mlir)](https://github.com/llvm/mlir-tcp/actions/workflows/bazelBuildAndTestTorchmlir.yml/badge.svg)](https://github.com/llvm/mlir-tcp/actions/workflows/bazelBuildAndTestTorchmlir.yml) -- [![Bazel Build and Test (stablehlo)](https://github.com/llvm/mlir-tcp/actions/workflows/bazelBuildAndTestStablehlo.yml/badge.svg)](https://github.com/llvm/mlir-tcp/actions/workflows/bazelBuildAndTestStablehlo.yml) To use newer `torch-mlir` and/or `torch` python packages in our hermetic python sandbox, just regenerate `requirements_lock.txt` as follows: ```shell diff --git a/deps.bzl b/deps.bzl index e783d9a..33e4de5 100644 --- a/deps.bzl +++ b/deps.bzl @@ -7,10 +7,8 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load( ":local_repos.bzl", "local_llvm_repo_path", - "local_stablehlo_repo_path", "local_torch_mlir_repo_path", "use_local_llvm_repo", - "use_local_stablehlo_repo", "use_local_torch_mlir_repo", ) @@ -22,8 +20,8 @@ def third_party_deps(): path = local_llvm_repo_path(), ) else: - LLVM_COMMIT = "72144d119a7291f8b6b8e022a2947fbe31e66afc" - LLVM_SHA256 = "2caacb6925a13cb5886a5d7f225fa408b80ca8e1efe0736186954b2abc4ee1c3" + LLVM_COMMIT = "b231e5ff504295641b0f580ceefa2e1048011614" + LLVM_SHA256 = "88dfa59052730710cb48fa20b00a4344144edd1c3cb524c06d983899835e491a" http_archive( name = "llvm-raw", build_file_content = "# empty", @@ -39,32 +37,15 @@ def third_party_deps(): path = local_torch_mlir_repo_path(), ) else: - TORCH_MLIR_COMMIT = "9f2ba5abaa85cefd95cc85579fafd0c53c1101e8" - TORCH_MLIR_SHA256 = "09444281839eeae4aff42c029d87b1728f307fa26511b896ff448d51aaa98049" + TORCH_MLIR_COMMIT = "1ad9702d2a290b693c4f6f17921d0e0a8d14a999" + TORCH_MLIR_SHA256 = "8843399168c34ca3ca16d2417703fe4e1440ca7240d9e04844b3deedf256f0ab" http_archive( name = "torch-mlir-raw", build_file_content = "# empty", + patches = ["//third_party/patches:torch-mlir-bazel-build.1.patch", "//third_party/patches:torch-mlir-bazel-build.2.patch"], sha256 = TORCH_MLIR_SHA256, strip_prefix = "torch-mlir-" + TORCH_MLIR_COMMIT, urls = ["https://github.com/llvm/torch-mlir/archive/{commit}.tar.gz".format(commit = TORCH_MLIR_COMMIT)], - patches = [ - "//third_party/patches:torch-mlir.1.patch", - ], - ) - - if use_local_stablehlo_repo(): - native.local_repository( - name = "stablehlo", - path = local_stablehlo_repo_path(), - ) - else: - STABLEHLO_COMMIT = "a54938f0651d3b4b7be9771848eda2463c92a8e7" - STABLEHLO_SHA256 = "edab2288f0b19e3efbf08815d17d4efb106984aa6fe02fed0cb2165284e6a5b7" - http_archive( - name = "stablehlo", - sha256 = STABLEHLO_SHA256, - strip_prefix = "stablehlo-" + STABLEHLO_COMMIT, - urls = ["https://github.com/openxla/stablehlo/archive/{commit}.tar.gz".format(commit = STABLEHLO_COMMIT)], ) SKYLIB_VERSION = "1.3.0" diff --git a/include/mlir-tcp/Conversion/Passes.td b/include/mlir-tcp/Conversion/Passes.td index 3a11a72..8da8d7b 100644 --- a/include/mlir-tcp/Conversion/Passes.td +++ b/include/mlir-tcp/Conversion/Passes.td @@ -46,23 +46,6 @@ def ConvertTorchToTcpCustomOp : Pass<"convert-torch-to-tcp-custom-op", "func::Fu ]; } -//===----------------------------------------------------------------------===// -// StablehloToTcp -//===----------------------------------------------------------------------===// - -def ConvertStablehloToTcp - : Pass<"convert-stablehlo-to-tcp", "func::FuncOp"> { - let summary = "Lower StableHLO to TCP"; - let description = [{ - Pass that converts StableHLO operations to equivalent operations in TCP. - }]; - - let constructor = "mlir::tcp::createConvertStablehloToTcpPass()"; - let dependentDialects = [ - "mlir::tcp::TcpDialect", - ]; -} - //===----------------------------------------------------------------------===// // TcpToLinalg //===----------------------------------------------------------------------===// diff --git a/include/mlir-tcp/Dialect/IR/TcpTypes.td b/include/mlir-tcp/Dialect/IR/TcpTypes.td index 2ab1573..62e51aa 100644 --- a/include/mlir-tcp/Dialect/IR/TcpTypes.td +++ b/include/mlir-tcp/Dialect/IR/TcpTypes.td @@ -24,8 +24,8 @@ include "mlir-tcp/Dialect/IR/TcpBase.td" // Where low and high ends are 0,255 when unsigned, -128,127 when signed, for // the 8-bit case. class Tcp_QuantizedType params, bit signed> - : Type()">, - CPred<"$_self.cast()" # + : Type($_self)">, + CPred<"::llvm::cast($_self)" # ".getStorageTypeIntegralWidth() == " # !head(params)>]>, "Q" # !if (signed, "int", "uint") # !head(params) # " type"> { string name = n; diff --git a/include/mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h b/include/mlir-tcp/Dialect/Transforms/EliminateUnusedTorchOpsPass.h similarity index 61% rename from include/mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h rename to include/mlir-tcp/Dialect/Transforms/EliminateUnusedTorchOpsPass.h index 788c750..a81fec6 100644 --- a/include/mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h +++ b/include/mlir-tcp/Dialect/Transforms/EliminateUnusedTorchOpsPass.h @@ -9,17 +9,13 @@ #pragma once -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include -namespace mlir { +namespace mlir::tcp { -#define GEN_PASS_DECL_CONVERTSTABLEHLOTOTCP -#include "mlir-tcp/Conversion/Passes.h.inc" +std::unique_ptr> +createEliminateUnusedTorchOpsPass(); -namespace tcp { - -std::unique_ptr> createConvertStablehloToTcpPass(); - -} // namespace tcp -} // namespace mlir +} // namespace mlir::tcp diff --git a/include/mlir-tcp/Dialect/Transforms/Passes.td b/include/mlir-tcp/Dialect/Transforms/Passes.td index 58dbdbb..c55b910 100644 --- a/include/mlir-tcp/Dialect/Transforms/Passes.td +++ b/include/mlir-tcp/Dialect/Transforms/Passes.td @@ -45,4 +45,10 @@ def DropSymbolicShapeOps : Pass<"drop-symbolic-shape-ops", "func::FuncOp"> { let constructor = "mlir::tcp::createDropSymbolicShapeOpsPass()"; } +// \brief This pass removes unused torch ops. +def EliminateUnusedTorchOps : Pass<"eliminate-unused-torch-ops", "ModuleOp"> { + let summary = "Removes unused/unnecessary torch ops"; + let constructor = "mlir::tcp::createEliminateUnusedTorchOpsPass()"; +} + #endif // TCP_PASSES diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 6173eca..3e91b6b 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -9,7 +9,6 @@ #include "mlir-tcp/Conversion/Passes.h" -#include "mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h" #include "mlir-tcp/Conversion/TcpToArith/TcpToArith.h" #include "mlir-tcp/Conversion/TcpToLinalg/TcpToLinalg.h" #include "mlir-tcp/Conversion/TcpToTensor/TcpToTensor.h" diff --git a/lib/Conversion/StablehloToTcp/StablehloToTcp.cpp b/lib/Conversion/StablehloToTcp/StablehloToTcp.cpp deleted file mode 100644 index 2b5fa58..0000000 --- a/lib/Conversion/StablehloToTcp/StablehloToTcp.cpp +++ /dev/null @@ -1,75 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#include "mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h" - -#include "mlir-tcp/Dialect/IR/TcpDialect.h" -#include "mlir-tcp/Dialect/IR/TcpOps.h" - -#include "../PassDetail.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/Passes.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir { - -#define GEN_PASS_DEF_CONVERTSTABLEHLOTOTCP -#include "mlir-tcp/Conversion/Passes.h.inc" - -namespace tcp { - -namespace { - -class TanhOpConverter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(stablehlo::TanhOp op, - PatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); - return success(); - } -}; - -void populateStablehloToTcpPatternsAndLegality(RewritePatternSet &patterns, - ConversionTarget &target) { - MLIRContext *context = patterns.getContext(); - - target.addIllegalOp(); - patterns.add(context); -} - -class ConvertStablehloToTcp - : public ConvertStablehloToTcpBase { -public: - void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionTarget target(*context); - target.addLegalDialect(); - - RewritePatternSet patterns(context); - populateStablehloToTcpPatternsAndLegality(patterns, target); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> createConvertStablehloToTcpPass() { - return std::make_unique(); -} - -} // namespace tcp -} // namespace mlir diff --git a/lib/Conversion/TcpToLinalg/DataMovement.cpp b/lib/Conversion/TcpToLinalg/DataMovement.cpp index 7c3b35d..45e8b25 100644 --- a/lib/Conversion/TcpToLinalg/DataMovement.cpp +++ b/lib/Conversion/TcpToLinalg/DataMovement.cpp @@ -36,9 +36,8 @@ class ConvertGatherOp : public OpConversionPattern { matchAndRewrite(GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto resultTensorType = getTypeConverter() - ->convertType(op.getOut().getType()) - .cast(); + auto resultTensorType = cast( + getTypeConverter()->convertType(op.getOut().getType())); auto inputTensor = adaptor.getInput(); auto indicesTensor = adaptor.getIndices(); @@ -110,9 +109,8 @@ class ConvertGatherNDOp : public OpConversionPattern { matchAndRewrite(GatherNDOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto resultTensorType = getTypeConverter() - ->convertType(op.getOut().getType()) - .cast(); + auto resultTensorType = cast( + getTypeConverter()->convertType(op.getOut().getType())); auto inputTensor = adaptor.getInput(); auto indicesTensor = adaptor.getIndices(); diff --git a/lib/Conversion/TcpToLinalg/Elementwise.cpp b/lib/Conversion/TcpToLinalg/Elementwise.cpp index 3ff292b..3c6ad94 100644 --- a/lib/Conversion/TcpToLinalg/Elementwise.cpp +++ b/lib/Conversion/TcpToLinalg/Elementwise.cpp @@ -69,7 +69,7 @@ createLinalgPayloadForElementwiseOp(Operation *op, // This implementation always performs the max followed by min. // TODO: Is this going to work for degenerative floating point numbers? Value result = payloadArgs[0]; - if (elemType.isa()) { + if (isa(elemType)) { auto minFloat = clampOp.getMinFloat(); auto maxFloat = clampOp.getMaxFloat(); if (minFloat) @@ -80,7 +80,7 @@ createLinalgPayloadForElementwiseOp(Operation *op, result = b.create( loc, result, b.create(loc, *maxFloat, b.getF32Type())); - } else if (elemType.isa()) { + } else if (isa(elemType)) { auto minInt = clampOp.getMinInt(); auto maxInt = clampOp.getMaxInt(); if (minInt) @@ -136,9 +136,9 @@ createLinalgPayloadForElementwiseOp(Operation *op, } if (isa(op)) { - if (elemType.isa()) + if (isa(elemType)) return {b.create(loc, payloadArgs[0])}; - else if (elemType.isa()) + else if (isa(elemType)) return {b.create(loc, payloadArgs[0])}; else llvm_unreachable("unsupported element type in " @@ -158,9 +158,9 @@ createLinalgPayloadForElementwiseOp(Operation *op, } if (isa(op)) { - if (elemType.isa()) + if (isa(elemType)) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; - else if (elemType.isa()) + else if (isa(elemType)) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; else llvm_unreachable("unsupported element type in " @@ -168,9 +168,9 @@ createLinalgPayloadForElementwiseOp(Operation *op, } if (isa(op)) { - if (elemType.isa()) + if (isa(elemType)) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; - else if (elemType.isa()) + else if (isa(elemType)) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; else llvm_unreachable("unsupported element type in " @@ -178,9 +178,9 @@ createLinalgPayloadForElementwiseOp(Operation *op, } if (isa(op)) { - if (elemType.isa()) + if (isa(elemType)) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; - else if (elemType.isa()) + else if (isa(elemType)) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; else llvm_unreachable("unsupported element type in " @@ -188,7 +188,7 @@ createLinalgPayloadForElementwiseOp(Operation *op, } if (isa(op)) { - if (elemType.isa()) + if (isa(elemType)) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; else llvm_unreachable("unsupported element type in " @@ -196,7 +196,7 @@ createLinalgPayloadForElementwiseOp(Operation *op, } if (auto divOp = dyn_cast(op)) { - if (!elemType.isa()) + if (!isa(elemType)) llvm_unreachable("unsupported element type in " "createLinalgPayloadForElementwiseOp for tcp.divsi"); if (divOp.getRoundingMode() == RoundingMode::Trunc) @@ -210,7 +210,7 @@ createLinalgPayloadForElementwiseOp(Operation *op, } if (auto divOp = dyn_cast(op)) { - if (!elemType.isa()) + if (!isa(elemType)) llvm_unreachable("unsupported element type in " "createLinalgPayloadForElementwiseOp for tcp.divui"); if (divOp.getRoundingMode() == RoundingMode::Trunc || @@ -222,7 +222,7 @@ createLinalgPayloadForElementwiseOp(Operation *op, } if (isa(op)) { - if (elemType.isa()) + if (isa(elemType)) return {b.create(loc, payloadArgs[0], payloadArgs[1])}; else llvm_unreachable("unsupported element type in " @@ -231,7 +231,7 @@ createLinalgPayloadForElementwiseOp(Operation *op, if (auto castOp = dyn_cast(op)) { auto inputType = - castOp.getIn().getType().dyn_cast().getElementType(); + dyn_cast(castOp.getIn().getType()).getElementType(); auto outputType = resultTensorType.getElementType(); if (inputType.getIntOrFloatBitWidth() == @@ -246,24 +246,24 @@ createLinalgPayloadForElementwiseOp(Operation *op, // To I1 (Bool) type Value cstZero = b.create(loc, b.getZeroAttr(inputType)); - if (inputType.isa()) { + if (isa(inputType)) { return {b.create(loc, arith::CmpFPredicate::UNE, payloadArgs[0], cstZero)}; - } else if (inputType.isa()) { + } else if (isa(inputType)) { return {b.create(loc, arith::CmpIPredicate::ne, payloadArgs[0], cstZero)}; } - } else if (outputType.isa()) { + } else if (isa(outputType)) { // TO FP type // FP -> FP - if (inputType.dyn_cast()) { + if (dyn_cast(inputType)) { if (inputType.getIntOrFloatBitWidth() > outputType.getIntOrFloatBitWidth()) return {b.create(loc, outputType, payloadArgs[0])}; return {b.create(loc, outputType, payloadArgs[0])}; } // INT -> FP - else if (inputType.dyn_cast()) { + else if (dyn_cast(inputType)) { // Signless or Unsigned INT to FP // Curently, signless is only for i1 (bool) case, // which has been handeled above @@ -274,10 +274,10 @@ createLinalgPayloadForElementwiseOp(Operation *op, else if (castOp.getInIntSignedness().value() == Signedness::Signed) return {b.create(loc, outputType, payloadArgs[0])}; } - } else if (outputType.isa()) { + } else if (isa(outputType)) { // TO INT type // FP -> INT - if (inputType.dyn_cast()) { + if (dyn_cast(inputType)) { // FP to Signless or Unsigned INT if (castOp.getOutIntSignedness().value() == Signedness::Signless || castOp.getOutIntSignedness().value() == Signedness::Unsigned) @@ -287,7 +287,7 @@ createLinalgPayloadForElementwiseOp(Operation *op, return {b.create(loc, outputType, payloadArgs[0])}; } // INT -> INT - if (inputType.dyn_cast()) { + if (dyn_cast(inputType)) { if (inputType.getIntOrFloatBitWidth() > outputType.getIntOrFloatBitWidth()) return {b.create(loc, outputType, payloadArgs[0])}; @@ -318,12 +318,12 @@ class ConvertElementwiseOp : public OpConversionPattern { matchAndRewrite(TcpOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto resultTensorType = OpConversionPattern::getTypeConverter() - ->convertType(op->getResult(0).getType()) - .template cast(); + auto resultTensorType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op->getResult(0).getType())); auto tensorOperands = llvm::to_vector<6>( llvm::make_filter_range(adaptor.getOperands(), [](Value v) { - return v.getType().isa(); + return isa(v.getType()); })); // Create Linalg payload diff --git a/lib/Conversion/TcpToLinalg/Misc.cpp b/lib/Conversion/TcpToLinalg/Misc.cpp index 4dfcbf5..5741599 100644 --- a/lib/Conversion/TcpToLinalg/Misc.cpp +++ b/lib/Conversion/TcpToLinalg/Misc.cpp @@ -29,7 +29,7 @@ namespace { SmallVector getValuesFromIndexArrayAttribute(ArrayAttr attr) { SmallVector arrayValues; for (Attribute val : attr.getValue()) - arrayValues.push_back(val.cast().getValue().getSExtValue()); + arrayValues.push_back(cast(val).getValue().getSExtValue()); return arrayValues; } @@ -40,9 +40,8 @@ class ConvertBroadcastOp : public OpConversionPattern { LogicalResult matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &b) const override { Location loc = op->getLoc(); - auto resultTensorType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto resultTensorType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); auto inputTensor = op->getOperands()[0]; SmallVector axes = getValuesFromIndexArrayAttribute(op.getAxes()); diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 3b2990c..50544a7 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -40,8 +40,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, SmallVector &strides) { Location loc = op.getLoc(); auto input = adaptor.getSelf(); - RankedTensorType inputType = - input.getType().template cast(); + RankedTensorType inputType = cast(input.getType()); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); @@ -64,8 +63,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value builtinTypeStart = adaptor.getStart(); Value builtinTypeEnd = adaptor.getEnd(); - if (torchTypeStart.getType().isa() || - torchTypeEnd.getType().isa()) + if (isa(torchTypeStart.getType()) || + isa(torchTypeEnd.getType())) return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep()); @@ -75,7 +74,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, // We cannot use to positive valid dim as for negative strides we need to // clamp to `-1` so that the full tensor bounds are available: Value end = builtinTypeEnd; - if (torchTypeEnd.getType().isa()) { + if (isa(torchTypeEnd.getType())) { end = dimSize; } else { end = castIntToIndex(rewriter, loc, end); @@ -140,7 +139,7 @@ class ConvertAtenCatOp : public OpConversionPattern { getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); RankedTensorType newResultType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); int rank = newResultType.getRank(); Value dimValue = op.getDim(); int64_t dim; @@ -185,9 +184,8 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern { return failure(); auto input = adaptor.getSelf(); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); SmallVector resultShape; SmallVector offsets; @@ -213,14 +211,13 @@ class ConvertAtenGatherOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto input = adaptor.getSelf(); auto indices = adaptor.getIndex(); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .template cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); int64_t dim = 0; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op.emitError("dim on torch.gather must be an int constant"); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); dim = Torch::toPositiveDim(dim, inputType.getRank()); bool sparseGrad = false; diff --git a/lib/Conversion/TorchToTcp/Elementwise.cpp b/lib/Conversion/TorchToTcp/Elementwise.cpp index 3ac8c86..de53612 100644 --- a/lib/Conversion/TorchToTcp/Elementwise.cpp +++ b/lib/Conversion/TorchToTcp/Elementwise.cpp @@ -46,12 +46,12 @@ Value convertScalarOperandToTensor(ConversionPatternRewriter &rewriter, RankedTensorType::get({}, convertedScalarValue.getType()); Value resultValue = torch_to_tcp::scalarToTcpTensor( rewriter, op, scalarToTensorType, scalarValue); - if (convertedScalarValue.getType().isa()) + if (isa(convertedScalarValue.getType())) // FP scalarValue is treated as fp64 resultValue = torch_to_tcp::castTensorToDtype( rewriter, rewriter.getF64Type(), outputType, resultValue, convertedOutputType); - else if (convertedScalarValue.getType().isa()) + else if (isa(convertedScalarValue.getType())) // INT scalarValue is treated as si64 resultValue = torch_to_tcp::castTensorToDtype( rewriter, rewriter.getIntegerType(64, true), outputType, resultValue, @@ -69,26 +69,23 @@ class ConvertAtenAddSubOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); - RankedTensorType lhsType = lhs.getType().dyn_cast(); + RankedTensorType lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getOther(); - RankedTensorType resultType = - OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + RankedTensorType resultType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!lhsType || !resultType) return rewriter.notifyMatchFailure( op, "Only Ranked Tensor types are supported in TCP"); - auto inputAType = op.getSelf() - .getType() - .template dyn_cast() - .getDtype(); - auto outputType = op.getType() - .template dyn_cast() - .getDtype(); + auto inputAType = + dyn_cast(op.getSelf().getType()) + .getDtype(); + auto outputType = + dyn_cast(op.getType()).getDtype(); if (isa(op) || isa(op)) { rhs = convertScalarOperandToTensor(rewriter, op, op.getOther(), @@ -97,10 +94,9 @@ class ConvertAtenAddSubOp : public OpConversionPattern { if (!rhs) return rewriter.notifyMatchFailure(op, "Unsupported rhs data type"); } else { - auto inputBType = op.getOther() - .getType() - .template dyn_cast() - .getDtype(); + auto inputBType = + dyn_cast(op.getOther().getType()) + .getDtype(); rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType, rhs, resultType.getElementType()); } @@ -135,26 +131,23 @@ class ConvertAtenMulOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); - RankedTensorType lhsType = lhs.getType().dyn_cast(); + RankedTensorType lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getOther(); - RankedTensorType resultType = - OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + RankedTensorType resultType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!lhsType || !resultType) return rewriter.notifyMatchFailure( op, "Only Ranked Tensor types are supported in TCP"); - auto inputAType = op.getSelf() - .getType() - .template dyn_cast() - .getDtype(); - auto outputType = op.getType() - .template dyn_cast() - .getDtype(); + auto inputAType = + dyn_cast(op.getSelf().getType()) + .getDtype(); + auto outputType = + dyn_cast(op.getType()).getDtype(); if (isa(op)) { rhs = convertScalarOperandToTensor(rewriter, op, op.getOther(), @@ -163,10 +156,9 @@ class ConvertAtenMulOp : public OpConversionPattern { if (!rhs) return rewriter.notifyMatchFailure(op, "Unsupported rhs data type"); } else { - auto inputBType = op.getOther() - .getType() - .template dyn_cast() - .getDtype(); + auto inputBType = + dyn_cast(op.getOther().getType()) + .getDtype(); rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType, rhs, resultType.getElementType()); } @@ -188,24 +180,24 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { matchAndRewrite(AtenBatchNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getInput(); - RankedTensorType inputType = input.getType().dyn_cast(); + RankedTensorType inputType = dyn_cast(input.getType()); Value weight = adaptor.getWeight(); - RankedTensorType weightType = weight.getType().dyn_cast(); + RankedTensorType weightType = dyn_cast(weight.getType()); Value bias = adaptor.getBias(); - RankedTensorType biasType = bias.getType().dyn_cast(); + RankedTensorType biasType = dyn_cast(bias.getType()); Value runningMean = adaptor.getRunningMean(); RankedTensorType runningMeanType = - runningMean.getType().dyn_cast(); + dyn_cast(runningMean.getType()); Value runningVar = adaptor.getRunningVar(); RankedTensorType runningVarType = - runningVar.getType().dyn_cast(); + dyn_cast(runningVar.getType()); RankedTensorType resultType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); if (!inputType || !weightType || !biasType || !runningMeanType || !runningVarType || !resultType) @@ -294,22 +286,19 @@ class ConvertAtenDivOp : public OpConversionPattern { Value rhs = adaptor.getOther(); - RankedTensorType resultType = - OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + RankedTensorType resultType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!lhsType || !resultType) return rewriter.notifyMatchFailure( op, "Only Ranked Tensor types are supported in TCP"); - auto inputAType = op.getSelf() - .getType() - .template dyn_cast() - .getDtype(); - auto outputType = op.getType() - .template dyn_cast() - .getDtype(); + auto inputAType = + dyn_cast(op.getSelf().getType()) + .getDtype(); + auto outputType = + dyn_cast(op.getType()).getDtype(); Type inputBType = nullptr; if (isa(op)) { @@ -321,10 +310,9 @@ class ConvertAtenDivOp : public OpConversionPattern { if (!rhs) return rewriter.notifyMatchFailure(op, "Unsupported rhs data type"); } else { - inputBType = op.getOther() - .getType() - .template dyn_cast() - .getDtype(); + inputBType = + dyn_cast(op.getOther().getType()) + .getDtype(); rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType, rhs, resultType.getElementType()); } @@ -368,7 +356,7 @@ class ConvertAtenClampOp : public OpConversionPattern { matchAndRewrite(AtenClampOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); - RankedTensorType inputType = input.getType().dyn_cast(); + RankedTensorType inputType = dyn_cast(input.getType()); if (!inputType) return rewriter.notifyMatchFailure( op, "Only Ranked Tensor types are supported in TCP"); @@ -390,15 +378,15 @@ class ConvertAtenClampOp : public OpConversionPattern { double floatValue; int64_t intValue; if (matchPattern(value, m_TorchConstantFloat(&floatValue))) { - if (elementType.isa()) + if (isa(elementType)) floatAttr = rewriter.getF32FloatAttr(floatValue); - else if (elementType.isa()) + else if (isa(elementType)) intAttr = rewriter.getI64IntegerAttr(static_cast(floatValue)); } else if (matchPattern(value, m_TorchConstantInt(&intValue))) { - if (elementType.isa()) + if (isa(elementType)) floatAttr = rewriter.getF32FloatAttr(static_cast(intValue)); - else if (elementType.isa()) + else if (isa(elementType)) intAttr = rewriter.getI64IntegerAttr(intValue); } else { llvm_unreachable("only float or integer constants are supported as min " @@ -430,7 +418,7 @@ class ConvertAtenReluOp : public OpConversionPattern { matchAndRewrite(AtenReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); - RankedTensorType inputType = input.getType().dyn_cast(); + RankedTensorType inputType = dyn_cast(input.getType()); if (!inputType) return rewriter.notifyMatchFailure( @@ -443,7 +431,7 @@ class ConvertAtenReluOp : public OpConversionPattern { FloatAttr minFloatAttr, maxFloatAttr; IntegerAttr minIntAttr, maxIntAttr; - if (elementType.isa()) + if (isa(elementType)) minFloatAttr = rewriter.getF32FloatAttr(0.0f); else minIntAttr = rewriter.getI64IntegerAttr(0); @@ -463,10 +451,10 @@ class ConvertAtenSqrtOp : public OpConversionPattern { matchAndRewrite(AtenSqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); - RankedTensorType inputType = input.getType().dyn_cast(); + RankedTensorType inputType = dyn_cast(input.getType()); RankedTensorType resultType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); if (!inputType || !resultType) return rewriter.notifyMatchFailure( @@ -478,13 +466,12 @@ class ConvertAtenSqrtOp : public OpConversionPattern { op, "Input tensor must have integer or floating-point datatype"); Value newInput = input; - if (elementType.isa()) { - auto inputDType = op.getSelf() - .getType() - .dyn_cast() - .getDtype(); + if (isa(elementType)) { + auto inputDType = + dyn_cast(op.getSelf().getType()) + .getDtype(); auto outputDType = - op.getType().dyn_cast().getDtype(); + dyn_cast(op.getType()).getDtype(); newInput = torch_to_tcp::castTensorToDtype(rewriter, inputDType, outputDType, input, resultType.getElementType()); @@ -504,7 +491,7 @@ class ConvertAtenLog1pOp : public OpConversionPattern { matchAndRewrite(AtenLog1pOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); - RankedTensorType inputType = input.getType().dyn_cast(); + RankedTensorType inputType = dyn_cast(input.getType()); if (!inputType) return rewriter.notifyMatchFailure( @@ -537,7 +524,7 @@ class ConvertAtenUnaryIntOrFpOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); - RankedTensorType inputType = input.getType().dyn_cast(); + RankedTensorType inputType = dyn_cast(input.getType()); if (!inputType) return rewriter.notifyMatchFailure( @@ -548,10 +535,9 @@ class ConvertAtenUnaryIntOrFpOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Input tensor must have integer or floating-point datatype"); - RankedTensorType resultType = - OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + RankedTensorType resultType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); rewriter.replaceOpWithNewOp(op, resultType, input); return success(); @@ -568,13 +554,13 @@ class ConvertAtenUnaryFpOnlyOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); - RankedTensorType inputType = input.getType().dyn_cast(); + RankedTensorType inputType = dyn_cast(input.getType()); if (!inputType) return rewriter.notifyMatchFailure( op, "Only Ranked Tensor types are supported in TCP"); - if (!inputType.getElementType().isa()) + if (!isa(inputType.getElementType())) return rewriter.notifyMatchFailure( op, "Input tensor must have floating-point datatype"); @@ -591,33 +577,31 @@ class ConvertAtenAtan2Op : public OpConversionPattern { matchAndRewrite(AtenAtan2Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); - RankedTensorType lhsType = lhs.getType().dyn_cast(); + RankedTensorType lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getOther(); - RankedTensorType rhsType = rhs.getType().dyn_cast(); + RankedTensorType rhsType = dyn_cast(rhs.getType()); RankedTensorType resultType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); if (!lhsType || !rhsType || !resultType) return rewriter.notifyMatchFailure( op, "Only Ranked Tensor types are supported in TCP"); - if (!lhsType.getElementType().isa() || - !rhsType.getElementType().isa()) + if (!isa(lhsType.getElementType()) || + !isa(rhsType.getElementType())) return rewriter.notifyMatchFailure( op, "Input tensors must have floating-point datatype"); - auto inputAType = op.getSelf() - .getType() - .dyn_cast() - .getDtype(); - auto inputBType = op.getOther() - .getType() - .dyn_cast() - .getDtype(); + auto inputAType = + dyn_cast(op.getSelf().getType()) + .getDtype(); + auto inputBType = + dyn_cast(op.getOther().getType()) + .getDtype(); auto outputType = - op.getType().dyn_cast().getDtype(); + dyn_cast(op.getType()).getDtype(); rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType, rhs, resultType.getElementType()); @@ -640,10 +624,10 @@ class ConvertAtenToDtypeOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op.getContext(); Value input = op.getSelf(); - auto inputType = input.getType().dyn_cast(); - auto outputType = op.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto outputType = dyn_cast(op.getType()); RankedTensorType resultType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); if (!inputType || !outputType) return rewriter.notifyMatchFailure( @@ -671,7 +655,7 @@ class ConvertAtenToDtypeOp : public OpConversionPattern { } // Only `none`, `contiguous` and `preserve` memory_format is supported. - if (!op.getMemoryFormat().getType().isa()) { + if (!isa(op.getMemoryFormat().getType())) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)) || @@ -682,15 +666,15 @@ class ConvertAtenToDtypeOp : public OpConversionPattern { "an integer constant with none, contiguous or preserve value"); } - if (inputType.getDtype().isa() && - outputType.getDtype().isa()) + if (isa(inputType.getDtype()) && + isa(outputType.getDtype())) // FP -> FP rewriter.replaceOpWithNewOp( op, resultType, adaptor.getSelf(), SignednessAttr{}, SignednessAttr{}); - else if (inputType.getDtype().isa()) { + else if (isa(inputType.getDtype())) { // FP -> INT - if (auto intType = outputType.getDtype().dyn_cast()) + if (auto intType = dyn_cast(outputType.getDtype())) rewriter.replaceOpWithNewOp( op, resultType, adaptor.getSelf(), SignednessAttr{}, torch_to_tcp::getTcpSignednessAttr(context, @@ -698,9 +682,9 @@ class ConvertAtenToDtypeOp : public OpConversionPattern { else return rewriter.notifyMatchFailure( op, "expect output type to be signless/signed/unsigned integer"); - } else if (outputType.getDtype().isa()) { + } else if (isa(outputType.getDtype())) { // INT -> FP - if (auto intType = inputType.getDtype().dyn_cast()) + if (auto intType = dyn_cast(inputType.getDtype())) rewriter.replaceOpWithNewOp( op, resultType, adaptor.getSelf(), torch_to_tcp::getTcpSignednessAttr(context, @@ -711,8 +695,8 @@ class ConvertAtenToDtypeOp : public OpConversionPattern { op, "expect input type to be signless/signed/unsigned integer"); } else { // INT -> INT - auto inIntType = inputType.getDtype().dyn_cast(); - auto outIntType = outputType.getDtype().dyn_cast(); + auto inIntType = dyn_cast(inputType.getDtype()); + auto outIntType = dyn_cast(outputType.getDtype()); if (inIntType && outIntType) rewriter.replaceOpWithNewOp( op, resultType, adaptor.getSelf(), diff --git a/lib/Conversion/TorchToTcp/Misc.cpp b/lib/Conversion/TorchToTcp/Misc.cpp index fdde4d4..4cf633f 100644 --- a/lib/Conversion/TorchToTcp/Misc.cpp +++ b/lib/Conversion/TorchToTcp/Misc.cpp @@ -40,7 +40,7 @@ bool checkZerosOnesOpAttributes(AtenOpT op, RankedTensorType outType) { // check default layout int64_t memoryLayout; - if (!op.getLayout().getType().template isa() && + if (!isa(op.getLayout().getType()) && (!matchPattern(op.getLayout(), m_TorchConstantInt(&memoryLayout)) || memoryLayout != 0)) { return false; @@ -48,7 +48,7 @@ bool checkZerosOnesOpAttributes(AtenOpT op, RankedTensorType outType) { // check default pin_memory bool pinMemory; - if (!op.getPinMemory().getType().template isa() && + if (!isa(op.getPinMemory().getType()) && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { return false; @@ -67,7 +67,7 @@ class ConvertAtenBroadcastLikeOps : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); - RankedTensorType inputType = input.getType().dyn_cast(); + RankedTensorType inputType = dyn_cast(input.getType()); ArrayRef inputShape = inputType.getShape(); @@ -84,7 +84,7 @@ class ConvertAtenBroadcastLikeOps : public OpConversionPattern { SmallVector axes; SmallVector resultShape; ArrayRef newInputShape = - input.getType().dyn_cast().getShape(); + dyn_cast(input.getType()).getShape(); for (int64_t i = 0; i < static_cast(newDimSizes.size()); ++i) { Value newDimSize = newDimSizes[i]; @@ -133,10 +133,9 @@ class ConvertAtenBroadcastLikeOps : public OpConversionPattern { rewriter.replaceOp(op, input); return success(); } - RankedTensorType resultType = - OpConversionPattern::getTypeConverter() - ->convertType(op->getResult(0).getType()) - .template cast(); + RankedTensorType resultType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op->getResult(0).getType())); auto axesAttr = rewriter.getI64ArrayAttr(axes); rewriter.replaceOpWithNewOp(op, resultType, input, resultShape, axesAttr); @@ -153,9 +152,9 @@ class ConvertValueTensorLiteralOp matchAndRewrite(ValueTensorLiteralOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { RankedTensorType resultType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); - if (auto elements = op.getValueAttr().dyn_cast()) { + if (auto elements = dyn_cast(op.getValueAttr())) { Type elementType = resultType.getElementType(); auto denseIntAttr = elements.mapValues(elementType, [&](const APInt &v) { return APInt(elementType.getIntOrFloatBitWidth(), v.getSExtValue()); @@ -189,7 +188,7 @@ class ConvertAtenSizeIntOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value self = adaptor.getSelf(); - auto type = self.getType().cast(); + auto type = cast(self.getType()); if (!isa(op->getOperand(1).getDefiningOp())) { return rewriter.notifyMatchFailure(op, "dim must be a constant int"); } @@ -219,9 +218,9 @@ class ConvertAtenZerosOnesOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); Type outElemTy = outType.getElementType(); if (!checkZerosOnesOpAttributes(op, outType)) { @@ -265,13 +264,13 @@ class ConvertAtenZerosOnesLikeOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); Type outElemTy = outType.getElementType(); // TODO: Check the attribute for input vtensor - if (!op.getMemoryFormat().getType().template isa()) + if (!isa(op.getMemoryFormat().getType())) return rewriter.notifyMatchFailure( op, "Only default memory format is supported"); @@ -287,7 +286,7 @@ class ConvertAtenZerosOnesLikeOp : public OpConversionPattern { Value resultOp = torch_to_tcp::broadcast0DOr1DToNDAndMatchShape( rewriter, constOp, input, - constOp.getType().cast().getElementType()); + cast(constOp.getType()).getElementType()); rewriter.replaceOp(op, resultOp); diff --git a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp index 37b42ab..6ef1304 100644 --- a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp +++ b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp @@ -79,7 +79,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { getTypeConverter()}; helper.addOperand("input", adaptor.getInput()); helper.addOperand("weight", adaptor.getWeight()); - if (!adaptor.getBias().getType().isa()) { + if (!isa(adaptor.getBias().getType())) { helper.addOperand("bias", adaptor.getBias()); } @@ -131,7 +131,7 @@ class ConvertAtenFakeQuantizePerTensorAffineTensorQparamsOp helper.addIntAttr("quant_max", op.getQuantMax()); // scale - auto scaleTy = adaptor.getScale().getType().dyn_cast(); + auto scaleTy = dyn_cast(adaptor.getScale().getType()); if (!scaleTy || scaleTy.getShape().size() != 1 || scaleTy.getNumElements() != 1) // scale should be a [1] tensor. @@ -140,7 +140,7 @@ class ConvertAtenFakeQuantizePerTensorAffineTensorQparamsOp // zero_point auto zeroPointTy = - adaptor.getZeroPoint().getType().dyn_cast(); + dyn_cast(adaptor.getZeroPoint().getType()); if (!zeroPointTy || zeroPointTy.getShape().size() != 1 || zeroPointTy.getNumElements() != scaleTy.getNumElements()) // zero_point should be a [1] tensor. @@ -168,7 +168,7 @@ class ConvertAtenFakeQuantizePerChannelAffineOp helper.addIntAttr("quant_max", op.getQuantMax()); // scale - auto scaleTy = adaptor.getScale().getType().dyn_cast(); + auto scaleTy = dyn_cast(adaptor.getScale().getType()); if (!scaleTy || scaleTy.getShape().size() != 1) // scale should be a [C] tensor. return rewriter.notifyMatchFailure(op, "Unsupported scale type or size"); @@ -176,7 +176,7 @@ class ConvertAtenFakeQuantizePerChannelAffineOp // zero_point auto zeroPointTy = - adaptor.getZeroPoint().getType().dyn_cast(); + dyn_cast(adaptor.getZeroPoint().getType()); if (!zeroPointTy || zeroPointTy.getShape().size() != 1 || zeroPointTy.getNumElements() != scaleTy.getNumElements()) // zero_point should be a [C] tensor. @@ -428,8 +428,7 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality( // Torch -> TOSA supports only 2D convolutions; map the rest to // TCP custom_op instead. auto is2dConvOp = [](AtenConvolutionOp op) { - auto inputTy = - op.getInput().getType().cast(); + auto inputTy = cast(op.getInput().getType()); return inputTy.getSizes().size() == 4; }; diff --git a/lib/Conversion/TorchToTcp/Utils.cpp b/lib/Conversion/TorchToTcp/Utils.cpp index 65fe948..499eb62 100644 --- a/lib/Conversion/TorchToTcp/Utils.cpp +++ b/lib/Conversion/TorchToTcp/Utils.cpp @@ -51,7 +51,7 @@ Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter, Value input, int64_t rankIncrease) { if (rankIncrease == 0) return input; - RankedTensorType inputType = input.getType().cast(); + RankedTensorType inputType = cast(input.getType()); SmallVector reassociationMap(inputType.getRank()); if (inputType.getRank() > 0) { @@ -77,7 +77,7 @@ Value broadcastRankInTrailingDims(ConversionPatternRewriter &rewriter, Value input, int64_t rankIncrease) { if (rankIncrease == 0) return input; - RankedTensorType inputType = input.getType().cast(); + RankedTensorType inputType = cast(input.getType()); SmallVector reassociationMap(inputType.getRank()); if (inputType.getRank() > 0) { @@ -100,7 +100,7 @@ Value broadcastRankInTrailingDims(ConversionPatternRewriter &rewriter, Value broadcastRank0Dor1DToND(ConversionPatternRewriter &rewriter, Value input, int64_t targetRank, int64_t axisInOutput) { - RankedTensorType inputType = input.getType().cast(); + RankedTensorType inputType = cast(input.getType()); auto inputRank = inputType.getRank(); assert(inputRank < 2 && "Only 0D and 1D tensors are supported!"); @@ -127,10 +127,10 @@ Value broadcastRank0Dor1DToND(ConversionPatternRewriter &rewriter, Value input, Value broadcastShapeExceptDims(ConversionPatternRewriter &rewriter, Value input, Value target, llvm::SmallDenseSet dimsToExclude) { - RankedTensorType inputType = input.getType().cast(); + RankedTensorType inputType = cast(input.getType()); auto inputShape = inputType.getShape(); - RankedTensorType targetType = target.getType().cast(); + RankedTensorType targetType = cast(target.getType()); auto targetShape = targetType.getShape(); SmallVector axes; @@ -252,8 +252,8 @@ broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc, std::pair broadcastToMatchShape(ConversionPatternRewriter &rewriter, Value lhs, Value rhs) { - RankedTensorType inputAType = lhs.getType().cast(); - RankedTensorType inputBType = rhs.getType().cast(); + RankedTensorType inputAType = cast(lhs.getType()); + RankedTensorType inputBType = cast(rhs.getType()); Value resultA = lhs; Value resultB = rhs; @@ -264,8 +264,8 @@ broadcastToMatchShape(ConversionPatternRewriter &rewriter, Value lhs, resultA = broadcastRankInLeadingDims( rewriter, resultA, inputBType.getRank() - inputAType.getRank()); - inputAType = resultA.getType().cast(); - inputBType = resultB.getType().cast(); + inputAType = cast(resultA.getType()); + inputBType = cast(resultB.getType()); SmallVector inputAShape(inputAType.getShape().begin(), inputAType.getShape().end()); SmallVector inputBShape(inputBType.getShape().begin(), @@ -312,8 +312,8 @@ broadcastToMatchShape(ConversionPatternRewriter &rewriter, Value lhs, Value broadcast0DOr1DToNDAndMatchShape(ConversionPatternRewriter &rewriter, Value input, Value target, Type resultType, int64_t axisInOutput) { - RankedTensorType inputType = input.getType().cast(); - RankedTensorType targetType = target.getType().cast(); + RankedTensorType inputType = cast(input.getType()); + RankedTensorType targetType = cast(target.getType()); auto inputRank = inputType.getRank(); auto targetRank = targetType.getRank(); @@ -367,9 +367,9 @@ Value broadcast0DOr1DFromShape(ConversionPatternRewriter &rewriter, Value input, ArrayRef targetVal, SmallVector resultShape, int64_t axisInOutput) { - RankedTensorType inputType = input.getType().cast(); + RankedTensorType inputType = cast(input.getType()); auto inputRank = inputType.getRank(); - RankedTensorType targetType = input.getType().cast(); + RankedTensorType targetType = cast(input.getType()); int64_t targetRank = 0; SmallVector dimSizes; @@ -417,15 +417,15 @@ Value castTensorToDtype(ConversionPatternRewriter &rewriter, Type srcType, if (srcType == dstType) return input; - RankedTensorType inputType = input.getType().cast(); + RankedTensorType inputType = cast(input.getType()); auto resultType = inputType.cloneWith(inputType.getShape(), convertedType); SignednessAttr inputSignedness; SignednessAttr outputSignedness; - if (auto inputIntType = srcType.dyn_cast()) + if (auto inputIntType = dyn_cast(srcType)) inputSignedness = getTcpSignednessAttr(input.getDefiningOp()->getContext(), inputIntType.getSignedness()); - if (auto outputIntType = dstType.dyn_cast()) + if (auto outputIntType = dyn_cast(dstType)) outputSignedness = getTcpSignednessAttr(input.getDefiningOp()->getContext(), outputIntType.getSignedness()); return rewriter.create(input.getDefiningOp()->getLoc(), diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index 7dbd2bb..ff1a3b9 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -21,9 +21,9 @@ namespace mlir::tcp { LogicalResult ClampOp::verify() { - auto inputType = getIn().getType().cast(); + auto inputType = cast(getIn().getType()); - if (inputType.getElementType().isa()) { + if (isa(inputType.getElementType())) { if (getMinInt() || getMaxInt()) return emitOpError("failed to verify that int min / max attributes " "must not be set when input is a float tensor"); @@ -32,7 +32,7 @@ LogicalResult ClampOp::verify() { "attributes must be set"); } - if (inputType.getElementType().isa()) { + if (isa(inputType.getElementType())) { if (getMinFloat() || getMaxFloat()) return emitOpError("failed to verify that float min / max attributes " "must not be set when input is an int tensor"); @@ -46,7 +46,7 @@ LogicalResult ClampOp::verify() { LogicalResult BroadcastOp::verify() { auto compareIntAttr = [](Attribute v1, Attribute v2) { - return v1.cast().getInt() < v2.cast().getInt(); + return cast(v1).getInt() < cast(v2).getInt(); }; auto getInt = [](IntegerAttr v) { return v.getInt(); }; @@ -134,20 +134,20 @@ LogicalResult ConstOp::verify() { } LogicalResult CastOp::verify() { - auto inputType = getIn().getType().cast(); - auto outputType = getOut().getType().cast(); + auto inputType = cast(getIn().getType()); + auto outputType = cast(getOut().getType()); if (!inputType.getElementType().isIntOrFloat() || !outputType.getElementType().isIntOrFloat()) return emitOpError("Cast Op must have integer or floating-point datatype"); - if (inputType.getElementType().isa()) { + if (isa(inputType.getElementType())) { if (getInIntSignedness()) return emitOpError( "in_int_signedness attr should not set when input is FP"); } - if (inputType.getElementType().isa()) { + if (isa(inputType.getElementType())) { if (!getInIntSignedness()) return emitOpError( "in_int_signedness attr must be set when input is INT"); @@ -157,13 +157,13 @@ LogicalResult CastOp::verify() { "Signedness::Signless when input is i1"); } - if (outputType.getElementType().isa()) { + if (isa(outputType.getElementType())) { if (getOutIntSignedness()) return emitOpError( "out_int_signedness attr should not set when output is FP"); } - if (outputType.getElementType().isa()) { + if (isa(outputType.getElementType())) { if (!getOutIntSignedness()) return emitOpError( "out_int_signedness attr must be set when output is INT"); diff --git a/lib/Dialect/Transforms/EliminateUnusedTorchOpsPass.cpp b/lib/Dialect/Transforms/EliminateUnusedTorchOpsPass.cpp new file mode 100644 index 0000000..8da3c75 --- /dev/null +++ b/lib/Dialect/Transforms/EliminateUnusedTorchOpsPass.cpp @@ -0,0 +1,74 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir-tcp/Dialect/Transforms/EliminateUnusedTorchOpsPass.h" + +#include "./PassDetail.h" + +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace mlir::tcp { + +namespace { + +bool isTargetOpInEraselist(Operation *op) { + if (isa(op->getDialect())) + return true; + return false; +} + +class RemoveTargetedTorchOps : public RewritePattern { +public: + RemoveTargetedTorchOps(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!op->use_empty()) + return failure(); + if (!isTargetOpInEraselist(op)) + return failure(); + // These contain dynamic shape annotations, do not DCE + if (isa(op)) + return failure(); + + rewriter.eraseOp(op); + return success(); + } +}; + +class EliminateUnusedTorchOpsPass + : public EliminateUnusedTorchOpsBase { + void runOnOperation() override { + auto moduleOp = getOperation(); + MLIRContext *context = &getContext(); + + RewritePatternSet patterns(context); + patterns.add(context); + + if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> createEliminateUnusedTorchOpsPass() { + return std::make_unique(); +} + +} // namespace mlir::tcp diff --git a/lib/Dialect/Transforms/Passes.cpp b/lib/Dialect/Transforms/Passes.cpp index dd1dab6..e2ef20a 100644 --- a/lib/Dialect/Transforms/Passes.cpp +++ b/lib/Dialect/Transforms/Passes.cpp @@ -9,6 +9,7 @@ #include "mlir-tcp/Dialect/Transforms/Passes.h" #include "mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h" +#include "mlir-tcp/Dialect/Transforms/EliminateUnusedTorchOpsPass.h" #include "mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h" #include "mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h" #include "mlir-tcp/Dialect/Transforms/TransformTensorOps.h" diff --git a/lib/Pipeline/Pipeline.cpp b/lib/Pipeline/Pipeline.cpp index 6ff1d68..66b77aa 100644 --- a/lib/Pipeline/Pipeline.cpp +++ b/lib/Pipeline/Pipeline.cpp @@ -15,6 +15,7 @@ #include "mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h" #include "mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h" #include "mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h" +#include "mlir-tcp/Dialect/Transforms/EliminateUnusedTorchOpsPass.h" #include "mlir-tcp/Dialect/Transforms/TransformTensorOps.h" #include "mlir-tcp/Dialect/Transforms/VerifyTcpBackendContractPass.h" @@ -44,6 +45,9 @@ using namespace mlir; static void createTorchBackendToTcpBackendPipeline(OpPassManager &pm) { + // Remove unused / unnecessary torch ops first + pm.addPass(tcp::createEliminateUnusedTorchOpsPass()); + // Torch -> TCP conversions. pm.addNestedPass(tcp::createConvertTorchToTcpPass()); pm.addNestedPass(tcp::createConvertTorchToTcpCustomOpPass()); diff --git a/local_repos.bzl b/local_repos.bzl index 31b2584..ec33f5e 100644 --- a/local_repos.bzl +++ b/local_repos.bzl @@ -21,16 +21,8 @@ def use_local_torch_mlir_repo(): # `local_torch_mlir_repo_path()` return False -def use_local_stablehlo_repo(): - # Change this to return True to have mlir-tcp use the source tree at - # `local_stablehlo_repo_path()` - return False - def local_llvm_repo_path(): return "./third_party/llvm-project" def local_torch_mlir_repo_path(): return "./third_party/torch-mlir" - -def local_stablehlo_repo_path(): - return "./third_party/stablehlo" diff --git a/requirements_lock.txt b/requirements_lock.txt index e23abc0..d2d088d 100644 --- a/requirements_lock.txt +++ b/requirements_lock.txt @@ -7,17 +7,17 @@ --find-links https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels --find-links https://download.pytorch.org/whl/nightly/cpu/torch/ -filelock==3.17.0 \ - --hash=sha256:533dc2f7ba78dc2f0f531fc6c4940addf7b70a481e269a5a3b93be94ffbe8338 \ - --hash=sha256:ee4e77401ef576ebb38cd7f13b9b28893194acc20a8e68e18730ba9c0e54660e +filelock==3.18.0 \ + --hash=sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2 \ + --hash=sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de # via torch -fsspec==2025.2.0 \ - --hash=sha256:1c24b16eaa0a1798afa0337aa0db9b256718ab2a89c425371f5628d22c3b6afd \ - --hash=sha256:9de2ad9ce1f85e1931858535bc882543171d197001a0a5eb2ddc04f1781ab95b +fsspec==2025.5.1 \ + --hash=sha256:24d3a2e663d5fc735ab256263c4075f374a174c3410c0b25e5bd1970bceaa462 \ + --hash=sha256:2e55e47a540b91843b755e83ded97c6e897fa0942b11490113f09e9c443c2475 # via torch -jinja2==3.1.5 \ - --hash=sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb \ - --hash=sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb +jinja2==3.1.6 \ + --hash=sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d \ + --hash=sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67 # via torch markupsafe==3.0.2 \ --hash=sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4 \ @@ -90,104 +90,107 @@ networkx==3.4.2 \ --hash=sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1 \ --hash=sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f # via torch -numpy==2.2.3 \ - --hash=sha256:0391ea3622f5c51a2e29708877d56e3d276827ac5447d7f45e9bc4ade8923c52 \ - --hash=sha256:12c045f43b1d2915eca6b880a7f4a256f59d62df4f044788c8ba67709412128d \ - --hash=sha256:136553f123ee2951bfcfbc264acd34a2fc2f29d7cdf610ce7daf672b6fbaa693 \ - --hash=sha256:1402da8e0f435991983d0a9708b779f95a8c98c6b18a171b9f1be09005e64d9d \ - --hash=sha256:16372619ee728ed67a2a606a614f56d3eabc5b86f8b615c79d01957062826ca8 \ - --hash=sha256:1ad78ce7f18ce4e7df1b2ea4019b5817a2f6a8a16e34ff2775f646adce0a5027 \ - --hash=sha256:1b416af7d0ed3271cad0f0a0d0bee0911ed7eba23e66f8424d9f3dfcdcae1304 \ - --hash=sha256:1f45315b2dc58d8a3e7754fe4e38b6fce132dab284a92851e41b2b344f6441c5 \ - --hash=sha256:2376e317111daa0a6739e50f7ee2a6353f768489102308b0d98fcf4a04f7f3b5 \ - --hash=sha256:23c9f4edbf4c065fddb10a4f6e8b6a244342d95966a48820c614891e5059bb50 \ - --hash=sha256:246535e2f7496b7ac85deffe932896a3577be7af8fb7eebe7146444680297e9a \ - --hash=sha256:2e8da03bd561504d9b20e7a12340870dfc206c64ea59b4cfee9fceb95070ee94 \ - --hash=sha256:34c1b7e83f94f3b564b35f480f5652a47007dd91f7c839f404d03279cc8dd021 \ - --hash=sha256:39261798d208c3095ae4f7bc8eaeb3481ea8c6e03dc48028057d3cbdbdb8937e \ - --hash=sha256:3b787adbf04b0db1967798dba8da1af07e387908ed1553a0d6e74c084d1ceafe \ - --hash=sha256:3c2ec8a0f51d60f1e9c0c5ab116b7fc104b165ada3f6c58abf881cb2eb16044d \ - --hash=sha256:435e7a933b9fda8126130b046975a968cc2d833b505475e588339e09f7672890 \ - --hash=sha256:4d8335b5f1b6e2bce120d55fb17064b0262ff29b459e8493d1785c18ae2553b8 \ - --hash=sha256:4d9828d25fb246bedd31e04c9e75714a4087211ac348cb39c8c5f99dbb6683fe \ - --hash=sha256:52659ad2534427dffcc36aac76bebdd02b67e3b7a619ac67543bc9bfe6b7cdb1 \ - --hash=sha256:5266de33d4c3420973cf9ae3b98b54a2a6d53a559310e3236c4b2b06b9c07d4e \ - --hash=sha256:5521a06a3148686d9269c53b09f7d399a5725c47bbb5b35747e1cb76326b714b \ - --hash=sha256:596140185c7fa113563c67c2e894eabe0daea18cf8e33851738c19f70ce86aeb \ - --hash=sha256:5b732c8beef1d7bc2d9e476dbba20aaff6167bf205ad9aa8d30913859e82884b \ - --hash=sha256:5ebeb7ef54a7be11044c33a17b2624abe4307a75893c001a4800857956b41094 \ - --hash=sha256:712a64103d97c404e87d4d7c47fb0c7ff9acccc625ca2002848e0d53288b90ea \ - --hash=sha256:7678556eeb0152cbd1522b684dcd215250885993dd00adb93679ec3c0e6e091c \ - --hash=sha256:77974aba6c1bc26e3c205c2214f0d5b4305bdc719268b93e768ddb17e3fdd636 \ - --hash=sha256:783145835458e60fa97afac25d511d00a1eca94d4a8f3ace9fe2043003c678e4 \ - --hash=sha256:7bfdb06b395385ea9b91bf55c1adf1b297c9fdb531552845ff1d3ea6e40d5aba \ - --hash=sha256:7c8dde0ca2f77828815fd1aedfdf52e59071a5bae30dac3b4da2a335c672149a \ - --hash=sha256:83807d445817326b4bcdaaaf8e8e9f1753da04341eceec705c001ff342002e5d \ - --hash=sha256:87eed225fd415bbae787f93a457af7f5990b92a334e346f72070bf569b9c9c95 \ - --hash=sha256:8fb62fe3d206d72fe1cfe31c4a1106ad2b136fcc1606093aeab314f02930fdf2 \ - --hash=sha256:95172a21038c9b423e68be78fd0be6e1b97674cde269b76fe269a5dfa6fadf0b \ - --hash=sha256:9f48ba6f6c13e5e49f3d3efb1b51c8193215c42ac82610a04624906a9270be6f \ - --hash=sha256:a0c03b6be48aaf92525cccf393265e02773be8fd9551a2f9adbe7db1fa2b60f1 \ - --hash=sha256:a5ae282abe60a2db0fd407072aff4599c279bcd6e9a2475500fc35b00a57c532 \ - --hash=sha256:aee2512827ceb6d7f517c8b85aa5d3923afe8fc7a57d028cffcd522f1c6fd082 \ - --hash=sha256:c8b0451d2ec95010d1db8ca733afc41f659f425b7f608af569711097fd6014e2 \ - --hash=sha256:c9aa4496fd0e17e3843399f533d62857cef5900facf93e735ef65aa4bbc90ef0 \ - --hash=sha256:cbc6472e01952d3d1b2772b720428f8b90e2deea8344e854df22b0618e9cce71 \ - --hash=sha256:cdfe0c22692a30cd830c0755746473ae66c4a8f2e7bd508b35fb3b6a0813d787 \ - --hash=sha256:cf802eef1f0134afb81fef94020351be4fe1d6681aadf9c5e862af6602af64ef \ - --hash=sha256:d42f9c36d06440e34226e8bd65ff065ca0963aeecada587b937011efa02cdc9d \ - --hash=sha256:d5b47c440210c5d1d67e1cf434124e0b5c395eee1f5806fdd89b553ed1acd0a3 \ - --hash=sha256:d9b4a8148c57ecac25a16b0e11798cbe88edf5237b0df99973687dd866f05e1b \ - --hash=sha256:daf43a3d1ea699402c5a850e5313680ac355b4adc9770cd5cfc2940e7861f1bf \ - --hash=sha256:dbdc15f0c81611925f382dfa97b3bd0bc2c1ce19d4fe50482cb0ddc12ba30020 \ - --hash=sha256:deaa09cd492e24fd9b15296844c0ad1b3c976da7907e1c1ed3a0ad21dded6f76 \ - --hash=sha256:e37242f5324ffd9f7ba5acf96d774f9276aa62a966c0bad8dae692deebec7716 \ - --hash=sha256:ed2cf9ed4e8ebc3b754d398cba12f24359f018b416c380f577bbae112ca52fc9 \ - --hash=sha256:f2712c5179f40af9ddc8f6727f2bd910ea0eb50206daea75f58ddd9fa3f715bb \ - --hash=sha256:f4ca91d61a4bf61b0f2228f24bbfa6a9facd5f8af03759fe2a655c50ae2c6610 \ - --hash=sha256:f6b3dfc7661f8842babd8ea07e9897fe3d9b69a1d7e5fbb743e4160f9387833b +numpy==2.2.6 \ + --hash=sha256:038613e9fb8c72b0a41f025a7e4c3f0b7a1b5d768ece4796b674c8f3fe13efff \ + --hash=sha256:0678000bb9ac1475cd454c6b8c799206af8107e310843532b04d49649c717a47 \ + --hash=sha256:0811bb762109d9708cca4d0b13c4f67146e3c3b7cf8d34018c722adb2d957c84 \ + --hash=sha256:0b605b275d7bd0c640cad4e5d30fa701a8d59302e127e5f79138ad62762c3e3d \ + --hash=sha256:0bca768cd85ae743b2affdc762d617eddf3bcf8724435498a1e80132d04879e6 \ + --hash=sha256:1bc23a79bfabc5d056d106f9befb8d50c31ced2fbc70eedb8155aec74a45798f \ + --hash=sha256:287cc3162b6f01463ccd86be154f284d0893d2b3ed7292439ea97eafa8170e0b \ + --hash=sha256:37c0ca431f82cd5fa716eca9506aefcabc247fb27ba69c5062a6d3ade8cf8f49 \ + --hash=sha256:37e990a01ae6ec7fe7fa1c26c55ecb672dd98b19c3d0e1d1f326fa13cb38d163 \ + --hash=sha256:389d771b1623ec92636b0786bc4ae56abafad4a4c513d36a55dce14bd9ce8571 \ + --hash=sha256:3d70692235e759f260c3d837193090014aebdf026dfd167834bcba43e30c2a42 \ + --hash=sha256:41c5a21f4a04fa86436124d388f6ed60a9343a6f767fced1a8a71c3fbca038ff \ + --hash=sha256:481b49095335f8eed42e39e8041327c05b0f6f4780488f61286ed3c01368d491 \ + --hash=sha256:4eeaae00d789f66c7a25ac5f34b71a7035bb474e679f410e5e1a94deb24cf2d4 \ + --hash=sha256:55a4d33fa519660d69614a9fad433be87e5252f4b03850642f88993f7b2ca566 \ + --hash=sha256:5a6429d4be8ca66d889b7cf70f536a397dc45ba6faeb5f8c5427935d9592e9cf \ + --hash=sha256:5bd4fc3ac8926b3819797a7c0e2631eb889b4118a9898c84f585a54d475b7e40 \ + --hash=sha256:5beb72339d9d4fa36522fc63802f469b13cdbe4fdab4a288f0c441b74272ebfd \ + --hash=sha256:6031dd6dfecc0cf9f668681a37648373bddd6421fff6c66ec1624eed0180ee06 \ + --hash=sha256:71594f7c51a18e728451bb50cc60a3ce4e6538822731b2933209a1f3614e9282 \ + --hash=sha256:74d4531beb257d2c3f4b261bfb0fc09e0f9ebb8842d82a7b4209415896adc680 \ + --hash=sha256:7befc596a7dc9da8a337f79802ee8adb30a552a94f792b9c9d18c840055907db \ + --hash=sha256:894b3a42502226a1cac872f840030665f33326fc3dac8e57c607905773cdcde3 \ + --hash=sha256:8e41fd67c52b86603a91c1a505ebaef50b3314de0213461c7a6e99c9a3beff90 \ + --hash=sha256:8e9ace4a37db23421249ed236fdcdd457d671e25146786dfc96835cd951aa7c1 \ + --hash=sha256:8fc377d995680230e83241d8a96def29f204b5782f371c532579b4f20607a289 \ + --hash=sha256:9551a499bf125c1d4f9e250377c1ee2eddd02e01eac6644c080162c0c51778ab \ + --hash=sha256:b0544343a702fa80c95ad5d3d608ea3599dd54d4632df855e4c8d24eb6ecfa1c \ + --hash=sha256:b093dd74e50a8cba3e873868d9e93a85b78e0daf2e98c6797566ad8044e8363d \ + --hash=sha256:b412caa66f72040e6d268491a59f2c43bf03eb6c96dd8f0307829feb7fa2b6fb \ + --hash=sha256:b4f13750ce79751586ae2eb824ba7e1e8dba64784086c98cdbbcc6a42112ce0d \ + --hash=sha256:b64d8d4d17135e00c8e346e0a738deb17e754230d7e0810ac5012750bbd85a5a \ + --hash=sha256:ba10f8411898fc418a521833e014a77d3ca01c15b0c6cdcce6a0d2897e6dbbdf \ + --hash=sha256:bd48227a919f1bafbdda0583705e547892342c26fb127219d60a5c36882609d1 \ + --hash=sha256:c1f9540be57940698ed329904db803cf7a402f3fc200bfe599334c9bd84a40b2 \ + --hash=sha256:c820a93b0255bc360f53eca31a0e676fd1101f673dda8da93454a12e23fc5f7a \ + --hash=sha256:ce47521a4754c8f4593837384bd3424880629f718d87c5d44f8ed763edd63543 \ + --hash=sha256:d042d24c90c41b54fd506da306759e06e568864df8ec17ccc17e9e884634fd00 \ + --hash=sha256:de749064336d37e340f640b05f24e9e3dd678c57318c7289d222a8a2f543e90c \ + --hash=sha256:e1dda9c7e08dc141e0247a5b8f49cf05984955246a327d4c48bda16821947b2f \ + --hash=sha256:e29554e2bef54a90aa5cc07da6ce955accb83f21ab5de01a62c8478897b264fd \ + --hash=sha256:e3143e4451880bed956e706a3220b4e5cf6172ef05fcc397f6f36a550b1dd868 \ + --hash=sha256:e8213002e427c69c45a52bbd94163084025f533a55a59d6f9c5b820774ef3303 \ + --hash=sha256:efd28d4e9cd7d7a8d39074a4d44c63eda73401580c5c76acda2ce969e0a38e83 \ + --hash=sha256:f0fd6321b839904e15c46e0d257fdd101dd7f530fe03fd6359c1ea63738703f3 \ + --hash=sha256:f1372f041402e37e5e633e586f62aa53de2eac8d98cbfb822806ce4bbefcb74d \ + --hash=sha256:f2618db89be1b4e05f7a1a847a9c1c0abd63e63a1607d892dd54668dd92faf87 \ + --hash=sha256:f447e6acb680fd307f40d3da4852208af94afdfab89cf850986c3ca00562f4fa \ + --hash=sha256:f92729c95468a2f4f15e9bb94c432a9229d0d50de67304399627a943201baa2f \ + --hash=sha256:f9f1adb22318e121c5c69a09142811a201ef17ab257a1e66ca3025065b7f53ae \ + --hash=sha256:fc0c5673685c508a142ca65209b4e79ed6740a4ed6b2267dbba90f34b0b3cfda \ + --hash=sha256:fc7b73d02efb0e18c000e9ad8b83480dfcd5dfd11065997ed4c6747470ae8915 \ + --hash=sha256:fd83c01228a688733f1ded5201c678f0c53ecc1006ffbc404db9f7a899ac6249 \ + --hash=sha256:fe27749d33bb772c80dcd84ae7e8df2adc920ae8297400dabec45f0dedb3f6de \ + --hash=sha256:fee4236c876c4e8369388054d02d0e9bb84821feb1a64dd59e137e6511a551f8 # via # -r requirements.txt # torch-mlir -packaging==24.2 \ - --hash=sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759 \ - --hash=sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f +packaging==25.0 \ + --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ + --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f # via torch-mlir -sympy==1.13.3 \ - --hash=sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73 \ - --hash=sha256:b27fd2c6530e0ab39e275fc9b683895367e51d5da91baa8d3d64db2565fec4d9 +sympy==1.14.0 \ + --hash=sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517 \ + --hash=sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5 # via torch -torch==2.8.0.dev20250506+cpu \ - --hash=sha256:02abfdcdbb9ca15e3c561d31b1617f9d88f978af49b3b76cc048a5159c4bbb19 \ - --hash=sha256:0304c11aa1a404a664a776dea4b61dab31707d5fecc1e165ea17b1c780049911 \ - --hash=sha256:081ecdc2ced1285b92cce4684922710af244ccf4e4430d36c746f025e6872a30 \ - --hash=sha256:0bdc6883695004803ea0e062382d21e432168d7ee93e6f77375d34fc43778ca8 \ - --hash=sha256:1c82f3cd449bee2adcfc8c1dc25b087fc3ed9eba239ea46449e1a087ddbf5f97 \ - --hash=sha256:370ae6fb1c8c132c4578973eb6066f14d10fb6cdc05a89e44660fec15bbce9a4 \ - --hash=sha256:3c68844186c4d43db95f096b120b91c530c4e92540eeeece90e59fd6ec078f03 \ - --hash=sha256:4017473f0a77cd2774a3c8245032fb9979ac08f92831f94f70d9e22612e2d5c1 \ - --hash=sha256:4575a76e5459285311d1f94fb8835fec81d5509321192716fcff8631aa258ae3 \ - --hash=sha256:48c682f8f369b573045d5922e989812b77183f4020a750b3339c3e64e42fd733 \ - --hash=sha256:4a64fd103df112e2dbfb00ab04ffef839bc1838caa40ff8bf86647eb39daa7ad \ - --hash=sha256:5f2a251b87dc7a359fe5b83772cb2830e01b0d75a585edc1ffe659a3e59ae17b \ - --hash=sha256:690f44ae8974588810a6c58052e908fb1abc7c3d34e335faccec0baba852596b \ - --hash=sha256:810c8106d575256c6e429e26a8edf58e4ab43fea0b10c4d56eed011f0712ee90 \ - --hash=sha256:8701a35246db0aa148ea3bb6edb022a639c16115912d2dc90cbad9a56c0ded2e \ - --hash=sha256:a5974f2958d12d01577e206417ee4d04dc2f2275505d266323cf23e828e46d96 \ - --hash=sha256:b17959e888c65cef0765bfef3e4813f3dad7d3d55f73c976ca33a47d2ff875b5 \ - --hash=sha256:b91059dce8f9c97fce586b1367c91c64912ba0866e2213510a3ffe522cee3aee \ - --hash=sha256:c8a7058db5c6c478d2f93a14f911dcc045d5470ed0920797ec5a6008a0bce354 \ - --hash=sha256:ce7960db4fb7899626a4a94c361b0d2c80c9b3bd6907b929380a176df27b9908 \ - --hash=sha256:e23ba269a7f189dc65c1b0ff937beb0630dfbe9a810cd307d284a51cbc8409d6 \ - --hash=sha256:e2cccdcc64938ede25afc43efaa4e70fdf45709c3f2b48549adc0d163aa7fadf \ - --hash=sha256:fb30a20142ed498569649208d67f03e9e9f345be79ab340ceec734439a475d9a +torch==2.8.0.dev20250606+cpu \ + --hash=sha256:024466c9c8f4ef792ae4e5624d188304e101c1c6a1680952c86271a765607230 \ + --hash=sha256:144f8db1f5d4bd74eabab63dba52ae2d9301e72fb7d59c812e2fca8eff9ccaf4 \ + --hash=sha256:2465ea2a5a024e2c8bf8894a647aa8f7d1006d50f4476184cf4979ab2159fce2 \ + --hash=sha256:362213175bb34d5d7cb1f5b0b247f44aa83586c991f82638fb2d1814743134b2 \ + --hash=sha256:3be665e2d06e1528d0cafe3294d1a29e18bb8983af6184a9e911f511ffa331f1 \ + --hash=sha256:3f00f475532d07747325f0cffb248d388a694036841e0673284bc8fe6e856a9d \ + --hash=sha256:453aa6171bd1a01a3bcd50069a5863e9f626bebc9b6ed47bf9e884372624ecd6 \ + --hash=sha256:4cd340471009c300382c061f1c30d466c9fb8a9a5b1ad95cca654a4a2de1781a \ + --hash=sha256:58d9bab4b7c1951a7336a94c3b8634215e2ead479860d664ceb7f628d884f95e \ + --hash=sha256:6662f37741d7321533181891ad43a36be5390c096d49dcc94378fe347bc5d0a1 \ + --hash=sha256:7132697c8180f7caecb452319a13f6735ab38b26cc3d9aade24fda935d350ed4 \ + --hash=sha256:763f061d1d4d73c977aa9edebf3d1f98f2e32b2f1a84585dddb8bb4ae43d561f \ + --hash=sha256:7ffda2c87dfce179cc3933b5300288e1bbf1e57554cfdd91c0735621cdfbb923 \ + --hash=sha256:94380f5f4ca10cab93167f772c141618d088a65e93ff1c632b15249b85054183 \ + --hash=sha256:a306bb6c25b23bb268b3b174c29d6908dfe9287702ec911fb65dce8078526642 \ + --hash=sha256:aaad0ab77149acf9fdb68c42a926a95823170744dbe96587b8794c55967417fd \ + --hash=sha256:af57e85b177423d7cbd45194629ff4ab9206ab2ab8d1af4aa04e9d54734f1ccd \ + --hash=sha256:b7770b44b4ad60803c7ab66ff7441dc892f2af0d433188ab1f75b2e97e0f8920 \ + --hash=sha256:bf5679d571eb06eaf0ab9be95c9419f1538e879f756575cc86dcfa4cc07f26ae \ + --hash=sha256:cc5ea5ec7a2020f3c8899f6cf456ae1cc0e69cc7ad522409827c6aee1945bf18 \ + --hash=sha256:da64af888e572ef56f593a6d00ac5ab62a55a4ed0f45898b303659fbb5c3a9c7 \ + --hash=sha256:dad798b9049e677327e4403a70ff56fd50bafc6d915d47ebdec30c342c247a67 \ + --hash=sha256:e11044e9d23dcfcb8d4dd6806c0294afcece4c1888feb11920608a9890bd9c4b \ + --hash=sha256:e5d9fec8530dc58e3e7f871af6b4f17f58d1b5cd19f3d8d0f9be7dd1a3c1678c \ + --hash=sha256:ee57a73dc9ed97d98578ade99b3303e4ec3d9b1cb927cafc303af3fcc91e4c39 \ + --hash=sha256:f46b2aeed966e5224cc47cb35b34f00ff54f2f06d34b4c7ac552bcd0570534dc # via -r requirements.txt -torch-mlir==20250127.357 \ - --hash=sha256:43c2362b6a5265405ac5d2291982d6b0d83afafc7ee37165f4cc6b845dec4c15 \ - --hash=sha256:62dd44c74212ce772cf245be5610c2ed9cc60c0a51fcf1c3a2e4f1bd3245da29 \ - --hash=sha256:b72097a773b3a90adae71a92919868695d32bd487622a32486baa9a75b033bdd +torch-mlir==20250606.490 \ + --hash=sha256:1c00e4edb87455734f7bd079dee8a498ee052acb3248e8dff14285fc8fcb2df4 \ + --hash=sha256:8ff32632c5336b5cf85d470655739cb9750e2da4c75d4c8723455be416b1b681 \ + --hash=sha256:a909b17d4a367bba553d0884a410e338b55f0bbbc7863f296814c59725c66f81 # via -r requirements.txt -typing-extensions==4.12.2 \ - --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ - --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 +typing-extensions==4.14.0 \ + --hash=sha256:8676b788e32f02ab42d9e7c61324048ae4c6d844a399eebace3d4979d75ceef4 \ + --hash=sha256:a1514509136dd0b477638fc68d6a91497af5076466ad0fa6c338e44e359944af # via torch diff --git a/test/Conversion/StablehloToTcp/stablehlo.mlir b/test/Conversion/StablehloToTcp/stablehlo.mlir deleted file mode 100644 index a6c9db2..0000000 --- a/test/Conversion/StablehloToTcp/stablehlo.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: tcp-opt %s -convert-stablehlo-to-tcp | FileCheck %s - -// CHECK-LABEL: func.func @tanh( -// CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { -// CHECK: %[[TANH:.*]] = tcp.tanh %[[ARG0]] : tensor -// CHECK: return %[[TANH]] : tensor -// CHECK: } -func.func @tanh(%arg0: tensor) -> tensor { - %0 = stablehlo.tanh %arg0 : tensor - return %0 : tensor -} diff --git a/test/Dialect/eliminate_unused_torch_ops.mlir b/test/Dialect/eliminate_unused_torch_ops.mlir new file mode 100644 index 0000000..82c1398 --- /dev/null +++ b/test/Dialect/eliminate_unused_torch_ops.mlir @@ -0,0 +1,27 @@ +// RUN: tcp-opt %s -eliminate-unused-torch-ops | FileCheck %s + +// CHECK-LABEL: func.func @test_eliminate_unused_torch_ops( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,3],f32> { +// CHECK: %[[CONST2:.*]] = torch.constant.int 2 +// CHECK: %[[S0:.*]] = torch.symbolic_int "s35" {min_val = 0, max_val = 9223372036854775807} : !torch.int +// CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +// CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[ARG1]], %[[CONST2]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.int -> !torch.vtensor<[?,3],f32> +// CHECK: torch.bind_symbolic_shape %[[SUB]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +// CHECK: return %[[SUB]] : !torch.vtensor<[?,3],f32> +func.func @test_eliminate_unused_torch_ops(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,3],f32> { +%int2 = torch.constant.int 2 +%int0 = torch.constant.int 0 +%0 = torch.symbolic_int "s35" {min_val = 0, max_val = 9223372036854775807} : !torch.int +torch.bind_symbolic_shape %arg0, [%0], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +torch.bind_symbolic_shape %arg1, [%0], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +%1 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int +%2 = torch.aten.size.int %arg1, %int0 : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int +%3 = torch.aten.sub.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.int -> !torch.vtensor<[?,3],f32> +torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +%4 = torch.aten.eq.int %1, %2 : !torch.int, !torch.int -> !torch.bool +%5 = torch.aten.Int.bool %4 : !torch.bool -> !torch.int +%6 = torch.aten.Bool.int %5 : !torch.int -> !torch.bool +torch.runtime.assert %6, "Runtime assertion failed for expression Eq(s35, s58) on node 'eq_2'" +return %3 : !torch.vtensor<[?,3],f32> +} diff --git a/third_party/patches/torch-mlir-bazel-build.1.patch b/third_party/patches/torch-mlir-bazel-build.1.patch new file mode 100644 index 0000000..e493b38 --- /dev/null +++ b/third_party/patches/torch-mlir-bazel-build.1.patch @@ -0,0 +1,129 @@ +diff --git utils/bazel/torch-mlir-overlay/BUILD.bazel utils/bazel/torch-mlir-overlay/BUILD.bazel +index fdde0d63..6611d78b 100644 +--- utils/bazel/torch-mlir-overlay/BUILD.bazel ++++ utils/bazel/torch-mlir-overlay/BUILD.bazel +@@ -277,7 +277,7 @@ gentbl_cc_library( + ( + [ + "-gen-pass-decls", +- "-DTORCH_MLIR_ENABLE_STABLEHLO", ++ # "-DTORCH_MLIR_ENABLE_STABLEHLO", + "-DTORCH_MLIR_ENABLE_TOSA", + ], + "include/torch-mlir/Conversion/Passes.h.inc", +@@ -334,7 +334,7 @@ gentbl_cc_library( + ( + [ + "-gen-pass-decls", +- "-DTORCH_MLIR_ENABLE_STABLEHLO", ++ # "-DTORCH_MLIR_ENABLE_STABLEHLO", + "-DTORCH_MLIR_ENABLE_TOSA", + ], + "include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc", +@@ -495,28 +495,28 @@ cc_library( + ], + ) + +-cc_library( +- name = "TorchMLIRTorchToStablehlo", +- srcs = glob([ +- "lib/Conversion/*.h", +- "lib/Conversion/TorchToStablehlo/*.h", +- "lib/Conversion/TorchToStablehlo/*.cpp", +- ]), +- hdrs = glob(["include/torch-mlir/Conversion/TorchToStablehlo/*.h"]), +- defines = [ +- "TORCH_MLIR_ENABLE_STABLEHLO", +- ], +- strip_include_prefix = "include", +- deps = [ +- ":TorchMLIRConversionPassesIncGen", +- ":TorchMLIRConversionUtils", +- ":TorchMLIRTorchBackendTypeConversion", +- ":TorchMLIRTorchConversionDialect", +- "@llvm-project//mlir:Dialect", +- "@stablehlo//:register", +- "@stablehlo//:stablehlo_passes", +- ], +-) ++# cc_library( ++# name = "TorchMLIRTorchToStablehlo", ++# srcs = glob([ ++# "lib/Conversion/*.h", ++# "lib/Conversion/TorchToStablehlo/*.h", ++# "lib/Conversion/TorchToStablehlo/*.cpp", ++# ]), ++# hdrs = glob(["include/torch-mlir/Conversion/TorchToStablehlo/*.h"]), ++# defines = [ ++# "TORCH_MLIR_ENABLE_STABLEHLO", ++# ], ++# strip_include_prefix = "include", ++# deps = [ ++# ":TorchMLIRConversionPassesIncGen", ++# ":TorchMLIRConversionUtils", ++# ":TorchMLIRTorchBackendTypeConversion", ++# ":TorchMLIRTorchConversionDialect", ++# "@llvm-project//mlir:Dialect", ++# "@stablehlo//:register", ++# "@stablehlo//:stablehlo_passes", ++# ], ++# ) + + cc_library( + name = "TorchMLIRTorchOnnxToTorch", +@@ -543,7 +543,7 @@ cc_library( + "include/torch-mlir/Conversion/Passes.h", + ], + defines = [ +- "TORCH_MLIR_ENABLE_STABLEHLO", ++ # "TORCH_MLIR_ENABLE_STABLEHLO", + "TORCH_MLIR_ENABLE_TOSA", + ], + strip_include_prefix = "include", +@@ -553,7 +553,7 @@ cc_library( + ":TorchMLIRTorchToArith", + ":TorchMLIRTorchToLinalg", + ":TorchMLIRTorchToSCF", +- ":TorchMLIRTorchToStablehlo", ++ # ":TorchMLIRTorchToStablehlo", + ":TorchMLIRTorchToTMTensor", + ":TorchMLIRTorchToTensor", + ":TorchMLIRTorchToTosa", +@@ -568,7 +568,7 @@ cc_library( + ]), + hdrs = glob(["include/torch-mlir/Dialect/TorchConversion/Transforms/*.h"]), + defines = [ +- "TORCH_MLIR_ENABLE_STABLEHLO", ++ # "TORCH_MLIR_ENABLE_STABLEHLO", + "TORCH_MLIR_ENABLE_TOSA", + ], + strip_include_prefix = "include", +@@ -582,7 +582,7 @@ cc_library( + ":TorchMLIRTorchToArith", + ":TorchMLIRTorchToLinalg", + ":TorchMLIRTorchToSCF", +- ":TorchMLIRTorchToStablehlo", ++ # ":TorchMLIRTorchToStablehlo", + ":TorchMLIRTorchToTMTensor", + ":TorchMLIRTorchToTensor", + ":TorchMLIRTorchToTosa", +@@ -893,7 +893,7 @@ cc_library( + ], + copts = [ + "-DTORCH_MLIR_ENABLE_REFBACKEND", +- "-DTORCH_MLIR_ENABLE_STABLEHLO", ++ # "-DTORCH_MLIR_ENABLE_STABLEHLO", + "-DTORCH_MLIR_ENABLE_TOSA", + ], + strip_include_prefix = "include", +@@ -911,8 +911,8 @@ cc_library( + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TensorInferTypeOpInterfaceImpl", +- "@stablehlo//:linalg_passes", +- "@stablehlo//:stablehlo_passes", ++ # "@stablehlo//:linalg_passes", ++ # "@stablehlo//:stablehlo_passes", + ], + ) + diff --git a/third_party/patches/torch-mlir-bazel-build.2.patch b/third_party/patches/torch-mlir-bazel-build.2.patch new file mode 100644 index 0000000..89159cc --- /dev/null +++ b/third_party/patches/torch-mlir-bazel-build.2.patch @@ -0,0 +1,11 @@ +diff --git utils/bazel/torch-mlir-overlay/test/Conversion/BUILD.bazel utils/bazel/torch-mlir-overlay/test/Conversion/BUILD.bazel +index 2cbe8091..852c9886 100644 +--- utils/bazel/torch-mlir-overlay/test/Conversion/BUILD.bazel ++++ utils/bazel/torch-mlir-overlay/test/Conversion/BUILD.bazel +@@ -11,5 +11,5 @@ package(default_visibility = ["//visibility:public"]) + "@torch-mlir//test:lit_data", + ], + ) +- for src in glob(["**/*.mlir"]) ++ for src in glob(["**/*.mlir"], exclude = ["TorchToStablehlo/*.mlir"]) + ] diff --git a/third_party/patches/torch-mlir.1.patch b/third_party/patches/torch-mlir.1.patch deleted file mode 100644 index 0426372..0000000 --- a/third_party/patches/torch-mlir.1.patch +++ /dev/null @@ -1,12 +0,0 @@ -diff --git lib/InitAll.cpp lib/InitAll.cpp -index d9096929..2a9be6cc 100644 ---- lib/InitAll.cpp -+++ lib/InitAll.cpp -@@ -33,6 +33,7 @@ - #ifdef TORCH_MLIR_ENABLE_STABLEHLO - #include "stablehlo/conversions/linalg/transforms/Passes.h" - #include "stablehlo/transforms/Passes.h" -+#include "stablehlo/transforms/optimization/Passes.h" - #endif - - #ifdef TORCH_MLIR_ENABLE_TOSA diff --git a/tools/clangd/BUILD b/tools/clangd/BUILD index 51ec5de..acceef3 100644 --- a/tools/clangd/BUILD +++ b/tools/clangd/BUILD @@ -12,7 +12,6 @@ refresh_compile_commands( # bazel query 'kind("(cc.*) rule", //...)' targets = [ "//:Pipeline", - "//:StablehloToTcp", "//:TcpConversionPasses", "//:TcpConversionPassesIncGen", "//:TcpDialect", diff --git a/tools/tcp-opt/tcp-opt.cpp b/tools/tcp-opt/tcp-opt.cpp index 9a21e76..2dac034 100644 --- a/tools/tcp-opt/tcp-opt.cpp +++ b/tools/tcp-opt/tcp-opt.cpp @@ -15,8 +15,6 @@ #include "mlir-tcp/InitAll.h" #include "mlir-tcp/Pipeline/Pipeline.h" -#include "stablehlo/dialect/Register.h" - using namespace mlir; int main(int argc, char **argv) { @@ -28,8 +26,6 @@ int main(int argc, char **argv) { registerAllExtensions(registry); mlir::tcp::registerAllDialects(registry); - mlir::stablehlo::registerAllDialects(registry); - mlir::tcp::registerTcpPipelines(); return mlir::asMainReturnCode(mlir::MlirOptMain(