Skip to content

Commit e2cb796

Browse files
authored
Merge pull request #464 from Nuzhny007/master
Fix build for TensorRT 8.x but it works only with TensorRT 10.x
2 parents ce3a58c + 11f3002 commit e2cb796

File tree

11 files changed

+181
-32
lines changed

11 files changed

+181
-32
lines changed

src/Detector/tensorrt_yolo/common/common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,17 +541,23 @@ inline uint32_t getElementSize(nvinfer1::DataType t) noexcept
541541
{
542542
switch (t)
543543
{
544+
#if (NV_TENSORRT_MAJOR > 8)
544545
case nvinfer1::DataType::kINT64: return 8;
546+
#endif
545547
case nvinfer1::DataType::kINT32:
546548
case nvinfer1::DataType::kFLOAT: return 4;
549+
#if (NV_TENSORRT_MAJOR > 8)
547550
case nvinfer1::DataType::kBF16:
551+
#endif
548552
case nvinfer1::DataType::kHALF: return 2;
549553
case nvinfer1::DataType::kBOOL:
550554
case nvinfer1::DataType::kUINT8:
551555
case nvinfer1::DataType::kINT8:
552556
case nvinfer1::DataType::kFP8: return 1;
557+
#if (NV_TENSORRT_MAJOR > 8)
553558
case nvinfer1::DataType::kINT4:
554559
ASSERT(false && "Element size is not implemented for sub-byte data-types");
560+
#endif
555561
}
556562
return 0;
557563
}

src/Detector/tensorrt_yolo/common/safeCommon.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,17 +153,23 @@ inline uint32_t elementSize(nvinfer1::DataType t)
153153
{
154154
switch (t)
155155
{
156+
#if (NV_TENSORRT_MAJOR > 8)
156157
case nvinfer1::DataType::kINT64: return 8;
158+
#endif
157159
case nvinfer1::DataType::kINT32:
158160
case nvinfer1::DataType::kFLOAT: return 4;
159161
case nvinfer1::DataType::kHALF:
162+
#if (NV_TENSORRT_MAJOR > 8)
160163
case nvinfer1::DataType::kBF16: return 2;
164+
#endif
161165
case nvinfer1::DataType::kINT8:
162166
case nvinfer1::DataType::kUINT8:
163167
case nvinfer1::DataType::kBOOL:
164168
case nvinfer1::DataType::kFP8: return 1;
169+
#if (NV_TENSORRT_MAJOR > 8)
165170
case nvinfer1::DataType::kINT4:
166171
SAFE_ASSERT(false && "Element size is not implemented for sub-byte data-types");
172+
#endif
167173
}
168174
return 0;
169175
}

src/Detector/tensorrt_yolo/common/sampleDevice.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,11 +503,19 @@ class OutputAllocator : public nvinfer1::IOutputAllocator
503503
}
504504

505505
//! IMirroredBuffer does not implement Async allocation, hence this is just a wrap around
506+
#if (NV_TENSORRT_MAJOR > 8)
506507
void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment,
507508
cudaStream_t /*stream*/) noexcept override
508509
{
509510
return reallocateOutput(tensorName, currentMemory, size, alignment);
510511
}
512+
#else
513+
void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment,
514+
cudaStream_t /*stream*/) noexcept
515+
{
516+
return reallocateOutput(tensorName, currentMemory, size, alignment);
517+
}
518+
#endif
511519

