@@ -120,8 +120,8 @@ def test_rnnt_update_batch(self):
120
120
121
121
x_org = torch .randn ([self .max_len , batch_size , 2 ], dtype = dtype )
122
122
x = copy .deepcopy (x_org )
123
- hidden = [torch .zeros ([2 , batch_size , 320 ], dtype = dtype ), torch .zeros ([2 , batch_size , 320 ], dtype = torch . float )]
124
- hidden_prime = [torch .randn ([2 , batch_size , 320 ], dtype = dtype ), torch .randn ([2 , batch_size , 320 ], dtype = torch . float )]
123
+ hidden = [torch .zeros ([2 , batch_size , 320 ], dtype = dtype ), torch .zeros ([2 , batch_size , 320 ], dtype = dtype )]
124
+ hidden_prime = [torch .randn ([2 , batch_size , 320 ], dtype = dtype ), torch .randn ([2 , batch_size , 320 ], dtype = dtype )]
125
125
126
126
blank_vec_org , blankness_org , label_col_org , time_idxs_org , symbols_added_org , not_blank_org , label_tensor_org , hidden_org , f_org = self ._test_org (hidden , hidden_prime , x_org .transpose (0 , 1 ), batch_size , max_symbol , blank_id , loop_cnt )
127
127
blank_vec_out , blankness_out , label_col , time_idxs , symbols_added , not_blank , label_tensor , hidden , f = self ._test_rnnt_update_batch_kernel (hidden , hidden_prime , x .transpose (0 ,1 ), batch_size , max_symbol , blank_id , loop_cnt )
0 commit comments