@@ -131,62 +131,48 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
131131 return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
132132 }
133133
134- BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
135- return llvm::dyn_cast_if_present<BlockTensorDescAttr>(getEncoding());
136- }
137-
138- ScatterTensorDescAttr getEncodingAsScatterTensorDescAttr () const {
139- return llvm::dyn_cast_if_present<ScatterTensorDescAttr >(getEncoding());
134+ template <typename T,
135+ typename = std::enable_if_t<
136+ std::is_same_v<T, BlockTensorDescAttr> ||
137+ std::is_same_v<T, ScatterTensorDescAttr>>>
138+ T getEncodingOfType () const {
139+ return llvm::dyn_cast_if_present<T >(getEncoding());
140140 }
141141
142142 LayoutAttr getLayoutAttr() const {
143143 return llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
144144 }
145145
146146 xegpu::MemorySpace getMemorySpace() const {
147- auto block_attr = getEncodingAsBlockTensorDescAttr();
148- if (block_attr && block_attr.getMemorySpace())
149- return block_attr.getMemorySpace().getValue();
150-
151- auto scatter_attr = getEncodingAsScatterTensorDescAttr();
152- if (scatter_attr && scatter_attr.getMemorySpace())
153- return scatter_attr.getMemorySpace().getValue();
147+ if (auto attr = getEncodingOfType<BlockTensorDescAttr>())
148+ return attr.getMemorySpace().getValue();
154149
155- // return default value
156- return MemorySpace::Global ;
150+ auto attr = getEncodingOfType<ScatterTensorDescAttr>();
151+ return attr.getMemorySpace().getValue() ;
157152 }
158153
159154 // get the ArrayLength for blocked TensorDesc
160155 int getArrayLength() {
161- auto attr = getEncoding();
162- auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
163- assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
164- if (block_attr && block_attr.getArrayLength())
165- return block_attr.getArrayLength().getInt();
166- // return default value
167- return 1;
156+ auto attr = getEncodingOfType<BlockTensorDescAttr>();
157+ assert(attr && "invalid on non BlockTensorDescAttr.");
158+ return attr.getArrayLength().getInt();
168159 }
169160
170161 bool getBoundaryCheck() {
171- auto attr = getEncoding();
172- auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
173- assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
174- if (block_attr && block_attr.getBoundaryCheck())
175- return block_attr.getBoundaryCheck().getValue();
176- // return default value
177- return true;
162+ auto attr = getEncodingOfType<BlockTensorDescAttr>();
163+ assert(attr && "invalid on non BlockTensorDescAttr.");
164+ return attr.getBoundaryCheck().getValue();
178165 }
179166
180167 bool isScattered() {
181- return bool(getEncodingAsScatterTensorDescAttr ());
168+ return bool(getEncodingOfType<ScatterTensorDescAttr> ());
182169 }
183170
184171 // get the ChunkSize for scattered TensorDesc
185172 int getChunkSizeAsInt() {
186- auto attr = getEncoding();
187- auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
188- assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
189- return scatter_attr.getChunkSizeAsInt();
173+ auto attr = getEncodingOfType<ScatterTensorDescAttr>();
174+ assert(attr && "invalid on non ScatterTensorDescAttr.");
175+ return attr.getChunkSizeAsInt();
190176 }
191177
192178 /// Helper to drop all layout information from the TensorDesc type.
0 commit comments