Skip to content

Commit 986d178

Browse files
committed
Fix the CUDA backend for_each so that it can handle big indexes.
Bug 2448170 Github NVIDIA#967
1 parent a8187b8 commit 986d178

File tree

3 files changed

+50
-5
lines changed

3 files changed

+50
-5
lines changed

testing/for_each.cu

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include <thrust/device_ptr.h>
44
#include <thrust/iterator/counting_iterator.h>
55
#include <thrust/iterator/retag.h>
6+
#include <thrust/device_malloc.h>
7+
#include <thrust/device_free.h>
68
#include <algorithm>
79

810
THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING_BEGIN
@@ -352,3 +354,46 @@ void TestForEachNWithLargeTypes(void)
352354
DECLARE_UNITTEST(TestForEachNWithLargeTypes);
353355

354356
THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING_END
357+
358+
struct OnlySetWhenExpected
359+
{
360+
unsigned long long expected;
361+
bool * flag;
362+
363+
__device__
364+
void operator()(unsigned long long x)
365+
{
366+
if (x == expected)
367+
{
368+
*flag = true;
369+
}
370+
}
371+
};
372+
373+
void TestForEachWithBigIndexesHelper(int magnitude)
374+
{
375+
thrust::counting_iterator<unsigned long long> begin(0);
376+
thrust::counting_iterator<unsigned long long> end = begin + (1ull << magnitude);
377+
ASSERT_EQUAL(thrust::distance(begin, end), 1ll << magnitude);
378+
379+
thrust::device_ptr<bool> has_executed = thrust::device_malloc<bool>(1);
380+
*has_executed = false;
381+
382+
OnlySetWhenExpected fn = { (1ull << magnitude) - 1, thrust::raw_pointer_cast(has_executed) };
383+
384+
thrust::for_each(thrust::device, begin, end, fn);
385+
386+
bool has_executed_h = *has_executed;
387+
thrust::device_free(has_executed);
388+
389+
ASSERT_EQUAL(has_executed_h, true);
390+
}
391+
392+
void TestForEachWithBigIndexes()
393+
{
394+
TestForEachWithBigIndexesHelper(30);
395+
TestForEachWithBigIndexesHelper(31);
396+
TestForEachWithBigIndexesHelper(32);
397+
TestForEachWithBigIndexesHelper(33);
398+
}
399+
DECLARE_UNITTEST(TestForEachWithBigIndexes);

thrust/system/cuda/detail/core/agent_launcher.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ namespace core {
408408
stream(stream_),
409409
name(name_),
410410
debug_sync(debug_sync_),
411-
grid(static_cast<unsigned int>(count + plan.items_per_tile - 1) / plan.items_per_tile),
411+
grid(static_cast<unsigned int>((count + plan.items_per_tile - 1) / plan.items_per_tile)),
412412
vshmem(NULL),
413413
has_shmem((size_t)core::get_max_shared_memory_per_block() >= (size_t)plan.shared_memory_size),
414414
shmem_size(has_shmem ? plan.shared_memory_size : 0)
@@ -429,7 +429,7 @@ namespace core {
429429
stream(stream_),
430430
name(name_),
431431
debug_sync(debug_sync_),
432-
grid(static_cast<unsigned int>(count + plan.items_per_tile - 1) / plan.items_per_tile),
432+
grid(static_cast<unsigned int>((count + plan.items_per_tile - 1) / plan.items_per_tile)),
433433
vshmem(vshmem),
434434
has_shmem((size_t)core::get_max_shared_memory_per_block() >= (size_t)plan.shared_memory_size),
435435
shmem_size(has_shmem ? plan.shared_memory_size : 0)

thrust/system/cuda/detail/parallel_for.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ namespace __parallel_for {
9393
#pragma unroll
9494
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
9595
{
96-
int idx = BLOCK_THREADS * ITEM + threadIdx.x;
96+
Size idx = BLOCK_THREADS * ITEM + threadIdx.x;
9797
if (IS_FULL_TILE || idx < items_in_tile)
9898
f(tile_base + idx);
9999
}
@@ -103,9 +103,9 @@ namespace __parallel_for {
103103
Size num_items,
104104
char * /*shmem*/ )
105105
{
106-
Size tile_base = blockIdx.x * ITEMS_PER_TILE;
106+
Size tile_base = static_cast<Size>(blockIdx.x) * ITEMS_PER_TILE;
107107
Size num_remaining = num_items - tile_base;
108-
int items_in_tile = static_cast<int>(
108+
Size items_in_tile = static_cast<Size>(
109109
num_remaining < ITEMS_PER_TILE ? num_remaining : ITEMS_PER_TILE);
110110

111111
if (items_in_tile == ITEMS_PER_TILE)

0 commit comments

Comments
 (0)