@@ -3707,6 +3707,7 @@ struct test_im2col : public test_case {
3707
3707
struct test_conv_2d : public test_case {
3708
3708
const std::array<int64_t , 4 > ne_input;
3709
3709
const std::array<int64_t , 4 > ne_kernel;
3710
+ const ggml_type type_kernel;
3710
3711
const int stride0;
3711
3712
const int stride1;
3712
3713
const int padding0;
@@ -3724,7 +3725,7 @@ struct test_conv_2d : public test_case {
3724
3725
// IM2COL -> MUL_MM graph will be built.
3725
3726
3726
3727
std::string vars () override {
3727
- return VARS_TO_STR9 (ne_input, ne_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
3728
+ return VARS_TO_STR10 (ne_input, ne_kernel, type_kernel , stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
3728
3729
}
3729
3730
3730
3731
uint64_t op_flops (ggml_tensor * t) override {
@@ -3755,10 +3756,11 @@ struct test_conv_2d : public test_case {
3755
3756
}
3756
3757
3757
3758
test_conv_2d (std::array<int64_t , 4 > ne_input = { 64 , 64 , 16 , 1 },
3758
- std::array<int64_t , 4 > ne_kernel = { 3 , 3 , 1 , 16 }, int stride0 = 1 , int stride1 = 1 , int padding0 = 0 ,
3759
- int padding1 = 0 , int dilation0 = 1 , int dilation1 = 1 , bool cwhn = false ) :
3759
+ std::array<int64_t , 4 > ne_kernel = { 3 , 3 , 1 , 16 }, ggml_type type_kernel = GGML_TYPE_F32 , int stride0 = 1 ,
3760
+ int stride1 = 1 , int padding0 = 0 , int padding1 = 0 , int dilation0 = 1 , int dilation1 = 1 , bool cwhn = false ) :
3760
3761
ne_input (ne_input),
3761
3762
ne_kernel (ne_kernel),
3763
+ type_kernel (type_kernel),
3762
3764
stride0 (stride0),
3763
3765
stride1 (stride1),
3764
3766
padding0 (padding0),
@@ -3771,7 +3773,7 @@ struct test_conv_2d : public test_case {
3771
3773
ggml_tensor * input = ggml_new_tensor (ctx, GGML_TYPE_F32, 4 , ne_input.data ());
3772
3774
ggml_set_name (input, " input" );
3773
3775
3774
- ggml_tensor * kernel = ggml_new_tensor (ctx, GGML_TYPE_F32 , 4 , ne_kernel.data ());
3776
+ ggml_tensor * kernel = ggml_new_tensor (ctx, type_kernel , 4 , ne_kernel.data ());
3775
3777
ggml_set_name (kernel, " kernel" );
3776
3778
3777
3779
if (cwhn) {
@@ -5141,7 +5143,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
5141
5143
for (auto act_case : cases) {
5142
5144
test_cases.emplace_back (new test_conv_2d (
5143
5145
{ act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5144
- { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5146
+ { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
5147
+ GGML_TYPE_F32, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5148
+ test_cases.emplace_back (new test_conv_2d (
5149
+ { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5150
+ { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
5151
+ GGML_TYPE_F16, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5145
5152
}
5146
5153
#endif
5147
5154
@@ -5168,7 +5175,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
5168
5175
if (calc_conv_output_size (W, KW, s0, p0, d0) > 0 &&
5169
5176
calc_conv_output_size (H, KH, s1, p1, d1) > 0 ) {
5170
5177
test_cases.emplace_back (new test_conv_2d (
5171
- { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, s0, s1, p0, p1, d0, d1, false ));
5178
+ { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, GGML_TYPE_F32, s0, s1, p0, p1, d0, d1, false ));
5172
5179
}
5173
5180
}
5174
5181
}
@@ -5817,7 +5824,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
5817
5824
// Direct CONV_2D
5818
5825
test_cases.emplace_back (new test_conv_2d (
5819
5826
{ act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5820
- { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5827
+ { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
5828
+ GGML_TYPE_F32, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5829
+ test_cases.emplace_back (new test_conv_2d (
5830
+ { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5831
+ { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
5832
+ GGML_TYPE_F16, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5821
5833
}
5822
5834
5823
5835
test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {4096 , 1 , 1 , 1 }, {1 , 1 , 1 , 1 }));
0 commit comments