Skip to content

Commit 29dea9c

Browse files
[Mosaic] Internal change.
PiperOrigin-RevId: 831038411
1 parent 24e80c4 commit 29dea9c

File tree

5 files changed

+69
-2
lines changed

5 files changed

+69
-2
lines changed

jaxlib/mlir/_mlir_libs/tpu_ext.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,4 +871,17 @@ NB_MODULE(_tpu_ext, m) {
871871
mlirOperationGetRegion(src, i));
872872
}
873873
});
874+
875+
mlir::python::nanobind_adaptors::mlir_type_subclass(m, "Float8EXMYType",
876+
mlirTpuIsAFloat8EXMYType)
877+
.def_classmethod(
878+
"get",
879+
[](nb::object cls, MlirType exmy_type, MlirContext ctx) {
880+
return cls(mlirTpuFloat8EXMYTypeGet(ctx, exmy_type));
881+
},
882+
nb::arg("self"), nb::arg("exmy_type") = nullptr,
883+
nb::arg("ctx") = nullptr)
884+
.def_property_readonly("underlying_type", [](MlirType self) {
885+
return mlirTpuFloat8EXMYTypeGetUnderlyingType(self);
886+
});
874887
}

jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,26 @@ MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass() {
418418
mlir::tpu::registerMosaicSerdePass();
419419
}
420420

421+
//===----------------------------------------------------------------------===//
422+
// Type API.
423+
//===----------------------------------------------------------------------===//
424+
425+
// Float8EXMYType
426+
//===----------------------------------------------------------------------===//
427+
428+
MlirType mlirTpuFloat8EXMYTypeGetUnderlyingType(MlirType exmy_type) {
429+
return wrap(llvm::cast<mlir::tpu::Float8EXMYType>(unwrap(exmy_type))
430+
.getUnderlyingType());
431+
}
432+
433+
bool mlirTpuIsAFloat8EXMYType(MlirType type) {
434+
return llvm::isa<mlir::tpu::Float8EXMYType>(unwrap(type));
435+
}
436+
437+
MlirType mlirTpuFloat8EXMYTypeGet(MlirContext ctx, MlirType exmy_type) {
438+
return wrap(mlir::tpu::Float8EXMYType::get(unwrap(ctx), unwrap(exmy_type)));
439+
}
440+
421441
#include "mlir/CAPI/Pass.h" // IWYU pragma: keep
422442
#include "mlir/CAPI/Support.h" // IWYU pragma: keep
423443

jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,14 @@ mlirTpuRelayout(MlirTpuInsertionPoint insertion_point, MlirValue val,
240240

241241
MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass();
242242

243+
MLIR_CAPI_EXPORTED MlirType mlirTpuFloat8EXMYTypeGetUnderlyingType(
244+
MlirType exmy_type);
245+
246+
MLIR_CAPI_EXPORTED bool mlirTpuIsAFloat8EXMYType(MlirType type);
247+
248+
MLIR_CAPI_EXPORTED MlirType mlirTpuFloat8EXMYTypeGet(
249+
MlirContext ctx, MlirType exmy_type);
250+
243251
#ifdef __cplusplus
244252
}
245253
#endif

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ class TPU_Attr<string name, string mnemonic_, list<Trait> traits = []>
4848
// TODO(b/369418606): Find out the way to verify vreg size.
4949
def TPU_Vreg : Type<IsVectorOfNonZeroRankTypePred, "native-sized vreg", "::mlir::VectorType">;
5050

51-
class TPU_Type<string name, string mnemonic_, list<Trait> traits = []>
52-
: TypeDef<TPU_Dialect, name, traits> {
51+
class TPU_Type<string name, string mnemonic_, list<Trait> traits = [],
52+
string baseCppType = "::mlir::Type">
53+
: TypeDef<TPU_Dialect, name, traits, baseCppType> {
5354
let mnemonic = mnemonic_;
5455
}
5556

@@ -82,6 +83,25 @@ def TPU_SemaphoreType : TPU_Type<"Semaphore", "semaphore", [MemRefElementTypeInt
8283
def TPU_DMASemaphoreType : TPU_Type<"DMASemaphore", "dma_semaphore", [MemRefElementTypeInterface]>;
8384
def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>;
8485

86+
def TPU_Float8EXMYType : TPU_Type<"Float8EXMY", "float8_exmy",
87+
[DeclareTypeInterfaceMethods<FloatTypeInterface, ["getFloatSemantics"]>]> {
88+
let summary = "EXMY type in a nearest power-of-2 bitwidth type container";
89+
let description = [{
90+
EXMY type in a nearest power-of-2 bitwidth type container. Meaningful bits
91+
are aligned to LSB, and bits higher than the underlying exmy type in the
92+
container are considered as ignored. See https://arxiv.org/abs/2405.13938
93+
for more details.
94+
}];
95+
96+
let parameters = (ins
97+
TypeParameter<"::mlir::Type", "Underlying EXMY type">:$underlying_type
98+
);
99+
100+
let assemblyFormat = [{
101+
`<` $underlying_type `>`
102+
}];
103+
}
104+
85105
def TPU_DimensionSemantics : I32EnumAttr<"DimensionSemantics", "Dimension semantics", [
86106
I32EnumAttrCase<"parallel", 0>,
87107
I32EnumAttrCase<"arbitrary", 1>,

jaxlib/mosaic/dialect/tpu/tpu_dialect.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
#include "absl/hash/hash.h"
2222
#include "absl/log/log.h"
23+
#include "llvm/ADT/APFloat.h"
2324
#include "llvm/ADT/Hashing.h"
2425
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep.
2526
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -292,4 +293,9 @@ DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder,
292293
/*rhs_batch_dims=*/{});
293294
}
294295

296+
const ::llvm::fltSemantics& Float8EXMYType::getFloatSemantics() const {
297+
// TODO(twsung): Fix this.
298+
return llvm::APFloat::Float8E8M0FNU();
299+
}
300+
295301
} // namespace mlir::tpu

0 commit comments

Comments
 (0)