@@ -121,10 +121,21 @@ class Im2ColFunctor<phi::funcs::ColFormat::kCFO, DeviceContext, T> {
121121 (data_layout != DataLayout::kNHWC ? im.dims ()[1 ] : im.dims ()[0 ]);
122122 int im_width =
123123 (data_layout != DataLayout::kNHWC ? im.dims ()[2 ] : im.dims ()[1 ]);
124- int filter_height = col->dims ()[1 ];
125- int filter_width = col->dims ()[2 ];
126- int col_height = col->dims ()[3 ];
127- int col_width = col->dims ()[4 ];
124+ int64_t filter_height = col->dims ()[1 ];
125+ // TODO(large-tensor): downstream functors may still use int; guard until
126+ // upgraded.
127+
128+ int64_t filter_width = col->dims ()[2 ];
129+ // TODO(large-tensor): downstream functors may still use int; guard until
130+ // upgraded.
131+
132+ int64_t col_height = col->dims ()[3 ];
133+ // TODO(large-tensor): downstream functors may still use int; guard until
134+ // upgraded.
135+
136+ int64_t col_width = col->dims ()[4 ];
137+ // TODO(large-tensor): downstream functors may still use int; guard until
138+ // upgraded.
128139
129140 int num_outputs = im_channels * col_height * col_width;
130141 int num_thread = 1024 ;
@@ -256,10 +267,21 @@ class Col2ImFunctor<phi::funcs::ColFormat::kCFO, DeviceContext, T> {
256267 (data_layout != DataLayout::kNHWC ? im->dims ()[1 ] : im->dims ()[0 ]);
257268 int im_width =
258269 (data_layout != DataLayout::kNHWC ? im->dims ()[2 ] : im->dims ()[1 ]);
259- int filter_height = col.dims ()[1 ];
260- int filter_width = col.dims ()[2 ];
261- int col_height = col.dims ()[3 ];
262- int col_width = col.dims ()[4 ];
270+ int64_t filter_height = col.dims ()[1 ];
271+ // TODO(large-tensor): downstream functors may still use int; guard until
272+ // upgraded.
273+
274+ int64_t filter_width = col.dims ()[2 ];
275+ // TODO(large-tensor): downstream functors may still use int; guard until
276+ // upgraded.
277+
278+ int64_t col_height = col.dims ()[3 ];
279+ // TODO(large-tensor): downstream functors may still use int; guard until
280+ // upgraded.
281+
282+ int64_t col_width = col.dims ()[4 ];
283+ // TODO(large-tensor): downstream functors may still use int; guard until
284+ // upgraded.
263285
264286 PADDLE_ENFORCE_EQ (
265287 (im_height + padding[0 ] + padding[2 ] -
@@ -406,13 +428,33 @@ class Im2ColFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
406428 " the dims of tensor 'col' is [%s]." ,
407429 col->dims ()));
408430
409- int im_channels = im.dims ()[0 ];
410- int im_height = im.dims ()[1 ];
411- int im_width = im.dims ()[2 ];
412- int filter_height = col->dims ()[3 ];
413- int filter_width = col->dims ()[4 ];
414- int col_height = col->dims ()[0 ];
415- int col_width = col->dims ()[1 ];
431+ int64_t im_channels = im.dims ()[0 ];
432+ // TODO(large-tensor): downstream functors may still use int; guard until
433+ // upgraded.
434+
435+ int64_t im_height = im.dims ()[1 ];
436+ // TODO(large-tensor): downstream functors may still use int; guard until
437+ // upgraded.
438+
439+ int64_t im_width = im.dims ()[2 ];
440+ // TODO(large-tensor): downstream functors may still use int; guard until
441+ // upgraded.
442+
443+ int64_t filter_height = col->dims ()[3 ];
444+ // TODO(large-tensor): downstream functors may still use int; guard until
445+ // upgraded.
446+
447+ int64_t filter_width = col->dims ()[4 ];
448+ // TODO(large-tensor): downstream functors may still use int; guard until
449+ // upgraded.
450+
451+ int64_t col_height = col->dims ()[0 ];
452+ // TODO(large-tensor): downstream functors may still use int; guard until
453+ // upgraded.
454+
455+ int64_t col_width = col->dims ()[1 ];
456+ // TODO(large-tensor): downstream functors may still use int; guard until
457+ // upgraded.
416458
417459 int block_dim_x = 0 ;
418460 int block_dim_y = 0 ;
@@ -431,7 +473,9 @@ class Im2ColFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
431473 }
432474
433475 int block_dim_z = 1024 / block_dim_x / block_dim_y;
434- dim3 threads (block_dim_x, block_dim_y, std::min (block_dim_z, im_channels));
476+ dim3 threads (block_dim_x,
477+ block_dim_y,
478+ std::min (block_dim_z, static_cast <int >(im_channels)));
435479 dim3 grid (col_width, col_height);
436480 im2colOCF<T><<<grid, threads, 0 , dev_ctx.stream()>>> (im.data <T>(),
437481 im_channels,
@@ -516,13 +560,33 @@ class Col2ImFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
516560 " the dims of tensor 'col' is [%s]." ,
517561 col.dims ()));
518562
519- int im_channels = im->dims ()[0 ];
520- int im_height = im->dims ()[1 ];
521- int im_width = im->dims ()[2 ];
522- int filter_height = col.dims ()[3 ];
523- int filter_width = col.dims ()[4 ];
524- int col_height = col.dims ()[0 ];
525- int col_width = col.dims ()[1 ];
563+ int64_t im_channels = im->dims ()[0 ];
564+ // TODO(large-tensor): downstream functors may still use int; guard until
565+ // upgraded.
566+
567+ int64_t im_height = im->dims ()[1 ];
568+ // TODO(large-tensor): downstream functors may still use int; guard until
569+ // upgraded.
570+
571+ int64_t im_width = im->dims ()[2 ];
572+ // TODO(large-tensor): downstream functors may still use int; guard until
573+ // upgraded.
574+
575+ int64_t filter_height = col.dims ()[3 ];
576+ // TODO(large-tensor): downstream functors may still use int; guard until
577+ // upgraded.
578+
579+ int64_t filter_width = col.dims ()[4 ];
580+ // TODO(large-tensor): downstream functors may still use int; guard until
581+ // upgraded.
582+
583+ int64_t col_height = col.dims ()[0 ];
584+ // TODO(large-tensor): downstream functors may still use int; guard until
585+ // upgraded.
586+
587+ int64_t col_width = col.dims ()[1 ];
588+ // TODO(large-tensor): downstream functors may still use int; guard until
589+ // upgraded.
526590
527591 PADDLE_ENFORCE_EQ (
528592 (im_height + padding[0 ] + padding[2 ] -
@@ -558,7 +622,9 @@ class Col2ImFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
558622 }
559623
560624 int block_dim_z = 1024 / block_dim_x / block_dim_y;
561- dim3 threads (block_dim_x, block_dim_y, std::min (block_dim_z, im_channels));
625+ dim3 threads (block_dim_x,
626+ block_dim_y,
627+ std::min (block_dim_z, static_cast <int >(im_channels)));
562628 dim3 grid (col_width, col_height);
563629 col2imOCF<T><<<grid, threads, 0 , dev_ctx.stream()>>> (col.data <T>(),
564630 im_channels,
0 commit comments