Constrained Optimization to Train Neural Networks on Critical and Under-Represented Classes

Part of Advances in Neural Information Processing Systems 34 (NeurIPS 2021)

Bibtex Paper Reviews And Public Comment » Supplemental

Authors

Sara Sangalli, Ertunc Erdil, Andeas Hötker, Olivio Donati, Ender Konukoglu

Abstract

Deep neural networks (DNNs) are notorious for making more mistakes for the classes that have substantially fewer samples than the others during training. Such class imbalance is ubiquitous in clinical applications and very crucial to handle because the classes with fewer samples most often correspond to critical cases (e.g., cancer) where misclassifications can have severe consequences.Not to miss such cases, binary classifiers need to be operated at high True Positive Rates (TPRs) by setting a higher threshold, but this comes at the cost of very high False Positive Rates (FPRs) for problems with class imbalance. Existing methods for learning under class imbalance most often do not take this into account. We argue that prediction accuracy should be improved by emphasizing the reduction of FPRs at high TPRs for problems where misclassification of the positive, i.e. critical, class samples are associated with higher cost.To this end, we pose the training of a DNN for binary classification as a constrained optimization problem and introduce a novel constraint that can be used with existing loss functions to enforce maximal area under the ROC curve (AUC) through prioritizing FPR reduction at high TPR. We solve the resulting constrained optimization problem using an Augmented Lagrangian method (ALM).Going beyond binary, we also propose two possible extensions of the proposed constraint for multi-class classification problems.We present experimental results for image-based binary and multi-class classification applications using an in-house medical imaging dataset, CIFAR10, and CIFAR100. Our results demonstrate that the proposed method improves the baselines in majority of the cases by attaining higher accuracy on critical classes while reducing the misclassification rate for the non-critical class samples.