512520
void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override
513521
{

src/Detector/tensorrt_yolo/common/sampleEngines.cpp

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,10 @@ nvinfer1::ICudaEngine* LazilyDeserializedEngine::get()
8080

8181
if (mEngine == nullptr)
8282
{
83+
#if (NV_TENSORRT_MAJOR > 8)
8384
SMP_RETVAL_IF_FALSE(getFileReader().isOpen() || !getBlob().empty(), "Engine is empty. Nothing to deserialize!",
8485
nullptr, sample::gLogError);
85-
86+
#endif
8687
using time_point = std::chrono::time_point<std::chrono::high_resolution_clock>;
8788
using duration = std::chrono::duration<float>;
8889
time_point const deserializeStartTime{std::chrono::high_resolution_clock::now()};
@@ -126,6 +127,7 @@ nvinfer1::ICudaEngine* LazilyDeserializedEngine::get()
126127
}
127128
#endif
128129

130+
#if (NV_TENSORRT_MAJOR > 8)
129131
if (getFileReader().isOpen())
130132
{
131133
mEngine.reset(mRuntime->deserializeCudaEngine(getFileReader()));
@@ -135,6 +137,11 @@ nvinfer1::ICudaEngine* LazilyDeserializedEngine::get()
135137
auto const& engineBlob = getBlob();
136138
mEngine.reset(mRuntime->deserializeCudaEngine(engineBlob.data, engineBlob.size));
137139
}
140+
#else
141+
auto const& engineBlob = getBlob();
142+
mEngine.reset(mRuntime->deserializeCudaEngine(engineBlob.data, engineBlob.size));
143+
std::cerr << "getFileReader is not implemented! Use TensorRT 10.x and higher" << std::endl;
144+
#endif
138145
SMP_RETVAL_IF_FALSE(mEngine != nullptr, "Engine deserialization failed", nullptr, sample::gLogError);
139146

140147
time_point const deserializeEndTime{std::chrono::high_resolution_clock::now()};
@@ -405,8 +412,12 @@ bool setTensorDynamicRange(INetworkDefinition const& network, float inRange = 2.
405412

406413
bool isNonActivationType(nvinfer1::DataType const type)
407414
{
408-
return type == nvinfer1::DataType::kINT32 || type == nvinfer1::DataType::kINT64 || type == nvinfer1::DataType::kBOOL
409-
|| type == nvinfer1::DataType::kUINT8;
415+
return type == nvinfer1::DataType::kINT32
416+
#if (NV_TENSORRT_MAJOR > 8)
417+
|| type == nvinfer1::DataType::kINT64
418+
#endif
419+
|| type == nvinfer1::DataType::kBOOL
420+
|| type == nvinfer1::DataType::kUINT8;
410421
}
411422

412423
void setLayerPrecisions(INetworkDefinition& network, LayerPrecisions const& layerPrecisions)
@@ -567,6 +578,7 @@ void setLayerDeviceTypes(
567578

568579
void markDebugTensors(INetworkDefinition& network, StringSet const& debugTensors)
569580
{
581+
#if (NV_TENSORRT_MAJOR > 8)
570582
for (int64_t inputIndex = 0; inputIndex < network.getNbInputs(); ++inputIndex)
571583
{
572584
auto* t = network.getInput(inputIndex);
@@ -589,6 +601,9 @@ void markDebugTensors(INetworkDefinition& network, StringSet const& debugTensors
589601
}
590602
}
591603
}
604+
#else
605+
std::cerr << "Can not markDebugTensors. Use TensorRT 10.x or higher" << std::endl;
606+
#endif
592607
}
593608

594609
void setMemoryPoolLimits(IBuilderConfig& config, BuildOptions const& build)
@@ -626,10 +641,12 @@ void setMemoryPoolLimits(IBuilderConfig& config, BuildOptions const& build)
626641
{
627642
config.setMemoryPoolLimit(MemoryPoolType::kDLA_GLOBAL_DRAM, roundToBytes(build.dlaGlobalDRAM));
628643
}
644+
#if (NV_TENSORRT_MAJOR > 8)
629645
if (build.tacticSharedMem >= 0)
630646
{
631647
config.setMemoryPoolLimit(MemoryPoolType::kTACTIC_SHARED_MEMORY, roundToBytes(build.tacticSharedMem, false));
632648
}
649+
#endif
633650
}
634651

635652
void setPreviewFeatures(IBuilderConfig& config, BuildOptions const& build)
@@ -641,7 +658,9 @@ void setPreviewFeatures(IBuilderConfig& config, BuildOptions const& build)
641658
config.setPreviewFeature(feat, build.previewFeatures.at(featVal));
642659
}
643660
};
661+
#if (NV_TENSORRT_MAJOR > 8)
644662
setFlag(PreviewFeature::kALIASED_PLUGIN_IO_10_03);
663+
#endif
645664
}
646665

647666
} // namespace
@@ -845,7 +864,7 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,
845864

