# knowledge_distillation_performs_partial_variance_reduction__0486b419.pdf Knowledge Distillation Performs Partial Variance Reduction Mher Safaryan IST Austria mher.safaryan@ista.ac.at Alexandra Peste IST Austria alexandra.peste@ista.ac.at Dan Alistarh IST Austria dan.alistarh@ista.ac.at Knowledge distillation is a popular approach for enhancing the performance of student models, with lower representational capacity, by taking advantage of more powerful teacher models. Despite its apparent simplicity and widespread use, the underlying mechanics behind knowledge distillation (KD) are still not fully understood. In this work, we shed new light on the inner workings of this method, by examining it from an optimization perspective. We show that, in the context of linear and deep linear models, KD can be interpreted as a novel type of stochastic variance reduction mechanism. We provide a detailed convergence analysis of the resulting dynamics, which hold under standard assumptions for both strongly-convex and non-convex losses, showing that KD acts as a form of partial variance reduction, which can reduce the stochastic gradient noise, but may not eliminate it completely, depending on the properties of the teacher model. Our analysis puts further emphasis on the need for careful parametrization of KD, in particular w.r.t. the weighting of the distillation loss, and is validated empirically on both linear models and deep neural networks. 1 Introduction Knowledge Distillation (KD) [13, 3] is a standard tool for transferring information between a machine learning model of lower representational capacity usually called the student and a more accurate and powerful teacher model. In the context of classification using neural networks, it is common to consider the student to be a smaller network [2], whereas the teacher is a network that is larger and more computationally-heavy, but also more accurate. Assuming a supervised classification task, distillation consists in training the student to minimize the cross-entropy with respect to the teacher s logits on every given sample, in addition to minimizing the standard cross-entropy loss with respect to the ground truth labels. Since its introduction [3], distillation has been developed and applied in a wide variety of settings, from obtaining compact high-accuracy encodings of model ensembles [14], to boosting the accuracy of compressed models [50, 39, 32], to reinforcement learning [51, 43, 36, 5, 7, 46] and learning with privileged information [52]. Given its apparent simplicity, there has been significant interest in finding explanations for the effectiveness of distillation [2, 14, 38]. For instance, one hypothesis [2, 14] is that the smoothed labels resulting from distillation present the student with a decision surface that is easier to learn than the one presented by the categorical (one-hot) outputs. Another hypothesis [2, 14, 52] starts from the observation that the teacher s outputs have higher entropy than the ground truth labels, and therefore, higher information content. Despite this work, we still have a limited analytical understanding regarding why knowledge distillation is so effective [38]. Specifically, very little is 37th Conference on Neural Information Processing Systems (Neur IPS 2023). known about the interplay between distillation and stochastic gradient descent (SGD), which is the standard optimization setting in which this method is applied. 1.1. Contributions. In this paper, we investigate the impact of knowledge distillation on the convergence of the student model when optimizing via SGD. Our approach starts from a simple re-formulation of distillation in the context of gradient-based optimization, which allows us to connect KD to stochastic variance reduction techniques, such as SVRG [16], which are popular in stochastic optimization. Our results apply both to self-distillation, where KD is applied while training relative to an earlier version of the same model, as well as distillation for compression, where a compressed model leverages outputs from an uncompressed one during training. In a nutshell, in both cases, we show that SGD with distillation preserves the convergence speed of vanilla SGD, but that the teacher s outputs serve to reduce the gradient variance term, proportionally to the distance between the teacher model and the true optimum. Since the teacher model may not be at an optimum, this means that variance reduction is only partial, as distillation may not completely eliminate noise, and in fact may introduce bias or even increase variance for certain parameter values. Our analysis precisely characterizes this effect, which can be controlled by the weight of the distillation loss, and is validated empirically for both linear models and deep networks. 1.2. Results Overview. To illustrate our results, we consider the case of self-distillation [59], which is a popular supervised training technique in which both the student model x Rd and the teacher model θ Rd have the same structure and dimensionality. The process starts from a teacher model θ trained using regular cross-entropy loss, and then trains the student model with respect to a weighted combination of cross-entropy w.r.t. the data labels, and cross-entropy w.r.t. the teacher s outputs. This process is often executed over multiple iterations, in the sense that the student at iteration k becomes the teacher for iteration k + 1, and so on. Our first observation is that, in the case of self-distillation with teacher weight 1 λ 0, the gradient of the distilled student model on a sample i, denoted by xfi(x | θ, λ) at a given iteration can simply be written as xfi(x | θ, λ) fi(x) λ fi(θ), where fi(x) and fi(θ) are the student s and teacher s standard gradients on sample i, respectively. This expression is exact for linear models and generalized linear networks, and we provide evidence that it holds approximately for general classifiers. With this re-formulation, self-distillation can be interpreted as a truncated form of the classic SVRG iteration [16], which never employs full (non-stochastic) gradients. Our main technical contribution is in providing convergence guarantees for iterations of this form, for both strong quasi-convex, and non-convex functions under the Polyak-Łojasiewicz (PL) condition. Our analysis covers both self-distillation (where the teacher is a partially-trained version of the same model), and more general distillation, in which the student is a compressed version of the teacher model. The convergence rates we provide are similar in nature to SGD convergence, with one key difference: the rate dependence on the gradient variance σ2 is dampened by a term depending on the gap between the teacher model and an optimum. Intuitively, this says that, if the teacher s accuracy is poor, then distillation will not have any positive effect. However, the better-trained the teacher is, the more it can help reduce the student s variance during optimization. Importantly, this effect occurs even if the teacher is not trained to near-zero loss, thus motivating the usefulness of the teacher in self-distillation. Our analysis highlights the importance of the distillation weight, as a means to maximize the positive effects of distillation: for linear models, we can even derive a closed-form solution for the optimal distillation weight. We validate our findings experimentally for both linear models and deep networks, confirming the effects predicted by the analysis. 2 Related Work We now provide an overview for some of the relevant related work regarding KD. Knowledge distillation, in its current formulation, was introduced in the seminal work of [13], which showed that the predictive performance of a model can be improved if it is trained to match the soft targets produced by a large and accurate model. This observation has motivated the adoption of KD as a standard mechanism to enhance the training of neural networks in a wide range of settings, such as compression [39], learning with noisy labels [26], and has also become an essential tool in training accurate compressed versions of large models [39, 49, 54]. Despite these important practical advantages, providing a thorough theoretical justification for the mechanisms driving the success of KD has so far been elusive. Several works have focused on studying KD from different theoretical perspectives. For example, Lopez et al. [28] connected distillation with privileged information [52] by proposing the notion of generalized distillation, and presented an intuitive explanation for why generalized distillation should allow for higher sample efficiency, relative to regular training. Phuong and Lampert [38] studied distillation from the perspective of generalization bounds in the case of linear and deep linear models. They identify three factors which influence the success of KD: (1) the geometry of the data; (2) the fact that the expected risk of the student always decreases with more data; (3) the fact that gradient descent finds a favorable minimum of the distillation objective. By contrast to all these previous references, our work studies the impact of distillation on stochastic (SGD-based) optimization. More broadly, there has been extensive work on connecting KD with other areas in learning. Dao et al. [6] examined links between KD and semi-parametric Inference, whereas Li et al. [61] performed an empirical study on KD in the context of learning with noisy labels. Yuan et al. [58] and Sultan et al. [48] investigated the relationships between KD and label smoothing, a popular heuristic for training neural networks, showing both similarities and substantive differences. In this context, our work is the first to signal the connection between KD and variance-reduction, as well as investigating the convergence of KD in the context of stochastic optimization. 3 Knowledge Distillation 3.1. Background. Assume we are given a finite dataset {(an, bn) | n = 1, 2, . . . , N}, where inputs an A (e.g., vectors from Rd) and outputs bn B (e.g., categorical labels or real numbers). Consider a set of models F = {ϕx : A B | x P Rd} with fixed neural network architecture parameterized by vector x. Depending on the supervised learning task and the class F of models, we define a loss function ℓ: B B R+ in order to measure the performance of the model. In particular, the loss associated with a data point (an, bn) and model ϕx F would be ℓ(ϕx(an), bn). In this framework, the standard Empirical Risk Minimization (ERM) takes the following form: min x Rd 1 N PN n=1 ℓ(ϕx(an), bn). (1) In the objective above, the model ϕx is trained to match the true outputs bn given in the training dataset. Suppose that in addition to the true labels bn, we have access to sufficiently well-trained and perhaps more complicated teacher model s outputs Φθ(an) B for each input an A. Similar to the student model ϕx, the teacher model Φθ maps A B but can have different architecture, more layers and parameters. The fundamental question is how to exploit the additional knowledge of the teacher Φθ to facilitate the training of a more compact student model ϕx with lower representational capacity. Knowledge Distillation with parameter λ [0, 1] from teacher model Φθ to student model ϕx is the following modification to the objective (1): min x Rd 1 N PN n=1 h (1 λ)ℓ(ϕx(an), bn) + λℓ(ϕx(an), Φθ(an)) i . (2) Here we customize the loss penalizing dissimilarities from the teacher s feedback Φθ(an) in addition to the true outputs bn. In case of ℓis linear in the second argument (e.g., cross-entropy loss), the problem simplifies into min x Rd 1 N PN n=1 ℓ(ϕx(an), (1 λ)bn + λΦθ(an)), (3) which is a standard ERM (1) with modified soft labels sn := (1 λ)bn + λΦθ(an) as the target. 3.2. Self-distillation. As already mentioned, the teacher s model Φθ can have more complicated neural network architecture and potentially larger parameter space θ Q RD. In particular, Φθ does not have to be from the same set of models F as the student model ϕx. The special case when both the student and the teacher share the same structure/architecture is called self-distillation [33, 60], which is the key setup for our work. In this case, the teacher model Φθ ϕθ F with θ Rd (i.e., Q = P, D = d) and the corresponding distillation objective would be min x Rd 1 N PN n=1 h (1 λ)ℓ(ϕx(an), bn) + λℓ(ϕx(an), ϕθ(an)) i . (4) Algorithm 1 Knowledge Distillation via SGD 1: Input: learning rate γ > 0, initial student model x0 P Rd 2: for each distillation iteration m do 3: choose a teacher model θm Q RD and distillation weight λm [0, 1] (e.g., see Sec. 5) 4: for each training iteration t do 5: sample an unbiased mini-batch ξ D form the train set 6: compute distillation gradient fξ(xt | θm, λm) = λm fξ(xt) + (1 λm) fξ(xt | θm) 7: update the student model via xt+1 = xt γ fξ(xt | θm, λm) 8: end for 9: end for Our primary focus in this work would be the objective mentioned above of self-distillation. For convenience, let fn(x) := ℓ(ϕx(an), bn) be the prediction loss with respect to the output bn, fn(x | θ) := ℓ(ϕx(an), ϕθ(an)) be the loss with respect to the teacher s output probabilities and fn(x | θ, λ) := λfn(x) + (1 λ)fn(x | θ) be the distillation loss. See Algorithm 1 for an illustration. 3.3. Distillation Gradient. As the first step towards understanding how self-distillation affects the training procedure, we analyze the modified loss landscape (4) via stochastic gradients of (1) and (4). In particular, we put forward the following proposition regarding the form of distillation gradient in terms of gradients of (1). Proposition 1 (Distillation Gradient). For a student model x Rd, teacher model θ Rd and distillation weight λ, the distillation gradient corresponding to self-distillation (4) is given by xfn(x | θ, λ) = fn(x) λ fn(θ). (5) Before justifying this proposition formally, let us provide some intuition behind the expression (5) and its connection to distillation loss (4). First, the gradient expression (5) suggests that the teacher has little or no effect on the data points classified correctly with high confidence (i.e. those for which fn(θ) is close to 0 or ϕθ(an) is close to bn). In other words, the more accurate the teacher is, the less it can affect the learning process. In the extreme case, a perfect or overfitted teacher (one that fn(θ) = 0 or ϕθ(an) = bn for all n) will have no effect. In fact, this is expected since, in this case, problems (1) and (4) coincide. Alternatively, if the teacher is not perfect, then the modified objective (4) intuitively suggests that the learning dynamics of a student model is adjusted based on the teacher s knowledge. As we can see in (5), the adjustment from the teacher is enforced by λ fn(θ) term. It is worth mentioning that the direction of distillation gradient xfn(x | θ, λ) can be different from the usual gradient s direction fn(x) due to the influence of the teacher. Thus, Proposition 1 explicitly shows how the teacher guides the student by adjusting its stochastic gradient. As we will show later, distillation gradient (5) leads to partial variance reduction because of the additional λ fξ(θ) term. When chosen properly (distillation weight λ and proximity of θ to the optimal solution x ), this additional stochastic gradient is capable of adjusting the student s stochastic gradient since both are computed using the same batch from the train data. In other words, both gradients have the same source of randomness which makes partial cancellations feasible. Linear regression. To support Proposition 1 rigorously, consider the simple setup of linear regression. Let A = Rd, P = Rd+1, ϕx(a) = x a R, where a = [a 1] Rd+1 is the input vector in the lifted space (to include the bias term), B = R, and the loss is defined by ℓ(t, t ) = (t t )2 for all t, t R. Thus, based on (4), we have fn(x | θ, λ) = (1 λ)(x an bn)2 + λ(x an θ an)2, from which we compute its gradient straightforwardly as xfn(x | θ, λ) = 2(1 λ)(x an bn) an + 2λ(x an θ an) an = 2(x an bn) an 2λ(θ an bn) an = fn(x) λ fn(θ). Hence, the distillation gradient for linear regression tasks has the form (5). Classification with a single hidden layer. We can extend the above argument and derivation for a K-class classification model with one hidden layer that has soft-max as the last layer, i.e. ϕX(an) = σ(X an) RK, where X = [x1 x2 . . . x K] Rd K are the model parameters, 0 20 40 60 80 100 Epoch Cosine Similarity =0.1 =0.3 =0.5 =0.7 =0.9 0 20 40 60 80 100 Epoch =0.1 =0.3 =0.5 =0.7 =0.9 0 20 40 60 80 100 Epoch =0.1 =0.3 =0.5 =0.7 =0.9 Figure 1: Cosine similarity, l2 distance and SNR (i.e., l2 distance over the gradient norm of standard KD) statistics between the true and approximated distillation gradient for a neural network during training. As predicted, larger λ leads to larger differences, although gradients remain well-correlated. an A = Rd is the input data and σ is the soft-max function. Then, we show in the Appendix B.2 that for all k = 1, 2 . . . , K it holds xkfn(X | Θ, λ) = xkfn(X) λ θkfn(Θ), where Θ = [θ1 θ2 . . . θK] Rd K are the teacher s parameters. Generic classification. Proposition 1 will not hold precisely for arbitrary deep non-linear neural networks. However, careful calculations reveal that, in general, distillation gradient takes a form similar to (5). Detailed derivations are deferred to Appendix B.3, here we provide the sketch. Consider an arbitrary neural network architecture for classification that ends with soft-max layer, i.e. ϕx(an) = σ(ψn(x)), where an A is the input data, ψn(x) are the produced logits with respect to the model parameters x, and σ is the soft-max function. Denote φn(z) := ℓ(σ(z), bn) the loss associated with logits z and the true label bn. In words, ψn gives the logits from the input data, while φn computes the loss from given logits. Then, the representation for the loss function is fn(x) = φn(ψn(x)). We show in Appendix B.3 that the distillation gradient can be written as xfn(x | θ, λ) = Jψn(x) ( φn(ψn(x)) λ φn(ψn(θ))) = ψn(x) x fn(x) ψn(x) λ ψn(x) x fn(θ) ψn(θ), where Jψn(x) := ψn(x) x = [ ψn,1(x) ψn,2(x) . . . ψn,K(x)] Rd K is the Jacobian of the vector-valued function ψn for logits. Notice that the first term ψn(x) x fn(x) ψn(x) coincides with the student s gradient xfn(x) = fn(x) x . However, the second term ψn(x) x fn(θ) ψn(θ) differs from the teacher s gradient as the partial derivatives of logits are with respect to the student model. Despite these differences in the case of deep non-linear models, we observe that the distillation gradient defined by Equation 5 can approximate well the true distillation gradient from Equation 3. Specifically, we consider a fully connected neural network with one hidden layer and Re LU activation [34], trained on the MNIST dataset [24], using regular self-distillation, from an SGD-teacher, with a fixed learning rate, and SGD with weight decay and no momentum. At each training iteration we compute the cosine similarity between the gradient of the distillation loss and the approximation from Equation 5, and we average the results across each epoch. The results presented in Figure 1 show that the distillation gradient approximates well the true distillation gradient. Moreover, the behavior is monotonic in the distillation weight λ (higher similarity for smaller λ), as predicted by the analysis above, and it stabilizes as training progresses. The decrease of cosine similarity can be explained as follows: at the beginning the cosine similarity is high (and SNR is low) since we start from the same model. Then, initial perturbations caused by either the KD or modified KD gradient don t cause big shifts (the teacher has enough confidence and small gradients). These perturbations accumulate over the training leading to decreased cosine similarity and eventually stabilize. 4 Convergence Theory for Self-Distillation 4.1. Optimization Setup and Assumptions. We abstract the standard ERM problem (1) into a stochastic optimization problem of the form n f(x) := Eξ D [fξ(x)] o , (6) where fξ(x) is the loss associated with data sample ξ D given model parameters x Rd. For instance, if ξ = (an, bn) is a single data point, then the corresponding loss is fξ(x) = ℓ(ϕx(an), bn). The goal is to find parameters x minimizing the risk f(x). To solve the problem (6), we employ Stochastic Gradient Descent (SGD). Based on Section 3, applying SGD to the problem (6) with self-distillation amounts to the following optimization updates in the parameter space: xt+1 = xt γ( fξ(xt) λ fξ(θ)), (7) with initialization x0 Rd, step size or learning rate γ > 0, teacher model s parameters θ Rd and distillation weight λ. To analyze the convergence behavior of iterates (7), we need to impose some assumptions in order to derive reasonable convergence guarantees. First, we assume that the problem (6) has a non-empty solution set X = and f := f(x ) for some minimizer x X. Assumption 1 (Strong quasi-convexity). The function f : Rd R is differentiable and µ-strongly quasi-convex for some constant µ > 0, i.e., for any x Rd it holds f(x ) f(x) + f(x), x x + µ 2 x x 2. (8) Strong quasi-convexity [9] is a weaker version of strong convexity [35], which assumes that the quadratic lower bound above holds for at every point y Rd instead of x X. Notice that strong quasi-convexity implies that the minimizer x is unique1. A more relaxed version of this assumption is the Polyak-Łojasiewicz (PL) condition [40]. Assumption 2 (Polyak-Łojasiewicz condition). Function f : Rd R is differentiable and satisfies PL condition with parameter µ > 0, if for any x Rd it holds f(x) 2 2µ(f(x) f ). (9) Note that the requirement imposed by the PL condition above is weaker than by strong convexity. Functions satisfying PL condition do not have to be convex and can have multiple minimizers [17]. We make use of the following form of smoothness assumption on the stochastic gradient commonly referred to as expected smoothness in the optimization literature [11, 10, 19]. Assumption 3 (Expected Smoothness). Functions fξ(x) are differentiable and L-smooth in expectation with respect to subsampling ξ D, i.e., for any x Rd it holds Eξ D fξ(x) fξ(x ) 2 2L(f(x) f ) (10) for some constant L = L(f, D). The expected smoothness condition above is a joint property of loss function f and data subsampling strategy from the distribution D. In particular, it subsumes the smoothness condition for f(x) since (10) also implies f(x) f(x ) 2 2L(f(x) f ) for any x Rd. We denote by L the smoothness constant of f(x) and notice that L L. 4.2. Convergence Theory and Partial Variance Reduction. Equipped with the assumptions described in the previous part, we now present our convergence guarantees for the iterates (7) for both strong quasi-convex and PL loss functions. Theorem 1 (See Appendix C.2). Let Assumptions 1 and 3 hold. For any γ 1 8L and properly chosen distillation weight λ, the iterates (7) of SGD with self-distillation using teacher s parameters θ converge as E xt x 2 (1 γµ)t x0 x 2 + 2σ2 µ min(γ, O(f(θ) f )), (11) where σ2 := E[ fξ(x ) 2] is the stochastic noise at the optimum. Theorem 2 (See Appendix C.3). Let Assumptions 2 and 3 hold. For any γ 1 4L µ L and properly chosen distillation weight λ, the iterates (7) of SGD with self-distillation using teacher s parameters θ converge as E [f(xt) f ] (1 γµ)t f(x0) f + Lσ2 µ min(γ, O(f(θ) f )), (12) 1If f(x ) = f(x ), then µ 2 x x f(x ) f(x ) = 0. Hence, x = x . Proof overview. Both proofs follow similar steps and can be divided into three logical parts. Part 1 (Descent inequality). Generally, an integral part of essentially any convergence theory for optimization methods is a descent inequality quantifying the progress of an algorithm in one iteration. Our theory is not an exception: we first define our potential et = xt x 2 for the strongly quasi-convex setup, and et = f(xt) f for the PL setup. Then, we start our derivations by bounding Et[et+1] (1 γµ)et. Here, Et is the conditional expectation with respect the randomness of previous iterate xt. Specifically, up to constants, both setups allow the following bound: Et[et+1] (1 γµ)et O(γ)(1 O(γ))(f(xt) f ) + O(γ)N(λ), (13) where N(λ) = λ2 f(θ) 2 + γE fξ(x ) λ fξ(θ) 2 . Choosing the learning rate γ to be small enough, we ensure that the second term is non-positive and hence negligible in the upper bound. Part 2 (Optimal distillation weight). Next, we focus our attention on the third term in (13) involving the iteration-independent neighborhood term N(λ). Note that the O(γ) factor next to N(λ) will be absorbed once we unfold the recursion (13) up to initial iterate. Hence, the convergence neighborhood is proportional to N(λ). Now the question is how small this term can get if we properly tune the parameter λ. Notice that N(0) = γσ2 corresponds to the neighborhood size for plain SGD without any distillation involved. Luckily, due to the quadratic dependence, we can minimize N(λ) analytically with respect to λ and find the optimal value λ = E[ fξ(x ), fξ(θ) ] E[ fξ(θ) 2]+ 1 γ f(θ) 2 . (14) Consequently, the analysis puts further emphasis on the need for careful parametrization with respect to the weighting λ of the distillation loss as there exists a particularly privileged value λ . Part 3 (Impact of the teacher). In the final step, we quantify the impact of the teacher on the reduction in the neighborhood term N(λ ) compared to the plain SGD neighborhood N(0). Via algebraic transformations, we show that N(0) = 1 ρ2(x , θ) 1 β(θ) γ β(θ), (15) where β(θ) = f(θ) 2/E[ fξ(θ) 2] [0, 1] is the signal-to-noise ratio, and ρ(x , θ) [ 1, 1] is the correlation coefficient between stochastic gradients fξ(x ) and fξ(θ). This representation gives us analytical means to measure the impact of the teacher. For instance, the optimal teacher θ = x satisfies ρ(x , θ) = 1 and β(θ) = 0, and thus N(λ ) = 0. In general, if the teacher is not optimal, then the reduction (15) is of order 1 γ O(f(θ) f ) and N(λ ) = σ2 min(γ, O(f(θ) f )). We discuss these results highlighting the key aspects and significance. Structure of the rates. The structure of these two convergence rates is typical in gradient-based stochastic optimization literature: linear convergence up to some neighborhood controlled by the stochastic noise term σ2 and learning rate γ. In fact, these results (including learning rate restrictions) are identical to ones for SGD [8] except the non-vanishing terms in (11) and (12) include an additional O(f(θ) f ) factor due to distillation and proper selection of weight parameter λ. For both setups, the rate of SGD is the same (11) or (12) with only one difference: min(γ, O(f(θ) f )) term is replaced with γ. So, O(f(θ) f ) is the factor that makes our results better compared to SGD in terms of optimization performance. Importance of the results. First, observe that in the worst case when the teacher s parameters are trained inadequately (or not trained at all), that is O(f(θ) f ) γ, then the obtained rates recover the known results for plain SGD. However, the crucial benefit of these results is to show that a sufficiently well-trained teacher, i.e. O(f(θ) f ) < γ, provably reduces the neighborhood size of SGD without slowing down the speed of convergence. In the best case scenario, when the teacher s model is perfectly trained, namely f(θ) = f , then the neighborhood term vanishes, and the method converges to the exact solution (see SGD-star Algorithm 4 in [9]). Thus, self-distillation in the form of iterates (7) acts as a form of partial variance reduction, which can reduce the stochastic gradient noise, but may not eliminate it completely, depending on the properties of the teacher model. Choice of distillation weight λ. As we discussed in the proof sketch above, our analysis reveals that the performance of distillation is optimized for a specific value (14) of distillation weight λ depending on the teacher model. One way to interpret the expression (14) for weight parameter 0.0 0.01 0.05 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Lambda Train Loss Statistics (a) Self-distillation on MNIST. 0.0 0.01 0.05 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Lambda Train Loss Statistics L-BFGS teacher SGD teacher (b) Self-distillation on CIFAR-10. Figure 2: The minimum training loss, and average over the last 10 epochs, for models trained with SGD and with self-distillation, using different values of the distillation parameter λ, on MNIST and CIFAR-10. SGD is equivalent to λ = 0. The curves for the SGD-based teachers (which do not have zero loss) reflect our analysis, corroborating the existence of the optimal distillation weight. By contrast, in the L-BFGS teacher, higher distillation weight always leads to lower loss. intuitively is that the better the teacher s model θ is, the bigger λ 1 gets. In other words, λ quantifies the quality of the teacher: λ 0 indicates a poor teacher model (f(θ) f ) and λ = 1 is for the optimal teacher (f(θ) = f ). 4.3. Experimental Validation. In this section we illustrate that our theoretical analysis in the convex case also holds empirically. Specifically, we consider classification problems using linear models in two different setups: training a linear model on the MNIST dataset [24] and linear probing on the CIFAR-10 dataset [23], using a Res Net50 model [12], pre-trained on the Image Net dataset [42]. For the second setup we train a linear classifier on top of the features extracted from a Res Net50 model pre-trained on Image Net. This is a standard setting, commonly used in the transfer learning literature, see e.g. [21, 45, 15]. In both cases we train using SGD without momentum and regularization, with a fixed learning rate and mini-batch of size 10, for a total of 100 epochs. The models trained with SGD are compared against self-distillation (Equation 7), using the same training hyper-parameters, where the teacher is the model trained with SGD. In the case of CIFAR-10 features, we also consider the optimal teacher, which is a model trained with L-BFGS [27]. We perform all experiments using multiple values of the distillation parameter λ and measure the cross entropy loss between student and true labels. At each training epoch we computing the running average over all mini-batch losses seen during that epoch. The results presented in Figure 2 show the minimum cross entropy train loss obtained over 100 epochs, as well as the average over the last 10 epochs, for models trained with SGD, as well as with self-distillation, with λ [0.01, 1]. We observe that when the teacher is the model trained with SGD (λ = 0), there exists a λ > 0 which achieves a lower training loss than SGD, which is in line with our statement from Theorem 1. Furthermore, when the teacher is very close to the optimum, λ closer to 1 reduces the training loss the most compared to SGD, which is also in line with the theory (see Theorem 1). This behavior is illustrated in Figure 2b, when using an L-BFGS teacher. 5 Removing Bias and Improving Variance Reduction In this section, we investigate the cause of having variance reduction only partially and suggest a possible workaround to obtain complete variance reduction. In brief, the potential source of partial variance reduction is the biased nature of distillation. Essentially, distillation bias is reflected in the iterates (7) since the expected update direction E [ fξ(xt) λ fξ(θ) | xt] = f(xt) λ f(θ) can be different from f(xt). This comes from the fact that distillation loss (4) modifies the initial loss (1) composed of true outputs. To make our argument compelling, next we correct the bias by adding λ f(θ) to iterates (7) and analyze the following dynamics: xt+1 = xt γ( fξ(xt) λ fξ(θ) + λ f(θ)). (16) Besides making the estimate unbiased, the advantage of this adjustment is that no tuning is required for the distillation weight λ; we may simply set λ = 1. The obvious disadvantage is that f(θ) is the batch gradient over the whole train data that can be very costly to compute. However, we could compute it once and reuse it for all further iterates. The resulting iteration is similar to the popular and well-studied SVRG [16, 57, 41, 20, 25, 22, 29] method, and therefore iterates (16) will enjoy full variance reduction. 0 20 40 60 80 100 Epoch Training Method KD = 0.4 Unbiased KD = 0.4 0 20 40 60 80 100 Epoch Average Gradient Variance (per epoch) Training Method KD = 0.4 Unbiased KD = 0.4 Figure 3: (Left plot) The train loss of self-distillation, unbiased self-distillation and vanilla SGD training. (Right plot) The progress of gradient variances (averaged over the iterations within each epoch) for the same setup Theorem 3 (See Appendix D). Let Assumptions 2 and 3 hold. Then for any γ µ 3LL the iterates (16) with λ = 1 converge as Et [f(xt) f ] (1 γµ)t(f(x0) f ) + 3L(L+L) µ γ (f(θ) f ) . (17) The key improvement that bias correction brings in (17) is the convergence up to a neighborhood O(γ(f(θ) f )) in contrast to min(γ, O(f(θ) f )) as in (11) and (12). The multiplicative dependence of learning rate and the quality of the teacher leads the method (16) to full variance reduction. Indeed, if we choose the teacher model as θ = x0, then the rate (17) becomes E [f(xt) f ] [(1 γµ)t + γ 3L(L + L)/µ] (f(x0) f ) 1 2(f(x0) f ), provided sufficiently small step-size γ µ 12LL and enough training iterations t = O(1/γµ). Hence, following SVRG and updating the teacher model in every τ = O(1/γµ) training iterations, that is choosing θm = xmτ as the teacher at the mth distillation iteration (see line 3 of Algorithm 1), we have E [f(xmτ) f ] 1 2(f(x(m 1)τ) f ) 1 2m (f(x0) f ). Thus, we need O(log 1 ϵ ) bias-corrected distillation iteration phases, each with τ = O(1/γµ) training iterates, to get ϵ accuracy in function value. Overall, this amounts to O( 1 ϵ ) iterations of (16). Experimental Validation. Similarly to the previous section, we further validate empirically the result from Theorem 3. Specifically, we consider the convex setup described before, where we train linear models on features extracted on the CIFAR-10 dataset. Based on Figure 2b, we select λ = 0.4 achieving the largest reduction in train loss, compared to SGD, and we additionally perform unbiased self-distillation (Equation 16), using the same training hyperparameters. Similar to the setup from Figure 2, we measure the cross entropy train loss of the student and with the true labels, which is computed at each epoch by averaging the mini-batch losses. The results are averaged over three runs and presented in Figure 3. The first plot on the left shows that, indeed, the unbiased self-distillation update further reduces the training loss, compared to the update from Equation 7. The second plot explicitly tracks gradient variance (averaged over the iterations within each epoch) for the same setup. As expected, both variants of KD (biased and unbiased) have reduced gradient variance compared to plain SGD. The plot also highlights that both variants of KD have similar variance reduction properties, while the unbiasedness of unbiased KD amplifies the reduction of train loss. 6 Convergence for Distillation of Compressed Models So far, the theory we have presented is for self-distillation, i.e., the teacher s and student s architectures are identical. To understand the impact of knowledge distillation, we relax this requirement and allow the student s model to be a sub-network of the larger and more powerful teacher s model. Our approach to model this relationship between the student and the teacher is to view the student as a masked or, in general, compressed version of the teacher. Hence, as an extension to (7) we analyze the following dynamics of distillation with compressed iterates: xt+1 = C(xt γ( fξ(xt) λ fξ(θ))), (18) where student s parameters are additionally compressed in each iteration using an unbiased compression operator defined below. Assumption 4. The compression operator C : Rd Rd is unbiased and there exists finite ω 0 bounding the compression variance variance, i.e., for all x Rd we have E[C(x)] = x, E[ C(x) x 2] ω x 2 . (19) Typical examples of compression operators satisfying conditions (19) are sparsification [55, 47] and quantization [1, 56], which are heavily used in the context of communication efficient distributed optimization and federated learning [30, 44, 37, 53]. In this context, we obtain the following: Theorem 4 (See Appendix E). Let smoothness Assumption 3 hold and f be µ-strongly convex. Choose any γ 1 16L and compression operator with variance parameter ω = O(µ/L). Then, properly selecting distillation weight λ, the iterates (18) satisfy E h xt x 2i O(ω + 1) h (1 γµ)t x0 x 2 + ωL µ x 2 + σ2 µ min (γ, O(f(θ) f )) i . Clearly, there are several factors influencing the speed of the rate and the neighborhood of the convergence that require some discussion. First of all, choosing the identity map as a compression operator (C(x) = x for all x Rd), we recover the same rate (11) as before (ω = 0 in this case). Next, consider the case when the stochastic noise at the optimum vanishes (σ2 = 0) and distillation is switched off (λ = 0) in (18). In this case, the convergence is still up to some neighborhood proportional to x 2 since compression is applied to the iterates. Intuitively, the neighborhood term O( x 2) corresponds to the compression noise at the optimum x ((19) when x = x ). Also note that the presence of this non-vanishing term O( x 2) and the variance restriction ω = O(µ/L) is consistent with the prior work [18]. So, the convergence neighborhood of iterates (18) has two terms, one from each source of randomness: compression noise/variance O( x 2) at the optimum and stochastic noise/variance O(σ2 ) at the optimum. Therefore, in this case as well, distillation with a properly chosen weight parameter (partially) reduces the stochastic variance of sub-sampling. 7 Discussion and Future Work Our work has provided a new interpretation of knowledge distillation, examining this mechanism for the first time from the point of view of optimization. Specifically, we have shown that knowledge distillation acts as a form of partial variance reduction, whose strength depends on the characteristics of the teacher model. This finding holds across several variants of distillation, such as self-distillation and distillation of compressed models, as well as across various families of objective functions. Prior observations showed that significant capacity gap between the student and the teacher may in fact lead to poorer distillation performance [31]. To reconcile the issue of large capacity gap our results, notice that, in our case better teacher means better parameter (i.e., weights and biases) values, evaluated in terms of training loss. In particular, in the case of self-distillation, covered in Sections 4 and 5, the teacher and student architectures are identical, and hence they have the same capacity. In our second regime, distillation for compressed models (Section 6), we actually consider the case when the student network is a subnetwork of the teacher; we consider a sparsification compression operator that selects k parameters for the student out of d parameters of the teacher. Then, clearly, the teacher has a larger capacity with a capacity ratio d/k 1. However, our result in this direction (Theorem 4) does not allow the capacity ratio to be arbitrarily large. Indeed, the constraint ω = O(L/µ) on compression variance implies a constraint on capacity ratio since ω = d/k 1 for the sparsification operator. Thus, our result holds when the teacher s size is not significantly larger than the student s size, which is in line with the prior observations on large capacity gap. As we mentioned, our Proposition 1 does not hold precisely for arbitrary deep non-linear neural networks. However, we showed that this simple model (5) of distillation gradient approximates the true distillation gradient reasonably well both empirically (see Figure 1) and analytically (see Appendix B.3). There is much more to investigate for the case of non-convex deep networks where exact tracking of teacher s impact across multiple layers of non-linearities becomes harder. We see our results as a promising first step towards a more complete understanding of the effectiveness of distillation. One interesting direction of future work would be to construct more complex models for distillation gradient and to investigate further connections with more complex variance-reduction methods, e.g. [4], which may yield even better-performing variants of KD. Acknowledgements MS has received funding from the European Union s Horizon 2020 research and innovation programme under the Marie Skłodowska-Curie grant agreement No 101034413. [1] Dan Alistarh, Demjan Grubic, Jerry Li, Ryota Tomioka, and Milan Vojnovic. QSGD: Communication-efficient SGD via gradient quantization and encoding. In Advances in Neural Information Processing Systems 30, pages 1709 1720, 2017. [2] Lei Jimmy Ba and Rich Caruana. Do deep nets really need to be deep? Co RR, abs/1312.6184, 2013. [3] C. Buciluˇa, R. Caruana, and A. Niculescu-Mizil. Model compression. Proceedings of the 12th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, KDD 06, pages 535 541, New York, NY, USA, 2006. [4] Ashok Cutkosky and Francesco Orabona. Momentum-based variance reduction in non-convex SGD. Advances in Neural Information Processing Systems, 32, 2019. [5] Wojciech Czarnecki, Siddhant Jayakumar, Max Jaderberg, Leonard Hasenclever, Yee Whye Teh, Nicolas Heess, Simon Osindero, and Razvan Pascanu. Mix & match agent curricula for reinforcement learning. In Proceedings of the 35th International Conference on Machine Learning, 2018. [6] Tri Dao, Govinda M Kamath, Vasilis Syrgkanis, and Lester Mackey. Knowledge Distillation as Semiparametric Inference. International Conference on Learning Representations (ICLR), 2021. [7] Alexandre Galashov, Siddhant Jayakumar, Leonard Hasenclever, Dhruva Tirumala, Jonathan Schwarz, Guillaume Desjardins, Wojtek M. Czarnecki, Yee Whye Teh, Razvan Pascanu, and Nicolas Heess. Information asymmetry in KL-regularized RL. In International Conference on Learning Representations, 2019. [8] Guillaume Garrigos and Robert M Gower. Handbook of convergence theorems for (stochastic) gradient methods. ar Xiv preprint ar Xiv:2301.11235, 2023. [9] Eduard Gorbunov, Filip Hanzely, and Peter Richtárik. A Unified Theory of SGD: Variance Reduction, Sampling, Quantization and Coordinate Descent. 23rd International Conference on Artificial Intelligence and Statistics (AISTATS), 2020. [10] R.M. Gower, P. Richtárik, and F. Bach. Stochastic quasi-gradient methods: variance reduction via Jacobian sketching. Math. Program. 188, pages 135 192, 2021. [11] Robert Mansel Gower, Nicolas Loizou, Xun Qian, Alibek Sailanbayev, Egor Shulgin, and Peter Richtárik. SGD: General Analysis and Improved Rates. Proceedings of the 36th International Conference on Machine Learning, 2019. [12] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. Conference on Computer Vision and Pattern Recognition, 2016. [13] G. Hinton, O. Vinyals, and J. Dean. Distilling the knowledge in a neural network. Deep Learning Workshop at NIPS, 2014. [14] G. Hinton, O. Vinyals, and J. Dean. Distilling the Knowledge in a Neural Network. Ar Xiv e-prints, March 2015. [15] Eugenia Iofinova, Alexandra Peste, Mark Kurtz, and Dan Alistarh. How Well Do Sparse Imagenet Models Transfer? Conference on Computer Vision and Pattern Recognition, 2022. [16] Rie Johnson and Tong Zhang. Accelerating stochastic gradient descent using predictive variance reduction. Advances in Neural Information Processing Systems, 26, 2013. [17] Hamed Karimi, Julie Nutini, and Mark Schmidt. Linear Convergence of Gradient and Proximal Gradient Methods Under the Polyak-Łojasiewicz Condition. ar Xiv:1608.04636, 2020. [18] Ahmed Khaled and Peter Richtárik. Gradient descent with compressed iterates. ar Xiv preprint ar Xiv:1909.04716, 2019. [19] Ahmed Khaled and Peter Richtárik. Better Theory for SGD in the Nonconvex World. Transactions on Machine Learning Research, 2023. [20] Jakub Konˇecný and Peter Richtárik. Semi-Stochastic Gradient Descent Methods. Frontiers in Applied Mathematics and Statistics 3:9, 2017. [21] Simon Kornblith, Jonathon Shlens, and Quoc V. Le. Do Better Image Net Models Transfer Better? Conference on Computer Vision and Pattern Recognition, 2019. [22] Dmitry Kovalev, Samuel Horváth, and Peter Richtárik. Don t jump through hoops and remove those loops: SVRG and Katyusha are better without the outer loop. 31st International Conference on Learning Theory (ALT), 2020. [23] Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. Cite Seer, 2009. [24] Yann Le Cun and Corinna Cortes. MNIST handwritten digit database. http://yann.lecun.com/exdb/mnist/, 2010. [25] Lihua Lei, Cheng Ju, Jianbo Chen, and Michael I. Jordan. Non-Convex Finite-Sum Optimization Via SCSG Methods. 31st Conference on Neural Information Processing Systems (NIPS), 2017. [26] Yuncheng Li, Jianchao Yang, Yale Song, Liangliang Cao, Jiebo Luo, and Li-Jia Li. Learning from noisy labels with distillation. In 2017 IEEE International Conference on Computer Vision (ICCV), pages 1928 1936. IEEE, 2017. [27] Dong C Liu and Jorge Nocedal. On the limited memory bfgs method for large scale optimization. Mathematical programming, 45(1):503 528, 1989. [28] David Lopez-Paz, Léon Bottou, Bernhard Schölkopf, and Vladimir Vapnik. Unifying distillation and privileged information. ar Xiv preprint ar Xiv:1511.03643, 2015. [29] Grigory Malinovsky, Alibek Sailanbayev, and Peter Richtárik. Random reshuffling with variance reduction: new analysis and better rates. 39th Conference on Uncertainty in Artificial Intelligence (UAI), 2023. [30] Ilia Markov, Adrian Vladu, Qi Guo, and Dan Alistarh. Quantized Distributed Training of Large Models with Convergence Guarantees. ar Xiv preprint ar Xiv:2302.02390, 2023. [31] Seyed Iman Mirzadeh, Mehrdad Farajtabar, Ang Li, Nir Levine, Akihiro Matsukawa, and Hassan Ghasemzadeh. Improved knowledge distillation via teacher assistant. Proceedings of the AAAI conference on artificial intelligence, volume 34, pages 5191 5198, 2020. [32] Asit Mishra and Debbie Marr. Apprentice: Using knowledge distillation techniques to improve low-precision network accuracy. ar Xiv preprint ar Xiv:1711.05852, 2017. [33] Hossein Mobahi, Mehrdad Farajtabar, and Peter L. Bartlett. Self-Distillation Amplifies Regularization in Hilbert Space. 34th Conference on Neural Information Processing Systems (Neur IPS), 2020. [34] Vinod Nair and Geoffrey E Hinton. Rectified linear units improve restricted boltzmann machines. In International Conference on Machine Learning, 2010. [35] Yurii Nesterov. Introductory lectures on convex optimization: A basic course, volume 87. Springer Science & Business Media, 2013. [36] Emilio Parisotto, Jimmy Lei Ba, and Ruslan Salakhutdinov. Actor-mimic: Deep multitask and transfer reinforcement learning. In International Conference on Learning Representations, 2016. [37] Constantin Philippenko and Aymeric Dieuleveut. Preserved central model for faster bidirectional compression in distributed settings. 35th Advances in Neural Information Processing Systems, 2021. [38] Mary Phuong and Christoph Lampert. Towards understanding knowledge distillation. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, pages 5142 5151, Long Beach, California, USA, 09 15 Jun 2019. PMLR. [39] Antonio Polino, Razvan Pascanu, and Dan Alistarh. Model compression via distillation and quantization. International Conference on Learning Representations (ICLR), 2018. [40] B. T. Polyak. Gradient methods for minimizing functionals (in Russian). Zh. Vychisl. Mat. Mat. Fiz., pages 643 653, 1963. [41] Sashank J. Reddi, Ahmed Hefny, Suvrit Sra, Barnabas Poczos, and Alex Smola. Stochastic Variance Reduction for Nonconvex Optimization. Proceedings of the 33rd International Conference on Machine Learning, 2016. [42] Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, et al. Image Net large scale visual recognition challenge. IJCV, 115(3):211 252, 2015. [43] Andrei A Rusu, Sergio Gomez Colmenarejo, Caglar Gulcehre, Guillaume Desjardins, James Kirkpatrick, Razvan Pascanu, Volodymyr Mnih, Koray Kavukcuoglu, and Raia Hadsell. Policy distillation. In International Conference on Learning Representations, 2016. [44] Mher Safaryan, Filip Hanzely, and Peter Richtárik. Smoothness Matrices Beat Smoothness Constants: Better Communication Compression Techniques for Distributed Optimization. In 35th Conference on Neural Information Processing Systems, 2021. [45] Hadi Salman, Andrew Ilyas, Logan Engstrom, Ashish Kapoor, and Aleksander Madry. Do Adversarially Robust Image Net Models Transfer Better? Advances in Neural Information Processing Systems, 2020. [46] Simon Schmitt, Jonathan J Hudson, Augustin Zidek, Simon Osindero, Carl Doersch, Wojciech M Czarnecki, Joel Z Leibo, Heinrich Kuttler, Andrew Zisserman, Karen Simonyan, et al. Kickstarting deep reinforcement learning. ar Xiv preprint ar Xiv:1803.03835, 2018. [47] Sebastian U Stich, Jean-Baptiste Cordonnier, and Martin Jaggi. Sparsified SGD with memory. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems 31, pages 4452 4463. Curran Associates, Inc., 2018. [48] Md Arafat Sultan. Knowledge Distillation Label Smoothing: Fact or Fallacy? ar Xiv preprint ar Xiv:2301.12609, 2023. [49] Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. Mobilebert: a compact task-agnostic bert for resource-limited devices. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 2158 2170, 2020. [50] Hokchhay Tann, Soheil Hashemi, R Iris Bahar, and Sherief Reda. Hardware-software codesign of accurate, multiplier-free deep neural networks. In Design Automation Conference (DAC), 2017 54th ACM/EDAC/IEEE, pages 1 6. IEEE, 2017. [51] Yee Teh, Victor Bapst, Wojciech M Czarnecki, John Quan, James Kirkpatrick, Raia Hadsell, Nicolas Heess, and Razvan Pascanu. Distral: Robust multitask reinforcement learning. In Advances in Neural Information Processing Systems, pages 4496 4506, 2017. [52] Vladimir Vapnik and Rauf Izmailov. Learning using privileged information: similarity control and knowledge transfer. Journal of machine learning research, 16(20232049):55, 2015. [53] Thijs Vogels, Sai Praneeth Karimireddy, and Martin Jaggi. Power SGD: Practical Low-Rank Gradient Compression for Distributed Optimization. 33th Advances in Neural Information Processing Systems, 2019. [54] Wenhui Wang, Furu Wei, Li Dong, Hangbo Bao, Nan Yang, and Ming Zhou. Minilm: Deep self-attention distillation for task-agnostic compression of pre-trained transformers. Advances in Neural Information Processing Systems, 33:5776 5788, 2020. [55] Jianqiao Wangni, Jialei Wang, Ji Liu, and Tong Zhang. Gradient sparsification for communication-efficient distributed optimization. In Advances in Neural Information Processing Systems, pages 1306 1316, 2018. [56] Wei Wen, Cong Xu, Feng Yan, Chunpeng Wu, Yandan Wang, Yiran Chen, and Hai Li. Terngrad: Ternary gradients to reduce communication in distributed deep learning. In Advances in Neural Information Processing Systems, page 1509 1519, 2017. [57] Lin Xiao and Tong Zhang. A Proximal Stochastic Gradient Method with Progressive Variance Reduction. SIAM Journal on Optimization, 2014. [58] Li Yuan, Francis EH Tay, Guilin Li, Tao Wang, and Jiashi Feng. Revisiting Knowledge Distillation via Label Smoothing Regularization. Conference on Computer Vision and Pattern Recognition, 2020. [59] Linfeng Zhang, Jiebo Song, Anni Gao, Jingwei Chen, Chenglong Bao, and Kaisheng Ma. Be your own teacher: Improve the performance of convolutional neural networks via self distillation. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 3713 3722, 2019. [60] Linfeng Zhang, Jiebo Song, Anni Gao, Jingwei Chen, Chenglong Bao, and Kaisheng Ma. Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation. ICCV, 2019. [61] Xinshao Wang Ziyun Li, Di Hu, Neil M. Robertson, David A. Clifton, Christoph Meinel, and Haojin Yang. Not All Knowledge Is Created Equal: Mutual Distillation of Confident Knowledge. Neur IPS 2022 Workshop(Trustworthy and Socially Responsible Machine Learning), 2022. 1 Introduction 1 2 Related Work 2 3 Knowledge Distillation 3 4 Convergence Theory for Self-Distillation 5 5 Removing Bias and Improving Variance Reduction 8 6 Convergence for Distillation of Compressed Models 9 7 Discussion and Future Work 10 A Basic Facts and Inequalities 16 B Proofs for Section 3 17 B.1 Binary Logistic Regression . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 17 B.2 Multi-class Classification with Soft-max . . . . . . . . . . . . . . . . . . . . . . . 18 B.3 Generic Non-linear Classification . . . . . . . . . . . . . . . . . . . . . . . . . . . 18 C Proofs for Section 4 20 C.1 Key lemma . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 20 C.2 Proof of Theorem 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 21 C.3 Proof of Theorem 2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 23 D Proofs for Section 5 23 D.1 Proof of Theorem 3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 23 E Proofs for Section 6 24 E.1 Five Lemmas . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 24 E.2 Proof of Theorem 4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 26 F Additional Experimental Validation 28 F.1 The impact of the learning hyperparameters . . . . . . . . . . . . . . . . . . . . . 28 F.2 The impact of the teacher s quality . . . . . . . . . . . . . . . . . . . . . . . . . . 29 F.3 The impact of knowledge distillation on compression . . . . . . . . . . . . . . . . 29 G Limitations 29 A Basic Facts and Inequalities To facilitate the reading of technical part of the work, here we present several standard inequalities and basic facts that we are going to use in the proofs. Usually we need to bound the sum of two error terms by individual errors using the following simple bound a + b 2 2 a 2 + 2 b 2 . (20) A direct generalization of this inequality for arbitrary number of summation terms is the following one: i=1 ai 2 . (21) This inequality can be seen as a special case of Jensen s inequality for convex functions h: Rd R, i=1 αih(xi), where xi Rd are any vectors and αi [0, 1] with Pn i=1 αi = 1. Then, (20) follows from Jensen s inequality when h(x) = x 2 and α1 = = αn = 1 n. A more general version of the Jensen s inequality, from the perspective of probability theory, is h(Ez) Eh(z) (22) for any convex function h and random vector z Rd. Another extension of (20), which we actually apply in that form, is Peter-Paul inequality given by 2 b 2 , (23) for any positive s > 0. Then, (20) is the special case of (23) with s = 1. Typically, a function f is called L-smooth if its gradient is Lipschitz continuous with Lipschitz constant L 0, namely f(x) f(y) L x y , (24) In particular, smoothness inequality (24) implies the following quadratic upper bound [35]: f(y) f(x) + f(x), y x + L 2 y x 2 , (25) for all points x, y Rd. If the function f is additionally convex, then the following lower bound holds too: f(y) f(x) + f(x), y x + 1 2L f(y) f(x) 2 . (26) As we mentioned in the main part of the paper, function f is called µ-strongly convex if the following inequality holds f(y) f(x) + f(x), y x + µ 2 y x 2 , (27) for all points x, y Rd. Recall that strong quasi-convexity assumption (1) is the special case of (27) when y = x . In other words, strong convexity (27) implies strong quasi-convexity (1). Continuing this chain of conditions, strong quasi-convexity (1) implies PL condition (9), which, in turn, implies the so-called Quadratic Growth condition given below 8 x x 2, (28) for all x Rd. Derivations of these implications and relationship with other conditions can be found in [17]. As many analyses with linear or exponential convergence speed, our analysis also uses a very standard transformation from a single-step recurrence relation to convergence inequality. Specifically, assume we have the following recursion for et, t = 0, 1, 2, . . . : et+1 (1 η)et + N, with constants η (0, 1] and N 0. Then, repeated application of the above recursion gives et (1 η)te0 + (1 η)t 1N + (1 η)t 2N + + (1 η)0N (1 η)te0 + N j=0 (1 η)j = (1 η)te0 + N If z Rd is a random vector and x Rd is fixed (or has randomness independent of z), then the following decomposition holds: E h x + z 2i = x 2 + E h z 2i (30) B Proofs for Section 3 For the sake of presentation, we first consider binary logistic regression problem as a special case of multi-class classification. B.1 Binary Logistic Regression In this case we have d-dimensional input vectors an Rd with their binary true labels bn B = [0, 1], predictor ϕx(a) = σ(x a) (0, 1) with parameters x P = Rd+1 and lifted input vector a = [a 1] Rd+1 (to avoid additional notation for the bias terms), where σ(t) = 1 1+e t , t R is the sigmoid function. Besides, the loss is given by the cross entropy loss below ℓ(p, q) = H q 1 q , p 1 p = q log p (1 q) log(1 p), p, q [0, 1]. Based on (2) and (3), we have fn(x | θ, λ) = (1 λ)fn(x) + λfn(x | θ) = (1 λ)ℓ(ϕx(an), bn) + λℓ(ϕx(an), ϕθ(an)) = ℓ(ϕx(an), (1 λ)bn + λϕθ(an)) = ℓ(σ(x an), (1 λ)bn + λσ(θ an)) = ℓ(σ(x an), sn) = sn log 1 + e x an + (1 sn) log 1 + ex an , where sn = (1 λ)bn + λσ(θ an) are the soft labels. Notice that fn(x | θ, λ = 0) = fn(x). Next, we derive an expression for the stochastic gradient for the above objective, namely the gradient xfn(x | θ, λ) of the loss associated with nth data point (an, bn). x h sn log 1 + e x an i = sn e x an 1 + e x an an = snσ( x an) an, x h (1 sn) log 1 + ex an i = (1 sn) ex an 1 + ex an an = (1 sn)σ(x an) an. Hence, using the identity σ(t) + σ( t) = 1 for the sigmoid function, we get xfn(x | θ, λ) = sn(1 σ(x an)) + (1 sn)σ(x an) an = σ(x an) sn an = σ(x an) (1 λ)bn λσ(θ an) an = σ(x an) bn an λ σ(θ an) bn an = fn(x) λ fn(θ). Thus, the distillation gradient for binary logistic regression tasks the same form as (5). B.2 Multi-class Classification with Soft-max Now we extend the above steps for multi-class problem. Here we consider a K-classification model with one hidden layer that has soft-max as the last layer, i.e., the forward pass has the following steps: an X an ϕX(an) := σ(X an) RK, where X = [x1 x2 . . . x K] Rd K are the student s model parameters, an A = Rd is the input data and σ is the soft-max function (as a generalization of sigmoid function). Then we simplify the loss fn(X | Θ, λ) = (1 λ)fn(X) + λfn(X | Θ) = (1 λ)ℓ(ϕX(an), bn) + λℓ(ϕX(an), ϕΘ(an)) = ℓ(ϕX(an), (1 λ)bn + λϕΘ(an)) = ℓ(σ(X an), sn) k=1 sn,k log σ(X an)k = k=1 sn,k log ex k an PK j=1 ex j an j=1 ex j an log ex k an k=1 sn,k log j=1 ex j an k=1 sn,k log ex k an k=1 ex k an k=1 sn,k log ex k an = log k=1 ex k an k=1 sn,kx k an, where sn = (1 λ)bn + λϕΘ(an) RK are the soft labels. Next, we derive an expression for the stochastic gradient for the above objective, namely the gradient xkfn(X | Θ, λ) of the loss associated with nth data point (an, bn). xkfn(X | Θ, λ) = xk k=1 ex k an k=1 sn,kx k an = ex k an PK i=1 ex i an an sn,kan = σ(X an) sn k an = σ(X an) bn k an λ σ(Θ an) bn an = xkfn(X) λ θkfn(Θ). Again, we get the same expression (5) for the distillation gradient Xfn(X | Θ, λ) = Xfn(X) λ Θfn(Θ). B.3 Generic Non-linear Classification Finally, consider arbitrary classification model that ends with linear layer and soft-max as the last layer, i.e., the forward pass has the following steps an ψn(x) ϕx(an) := σ(ψn(x)) RK, where an A = Rd is the input data and ψn(x) RK are the logits with respect to the model parameters x. Denote φn(z) := ℓ(σ(z), bn) the loss associated with logits z and the true label bn. In words, ψn gives the logits from the input data, while φn gives the loss from given logits. Then, clearly we have the following representation for the loss function fn(x) = φn(ψn(x)). Next, let us simplify the distillation loss as fn(x | θ, λ) = (1 λ)fn(x) + λfn(x | θ) = (1 λ)ℓ(ϕx(an), bn) + λℓ(ϕx(an), ϕθ(an)) = ℓ(ϕx(an), (1 λ)bn + λϕθ(an)) = ℓ(σ(ψn(x), (1 λ)bn + λσ(ψn(θ))) = ℓ(σ(ψn(x), sn) k=1 sn,k log σ(ψn(x))k = k=1 sn,k log eψn,k(x) PK j=1 eψn,j(x) j=1 eψn,j(x) ψn,k(x) k=1 sn,k log j=1 eψn,j(x) k=1 sn,kψn,k(x) k=1 eψn,k(x) k=1 sn,kψn,k(x), where sn = (1 λ)bn + λϕθ(an). Now we need to differentiate obtained expression and derive an expression for the stochastic gradient for the above objective, namely the gradient xfn(x | θ, λ) of the loss associated with nth data point (an, bn). Applying the gradient operator, we get xfn(x | θ, λ) = x k=1 eψn,k(x) k=1 sn,kψn,k(x) eψn,k(x) PK j=1 eψn,j(x) xψn,k(x) k=1 sn,k xψn,k(x) k=1 (σ(ψn(x)) sn)k xψn(x) = Jψn(x) (σ(ψn(x)) sn) , Jψn(x) := ψn(x) x = [ ψn,1(x) ψn,2(x) . . . ψn,K(x)] Rd K is the Jacobian of vector-valued function ψn : Rd RK. From the derivation so far we imply that σ(ψn(x)) sn = (σ(ψn(x)) bn) λ (σ(ψn(θ)) bn) = φn(ψn(x)) λ φn(ψn(θ)). Taking into account that fn(x) = φn(ψn(x)), we show the following form for the distilled gradient xfn(x | θ, λ) = Jψn(x) ( φn(ψn(x)) λ φn(ψn(θ))) x fn(x) ψn(x) λ ψn(x) x fn(θ) ψn(θ). C Proofs for Section 4 Before we proceed to the proofs of Theorems 1 and 2, we prove a key lemma that will be useful in both proofs. The lemma we are about to present covers Part 2 (Optimal distillation weight) and Part 3 (Impact of the teacher) of the proof overview discussed in the main content. C.1 Key lemma To simplify the expressions in our proofs, let us introduce some notation describing stochastic gradients. Denote the signal-to-noise ratio with respect to parameters θ by β(θ) := f(θ) 2 E [ fξ(θ) 2] [0, 1], (31) and correlation coefficient between stochastic gradients fξ(x) and fξ(y) by ρ(x, y) := Cov(x, y) Var(x)Var(y) [ 1, 1], (32) Cov(x, y) := E [ fξ(x) f(x), fξ(y) f(y) ] , Var(x) := p are the covariance and variance respectively. Lemma 1. Let N(λ) = λ2 f(θ) 2 + cγE fξ(x ) λ fξ(θ) 2 for some constant c 0. Then, the optimal λ that minimizes N(λ) is given by λ = E [ fξ(x ), fξ(θ) ] E [ fξ(θ) 2] + 1 cγ f(θ) 2 . (33) N(0) = 1 ρ2(x , θ) 1 β(θ) cγ β(θ) min 1, O 1 γ (f(θ) f ) . Proof. Notice that N(λ) is quadratic in λ and using the first-order optimality condition, we conclude d dλN(λ) = 2λ f(θ) 2 + cγ 2E [ fξ(x ), fξ(θ) ] + 2λE fξ(θ) 2 = 0, we get (33). Furthermore, plugging the expression of λ into N(λ), we get N(λ ) = cγ λ2 cγ f(θ) 2 + E fξ(x ) 2 2λ E [ fξ(x ), fξ(θ) ] + λ2 E fξ(θ) 2 = cγ E fξ(x ) 2 2λ E [ fξ(x ), fξ(θ) ] + λ2 E fξ(θ) 2 + 1 E fξ(x ) 2 (E [ fξ(x ), fξ(θ) ])2 E [ fξ(θ) 2] + 1 = cγE fξ(x ) 2 1 (E [ fξ(x ), fξ(θ) ])2 E [ fξ(θ) 2] + 1 cγ f(θ) 2 (E [ fξ(x ) 2]) Note that N(0) = cγσ2 . From E [ fξ(x )] = f(x ) = 0, we imply that Cov(x , θ) = E [ fξ(x ), fξ(θ) f(θ) ] = E [ fξ(x ), fξ(θ) ] , and therefore ρ(x , θ) = E [ fξ(x ), fξ(θ) ] p E [ fξ(x ) 2] p E [ fξ(θ) 2] f(θ) 2 . Now we can simplify the expression for N(λ ) as follows N(0) = 1 (E [ fξ(x ), fξ(θ) ])2 E [ fξ(θ) 2] + 1 cγ f(θ) 2 (E [ fξ(x ) 2]) = 1 (E [ fξ(x ), fξ(θ) ])2 (E [ fξ(θ) 2] f(θ) 2) (E [ fξ(x ) 2]) E fξ(θ) 2 f(θ) 2 E [ fξ(θ) 2] + 1 = 1 ρ2(x , θ) 1 β(θ) Here, ρ(x , θ) is the correlation coefficient between stochastic gradients at x and θ. Hence, we showed with tuned distillation weight the neighborhood can shrink by some factor depending on the teacher s parameters. In the extreme case when the teacher θ = x is optimal, we have ρ(x , θ) = 1, β(θ) = 0 and, thus, no neighborhood N(λ ) = 0. This hints us on the fact that the reduction factor N(λ )/N(0) of the neighborhood is controlled by the quality of the teacher. To make this argument rigorous, consider the teacher s model to be away from the optimal solution x within the limit described by the following inequality f(θ) f σ(x )σ(θ) where σ2(x) := E fξ(x) 2 is the second moment of the stochastic gradients. Without loss of generality, we assume that σ2(x) > 0 for all parameter choices x Rd: otherwise we have σ2 = 0 and even plain SGD ensures full variance reduction. Then, we can simplify the reduction factor as 1 ρ2(x , θ) 1 β(θ) cγ β(θ) = 1 (E [ fξ(x ), fξ(θ) ])2 E [ fξ(θ) 2] E [ fξ(x ) 2] 1 1 + 1 σ2(θ) + σ2(x ) E fξ(x ) fξ(θ) 2 σ2(θ) + σ2(x ) 2σ(θ)σ(x ) E fξ(x ) fξ(θ) 2 (10)+(34) 1 1 L(f(θ) f ) = 2L σ(x )σ(θ) + 1 cγ 2L σ2(θ) (f(θ) f ) = O 1 γ (f(θ) f ) , where the last inequality used 1 (1 u)2 1+uv (2 + v)u for all u, v 0. C.2 Proof of Theorem 1 Denote by Et [ ] := E [ | xt] the conditional expectation with respect to xt. Then, we start bounding the error using the update rule (7). Et xt+1 x 2 = xt x 2 2γ xt x , f(xt) λ f(θ) + γ2Et fξ(xt) λ fξ(θ) 2 = xt x 2 2γ xt x , f(xt) + 2γλ xt x , f(θ) + γ2Et fξ(xt) λ fξ(θ) 2 (1)+(20) (1 γµ) xt x 2 2γ(f(xt) f(x )) + 2γλ xt x , f(θ) + 2γ2Et fξ(xt) fξ(x ) 2 + 2γ2E fξ(x ) λ fξ(θ) 2 (28)+(23) (1 γµ) xt x 2 γ(f(xt) f(x )) γµ 8 xt x 2 + γµ xt x 2 + 8γ µ λ2 f(θ) 2 + 2γ2Et fξ(xt) fξ(x ) 2 + 2γ2E fξ(x ) λ fξ(θ) 2 (3) (1 γµ) xt x 2 γ(f(xt) f(x )) + 8γ µ λ2 f(θ) 2 + 4γ2L(f(xt) f(x )) + 2γ2E fξ(x ) λ fξ(θ) 2 = (1 γµ) xt x 2 γ (1 4γL) (f(xt) f(x )) µ λ2 f(θ) 2 + 2γ2E fξ(x ) λ fξ(θ) 2 (1 γµ) xt x 2 + 8γ µ λ2 f(θ) 2 + 2γ2E fξ(x ) λ fξ(θ) 2 , where we used Peter-Paul inequality (23) with parameter s = 8 µ and the step-size bound γ 1 4L in the last inequality. Applying full expectation and unrolling the recursion, we get E xt x 2 (1 γµ)E xt 1 x 2 + 8γ µ λ2 f(θ) 2 + 2γ2E fξ(x ) λ fξ(θ) 2 (29) (1 γµ)t x0 x 2 + 8λ2 µ2 f(θ) 2 + 2γ µ E fξ(x ) λ fξ(θ) 2 = (1 γµ)t x0 x 2 + 8 µ2 N1(λ), (35) where N1(λ) := λ2 f(θ) 2 + γµ 4 E fξ(x ) λ fξ(θ) 2 . Applying Lemma 1 with c = µ/4, we imply that for some λ = λ the neighborhood size is N1(λ ) Lemma 1 N1(0) min 1, O 1 γ (f(θ) f ) = µσ2 4 min (γ, O(f(θ) f )) . Plugging the above bound of N1 into (35) completes the proof. C.3 Proof of Theorem 2 We start the recursion from the L-smoothness condition of f. As before, Et denotes conditional expectation with respect to xt. Et f(xt+1) f (25) f(xt) f γ f(xt), f(xt) λ f(θ) + Lγ2 2 Et fξ(xt) λ fξ(θ) 2 (20) f(xt) f γ f(xt) 2 + γλ f(xt), f(θ) + Lγ2Et fξ(xt) fξ(x ) 2 + Lγ2E fξ(x ) λ fξ(θ) 2 (9)+(23) (1 γµ) f(xt) f γ 4 f(xt) 2 γµ 2 f(xt) f + γ 4 f(xt) 2 + γλ2 f(θ) 2 + Lγ2Et fξ(xt) fξ(x ) 2 + Lγ2Et fξ(x ) λ fξ(θ) 2 (10) (1 γµ) f(xt) f γµ 2 f(xt) f + γλ2 f(θ) 2 + 2LLγ2 f(xt) f + Lγ2Et fξ(x ) λ fξ(θ) 2 = (1 γµ) f(xt) f γµ + γλ2 f(θ) 2 + Lγ2E fξ(x ) λ fξ(θ) 2 (1 γµ) f(xt) f + γλ2 f(θ) 2 + Lγ2E fξ(x ) λ fξ(θ) 2 , where in the last inequality we used step-size bound γ 1 4L µ L. Applying full expectation and unrolling the recursion, we get E f(xt) f (1 γµ)E f(xt 1) f 2 + γλ2 f(θ) 2 + Lγ2E fξ(x ) λ fξ(θ) 2 (29) (1 γµ)t f(x0) f + λ2 µ f(θ) 2 + Lγ µ E fξ(x ) λ fξ(θ) 2 = (1 γµ)t f(x0) f + 1 µN2(λ), (36) where N2(λ) := λ2 f(θ) 2 + LγE fξ(x ) λ fξ(θ) 2 . Similar to the previous case, we applying Lemma 1 with c = L and conclude that for some λ = λ the neighborhood size is N2(λ ) Lemma 1 N2(0) min 1, O 1 γ (f(θ) f ) = Lσ2 min (γ, O(f(θ) f )) . Plugging the above bound of N2 into (36) completes the proof. D Proofs for Section 5 D.1 Proof of Theorem 3 Again, we start the recursion from the smoothness condition of f. Et f(xt+1) f (25) f(xt) f γ f(xt), f(xt) + Lγ2 2 Et fξ(xt) fξ(θ) + f(θ) 2 (21) f(xt) f γ f(xt) 2 2Lγ2Et fξ(xt) fξ(x ) 2 + 3 2Lγ2E fξ(θ) fξ(x ) 2 + 3 2Lγ2 f(θ) 2 (9)+(10) (1 γµ) f(xt) f γµ f(xt) f + 3LLγ2 f(xt) f + 3LLγ2 (f(θ) f ) + 3L2γ2(f(θ) f ) (1 γµ) f(xt) f + 3L(L + L)γ2 (f(θ) f ) , where we used step-size bound γ µ 3LL in the last step. Therefore, E f(xt) f (1 γµ)E f(xt) f + 3L(L + L)γ2 (f(θ) f ) (29) (1 γµ)t(f(x0) f ) + 3L(L + L) µ γ (f(θ) f ) , which concludes the proof. E Proofs for Section 6 First, we break the update rule (18) into two parts by introducing an auxiliary model parameters yt Rd: yt+1 = xt γ( fξ(xt) λ fξ(θ)), xt+1 = C(yt+1). Without loss of generality, we assume that initialization satisfies x0 = y0 = C(y0). Then, for all t 0 we have xt = C(yt) and the update rule of yt+1 can be written recursively without xt via yt+1 = C(yt) γ( fξ(C(yt)) λ fξ(θ)). (37) Using unbiasedness of the compression operator C, we decompose the error E h xt x 2i (30) = E h xt yt 2i + E h yt x 2i = E h C(yt) yt 2i + E h yt x 2i (19) ωE h yt 2i + E h yt x 2i (20) (2ω + 1)E h yt x 2i + 2ω x 2 . (38) Thus, our goal would be to analyze iterates (37) and derive the rate for xt. In fact, the special case of (37) was analyzed by [18] in the non-stochastic case and without distillation (λ = 0). The analysis we provide here for (37) is based on [18], and from this perspective, our analysis can be seen as an extension of their analysis. To avoid another notation for the expected smoothness constant (see Assumption 3), analogous to (24) we assume that L also satisfies the following smoothness inequality: E h fξ(x) fξ(y) 2i L2 x y 2. (39) E.1 Five Lemmas First, we need to upper bound the compression error by function suboptimality. Denote δ(x) := C(x) x. Let Eδ and Eξ be the expectations with respect to the compression operator C and sampling ξ respectively. Lemma 2 (Lemma 1 in [18]). Let α = 2ω µ and ν = 2ω x 2. For all x Rd we have Eδ δ(x) 2 2α (f(x) f(x )) + ν, (40) Proof. From (20) we imply x 2 2 x x 2 + 2 x 2, and from µ-strong convexity condition (27) we get x x 2 2 µ (f(x) f(x )). Putting these inequalities together, we arrive at Eδ δ(x) 2 ω x 2 2ω x x 2 + 2ω x 2 4ω µ (f(x) f(x ) + 2ω x 2. Lemma 3. For all x, y Rd we have E fξ(x + δ(x)) fξ(y) 2 L2 x y 2 + E δ(x) 2 , (41) f(x) Ef(x + δ(x)) f(y) + f(y), x y + L 2 x y 2 + L 2 E δ(x) 2. (42) Proof. Fix x and let δ = δ(x). Inequality (41) follows from Lipschitz continuity of the gradient, applying expectation and using (19): E fξ(x + δ) fξ(y) 2 (39) L2Eδ x + δ y 2 (19)+(30) = L2 x y 2 + Eδ δ 2 . The first inequality in (42) follows by applying Jensen s inequality (22) and using (19). Since f is L smooth, we have Ef(x + δ) (25) Ef(y) + f(y), x + δ y + L 2 x + δ y 2 (19)+(30) = f(y) + f(y), x y + L 2 x y 2 + L Lemma 4. For all x, y Rd it holds γ fξ(x + δ(x)) + λ fξ(θ) 2 fξ(y) λ fξ(θ) 2 + 2L2 x y 2 + 2 L2 + 1 Eδ δ(x) 2. (43) Proof. Fix x, and let δ = δ(x). Then for every y Rd we can write δ γ fξ(x + δ) + λ fξ(θ) δ γ fξ(y) fξ(x + δ) + λ fξ(θ) δ γ fξ(y) + λ fξ(θ) 2 + 2Eδ fξ(y) fξ(x + δ) 2 (41) 2 γ2 Eδ δ 2 2 γ Eδ δ, fξ(y) λ fξ(θ) + fξ(y) λ fξ(θ) 2 + 2L2 x y 2 + Eδ δ 2 (19) = 2 γ2 Eδ δ 2 + 2 fξ(y) λ fξ(θ) 2 + 2L2 x y 2 + Eδ δ 2 . The next lemma generalizes the strong convexity inequality (27) (special case of δ(x) 0). Lemma 5. If f is L-smooth and µ-strongly convex, then for all x, y Rd, it holds f(y) f(x) + E [ f(x + δ)] , y x + µ 2 y x 2 L µ 2 E δ(x) 2. (44) Proof. Fix x and let δ = δ(x). Using (27) with x x + δ, we get f(y) f(x + δ) + f(x + δ), y x δ + µ 2 y x δ 2 . Applying expectation and (30), we get f(y) Ef(x + δ) + E f(x + δ), y x E f(x + δ), δ + µ 2 y x 2 + µ It remains to bound the term E f(x + δ), δ , which can be done using L-smoothness and applying expectation as follows: E f(x + δ), δ (25) E f(x) f(x + δ) L 2 δ 2 = f(x) Ef(x + δ) L Lemma 6. Denote A = 2L + L2 + 1 γ2 α and B = L + 1 γ2 ν, where α, ν are defined in Lemma 2. Then γ fξ(x + δ(x)) + λ fξ(θ) 4A(f(x) f(x )) + 2B + 4E fξ(x ) λ fξ(θ) 2, (45) Proof. Using (43) with y = x, we get γ fξ(x + δ(x)) + λ fξ(θ) (43) 2E fξ(x) λ fξ(θ) 2 + 2 L2 + 1 (20)+(26)+(40) 8L(f(x) f(x )) + 4E fξ(x ) λ fξ(θ) 2 + 2 L2 + 1 (2α(f(x) f(x )) + ν) = 4 2L + L2 + 1 α (f(x) f(x )) + 2 L2 + 1 ν + 4E fξ(x ) λ fξ(θ) 2. E.2 Proof of Theorem 4 Denoting δt = δ(yt) we have C(yt) = yt + δt). Then = C(yt) γ fξ(C(yt)) + γλ fξ(θ) x 2 = yt x + δt γ fξ(yt + δt) + γλ fξ(θ) 2 = yt x 2 + 2 δt γ fξ(yt + δt) + γλ fξ(θ), yt x + δt γ fξ(yt + δt) + γλ fξ(θ) 2 . Taking conditional expectation Et := E [ | yt], we get Et h yt+1 x 2i = yt x 2 + 2γ Et f(yt + δt) λ f(θ), x yt + Et h δt γ fξ(yt + δt) + γλ fξ(θ) 2i (44) yt x 2 + 2γ f(x ) f(yt) µ yt x 2 + L µ 2 Et δt 2 + 2γλ f(θ), yt x γ fξ(yt + δt) + λ fξ(θ) = (1 γµ) yt x 2 2γ(f(yt) f(x )) + γ(L µ)Et δt 2 + 2γλ f(θ), yt x γ fξ(yt + δt) + λ fξ(θ) (27)+(23) (1 γµ) yt x 2 γ(f(yt) f(x )) γµ yt x 2 + γ(L µ)Et δt 2 yt x 2 + 2γλ2 µ f(θ) 2 + γ2Et γ fξ(yt + δt) + λ fξ(θ) (45) (1 γµ) yt x 2 γ(f(yt) f(x )) + γ(L µ)Et δt 2 + 2γλ2 +4γ2A(f(yt) f(x )) + 2γ2B + 4γ2E fξ(x ) λ fξ(θ) 2 = (1 γµ) yt x 2 + γ(4γA 1)(f(yt) f(x )) + 2γ2B + γ(L µ)Et δt 2 µ f(θ) 2 + 4γ2E fξ(x ) λ fξ(θ) 2 (40) (1 γµ) yt x 2 + γ(4γA 1)(f(yt) f(x )) + 2γ2B +γ(L µ) (2α(f(xk) f(x )) + ν) µ f(θ) 2 + 4γ2E fξ(x ) λ fξ(θ) 2 = (1 γµ) yt x 2 + γ(4γA + 2α(L µ) 1)(f(yt) f(x )) + 2γ2B + γ(L µ)ν µ f(θ) 2 + 4γ2E fξ(x ) λ fξ(θ) 2, where α and ν are as in Lemma 2 and A and B are defined Lemma 6. Next, we show that the bounds on γ and ω lead to 2γA + α(L µ) 1/2. Plugging the expression of A, we the former inequality becomes 2γ 2L + L2 + 1 α + α (L µ) 1 Rearranging terms, we get µ = α 1/2 4γL 2γL2 + 2 For γ 1 16L, the above inequality holds if ω = O(γµ). If γ = 1 16L, the condition on variance parameter ω becomes ω = O(µ/L). Thus, by assumption on ω and γ, we have 2γA+α(L µ) 1/2, and hence Et h yt+1 x 2i (1 γµ) yt x 2 + D, D = 2γ2B + γ(L µ)ν + 2γλ2 µ f(θ) 2 + 4γ2E fξ(x ) λ fξ(θ) 2 = 2γ2B + γ(L µ)ν + 2γ Figure 4: Validation loss statistics for the same setup as in Figure 2a. with N3(λ) = λ2 f(θ) 2 + 2γµE fξ(x ) λ fξ(θ) 2. Applying Lemma 1 with c = 2µ we get N3(λ ) = N3(0) min 1, 1 γ O(f(θ) f ) = 2µσ2 min (γ, O(f(θ) f )). Taking expectation, unrolling the recurrence, and applying the tower property, we get E h yt x 2i (1 γµ)t y0 x 2 + D where the neighborhood is given by D γµ = 1 γµ(2γ2B + γ(L µ)ν + 2γ µ N(λ )) = 1 γµ(2γ2(L2 + 1/γ2)ν + γ(L µ)ν + 2γ = ν µ(2γL2 + 2/γ + L µ) + 2 µ2 N(λ ) = ν µ(2γL2 + 2/γ + L µ) + 4σ2 µ min (γ, O(f(θ) f )) . Ignoring the absolute constants and using bounds γ 1 16L and ω = O(γµ), we have E h yt x 2i (1 γµ)t x0 x 2 + O ω γµ x 2 + 4σ2 µ min (γ, O(f(θ) f )) . Applying this inequality in (38) we conclude the proof.2 F Additional Experimental Validation In this section we provide additional experimental validation for our theory. F.1 The impact of the learning hyperparameters We begin by measuring the impact that the optimization parameters, in particular the step size, have on the convergence of SGD and KD. For this, we perform experiments on linear models trained on the MNIST dataset, without momentum and regularization, using a mini-batch size of 10, for a total of 100 epochs. We compute the cross entropy train loss between the student and true labels, measured as a running average over all iterations within an epoch, similar to Figure 2. We also compute the minimum train loss, as well as the average and standard deviation across the last 10 epochs. The results in Figure 5a show the impact that different learning rates have on the overall training dynamics of self-distillation. In all cases, the teacher was trained using the same hyperparamters as the self-distilled models (λ = 0 in the plot). We can see that using a higher learning rate introduces more variance in the SGD update, and KD would have a more pronounced variance reduction effect. In all cases, however, we can find an optimum λ achieving a lower train loss compared to SGD. 2The rate in Theorem 4 uses γ = 1 16L value for the learning rate. 0.0 0.01 0.05 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Lambda Train Loss Statistics LR 0.05 LR 0.01 LR 0.02 (a) The impact of the step size on the learning dynamics of SGD and KD. The teacher was trained with the same hyperparameters as the corresponding self-distilled models. 0.0 0.01 0.05 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Lambda Train Loss Statistics LR 0.05 LR 0.02 (b) The impact of a better teacher on the learning dynamics of self-distillation. The same teacher was used in both setups. Figure 5: Ablation study on the training loss of self-distillation and SGD, when taking into account the training hyperparameters (learning rate) and quality of the teacher. 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Lambda Train Loss Statistics Random, 50% Sparsity Random, 70% Sparsity Figure 6: The impact of self-distillation on the train loss of compressed models. Here a random mask is computed at initialization, and kept fixed throughout training. F.2 The impact of the teacher s quality Next, we quantify the impact that a better trained teacher could have on self-distillation. Using the same setup of convex MNIST as above, we perform self-distillation using a better teacher, i.e. one that achieves 93.7% train accuracy and 92.5% test accuracy. (In comparison, the teacher trained using a step size of 0.05 achieved 92% train accuracy and 90.9% test accuracy) We can see in Figure 5b that this better-trained teacher has a more substantial impact on the models which inherently have higher variance, i.e. those trained with a higher learning rate; in this case, the optimal value of λ is closer to 1, which is also suggested by the theory (see Equation 14). We note that a similar behavior was also observed on the CIFAR-10 features linear classification task presented in Figure 2b. F.3 The impact of knowledge distillation on compression Now we turn our attention towards validating the theoretical results developed in the context of knowledge distillation for compressed models, presented in Section 6. We consider again the convex MNIST setting, as described in the previous section, and we perform self-distillation from the better trained teacher. We prune the weights at initialization, using a random mask at a chosen sparsity level, and we apply this fixed mask after each parameter update. The results presented in Figure 6 show that self-distillation can indeed reduce the train loss, compared to SGD, even for compressed updates. Moreover, we observe that with increased sparsity the impact of self-distillation is less pronounced, as also suggested by the theory. G Limitations Lastly, following our discussion from Section 7, we discuss some limitations of our work. As a theoretical paper, we used several assumptions to make our claims rigorous. However, one can always question each assumption and extend the theory under certain relaxations. Our theoretical claims are based on strong (quasi) convexity or Polyak-Łojasiewicz condition, which are standard assumptions in the optimization literature. Another limitation concerning the distillation for compression" part of our theory is the unbiasedness condition E [C(x)] = x in Assumption 4. Ideally, we would utilize any biased" compression operator, such as Top K, with similar convergence properties. However, it is known that biased estimators (e.g., biased compression operators or biased stochastic gradients) are harder to analyze in theory.