# knowledge_distillation_as_semiparametric_inference__5025df2a.pdf Published as a conference paper at ICLR 2021 KNOWLEDGE DISTILLATION AS SEMIPARAMETRIC INFERENCE Tri Dao1, Govinda M. Kamath2, Vasilis Syrgkanis2, Lester Mackey2 1 Department of Computer Science, Stanford University 2 Microsoft Research, New England trid@stanford.edu, {govinda.kamath,vasy,lmackey}@microsoft.com A popular approach to model compression is to train an inexpensive student model to mimic the class probabilities of a highly accurate but cumbersome teacher model. Surprisingly, this two-step knowledge distillation process often leads to higher accuracy than training the student directly on labeled data. To explain and enhance this phenomenon, we cast knowledge distillation as a semiparametric inference problem with the optimal student model as the target, the unknown Bayes class probabilities as nuisance, and the teacher probabilities as a plug-in nuisance estimate. By adapting modern semiparametric tools, we derive new guarantees for the prediction error of standard distillation and develop two enhancements cross-fitting and loss correction to mitigate the impact of teacher overfitting and underfitting on student performance. We validate our findings empirically on both tabular and image data and observe consistent improvements from our knowledge distillation enhancements. 1 INTRODUCTION Knowledge distillation (KD) (Craven & Shavlik, 1996; Breiman & Shang, 1996; Bucila et al., 2006; Li et al., 2014; Ba & Caruana, 2014; Hinton et al., 2015) is a widely used model compression technique that enables the deployment of highly accurate predictive models on devices such as phones, watches, and virtual assistants (Stock et al., 2020). KD operates by training a compressed student model to mimic the predicted class probabilities of an expensive, high-quality teacher model. Remarkably and across a wide variety of domains (Hinton et al., 2015; Sanh et al., 2019; Jiao et al., 2019; Liu et al., 2018; Tan et al., 2018; Fakoor et al., 2020), this two-step process often leads to higher accuracy than training the student directly on the raw labeled dataset. While the practice of KD is now well developed, a general theoretical understanding of its successes and failures is still lacking. As we detail below, a number of authors have argued that the success of KD lies in the more precise soft labels provided by the teacher s predicted class probabilities. Recently, Menon et al. (2020) observed that these teacher probabilities can serve as a proxy for the Bayes probabilities (i.e., the true class probabilities) and that the closer the teacher and Bayes probabilities, the better the student s performance should be. Building on this observation, we cast KD as a plug-in approach to semiparametric inference (Kosorok, 2007): that is, we view KD as fitting a student model ˆf in the presence of nuisance (the Bayes probabilities p0) with the teacher s probabilities ˆp as a plug-in estimate of p0. This insight allows us to adapt modern tools from semiparametric inference to analyze the error of a distilled student in Sec. 3. Our analysis also reveals two distinct failure modes of KD: one due to teacher overfitting and data reuse and the other due to teacher underfitting from model misspecification or insufficient training. In Sec. 4, we introduce and analyze two complementary KD enhancements that correct for these failures: cross-fitting a popular technique from semiparametric inference (see, e.g., Chernozhukov et al., 2018) mitigates teacher overfitting through data partitioning while loss correction mitigates teacher underfitting by reducing the bias of the plug-in estimate ˆp. The latter enhancement was inspired by the orthogonal machine learning (Chernozhukov et al., 2018; Foster & Syrgkanis, 2019) approach to semiparametric inference which suggests a particular adjustment for the teacher s log probabilities. We argue in Sec. 4 that this orthogonal correction minimizes the teacher bias but often at the cost of Published as a conference paper at ICLR 2021 unacceptably large variance. Our proposed correction avoids this variance explosion by balancing the bias and variance terms in our generalization bounds. In Sec. 5, we complement our theoretical analysis with a pair of experiments demonstrating the value of our enhancements on six real classification problems. On five real tabular datasets, cross-fitting and loss correction improve student performance by up to 4% AUC over vanilla KD. Furthermore, on CIFAR-10 (Krizhevsky & Hinton, 2009), a benchmark image classification dataset, our enhancements improve vanilla KD accuracy by up to 1.5% when the teacher model overfits. Related work. Since we cannot review the vast literature on KD in its entirety, we point the interested reader to Gou et al. (2020) for a recent overview of the field. We devote this section to reviewing theoretical advances in the understanding of KD and summarize complementary empirical studies and applications of in the extended literature review in App. A. A number of papers have argued that the availability of soft class probabilities from the teacher rather than hard labels enables us to improve training of the student model. This was hypothesized in Hinton et al. (2015) with empirical justification. Phuong & Lampert (2019) consider the case in which the teacher is a fixed linear classifier and the student is either a linear model or a deep linear network. They show that the student can learn the teacher perfectly if the number of training examples exceeds the ambient dimension. Vapnik & Izmailov (2015) discuss the setting of learning with privileged information where one has additional information at training time which is not available at test time. Lopez-Paz et al. (2015) draw a connection between this and KD, arguing that KD is effective because the teacher learns a better representation allowing the student to learn at a faster rate. They hypothesize that a teacher s class probabilities enable student improvement by indicating how difficult each point is to classify. Tang et al. (2020) argue using empirical evidence that label smoothing and reweighting of training examples using the teacher s predictions are key to the success of KD. Mobahi et al. (2020) analyzed the case of self-distillation in which the student and teacher function classes are identical. Focusing on kernel ridge regression models, they proved that self-distillation can act as increased regularization strength. Bu et al. (2020) considers more generic model compression in a rate-distortion framework, where the rate is the size of the student model and distortion is the difference in excess risk between the teacher and the student. Menon et al. (2020) consider the case of losses such that the population risk is linear in the Bayes class probabilities. They consider distilled empirical risk and Bayes distilled empirical risk which are the risk computed using the teacher class probabilities and Bayes class probabilities respectively rather than the observed label. They show that the variance of the Bayes distilled empirical risk is lower than the empirical risk. Then using analysis from Maurer & Pontil (2009); Bennett (1962), they derive the excess risk of the distilled empirical risk as a function of the ℓ2 distance between the teacher s class probabilities and the Bayes class probabilities. We significantly depart from Menon et al. (2020) in multiple ways: i) our Thm. 1 allows for the common practice of data re-use, ii) our results cover the standard KD losses SEL and ACE which are non-linear in p0, iii) we use localized Rademacher analysis to achieve tight fast rates for standard KD losses, and iv) we use techniques from semiparametric inference to improve upon vanilla KD. 2 KNOWLEDGE DISTILLATION BACKGROUND We consider a multiclass classification problem with k classes and n training datapoints zi =(xi,yi) sampled independently from some distribution P. Each feature vector x belongs to a set X, each label vector y {e1,...,ek} {0,1}k is a one-hot encoding of the class label, and the conditional probability of observing each label is the Bayes class probability function p0(x) = E[Y | X = x]. Our aim is to identify a scoring rule f : X Rk that minimizes a prediction loss on average under the distribution P. Knowledgedistillation. Knowledgedistillation(KD)isatwo-steptrainingprocesswhereonefirstuses a labeled dataset to train a teacher model and then trains a student model to predict the teacher s predicted class probabilities. Typically the teacher model is larger and more cumbersome, while the student is smaller and more efficient. Knowledge distillation was first motivated by model compression (Bucila et al., 2006), to find compact yet high-performing models to be deployed (such as on mobile devices). In training the student to match the teacher s prediction probability, there are several types of loss functions that are commonly used. Let ˆp(x) Rk be the teacher s vector of predicted class probabilities, f(x) Rk be the student model s output, and [k] {1,2,...,k}. The most popular distillation loss Published as a conference paper at ICLR 2021 functions1 ℓ(z;f(x),ˆp(x)) include the squared error logit (SEL) loss (Ba & Caruana, 2014) ℓse(z;f(x),ˆp(x)) P j [k] 1 2(fj(x) log(ˆpj(x)))2 (SEL) and the annealed cross-entropy (ACE) loss (Hinton et al., 2015) ℓβ(z;f(x),ˆp(x))= P j [k] ˆpj(x)β P l [k] ˆpl(x)β log exp(βfj(x)) P l [k]exp(βfl(x)) (ACE) for an inverse temperature β >0. These loss functions measure the divergence between the probabilities predicted by the teacher and the student. A student model trained with knowledge distillation often performs better than the same model trained from scratch (Bucila et al., 2006; Hinton et al., 2015). In Secs. 3 and 4, we will adapt modern tools from semiparametric inference to understand and enhance this phenomenon. 3 DISTILLATION AS SEMIPARAMETRIC INFERENCE In semiparametric inference (Kosorok, 2007), one aims to estimate a target parameter or function f0, but that estimation depends on an auxiliary nuisance function p0 that is unknown and not of primary interest. We cast the knowledge distillation process as a semiparametric inference problem, by treating the unknown Bayes class probabilities p0 as nuisance and the teacher s predicted probabilities as a plug-in estimate of that nuisance. This perspective allows us bound the generalization of the student in terms of the mean squared error (MSE) between the teacher and the Bayes probabilities. In the next section (Sec. 4) we use techniques from semiparametric inference to enhance the performance of the student. The interested reader could consult Tsiatis (2007) for more details on semiparametric inference. Our analysis starts from taking the following perspective on distillation. For a given pointwise loss function ℓ(z;f(x),p0(x)), we view the goal of the student as minimizing an oracle population loss over a function class F, LD(f,p0)=E[ℓ(Z;f(X),p0(X))] with f0 argminf FLD(f,p0). The main hurdle is that this is objective depends on the unknown Bayes probabilities p0. We view the teacher s model ˆp as an approximate version of p0 and bound the distillation error of the student as a function of the teacher s estimation error. Typical semiparametric inference considers cases where f0 is a finite dimensional parameter; however recent work of Foster & Syrgkanis (2019) extends this framework to infinite dimensional models f0 and to develop statistical learning theory with a nuisance component framework. The distillation problem fits exactly into this setup. Bounds on vanilla KD As a first step we derive a vanilla bound on the error of the distilled student model without any further modifications of the distillation process, i.e., we assume that the student is trained on the same data as the teacher and is trained by running empirical risk minimization (ERM) on the plug-in loss, plugging in the teacher s model instead of p0, i.e., ˆf =argminf FLn(f,ˆp) for Ln(f,ˆp) En[ℓ(Z;f(X),ˆp(X))] (Vanilla KD) where En[X]= 1 n Pn i=1Xi denotes the empirical expectation of a random variable. Technical definitions Before presenting our main theorem we introduce some technical notation. For a vector valued function f that takes as input a random variable X, we use the shorthand notation f p,q f(X) p Lq =E f(X) q p 1/q. Let φ and π denote the partial derivatives of ℓ(z;φ,π), with respect to its second and third input correspondingly and φπ the Jacobian of cross partial derivatives, i.e., [ φπℓ(z;φ,π)]i,j = 2 φj πi ℓ(z;φ,π). Finally, let qf,p(x)=E[ φπℓ(Z;f(X),p(X))|X =x] and γf,p(x)=EU Unif([0,1])[qf,Up+(1 U)p0(x)]. Critical radius Finally, we need to define the notion of the critical radius (see, e.g., Wainwright (2019, 14.1.1)) of a function class, which typically provides tight learning rates for statistical learning theory tasks. For any function class F we define the localized Rademacher complexity as: R(δ;F)=EX1:n,ϵ1:n[supf F: f 2 δ 1 n Pn i=1ϵif(Xi)] where ϵi are i.i.d. random variables taking values equiprobably in { 1,1}. The critical radius of a class F, taking values in [ H,H], is the smallest positive solution δn to the inequality R(δ;F) δ2 1These loss functions do not depend on the ground-truth label y, but we use the augmented notation ℓ(z;f(x),ˆp(x)) to accommodate the enhanced distillation losses presented in Sec. 4. Published as a conference paper at ICLR 2021 Theorem 1 (Vanilla KD analysis). Suppose f0 belongs to a convex set F satisfying the ℓ2/ℓ4 ratio condition supf F f f0 2,4/ f f0 2,2 C and that the teacher estimates ˆp P from the same dataset used to train the student. Let δn,ζ = δn+c0 q n for universal constants c0,c1 and δn an upper bound on the critical radius of the function class G {z r(ℓ(z;f(x),p(x)) ℓ(z;f0(x),p(x))):f F, p P, r [0,1]}. Let µ(z) = supφ φℓ(z;φ,ˆp(x)) 2, and assume that the loss ℓ(z;φ,π) is σ-strongly convex in φ for each z and that each g G is uniformly bounded in [ H,H]. Then the Vanilla KD ˆf satisfies ˆf f0 2 2,2 = 1 σ2 O(δ2 n,ζC2H2 µ 2 4+ γ f0,ˆp(ˆp p0) 2 2,2) with probability at least 1 ζ. Thm. 1, proved in App. C, shows that vanilla distillation yields an accurate student whenever the teacher generalizes well (i.e., ˆp p0 2,2 is small) and the student and teacher model classes F and P are not too complex. The ℓ2/ℓ4 ratio requirement can be removed at the expense of replacing µ 4 by µ =supz|µ(z)| in the final bound. Moreover, we highlight that the strong convexity requirement for ℓis satisfied by all standard distillation objectives including SEL and ACE, as it is strong convexity with respect to the output of f and not the parameters of f. Even this requirement could be removed, but this would yield slow rate bounds of the form: ˆf f0 2 2,2 =O(δn,ζ + γ f0,ˆp(ˆp p0) 2 2,2). Failure modes of vanilla KD Thm. 1 also hints at two distinct ways in which vanilla distillation could fail. First, since the student only learns from the teacher and does not have access to the original labels, we would expect the student to be erroneous when the teacher probabilities are inaccurate due to model misspecification, an overly restrictive teacher function class, or insufficient training. Prop. 2, proved in App. D, confirms that, in the worst case, student error suffers from inaccuracy due to this teacher underfitting even when both the student and teacher belong to low complexity model classes. Proposition 2 (Impact of teacher underfitting on vanilla KD). There exists a classification problem in which the following properties all hold simultaneously with high probability for f0 =log(p0): The teacher learns ˆp(x)= 1 n(1+λ) Pn i=1yi for all x X via ridge regression with λ=Θ(1/n1/4). Vanilla KD with SEL loss and constant ˆf satisfies ˆf f0 2 2,2 γ f0,p0(ˆp p0) 2 2,2 = Ω( 1 n), matching the dependence of the Thm. 1 upper bound up to a constant factor. Enhanced KD with SEL loss, ˆγ(t) =diag( 1 ˆp(t) ), and constant ˆf satisfies ˆf f0 2 2,2 =O( 1 Second, the critical radius in Thm. 1 depends on the complexity of the teacher model class P. If P has a large critical radius, then the student error bound suffers due to potential teacher overfitting even if the teacher generalizes well. Prop. 3, proved in App. E, shows that, in the worst case, this teacher overfitting penalty is unavoidable and does in fact lead to increased student error. This occurs as the student only has access to the teacher s training set probabilities which, due to overfitting, need not reflect its test set probabilities. Proposition 3 (Impact of teacher overfitting on vanilla KD). There exists a classification problem in which the following properties all hold simultaneously with high probability for f0 =E[log(p0(X))]: The critical radius δn of the teacher-student function class G in Thm. 1 is a non-vanishing constant, due to the complexity of the teacher s function class. The Vanilla KD error ˆf f0 2 2,2 for constant ˆf with SEL loss is lower bounded by a non-vanishing constant, matching the δn dependence of the Thm. 1 upper bound up to a constant factor. Enhanced KD with SEL loss, ˆγ(t) =0, and constant ˆf satisfies ˆf f0 2 2,2 =O(n 4/(4+d)). These examples serve to lower bound student performance in the worst case by the teacher s critical radius and class probability MSE, matching the upper bounds given in Thm. 1. However, we note that in other better-case scenarios vanilla distillation can perform better than the upper-bounding Thm. 1 would imply. In the next section, we adapt and generalize techniques from semiparametric inference to mitigate the effects of teacher overfitting and underfitting in all cases. 4 ENHANCING KNOWLEDGE DISTILLATION To address the two distinct inefficiencies of vanilla distillation revealed in Sec. 3, we will adapt and generalize two distinct techniques from semiparametric inference: orthogonal correction and cross-fitting. Published as a conference paper at ICLR 2021 4.1 COMBATING TEACHER UNDERFITTING WITH LOSS CORRECTION We can view the plug-in distillation loss ℓ(z;f(x),ˆp(x)) as a zeroth order Taylor approximation to the ideal loss ℓ(z;f(x),p0(x)) around ˆp. An ideal first-order approximation would take the form ℓ(z;f(x),ˆp(x))+ p0(x) ˆp(x), πℓ(z;f(x),ˆp(x)) . However, its computation also requires knowledge of p0. Nevertheless, since p0(x)=E[Y |X =x], we can always construct an unbiased estimate of the ideal first order term by replacing p0(x) with y: ℓortho(z;f(x),ˆp(x))=ℓ(z;f(x),ˆp(x))+ y ˆp(x),E[ πℓ(z;f(x),ˆp(x))|x] . (1) For standard distillation base losses like SEL and ACE, the orthogonal loss (1) has an especially simple form, as πℓ(z;f(x),ˆp(x)) is linear in f. Indeed, this is true more generally for the following class of Bregman divergence losses. Definition 1 (Bregman divergence losses). Any Bregman divergence loss function of the form ℓ(z;f(x),p(x)) Ψ(f(x)) Ψ(g(p(x))) gΨ(g(p(x))),f(x) g(p(x)) has ℓortho(z;f(x),p(x))=ℓ(z;f(x),p(x))+(y p(x)) pg(p(x)) 2 ggΨ(g(p(x)))f(x)+const (2) with the second term bilinear in f(x) and y p(x). For the SEL loss, Ψ(s) = 1 2 s 2 2, g(p) = log(p), and the correction matrix pg(p(x)) 2 ggΨ(g(p(x))) = diag( 1 p(x)). Similarly, the ACE loss falls into the class of Bregman divergence losses. We will show that orthogonal correction (1) can significantly improve student bias due to teacher underfitting; however, for our standard distillation losses (SEL and ACE), the same orthogonal correction term often introduces unreasonably large variance due to division by small probabilities appearing in the correction matrix (see Definition 1). To grant ourselves more flexibility in balancing bias and variance, we propose and analyze a family of γ-corrected losses, parameterized by a matrix valued function γ :X Rk Rk: ℓγ(z;f(x),p(x)) ℓ(z;f(x),p(x))+(y p(x)) γ(x)f(x) to mimic the bilinear structure of Bregman orthogonal losses (2). Note that we can always recover the vanilla distillation loss by taking γ 0. We denote the associated population and empirical risks by LD(f,p,γ) E[ℓγ(Z;f(X),p(X))] and Ln(f,p,γ) En[ℓγ(Z;f(X),p(X))]. Observe that at p0 the correction term is mean-zero and hence LD(f,p0,γ) is independent of γ LD(f,p0) E[ℓ(Z;f(X),p0(X))]=LD(f,p0,γ) for all γ. The γ-corrected loss has strong connections to the literature on Neyman orthogonality (Chernozhukov et al., 2018; Chernozhukov et al., 2016; Nekipelov et al., 2018; Chernozhukov et al., 2018; Foster & Syrgkanis, 2019). In particular, if the function γ is set appropriately, then one can show that the γ-corrected loss function satisfies the condition of a Neyman orthogonal loss defined by Foster & Syrgkanis (2019). We begin our analysis by showing a general lemma for any estimator ˆf, which adapts the main theorem of Foster & Syrgkanis (2019) to account for approximate orthogonality; the proof can be found in App. F. Lemma 4 (Algorithm-agnostic analysis). Consider any estimation algorithm that produces an estimate ˆf with small plug-in excess risk, i.e., LD( ˆf,ˆp,γ) LD(f0,ˆp,γ) ϵ( ˆf,ˆp,γ). If the loss LD is σ-strongly convex with respect to f and F is a convex set, then σ 4 ˆf f0 2 2,2 ϵ( ˆf,ˆp,γ)+ 1 σ (γf0,ˆp γ) (ˆp p0) 2 2,2. If, in addition, supz,φ,π,i [d] φiππℓ(z;φ,π) op M, then (γf0,ˆp γ) (ˆp p0) 2 2,2 2 (qf0,ˆp γ) (ˆp p0) 2 2,2+M 2k ˆp p0 4 2,4 . Connection to Neyman orthogonality Remarkably, if we set γ =qf0,ˆp, then the γ-corrected loss is Neyman orthogonal (Foster & Syrgkanis, 2019), and the student MSE bound depends only on the squared MSE of the teacher. Moreover, qf0,ˆp is an observable quantity for any Bregman divergence loss (Definition 1) as qf0,ˆp is independent of f0. However, we note that this setting of the γ can lead to larger variance, i.e., the achievable excess risk can be much larger than the excess risk without the correction. For instance, in the case of the SEL loss qf0,ˆp(x) = 1 ˆp(x), which can be excessively large when ˆp is close to 0, leading to a large increase in the variance of our loss. Thus, in a departure from the standard approach in semiparametric inference, we will be choosing γ in practice to balance bias and variance. Published as a conference paper at ICLR 2021 Example instantiation of student s estimation algorithm If we use plug-in empirical risk minimization, i.e., ˆf = argminf FLn(f,ˆp,γ), to estimate f0 with ˆp estimated on an independent sample, then the results of Maurer & Pontil (2009) directly imply that as long as the loss function ℓ(z;φ,π) is uniformly bounded in [ H,H], then, with probability at least 1 δ, ϵ( ˆf,ˆp,γ)=O q supf FVar(ℓγ(Z;f(X),ˆp(X)))log(τ(n)/δ) n + Hlog(τ(n)/δ) where τ(n)=N (1/n,F,2n) and N (ϵ,F,m) is the ℓ empirical covering number of function class F intheworst-caseoverallrealizationsofmdatapointsandatapproximationlevelϵ. Thisresulthastwo drawbacks: it is a slow rate result that scales as 1/ n for parametric or bounded Vapnik Chervonenkis (VC)-dimension classes, and it requires the student to be fit on a completely separate dataset from the teacher s. In the next theorem, we address both of these drawbacks: i) we invoke localized Rademacher complexity analysis to provide a fast rate result which would be of the order of 1/n for VC or parametric function classes, and ii) we use a more sophisticated data-partitioning technique called cross-fitting, which allows the student to be trained using all of the available teacher data. 4.2 COMBATING TEACHER OVERFITTING WITH CROSS-FITTING We now describe a more sophisticated version of data partitioning to make use of all data points in our student estimation, while at the same time not suffering from the sample complexity of the teacher s function space. This approach is referred to as cross-fitting (CF) in the semiparametric inference literature (see, e.g., Chernozhukov et al. (2018)): 1. Partition the dataset into B equally sized folds P1,...,PB. 2. For each fold t [B] estimate ˆp(t) and ˆγ(t) using all the out-of-fold data points. 3. Estimate ˆf by minimizing the empirical loss: ˆf =argminf F 1 n PB t=1 P i Ptℓˆγ(t)(Zi;f(Xi),ˆp(t)(Xi)). (Enhanced KD) In other words, the nuisance estimates (ˆγ(t),ˆp(t)) that are evaluated on the data points in fold t when fitting the student in step 3, are estimated only using data points outside of Pt. Theorem 5 (Enhanced KD analysis). Suppose f0 belongs to a convex set F. Let δn/B,ζ/B = δn/B + c0 q Blog(c1B/ζ) n for universal constants c0, c1 and δn/B an upper bound on the critical radius of the class G(ˆp(t),ˆγ(t))={z r ℓˆγ(t)(z;f(x),ˆp(t)(x)) ℓˆγ(t)(z;f0(x),ˆp(t)(x)) :f F,r [0,1]} for each t [B]. Let µ(z) = supf F,t [B] φℓˆγ(t)(z;f(X),ˆp(t)(x)) 2, and assume that, with probability 1 for each t [B], the loss ℓˆγ(t)(z;φ, ˆp(t)(x)) is σ-strongly convex in φ for each z and each g G(ˆp(t),ˆγ(t)) is uniformly bounded in [ H,H]. Moreover, suppose that the function class F satisfies the ℓ2/ℓ4 ratio condition: supf F f f0 2,4 f f0 2,2 C. If ˆf is the output of Enhanced KD, then, with probability at least 1 ζ, σ 8 ˆf f0 2 2,2 = 1 σO δ2 n/B,ζ/BC2H2 µ 2 4+ 1 E (Y ˆp(t)(X)) ˆγ(t)(X) 4 2 B PB t=1 (γf0,ˆp(t) ˆγ(t)) (ˆp(t) p0) 2 2,2). The proof is found in App. G. Observe that, unlike Thm. 1, the function classes G(ˆp(t),ˆγ(t)) in the Thm. 5 do not vary the teacher s model over P but rather evaluate p at the specific out-of-fold estimates ˆp(t) and only vary f F. Since in practice the teacher s model can be quite complex, removing this dependence on the sample complexity of the teacher s function space can bring immense improvement with the critical radius of G(ˆp(t),ˆγ(t)) significantly smaller than that of G from Thm. 1. For instance, suppose that the loss function ℓˆγ(t)(z;f,ˆp(t)) is L-Lipschitz with respect to f and that F is a VC-subgraph class with VC dimension d F. Then the critical radius of the function class G(ˆp(t),ˆγ(t)) is of order p d Flog(n)/n for any choice of (ˆp(t),ˆγ(t)) (see, e.g., Foster & Syrgkanis, 2019, Sec. 4.2).2 However, under the same conditions, the critical radius of the teacher-student function class G in 2In fact, under the Lipschitz condition alone and using contraction lemma arguments as in Foster & Syrgkanis (2019, Lem. 11), one can derive a version of Thm. 5 in which the upper bound depends only on the critical radius of the function class {r(f f0):f F,r [0,1]}, which solely depends on the function space of the student. Published as a conference paper at ICLR 2021 Thm. 1 will still depend on the teacher s function space. If P is also a VC-subgraph class with VC dimension d P d F, then the critical radius of G will be of the much larger order p d Plog(n)/n. We can also see in the bound of Thm. 5 the interplay between bias and variance introduced by γ. In particular, the part of the bound that depends on ˆγ(t) can be further simplified as q E[δ4 n,ζC4 (Y ˆp(X)) ˆγ(t)(X) 4 2+ (γˆp,0(X) ˆγ(t)(X)) (ˆp(X) p0(X)) 4 2], (3) where the terms respectively encode the increase in variance and decrease in bias from employing loss correction. Notably, Thm. 5 implies that CF without γ-correction (i.e., ˆγ(t)(x)=0) is sufficient to reduce student error due to teacher overfitting but may still be susceptible to excessive student error due to teacher underfitting. These qualitative predictions accord with our experimental observations in Sec. 5 and Fig. 5. 4.3 BIASED STOCHASTIC GRADIENT DESCENT ANALYSIS When the set of candidate prediction rules fθ is parameterized by a vector θ Rd, we may alternatively fit θ via stochastic gradient descent (SGD) (Robbins & Monro, 1951; Bottou & Bousquet, 2008) on the γ-corrected objective LD(fθ,ˆp,ˆγ). With a minibatch size of 1 and a starting point θ0, the parameter updates take the form θt+1 =θt ηt θfθ(Xt) φℓγ(Wt;fθ(Xt),p(Xt)) for t+1 [n]. (4) Ideally, these updates would converge to a minimizer of the ideal risk L(θ;p0) = LD(fθ,p0). Our next result shows that, if the teacher ˆp is independent of (Wt)t [n], then the SGD updates (4) have excess ideal risk governed by a bias term ζ(ˆγ) and a variance term σ(ˆγ)2/n. Here, σ2 0(θ) represents the baseline stochastic gradient variance that would be incurred if SGD were run directly on the ideal risk L(θ;p0) rather than our surrogate risk. Our proof in App. H builds upon the biased SGD bounds of Ajalloeian & Stich (2020). Theorem 6 (Biased SGD analysis). Suppose that the loss L(θ;p0) is λ-strongly smooth in θ. Define the bias and root-variance parameters ζ(ˆγ) supθ Rd θf θ (γfθ,ˆp ˆγ) (ˆp p0) 2,2 σ(ˆγ) supθ Rdσ0(θ)+ q E[ θfθ(X) γ(X) (Y p0(X)) 2 2]+ θf θ (γfθ,ˆp ˆγ) (ˆp p0) 2 2,2 for σ2 0(θ) P i [d] Var[ θiℓ(W; fθ(X), p0(X))] the unbiased SGD variance. If F0 =L(θ0;p0) minθ Rd L(θ;p0), then the iterates{θt}n t=1 of the loss corrected SGD algorithm satisfy mint [n]E θL(θt;p0) 2 2 = O σ(ˆγ) λF0 n +ζ2(ˆγ) . If, in addition, L(θ;p0) is µ-strongly convex in θ, then the iterates satisfy E[L(θn;p0) minθ Rd L(θ;p0)]= 1 n +ζ2(ˆγ))+O F0e µ Similar to Thm. 5, the bound in Thm. 6, portrays the interplay of bias and variance as ˆγ ranges from 0 to qfθ,ˆp (recall that qfθ,ˆp is independent of fθ for any Bregman loss). In particular, the part of the bound for strongly convex losses that depends on ˆγ can be further simplified to: E h λ ˆγ(X) 2 2 Y p0(X) 2 2 µn + (γfθ,ˆp(X) ˆγ(X)) (ˆp(X) p0(X)) 2 2 θfˆθ(X) 2 2 i (5) This has a very intuitive form: the first term is the impact of ˆγ(X) on the variance, which is also related to the square of the noise of y, divided by the standard error scaling. The second controls how ˆγ improves the bias introduced by the error in the teacher s ˆp. 5 EXPERIMENTS We complement our theoretical analysis with a pair of experiments demonstrating the practical benefits of cross-fitting and loss correction on six real-world classification tasks. Throughout, we use the SEL loss and report mean performance 1 standard error across 5 independent runs. Code to replicate all experiments can be found at https://github.com/microsoft/semiparametric-distillation, and supplementary experimental details and results can be found in App. I. Published as a conference paper at ICLR 2021 Selecting the loss correction matrix ˆγ Motivated by the analyses in Sec. 4, for each training point (x,y), we will select our correction matrix ˆγ(x) to balance bias and variance by minimizing a pointwise upper bound on the loss correction error (5) (ideally with a closed-form solution to avoid excessive computational overhead).3 To eliminate dependence on the unobserved p0, we observe that the bias term (γfθ,ˆp(x) ˆγ(x)) (ˆp(x) p0(x)) 2 2 = O( qfθ,ˆp(x) ˆγ(x) 2 op) up to additive terms independent of ˆγ. We introduce a tunable hyperparameter α>0 to trade off between this bias bound and the variance term in (5) and select ˆγ(x)=diag(v(x)) to minimize: E[ ˆγ(x)(y ˆp(x)) 2 2 |x]+α qfθ,ˆp(x) ˆγ(x) 2 op =E[ v(x)(y ˆp(x)) 2 2 |x]+α 1 ˆp(x) v(x) 2 2. Since the conditional expectation involves the unknown quantity p0, we estimate E[ v(x)(y ˆp(x)) 2 2 | x] with its sample v(x)(y ˆp(x)) 2 2.4 This objective is quadratic in v(x) and thus has a closed-form solution. Given ˆγ(x), the student s loss-corrected objective is equivalent to a square loss with labels log(p(x))+ˆγ(x) (y p(x)). Tabular data. We first validate our KD enhancements on five real-world tabular datasets FICO (FIC), Stumble Upon (Eve; Liu et al., 2017), and Adult, Higgs, and MAGIC from Dheeru & Karra Taniskidou (2017) with random forest (Breiman, 2001) students and teachers. In Fig. 1a, we examine the impact of varying student model capacity for a fixed high-capacity teacher with 500 trees on FICO. This setting lends itself to teacher overfitting, and we find that cross-fitting consistently improves upon vanilla KD by up to 4 AUC percentage points. In Fig. 1b we explore the impact of teacher underfitting by limiting the teacher s maximum tree depth on Adult. Here we observe consistent gains from loss correction with student performance exceeding even that of the teacher for smaller maximum tree depths. Analogous results for the remaining datasets can be found in App. I.1. 123 5 10 15 20 30 40 Student's number of trees Student without KD KD Student Cross-fit KD Student Enhanced KD Student Teacher: 500 trees (a) FICO dataset, when teacher overfits 1 2 3 5 10 15 20 Teacher's max tree depth Student without KD KD Student: 10 trees, max depth= Cross-fit KD Student Enhanced KD Student Teacher: 100 trees (b) Adult dataset, when teacher underfits Figure 1: For random forest students and teachers, cross-fitting improves student performance when the teacher overfits, while loss correction improves student performance when the teacher underfits. Image data. We next validate our KD enhancements on the image classification dataset CIFAR-10 (Krizhevsky & Hinton, 2009). We pair a residual network (Res Net-8) student with teacher networks of varying depths (Res Net-14/20/32/44/56) (He et al., 2016). It has been observed that larger and deeper teachers need not yield better students, as the teacher might overfit to the training set (Cho & Hariharan, 2019; M uller et al., 2019). To induce this overfitting, we turn off data augmentation (random horizontal flipping and cropping). We compare students trained with Vanilla KD and Enhanced KD with and without loss correction in Fig. 2. We find that cross-fitting consistently reduces the effect of teacher overfitting with largest impact realized for the deepest models. This effect is most evident in the cross-entropy test loss, where the Vanilla KD student incurs significantly larger loss than the cross-fitted student. For both accuracy and test loss, employing loss correction on top of cross-fitting provides an additional small performance boost. Effect of the loss correction hyperparameter α. Our hyperparameter α controls the tradeoff between bias and variance in loss correction. When α is very small, the objective is close to the vanilla KD objective. When α is large, the objective is closer to the Neyman-orthogonal loss. In Figure 3, we show the effect of varying α, with Res Net-8 as the student and Res Net-20 as the teacher, on the CIFAR-10 dataset. Large values of α lead to high variance and thus lower test accuracy. Intermediate values of α improves on both the Vanilla KD objective, which corresponds to α=0 and on the orthogonal objective 3Balancing the bias and variance terms (3) of Thm. 5 yields a similar objective. 4An alternative estimate that performs slightly worse is v(x)ˆp(x)(1 ˆp(x)) 2 2. Published as a conference paper at ICLR 2021 14 20 32 44 56 Teacher's network depth Test accuracy 14 20 32 44 56 Teacher's network depth KD Student Cross-fit KD Student Enhanced KD Student Teacher Cross-fit Teacher Figure 2: On CIFAR-10 with Res Net students and teachers, cross-fitting reduces the effect of teacher overfitting, and loss correction yields an additional small performance boost. Here, the test loss is cross-entropy. (α = ). The test accuracy drops sharply beyond some threshold of α as the variance becomes too high (due to the terms qˆp(x)=diag 1 ˆp1(x),..., 1 ˆp K(x) ), causing training to become unstable. 0 10 4 10 3 10 2 10 1 Test accuracy Enhanced KD Student 0 10 4 10 3 10 2 10 1 Enhanced KD Student Figure 3: On CIFAR-10 with Res Net students and teachers, large values of the loss correction hyperparameter α (corresponding to the orthogonal loss correction) lead to large variance and training instability, while intermediate values improve upon cross-fit KD without loss correction (α=0). Here, the test loss is cross-entropy. 6 CONCLUSION We developed a new analysis of knowledge distillation under the lens of semiparametric inference. By framing the KD process as learning with plug-in estimation in the presence of nuisance, we obtained new generalization bounds for distillation and new lower bounds highlighting the susceptibility of KD to teacher overfitting and underfitting. To address these failure modes, we introduced two complementary KD enhancements cross-fitting and loss correction which improve student performance both in theory and in practice. Past work has shown that augmenting the student training set with synthetic data from a generative model (e.g., a generative adversarial network (Liu et al., 2018) or MUNGE (Bucila et al., 2006)) often leads to improved student performance. A natural next step is to prove an analogue of Thm. 5 for synthetic augmentation to understand when this strategy successfully mitigates the impact of teacher overfitting. In addition, two tantalizing open questions are, first, whether other techniques from semiparametric inference, such as targeted maximum likelihood (Van Der Laan & Rubin, 2006), can be used to improve KD performance and, second, whether a semiparametric perspective can explain the surprising success of self-distillation (Furlanello et al., 2018) and noisy student training (Xie et al., 2020) through which students routinely outperform their teachers. Stumbleupon evergreen dataset. https://www.kaggle.com/c/stumbleupon. Published as a conference paper at ICLR 2021 FICO: Explanable machine learning challenge. https://community.fico.com/s/ explainable-machine-learning-challenge. Ahmad Ajalloeian and Sebastian U Stich. Analysis of sgd with biased gradient estimators. ar Xiv preprint ar Xiv:2008.00051, 2020. Jimmy Ba and Rich Caruana. Do deep nets really need to be deep? In Advances in neural information processing systems, pp. 2654 2662, 2014. Mikhail Belkin, Alexander Rakhlin, and Alexandre B Tsybakov. Does data interpolation contradict statistical optimality? In The 22nd International Conference on Artificial Intelligence and Statistics, pp. 1611 1619. PMLR, 2019. George Bennett. Probability inequalities for the sum of independent random variables. Journal of the American Statistical Association, 57(297):33 45, 1962. Sergei Bernstein. The theory of probabilities. Gastehizdat Publishing House, 1946. L eon Bottou and Olivier Bousquet. The tradeoffs of large scale learning. In Advances in neural information processing systems, pp. 161 168, 2008. Leo Breiman. Random forests. Machine learning, 45(1):5 32, 2001. Leo Breiman and Nong Shang. Born again trees. University of California, Berkeley, Berkeley, CA, Technical Report, 1:2, 1996. Yuheng Bu, Weihao Gao, Shaofeng Zou, and Venugopal V Veeravalli. Information-theoretic understanding of population risk improvement with model compression. In AAAI, pp. 3300 3307, 2020. Cristian Bucila, Rich Caruana, and Alexandru Niculescu-Mizil. Model compression. In Proceedings of the Twelfth ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, Philadelphia, PA, USA, August 20-23, 2006, pp. 535 541, 2006. doi: 10.1145/1150402.1150464. Yevgen Chebotar and Austin Waters. Distilling knowledge from ensembles of neural networks for speech recognition. In Interspeech, pp. 3439 3443, 2016. Wei-Chun Chen, Chia-Che Chang, and Che-Rung Lee. Knowledge distillation with feature maps for image classification. In Asian Conference on Computer Vision, pp. 200 215. Springer, 2018. Xu Cheng, Zhefan Rao, Yilan Chen, and Quanshi Zhang. Explaining knowledge distillation by quantifying the knowledge. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12925 12935, 2020. Victor Chernozhukov, Juan Carlos Escanciano, Hidehiko Ichimura, Whitney K. Newey, and James M. Robins. Locally Robust Semiparametric Estimation. ar Xiv e-prints, art. ar Xiv:1608.00033, July 2016. Victor Chernozhukov, Denis Chetverikov, Mert Demirer, Esther Duflo, Christian Hansen, Whitney Newey, and James Robins. Double/debiased machine learning for treatment and structural parameters. The Econometrics Journal, 21(1):C1 C68, 2018. doi: 10.1111/ectj.12097. URL https://onlinelibrary.wiley.com/doi/abs/10.1111/ectj.12097. Victor Chernozhukov, Whitney Newey, and Rahul Singh. De-Biased Machine Learning of Global and Local Parameters Using Regularized Riesz Representers. ar Xiv e-prints, art. ar Xiv:1802.08667, February 2018. Jang Hyun Cho and Bharath Hariharan. On the efficacy of knowledge distillation. In Proceedings of the IEEE International Conference on Computer Vision, pp. 4794 4802, 2019. Mark Craven and Jude W Shavlik. Extracting tree-structured representations of trained networks. In Advances in neural information processing systems, pp. 24 30, 1996. Dua Dheeru and EfiKarra Taniskidou. UCI machine learning repository, 2017. URL http://archive.ics.uci.edu/ml. Published as a conference paper at ICLR 2021 Rasool Fakoor, Jonas Mueller, Nick Erickson, Pratik Chaudhari, and Alexander J Smola. Fast, accurate, and simple models for tabular data via augmented distillation. ar Xiv preprint ar Xiv:2006.14284, 2020. Dylan J Foster and Vasilis Syrgkanis. Orthogonal statistical learning. ar Xiv preprint ar Xiv:1901.09036, 2019. Markus Freitag, Yaser Al-Onaizan, and Baskaran Sankaran. Ensemble distillation for neural machine translation. ar Xiv preprint ar Xiv:1702.01802, 2017. Tommaso Furlanello, Zachary C Lipton, Michael Tschannen, Laurent Itti, and Anima Anandkumar. Born again neural networks. ar Xiv preprint ar Xiv:1805.04770, 2018. Yotam Gil, Yoav Chai, Or Gorodissky, and Jonathan Berant. White-to-black: Efficient distillation of black-box adversarial attacks. ar Xiv preprint ar Xiv:1904.02405, 2019. Micah Goldblum, Liam Fowl, Soheil Feizi, and Tom Goldstein. Adversarially robust distillation. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 34, pp. 3996 4003, 2020. Jianping Gou, Baosheng Yu, Stephen John Maybank, and Dacheng Tao. Knowledge distillation: A survey. ar Xiv preprint ar Xiv:2006.05525, 2020. Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770 778, 2016. Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. ar Xiv preprint ar Xiv:1503.02531, 2015. Minghao Hu, Yuxing Peng, Furu Wei, Zhen Huang, Dongsheng Li, Nan Yang, and Ming Zhou. Attention-guided answer distillation for machine reading comprehension. ar Xiv preprint ar Xiv:1808.07644, 2018. Xiaoqi Jiao, Yichun Yin, Lifeng Shang, Xin Jiang, Xiao Chen, Linlin Li, Fang Wang, and Qun Liu. Tinybert: Distilling bert for natural language understanding. ar Xiv preprint ar Xiv:1909.10351, 2019. Michael R Kosorok. Introduction to empirical processes and semiparametric inference. Springer Science & Business Media, 2007. Alex Krizhevsky and Geoffrey Hinton. Learning multiple layers of features from tiny images. Technical report, Citeseer, 2009. Jinyu Li, Rui Zhao, Jui-Ting Huang, and Yifan Gong. Learning small-size dnn with output-distributionbased criteria. In Fifteenth annual conference of the international speech communication association, 2014. Quanquan Li, Shengying Jin, and Junjie Yan. Mimicking very efficient network for object detection. In Proceedings of the ieee conference on computer vision and pattern recognition, pp. 6356 6364, 2017. Zhizhong Li and Derek Hoiem. Learning without forgetting. IEEE transactions on pattern analysis and machine intelligence, 40(12):2935 2947, 2017. Ruishan Liu, Nicolo Fusi, and Lester Mackey. Teacher-student compression with generative adversarial networks. ar Xiv preprint ar Xiv:1812.02271, 2018. Yu Liu, Hantian Zhang, Luyuan Zeng, Wentao Wu, and Ce Zhang. MLBench: How good are machine learning clouds for binary classification tasks on structured data. Ar Xiv e-prints, 2017. Raphael Gontijo Lopes, Stefano Fenu, and Thad Starner. Data-free knowledge distillation for deep neural networks. ar Xiv preprint ar Xiv:1710.07535, 2017. David Lopez-Paz, L eon Bottou, Bernhard Sch olkopf, and Vladimir Vapnik. Unifying distillation and privileged information. ar Xiv preprint ar Xiv:1511.03643, 2015. Published as a conference paper at ICLR 2021 Liang Lu, Michelle Guo, and Steve Renals. Knowledge distillation for small-footprint highway networks. In 2017 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 4820 4824. IEEE, 2017. Andreas Maurer and Massimiliano Pontil. Empirical bernstein bounds and sample variance penalization. ar Xiv preprint ar Xiv:0907.3740, 2009. Aditya Krishna Menon, Ankit Singh Rawat, Sashank J Reddi, Seungyeon Kim, and Sanjiv Kumar. Why distillation helps: a statistical perspective. ar Xiv preprint ar Xiv:2005.10419, 2020. Hossein Mobahi, Mehrdad Farajtabar, and Peter L Bartlett. Self-distillation amplifies regularization in hilbert space. ar Xiv preprint ar Xiv:2002.05715, 2020. Lili Mou, Ran Jia, Yan Xu, Ge Li, Lu Zhang, and Zhi Jin. Distilling word embeddings: An encoding approach. In Proceedings of the 25th ACM International on Conference on Information and Knowledge Management, pp. 1977 1980, 2016. Rafael M uller, Simon Kornblith, and Geoffrey E Hinton. When does label smoothing help? In Advances in Neural Information Processing Systems, pp. 4694 4703, 2019. Elizbar A Nadaraya. On estimating regression. Theory of Probability & Its Applications, 9(1): 141 142, 1964. Ndapandula Nakashole and Raphael Flauger. Knowledge distillation for bilingual dictionary induction. In Proceedings of the 2017 conference on empirical methods in natural language processing, pp. 2497 2506, 2017. Denis Nekipelov, Vira Semenova, and Vasilis Syrgkanis. Regularized Orthogonal Machine Learning for Nonlinear Semiparametric Models. ar Xiv e-prints, art. ar Xiv:1806.04823, June 2018. Aaron Oord, Yazhe Li, Igor Babuschkin, Karen Simonyan, Oriol Vinyals, Koray Kavukcuoglu, George Driessche, Edward Lockhart, Luis Cobo, Florian Stimberg, et al. Parallel wavenet: Fast high-fidelity speech synthesis. In International conference on machine learning, pp. 3918 3926. PMLR, 2018. Nicolas Papernot, Mart ın Abadi, Ulfar Erlingsson, Ian Goodfellow, and Kunal Talwar. Semi-supervised knowledge transfer for deep learning from private training data. ar Xiv preprint ar Xiv:1610.05755, 2016a. Nicolas Papernot, Patrick Mc Daniel, Xi Wu, Somesh Jha, and Ananthram Swami. Distillation as a defense to adversarial perturbations against deep neural networks. In 2016 IEEE Symposium on Security and Privacy (SP), pp. 582 597. IEEE, 2016b. Mary Phuong and Christoph Lampert. Towards understanding knowledge distillation. In International Conference on Machine Learning, pp. 5142 5151, 2019. Herbert Robbins and Sutton Monro. A stochastic approximation method. Ann. Math. Statist., 22(3):400 407, 09 1951. doi: 10.1214/aoms/1177729586. URL https://doi.org/10.1214/aoms/1177729586. Andrew Slavin Ross and Finale Doshi-Velez. Improving the adversarial robustness and interpretability of deep neural networks by regularizing their input gradients. ar Xiv preprint ar Xiv:1711.09404, 2017. Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter. ar Xiv preprint ar Xiv:1910.01108, 2019. Peng Shen, Xugang Lu, Sheng Li, and Hisashi Kawai. Feature representation of short utterances based on knowledge distillation for spoken language identification. In Interspeech, pp. 1813 1817, 2018. Pierre Stock, Armand Joulin, R emi Gribonval, Benjamin Graham, and Herv e J egou. And the bit goes down: Revisiting the quantization of neural networks. 2020. Sarah Tan, Rich Caruana, Giles Hooker, Paul Koch, and Albert Gordo. Learning global additive explanations for neural nets using model distillation. ar Xiv preprint ar Xiv:1801.08640, 2018. Published as a conference paper at ICLR 2021 Jiaxi Tang, Rakesh Shivanna, Zhe Zhao, Dong Lin, Anima Singh, Ed H Chi, and Sagar Jain. Understanding and improving knowledge distillation. ar Xiv preprint ar Xiv:2002.03532, 2020. Anastasios Tsiatis. Semiparametric theory and missing data. Springer Science & Business Media, 2007. Mark J Van Der Laan and Daniel Rubin. Targeted maximum likelihood learning. The international journal of biostatistics, 2(1), 2006. Vladimir Vapnik and Rauf Izmailov. Learning using privileged information: similarity control and knowledge transfer. J. Mach. Learn. Res., 16(1):2023 2049, 2015. Martin J. Wainwright. High-Dimensional Statistics: A Non-Asymptotic Viewpoint. Cambridge Series in Statistical and Probabilistic Mathematics. Cambridge University Press, 2019. doi: 10.1017/9781108627771. Chong Wang, Xipeng Lan, and Yangang Zhang. Model distillation with knowledge transfer from face classification to alignment and verification. ar Xiv preprint ar Xiv:1709.02929, 2017. Ji Wang, Weidong Bao, Lichao Sun, Xiaomin Zhu, Bokai Cao, and S Yu Philip. Private model compression via knowledge distillation. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 33, pp. 1190 1197, 2019. Shinji Watanabe, Takaaki Hori, Jonathan Le Roux, and John R Hershey. Student-teacher network learning with enhanced features. In 2017 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 5275 5279. IEEE, 2017. Geoffrey S Watson. Smooth regression analysis. Sankhy a: The Indian Journal of Statistics, Series A, pp. 359 372, 1964. Qizhe Xie, Minh-Thang Luong, Eduard Hovy, and Quoc V Le. Self-training with noisy student improves imagenet classification. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10687 10698, 2020. Published as a conference paper at ICLR 2021 A EXTENDED LITERATURE REVIEW We point the interested reader to Gou et al. (2020) for a sweeping survey of the many developments in knowledge distillation over the past half decade. In addition to the references discussing theoretical aspects of knowledge distillation provided in Sec. 1, we highlight here a number of empirical investigations of why distillation works. Cho & Hariharan (2019) show that larger teacher models do not necessarily improve the performance of student models as parsimonious student models are not able to mimic the teacher model. They suggest early stopping in training large teacher neural networks as means of regularizing. Cheng et al. (2020) demonstrate that when applied to image data, distillation allows the student neural net to learn multiple visual concepts simultaneously, while, when learning from raw data, neural networks learn concepts sequentially. Knowledge distillation has also been used for adversarial attacks (Papernot et al., 2016b; Ross & Doshi-Velez, 2017; Gil et al., 2019; Goldblum et al., 2020), data security (Papernot et al., 2016a; Lopes et al., 2017; Wang et al., 2019), image processing (Li & Hoiem, 2017; Wang et al., 2017; Chen et al., 2018; Li et al., 2017), natural language processing (Nakashole & Flauger, 2017; Mou et al., 2016; Hu et al., 2018; Freitag et al., 2017), and speech processing (Chebotar & Waters, 2016; Lu et al., 2017; Watanabe et al., 2017; Oord et al., 2018; Shen et al., 2018). Table 1: Glossary of notation Notation Definition ℓ(Z;f(X),p0(X)) Loss function on a random data point Population risk LD(f,p) E[ℓ(Z;f(X),p(X))] Empirical risk Ln(f,p) En[ℓ(Z;f(X),p(X))] Population optimal student model f0 argminf FLD(f,p0) Empirical optimal student model ˆf argminf FLn(f,ˆp) f p,q f(X) p Lq =E f(X) q p 1/q φ Partial derivative of ℓ(z;φ,π) with respect to the second input π Partial derivative of ℓ(z;φ,π) with respect to the third input φπ [ φπℓ(z;φ,π)]i,j = 2 φj πi ℓ(z;φ,π) qf,p(x) E[ φπℓ(Z;f(X),p(X))|X =x] γf,p(x) EU Unif([0,1])[qf,Up+(1 U)p0(x)] R(δ;F) Localized Rademacher complexity of function class F δn Critical radius γ-corrected loss ℓγ(z;f(x),p(x)) ℓ(z;f(x),p(x))+(y p(x)) γ(x)f(x) Population γ-risk LD(f,p,γ) E[ℓγ(Z;f(X),p(X))] Empirical γ-risk Ln(f,p,γ) En[ℓγ(Z;f(X),p(X))] C PROOF OF THM. 1: VANILLA DISTILLATION ANALYSIS Introduce the shorthand ℓf,ˆp(z)=ℓ(z;f(x),ˆp(x)). Since δn upper bounds the critical radius of the function class G, the localized Rademacher analysis of Foster & Syrgkanis (2019, Lem. 11) implies5 Ln( ˆf,ˆp) Ln(f0,ˆp) (LD( ˆf,ˆp) LD(f0,ˆp)) O Hδn,ζ ℓˆ f,ˆp ℓf0,ˆp 2,2+Hδ2 n,ζ with probability at least 1 ζ. Moreover, by Cauchy-Scwharz, ℓˆ f,ˆp ℓf0,ˆp 2,2 µ 4 ˆf f0 2,4. By the assumed ℓ2/ℓ4 ratio condition we therefore have ϵ( ˆf,ˆp,γ) O δn,ζCH µ 4 ˆf f0 2,2+Hδ2 n,ζ . Plugging this bound into Lemma 4 (which holds irrespective of whether data re-use, sample splitting, or cross-fitting is employed) and applying the arithmetic-geometric mean inequality yields σ 8 ˆf f0 2 2,2 1 σ O δ2 n,ζC2H2 µ 2 4+ γ ˆp,0(ˆp p0) 2 2,2 5We apply Foster & Syrgkanis (2019, Lem. 11) with Lg = g for g G with g = 0. Then we instantiate the concentration inequality for the choice g =ℓˆ f,ˆp ℓf0,ˆp G. Published as a conference paper at ICLR 2021 D PROOF OF PROP. 2: IMPACT OF TEACHER UNDERFITTING ON VANILLA DISTILLATION Suppose that p0 does not vary with x and, for known ϵ>0, belongs to the set P ={p : pj(x) [ϵ,1], x X, j [k]}. As all quantities in this proof are independent of x, we will omit the dependence on x whenever convenient. Consider the constant teacher estimate ˆp= y 1+λ ϵ obtained via ridge regression with regularization strength λ 1 and y 1 n Pn i=1yi. A constant student prediction rule in F ={f : fj(x) [log(ϵ),0], x X, j [k]} trained via Vanilla KD with SEL loss yields ˆf(x)=log(ˆp). Suppose that, unbeknownst to the teacher and student, the true p0 satisfies the more stringent condition p0,j 2ϵ for all j [k]. Then the student satisfies f0 ˆf =log(p0) log(ˆp) diag( 1 p0 )(p0 ˆp)=γ f0,p0(p0 ˆp)=γ f0,p0( λp0 ( y p0) (1+λ) +min(0, y 1+λ ϵ)) =γ f0,p0( λp0 ( y p0) (1+λ) +min(0, y p0+p0 (1+λ)ϵ 1+λ )) γ f0,p0 λp0 | y p0| (1+λ) by the concavity of the logarithm and the choice λ 1. Since P(| yj p0,j| θp0,j) 2ζ np0,j log( k 3 1 np0,j log( k ζ ) by Bernstein s inequality (Bernstein, 1946), we have P( f0 ˆf 2 2,2 γ f0,p0(p0 ˆp) 2 2,2) P(f0,j ˆfj 0, j [k]) P(λp0,j | yj p0,j|, j [k]) 1 2ζ whenever q 3 1 nϵlog( k Moreover, since limsup n n| yj p0,j| 2p0,j(1 p0,j)loglog(n) = 1 with probability 1 by the law of the iterated logarithm, γ f0,p0(p0 ˆp) 2 2,2 =Ω(min(1,λ2)) with probability 1 whenever 1 λ q nϵ . The choice λ=min 1,max 1 n1/4 , q 3 1 nϵlog( k ζ ) =Θ( 1 n1/4 ) now yields the first two advertised claims. The final claim follows directly from Thm. 5 with B = O(1) as ˆγ(t) = γf0,ˆp and the critical radius of G(ˆp(t), ˆγ(t)) satisfies δn/B =O( p Bk/n) by Wainwright (2019, Ex. 13.8). E PROOF OF PROP. 3: IMPACT OF TEACHER OVERFITTING ON VANILLA DISTILLATION Suppose that p0 has Lipschitz gradient and, for known ϵ>0, belongs to the set P ={p : pj(x) [ϵ,1], x X, j [k]}. Suppose moreover that X Rd has Lebesgue density bounded away from 0 and and that ϵ< 1 4 E[p0,j(X)(1 p0,j(X))2] E[(1 p0,j(X))/p0,j(X)] for each j. Consider the teacher estimates ˆpj(x)=max(ϵ, pj(x)) for p the Nadaraya-Watson kernel smoothing estimator (Nadaraya, 1964; Watson, 1964) p(x) yi if x=xi Pn i=1yi K((x xi)/h)/Pn i=1K((x xi)/h) otherwise with kernel K(x) = x a 2 I[ x 2 1], a (0,d/2), and h = n 1/(4+d). By Belkin et al. (2019, Thm. 1), the teacher satisfies E[ p0 ˆp 2 2,2]=O(n 4/(4+d)). Now instantiate the notation of Thm. 1, and consider a student prediction rule trained to learn a constant prediction rule via Vanilla KD with the SEL loss and F ={f : f(x)=f(x ) [log(ϵ),0]k for all x,x X}. (6) Since p exactly interpolates the observed labels (i.e., p(xi) = yi), the critical radius of the teacher-student function class G satisfies δn = Ω(1). Moreover, since the student only has access to the teacher s training set probabilities, its estimate ˆf(x)= 1 n Pn i=1log(max(yi,ϵ)) is inconsistent for the optimal constant rule f0(x)=E[log(p0(X))] as f0,j(x) E ˆfj(x)=E[log(p0,j(X)) log(max(Yj,ϵ))] E[ p0,j(X) max(Yj,ϵ) p0,j(X) + (max(Yj,ϵ) p0,j(X))2 =E[ p0,j(X)(1 p0,j(X))2+(1 p0,j(X))(p0,j(X) ϵ)2 2 ] ϵE[ 1 p0,j(X) p0,j(X) ] E[ p0,j(X)(1 p0,j(X))2 Published as a conference paper at ICLR 2021 by Taylor s theorem with Lagrange remainder. This non-vanishing student error reflects the non-vanishing critical radius δn of the composite student-teacher function class G defined in Thm. 1; since the student function class F has low complexity, the complexity of G is driven by the highly flexible interpolating teacher. Next, instantiate the notation of Thm. 5, and consider a student prediction rule ˆf trained via Enhanced KD with SEL loss, ˆγ(t) =0, B =O(1), and F (6). The critical radius of G(ˆp(t),ˆγ(t)) satisfies δn/B =O( p Bk/n) by Wainwright (2019, Ex. 13.8). Moreover, each cross-fitted teacher satisfies E[ p0 ˆp(t) 2 2,2]=O(n 4/(4+d)) by Belkin et al. (2019, Thm. 1), so, by Chebyshev s and Jensen s inequalities, with probability at least 1 ζ/2, p0 ˆp(t) 2,2 E[ p0 ˆp(t) 2,2]+ p 2BVar( p0 ˆp(t) 2,2)/ζ E[ p0 ˆp(t) 2 2,2]=O(n 2/(4+d)) for all t. Therefore, Thm. 5 implies that ˆf f0 2 2,2 =O( 1 B PB t=1 (γf0,ˆp(t)) (ˆp(t) p0) 2 2,2) B PB t=1 (diag( 1 ˆp(t) )(ˆp(t) p0) 2 2,2) n + 1 Bϵ2 PB t=1 ˆp(t) p0 2 2,2)=O(n 4/(4+d)) with probability at least 1 ζ. F PROOF OF LEMMA 4: ALGORITHM-AGNOSTIC ANALYSIS First we define for any functional L(f) the Frechet derivative as: Df L(f)[ν]= t L(f +tν)|t=0 When L is an operator of the form: E[g(f(X))], then: Df L(f)[ν]=E[ g(f(X)) ν(X)]. By the σ-strong convexity of LD,6 we have that LD( ˆf,ˆp,γ) LD(f0,ˆp,γ)+Df LD(f0,ˆp,γ)[ ˆf f0]+ σ 2 ˆf f0 2 2,2. Furthermore, our excess risk assumption and the optimality of f0 give us σ 2 ˆf f0 2 2,2 LD( ˆf,ˆp,γ) LD(f0,ˆp,γ) | {z } excess risk of ˆ f Df LD(f0,ˆp,γ)[ ˆf f0] (a) ϵ( ˆf,ˆp,γ) Df LD(f0,p0,γ)[ ˆf f0] | {z } 0 by optimality of f0 +Df(LD(f0,p0,γ) LD(f0,ˆp,γ))[ ˆf f0]. By Taylor s theorem with integral remainder, E[ φℓ(W;f0(x),p0(x)) φℓ(W;f0(x),ˆp(x)), ˆf(x) f0(x) |X =x] (7) =(p0(x) ˆp(x)) γf0,ˆp(x)( ˆf(x) f0(x)) whenever φπℓis well-defined. We can now invoke the expansion (7) and Cauchy-Schwarz to obtain the bound Df(LD(f0,p0,γ) LD(f0,ˆp,γ))[ ˆf f0] = E[ φℓ(W;f0(X),p0(X)) φℓ(W;f0(X),p(X)), ˆf(X) f0(X) ] E[(p0(X) ˆp(X)) γ(X)( ˆf(X) f0(X))] = E[(p0(X) ˆp(X)) (γf0,ˆp(X) γ(X))( ˆf(X) f0(X))] E[ (p0(X) ˆp(X)) (γf0,ˆp(X) γ(X)) 2 ˆf(X) f0(X) 2] (p0 ˆp) (γf0,ˆp γ) 2,2 ˆf f0 2,2 Thus combining all the above inequalities: σ 2 ˆf f0 2 2,2 ϵ( ˆf,ˆp,γ)+ (ˆp p0) (γf0,ˆp γ) 2,2 ˆf f0 2,2 6Notably this strong convexity assumption can be relaxed to E h φℓ(W;f0(X),p0(X))( ˆf(X) f0(X) i 0. Published as a conference paper at ICLR 2021 By an AM-GM inequality, for all a,b 0: a b 1 2 b2). Applying this to the product of norms on the RHS and re-arranging yields σ 4 ˆf f0 2 2,2 ϵ( ˆf,ˆp,γ)+ 1 σ (ˆp p0) (γf0,ˆp γ) 2 2,2. To get the final inequality, observe that: (ˆp p0) (γf0,ˆp γ) 2 2,2 2 (ˆp p0) (qf0,ˆp γ) 2 2,2+2 (ˆp p0) (γf0,ˆp qf0,ˆp) 2 2,2 Moreover, by the boundedness of the third derivative, we have: (ˆp p0) (γf0,ˆp qf0,ˆp) 2 2,2 E[ ˆp(X) p0(X) 2 2 γf0,ˆp(X) qf0,ˆp(X) 2 2] E[ ˆp(X) p0(X) 2 2M 2k ˆp(X) p0(X) 2 2] M 2k ˆp p0 4 2,4 Combining all the above yields the final bound. G PROOF OF THM. 5: CROSS-FITTED ERM ANALYSIS Let Ln,t denote the empirical loss over the samples in the t-th fold and ˆp(t),ˆγ(t) the nuisance functions used on the samples in the k-th fold. For any t [K] and conditional on ˆp(t),ˆγ(t), suppose that δn upper bounds the critical radius of the function class G(ˆp(t),ˆγ(t)), then by Lemma 11 of Foster & Syrgkanis (2019),7 if we denote with ℓt,f(z)=ℓˆγ(t)(z;f(x),ˆp(t)(x)), w.p. 1 ζ: Ln,t( ˆf,ˆp(t),ˆγ(t)) Ln,t(f0,ˆp(t),ˆγ(t)) (LD( ˆf,ˆp(t),ˆγ(t)) LD(f0,ˆp(t),ˆγ(t))) O Hδn/B,ζ ℓt, ˆ f ℓt,f0 2,2+Hδ2 n/B,ζ Moreover, we have that by the definition of cross-fitted ERM: t=1 Ln,t( ˆf,ˆp(t),ˆγ(t)) Ln,t(f0,ˆp(t),ˆγ(t)) 0 Thus we have that w.p. 1 ζB: t=1 LD( ˆf,ˆp(t),ˆγ(t)) LD(f0,ˆp(t),ˆγ(t)) O Hδn/B,ζ 1 B t=1 ℓt,f ℓt,f0 2,2+Hδ2 n/B,ζ Moreover, if we let µ(z)=supφ,t φℓ(z;φ,ˆp(t)(x)) 2, then we have by Cauchy-Schwarz inequality: ℓt,f ℓt,f0 2,2 µ 4 f f0 2,4+ E h (Y ˆp(t)(X)) ˆγ(t)(X)(f(X) f0(X)) 2i µ 4 f f0 2,4+E (Y ˆp(t)(X)) ˆγ(t)(X) 2 2 f(X) f0(X)) 2 µ 4+E (Y ˆp(t)(X)) ˆγ(t)(X) 4 If we further assume that the function class F satisfies an ℓ2/ℓ4 condition that: f f0 2,4 f f0 2,2 C then w.p. 1 ζ: t=1 ϵ( ˆf,ˆp(t),ˆγ(t)) O Hδn/B,ζ/B 1 B t=1 C µp 4+E h (Y ˆp(X)) γ(X) 4 i1/4 f f0 2,2+Hδ2 n/B,ζ/B Applying Lemma 4 for any ˆp(t),ˆγ(t) and averaging the final inequality we get: 4 ˆf f0 2 2,2 1 ϵ( ˆf,ˆp(t),ˆγ(t))+ 1 σ (γf0,ˆp(t) ˆγ(t)) (ˆp(t) p0) 2 2,2 7We apply the lemma with Lg = g and g G(ˆp(t),ˆγ(t)) and g = 0. Then we instantiate the concentration inequality with g =ℓt, ˆ f ℓt,f0 G(ˆp(t),ˆγ(t)). Published as a conference paper at ICLR 2021 Plugging in the bound above to Lemma 4 and applying the AM-GM inequality and Jensen s inequality, yields: σ 8 ˆf f0 2 2,2 1 δ2 n/B,ζ/BC2H2 E h (Y ˆp(t)(X)) ˆγ(t)(X) 4 2 t=1 (γf0,ˆp(t) ˆγ(t)) (ˆp(t) p0) 2 2,2 H PROOF OF THM. 6: BIASED SGD ANALYSIS Below, for any integer s, we define the operator norm of any vector v Rs and any tensor T operating on Rs as v op v 2 and T op sup v: v 2=1 T[v] op. Recall the definition (W;θ,p,γ)= θfθ(X) φℓγ(W;fθ(X),p(X)) = θfθ(X) ( φℓ(W;fθ(X),p(X))+γ(X) (Y p(X))). Observe that since E[Y |X =x]=p0(x), we can write for any γ: L(θ;p0)=E[ℓ(W;fθ(X),p0(X))+(Y p0(X)) γ(X)fθ(X)]=E[ℓγ(W;fθ(X),p0(X))] Thus we also have that: θ,γ : θL(θ;p0)=E[ (W;θ,p0,γ)] Given this observation, we can decompose the gradient that is used in our SGD algorithm into a bias and variance component, when viewed from the perspective of a biased SGD algorithm for the population oracle loss: (W;θ,p,γ)= θL(θ;p0) +E[ (W;θ,p,γ)] E[ (W;θ,p0,γ)] | {z } b(θ,p,γ) + (W;θ,p,γ) E[ (W;θ,p,γ)] | {z } n(W ;θ,p,γ) The following two lemmas bound the gradient bias and noise terms. Lemma 7 (Gradient bias). If supx,φ,π E[ ππφℓ(W;φ,π)|X =x] op M, then for any parameter vector θ and functions p and γ, we have: b(θ,p,γ)=E[ θfθ(X) (γfθ,p(X) γ(X)) (p(X) p0(X))], b(θ,p,γ) 2 θf θ (γfθ,p γ) (p p0) 2,2, and b(θ,p,γ) 2 θf θ (qfθ,p γ) (p p0) 2,2+ M 2 θfθ F,2 p p0 2 2,4. Proof By Taylor s theorem with integral remainder and Lagrange remainder respectively the SGD bias for each parameter i takes the form bi(θ,p,γ)= E[ i(W;θ,p,γ)] E[ i(W;θ,p0,γ)] = E[ φℓγ(W;fθ(X),p(X)) φℓγ(W;fθ(X),p0(X)), θifθ(X) ] = E[(p(X) p0(X)) (γfθ,p(X) γ(X)) θifθ(X)] = E[(p(X) p0(X)) (qfθ,p(X) γ(X)) θifθ(X)] 2E[ ππφℓ(W;fθ(X), p(X))[ θifθ(X),p(X) p0(X),p(X) p0(X)]]. Furthermore, our operator norm assumption and Cauchy-Schwarz imply |bi(θ,p,γ)| |E[(p(X) p0(X)) (qfθ,p(X) γ(X)) θifθ(X)]|+ M 2 E h θifθ(X) 2 p(X) p0(X) 2 2 i |E[(p(X) p0(X)) (qfθ,p(X) γ(X)) θifθ(X)]|+ M 2 θifθ 2,2 p p0 2 2,4. Thus, by the triangle inequality and Jensen s inequality we find that b(θ,p,γ) 2 θf θ (γfθ,p γ) (p p0) 2,2 and b(θ,p,γ) 2 θf θ (qfθ,p γ) (p p0) 2,2+ M 2 θfθ F,2 p p0 2 2,4. Lemma 8 (Gradient Variance). Define For any parameter θ and functions p and γ, q E[ n(W;θ,p,γ) 2 2] σ0(θ)+ q E[ θfθ(X) γ(X) (Y p0(X)) 2 2]+ θf θ (γfθ,p γ) (p p0) 2 2,2. Published as a conference paper at ICLR 2021 Proof For each i [d], define the shorthand i = θiℓ(W;fθ(X),p(X))+ θifθ(X) γ(X) (Y p(X)) θiℓ(W;fθ(X),p0(X)) and Zi =E[ i |X] = θifθ(X) (γ(X) γfθ,p(X)) (p0(X) p(X)) = θifθ(X) (γ(X) qfθ,p(X)) (p0(X) p(X)) 2E[ ππφℓ(W;fθ(X), p(X))[ θifθ(X),p0(X) p(X),p(X) p0(X)]] for some convex combination p(X) of p(X) and p0(X). We begin by bounding the target expectation using Cauchy-Schwarz E[ n(W;θ,p,γ) 2 2] i [d] Var[ θiℓ(W;fθ(X),p(X))+ θifθ(X) γ(X) (Y p(X))] i [d] Var[ θiℓ(W;fθ(X),p0(X))+ i] = σ0(θ,p0)2+ X i [d] Var[ i]+2Cov( θiℓ(W;fθ(X),p0(X)), i) σ0(θ,p0)2+ X i [d] Var[ i]+2 p Var[ θiℓ(W;fθ(X),p0(X))]Var[ i] σ0(θ,p0)2+( X i [d] Var[ i])+2 s X i [d] Var[ θiℓ(W;fθ(X),p0(X))] X i [d] Var[ i] = (σ0(θ,p0)+ s X i [d] Var[ i])2. We next employ the law of total variance to rewrite the variance terms: X i [d] Var[ i]= X i [d] Var[Zi+ θifθ(X) γ(X) (Y p0(X))] = E[ θfθ(X) γ(X) (Y p0(X)) 2 i [d] Var[Zi]. Finally, we control Var[Zi] using Cauchy-Schwarz s X i [d] Var[Zi] θf θ (γfθ,p γ) (p p0) 2,2. The two claims of Thm. 6 now follow from Theorems 2 and 3 of Ajalloeian & Stich (2020) respectively, with the parameters σ2 and ζ instantiated with quantities σ2(γ) and ζ(γ) of Lemmas 7 and 8. I EXPERIMENT DETAILS AND ADDITIONAL RESULTS I.1 TABULAR DATA We use cross-fitting with 10 folds. The student is trained using the SEL loss with clipped teacher class probabilities max(ˆp(x),ϵ) for ϵ=10 3. The α hyperparameter of the loss correction was chosen by cross-validation with 5 folds. We repeat the experiments 5 times to measure the mean and standard deviation. For the overfitting experiment, we use a random forest with 500 trees as the teacher and a random forest with 1-40 trees as the student. We also evaluate the impact of teacher underfitting by limiting the teacher s maximum tree depth (from 1 to 20). Lower depth corresponds to greater underfitting. The teacher has 100 trees, and the student has 10 trees. For all of the datasets, loss correction successfully mitigates the teacher s underfitting and thus improves the student s performance. The effect is most pronounced when the teacher underfits more heavily (has lower tree depth). Published as a conference paper at ICLR 2021 We show the full results for all 5 of the datasets in Figs. 4 and 5. 123 5 10 15 20 30 40 Student's number of trees Student without KD KD Student Cross-fit KD Student Enhanced KD Student Teacher: 500 trees (a) Adult dataset 123 5 10 15 20 30 40 Student's number of trees Student without KD KD Student Cross-fit KD Student Enhanced KD Student Teacher: 500 trees (b) FICO dataset 123 5 10 15 20 30 40 Student's number of trees Test Accuracy Student without KD KD Student Cross-fit KD Student Enhanced KD Student Teacher: 500 trees (c) Higgs dataset 123 5 10 15 20 30 40 Student's number of trees Student without KD KD Student Cross-fit KD Student Enhanced KD Student Teacher: 500 trees (d) MAGIC dataset 123 5 10 15 20 30 40 Student's number of trees Test Accuracy Student without KD KD Student Cross-fit KD Student Enhanced KD Student Teacher: 500 trees (e) Stumble Upon dataset Figure 4: Tabular random forest distillation with varying student complexity. 1 2 3 5 10 15 20 Teacher's max tree depth Student without KD KD Student: 10 trees, max depth= Cross-fit KD Student Enhanced KD Student Teacher: 100 trees (a) Adult dataset 1 2 3 5 10 15 20 Teacher's max tree depth Student without KD KD Student: 10 trees, max depth= Cross-fit KD Student Enhanced KD Student Teacher: 100 trees (b) FICO dataset 1 2 3 5 10 15 20 Teacher's max tree depth Test Accuracy Student without KD KD Student: 10 trees, max depth= Cross-fit KD Student Enhanced KD Student Teacher: 100 trees (c) Higgs dataset 1 2 3 5 10 15 20 Teacher's max tree depth Student without KD KD Student: 10 trees, max depth= Cross-fit KD Student Enhanced KD Student Teacher: 100 trees (d) MAGIC dataset 1 2 3 5 10 15 20 Teacher's max tree depth Test Accuracy Student without KD KD Student: 10 trees, max depth= Cross-fit KD Student Enhanced KD Student Teacher: 100 trees (e) Stumble Upon dataset Figure 5: Tabular random forest distillation with varying teacher complexity. I.2 IMAGE DATA (CIFAR-10) We use SGD with initial learning rate 0.1, momentum 0.9, and batch size 128 to train for 200 epochs. We use the standard learning rate decay schedule, where the learning rate is divided by 5 at epoch 60, 120, and 160. For loss correction, we select the value of the hyperparameter α that yields the highest accuracy on a held-out validation set. For cross-fitting, we use 10 folds.