846865
if (build.maxTactics != defaultMaxTactics)
847866
{
848-
#if (NV_TENSORRT_MAJOR < 9)
867+
#if (NV_TENSORRT_MAJOR < 8)
849868
config.setMaxNbTactics(build.maxTactics);
850869
#else
851870
config.setTacticSources(build.maxTactics);
@@ -856,7 +875,7 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,
856875
{
857876
config.setFlag(BuilderFlag::kDISABLE_TIMING_CACHE);
858877
}
859-
878+
#if (NV_TENSORRT_MAJOR > 8)
860879
if (build.disableCompilationCache)
861880
{
862881
config.setFlag(BuilderFlag::kDISABLE_COMPILATION_CACHE);
@@ -866,7 +885,7 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,
866885
{
867886
config.setFlag(BuilderFlag::kERROR_ON_TIMING_CACHE_MISS);
868887
}
869-
888+
#endif
870889
if (!build.tf32)
871890
{
872891
config.clearFlag(BuilderFlag::kTF32);
@@ -876,13 +895,13 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,
876895
{
877896
config.setFlag(BuilderFlag::kREFIT);
878897
}
879-
898+
#if (NV_TENSORRT_MAJOR > 8)
880899
if (build.stripWeights)
881900
{
882901
// The kREFIT_IDENTICAL is enabled by default when kSTRIP_PLAN is on.
883902
config.setFlag(BuilderFlag::kSTRIP_PLAN);
884903
}
885-
904+
#endif
886905
if (build.versionCompatible)
887906
{
888907
config.setFlag(BuilderFlag::kVERSION_COMPATIBLE);
@@ -924,23 +943,25 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,
924943
{
925944
config.setFlag(BuilderFlag::kINT8);
926945
}
946+
#if (NV_TENSORRT_MAJOR > 8)
927947
if (build.bf16)
928948
{
929949
config.setFlag(BuilderFlag::kBF16);
930950
}
951+
#endif
931952

932953
SMP_RETVAL_IF_FALSE(!(build.int8 && build.fp8), "FP8 and INT8 precisions have been specified", false, err);
933954

934955
if (build.fp8)
935956
{
936957
config.setFlag(BuilderFlag::kFP8);
937958
}
938-
959+
#if (NV_TENSORRT_MAJOR > 8)
939960
if (build.int4)
940961
{
941962
config.setFlag(BuilderFlag::kINT4);
942963
}
943-
964+
#endif
944965
if (build.int8 && !build.fp16)
945966
{
946967
sample::gLogInfo
@@ -1136,7 +1157,9 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,
11361157
}
11371158

11381159
config.setHardwareCompatibilityLevel(build.hardwareCompatibilityLevel);
1160+
#if (NV_TENSORRT_MAJOR > 8)
11391161
config.setRuntimePlatform(build.runtimePlatform);
1162+
#endif
11401163

11411164
if (build.maxAuxStreams != defaultMaxAuxStreams)
11421165
{
@@ -1145,7 +1168,11 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys,
11451168

11461169
if (build.allowWeightStreaming)
11471170
{
1171+
#if (NV_TENSORRT_MAJOR > 8)
11481172
config.setFlag(BuilderFlag::kWEIGHT_STREAMING);
1173+
#else
1174+
std::cerr << "BuilderFlag::kWEIGHT_STREAMING not allowed in TensorRT with version less than 10.x" << std::endl;
1175+
#endif
11491176
}
11501177

11511178
return true;
@@ -1208,9 +1235,13 @@ bool modelToBuildEnv(
12081235
env.builder.reset(createBuilder());
12091236
SMP_RETVAL_IF_FALSE(env.builder != nullptr, "Builder creation failed", false, err);
12101237
env.builder->setErrorRecorder(&gRecorder);
1238+
#if (NV_TENSORRT_MAJOR > 8)
12111239
auto networkFlags = (build.stronglyTyped)
12121240
? 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED)
12131241
: 0U;
1242+
#else
1243+
auto networkFlags = 0U;
1244+
#endif
12141245
#if !TRT_WINML
12151246
for (auto const& pluginPath : sys.dynamicPlugins)
12161247
{
@@ -1304,8 +1335,12 @@ std::pair<std::vector<std::string>, std::vector<WeightsRole>> getMissingLayerWei
13041335

13051336
bool loadStreamingEngineToBuildEnv(std::string const& filepath, BuildEnvironment& env, std::ostream& err)
13061337
{
1338+
#if (NV_TENSORRT_MAJOR > 8)
13071339
auto& reader = env.engine.getFileReader();
13081340
SMP_RETVAL_IF_FALSE(reader.open(filepath), "", false, err << "Error opening engine file: " << filepath);
1341+
#else
1342+
SMP_RETVAL_IF_FALSE(false, "", false, err << "Error opening engine file: " << filepath);
1343+
#endif
13091344
return true;
13101345
}
13111346

@@ -1337,12 +1372,14 @@ bool printPlanVersion(BuildEnvironment& env, std::ostream& err)
13371372
std::vector<uint8_t> data(kPLAN_SIZE);
13381373
auto blob = data.data();
13391374

1375+
#if (NV_TENSORRT_MAJOR > 8)
13401376
auto& reader = env.engine.getFileReader();
13411377
if (reader.isOpen())
13421378
{
13431379
SMP_RETVAL_IF_FALSE(reader.read(data.data(), kPLAN_SIZE) == kPLAN_SIZE, "Failed to read plan file", false, err);
13441380
}
13451381
else
1382+
#endif
13461383
{
13471384
SMP_RETVAL_IF_FALSE(env.engine.getBlob().data != nullptr, "Plan file is empty", false, err);
13481385
SMP_RETVAL_IF_FALSE(env.engine.getBlob().size >= 28, "Plan file is incorrect", false, err);
@@ -1473,14 +1510,21 @@ std::vector<std::pair<WeightsRole, Weights>> getAllRefitWeightsForLayer(const IL
14731510
{
14741511
case DataType::kFLOAT:
14751512
case DataType::kHALF:
1513+
#if (NV_TENSORRT_MAJOR > 8)
14761514
case DataType::kBF16:
1515+
#endif
14771516
case DataType::kINT8:
14781517
case DataType::kINT32:
1479-
case DataType::kINT64: return {std::make_pair(WeightsRole::kCONSTANT, weights)};
1518+
#if (NV_TENSORRT_MAJOR > 8)
1519+
case DataType::kINT64:
1520+
#endif
1521+
return {std::make_pair(WeightsRole::kCONSTANT, weights)};
14801522
case DataType::kBOOL:
14811523
case DataType::kUINT8:
14821524
case DataType::kFP8:
1525+
#if (NV_TENSORRT_MAJOR > 8)
14831526
case DataType::kINT4:
1527+
#endif
14841528
// Refit not supported for these types.
14851529
break;
14861530
}
@@ -1530,7 +1574,9 @@ std::vector<std::pair<WeightsRole, Weights>> getAllRefitWeightsForLayer(const IL
15301574
case LayerType::kPARAMETRIC_RELU:
15311575
case LayerType::kPLUGIN:
15321576
case LayerType::kPLUGIN_V2:
1577+
#if (NV_TENSORRT_MAJOR > 8)
15331578
case LayerType::kPLUGIN_V3:
1579+
#endif
15341580
case LayerType::kPOOLING:
15351581
case LayerType::kQUANTIZE:
15361582
case LayerType::kRAGGED_SOFTMAX:
@@ -1610,11 +1656,10 @@ bool timeRefit(INetworkDefinition const& network, nvinfer1::ICudaEngine& engine,
16101656
}
16111657
return layerNames.empty();
16121658
};
1613-
1614-
// Skip weights validation since we are confident that the new weights are similar to the weights used to build
1615-
// engine.
1659+
#if (NV_TENSORRT_MAJOR > 8)
1660+
// Skip weights validation since we are confident that the new weights are similar to the weights used to build engine.
16161661
refitter->setWeightsValidation(false);
1617-
1662+
#endif
16181663
// Warm up and report missing weights
16191664
// We only need to set weights for the first time and that can be reused in later refitting process.
16201665
bool const success = setWeights() && reportMissingWeights() && refitter->refitCudaEngine();
@@ -1623,9 +1668,10 @@ bool timeRefit(INetworkDefinition const& network, nvinfer1::ICudaEngine& engine,
16231668
return false;
16241669
}
16251670

1626-
TrtCudaStream stream;
1627-
constexpr int32_t kLOOP = 10;
16281671
time_point const refitStartTime{std::chrono::steady_clock::now()};
1672+
constexpr int32_t kLOOP = 10;
1673+
#if (NV_TENSORRT_MAJOR > 8)
1674+
TrtCudaStream stream;
16291675
{
16301676
for (int32_t l = 0; l < kLOOP; l++)
16311677
{
@@ -1636,6 +1682,7 @@ bool timeRefit(INetworkDefinition const& network, nvinfer1::ICudaEngine& engine,
16361682
}
16371683
}
16381684
stream.synchronize();
1685+
#endif
16391686
time_point const refitEndTime{std::chrono::steady_clock::now()};
16401687

16411688
sample::gLogInfo << "Engine refitted"

0 commit comments

Comments
 (0)