Skip to content

Commit dc8e22a

Browse files
authored
[large tensor] Use int64_t instead of int/int32_t for dims (PaddlePaddle#76290)
* add new enforce * apply int32-dims * fix * add rule * fix * fix rule * remove PADDLE_ENFORCE_LE_INT_MAX * fix * fix ast-grep
1 parent 6fd1723 commit dc8e22a

File tree

61 files changed

+629
-159
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+629
-159
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
id: no-int32-type-dims
2+
snapshots:
3+
int class_num = prob->dims()[1];:
4+
fixed: |
5+
int64_t class_num = prob->dims()[1];
6+
// TODO(large-tensor): downstream functors may still use int; guard until upgraded.
7+
labels:
8+
- source: int class_num = prob->dims()[1];
9+
style: primary
10+
start: 0
11+
end: 32
12+
int input_ids_num = input.numel();:
13+
fixed: |
14+
int64_t input_ids_num = input.numel();
15+
// TODO(large-tensor): downstream functors may still use int; guard until upgraded.
16+
PADDLE_ENFORCE_LE_INT_MAX(input_ids_num, "input_ids_num");
17+
labels:
18+
- source: int input_ids_num = input.numel();
19+
style: primary
20+
start: 0
21+
end: 34
22+
int row_size = x.dims()[ndim - 1];:
23+
fixed: |
24+
int64_t row_size = x.dims()[ndim - 1];
25+
// TODO(large-tensor): downstream functors may still use int; guard until upgraded.
26+
labels:
27+
- source: int row_size = x.dims()[ndim - 1];
28+
style: primary
29+
start: 0
30+
end: 34
31+
int32_t input_ids_num = input.numel();:
32+
fixed: |
33+
int64_t input_ids_num = input.numel();
34+
// TODO(large-tensor): downstream functors may still use int; guard until upgraded.
35+
labels:
36+
- source: int32_t input_ids_num = input.numel();
37+
style: primary
38+
start: 0
39+
end: 38
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
id: no-int32-type-dims
2+
valid:
3+
- int64_t input_ids_num = input.numel();
4+
- int64_t class_num = prob->dims()[1];
5+
- int64_t row_size = x.dims()[ndim - 1];
6+
invalid:
7+
- int32_t input_ids_num = input.numel();
8+
- int class_num = prob->dims()[1];
9+
- int row_size = x.dims()[ndim - 1];

ci/rules/no-int32-type-dims.yml

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
id: no-int32-type-dims
2+
language: cpp
3+
files:
4+
- paddle/phi/kernels/**
5+
ignores:
6+
- paddle/phi/kernels/legacy/**
7+
- paddle/phi/kernels/xpu/**
8+
- paddle/phi/kernels/custom/**
9+
- paddle/phi/kernels/sparse/**
10+
severity: error
11+
message: |
12+
Use int64_t instead of int/int32_t for dims/shape/strides/numel/offset (large tensor support).
13+
If it is a false positive, please contact zrr1999 (Recommend), wanghuancoder for more information.
14+
note: dims/shape/strides/numel/offset may be >= INT32_MAX; dims()/numel() return int64_t
15+
rule:
16+
any:
17+
# int x = <RIGHT>;
18+
- pattern: int $VAR = $RIGHT
19+
# int32_t x = <RIGHT>;
20+
- pattern: int32_t $VAR = $RIGHT
21+
# direct-initializer: int x(<RIGHT>);
22+
- pattern: int $VAR($RIGHT)
23+
# direct-initializer: int32_t x(<RIGHT>);
24+
- pattern: int32_t $VAR($RIGHT)
25+
constraints:
26+
RIGHT:
27+
any:
28+
# dims[...] / expr.dims()[...] / expr->dims()[...]
29+
- pattern: $E.dims[$INDEX]
30+
- pattern: $E.dims()[$INDEX]
31+
- pattern: $E->dims[$INDEX]
32+
- pattern: $E->dims()[$INDEX]
33+
# shape[...] / expr.shape()[...] / expr->shape()[...]
34+
- pattern: $E.shape[$INDEX]
35+
- pattern: $E.shape()[$INDEX]
36+
- pattern: $E->shape[$INDEX]
37+
- pattern: $E->shape()[$INDEX]
38+
# strides[...] / expr.strides()[...] / expr->strides()[...]
39+
- pattern: $E.strides[$INDEX]
40+
- pattern: $E.strides()[$INDEX]
41+
- pattern: $E->strides[$INDEX]
42+
- pattern: $E->strides()[$INDEX]
43+
# numel / numel()
44+
- pattern: $E.numel
45+
- pattern: $E.numel()
46+
- pattern: $E->numel()
47+
- pattern: $E->numel()
48+
# offset / offset()
49+
# unsafe
50+
# - pattern: $E.offset
51+
# - pattern: $E.offset()
52+
fix: |
53+
int64_t $VAR = $RIGHT;
54+
// TODO(large-tensor): downstream functors may still use int; guard until upgraded.

paddle/common/enforce.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,16 @@ using CommonType2 = typename std::add_lvalue_reference<
335335
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
336336
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
337337

338+
#define PADDLE_ENFORCE_LE_INT_MAX(var, var_name) \
339+
PADDLE_ENFORCE_LE(var, \
340+
std::numeric_limits<int>::max(), \
341+
common::errors::InvalidArgument( \
342+
"Tensor dimension %s=%ld exceeds the maximum value " \
343+
"that int can represent (%d).", \
344+
var_name, \
345+
var, \
346+
std::numeric_limits<int>::max()))
347+
338348
TEST_API bool RegisterLogSimplyStr(const std::string& type,
339349
const std::string& simply);
340350
template <typename T>

paddle/phi/kernels/cpu/tdm_child_kernel.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ void TDMChildInner(const Context &dev_ctx,
3636
int node_nums = info_dims[0];
3737
int length = info_dims[1];
3838

39-
int input_ids_num = input.numel();
39+
int64_t input_ids_num = input.numel();
40+
// TODO(large-tensor): downstream functors may still use int; guard until
41+
// upgraded.
42+
4043
VLOG(4) << "TDM child op: input numel -> " << input_ids_num;
4144

4245
std::vector<OutT> child_vec{};

paddle/phi/kernels/cpu/tdm_sampler_kernel.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ void TDMSamplerInner(const Context &dev_ctx,
4343
phi::DenseTensor *label,
4444
phi::DenseTensor *mask) {
4545
// get dimension
46-
int input_ids_num = input_tensor.numel();
46+
int64_t input_ids_num = input_tensor.numel();
47+
// TODO(large-tensor): downstream functors may still use int; guard until
48+
// upgraded.
49+
4750
VLOG(3) << "TDM: input ids nums: " << input_ids_num;
4851
auto layer_nums = neg_samples_num_list.size();
4952
VLOG(3) << "TDM: tree layer nums: " << layer_nums;

paddle/phi/kernels/funcs/cross_entropy.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,14 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()(
122122
T* loss_data = dev_ctx.template Alloc<T>(out);
123123
const T* prob_data = prob->data<T>();
124124

125-
int batch_size = prob->dims()[0];
126-
int class_num = prob->dims()[1];
125+
int64_t batch_size = prob->dims()[0];
126+
// TODO(large-tensor): downstream functors may still use int; guard until
127+
// upgraded.
128+
129+
int64_t class_num = prob->dims()[1];
130+
// TODO(large-tensor): downstream functors may still use int; guard until
131+
// upgraded.
132+
127133
constexpr int kMaxBlockDim = 512;
128134

129135
// big tensor currently not supported

paddle/phi/kernels/funcs/fake_dequantize_functor.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ void ChannelDequantizeFunctor<Context, T>::operator()(
142142
// quantized on. `x_num_col_dims` is -1 for operator in ['matmul',
143143
// 'matmul_v2', 'mul'] and is 1 for other operators.
144144
int64_t num = in->numel();
145-
int n_scales = in->dims()[x_num_col_dims];
145+
int64_t n_scales = in->dims()[x_num_col_dims];
146+
// TODO(large-tensor): downstream functors may still use int; guard until
147+
// upgraded.
148+
146149
const T* scale_one = scales[0]->data<T>();
147150
const T* scale_two = scales[1]->data<T>();
148151

paddle/phi/kernels/funcs/im2col.cu

Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,21 @@ class Im2ColFunctor<phi::funcs::ColFormat::kCFO, DeviceContext, T> {
121121
(data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
122122
int im_width =
123123
(data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
124-
int filter_height = col->dims()[1];
125-
int filter_width = col->dims()[2];
126-
int col_height = col->dims()[3];
127-
int col_width = col->dims()[4];
124+
int64_t filter_height = col->dims()[1];
125+
// TODO(large-tensor): downstream functors may still use int; guard until
126+
// upgraded.
127+
128+
int64_t filter_width = col->dims()[2];
129+
// TODO(large-tensor): downstream functors may still use int; guard until
130+
// upgraded.
131+
132+
int64_t col_height = col->dims()[3];
133+
// TODO(large-tensor): downstream functors may still use int; guard until
134+
// upgraded.
135+
136+
int64_t col_width = col->dims()[4];
137+
// TODO(large-tensor): downstream functors may still use int; guard until
138+
// upgraded.
128139

129140
int num_outputs = im_channels * col_height * col_width;
130141
int num_thread = 1024;
@@ -256,10 +267,21 @@ class Col2ImFunctor<phi::funcs::ColFormat::kCFO, DeviceContext, T> {
256267
(data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
257268
int im_width =
258269
(data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
259-
int filter_height = col.dims()[1];
260-
int filter_width = col.dims()[2];
261-
int col_height = col.dims()[3];
262-
int col_width = col.dims()[4];
270+
int64_t filter_height = col.dims()[1];
271+
// TODO(large-tensor): downstream functors may still use int; guard until
272+
// upgraded.
273+
274+
int64_t filter_width = col.dims()[2];
275+
// TODO(large-tensor): downstream functors may still use int; guard until
276+
// upgraded.
277+
278+
int64_t col_height = col.dims()[3];
279+
// TODO(large-tensor): downstream functors may still use int; guard until
280+
// upgraded.
281+
282+
int64_t col_width = col.dims()[4];
283+
// TODO(large-tensor): downstream functors may still use int; guard until
284+
// upgraded.
263285

264286
PADDLE_ENFORCE_EQ(
265287
(im_height + padding[0] + padding[2] -
@@ -406,13 +428,33 @@ class Im2ColFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
406428
"the dims of tensor 'col' is [%s].",
407429
col->dims()));
408430

409-
int im_channels = im.dims()[0];
410-
int im_height = im.dims()[1];
411-
int im_width = im.dims()[2];
412-
int filter_height = col->dims()[3];
413-
int filter_width = col->dims()[4];
414-
int col_height = col->dims()[0];
415-
int col_width = col->dims()[1];
431+
int64_t im_channels = im.dims()[0];
432+
// TODO(large-tensor): downstream functors may still use int; guard until
433+
// upgraded.
434+
435+
int64_t im_height = im.dims()[1];
436+
// TODO(large-tensor): downstream functors may still use int; guard until
437+
// upgraded.
438+
439+
int64_t im_width = im.dims()[2];
440+
// TODO(large-tensor): downstream functors may still use int; guard until
441+
// upgraded.
442+
443+
int64_t filter_height = col->dims()[3];
444+
// TODO(large-tensor): downstream functors may still use int; guard until
445+
// upgraded.
446+
447+
int64_t filter_width = col->dims()[4];
448+
// TODO(large-tensor): downstream functors may still use int; guard until
449+
// upgraded.
450+
451+
int64_t col_height = col->dims()[0];
452+
// TODO(large-tensor): downstream functors may still use int; guard until
453+
// upgraded.
454+
455+
int64_t col_width = col->dims()[1];
456+
// TODO(large-tensor): downstream functors may still use int; guard until
457+
// upgraded.
416458

417459
int block_dim_x = 0;
418460
int block_dim_y = 0;
@@ -431,7 +473,9 @@ class Im2ColFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
431473
}
432474

433475
int block_dim_z = 1024 / block_dim_x / block_dim_y;
434-
dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
476+
dim3 threads(block_dim_x,
477+
block_dim_y,
478+
std::min(block_dim_z, static_cast<int>(im_channels)));
435479
dim3 grid(col_width, col_height);
436480
im2colOCF<T><<<grid, threads, 0, dev_ctx.stream()>>>(im.data<T>(),
437481
im_channels,
@@ -516,13 +560,33 @@ class Col2ImFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
516560
"the dims of tensor 'col' is [%s].",
517561
col.dims()));
518562

519-
int im_channels = im->dims()[0];
520-
int im_height = im->dims()[1];
521-
int im_width = im->dims()[2];
522-
int filter_height = col.dims()[3];
523-
int filter_width = col.dims()[4];
524-
int col_height = col.dims()[0];
525-
int col_width = col.dims()[1];
563+
int64_t im_channels = im->dims()[0];
564+
// TODO(large-tensor): downstream functors may still use int; guard until
565+
// upgraded.
566+
567+
int64_t im_height = im->dims()[1];
568+
// TODO(large-tensor): downstream functors may still use int; guard until
569+
// upgraded.
570+
571+
int64_t im_width = im->dims()[2];
572+
// TODO(large-tensor): downstream functors may still use int; guard until
573+
// upgraded.
574+
575+
int64_t filter_height = col.dims()[3];
576+
// TODO(large-tensor): downstream functors may still use int; guard until
577+
// upgraded.
578+
579+
int64_t filter_width = col.dims()[4];
580+
// TODO(large-tensor): downstream functors may still use int; guard until
581+
// upgraded.
582+
583+
int64_t col_height = col.dims()[0];
584+
// TODO(large-tensor): downstream functors may still use int; guard until
585+
// upgraded.
586+
587+
int64_t col_width = col.dims()[1];
588+
// TODO(large-tensor): downstream functors may still use int; guard until
589+
// upgraded.
526590

527591
PADDLE_ENFORCE_EQ(
528592
(im_height + padding[0] + padding[2] -
@@ -558,7 +622,9 @@ class Col2ImFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
558622
}
559623

560624
int block_dim_z = 1024 / block_dim_x / block_dim_y;
561-
dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
625+
dim3 threads(block_dim_x,
626+
block_dim_y,
627+
std::min(block_dim_z, static_cast<int>(im_channels)));
562628
dim3 grid(col_width, col_height);
563629
col2imOCF<T><<<grid, threads, 0, dev_ctx.stream()>>>(col.data<T>(),
564630
im_channels,

0 commit comments

Comments
 (0)