The number of input parameters for function forward in class _AllToAll is not equal to the number returned by backward. This will cause the following error:
RuntimeError: function _AllToAllBackward returned an incorrect number of gradients (expected 6, got 5)
Adding a new None as a return parameter to backward solves this problem.