{"title": "Neural Proximal Gradient Descent for Compressive Imaging", "book": "Advances in Neural Information Processing Systems", "page_first": 9573, "page_last": 9583, "abstract": "Recovering high-resolution images from limited sensory data typically leads to a serious ill-posed inverse problem, demanding inversion algorithms that effectively capture the prior information. Learning a good inverse mapping from training data faces severe challenges, including: (i) scarcity of training data; (ii) need for  plausible reconstructions that are physically feasible; (iii) need for fast reconstruction, especially in real-time applications. We develop a successful system solving all these challenges, using as basic architecture the repetitive application of alternating proximal and data fidelity constraints. We learn a proximal map that works well with real images based on residual networks with recurrent blocks. Extensive experiments are carried out under different settings: (a) reconstructing abdominal MRI of pediatric patients from highly undersampled k-space data and (b) super-resolving natural face images. Our key findings include: 1. a recurrent ResNet with a single residual block (10-fold repetition) yields an effective proximal which accurately reveals MR image details. 2. Our architecture significantly outperforms conventional non-recurrent deep ResNets by 2dB SNR; it is also trained much more rapidly. 3. It outperforms state-of-the-art compressed-sensing Wavelet-based methods by 4dB SNR, with 100x speedups in reconstruction time.", "full_text": "Neural Proximal Gradient Descent for Compressive\n\nImaging\n\nMorteza Mardani1, Qingyun Sun4, Shreyas Vasawanala2, Vardan Papyan3,\n\nHatef Monajemi3, John Pauly1, and David Donoho3\n\nDepts. of 1Electrical Eng., 2Radiology, 3Statistics, and 4Mathematics; Stanford University\n\nmorteza,qysun,vasanawala,papyan,monajemi,pauly,donoho@stanford.edu\n\nAbstract\n\nRecovering high-resolution images from limited sensory data typically leads to a\nserious ill-posed inverse problem, demanding inversion algorithms that effectively\ncapture the prior information. Learning a good inverse mapping from training data\nfaces severe challenges, including: (i) scarcity of training data; (ii) need for plausi-\nble reconstructions that are physically feasible; (iii) need for fast reconstruction,\nespecially in real-time applications. We develop a successful system solving all\nthese challenges, using as basic architecture the recurrent application of proximal\ngradient algorithm. We learn a proximal map that works well with real images\nbased on residual networks. Contraction of the resulting map is analyzed, and\nincoherence conditions are investigated that drive the convergence of the iterates.\nExtensive experiments are carried out under different settings: (a) reconstructing\nabdominal MRI of pediatric patients from highly undersampled Fourier-space\ndata and (b) superresolving natural face images. Our key \ufb01ndings include: 1. a\nrecurrent ResNet with a single residual block unrolled from an iterative algorithm\nyields an effective proximal which accurately reveals MR image details. 2. Our\narchitecture signi\ufb01cantly outperforms conventional non-recurrent deep ResNets by\n2dB SNR; it is also trained much more rapidly. 3. It outperforms state-of-the-art\ncompressed-sensing Wavelet-based methods by 4dB SNR, with 100x speedups in\nreconstruction time.\n\nIntroduction\n\n1\nLinear inverse problems appear broadly in image restoration tasks, in applications ranging from\nnatural image superresolution to biomedical image reconstruction. In such tasks, one oftentimes\nencounters a seriously ill-posed recovery task, which necessitates regularization with proper statistical\npriors. This is however impeded by the following challenges: c1) real-time and interactive tasks\ndemand a low overhead for inference; e.g., imagine MRI visualization for neurosurgery [1], or,\ninteractive superresolution on cell phones [2]; c2) the need for recovering plausible images that are\nconsistent with the physical model; this is particularly important for medical diagnosis, which is\nsensitive to artifacts; c3) and limited labeled training data especially for medical imaging.\nConventional compressed sensing (CS) relies on sparse coding of images in a proper transform domain\nvia a universal (cid:96)1-regularization; see e.g., [3, 4, 5]. To automate the time-intensive iterative soft-\nthresholding algorithm (ISTA) for sparse coding, [6] puts forth the learned ISTA (LISTA). Relying\non soft-thresholding it trains a simple (single dense layer) recurrent network to map measurements to\nthe (cid:96)1 sparse code as a surrogate for the (cid:96)0 code. [7] advocates a wider class of functions derived\nfrom proximal operators. [8] also adopts LSTMs to learn the minimal (cid:96)0 sparse code, where the\nlearned network was seen to improve the RIP of coherent dictionaries. Sparse recovery however is\nthe common objective of [8, 6], and the measurement model is not explicitly taken into account. No\nguarantees were also provided for the convergence and quality of the iterates.\n\n32nd Conference on Neural Information Processing Systems (NeurIPS 2018), Montr\u00e9al, Canada.\n\n\fDeep neural networks have recently proven quite powerful in modeling prior distributions for im-\nages [9, 10, 11, 12, 13, 14]. There is a handful of recent attempts to integrate the priors offered\nby generative nets for inverting linear inverse tasks dealing with local image restoration such as\nsuperresolution [10, 12], inpainting [13]; and more global tasks such as biomedical image recon-\nstruction [15, 16, 17, 19, 20, 21, 22, 23]. One can divide them into two main categories, with the\n\ufb01rst category being the post-processing methods that train a deep network to map a poor (linear)\nestimate of the image to the true one [15, 17, 20, 23, 10, 12, 19]. Residual networks (ResNets) are\na suitable choice for training such deep nets due to their stable training behavior [24] along with\npixel-wise and perceptual costs induced e.g., by generative adversarial networks (GANs) [9, 19].\nThe post-processing schemes offer a clear gain in computation time, but they offer no guarantee for\ndata \ufb01delity. Their accuracy is also only comparable with CS-based iterative methods. The second\ncategory is inspired by unrolling the iterations of classical optimization algorithms, and learns the\n\ufb01lters and nonlinearities by training deep CNNs [25, 16, 26, 27]. They improve the accuracy relative\nto CS, but deep denoising CNNs that are changing over iterations incur a huge training overhead.\nNote also that for a signal that has a low-dimensional code under a deep pre-trained generative\nmodel, [28, 29] establishes reconstruction guarantees. The inference however relies on a iterative\nprocedure based on empirical risk minimization that is quite time intensive for real-time applications.\nContributions. Aiming for rapid, feasible, and plausible image recovery in ill-posed linear inverse\ntasks, this paper puts forth a novel neural proximal gradient descent algorithm that learns the proximal\nmap using a recurrent ResNet. Local convergence of the iterates is studied for the inference phase\nassuming that the true image is a \ufb01xed point for a proximal (lies on a manifold represented by\nproximal). In particular, contraction of the learned proximal is empirically analyzed to ensure the\nRNN iterates converge to the true solution. Extensive evaluations are examined for the global task of\nMRI reconstruction, and a local task of natural image superresolution. We \ufb01nd:\n\u2022 For MRI reconstruction, it works better to repeat a small ResNet (with a single RB) several\n\u2022 Our recurrent ResNet architecture outperforms general deep network schemes by about 2dB\n\u2022 Our architecture outperforms existing state-of-the-art CS-WV schemes, with a 4dB gain in\n\nSNR, with much less training data needed. It is also trained much more rapidly.\n\ntimes than to build a general deep network.\n\nSNR, while achieving reconstruction with 100x reduction in computing time.\n\nThese \ufb01ndings rest on several novel project contributions:\n\nrecurrent ResNets.\n\n\u2022 Successful design and construction of a neural proximal gradient descent scheme based on\n\u2022 Rigorous experimental evaluations, both for undersampled pediatric MRI data, and for\nsuperresolving natural face images, comparing our proposed architecture with conventional\nnon-recurrent deep ResNets and with CS-WV.\n\u2022 Formal analysis of the map contraction for the proximal gradient algorithm with accompa-\n\nnying empirical measurements.\n\n2 Preliminaries and problem statement\nConsider an ill-posed linear system y = \u03a6x\u2217 + v with \u03a6 \u2208 Cm\u00d7n where m (cid:28) n, and v captures\nthe noise and unmodeled dynamics. Suppose the unknown and (complex-valued) image x lies in a\nlow-dimensional manifold. No information is known about the manifold besides the training samples\nX := {xi}N\ni=1 drawn from it, and the corresponding (possibly) noisy observations Y := {yi}N\ni=1.\nGiven a new undersampled observation y, the goal is to quickly recover a plausible image x\u2217.\nThe stated problem covers a wide range of image restoration tasks. For instance, in medical image\nreconstruction, \u03a6 describes a projection driven by physics of the acquisition system (e.g., Fourier\ntransform for MRI scanner, and Radon transform for the CT scanner). For image superresolution\nit is the downsampling operator that averages out nonoverlapping image regions to arrive at a low-\nresolution image. Given an image prior distribution, one typically forms a maximum-likelihood\nestimator formulated as a regularized least-squares (LS) program\n\n(cid:13)(cid:13)y \u2212 \u03a6x(cid:13)(cid:13)2\n\n1\n2\n\n2\n\nwith the regularizer \u03c8(\u00b7) parameterized by W that incorporates the image prior.\n\n(P1)\n\nmin\n\nx\n\n+ \u03c8(x;W)\n\n(1)\n\n\fIn order to solve (P1) one can adopt a variation of proximal gradient algorithm [30] with a proximal\noperator P\u03c8 that depends on the regularizer \u03c8(\u00b7,\u00b7) [30]. Starting from x0, and adopting a small step\nsize \u03b1 the overall iterative procedure is expressed as\n\n(cid:16)\n\nxt+1 = P\u03c8\n\nxt \u2212 \u03b1\u2207 1\n2\n\n(cid:13)(cid:13)y \u2212 \u03a6xt\n\n(cid:13)(cid:13)2(cid:17)\n\n(cid:16)\n\n= P\u03c8\n\nxt + \u03b1\u03a6H(y \u2212 \u03a6xt)\n\n(2)\n\n(cid:17)\n\nFor convex function \u03c8, the proximal map is monotone, and the \ufb01xed point of (2) coincides with the\nglobal optimum for (P1) [30]. For some simple prior distributions, the proximal operation is tractable\nin closed-form. One popular example of such a proximal pertains to (cid:96)1-norm regularization for sparse\ncoding, where the proximal operator gives rise to soft-thresholding and shrinkage in a certain domain\nsuch as Wavelet, or, Fourier. The associated iterations have been labeled ISTA; the related FISTA\niterations offer accelerated convergence [31].\n3 Neural Proximal learning\nMotivated by the proximal gradient iterations in (2), to design ef\ufb01cient network architectures that\nautomatically invert linear inverse tasks, the following questions need to be \ufb01rst addressed:\nQ1. How to ensure rapid inference with affordable training for real-time image recovery?\nQ2. How to ensure plausible reconstructions that are physically feasible?\n3.1 Deep recurrent network architecture\nThe recursion in (2) can be envisioned as a feedback loop which at the t-th iteration takes an image\nestimate xt, moves it towards the af\ufb01ne subspace of data consistent images, and then applies the\nproximal operator to obtain xt+1. The iterations adhere to the state-space model\n\nxt+1 = P\u03c8(st+1)\n\n(3)\n(4)\nwhere g(xt; y) := \u03b1\u03a6Hy+(I\u2212\u03b1\u03a6H\u03a6)xt is the gradient descent step that encourages data consistency.\nThe initial input is x0 = 0, with initial state s1 = \u03b1\u03a6Hy that is a linear (low-quality) image estimate.\nThe state variable is essentially a linear network with the learnable step size \u03b1 that linearly combines\nthe linear image estimate \u03a6Hy with the output of the previous iteration, namely xt.\nIn order to model the proximal mapping we use a homogeneous recurrent neural network depicted\nin Fig. 1. In essence, a truncated RNN with T iterations is used for training. The measurement y\nforms the input variables for all iterations, which together with the output of the previous iteration\nform the state variable for the current iteration. The proximal operator is modeled via a possibly deep\nneural network, as will be elaborated in the next section. As argued earlier, the proximal resembles\nprojection onto the manifold of visually plausible images. Thus, one can interpret P\u03c8 as a denoiser\nthat gradually removes the aliasing artifacts from the input image.\n3.2 Proximal modeling\nWe consider a K-layer neural network with element-wise activation function \u03c3(z) = D(z) \u00b7 z. We\nstudy several examples of the mask function D(z), including the step function for ReLU, and the\nsigmoid function for Swish [32]. The k-th layer maps hk\u22121 to hk through\n\nst+1 = g(cid:0)xt; y(cid:1)\n\nzk = Wkhk\u22121,\nhk = \u03c3(zk) = D(zk) \u00b7 zk\n\nwhere the bias term is included in the weight matrix. At the t-th iteration, the network starts with\nthe input z0 = xt, and outputs zK := xt+1. Typically, the linear weights Wk are modeled by\na convolution operation with a certain kernel and stride size. The network weights collected in\nW := {Wk}K\nk=1 then parameterize the proximal. To avoid vanishing gradients associated with\ntraining RNNs we can use ResNets [24] or, highway nets [33]. An alternate path to our model goes\nvia DiracNets [34] with Wk = I + \u00afWk, which are shown to exhibit similar behavior as ResNet.\n3.3 Neural proximal training\nIn order to learn the proximal map, the recurrent neural network in Fig. 1 is trained end-to-end using\nthe population of training data X and Y. For the measurement yi, RNN with T iterations recovers\n\n3\n\n\fFigure 1: Truncated RNN architecture for neural proximal learning with T iterations. P\u03c8 is modeled\nwith a multi-layer NN.\n\nT = (P\u03c8 \u25e6 g)T (\u03a6Hyi), where the composite map P\u03c8 \u25e6 g is parameterized by the network\n\u02c6xi = xi\ntraining weights W and step size \u03b1. Let \u02c6X denote the population of recovered images. In general,\none can use the population-wise costs such as GANs [9, 19], or, the element-wise costs such as (cid:96)1/(cid:96)2\nto penalize the difference between X and \u02c6X . To ease the exposition, we adopt the element-wise\nempirical risk minimization\n\n(cid:1) + (1 \u2212 \u03b2)\n\n(cid:96)(cid:0)xi, xT\n\nN(cid:88)\n\u03b2\ni = (P\u03c8 \u25e6 g)t(\u03a6Hyi),\n\u02c6xt\n\ni=1\n\ni\n\nN(cid:88)\n\nT(cid:88)\n\n(cid:107)yi \u2212 \u03a6xt\ni(cid:107)2\n\ni=1\n\nt=1\n\n\u2200i \u2208 [N ], t \u2208 [T ]\n\n(P2)\n\nminW,\u03b1\n\ns.t.\n\nfor some \u03b2 \u2208 [0, 1], where a typical choice for loss (cid:96) is MSE, i.e., (cid:96)(\u02c6x, x) = (cid:107)x \u2212 \u02c6x(cid:107)2. The second\nterm encourages the outputs of different iterations to be consistent with the measurements. It is\nfound to signi\ufb01cantly improve training convergence of RNN for large iteration numbers T when\nthe gradient vanishing can occur. Alternatively, to facilitate the training one can ask reconstructions\nat each iteration to be faithful with the ground-truth images as in [18]. Note, one can additionally\naugment (P2) with adversarial GAN loss as in our companion work [19] that favors more the image\nperceptual quality that is critical in medical imaging.\n4 Contraction Analysis\nConsider the trained RNN in Fig. 1. In the inference phase with a new measurement y, we are\nmotivated to study whether the iterates {(st, xt)} in (3)-(4) converge, their speed of convergence,\nand whether upon convergence they coincide with the true unknown image. To make the analysis\ntractable, the following assumptions are made:\n(A1) The measurements are noiseless, namely, y = \u03a6x\u2217. and the true image x\u2217 is close to a \ufb01xed\npoint of the proximal operator, namely (cid:107)x\u2217 \u2212 P\u03c8(x\u2217)(cid:107) \u2264 \u0001 for some small \u0001.\nThe \ufb01xed point assumption seems to be an stringent requirement, but it is typically made in this\ncontext to make the analysis tractable; see e.g., [28]. It roughly means that the images lie on a\nmanifold 1 represented by the map P\u03c8. Assuming that the train and test data lie on the same manifold,\none can enforce it during the training by adding a penalty term to (P2).\nThe mask can then be decomposed as\n\ndk\nt = D(zk\u2217 ) + (D(zk\n\nt ) \u2212 D(zk\u2217 )) = dk\u2217 + \u03b4k\nt .\n\nwhere dk\u2217 = D(zk\u2217 ) is the true mask, and \u03b4k\nthe K-layer neural network then yields the output\nt M 1\n\nxt+1 = M K\nt\n\n. . . M 2\n\nt (\u03b1\u03a6Hy + (I \u2212 \u03b1\u03a6H\u03a6)xt),\n\nt models the perturbation. Passing the input image xt into\n\n(5)\n\n(6)\n\n1We use here the term manifold purely in an informal sense.\n\n4\n\n\fwhere M k\n\nt = diag(dk\n\nt )W k. One can further write M k\n\nt as\n\nM k\n\nt = diag(dk\u2217 + \u03b4k\n\nt )W k = diag(dk\u2217)W k + diag(\u03b4k\n\nt )W k = M k\u2217 + diag(\u03b4k\n\nt )W k.\n\nLet us de\ufb01ne the residual operator\n\n\u2206t := M K\nt\n\n. . . M 2\n\nt M 1\nt\n\n(cid:124)\n\n(cid:123)(cid:122)\n\n:=Mt\n\n(cid:125)\n\n(cid:124)\n(cid:125)\n\u2212 M K\u2217 . . . M 2\u2217 M 1\u2217\n\n(cid:123)(cid:122)\n\n:=M\u2217\n\n.\n\nIt can then be expressed as \u2206t = \u22061\n\nt + . . . + \u2206K\n\nt with\n\nt :=(cid:80)\n\n\u2206s\n\nj1,...,js\n\nM K\u2217 . . . (diag(\u03b4js\n\nt )W js) . . . (diag(\u03b4j1\n\nt )W j1) . . . M 1\u2217 .\n\n(7)\n\n(8)\n\n(9)\n\nt captures the mask perturbation in every s-subset of the layers.\n\nThe term \u2206s\nRearranging the terms in (6), and using the assumption (A1), namely M\u2217x\u2217 = x\u2217 + \u03be for some\nrepresentation error \u03be such that (cid:107)\u03be(cid:107) \u2264 \u0001, and the noiseless model y = \u03a6x\u2217, we arrive at\n\nxt+1 \u2212 x\u2217 = (M\u2217 + \u2206t)(\u03b1\u03a6H\u03a6x\u2217 + (I \u2212 \u03b1\u03a6H\u03a6)xt) \u2212 x\u2217\n\n= M\u2217(I \u2212 \u03b1\u03a6H\u03a6)(xt \u2212 x\u2217) + \u2206t(I \u2212 \u03b1\u03a6H\u03a6)(xt \u2212 x\u2217) + \u2206tx\u2217 + \u03be\n\n(10)\n\nTo study the contraction property and thus local convergence of the iterates {xt} to the true solution\nx\u2217, let us \ufb01rst suppose that the perturbation xt \u2212 x\u2217 at t-th iteration belongs to the set St. We then\nintroduce the contraction parameter associated with M\u2217 as\n\n\u03b7t\n1 := sup\n\u03b4\u2208St\n\n(cid:107)M\u2217(I \u2212 \u03b1\u03a6H\u03a6)\u03b4(cid:107)\n\n(cid:107)\u03b4(cid:107)\n\n.\n\nSimilarly, for the perturbation map \u2206t de\ufb01ne the contraction parameter\n\n\u03b7t\n2 := sup\n\u03b4\u2208St\n\n(cid:107)\u2206t[x\u2217 + (I \u2212 \u03b1\u03a6H\u03a6)\u03b4](cid:107)\n\n(cid:107)\u03b4(cid:107)\n\nApplying triangle inequality to (10), one then simply arrives at\n\n(11)\n\n(12)\n\n1 + \u03b7t\n\n1 + \u03b7t\n\n1 + \u03b7t\n\n\u2264 (\u03b7t\n\n2)(cid:107)xt \u2212 x\u2217(cid:107) + \u0001\n\n(cid:107)xt+1 \u2212 x\u2217(cid:107) \u2264 (cid:107)M\u2217(I \u2212 \u03b1\u03a6H\u03a6)(xt \u2212 x\u2217)(cid:107) + (cid:107)\u2206t[(I \u2212 \u03b1\u03a6H\u03a6)(xt \u2212 x\u2217) + x\u2217](cid:107) + (cid:107)\u03be(cid:107) (13)\n(14)\nAccording to (14), for small values \u0001 \u2248 0 a suf\ufb01cient condition for (asymptotic) linear convergence of\nthe iterates {xt} to true x\u2217 is that lim supt\u2192\u221e(\u03b7t\n2) < 1. For the non-negligible representation\nerror \u03be, if one wants the iterates to converge within a \u03bd-ball of x\u2217, i.e., (cid:107)xt \u2212 x\u2217(cid:107) \u2264 \u03bd, a suf\ufb01cient\n2) < 1 \u2212 \u0001/\u03bd.\ncondition is that lim supt\u2192\u221e(\u03b7t\nMotivated by real-time applications, e.g., in MRI neurosurgery visualization, it is of high interest\nto use the minimum iteration count T that algorithm reaches within a close neighborhood of x\u2217.\nOur conjecture is that for a reasonably expressive neural proximal network, the perturbation masks\nt \u2265 T for\n\u03b4j\nt become highly sparse for the perturbed layers over the iterations so as \u03b7t\nsome small \u00012. Further analysis of this phenomenon, and establishing guarantees under simple and\ninterpretable conditions in terms of network parameters is an important next step. This is the subject\nof our ongoing research, and will be reported elsewhere. Nonetheless, the next section provides\nempirical observations about the contraction parameters, where in particular \u03b7t\n1 is observed to be an\n2.\norder-of-magnitude larger than \u03b7t\nRemark 1 [De-biasing]. In sparse linear regression, LASSO is used to obtain a sparse solution that\nis possibly biased, while the support is accurate. The solution can then be de-biased by solving a\nLS program given the LASSO support. In a similar manner, neural proximal gradient descent may\nintroduce a bias due to e.g., the representation error \u03be. To reduce the bias, after the convergence of\niterates to xT , one can \ufb01x the masks at all layers and replace the proximal map with the linear map\nMT , and then \ufb01nd another \ufb01xed point for the iterates (6).\n\n2 \u2264 \u00012,\n\n5\n\n\f5 Experiments\nPerformance of our novel neural proximal gradient descent scheme was assessed in two tasks:\nreconstructing pediatric MR images from undersampled k-space data; and superresolving natural\nface images. In the \ufb01rst task, undersampling k-space introduces aliasing artifacts that globally impact\nthe entire image, while in the second task the blurring is local. While our focus is mostly on MRI,\nexperiments with the image superresolution task are included to shed some light on the contraction\nanalysis in previous section. In particular, we aim to address the following questions:\nQ1. What is the performance compared with the conventional deep architectures and with CS-MRI?\nQ2. What is the proper depth for the proximal network, and number of iterations (T) for training?\nQ3. Can one empirically verify the deep contraction conditions for the convergence of the iterates?\n\n5.1 ResNets for proximal training\n\nTo address the above questions, we adopted a ResNet with a variable number of residual blocks\n(RB). Each RB consisted of two convolutional layers with 3 \u00d7 3 kernels and a \ufb01xed number of 128\nfeature maps, respectively, that were followed by batch normalization (BN) and ReLU activation. We\nfollowed these by three simple convolutional layers with 1 \u00d7 1 kernels, where the \ufb01rst two layers\nundergo ReLU activation.\nWe used the Adam SGD optimizer with the momentum parameter 0.9, mini-batch size 2, and initial\nlearning rate 10\u22125 that is halved every 10K iterations. Training was performed with TensorFlow\ninterface on an NVIDIA Titan X Pascal GPU with 12GB RAM. The source code for TensorFlow\nimplementation is publicly available in the Github page [35].\n5.2 MRI reconstruction and artifact suppression\nPerformance of our novel recurrent scheme was assessed in removing k-space undersampling artifacts\nfrom MR images. In essence, the MR scanner acquires Fourier coef\ufb01cients (k-space data) of the\nunderlying image across various coils. We focused on a single-coil MR acquisition model, where for\nthe n-th patient, the acquired k-space data admits\n\ni,j = [F(xn)]i,j + v(n)\ny(n)\n\ni,j , (i, j) \u2208 \u2126\n\n(15)\nHere, F refers to the 2D Fourier transform, and the set \u2126 indexes the sampled Fourier coef\ufb01cients.\nJust as in conventional CS MRI, we selected \u2126 based on variable-density sampling with radial view\nordering that is more likely to pick low frequency components from the center of k-space [4]. Only\n20% of Fourier coef\ufb01cients were collected.\nDataset. T1-weighted abdominal image volumes were acquired for 350 pediatric patients. Each 3D\nvolume includes 151 axial slices of size 200 \u00d7 100 pixels. All in-vivo scans were acquired on a 3T\nMRI scanner (GE MR750) with voxel resolution 1.07 \u00d7 1.12 \u00d7 2.4 mm. The input and output were\ncomplex-valued images of the same size and each included two channels for real and imaginary\ncomponents. The input image was generated using an inverse 2D FT of the k-space data where the\nmissing data were \ufb01lled with zeros (ZF); it is severely contaminated with artifacts.\n5.2.1 Performance for various number/size of iterations\nIn order to assess the impact of network architecture on image recovery performance, the RNN was\ntrained for a variable number of iterations (T ) with a variable number of residual blocks (RBs). 10K\nslices (67 patients) from the train dataset were randomly picked for training, and 1, 280 slices (9\npatients) from the test dataset for test. For training RNN, we use (cid:96)2 cost in (P2) with \u03b2 = 0.75.\nFig. 2 depicts the SNR and structural similarity index metric (SSIM) [36] versus the number of\niterations (copies), when proximal network comprises 1/2/5/10 RBs. It is observed that increasing\nthe number of iterations signi\ufb01cantly improves the SNR and SSIM, but lead to a longer inference and\ntraining time. In particular, using three iterations instead of one achieves more than 2dB SNR gain\nfor 1 RB, and more than 3dB for 2 RBs. Interestingly, when using a single iteration, adding more\nthan 5 RBs to make a deeper network does not yield further improvements; the SNR=24.33 for 10\nRBs, and SNR=24.15 for 5 RBs. Notice also that a single RB tends to be reasonably expressive to\nmodel the MR image denoising proximal, and as a result, repeating it several times, the SNR does not\nseem to exceed 27dB. Using 2 RBs however turns out to be more expressive to learn the proximal,\nand perform as good as using 5 RBs. Similar observations are made for SSIM.\n\n6\n\n\fFigure 2: Average SNR and SSIM versus the number of copies (iterations). Note, single copy ResNet\nrefers to the deep ResNet that is an exiting alternative to our proposed RNN.\n\nTable 1: Performance trade-off for various RNN architectures.\n\niterations\n\nRBs\n\ntrain time (hours)\n\ninference time (sec)\n\nSNR (dB)\n\n10\n5\n2\n\ndeep ResNet\n\nCS-TV\nCS-WV\n\n1\n2\n5\n10\nn/a\nn/a\n\n2\n4\n8\n12\nn/a\nn/a\n\n0.04\n0.10\n0.12\n\n0.0522\n\n1.30\n1.16\n\n26.07\n26.94\n26.55\n24.33\n22.20\n22.51\n\nSSIM\n0.9117\n0.9221\n0.9194\n0.8810\n0.82\n0.86\n\nTraining and inference time. Inference time is proportional to the number of unrolled iterations.\nPassing each image through one unrolled iteration with one RB takes 4 msec when fully using the\nGPU. It is hard to precisely evaluate the training and inference time under fair conditions as it strongly\ndepends on the implementation and the allocated memory and processing power per run. Estimated\ninference times as listed in Table 1 are averaged out over a few runs on the GPU. We observed\nempirically that with shared weights, e.g., 10 iterations with 1 RB, the training converges in 2 \u2212 3\nhours. In constrast, training a deep ResNet with 10 RBs takes around 10 \u2212 12 hours to converge.\n\n5.2.2 Comparison with sparse coding\nTo compare with conventional CS-MRI, CS-WV is tuned for best SNR performance using BART [37]\nthat runs 300 iterations of FISTA along with 100 iterations of conjugate gradient descent to reach\nconvergence. Quantitative results are listed under Table 1, where it is evident that the recurrent\nscheme with shared weights signi\ufb01cantly outperforms CS with more than 4dB SNR gain that leads to\nsharper images with \ufb01ner texture details as seen in Fig. 3. As a representative example, Fig. 3 depicts\nthe reconstructed abdominal slice of a test patient. CS-WV retrieves a blurry image that misses\nout the sharp details of the liver vessels. A deep ResNet with one iteration and 10 RBs captures a\ncleaner image, but still blurs out \ufb01ne texture details such as vessels. However, when using 10 unrolled\niterations with a single RB for proximal modeling, more details of the liver vessels are visible, and\nthe texture appears to be more realistic. Similarly, using 5 iterations and 2 RBs retrieves \ufb01ner details\nthan 2 iterations with relatively large 5 RBs network for proximal.\nIn summary, we make three key \ufb01ndings:\nF1. The proximal for denoising MR images can be well represented by training a ResNet with a small\nnumber 1 \u2212 2 of RBs.\nF2. Multiple back-and-forth iterations are needed to recover a plausible MR image that is physically\nfeasible.\nF3. Considering the training and inference overhead and the quality of reconstructed images, RNN\nwith 10 iterations and 1 RB proximal is promising to implement in clinical scanners.\n\n7\n\n\fFigure 3: A representative axial abdominal slice for a test patient reconstructed by zero-\ufb01lling (1st\ncolumn); CS-WV (2nd column); deep ResNet with 10 RBs (3rd column); and neural proximal\ngradient descent with 10 iterations and 1 RBs (4th column), 2 iterations and 5 RBs (5th column), 5\niterations and 2 RBs (6th column); and the gold-standard (7th column).\n\nFigure 4: Superresolved (2\u00d7) face images at different iterations (x0, x1, x5, x25) compared with the\nground-truth (x\u2217). Proximal is a single nonlinear layer CNN with kernel size 32.\n\n5.3 Veri\ufb01cation of the contraction conditions\n\nTo verify the contraction analysis developed for Proposition 1, we focus on the image superresolution\n(SR) task. In this linear inverse task, one only has access to a low-resolution (LR) image y = \u03c6 \u2217 x\ndownsampled via the convolution kernel \u03c6. To form y, the image pixels in 2 \u00d7 2 non-overlapping\nregions are averaged out. SR is a challenging ill-posed problem, and has been subject of intensive\nresearch; see e.g., [38, 39, 10, 2]. Our goal is not to achieve state-of-the-art performance, but a simple\nscenario to study the behavior of contraction parameters for proximal learning.\nCelebA dataset. Adopting celebFaces Attributes Dataset (CelebA) [40], for training and test we use\n10K and 1, 280 images, respectively. Ground-truth images has 128 \u00d7 128 pixels that is downsampled\nto 64 \u00d7 64 LR images.\nThe proximal net is modeled as a 5-layer linear CNN with Smash nonlinearity [32] in the last layer.\nThe hidden layers undergo no nonlinearity and the kernel size 8 and 32 are adopted. Thus, it is\neffectively a single layer nonlinear neural network. The proximal then admits P\u03c8(x) = \u03c3(W x)\nas per (6). RNN with T = 25 is trained, and normalized RMSE, i.e., (cid:107)xt \u2212 x\u2217(cid:107)/(cid:107)x\u2217(cid:107) is plotted\nversus the iteration index in Fig. 5 (top) for various kernel sizes. It decreases quickly and after a few\niterations it converges which suggests that the converged solution is possibly a \ufb01xed point for the\nproximal map. For a representative face image, output of different iterations t0, t1, t5, t25 as well as\nthe ground-truth x\u2217 are plotted in Fig. 4. Apparently, the resolution improves over the iterations.\n\n8\n\n\fFigure 5: The top \ufb01gure is normalized RMSE evolution over iterations for image superresolution\ntask with different kernel sizes. The bottom ones are also the error bar for \u03b71 and \u03b72 per iteration for\nimage superresolution where the proximal is a single nonlinear layer CNN.\n\nt(cid:107)/(cid:107)\u03b4i\n\nt(cid:107), where \u03b4i\n\n2,t.\n\n1,t = (cid:107)M\u2217(I \u2212 \u03b1\u03a6H\u03a6)\u03b4i\n\nThe contraction parameters are also plotted in Fig. 5. The space of perturbations for the operator\nnorm are limited to the admissible ones that inherit the structure of iterations. For the i-th test sample,\nt \u2212 xi\u2217. The corresponding\nwe inspect the behavior \u03b7i\nt := xi\nerror bars are then plotted in Fig. 5 for kernel size 32. It is apparent that \u03b7i\n1,t and \u03b7i\n2,t quickly decay\nacross iterations, indicating that later iterations produce perturbations that are more incoherent to\nthe proximal map. Also, we can see that \u03b7i\n2,t converges to a level that represents the bias generated\nby the iterates, similar to the bias introduced in LASSO. In addition, one can observe that \u03b7i\n1,t is the\ndominant term - usually an order of magnitude larger than \u03b7i\n6 Conclusions\nThis paper develops a novel neural proximal gradient descent scheme for recovery of images from\nhighly compressed measurements. Unrolling the proximal gradient iterations, a recurrent architecture\nis proposed that models the proximal map via ResNets. For the trained network, contraction of\nthe proximal map and subsequently the local convergence of the iterates is studied and empirically\nevaluated. Extensive experiments are performed to assess various network wirings, and to verify the\ncontraction conditions in reconstructing MR images of pediatric patients, and superresolving natural\nimages. Our \ufb01ndings for MRI indicate that a small ResNet can effectively model the proximal, and\nsigni\ufb01cantly improve the quality and complexity of recent deep architectures as well as conventional\nCS-MRI.\nWhile this paper sheds some light on the local convergence of neural proximal gradient descent, our\nongoing research focuses on a more rigorous analysis to derive simple and interpretable contraction\nconditions. The main challenge pertains to understanding the distribution of activation masks that\nneeds extensive empirical evaluation. other important avenues that are the focus of our current\nresearch include: 1)Stable training of neural PGD for large iteration counts using gated recurrent\nnetworks; 2) comparing with existing deep learning based MRI reconstruction schemes such as deep\nADMM-net and LDMAP; 3) more extensive experiments for natural image superresolution with\ndeeper proximals and possibly using dilated convolutions for capturing large image \ufb01eld of view.\n\n9\n\n\f7 Acknowledgements\nWe would like to acknowledge Dr. Marcus Alley from the Radiology Department at Stanford\nUniversity for setting up the infrastructure to automatically collect the MRI dataset used in this paper.\nWe would also like to acknowledge Dr. Enhao Gong, and Dr. Joseph Cheng for fruitful discussions\nand their feedback about the MRI reconstruction and software implementation.\n\nReferences\n[1] http://www.mriinterventions.com/clearpoint/clearpoint-overview.html.\n\n[2] Yaniv Romano, John Isidoro, and Peyman Milanfar. RAISR: rapid and accurate image super resolution.\n\nIEEE Transactions on Computational Imaging, 3(1):110\u2013125, 2017.\n\n[3] David L Donoho. Compressed sensing. IEEE Transactions on information theory, 52(4):1289\u20131306, 2006.\n\n[4] Michael Lustig, David Donoho, and John M. Pauly. Sparse MRI: The application of compressed sensing\n\nfor rapid MR imaging. Magnetic Resonance in Medicine, 58(6):1182\u20131195, December 2007.\n\n[5] Julio Martin Duarte-Carvajalino and Guillermo Sapiro. Learning to sense sparse signals: Simultane-\nous sensing matrix and sparsifying dictionary optimization. IEEE Transactions on Image Processing,\n18(7):1395\u20131408, 2009.\n\n[6] Karol Gregor and Yann LeCun. Learning fast approximations of sparse coding. In Proceedings of the 27th\n\nInternational Conference on Machine Learning (ICML), pages 399\u2013406, June 2010.\n\n[7] Pablo Sprechmann, Alexander M Bronstein, and Guillermo Sapiro. Learning ef\ufb01cient sparse and low rank\nmodels. IEEE Transactions on Pattern Analysis and Machine Intelligence, 37(9):1821\u20131833, January\n2015.\n\n[8] Bo Xin, Yizhou Wang, Wen Gao, David Wipf, and Baoyuan Wang. Maximal sparsity with deep networks?\n\nIn Advances in Neural Information Processing Systems, pages 4340\u20134348, December 2016.\n\n[9] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron\nCourville, and Yoshua Bengio. Generative adversarial nets. In Advances in neural information processing\nsystems, pages 2672\u20132680, December 2014.\n\n[10] Justin Johnson, Alexandre Alahi, and Li Fei-Fei. Perceptual losses for real-time style transfer and\n\nsuper-resolution. In European Conference on Computer Vision, pages 694\u2013711. Springer, 2016.\n\n[11] Xudong Mao, Qing Li, Haoran Xie, Raymond YK Lau, Zhen Wang, and Stephen Paul Smolley. Least\nsquares generative adversarial networks. In 2017 IEEE International Conference on Computer Vision\n(ICCV), pages 2813\u20132821. IEEE, 2017.\n\n[12] Christian Ledig, Lucas Theis, Ferenc Husz\u00e1r, Jose Caballero, Andrew Cunningham, Alejandro Acosta,\nAndrew Aitken, Alykhan Tejani, Johannes Totz, Zehan Wang, et al. Photo-realistic single image super-\nresolution using a generative adversarial network. arXiv preprint, 2016.\n\n[13] Raymond Yeh, Chen Chen, Teck Yian Lim, Mark Hasegawa-Johnson, and Minh N Do. Semantic image\n\ninpainting with perceptual and contextual losses. arXiv preprint arXiv:1607.07539, 2016.\n\n[14] Hang Zhao, Orazio Gallo, Iuri Frosio, and Jan Kautz. Loss functions for image restoration with neural\n\nnetworks. IEEE Transactions on Computational Imaging, 3(1):47\u201357, 2017.\n\n[15] A. Majumdar. Real-time dynamic MRI reconstruction using stacked denoising autoencoder. arXiv preprint,\n\narXiv:1503.06383 [cs.CV], March 2015.\n\n[16] Jian Sun, Huibin Li, Zongben Xu, et al. Deep ADMM-net for compressive sensing MRI. In Advances in\n\nNeural Information Processing Systems, pages 10\u201318, 2016.\n\n[17] Hu Chen, Yi Zhang, Mannudeep K Kalra, Feng Lin, Yang Chen, Peixi Liao, Jiliu Zhou, and Ge Wang.\nLow-dose CT with a residual encoder-decoder convolutional neural network. IEEE transactions on medical\nimaging, 36(12):2524\u20132535, 2017.\n\n[18] Jiwon Kim, Jung Kwon Lee, and Kyoung Mu Lee. Deeply-recursive convolutional network for image\nsuper-resolution. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2016.\n\n10\n\n\f[19] Morteza Mardani, Enhao Gong, Joseph Y Cheng, Shreyas Vasanawala, Greg Zaharchuk, Lei Xing, and\nJohn M. Pauly. Generative Adversarial Neural Networks for Compressive Sensing (GANCS) MRI. IEEE\ntransactions on medical imaging, July 2018 (to appear).\n\n[20] Bo Zhu, Jeremiah Z. Liu, Bruce R. Rosen, and Matthew S. Rosen. Neural network MR image reconstruction\nwith AUTOMAP: Automated transform by manifold approximation. In Proceedings of the 25st Annual\nMeeting of ISMRM, Honolulu, HI, USA, 2017.\n\n[21] Shanshan Wang, Ningbo Huang, Tao Zhao, Yong Yang, Leslie Ying, and Dong Liang. 1D partial fourier\nparallel MR imaging with deep convolutional neural network. In Proceedings of the 25st Annual Meeting\nof ISMRM, Honolulu, HI, USA, 2017.\n\n[22] Jo Schlemper, Jose Caballero, Joseph V. Hajnal, Anthony Price, and Daniel Rueckert. A deep cascade of\nconvolutional neural networks for MR image reconstruction. In Proceedings of the 25st Annual Meeting of\nISMRM, Honolulu, HI, USA, 2017.\n\n[23] Dongwook Lee, Jaejun Yoo, and Jong Chul Ye. Compressed sensing and parallel MRI using deep residual\n\nlearning. In Proceedings of the 25st Annual Meeting of ISMRM, Honolulu, HI, USA, 2017.\n\n[24] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Identity mappings in deep residual networks.\n\nIn European Conference on Computer Vision, pages 630\u2013645. Springer, 2016.\n\n[25] Steven Diamond, Vincent Sitzmann, Felix Heide, and Gordon Wetzstein. Unrolled optimization with deep\n\npriors. arXiv preprint arXiv:1705.08041, 2017.\n\n[26] Chris Metzler, Ali Mousavi, and Richard Baraniuk. Learned D-AMP: Principled neural network based\ncompressive image recovery. In Advances in Neural Information Processing Systems, pages 1770\u20131781,\n2017.\n\n[27] Jonas Adler and Ozan \u00d6ktem. Learned primal-dual reconstruction. arXiv preprint arXiv:1707.06474,\n\n2017.\n\n[28] Ashish Bora, Ajil Jalal, Eric Price, and Alexandros G Dimakis. Compressed sensing using generative\n\nmodels. arXiv preprint arXiv:1703.03208, 2017.\n\n[29] Paul Hand and Vladislav Voroninski. Global guarantees for enforcing deep generative priors by empirical\n\nrisk. arXiv preprint arXiv:1705.07576, 2017.\n\n[30] Neal Parikh, Stephen Boyd, et al. Proximal algorithms. Foundations and Trends R(cid:13) in Optimization,\n\n1(3):127\u2013239, 2014.\n\n[31] Amir Beck and Marc Teboulle. A fast iterative shrinkage-thresholding algorithm for linear inverse problems.\n\nSIAM journal on imaging sciences, 2(1):183\u2013202, 2009.\n\n[32] Prajit Ramachandran, Barret Zoph, and Quoc V Le. Searching for activation functions. 2018.\n\n[33] Klaus Greff, Rupesh K Srivastava, and J\u00fcrgen Schmidhuber. Highway and residual networks learn unrolled\n\niterative estimation. arXiv preprint arXiv:1612.07771, 2016.\n\n[34] Sergey Zagoruyko and Nikos Komodakis. Diracnets: training very deep neural networks without skip-\n\nconnections. arXiv preprint arXiv:1706.00388, 2017.\n\n[35] https://github.com/MortezaMardani/NeuralPGD.html.\n\n[36] Zhou Wang, Alan C Bovik, Hamid R Sheikh, and Eero P Simoncelli. Image quality assessment: from error\n\nvisibility to structural similarity. IEEE Transactions on Image Processing, 13(4):600\u2013612, 2004.\n\n[37] Jonathan I Tamir, Frank Ong, Joseph Y Cheng, Martin Uecker, and Michael Lustig. Generalized Magnetic\nResonance Image Reconstruction using The Berkeley Advanced Reconstruction Toolbox. In ISMRM\nWorkshop on Data Sampling and Image Reconstruction, Sedona, 2016.\n\n[38] Joan Bruna, Pablo Sprechmann, and Yann LeCun. Super-resolution with deep convolutional suf\ufb01cient\n\nstatistics. arXiv preprint arXiv:1511.05666, 2015.\n\n[39] Casper Kaae S\u00f8nderby, Jose Caballero, Lucas Theis, Wenzhe Shi, and Ferenc Husz\u00e1r. Amortised map\n\ninference for image super-resolution. arXiv preprint arXiv:1610.04490, 2016.\n\n[40] Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In\n\nProceedings of International Conference on Computer Vision (ICCV), 2015.\n\n11\n\n\f", "award": [], "sourceid": 5840, "authors": [{"given_name": "Morteza", "family_name": "Mardani", "institution": "Stanford University"}, {"given_name": "Qingyun", "family_name": "Sun", "institution": "Stanford university"}, {"given_name": "David", "family_name": "Donoho", "institution": "Stanford University"}, {"given_name": "Vardan", "family_name": "Papyan", "institution": "Stanford University"}, {"given_name": "Hatef", "family_name": "Monajemi", "institution": "Stanford University"}, {"given_name": "Shreyas", "family_name": "Vasanawala", "institution": "Stanford University"}, {"given_name": "John", "family_name": "Pauly", "institution": "Stanford University"}]}