# learning_to_stop_while_learning_to_predict__daee2573.pdf Learning To Stop While Learning To Predict Xinshi Chen 1 Hanjun Dai 2 Yu Li 3 Xin Gao 3 Le Song 1 4 There is a recent surge of interest in designing deep architectures based on the update steps in traditional algorithms, or learning neural networks to improve and replace traditional algorithms. While traditional algorithms have certain stopping criteria for outputting results at different iterations, many algorithm-inspired deep models are restricted to a fixed-depth for all inputs. Similar to algorithms, the optimal depth of a deep architecture may be different for different input instances, either to avoid over-thinking , or because we want to compute less for operations converged already. In this paper, we tackle this varying depth problem using a steerable architecture, where a feed-forward deep model and a variational stopping policy are learned together to sequentially determine the optimal number of layers for each input instance. Training such architecture is very challenging. We provide a variational Bayes perspective and design a novel and effective training procedure which decomposes the task into an oracle model learning stage and an imitation stage. Experimentally, we show that the learned deep model along with the stopping policy improves the performances on a diverse set of tasks, including learning sparse recovery, few-shot meta learning, and computer vision tasks. 1. Introduction Recently, researchers are increasingly interested in the connections between deep learning models and traditional algorithms: deep learning models are viewed as parameterized algorithms that operate on each input instance iteratively, and traditional algorithms are used as templates for designing deep learning architectures. While an important con- 1Georgia Institute of Technology, USA 2Google Research, USA 3King Abdullah University of Science and Technology, Saudi Arabia 4Ant Financial, China. Correspondence to: Xinshi Chen , Le Song . Proceedings of the 37 th International Conference on Machine Learning, Online, PMLR 119, 2020. Copyright 2020 by the author(s). Fixed-depth Learned Algorithm 𝑓#$ 𝑓#% 𝑓#& (a): Learning-based Algorithm design (b): Task-imbalanced Meta Learning 𝒙𝒕 (output) Dynamic-depth Traditional Algorithm not satisfied criteria hand-designed update step Figure 1. Motivation for learning to stop. cept in traditional algorithms is the stopping criteria for outputting the result, which can be either a convergence condition or an early stopping rule, such stopping criteria has been more or less ignored in algorithm-inspired deep learning models. A fixed-depth deep model is used to operate on all problem instances (Fig. 1 (a)). Intuitively, for deep learning models, the optimal depth (or the optimal number of steps to operate on an input) can also be different for different input instances, either because we want to compute less for operations converged already, or we want to generalize better by avoiding over-thinking . Such motivation aligns well with both the cognitive science literature (?) and many examples below: In learning to optimize (Andrychowicz et al., 2016; Li & Malik, 2016), neural networks are used as the optimizer to minimize some loss function. Depending on the initialization and the objective function, an optimizer should converge in different number of steps; In learning to solve statistical inverse problems such as compressed sensing (Chen et al., 2018; Liu et al., 2019), inverse covariance estimation (Shrivastava et al., 2020), and image denoising (Zhang et al., 2019), deep models are learned to directly predict the recovery results. In traditional algorithms, problem-dependent early stopping rules are widely used to achieve regularization for a variance-bias trade-off. Deep learning models for solving such problems maybe also achieve a better recovery accuracy by allowing instance-specific computation steps; In meta learning, MAML (Finn et al., 2017) used an unrolled and parametrized algorithm to adapt a common Learning to Stop While Learning to Predict 𝒙 𝜃# 𝒙𝟏 𝜃% 𝒙𝟐 𝜋. 0.5 stop, output 𝒙𝒕 Figure 2. Two-component model: learning to predict (blue) while learning to stopping (green). parameter to a new task. However, depending on the similarity of the new task to the old tasks, or, in a more realistic task-imbalanced setting where different tasks have different numbers of data points (Fig. 1 (b)), a taskspecific number of adaptation steps is more favorable to avoid under or over adaption. To address the varying depth problem, we propose to learn a steerable architecture, where a shared feed-forward model for normal prediction and an additional stopping policy are learned together to sequentially determine the optimal number of layers for each input instance. In our framework, the model consists of (see Fig. 2) A feed-forward or recurrent mapping Fθ, which transforms the input x to generate a path of features (or states) x1, , x T ; and A stopping policy πφ : (x, xt) 7 πt [0, 1], which sequentially observes the states and then determines the probability of stopping the computation of Fθ at layer t. These two components allow us to sequentially predict the next targeted state while at the same time determining when to stop. In this paper, we propose a single objective function for learning both θ and φ, and we interpret it from the perspective of variational Bayes, where the stopping time t is viewed as a latent variable conditioned on the input x. With this interpretation, learning θ corresponds to maximizing the marginal likelihood, and learning φ corresponds to the inference step for the latent variable, where a variational distribution qφ(t) is optimized to approximate the posterior. A natural algorithm for solving this problem could be the Expectation-Maximization (EM) algorithm, which can be very hard to train and inefficient. How to learn θ and φ effectively and efficiently? We propose a principled and effective training procedure, where we decompose the task into an oracle model learning stage and an imitation learning stage (Fig. 3). More specifically, During the oracle model learning stage, we utilize a closed-form oracle stopping distribution q |θ which can leverage label information not available at testing time. In the imitation learning stage, we use a sequential policy πφ to mimic the behavior of the oracle policy obtained in the first stage. The sequential policy does not have access to the label so that it can be used during testing phase. $ 𝒥𝛽 VAE(ℱ𝜃, 𝑞𝜙) Alternating Updates ℱ𝜽 𝒒𝝓 VAE-based method: Our method: 6 𝒥𝛽 VAE(ℱ𝜃, 𝑞𝜙) 6 𝒥𝛽 VAE(ℱ𝜃, 𝑞 |𝜃) Figure 3. Two-stage training framework. This procedure provides us a very good initial predictive model and a stopping policy. We can either directly use these learned models, or plug them back to the variational EM framework and reiterate to further optimize both together. Our proposed learning to stop method is a generic framework that can be applied to a diverse range of applications. To summarize, our contribution in this paper includes: 1. a variational Bayes perspective to understand the proposed model for learning both the predictive model and the stopping policy together; 2. a principled and efficient algorithm for jointly learning the predictive model and the stopping policy; and the relation of this algorithm to reinforcement learning; 3. promising experiments on various tasks including learning to solve sparse recovery problems, task-imbalanced few-shot meta learning, and computer vision tasks, where we demonstrate the effectiveness of our method in terms of both the prediction accuracy and inference efficiency. 2. Related Works Unrolled algorithm. A line of recent works unfold and truncate iterative algorithms to design neural architectures. These algorithm-based deep models can be used to automatically learn a better algorithm from data. This idea has been demonstrated in different problems including sparse signal recovery (Gregor & Le Cun, 2010; Sun et al., 2016; Borgerding et al., 2017; Metzler et al., 2017; Zhang & Ghanem, 2018; Chen et al., 2018; Liu et al., 2019), sparse inverse covariance estimation (Shrivastava et al., 2020), sequential Bayesian inference (Chen et al., 2019), parameter learning in graphical models (Domke, 2011), non-negative matrix factorization (Yakar et al., 2013), etc. Unrolled algorithm based deep module has also be used for structured prediction (Belanger et al., 2017; Ingraham et al., 2019; Chen et al., Learning to Stop While Learning to Predict 2020). Before the training phase, all these works need to assign a fixed number of iterations that is used for every input instance regardless of their varying difficulty level. Our proposed method is orthogonal and complementary to all these works, by taking the variety of the input instances into account via adaptive stopping time. Meta learning. Optimization-based meta learning techniques are widely applied for solving challenging few-shot learning problems (Ravi & Larochelle, 2017; Finn et al., 2017; Li et al., 2017). Several recent advances proposed task-adaptive meta-learning models which incorporate taskspecific parameters (Qiao et al., 2018; Lee & Choi, 2018; Na et al., 2020) or task-dependent metric scaling (Oreshkin et al., 2018). In parallel with these task-adaptive methods, we propose a task-specific number of adaptation steps and demonstrate the effectiveness of this simple modification under the task-imbalanced scenarios. Other adaptive-depth deep models. In image recognition, early exits is proposed mainly aimed at improving the computation efficiency during the inference phase (Teerapittayanon et al., 2016; Zamir et al., 2017; Huang et al., 2018), but these methods are based on specific architectures. Kaya et al. (2019) proposed to avoiding over-thinking by early stopping. However, the same as all the other early exits models, some heuristic policies are adopted to choose the output layer by confidence scores of internal classifiers. Also, their algorithms for training the feed-forward model Fθ do not take into account the effect of the stopping policy. Optimal stopping. In optimal control literature, optimal stopping is a problem of choosing a time to take a given action based on sequentially observed random variables in order to maximize an expected payoff (Shiryaev, 2007). When a policy for controlling the evolution of random variables (corresponds to the output of Fθ) is also involved, it is called a mixed control problem, which is highly related to our work. Existing works in this area find the optimal controls by solving the Hamilton-Jacobi-Bellman (HJB) equation, which is theoretically grounded (Pham, 1998; Ceci & Bassan, 2004; Dumitrescu et al., 2018). However, they focus on stochastic differential equation based model and the proposed algorithms suffer from the curse of dimensionality problem. Becker et al. (2019) use DL to learn the optimal stopping policy, but the learning of θ is not considered. Besides, Becker et al. (2019) use reinforcement learning (RL) to solve the problem. In Section 4, we will discuss how our variational inference formulation is related to RL. 3. Problem Formulation In this section, we will introduce how we model the stopping policy together with the predictive deep model, define the joint optimization objective, and interpret this framework from a variational Bayes perspective. 3.1. Steerable Model The predictive model, Fθ, is a typical T-layer deep model that generates a path of embeddings (x1, , x T ) through: Predictive model: xt = fθt(xt 1), for t= 1, , T (1) where the initial x0 is determined by the input x. We denote it by Fθ = {fθ1, , fθT } where θ Θ are the parameters. Standard supervised learning methods learn θ by optimizing an objective estimated on the final state x T . In our model, the operations in Eq. 1 can be stopped earlier, and for different input instance x, the stopping time t can be different. Our stopping policy, πφ, determines whether to stop at t-th step after observing the input x and its first t states x1:t transformed by Fθ. If we assume the Markov property, then πφ only needs to observe the most recent state xt. In this paper, we only input x and xt to πφ at each step t, but it is trivial to generalize it to πφ(x, x1:t). More precisely, πφ is defined as a randomized policy as follows: Stopping policy: πt = πφ(x, xt), for t= 1, , T 1 (2) where πt [0, 1] is the probability of stopping. We abuse the notation π to both represent the parametrized policy and also the probability mass. This stopping policy sequentially makes a decision whenever a new state xt is observed. Conditioned on the states observed until step t, whether to stop before t is independent on states after t. Therefore, once it decides to stop at t, the remaining computations can be saved, which is a favorable property when the inference time is a concern, or for some optimal stopping problems such as option trading where getting back to earlier states is not allowed. 3.2. From Sequential Policy To Stop Time Distribution The stopping policy πφ makes sequential actions based on the observations, where πt := πφ(x, xt) is the probability of stopping when xt is observed. These sequential actions π1, , πT 1 jointly determines the random time t at which the stop occurs. Induced by πφ, the probability mass function of the stop time t, denoted as qφ, can be computed by Variational stop time distribution: ( qφ(t) = πt Qt 1 τ=1(1 πτ) if t < T, qφ(T) = QT 1 τ=1 (1 πτ) else. (3) In this equation, the product Qt 1 τ=1(1 πτ) indicates the probability of not stopped before t , which is the survival probability. Multiply this survival probability with πt, we have the stop time distribution qφ(t). For the last time step Learning to Stop While Learning to Predict T, the stop probability qφ(T) simply equals to the survival probability at T, which means if the process is not stopped before T , then it must stop at T. Note that we only use πφ in our model to sequentially determine whether to stop. However, we use the induced probability mass qφ to help design the training objective and also the algorithm. 3.3. Optimization Objective Note that the stop time t is a discrete random variable with distribution determined by qφ(t). Given the observed label y of an input x, the loss of the predictive model stopped at position t can computed as ℓ(y, xt; θ) where ℓ( ) is a loss function. Taking into account all possible stopping positions, we will be interested in the loss in expectation over t, L(θ, qφ; x, y) := Et qφℓ(y, xt; θ) βH(qφ), (4) where H(qφ) := P t qφ(t) log qφ(t) is an entropy regularization and β is the regularization coefficient. Given a data set D = {(x, y)}, the parameters of the predictive model and the stopping policy can be estimated by minθ,φ 1 |D| P (x,y) D L(θ, qφ; x, y). (5) To better interpret the model and objective, in the following, we will make a connection from the perspective of variational Bayes, and how the objective function defined in Eq. 4 is equivalent to the β-VAE objective. 3.4. Variational Bayes Perspective In the Bayes framework, a probabilistic model typically consists of prior, likelihood function and posterior of the latent variable. We find the correspondence between our model and a probabilistic model as follows (also see Table 1) we view the adaptive stopping time t as a latent variable which is unobserved; The conditional prior p(t|x) of t is a uniform distribution over all the layers in this paper. However, if one wants to reduce the computation cost and penalize the stopping decisions at deeper layers, a prior with smaller probability on deeper layers can be defined to regularize the results; The likelihood function pθ(y|t, x) of the observed label y is controlled by θ, since Fθ determines the states xt; The posterior distribution over the stopping time t can be computed by Bayes rule pθ(t|y, x) pθ(y|t, x)p(t|x), but it requires the observation of the label y, which is infeasible during testing phase. In this probabilistic model, we need to learn θ to better fit the observed data and learn a variational distribution qφ over t that only takes x and the transformed internal states as inputs to approximate the true posterior. Table 1. Corresponds between our model and Bayes model. stop time t latent variable label y observation loss ℓ(y, xt; θ) likelihood pθ(y|t, x) stop time distribution qφ posterior pθ(t|y, x) regularization prior p(t|x) More specifically, the parameters in the likelihood function and the variational posterior can be optimized using the variational autoencoder (VAE) framework (Kingma & Welling, 2013). Here we consider a generalized version called βVAE (Higgins et al., 2017), and obtain the optimization objective for data point (x, y) Jβ-VAE(θ, qφ; x, y) := Eqφ log pθ(y|t,x) βKL(qφ(t)||p(t|x)), (6) where KL( || ) is the KL divergence. When β = 1, it becomes the original VAE objective, i.e., the evidence lower bound (ELBO). Now we are ready to present the equivalence relation between the β-VAE objective and the loss defined in Eq. 4. See Appendix A.1 for the proof. Lemma 1. Under assumptions: (i) the loss function ℓin Eq. 4 is defined as the negative log-likelihood (NLL), i.e., ℓ(y, xt; θ) := log pθ(y|t, x); (ii) the prior p(t|x) is a uniform distribution over t; then minimizing the loss L in Eq. 4 is equivalent to maximizing the β-VAE objective Jβ-VAE in Eq. 6. For classification problems, the cross-entropy loss is aligned with NLL. For regression problems with mean squared error (MSE) loss, we can define the likelihood as pθ(y|t, x) N(xt, I). Then the NLL of this Gaussian distribution is log pθ(y|t, x) = 1 2 y xt 2 2 + C, which is equivalent to MSE loss. More generally, we can always define pθ(y|t, x) exp( ℓ(y, xt; θ)). This VAE view allows us to design a two-step procedure to effectively learn θ and φ in the predictive model and stopping policy, which is presented in the next section. 4. Effective Training Algorithm VAE-based methods perform optimization steps over θ (M step for learning) and φ (E step for inference) alternatively until convergence, which has two limitations in our case: i. The alternating training can be slow to converge and requires tuning the training scheduling; ii. The inference step for learning qφ may have the mode collapse problem, which in this case means qφ only captures the time step t with highest averaged frequency. Learning to Stop While Learning to Predict To overcome these limitations, we design a training procedure followed by an optional fine-tuning stage using the variational lower bound in Eq. 6. More specifically, Stage I. Find the optimal θ by maximizing the conditional mariginal likelihood when the stop time distribution follows an oracle distribution q θ. Stage II. Fix the optimal θ learned in Stage I, and only learn the distribution qφ to mimic the oracle by minimizing the KL divergence between qφ and q θ. Stage III. (Optional) Fine-tune θ and φ jointly towards the joint objective in Eq. 6. The overall algorithm steps are summarized in Algorithm 1. In the following sections, we will focus on the derivation of the first two training steps. Then we will discuss several methods to further improve the memory and computation efficiency for training. 4.1. Oracle Stop Time Distribution We first give the definition of the oracle stop time distribution q θ. For each fixed θ, we can find a closed-form solution for the optimal q θ that optimizes the joint objective. q θ( |y, x) := arg maxq T 1 Jβ-VAE(θ, q; x, y) Alternatively, q θ( |y, x) = arg minq T 1 L(θ, q; x, y). Under the mild assumptions in Lemma 1, these two optimizations lead to the same optimal oracle distribution. Oracle stop time distribution: q θ(t|y, x) = pθ(y|t, x) 1 β PT t=1 pθ(y|t, x) 1 β (7) β ℓ(y, xt; θ)) PT t=1 exp( 1 β ℓ(y, xt; θ)) (8) This closed-form solution makes it clear that the oracle picks a step t according to the smallest loss or largest likelihood with an exploration coefficient β. Remark: When β = 1, q θ is the same as the posterior distribution pθ(t|y, x) pθ(y|t, x)p(t|x). Note that there are no new parameters in the oracle distribution. Instead, it depends on the parameters θ in the predictive model. Overall, the oracle q θ is a function of θ, t, y and x that has a closed-form. Next, we will introduce how we use this oracle in the first two training stages. 4.2. Stage I. Predictive Model Learning In Stage I, we optimize the parameters θ in the predictive model by taking into account the oracle stop distribution q θ . This step corresponds to the M step for learning θ, by maximizing the marginal likelihood. The difference with Algorithm 1 Overall Algorithm Randomly initialized θ and φ. For itr = 1 to #iterations do Stage I. Sample a batch of data points B D. Take an optimization step to update θ towards the marginal likelihood function defined in Eq. 9. For itr = 1 to #iterations do Stage II. Sample a batch of data points B D. Take an optimization step to update φ towards the reverse KL divergence defined in Eq. 10. For itr = 1 to #iterations do Optional Step Sample a batch of data points B D. Update both θ and φ towards β-VAE objective in Eq. 6. return θ, φ the normal M step is that here qφ is replaced by the oracle q θ that gives the optimal stopping distribution so that the marginal likelihood is independent on φ. More precisely, stage I finds the optimum of: max θ 1 |D| (x,y) D Jβ VAE(θ, q θ; x, y), (9) where the β-VAE objective here is Jβ VAE(θ, q θ; x, y) = PT t=1 q θ(t|y, x) log pθ(y|t, x) βKL(q θ(t)||p(t|x)). Remark. For experiments that require higher memory costs (e.g., MAML), we prefer to drop the entropy term, βKL(q θ(t)||p(t|x)), in the objective, so that stochastic sampling can be applicable to reduce the memory cost (see more details of the efficient training algorithm in Sec. 4.5). Since we can adjust β in the oracle q to control the concentration level of the distribution, dropping the entropy term in the objective in stage I does not affect much the performance. Since q θ has a differentiable closed-form expression in terms of θ, x, y and t, the gradient can also propagate through q θ, which is also different from the normal M step. To summarize, in Stage I., we learn the predictive model parameter θ, by assuming that the stop time always follows the best stopping distribution that depends on θ. In this case, the learning of θ has already taken into account the effect of the data-specific stop time. However, we note that the oracle q θ is not in the form of sequential actions as in Eq. 2 and it requires the access to the true label y, so it can not be used for testing. However, it plays an important role in obtaining a sequential policy which will be explained next. 4.3. Stage II. Imitation With Sequential Policy In Stage II, we learn the sequential policy πφ that can best mimic the oracle distribution q θ, where θ is fixed to be the optimal θ learned in Stage I. The way of doing so is Learning to Stop While Learning to Predict to minimize the divergence between the oracle q θ and the variational stop time distribution qφ induced by πφ (Eq. 3). There are various variational divergence minimization approaches that we can use (Nowozin et al., 2016). For example, a widely used objective for variational inference is the reverse KL divergence: KL(qφ||q θ) = PT t=1 qφ(t) log q θ(t|y, x) H(qφ). Remark. We write qφ(t) instead of qφ(t|x1:T , x) for notation simplicity, but qφ is dependent on x and x1:T (Eq. 3). If we rewrite qφ using π1, , πT 1 as defined in Eq. 3, we can find that minimizing the reverse KL is equivalent to finding the optimal policy πφ in a reinforcement learning (RL) environment, where the state is xt, action at πt := πφ(x, xt) is a stop/continue decision, the state transition is determined by θ and at, and the reward is defined as r(xt, at; y) := ( βℓ(y, xt; θ) if at = 0 (i.e. stop) 0 if at = 1 (i.e. continue) where ℓ(y, xt; θ) = log pθ(y|t, x). More details and also the derivation are given in Appendix A.2 to show that minimizing KL(qφ||q θ) is equivalent to solving the following maximum-entropy RL: max φ Eπφ PT t=1 [r(xt, at; y) + H(πt)] . In some related literature, optimal stopping problem is often formulated as an RL problem (Becker et al., 2019). Above we bridge the connection between our variational inference formulation and the RL-based optimal stopping literature. Although reverse KL divergence is a widely used objective, it suffers from the mode collapse issue, which in our case may lead to a distribution qφ that captures only a common stopping time t for all x that on average performs the best, instead of a more spread-out stopping time. Therefore, we consider the forward KL divergence: KL(q θ||qφ) = t=1 q θ(t|y, x) log qφ(t) H(q θ), (10) which is equivalent to the cross-entropy loss, since the term H(q θ) can be ignored as θ is fixed in this step. Experimentally, we find forward KL leads to a better performance. 4.4. The Optional Fine Tuning Stage It is easy to see that our two-stage training procedure also has an EM flavor. However, with the oracle q θ incorporated, the training of θ has already taken into account the effect of the optimal stopping distribution. Therefore, we can save a lot of alternation steps. After the two-stage training, we can fine-tune θ and φ jointly towards the β-VAE objective. Experimentally, we find this additional stage does not improve much the performance trained after the first two stages. 4.5. Implementation Details For Efficient Training Since both objectives in oracle learning stage (Eq. 9) and imitation stage (Eq. 10) involve the summation over T layers, the computation and memory costs during training are higher than standard learning methods. The memory issue is especially important in meta learning. In the following, we introduce several ways of improving the training efficiency. Fewer output channels. Instead of allowing the model to output xt at any layer, we can choose a smaller number of output channels that are evenly placed along with the layers. Stochastic sampling in Step I. A Monte Carlo method can be used to approximate the expectation over q θ in Step I. More precisely, for each (x, y) we can randomly sample a layer ts q θ(t|y, x) from the oracle, and only compute log pθ(y|ts, x) at ts, instead of summing over all t [T]. Note that, in this case, the gradient will not back-propagate through q θ(t|y, x). As explained earlier in Sec 4.2, the entropy term can be dropped to reduce the memory cost. MAP estimate in Step II. Instead of approximating the distribution q θ, we can approximate the maximum a posterior (MAP) estimate ˆt(x, y) = arg maxt [T ] q θ(t|y, x) so that the objective for each sample is log qθ(ˆt(x, y)), which does not involve the summation over t. Except for efficiency, we also find this MAP estimate can lead to a higher accuracy, by encouraging the learning of qφ to focus more on the sample-wise best layer. 5. Experiments We conduct experiments on (i) learning-based algorithm for sparse recovery, (ii) few-shot meta learning, and (iii) image denoising. The comparison is in an ablation study fashion to better examine whether the stopping policy can improve the performances given the same architecture for the predictive model, and whether our training algorithm is more effective compared to the alternating EM algorithm. In the end, we also discuss our exploration of the image recognition task. Pytorch implementation of the experiments is released at https://github.com/xinshi-chen/l2stop. 5.1. Learning To Optimize: Sparse Recovery We consider a sparse recovery task which aims at recovering x Rn from its noisy linear measurements b = Ax + ϵ, where A Rm n, ϵ Rm is Gaussian white noise, and m n. A popular approach is to model the problem as the LASSO formulation minx 1 2 b Ax 2 2 +ρ x 1 and solves it using iterative methods such as the ISTA (Blumensath & Davies, 2008) and FISTA (Beck & Teboulle, 2009) algorithms. We choose the most popular model named Learned ISTA (LISTA) as the baseline and also as our predictive model. LISTA is a T-layer network with update steps: Learning to Stop While Learning to Predict xt = ηλt(W 1 t b + W 2 t xt 1), t = 1, , T, (11) where θ = {(λt, W 1 t , W 2 t )}T t=1 are leanable parameters. Experiment setting. We follow Chen et al. (2018) to generate the samples. The signal-to-noise ratio (SNR) for each sample is uniformly sampled from 20, 30, and 40. The training loss for LISTA is PT t=1 γT t xt x 2 2 where γ 1. It is commonly used for algorithm-based deep learning, so that there is a supervision signal for every layer. For ISTA and FISTA, we use the training set to tune the hyperparameters by grid search. See Appendix B.1 for more details. Table 2. Recovery performances of different algorithms/models. SNR mixed 20 30 40 FISTA (T = 100) -18.96 -16.75 -20.46 -20.97 ISTA (T = 100) -14.66 -13.99 -14.99 -15.07 ISTA (T = 20) -9.17 -9.12 -9.24 -9.16 FISTA (T = 20) -11.12 -10.98 -11.19 -11.19 LISTA (T = 20) -17.58 -16.52 -18.16 -18.29 LISTA-stop (T 20) -22.41 -20.29 -23.90 -24.21 Recovery performance. (Table 2) We report the NMSE (in d B) results for each model/algorithm evaluated on 1000 fixed test samples per SNR level. It is revealed in Table 2 that learning-based methods have better recovery performances, especially for the more difficult tasks (i.e. when SNR is 20). Compared to LISTA, our proposed adaptivestopping method (LISTA-stop) significantly improve recovery performance. Also, LISTA-stop with 20 iterations performs better than ISTA and FISTA with 100 iterations, which indicates a better convergence. Stopping distribution. The stop time distribution qφ(t) induced by πφ can be computed via Eq. 3. We report in Fig. 4 the stopping distribution averaged over the test samples, from which we can see that with a high probability LISTA-stop terminates the process before arriving at 20-th iteration. (a) stop time distribution (b) convergence Figure 4. Left: Stop time distribution 1 |Dtest| P x Dtest qφ(t|x) averaged over the test set. Right: Convergence of different algorithms. For LISTA-stop, the NMSE weighted by the stopping distribution qφ is plotted. In the first 13 iterations qφ(t) = 0, so no red dots are plotted. Convergence comparison. Fig. 4 shows the change of NMSE as the number of iterations increases. Since LISTAstop outputs the results at different iteration steps, it is not meaningful to draw a unified convergence curve. Therefore, we plot the NMSE weighted by the stopping distribution qφ, i.e., 10 log10( PN i=1 qφ(t|i) xt x ,i 2 2 PN i=1 qφ(t|i) /( PN i=1 x ,i 2 2 N ), using the red dots. We observe that for LISTA-stop the expected NMSE increases as the number of iterations increase, this might indicate that the later stopped problems are more difficult to solve. Besides, at 15th iteration, the NMSE in Fig. 4 (b) is the smallest, while the averaged stop probability mass qφ(15) in Fig. 4 (a) is the highest. Table 3. Different algorithms for training LISTA-stop. SNR mixed 20 30 40 AEVB algorithm -21.92 -19.92 -23.27 -23.58 Stage I. + II. -22.41 -20.29 -23.90 -24.21 Stage I.+II.+III. -22.78 -20.59 -24.29 -24.73 Ablation study on training algorithms. To show the effectiveness of our two-stage training, in Table 3, we compare the results with the auto-encoding variational Bayes (AEVB) algorithm (Le et al., 2018) that jointly optimizes Fθ and qφ. We observe that the distribution qφ in AEVB gradually becomes concentrated on one layer and does not get rid of this local minimum, making its final result not as good as the results of our two-stage training. Moreover, it is revealed that Stage III does not improve much of the performance of the two-stage training, which also in turn shows the effectiveness of the oracle-based two-stage training. 5.2. Task-imbalanced Meta Learning In this section, we perform meta learning experiments in the few-short learning domain (Ravi & Larochelle, 2017). Experiment setting. We follow the setting in MAML (Finn et al., 2017) for the few-shot learning tasks. Each task is an N-way classification that contains meta-{train, valid, test} sets. On top of it, the macro dataset with multiple tasks is split into train, valid and test sets. We consider the more realistic task-imbalanced setting proposed by Na et al. (2020). Unlike the standard setting where the meta-train of each task contains k-shots for each class, here we vary the number of observation to perform k1k2-shot learning where k1 < k2 are the minimum/maximum number of observations per class, respectively. Build on top of MAML, we denote our variant as MAML-stop which learns how many adaptation gradient descent steps are needed for each task. Intuitively, the tasks with less training data would prefer fewer steps of gradient-update to prevent overfitting. As we mainly focus on the effect of learning to stop, the neural architecture and other hyperparameters are largely the same as MAML. Please refer to Appendix B.2 for more details. Learning to Stop While Learning to Predict Table 4. Few-shot classification in vanilla meta learning setting (Finn et al., 2017) where all tasks have the same number of data points. Omniglot 5-way Omniglot 20-way Mini Imagenet 5-way 1-shot 5-shot 1-shot 5-shot 1-shot 5-shot MAML 98.7 0.4% 99.1 0.1% 95.8 0.3% 98.9 0.2% 48.70 1.84% 63.11 0.92% MAML-stop 99.62 0.22% 99.68 0.12% 96.05 0.35% 98.94 0.10 % 49.56 0.82% 63.41 0.80% Dataset. We use the benchmark datasets Omniglot (Lake et al., 2011) and Mini Imagenet (Ravi & Larochelle, 2017). Omniglot consists of 20 instances of 1623 characters from 50 different alphabets, while Mini Imagenet involves 64 training classes, 12 validation classes, and 24 test classes. We use exactly the same data split as Finn et al. (2017). To construct the imbalanced tasks, we perform 20-way 1-5 shot classification on Omniglot and 5-way 1-10 shot classification on Mini Imagenet. The number of observations per class in each meta-test set is 1 and 5 for Omniglot and Mini Imagenet, respectively. For evaluation, we construct 600 tasks from the held-out test set for each setting. Table 5. Task-imbalanced few-shot image classification. Omniglot Mini Imagenet 20-way, 1-5 shot 5-way, 1-10 shot MAML 97.96 0.3% 57.20 1.1% MAML-stop 98.45 0.2% 60.67 1.0% Results. Table 5 summarizes the accuracy and the 95% confidence interval on the held-out tasks for each dataset. The maximum number of adaptation gradient descent steps is 10 for both MAML and MAML-stop. We can see the optimal stopping variant of MAML outperforms the vanilla MAML consistently. For a more difficult task on Mini Imagenet where the imbalance issue is more severe, the accuracy improvement is 3.5%. For completeness, we include the performance on vanilla meta learning setting where all tasks have the same number of observations in Table 4. MAMLstop still achieves comparable or better performance. 5.3. Image Denoising In this section, we perform the image denoising experiments. More implementation details are provided in Appendix B.3. Dataset. The models are trained on BSD500 (400 images) (Arbelaez et al., 2010), validated on BSD12, and tested on BSD68 (Martin et al., 2001). We follow the standard setting in (Zhang et al., 2019; Lefkimmiatis, 2018; Zhang et al., 2017) to add Gaussian noise to the images with a random noise level σ 55 during training and validation phases. Experiment setting. We compare with two DL models, Dn CNN (Zhang et al., 2017) and UNLNet5 (Lefkimmiatis, 2018), and two traditional methods, BM3D (Dabov et al., 2007) and WNNM (Gu et al., 2014). Since Dn CNN is one of the most widely-used models for image denoising, we use it as our predictive model. All deep models including ours are considered in the blind Gaussian denoising setting, which means the noise-level is not given to the model, while BM3D and WNNM require the noise-level to be known. Table 6. PSNA performance comparison. The sign * indicates that noise levels 65 and 75 do not appear in the training set. σ Dn CNN-stop Dn CNN UNLNet5 BM3D WNNM 35 27.61 27.60 27.50 26.81 27.36 45 26.59 26.56 26.48 25.97 26.31 55 25.79 25.71 25.64 25.21 25.50 *65 23.56 22.19 - 24.60 24.92 *75 18.62 17.90 - 24.08 24.39 Results. The performance is evaluated by the mean peak signal-to-noise ratio (PSNR). Table 6 shows that Dn CNNstop performs better than the original Dn CNN. Especially, for images with noise levels 65 and 75 which are unseen during training phase, Dn CNN-stop generalizes significantly better than Dn CNN alone. Since there is no released code for UNLNet5, its performances are copied from the paper (Lefkimmiatis, 2018), where results are not reported for σ = 65 and 75. For traditional methods BM3D and WNNM, the test is in the noise-specific setting. That is, the noise level is given to both BM3D and WNNM, so the comparison is not completely fair to learning based methods in blind denoising setting. Ground Truth Dn CNN-stop Figure 5. Denoising results of an image with noise level 65. (See Appendix B.3.2 for more visualization results.) Learning to Stop While Learning to Predict 5.4. Image Recognition We explore the potential of our idea for improving the recognition performances on Tiny-Image Net, using VGG16 (Simonyan & Zisserman, 2014) as the predictive model. With 14 internal classifiers, after Stage I training, if the oracle q θ is used to determine the stop time t, the accuracy of VGG16 can be improved to 83.26%. Similar observation is provided in SDN (Kaya et al., 2019), but their loss P t wtℓt depends on very careful hand-tuning on the weight wt for each layer, while we directly take an expectation using the oracle, which is more principled and leads to higher accuracy (Table 7). However, it reveals to be very hard to mimic the behavior of the orcale q θ by πφ in Stage II, either due to the need of a better parametrization for πφ or more sophisticated reasons. Our learned πφ leads to similar accuracy as the heuristic policy in SDN, which becomes the bottleneck in our exploration. However, based on the large performance gap between the oracle and the original VGG16, our result still provides a potential direction for breaking the performance bottleneck of DL on image recognition. Table 7. Image recognition with oracle stop distribution. VGG16 SDN training Our Stage I. training 58.60% 77.78% (best layer) 83.26% (best layer) 6. Conclusion In this paper, we introduce a generic framework for modelling and training a deep learning model with input-specific depth, which is determined by a stopping policy πφ. Extensive experiments are conducted to demonstrate the effectiveness of both the model and the training algorithm, on a wide range of applications. In the future, it will be interesting to see whether other aspects of algorithms can be incorporated into deep learning models either to improve the performance or for better theoretical understandings. Acknowledgement We would like to thank anonymous reviewers for providing constructive feedbacks. This work is supported in part by NSF grants CDS&E-1900017 D3SC, CCF-1836936 FMit F, IIS-1841351, CAREER IIS-1350983 to L.S. and grants from King Abdullah University of Science and Technology, under award numbers BAS/1/1624-01, FCC/1/1976-18-01, FCC/1/1976-23-01, FCC/1/1976-25-01, FCC/1/1976-26-01, REI/1/0018-01-01, and URF/1/4098-01-01. Andrychowicz, M., Denil, M., Gomez, S., Hoffman, M. W., Pfau, D., Schaul, T., Shillingford, B., and De Freitas, N. Learning to learn by gradient descent by gradient descent. In Advances in Neural Information Processing Systems, pp. 3981 3989, 2016. Arbelaez, P., Maire, M., Fowlkes, C., and Malik, J. Contour detection and hierarchical image segmentation. IEEE transactions on pattern analysis and machine intelligence, 33(5):898 916, 2010. Beck, A. and Teboulle, M. A fast iterative shrinkagethresholding algorithm for linear inverse problems. SIAM journal on imaging sciences, 2(1):183 202, 2009. Becker, S., Cheridito, P., and Jentzen, A. Deep optimal stopping. Journal of Machine Learning Research, 20(74): 1 25, 2019. Belanger, D., Yang, B., and Mc Callum, A. End-to-end learning for structured prediction energy networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 429 439. JMLR. org, 2017. Blumensath, T. and Davies, M. E. Iterative thresholding for sparse approximations. Journal of Fourier analysis and Applications, 14(5-6):629 654, 2008. Borgerding, M., Schniter, P., and Rangan, S. Amp-inspired deep networks for sparse linear inverse problems. IEEE Transactions on Signal Processing, 65(16):4293 4308, 2017. Ceci, C. and Bassan, B. Mixed optimal stopping and stochastic control problems with semicontinuous final reward for diffusion processes. Stochastics and Stochastic Reports, 76(4):323 337, 2004. Chen, X., Liu, J., Wang, Z., and Yin, W. Theoretical linear convergence of unfolded ista and its practical weights and thresholds. In Advances in Neural Information Processing Systems, pp. 9061 9071, 2018. Chen, X., Dai, H., and Song, L. Particle flow bayes rule. In International Conference on Machine Learning, pp. 1022 1031, 2019. Chen, X., Li, Y., Umarov, R., Gao, X., and Song, L. Rna secondary structure prediction by learning unrolled algorithms. ar Xiv preprint ar Xiv:2002.05810, 2020. Dabov, K., Foi, A., Katkovnik, V., and Egiazarian, K. Image denoising by sparse 3-d transform-domain collaborative filtering. IEEE Transactions on image processing, 16(8): 2080 2095, 2007. Domke, J. Parameter learning with truncated messagepassing. In CVPR 2011, pp. 2937 2943. IEEE, 2011. Dumitrescu, R., Reisinger, C., and Zhang, Y. Approximation schemes for mixed optimal stopping and control problems with nonlinear expectations and jumps. ar Xiv preprint ar Xiv:1803.03794, 2018. Learning to Stop While Learning to Predict Finn, C., Abbeel, P., and Levine, S. Model-agnostic metalearning for fast adaptation of deep networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 1126 1135. JMLR. org, 2017. Gregor, K. and Le Cun, Y. Learning fast approximations of sparse coding. In Proceedings of the 27th International Conference on International Conference on Machine Learning, pp. 399 406. Omnipress, 2010. Gu, S., Zhang, L., Zuo, W., and Feng, X. Weighted nuclear norm minimization with application to image denoising. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2862 2869, 2014. Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X., Botvinick, M., Mohamed, S., and Lerchner, A. beta VAE: Learning basic visual concepts with a constrained variational framework. ICLR, 2(5):6, 2017. Huang, G., Chen, D., Li, T., Wu, F., van der Maaten, L., and Weinberger, K. Multi-scale dense networks for resource efficient image classification. In International Conference on Learning Representations, 2018. URL https:// openreview.net/forum?id=Hk2a Imx Ab. Ingraham, J., Riesselman, A., Sander, C., and Marks, D. Learning protein structure with a differentiable simulator. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum? id=Byg3y3C9Km. Kaya, Y., Hong, S., and Dumitras, T. Shallow-deep networks: Understanding and mitigating network overthinking. In International Conference on Machine Learning, pp. 3301 3310, 2019. Kingma, D. P. and Welling, M. Auto-encoding variational bayes. ar Xiv preprint ar Xiv:1312.6114, 2013. Lake, B., Salakhutdinov, R., Gross, J., and Tenenbaum, J. One shot learning of simple visual concepts. In Proceedings of the annual meeting of the cognitive science society, volume 33, 2011. Le, T. A., Igl, M., Rainforth, T., Jin, T., and Wood, F. Autoencoding sequential monte carlo. In International Conference on Learning Representations, 2018. Lee, Y. and Choi, S. Gradient-based meta-learning with learned layerwise metric and subspace. ar Xiv preprint ar Xiv:1801.05558, 2018. Lefkimmiatis, S. Universal denoising networks: a novel cnn architecture for image denoising. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 3204 3213, 2018. Li, K. and Malik, J. Learning to optimize. ar Xiv preprint ar Xiv:1606.01885, 2016. Li, Z., Zhou, F., Chen, F., and Li, H. Meta-sgd: Learning to learn quickly for few-shot learning. ar Xiv preprint ar Xiv:1707.09835, 2017. Liu, J., Chen, X., Wang, Z., and Yin, W. ALISTA: Analytic weights are as good as learned weights in LISTA. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum? id=B1lnzn0ct Q. Martin, D., Fowlkes, C., Tal, D., and Malik, J. A database of human segmented natural images and its application to evaluating segmentation algorithms and measuring ecological statistics. In Proceedings Eighth IEEE International Conference on Computer Vision. ICCV 2001, volume 2, pp. 416 423. IEEE, 2001. Metzler, C., Mousavi, A., and Baraniuk, R. Learned damp: Principled neural network based compressive image recovery. In Advances in Neural Information Processing Systems, pp. 1772 1783, 2017. Na, D., Lee, H. B., Lee, H., Kim, S., Park, M., Yang, E., and Hwang, S. J. Learning to balance: Bayesian metalearning for imbalanced and out-of-distribution tasks. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum? id=rke ZIJBYvr. Nowozin, S., Cseke, B., and Tomioka, R. f-gan: Training generative neural samplers using variational divergence minimization. In Advances in neural information processing systems, pp. 271 279, 2016. Oreshkin, B., L opez, P. R., and Lacoste, A. Tadam: Task dependent adaptive metric for improved few-shot learning. In Advances in Neural Information Processing Systems, pp. 721 731, 2018. Pham, H. Optimal stopping of controlled jump diffusion processes: a viscosity solution approach. In Journal of Mathematical Systems, Estimation and Control. Citeseer, 1998. Qiao, S., Liu, C., Shen, W., and Yuille, A. L. Few-shot image recognition by predicting parameters from activations. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7229 7238, 2018. Ravi, S. and Larochelle, H. Optimization as a model for few-shot learning. 2017. Shiryaev, A. N. Optimal stopping rules, volume 8. Springer Science & Business Media, 2007. Learning to Stop While Learning to Predict Shrivastava, H., Chen, X., Chen, B., Lan, G., Aluru, S., Liu, H., and Song, L. GLAD: Learning sparse graph recovery. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum? id=Bkxp MTEt PB. Simonyan, K. and Zisserman, A. Very deep convolutional networks for large-scale image recognition. ar Xiv preprint ar Xiv:1409.1556, 2014. Sun, J., Li, H., Xu, Z., et al. Deep admm-net for compressive sensing mri. In Advances in neural information processing systems, pp. 10 18, 2016. Teerapittayanon, S., Mc Danel, B., and Kung, H.-T. Branchynet: Fast inference via early exiting from deep neural networks. In 2016 23rd International Conference on Pattern Recognition (ICPR), pp. 2464 2469. IEEE, 2016. Yakar, T. B., Litman, R., Sprechmann, P., Bronstein, A. M., and Sapiro, G. Bilevel sparse models for polyphonic music transcription. In ISMIR, pp. 65 70, 2013. Zamir, A. R., Wu, T.-L., Sun, L., Shen, W. B., Shi, B. E., Malik, J., and Savarese, S. Feedback networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1308 1317, 2017. Zhang, J. and Ghanem, B. Ista-net: Interpretable optimization-inspired deep network for image compressive sensing. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1828 1837, 2018. Zhang, K., Zuo, W., Chen, Y., Meng, D., and Zhang, L. Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising. IEEE Transactions on Image Processing, 26(7):3142 3155, 2017. Zhang, X., Lu, Y., Liu, J., and Dong, B. Dynamically unfolding recurrent restorer: A moving endpoint control method for image restoration. In International Conference on Learning Representations, 2019. URL https: //openreview.net/forum?id=SJf ZKi C5FX.