@@ -34,7 +34,13 @@ void CuDNNTypeRule::registerMatcher(MatchFinder &MF) {
34
34
" cudnnDataType_t" , " cudnnActivationDescriptor_t" ,
35
35
" cudnnActivationMode_t" , " cudnnLRNDescriptor_t" , " cudnnLRNMode_t" ,
36
36
" cudnnPoolingDescriptor_t" , " cudnnPoolingMode_t" ,
37
- " cudnnSoftmaxAlgorithm_t" , " cudnnSoftmaxMode_t" ))))))
37
+ " cudnnSoftmaxAlgorithm_t" , " cudnnSoftmaxMode_t" , " cudnnStatus_t" ,
38
+ " cudnnReduceTensorDescriptor_t" , " cudnnReduceTensorOp_t" ,
39
+ " cudnnOpTensorDescriptor_t" , " cudnnOpTensorOp_t" ,
40
+ " cudnnBatchNormOps_t" , " cudnnBatchNormMode_t" , " cudnnNormMode_t" ,
41
+ " cudnnNormOps_t" , " cudnnConvolutionDescriptor_t" ,
42
+ " cudnnConvolutionFwdAlgo_t" , " cudnnConvolutionBwdDataAlgo_t" ,
43
+ " cudnnConvolutionBwdFilterAlgo_t" , " cudnnFilterDescriptor_t" ))))))
38
44
.bind (" CuDNNType" ),
39
45
this );
40
46
MF.addMatcher (declRefExpr (to (enumConstantDecl (matchesName (" CUDNN_.*" ))))
@@ -76,17 +82,23 @@ void CuDNNTypeRule::runRule(const MatchFinder::MatchResult &Result) {
76
82
emplaceTransformation (new ReplaceText (BeginLoc, Len, std::move (Str)));
77
83
return ;
78
84
}
79
- } else if (auto *E = getNodeAsType<DeclRefExpr>(Result, " CuDNNEnumConstant" )) {
80
- if (!E)
81
- return ;
85
+ } else if (auto *E =
86
+ getNodeAsType<DeclRefExpr>(Result, " CuDNNEnumConstant" )) {
82
87
std::string EnumName = E->getNameInfo ().getName ().getAsString ();
83
- if (EnumName == " CUDNN_DATA_DOUBLE" ) {
84
- report (E->getBeginLoc (), Diagnostics::API_NOT_MIGRATED, false ,
85
- " data type double" );
86
- return ;
88
+
89
+ if (EnumName.find (" CUDNN_STATUS_" ) != std::string::npos) {
90
+ if (auto EC = dyn_cast<EnumConstantDecl>(E->getDecl ())) {
91
+ std::string Repl = toString (EC->getInitVal (), 10 );
92
+ emplaceTransformation (new ReplaceStmt (E, Repl));
93
+ return ;
94
+ }
95
+ } else if (EnumName == " CUDNN_BATCHNORM_SPATIAL_PERSISTENT" ) {
96
+ report (E->getBeginLoc (), Diagnostics::API_NOT_MIGRATED, false , EnumName);
87
97
}
98
+
88
99
auto Search = CuDNNEnumNamesMap.find (EnumName);
89
100
if (Search == CuDNNEnumNamesMap.end ()) {
101
+ report (E->getBeginLoc (), Diagnostics::API_NOT_MIGRATED, false , EnumName);
90
102
return ;
91
103
}
92
104
@@ -117,7 +129,41 @@ void CuDNNAPIRule::registerMatcher(ast_matchers::MatchFinder &MF) {
117
129
" cudnnGetPooling2dDescriptor" , " cudnnGetPooling2dForwardOutputDim" ,
118
130
" cudnnGetPoolingNdDescriptor" , " cudnnGetPoolingNdForwardOutputDim" ,
119
131
" cudnnPoolingForward" , " cudnnPoolingBackward" , " cudnnSoftmaxForward" ,
120
- " cudnnSoftmaxBackward" , " cudnnSetTensor" );
132
+ " cudnnSoftmaxBackward" , " cudnnSetTensor" ,
133
+ " cudnnCreateReduceTensorDescriptor" ,
134
+ " cudnnDestroyReduceTensorDescriptor" , " cudnnSetReduceTensorDescriptor" ,
135
+ " cudnnSetReduceTensorDescriptor" , " cudnnGetReduceTensorDescriptor" ,
136
+ " cudnnGetReductionWorkspaceSize" , " cudnnReduceTensor" ,
137
+ " cudnnCreateOpTensorDescriptor" , " cudnnDestroyOpTensorDescriptor" ,
138
+ " cudnnGetOpTensorDescriptor" , " cudnnSetOpTensorDescriptor" ,
139
+ " cudnnOpTensor" , " cudnnBatchNormalizationForwardInference" ,
140
+ " cudnnBatchNormalizationForwardTraining" ,
141
+ " cudnnBatchNormalizationForwardTrainingEx" ,
142
+ " cudnnBatchNormalizationBackward" , " cudnnBatchNormalizationBackwardEx" ,
143
+ " cudnnDeriveBNTensorDescriptor" ,
144
+ " cudnnGetBatchNormalizationBackwardExWorkspaceSize" ,
145
+ " cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize" ,
146
+ " cudnnGetBatchNormalizationTrainingExReserveSpaceSize" ,
147
+ " cudnnNormalizationForwardInference" ,
148
+ " cudnnNormalizationForwardTraining" , " cudnnNormalizationBackward" ,
149
+ " cudnnDeriveNormTensorDescriptor" ,
150
+ " cudnnGetNormalizationForwardTrainingWorkspaceSize" ,
151
+ " cudnnGetNormalizationTrainingReserveSpaceSize" ,
152
+ " cudnnCreateFilterDescriptor" , " cudnnDestroyFilterDescriptor" ,
153
+ " cudnnGetFilter4dDescriptor" , " cudnnGetFilterNdDescriptor" ,
154
+ " cudnnGetFilterSizeInBytes" , " cudnnSetFilter4dDescriptor" ,
155
+ " cudnnSetFilterNdDescriptor" , " cudnnCreateConvolutionDescriptor" ,
156
+ " cudnnDestroyConvolutionDescriptor" , " cudnnGetConvolution2dDescriptor" ,
157
+ " cudnnGetConvolution2dForwardOutputDim" ,
158
+ " cudnnGetConvolutionGroupCount" , " cudnnGetConvolutionNdDescriptor" ,
159
+ " cudnnGetConvolutionNdForwardOutputDim" ,
160
+ " cudnnSetConvolution2dDescriptor" , " cudnnSetConvolutionGroupCount" ,
161
+ " cudnnSetConvolutionNdDescriptor" , " cudnnConvolutionForward" ,
162
+ " cudnnConvolutionBackwardData" , " cudnnConvolutionBiasActivationForward" ,
163
+ " cudnnConvolutionBackwardBias" , " cudnnConvolutionBackwardFilter" ,
164
+ " cudnnGetConvolutionForwardWorkspaceSize" , " cudnnGetConvolutionBackwardDataWorkspaceSize" ,
165
+ " cudnnGetConvolutionBackwardFilterWorkspaceSize" ,
166
+ " cudnnGetNormalizationBackwardWorkspaceSize" );
121
167
};
122
168
123
169
MF.addMatcher (callExpr (callee (functionDecl (CuDNNAPI ()))).bind (" call" ), this );
0 commit comments