Make a transformer that transforms `x.transpose(1, 2)` into `torch.einsum('abc...->acb...')` in order to then have these operations fused with the rest of the einsums
Make a transformer that transforms
x.transpose(1, 2)intotorch.einsum('abc...->acb...')in order to then have these operations fused with the rest of the einsums