Learning Compact Representations of Neural Networks using DiscriminAtive Masking (DAM)

Part of Advances in Neural Information Processing Systems 34 (NeurIPS 2021)

Bibtex Paper Reviews And Public Comment » Supplemental

Authors

Jie Bu, Arka Daw, M. Maruf, Anuj Karpatne

Abstract

A central goal in deep learning is to learn compact representations of features at every layer of a neural network, which is useful for both unsupervised representation learning and structured network pruning. While there is a growing body of work in structured pruning, current state-of-the-art methods suffer from two key limitations: (i) instability during training, and (ii) need for an additional step of fine-tuning, which is resource-intensive. At the core of these limitations is the lack of a systematic approach that jointly prunes and refines weights during training in a single stage, and does not require any fine-tuning upon convergence to achieve state-of-the-art performance. We present a novel single-stage structured pruning method termed DiscriminAtive Masking (DAM). The key intuition behind DAM is to discriminatively prefer some of the neurons to be refined during the training process, while gradually masking out other neurons. We show that our proposed DAM approach has remarkably good performance over a diverse range of applications in representation learning and structured pruning, including dimensionality reduction, recommendation system, graph representation learning, and structured pruning for image classification. We also theoretically show that the learning objective of DAM is directly related to minimizing the L_0 norm of the masking layer. All of our codes and datasets are available https://github.com/jayroxis/dam-pytorch.