diff --git a/utils.py b/utils.py index 15e09f5..7e02b42 100755 --- a/utils.py +++ b/utils.py @@ -353,9 +353,11 @@ def CoxLoss(survtime, censor, hazard_pred, device): R_mat[i,j] = survtime[j] >= survtime[i] R_mat = torch.FloatTensor(R_mat).to(device) + censor = censor.reshape(-1) theta = hazard_pred.reshape(-1) exp_theta = torch.exp(theta) - loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * censor) + loss_cox = (theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * censor + loss_cox = -(loss_cox.sum() / censor.sum()) # mean over samples who experienced the event only (where censor==1) return loss_cox @@ -889,4 +891,4 @@ def makeAUROCPlot(ckpt_name='./checkpoints/grad_15/', model_list=['path', 'omic' zoom = '_zoom' if use_zoom else '' for i, fig in enumerate(figures): - fig.savefig(ckpt_name+'/AUC_%s%s.png' % (classes[i], zoom)) \ No newline at end of file + fig.savefig(ckpt_name+'/AUC_%s%s.png' % (classes[i], zoom))