Need better handling for multi-label tasks and take infer: https://github.com/zj-zhang/AMBER/blob/master/amber/backend/pytorch/model.py#L43