Skip to content

Commit 6595260

Browse files
authored
LSTM bf16: support bias and c_states to be bf16 (#154)
* use bf16 bias, cx and cy of bf16 LSTM * clang format
1 parent 31bbf68 commit 6595260

File tree

4 files changed

+160
-156
lines changed

4 files changed

+160
-156
lines changed

tests/cpu/test_autocast.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,12 @@ def _test_lstm(self, training, bf16, prec = 1e-5):
272272
self.assertEqual(h_ipex.dtype, torch.float)
273273
self.assertEqual(c_ipex.dtype, torch.float)
274274

275-
# with mkldnn LSTM, y, hy[0] is bf16 and hy[1] is fp32
276275
self.assertEqual(y_ipex.dtype, torch.bfloat16)
277276
self.assertEqual(hy_ipex[0].dtype, torch.bfloat16)
278-
self.assertEqual(hy_ipex[1].dtype, torch.float)
277+
self.assertEqual(hy_ipex[1].dtype, torch.bfloat16)
279278
self.assertEqual(y, y_ipex, prec=prec)
280279
self.assertEqual(hy[0], hy_ipex[0], prec=prec)
281-
282-
self.assertEqual(hy[1], self._cast_dtype(hy_ipex[1], bf16), prec=prec)
280+
self.assertEqual(hy[1], hy_ipex[1], prec=prec)
283281

284282
def _test_lstm_pack_padded_sequence(self):
285283
embedding_dim = 1024

tests/cpu/test_rnnt_custom_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def test_rnnt_update_batch(self):
120120

121121
x_org = torch.randn([self.max_len, batch_size, 2], dtype=dtype)
122122
x = copy.deepcopy(x_org)
123-
hidden = [torch.zeros([2, batch_size, 320], dtype=dtype), torch.zeros([2, batch_size, 320], dtype=torch.float)]
124-
hidden_prime = [torch.randn([2, batch_size, 320], dtype=dtype), torch.randn([2, batch_size, 320], dtype=torch.float)]
123+
hidden = [torch.zeros([2, batch_size, 320], dtype=dtype), torch.zeros([2, batch_size, 320], dtype=dtype)]
124+
hidden_prime = [torch.randn([2, batch_size, 320], dtype=dtype), torch.randn([2, batch_size, 320], dtype=dtype)]
125125

126126
blank_vec_org, blankness_org, label_col_org, time_idxs_org, symbols_added_org, not_blank_org, label_tensor_org, hidden_org, f_org = self._test_org(hidden, hidden_prime, x_org.transpose(0, 1), batch_size, max_symbol, blank_id, loop_cnt)
127127
blank_vec_out, blankness_out, label_col, time_idxs, symbols_added, not_blank, label_tensor, hidden, f = self._test_rnnt_update_batch_kernel(hidden, hidden_prime, x.transpose(0,1), batch_size, max_symbol, blank_id, loop_cnt)

0 commit comments

Comments
 (0)