# a_statistical_perspective_on_retrievalbased_models__0cb23e7e.pdf A Statistical Perspective on Retrieval-Based Models Soumya Basu * 1 Ankit Singh Rawat * 2 Manzil Zaheer * 3 Many modern high-performing machine learning models increasingly rely on scaling up models, e.g., transformer networks. Simultaneously, a parallel line of work aims to improve the model performance by augmenting an input instance with other (labeled) instances during inference. Examples of such augmentations include task-specific prompts and similar examples retrieved from the training data by a nonparametric component. Despite a growing literature showcasing the promise of these retrieval-based models, their theoretical underpinnings remain under-explored. In this paper, we present a formal treatment of retrievalbased models to characterize their performance via a novel statistical perspective. In particular, we study two broad classes of retrieval-based classification approaches: First, we analyze a local learning framework that employs an explicit local empirical risk minimization based on retrieved examples for each input instance. Interestingly, we show that breaking down the underlying learning task into local sub-tasks enables the model to employ a low complexity parametric component to ensure good overall performance. The second class of retrieval-based approaches we explore learns a global model using kernel methods to directly map an input instance and retrieved examples to a prediction, without explicitly solving a local learning task. 1. Introduction As our world is complex, we need expressive machine learning (ML) models to make high-accuracy predictions on real-world problems. There are multiple ways to increase the expressiveness of an ML model. A popular way is to *Equal contribution; in alphabetical order 1Google, Mountain View, USA 2Google Research, New York, USA 3Google Deep Mind, New York, USA. Correspondence to: Soumya Basu . Proceedings of the 40 th International Conference on Machine Learning, Honolulu, Hawaii, USA. PMLR 202, 2023. Copyright 2023 by the author(s). homogeneously scale the size of a parametric model, such as neural networks, which has been behind many recent high-performance models such as GPT-3 (Brown et al., 2020) and Vi T (Dosovitskiy et al., 2021). Their performance (accuracy) exhibits a monotonic behavior with increasing model size, as demonstrated by scaling laws (Kaplan et al., 2020; Hoffmann et al., 2022). Such large models, however, have their own limitations, including high computation cost, catastrophic forgeting (hard to adapt to changing data), lack of provenance, and poor explainability. Classical instancebased models (Fix & Hodges, 1989), on the other hand, offer many desirable properties by design efficient data structures, incremental learning (easy addition and deletion of knowledge), and some provenance for its prediction based on the nearest neighbors w.r.t. the input. However, these models often suffer from weaker empirical performance as compared to deep parametric models. Increasingly, a middle ground combining the two paradigms and retaining the best of both worlds is becoming popular across various domains, ranging from natural language (Das et al., 2021; Wang et al., 2022; Liu et al., 2022; Izacard et al., 2022), to vision (Liu et al., 2015; 2019; Iscen et al., 2022; Long et al., 2022), to reinforcement learning (Blundell et al., 2016; Pritzel et al., 2017; Ritter et al., 2020), to even protein structure prediction (Cramer, 2021). In such approaches, given a test input, one first retrieves relevant entries from a data index and then processes the retrieved entries along with the test input to make the final predictions using an ML model. This process is visualized in Fig. 1c. While classical learning setups (cf. Fig. 1a and 1b) have been studied extensively over decades, even basic properties and trade-offs pertaining to retrieval-based models (cf. Fig. 1c), despite their aforementioned remarkable successes, remain highly under-explored. Most of the existing efforts on retrieval-based models solely focus on developing end-to-end domain-specific models, without identifying the key dataset properties or structures that are critical in realizing performance gains by such models. Furthermore, at first glance, due to the highly dependent nature of an input and the associated retrieved set, direct application of existing statistical learning techniques does not appear as straightforward. This prompts a natural question: What is the right theoretical framework to rigorously showcase the value of the retrieved set in ensuring superior performance of modern retrieval-based models? A Statistical Perspective on Retrieval-Based Models (a) Parametric setup (b) Instance-based learning setup (c) Modern retrieval-based setup Figure 1. An illustration of a retrieval-based classification model. Given an input instance x, similar to an instance-based model, it retrieves similar (labeled) examples Rx = {(x j, y j)}j from training data. Subsequently, it processes input instance along with the retrieved examples (potentially via a nonparametric method) to make the final prediction ˆy = f(x, Rx). In this paper, we take the first step towards answering this question, while focusing on the classification setting (Sec. 2.1). We begin with the hypothesis that the model might be using the retrieved set to do local learning implicitly and then adapt its predictions to the neighborhood of the test point. Multiple recent works (Garg et al., 2022; Aky urek et al., 2022; von Oswald et al., 2022) have studied the feasibility of such a mechanism in widely popular Transformer models. Notably, these works show that a Transformer network can emulate gradient descent to optimize a local learning objective when presented with multiple labeled examples as inputs. Such local learning is potentially beneficial in cases where the underlying task has a local structure, where a much simpler function class suffices to explain the data in a given local neighborhood but overall the data can be complex (formally defined in Sec. 2.2). For example, to solve an issue at hand (a problem instance), it is often faster to search for solutions to similar problems on Stackoverflow and utilize those (i.e., locally learning from the retrieved similar labeled examples) than understanding the whole system (i.e., learning the entire global function). Inspired by Bottou & Vapnik (1992), we analyze an explicit local learning framework: For each test input, 1) we retrieve a few (labeled) training examples located in the vicinity of the test input, 2) train a local model by performing empirical risk minimization (ERM) with only these retrieved examples local ERM; and 3) apply the resulting local model to make prediction on the test input. For the aforementioned retrievalbased local ERM, we derive finite sample generalization bounds that highlight a trade-off between the complexity of the underlying function class and size of the neighborhood where local structure of the data distribution holds in Sec. 3. Under this assumption of local regularity, we show that by using a much simpler function class for the local model, we can achieve a similar loss/error to that of a complex global model (Thm. 3.7). Thus, we show that breaking down the underlying learning task into local sub-tasks enables the model to employ a low complexity parametric component to ensure good global accuracy via a retrieval-based model. We acknowledge that such local learning cannot be the complete picture behind the effectiveness of retrieval-based models. As noted in Zakai & Ritov (2008), there always exists a model with global component that is more preferable to a local-only model. In Sec. 3.4, we extend local ERM to a two-stage setup: First learn a global representation using entire dateset, and then utilize the representation at the test time while solving the local ERM as previously defined. This enables the local learning to benefit from good quality global representations, especially in sparse data regions. Finally, we move beyond explicit local learning to a setting that resembles more closely the empirically successful systems such as REINA (Wang et al., 2022), Web GPT (Nakano et al., 2021), and Alpha Fold (Cramer, 2021): A model that directly learns to predict from the input instance and associated retrieved similar examples end-to-end. Towards this, we take a preliminary step in Sec. 4 by studying a novel formulation of classification over an extended feature space (to account for the retrieved examples) by using kernel methods (Deshmukh et al., 2019). To summarize, our main contributions include: 1) Setting up a formal framework for classification via retrieval-based models under local structure; 2) Finite sample analysis of explicit local learning framework; 3) Comparison with simple parametric and nonparametric paradigms 4) Extending the analysis to a globally learnt model; and 5) Providing the first rigorous treatment of an end-to-end retrieval-based model to study its generalization by using kernel-based learning. 2. Problem setup We first provide a brief background on (multiclass) classification along with the necessary notations. Subsequently, we discuss the problem setup considered in this paper, which A Statistical Perspective on Retrieval-Based Models deals with designing retrieval-based classification models for the data distributions with local structures. 2.1. Multiclass classification In this work, we restrict ourselves to (multi-class) classification setting, with access to n training examples S = {(xi, yi)}i [n] X Y, sampled i.i.d. from the data distribution D := DX,Y . Let PD(A) := E(X,Y ) D 1{A} for any random variable A. Given S, one is interested in learning a classifier h : X Y that minimizes miss-classification error. It is common to define a classifier via a scorer f : x 7 f1(x), . . . , f|Y|(x) R|Y| that assigns a score to each class in Y for an instance x. For a scorer f, the corresponding classifier takes the form: hf(x) = arg maxy Y fy(x). Given a set of scorers Fglobal {f : X R|Y|}, learning a model implies finding a scorer in Fglobal that minimizes the miss-classification error or expected 0/1-loss: f 0/1 = arg minf Fglobal PD(hf(X) = Y ). (1) One typically employs a surrogate loss (Bartlett et al., 2006) ℓfor the miss-classification error 1{hf (X) =Y } and aims to minimize the associated population risk: Rℓ(f) = E(X,Y ) D ℓ f(X), Y . Since the underlying data distribution D is only accessible via examples in S, one learns a good scorer by minimizing the (global) empirical risk over the function class Fglobal as follows: ˆf = arg min f Fglobal 1 n i [n] ℓ f(xi), yi . (2) We denote b Rℓ(f) := 1 i [n] ℓ f(xi), yi . 2.2. Classification with local structure In this work, we assume that the underlying data distribution D has a local structure, where a much simpler (parametric) function class suffices to explain the data in each local neighborhood. Formally, for x X and r > 0, we define Bx,r := {x X : d(x, x ) r}, an r-radius ball around x, w.r.t. a metric d : X X R. Let Dx,r be the data distribution restricted to Bx,r, i.e., Dx,r(A) = D(A)/D (Bx,r Y) A Bx,r Y. (3) Further, let us define the local population risk of a function f at a given instance x X: Rx ℓ(f) = E(X ,Y ) Dx,r ℓ f(X ), Y . Now, the local structure condition of the data distribution ensures that, for each x X, there exists a low-complexity function class Fx, with |Fx| |Fglobal|, that approximates the Bayes optimal (w.r.t. Fglobal) for the local classification problem defined by Dx,r. That is, for a given εX > 0 and x X, we have that 1 min f Fx Rx ℓ(f) min f Fglobal Rx ℓ(f) + εX. (4) As an example, if Fglobal is linear in Rd (possibly dense) with bounded norm τ, then Fx can be a simpler function class such as linear in Rd with sparsity k d and with bounded norm τx τ. 2.3. Retrieval-based classification model This work focuses on retrieval-based methods that can leverage the aforementioned local structure of the data distribution. In particular, we focus on two such approaches: Local empirical risk minimization. Given a (test) instance x, the local empirical risk minimization (ERM) approach first retrieves a neighboring set Rx = {(x j, y j)} S. Subsequently, it identifies a (local) scorer ˆf x from a simple function class Floc {f : X R|Y|} as follows: ˆf x = arg min f Floc 1 |Rx| (x ,y ) Rx ℓ f(x ), y . (5) By convention if |Rx| = 0, ˆf x Floc is chosen arbitrarily. Note that the local ERM approach requires solving a local learning task for each test instance. Such a local learning algorithms was introduced by Bottou & Vapnik (1992). Another point worth mentioning here is that (5) employs the same function class Floc for each x, whereas the local structure assumption (cf. (4)) allows for an instance dependent function class Fx. We consider Floc that approximates x XFx closely. In particular, we assume that, for some εloc > 0, we have x X that min f Floc Rx ℓ(f) min f Fx Rx ℓ(f) + εloc. (6) Continuing with the example following (4), where Fx is linear with sparsity k d and bounded norm τx, one can take Floc to be linear with the same sparsity k and bounded norm τ < supx X τx. Classification with extended feature space. We also consider the setting where the scorer f can implicitly solve the local-ERM using retrieved neighboring labeled instances to make the classification prediction. In other words, the scorer directly maps the augmented input x Rx X (X Y) to per-class scores. One can learn such a scorer over extended feature space X (X Y) as follows: ˆf ex = arg minf Fex ˆRex ℓ(f), (7) where ˆRex ℓ(f) := 1 n P i [n] ℓ f xi, Rxi , yi) and a function class of interest over the extended space is denoted 1As stated, we require the local structure condition to hold for each x. This can be relaxed to hold with high probability with the increased complexity of exposition. A Statistical Perspective on Retrieval-Based Models as Fex f : X (X Y) R|Y| . Examples of such a function class include prompting transformers with the retrieved labeled examples. Moreover, it has been recently shown that a transformer can express certain algorithms for optimizing a local learning objective based on the examples from the prompt using gradient descent (Garg et al., 2022; Aky urek et al., 2022; von Oswald et al., 2022). Our goal is to develop a statistical understanding of these two retrieval-based methods for classification when the underlying data distribution has local structure. We present our theoretical treatment of local ERM and classification with extended feature space in Sec. 3 and 4, respectively. 3. Local empirical risk minimization In this section, our objective is to characterize the excess risk of local ERM. In particular, we aim to bound E(X,Y ) D ℓ( ˆf X(X), Y ) ℓ(f (X), Y ) . (8) Note that ˆf X (cf. (5)) in the above equation is a function of RX, and expectation over RX is taken implicitly. 3.1. Assumptions Before presenting an excess risk bound for the local ERM method, we introduce various necessary definitions and assumptions that play a critical role in our analysis. We define the margin of scorer f at a given label y Y as γf(x, y) = fy(x) maxy =y fy (x). (9) In order to ensure the margin of the scorer f has smooth deviation as x varies, we introduce L-coordinate Lipschitz condition: A scorer f is L-coordinate Lipschitz iff for all y Y and x, x X, we have |fy(x) fy(x )| L x x 2. (10) Following D oring et al. (2018), we define the weak margin condition for a scorer f: Given a distribution D, a scorer f satisfies (α, c)-weak margin condition iff, for all t 0, P(X,Y ) D(|γf(X, Y )| t) c tα. (11) One of the key assumptions that we rely on is the existence of an underlying scorer f true that explains the true labels, while ensuring the weak margin condition. Here, we note that the true function f true may neither lie in the function class Fglobal, nor in Floc. Assumption 3.1 (True scorer function). There exists a scorer f true such that, for all (x, y) X Y, f true generates the true label, i.e., γf true(x, y) > 0. Furthermore, we assume f true is Ltrue-coordinate Lipschitz, and satisfies the (αtrue, ctrue)-weak margin condition. Furthermore, we restrict ourselves to smooth loss functions that act on the margin of a scorer (cf. (9)). Assumption 3.2 (Margin-based Lipschitz loss). For any given example (x, y) and any scorer f, we have ℓ(f(x), y) = ℓ(γf(x, y)) and ℓis a decreasing function of the margin. Furthermore, the loss function ℓis Lℓ-Lipschitz function, i.e., |ℓ(γ) ℓ(γ )| Lℓ|γ γ |, γ γ . Recall that Rx corresponds to the samples in S that belong to Bx,r; hence, it follows the distribution Dx,r. For the rest of the paper, we limit ourselves to X Rd. We can extend this to more general metric spaces with the increased complexity of exposition. Let the density of the distribution of x X Rd be ρD(x). A common assumption in the nonparametric estimation literature is the weak density condition (see, e.g., D oring et al., 2018). Moreover, we need to ensure that with high probability the density ρD(x) is not too low. We do so following the idea of density level sets from Steinwart (2011). Accordingly, we make the following assumption. Assumption 3.3 (Data regularity condition). 1. (Weak density condition) There exists constants cwdc > 0, and δwdc > 0, such that for all x X and ρD(x)rd δd wdc, PX D[d(X , x) r] cd wdcρD(x)rd. 2. (Density level-set) There exists a function fρ(δ) with fρ(δ) 0 as δ 0, such that for any δ > 0, PX D[ρD(X) fρ(δ)] δ. (12) For example, for d-dimensional multivariate Gaussian with the covariance matrix Σ, we have fρ(δ) = Θ(2 d/2|Σ| 1/2δ ln(1/δ) d/2). This result can be extended to mixture of Gaussian and sub-gaussian random variables (see Appendix B.6 for details). Assumption 3.4 (Weak+ density condition). There exists constants cwdc+ 0, and αwdc+ > 0, such that for all x X and r [0, rmax], PX D[d(X , x) r] ρD(x)vold(r) 1 cwdc+rαwdc+. The above assumption implies Assumption 3.3.1. We will show that under Assumption 3.4 the local ERM error bounds can be tightened further. For example, in d-dimensional multivariate Gaussian with the covariance matrix Σ, we have cwdc+ = dλmax(Σ 1) 2(d+2) , and αwdc+ = 2 for rmax = p (d + 2)λmax(Σ), where λmax( ) denotes the maximum eigenvalue. 3.2. Excess risk bound for local ERM We now proceed to our main results on the excess risk bound of local ERM. Recall that, at x X, f x, denotes the minimizer of the population version of the local loss, and f the population risk minimizer for the global loss, i.e., f x, = arg min f Floc Rx ℓ(f); f = arg min f Fglobal Rℓ(f). (13) A Statistical Perspective on Retrieval-Based Models To bound the excess risk defined in Eq. (8), we first obtain the following upper bound on (8). Lemma 3.5 (Risk decomposition). The expected excess risk of the local ERM solution ˆf X is bounded as E(X,Y ) D h ℓ( ˆf X(X), Y ) ℓ(f (X), Y ) i E(X,Y ) D h RX ℓ(f X, ) RX ℓ(f ) i | {z } Local vs Global Population Optimal Risk F {Fglobal,Floc} E(X,Y ) D h sup f F RX ℓ(f) ℓ(f(X), Y ) i | {z } Global and Local: Sample vs Retrieved Set Risk + E(X,Y ) D h supf Floc RX ℓ(f) ˆRX ℓ(f) i | {z } Generalization of Local ERM + E(X,Y ) D h RX ℓ(f X, ) ˆRX ℓ(f X, ) i Central Absolute Moment of f X, We delegate the proof of Lem. 3.5 to Appendix B. Now, as a strategy to obtain desired excess risk bounds, we separately bound the four terms appearing in Lem. 3.5. Note that the first term captures the expected difference between the loss incurred by global population optima f Fglobal and the local population optima f x, Floc in a local region around the test instance x. The second term aims to capture the loss for a scorer evaluated at x vs. the expected value of the loss for the scorer at a random instance sampled in the local region of x based on Dx,r. The third term corresponds to the standard generalization error for the local ERM with respect to the local data distribution DX,r, whereas the fourth term is the empirical variation of the local population optima f X, around its population mean under DX,r. Let the coordinate-Lipschitz constants for scorers in Floc and Fglobal be Lloc and Lglobal, respectively. We define a function class G(X, Y ) := {(x , y ) 7 ℓ(γf( , )) ℓ(γf(X, Y )) : f Floc}. Here, by subtracting ℓ f(X), Y from the loss, we center the losses on RX for any function f Floc, and obtain a tighter bound by utilizing the local structure of the distribution DX,r. For any L > 0, for notational convenience let us define Mr(L; ℓ, ftrue, F) := 2Lℓ Lr + 2 F Lr ctrue 2Ltruer αtrue . (14) For any x X, the weak density condition provides high probability lower bound on the size of the retrieved set Rx. Proposition 3.6. Under the Assumption 3.3, for any x X, radius r > 0, and δ > 0, PD |Rx| < N(r, δ) δ, (15) for N(r, δ) = n cd wdc min{fρ(δ/2)rd, δd wdc} q Now, by controlling different terms appearing in the bound in Lem. 3.5, we obtain the following. Theorem 3.7 (Excess risk bound). Let (4) and (6); and Assumptions 3.1, 3.2 and 3.3 hold. For any δ > 0, and N(r, δ) as defined in Proposition 3.6, the expected excess risk of the local ERM solution ˆf X is bounded as E(X,Y ) D h ℓ( ˆf X(X), Y ) ℓ(f (X), Y ) i (εX + εloc) | {z } Local vs Global Optimal loss (I) + Mr(Lloc; ℓ, ftrue, Floc) + Mr(Lglobal; ℓ, ftrue, Fglobal) | {z } Global and Local: Sample vs Retrieved Set Risk (II) E(X,Y ) D h RRX G(X, Y ) RX N(r, δ) i + 5Mr(Lloc; ℓ, ftrue, Floc) + 8δLℓ Floc , | {z } Generalization of Local ERM and Central Absolute Moment of f X, (III) where RRX G(X, Y ) denotes the empirical Rademacher complexity of G(X, Y ). Under Assumption 3.4 and r rmax, Sample vs Retrieved Set Risk (II) is O(cwdc+rαwdc+). The above result shows a trade-off in approximation vs. generalization error as retrieval radius r varies. Approximation error. It comprises two components, defined by (I) and (II) in Thm. 3.7. εX shows the gap in approximating the r-radius neighborhood around X with a simple local function class FX which varies with X X. Next, εloc shows the gap in approximating the union of the local function class x XFx with a single function class Floc (possibly with smaller complexity) but while allowing for choosing a different optimizer f X Floc for each X X. Both εX and εloc typically increases with r. The second component of the approximation error (II) corresponds to the difference of risk for the sample X and the retrieved set RX for Fglobal and Floc, i.e., Mr(Lglobal; ℓ, ftrue, Fglobal) and Mr(Lloc; ℓ, ftrue, Floc). Eq. (14) suggests that the terms increase as O(poly(r)). When the data follows multivaraite Gaussian then term (II) increases as O(r2). Generalization error. It (III) depends on the size of the retrieved set RX and the Rademacher complexity of G(X, Y ) which is induced by Floc. With increasing radius r, the term N(r, δ) increases. The Rademacher complexity decays with increasing radius, r, typically at the rate of O(1/ p N(r, δ)). Thus, under the local ERM setting the total approximation error increases with increasing radius r, given Floc is fixed. On the contrary, the generalization error decreases with increasing radius r for a fixed Floc. This suggests a trade-off between the approximation and generalization error as we make a design choice about r. (We empirically validate this in Fig. 3.) Due to centering within the set G(X, Y ) we have A Statistical Perspective on Retrieval-Based Models the upper bound on this term as Mr(Lloc; ℓ, ftrue, Floc), which is effective for small r. This does not decay with |RX|, hence becomes worse with increasing r and complements the above standard. 3.3. Illustrative examples Assume the Fglobal admits qx-th order derivative in the region B(x, r). Then a natural choice for |Fx| is the set of multivariate polynomial functions of degree qx, namely P(qx), for some qx 1. The L1 approximation error between Fx P(qx) and Fglobal can be quantified using the remainder in Taylor s approximation. This remainder typically grows as C(Fglobal, qx)r(qx+1) for our choice of radius r for the neighborhood, where C(Fglobal, qx) depends on the function class and the degree. Therefore, we have εX C(Fglobal, qx)r(qx+1). Excess risk Fit loc globally Entire data Figure 2. Behavior of excess risk of local ERM Local linear models. Let us consider this setting where Floc is the class of linear classifiers in d-dimension. The error in approximating Fx = P(qx) for any qx > 1 with a linear classifier in the B(x, r) neighborhood for any x X is bounded by εloc = Θ(r2). Therefore, the term (I) admits the bound O(r2). The generalization term varies as O(1/ p N(r, δ)). For r Ω(n 1/2d log(n)1/2) and δ = n 1/2d then N(r, δ) = Ω( n(2d 1)/2drd). Combining this we obtain: Excess Risk O r2 + O(rmin{αtrue,1}) | {z } (II) O d n(2d 1)/2drd/2 + rmin{αtrue,1} n(2d 1)/4drd/2 + 1 n1/2d . | {z } (III) For r = n 1/2d log(n)1/2 the excess risk bound is O(n 1/2d log(n)1/2), where the bottleneck comes from the term (II), i.e., the sample vs retrieved risk. This is depicted in Fig. 2. Moreover, when the data has multivariate Gaussian distribution we have the term (II) scale as O(r2), leading to excess risk of O(n 1/d log(n)1/2). However, global ERM with linear classifiers increases the approximation error considerably. In particular, now approximation error becomes a constant O(diam(X)2), and dwarfs the generalization that decreases as O(1/ n). Feed-forward classifiers. As another concrete example we study the setting where Floc is a the class of fully connected deep neural networks (FC-DNN). We have fy( ) to be an L layer feed-forward network with 1-Lipschitz nonlinearities (Bartlett et al., 2017). Let, for layers l = 1 to L, the dimension of the weight matrix be (dl dl 1) with d L = |Y|. Also, let bl and sl be the ℓ2,1 norm and spectral norm upper bounds for layer l weight matrix, respectively, with bl/sl κ. We define dmax = maxl [L] dl and let B = maxx X x 2 QL l=1 sl. Approximation Error εloc: For bounding εloc in (6), we require L1 error of Floc in approximating polynomials of degree qmax = maxx X qx. An FC-DNN that can approximate polynomials with degree at most qmax upto L1 error εloc has (see, Theorem 9 in Liang & Srikant (2016))2 depth, L = O qmax + log(dqmax C (Fglobal, qmax)/εloc) , width, dmax = O d log(dqmax C (Fglobal, qmax)/εloc) . Here, C (Fglobal, qx) is a constant independent of r and ϵ. Rademacher complexity: We now bound the term E(X,Y ) D[RRX G(X, Y ) ||RX| > N(r, δ)] for this class. Following (Bartlett et al., 2017), for some universal constant C > 0 and any δ > 0, we can bound the term as C Lℓ B κ ln(dmax)L3/4 ln(Lℓ B n)3/2 N(r,δ) + 2δ B . We now provide an excess risk bound when Floc is the class of FC-DNN. Let r Ω(n 1/2d log(n)1/2) and δ = n 1/2d. Then, N(r, δ) = Ω( n(2d 1)/2drd). Now, by setting εloc = r(qmax+1), it follows from Thm. 3.7 that Excess Risk O r(qmax+1) + O(rmin{αtrue,1}) | {z } (II) O q3/4 max ln(dqmax/r)3/4 ln(n)3/2 n(2d 1)/2drd/2 + rmin{αtrue,1} n(2d 1)/4drd/2 + 1 n1/2d . | {z } (III) With r = n 1/2d log(n)1/2, the excess risk is bounded as O(n 1/2d log(n)1/2). Again (II) is the bottleneck. This bottleneck can be improved for multivariate Gaussian distribution with excess risk O(n 1/d log(n)1/2. Note that, it s also worth comparing local-ERM with conventional (non-local) ERM. Under the local structure condition (Sec. 2.2), one would utilize a simple Floc for local-ERM. This would correspond to the Rademacher complexity term in Thm. 3.7 being small. In contrast, the generalization bound for the traditional (non-local) ERM approach would depend on the Rademacher complexity of a function class Fglobal that can achieve a low approximation error on the entire domain. Such a function class (even under the local structure assumption) would be much more complex than 2Although width is not explicitly mentioned in (Liang & Srikant, 2016), it can be inferred from the constructions. A Statistical Perspective on Retrieval-Based Models Floc, resulting in a large Rademacher complexity. For the right design choice of r, and Floc, the approximation error increase of local-ERM can be offset by large generalization error of Fglobal. As a consequence, local ERM with simple function class Floc can outperform (non-local) ERM with a complex class Fglobal. 3.4. Endowing local ERM with global representations Note that the local ERM method takes a somewhat myopic view and does not aim to learn a global hypothesis that (partially or entirely) explains the entire data distribution. Such an approach may potentially result in poor performance in those regions of input domains that are not well represented in the training set. Here, we explore a two-stage approach leveraging the global pattern present in the training data to address this apparent shortcoming of local ERM. Given training data S and a simple function class Gloc : Rd R|Y|, the first stage involves learning a d-dimensional feature map ΦS : X Rd that simultaneously ensures good representation for the entire data distribution (Radford et al., 2021; Grill et al., 2020; Cer et al., 2018; Reimers & Gurevych, 2019). Subsequently, given a test instance x and its retrieved neighboring points Rx = {(x j, y j)} S, one employs local ERM with the function class: FΦS = {x 7 g ΦS(x) : g Gloc}. (16) At this point, it is tempting to invoke the proof strategy outlined following Lem. 3.5, with Floc replaced with FΦS to characterize the performance of the aforementioned twostage method. Note that one can indeed bound the first two terms appearing in Lem. 3.5 for the two-stage method as well. However, bounding the third term that corresponds to generalization gap for local ERM becomes challenging as FΦS depends on S via the global representation ΦS learned in the first stage. Interestingly, Foster et al. (2019) explored a general framework to address such dependence for standard (non retrieval-based) learning. In fact, as an instantiation of their general framework, Foster et al. (2019, Sec. 5.4) consider the ERM in feature space defined by a representation. We employ their techniques to obtain the following result on the generalization gap for local ERM with FΦS. Proposition 3.8. Assuming the representation learned during the first stage is -sensitive, i.e., for S and S that differ in a single example, we have ΦS(x) ΦS (x) x X. Furthermore, we assume that each g Gloc (cf. 16) is L-Lipschitz, the loss ℓ: R|Y| |Y| R is Lℓ,1-Lipschitz w.r.t. -norm in the first argument, and ℓis bounded by Mℓ. Then, following holds with probability at least 1 δ: E (X ,Y ) Dx,r[ℓ(f(X ), Y )] ˆRx ℓ(f) Mℓ+ 2 LLℓ,1|Rx| p log(1/δ)/2|Rx| + Rℓ(f) ˆRx ℓ(f) Furthermore, Rℓ(f) ˆRx ℓ(f) 2R (ℓ FΦS), (18) where ℓ FΦS = {(x, y) 7 ℓ(f(x), y) : f FΦS} and R denotes the Rademacher complexity of data dependent hypothesis sets (Foster et al., 2019). We defer the proof of Prop. 3.8 and necessary background on Foster et al. (2019) to Appendix D. As a potential advantage of utilizing a global representation with local ERM, one can realize high-performance local learning with an even simpler function class. For example, it s a common approach to only train a linear classifier on learned representations. Furthermore, a high-quality global representation can ensure good performance for those local regions that are not well represented in the training set. We leave a formal treatment of these topics for a longer version of this manuscript. 4. Classification in extended feature space Next, we focus on a family of retrieval-based methods that directly learn a scorer to map an input instance and its neighboring labeled instance to a score vector (cf. (7)). In fact, as discussed in Sec. 1, many successful modern instances of retrieval-based models such as REINA (Wang et al., 2022) and KATE (Liu et al., 2022) belong to this family. In this section, we provide the first rigorous treatment (to the best of our knowledge) for such models. As introduced in Sec. 2.3, our objective is to learn a function f : X (X Y) R|Y|. For a given instance x, such a function can leverage its neighboring set Rx (X Y) to improve the prediction on x. In this work, we restrict ourselves to a sub-family of such retrieval-based methods that first map Rx Dx,r to ˆDx,r an empirical estimate of the local distribution Dx,r, which is subsequently utilized to make a prediction for x. In particular, the scorers of interest are of the form: (x, Rx) 7 f(x, ˆDx,r), with f(x, ˆDx,r) = f1(x, ˆDx,r), . . . , f|Y|(x, ˆDx,r) R|Y|. Here, fy(x, ˆDx,r) denotes the score assigned to the y-th class. Thus, assuming that X Y denotes the set of distribution over X Y, we restrict to a suitable function class in {f : X X Y R|Y|}. Note that, given a surrogate loss ℓ: R|Y| Y R and scorer f, the empirical risk b Rex ℓ(f) and population risk Rex ℓ(f) take the following form: b Rex ℓ(f) = 1 i [n] ℓ xi, ˆDxi,r Rex ℓ(f) = E(X,Y ) D ℓ f(X, DX,r), Y . A Statistical Perspective on Retrieval-Based Models 100 101 102 103 Number of nearest neighbor Linear Poly(deg=3) RBF MLP(layers=2) k NN Retrieve+Linear Retrieve+Poly(deg=3) Retrieve+RBF Retrieve+MLP(layers=2) (a) Synthetic 100 101 102 103 Number of nearest neighbor Linear MLP(layers=2) k NN Retrieve+Linear Retrieve+MLP(layers=2) (b) CIFAR-10 100 101 102 103 104 105 Number of nearest neighbor Linear MLP(layers=2) Vi T-G/14 (So TA) Mobile Net V3 k NN Retrieve+Linear Retrieve+MLP(layers=2) Retrieve+Mobile Net V3 (c) Image Net Figure 3. Performance of local ERM with size of retrieved set across models of different complexity. Note that that the general framework for learning in the extended feature space e X := X X Y provides a very rich class of functions. In this paper, we focus on a specific form of learning methods in the extended feature space by using the kernel methods. The method as well as its analysis is obtained by adapting the work on utilizing kernel methods for domain generalization (Blanchard et al., 2011; Deshmukh et al., 2019). In particular, we study generalization of a kernel-based classifier over e X learnt via regularized ERM. Due to space constraints, we present an informal version of our result below. See Appendix E for the precise statement (cf. Thm. E.4), necessary background, and detailed proof. Theorem 4.1 (Informal). Let 0 δ 1 and N(r, δ) be as defined in (15). Then, under appropriate assumptions, with probability at least 1 δ, we have b Rex ℓ(f) Rex ℓ(f) C1n 1 2 1 + log 3 2 where F is extended feature kernel function class; and b Rex ℓ(f) and Rex ℓ(f) are empirical and population risks. Interestingly, the bound in Thm. 4.1 implies that the size of the retrieved set Rx (as captured by N(r, δ n)) has to scale at least logarithmically in the size of the training set n to ensure convergence. 5. Experiments There have been numerous successful practical applications of retrieval-based models in the literature (e.g., Wang et al., 2022; Das et al., 2021). Here, we present a brief empirical study for such models in order to corroborate the benefits predicted by our theoretical results. We also present preliminary experiments to empirically verify the kernel based extended feature space-based approach in Appendix E.3. Task and dataset. We perform experiments on both synthetic and real datasets, as summarized below. Further details are relegated to Appendix F. (i) Synthetic. We consider a task of binary classification on a Gaussian mixture. Each mixture component is endowed with its local linear decision boundary. We randomly generate a train set of size n = 10000 in a 10-dimensional space. We use Euclidean distance for retrieval and perform a 10-fold cross-validation. (ii) CIFAR-10. Next, we consider a task of binary classification on a real data for object detection. In particular, we consider a subset of CIFAR-10 dataset where we only restrict to images from Cat and Dog classes. We randomly partition the data into a train set of size n = 10000 points and remaining 2000 points for test. We use Euclidean distance for retrieval and do a 10-fold cross-validation. (iii) Image Net. Finally, we consider 1000-way classification task on Image Net dataset. We use the standard train-test split with n = 1281167 training and 50000 test examples. Following standard practice in literature, we use unsupervised but globally learned features from ALIGN (Jia et al., 2021) to do image retrieval. This also showcases benefits of endowing local ERM with global representation (Sec. 3.4). Given large computational cost, we could only run each experiment once in this setting. Methods. On all datasets, as baseline, we consider simple linear classifier and multi-layer perceptron (MLP) of two layers. For retrieval-based models, we consider each of the above methods as the local model to fit on retrieved data points via local ERM framework (Sec. 3). For synthetic datasets, we also considered support vector machines with polynomial kernel (of degree 3) and with radial basis function (RBF) kernel, both for baseline and local ERM. For Image Net, we additionally consider the state-of-theart (So TA) single model published for this task, which is A Statistical Perspective on Retrieval-Based Models from the most recent CVPR 2022 (Zhai et al., 2022), as a baseline. In addition, for Image Net, we also consider the pretrain-finetune version of local ERM, where using the retrieved set we fine-tune a Mobile Net V3 (Howard et al., 2019) model that has been pretrained on entire Image Net. Observations. In Fig. 3, we observe the tradeoff of varying the size of the retrieved set (as dictated by the neighborhood radius) on the performance of retrieval-based methods across all settings. We see that when the number of retrieved samples is small, local ERM has lower accuracy, this is due to large generalization error. When the size of the retrieved sample space is high, local ERM fails to minimize the loss effectively due to the lack of model capacity. We see that this effect being more pronounced for simpler function classes such as linear classifier as compared to MLP. In Fig. 3c, we see that, via local ERM with a small Mobile Net-V3 model, we are able to achieve the top-1 accuracy of 82.78 whereas a regularly trained Mobile Net-V3 model achieves the top-1 accuracy of only 65.80. Also the result is very competitive with So TA of 90.45 with a much larger model. Thus, our empirical evaluation demonstrates the utility of retrieval-based models via simple local ERM framework. In particular, it allows small sized models to attain very high performance. 6. Related work and discussion Local polynomial regression. Perhaps the most similar to our setup is the rich set of work on local polynomial regression, which has been around for a long time since the pioneering works of Stone (1977; 1980) . This line of work aims to fit a low-degree polynomial at each point in the data set based on a subset of data points. Such approaches gained a lot of attention as parametric regression was not adequate in various practical applications of the time. The performance of this approach critically depends on subset selected to locally fit the data. Towards this, various selection approaches have been considered: fixed bandwidth (Katkovnik & Kheisin, 1979), nearest neighbors (Cleveland, 1979), kernel weighted (Ruppert & Wand, 1994), and adaptive methods (Ruppert et al., 1995). All these work only analyze under mean squared error loss and do not handle classification nor provide finite sample generalization bounds, which we obtain in this work. Multi-task and meta learning. At a surface level, our setup might resemble multi-task and meta learning frameworks. In multi-task learning, we are given the examples from T tasks/distributions and the objective is to ensure good classification performance on all the tasks. In meta-learning, the setting is made harder by requiring good performance on a new target task. As a common approach in these settings, we learn a shared representation across the tasks and then learn a simple task-specific mapping on top of these learned shared features (Vilalta & Drissi, 2002, interalia). Theoretical investigations is quite limited: a few works study upper-bounds of generalization error in multi-task environments (Ben-David & Borbely, 2008; Ben-David et al., 2010; Pentina & Lampert, 2014; Amit & Meir, 2017), and even fewer in case of meta-learning (Balcan et al., 2019; Khodak et al., 2019; Du et al., 2020; Tripuraneni et al., 2021). However, most of these works assume linear or other simple class, whereas we consider general function class using kernel methods. It is not clear if the aforementioned representation based approach can apply to our setting because: each tasks have little overlap, very large number of tasks, and most importantly a priori an example belongs is not assigned to a task. Interestingly, in this work, we show that retrieval-based approach alleviate the needs to identify the task-membership. Here, we would like to highlight a contemporary work (Li et al., 2023) that studies in-context learning by Transformer models in a multi-task/meta-learning setting. In particular, this work relies on the notion of algorithm stability (Bousquet & Elisseeff, 2002) and presents generalization bounds for Transformers as in-context learners. 7. Conclusion and future direction In this work, we initiate the development of a theoretical framework to study the statistical properties of retrievalbased modern machine learning models. Our treatment of an explicit local learning paradigm, namely local-ERM, establishes an approximation vs. generalization error tradeoff. This highlights the advantage realized by access to a retrieved set during classification as it enables good performance with much simpler (local) function classes. As for the retrieval-based models that leverage a retrieved set without explicitly performing local learning, we present a systematic study by considering a kernel-based classifier over extended feature space. Studying end-to-end retrievalbased models beyond kernel-based classification is a natural and fruitful direction for future work. It s also worth exploring if existing retrieval-based end-to-end models inherently perform implicit local learning via architectures such as Transformers. Aky urek, E., Schuurmans, D., Andreas, J., Ma, T., and Zhou, D. What learning algorithm is in-context learning? investigations with linear models. ar Xiv preprint ar Xiv:2211.15661, 2022. Amit, R. and Meir, R. Meta-learning by adjusting priors based on extended PAC-bayes theory. ar Xiv preprint ar Xiv:1711.01244, 2017. Bai, Y. and Lee, J. D. Beyond linearization: On quadratic and higher-order approximation of wide neural networks. ar Xiv preprint ar Xiv:1910.01619, 2019. Balcan, M.-F., Khodak, M., and Talwalkar, A. Provable guarantees for gradient-based meta-learning. In Interna- A Statistical Perspective on Retrieval-Based Models tional Conference on Machine Learning, pp. 424 433. PMLR, 2019. Bartlett, P. L., Jordan, M. I., and Mc Auliffe, J. D. Convexity, classification, and risk bounds. Journal of the American Statistical Association, 101(473):138 156, 2006. Bartlett, P. L., Foster, D. J., and Telgarsky, M. J. Spectrallynormalized margin bounds for neural networks. Advances in neural information processing systems, 30, 2017. Ben-David, S. and Borbely, R. S. A notion of task relatedness yielding provable multiple-task learning guarantees. Machine learning, 73(3):273 287, 2008. Ben-David, S., Blitzer, J., Crammer, K., Kulesza, A., Pereira, F., and Vaughan, J. W. A theory of learning from different domains. Machine learning, 79(1-2):151 175, 2010. Blanchard, G., Lee, G., and Scott, C. Generalizing from several related classification tasks to a new unlabeled sample. In Shawe-Taylor, J., Zemel, R., Bartlett, P., Pereira, F., and Weinberger, K. (eds.), Advances in Neural Information Processing Systems, volume 24. Curran Associates, Inc., 2011. Blundell, C., Uria, B., Pritzel, A., Li, Y., Ruderman, A., Leibo, J. Z., Rae, J., Wierstra, D., and Hassabis, D. Modelfree episodic control. ar Xiv preprint ar Xiv:1606.04460, 2016. Bottou, L. and Vapnik, V. Local Learning Algorithms. Neural Computation, 4(6):888 900, 11 1992. ISSN 08997667. doi: 10.1162/neco.1992.4.6.888. Bousquet, O. and Elisseeff, A. Stability and generalization. Journal of machine learning research, 2(Mar):499 526, 2002. Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G., Henighan, T., Child, R., Ramesh, A., Ziegler, D. M., Wu, J., Winter, C., Hesse, C., Chen, M., Sigler, E., Litwin, M., Gray, S., Chess, B., Clark, J., Berner, C., Mc Candlish, S., Radford, A., Sutskever, I., and Amodei, D. Language models are few-shot learners, 2020. Cer, D., Yang, Y., Kong, S.-y., Hua, N., Limtiaco, N., John, R. S., Constant, N., Guajardo-Cespedes, M., Yuan, S., Tar, C., et al. Universal sentence encoder. ar Xiv preprint ar Xiv:1803.11175, 2018. Chen, M., Bai, Y., Lee, J. D., Zhao, T., Wang, H., Xiong, C., and Socher, R. Towards understanding hierarchical learning: Benefits of neural representations. Advances in Neural Information Processing Systems, 33:22134 22145, 2020. Cleveland, W. S. Robust locally weighted regression and smoothing scatterplots. Journal of the American statistical association, 74(368):829 836, 1979. Cramer, P. Alphafold2 and the future of structural biology. Nature Structural & Molecular Biology, 28(9):704 705, 2021. Das, R., Zaheer, M., Thai, D., Godbole, A., Perez, E., Lee, J. Y., Tan, L., Polymenakos, L., and Mc Callum, A. Casebased reasoning for natural language queries over knowledge bases. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pp. 9594 9611, Online and Punta Cana, Dominican Republic, November 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.emnlp-main.755. Deshmukh, A. A., Lei, Y., Sharma, S., Dogan, U., Cutler, J. W., and Scott, C. A generalization error bound for multi-class domain generalization, 2019. D oring, M., Gy orfi, L., and Walk, H. Rate of convergence of k-nearest-neighbor classification rule. Journal of Machine Learning Research, 18(227):1 16, 2018. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., and Houlsby, N. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021. Du, S. S., Hu, W., Kakade, S. M., Lee, J. D., and Lei, Q. Fewshot learning via learning the representation, provably. ar Xiv preprint ar Xiv:2002.09434, 2020. Fix, E. and Hodges, J. L. Discriminatory analysis. nonparametric discrimination: Consistency properties. International Statistical Review/Revue Internationale de Statistique, 57(3):238 247, 1989. Foster, D. J., Greenberg, S., Kale, S., Luo, H., Mohri, M., and Sridharan, K. Hypothesis set stability and generalization. In Wallach, H., Larochelle, H., Beygelzimer, A., d Alch e Buc, F., Fox, E., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019. Garg, S., Tsipras, D., Liang, P., and Valiant, G. What can transformers learn in-context? a case study of simple function classes. In Oh, A. H., Agarwal, A., Belgrave, D., and Cho, K. (eds.), Advances in Neural Information Processing Systems, 2022. URL https: //openreview.net/forum?id=fl NZJ2e Oet. Grill, J.-B., Strub, F., Altch e, F., Tallec, C., Richemond, P., Buchatskaya, E., Doersch, C., Avila Pires, B., Guo, Z., Gheshlaghi Azar, M., et al. Bootstrap your own latenta new approach to self-supervised learning. Advances A Statistical Perspective on Retrieval-Based Models in Neural Information Processing Systems, 33:21271 21284, 2020. Hoffmann, J., Borgeaud, S., Mensch, A., Buchatskaya, E., Cai, T., Rutherford, E., Casas, D. d. L., Hendricks, L. A., Welbl, J., Clark, A., et al. Training compute-optimal large language models. ar Xiv preprint ar Xiv:2203.15556, 2022. Howard, A., Sandler, M., Chu, G., Chen, L.-C., Chen, B., Tan, M., Wang, W., Zhu, Y., Pang, R., Vasudevan, V., Le, Q. V., and Adam, H. Searching for mobilenetv3. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), October 2019. Iscen, A., Fathi, A., Schmid, C., Caron, M., and Bird, T. A memory transformer network for incremental learning. ar Xiv preprint, 2022. Izacard, G., Lewis, P., Lomeli, M., Hosseini, L., Petroni, F., Schick, T., Dwivedi-Yu, J., Joulin, A., Riedel, S., and Grave, E. Few-shot learning with retrieval augmented language models. ar Xiv preprint ar Xiv:2208.03299, 2022. Jia, C., Yang, Y., Xia, Y., Chen, Y.-T., Parekh, Z., Pham, H., Le, Q., Sung, Y.-H., Li, Z., and Duerig, T. Scaling up visual and vision-language representation learning with noisy text supervision. In International Conference on Machine Learning, pp. 4904 4916. PMLR, 2021. Kaplan, J., Mc Candlish, S., Henighan, T., Brown, T. B., Chess, B., Child, R., Gray, S., Radford, A., Wu, J., and Amodei, D. Scaling laws for neural language models. ar Xiv preprint ar Xiv:2001.08361, 2020. Katkovnik, V. Y. and Kheisin, V. Dynamic stochastic approximation of polynomials drifts. Avtomatika i Telemekhanika, pp. 89 98, 1979. Khodak, M., Balcan, M.-F. F., and Talwalkar, A. S. Adaptive gradient-based meta-learning methods. Advances in Neural Information Processing Systems, 32, 2019. Lei, Y., Dogan, U., Zhou, D.-X., and Kloft, M. Datadependent generalization bounds for multi-class classification. IEEE Transactions on Information Theory, 65(5): 2995 3021, 2019. Li, Y., Swersky, K., and Zemel, R. Generative moment matching networks. In International conference on machine learning, pp. 1718 1727. PMLR, 2015. Li, Y., Ildiz, M. E., Papailiopoulos, D., and Oymak, S. Transformers as algorithms: Generalization and stability in in-context learning. ar Xiv preprint ar Xiv:2301.07067, 2023. Liang, S. and Srikant, R. Why deep neural networks for function approximation? ar Xiv preprint ar Xiv:1610.04161, 2016. Liu, J., Shen, D., Zhang, Y., Dolan, B., Carin, L., and Chen, W. What makes good in-context examples for GPT-3? In Proceedings of Deep Learning Inside Out (Dee LIO 2022): The 3rd Workshop on Knowledge Extraction and Integration for Deep Learning Architectures, pp. 100 114, Dublin, Ireland and Online, May 2022. Association for Computational Linguistics. doi: 10.18653/v1/2022. deelio-1.10. URL https://aclanthology.org/ 2022.deelio-1.10. Liu, S., Liang, X., Liu, L., Shen, X., Yang, J., Xu, C., Lin, L., Cao, X., and Yan, S. Matching-cnn meets knn: Quasiparametric human parsing. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1419 1427, 2015. Liu, Z., Miao, Z., Zhan, X., Wang, J., Gong, B., and Yu, S. X. Large-scale long-tailed recognition in an open world. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2537 2546, 2019. Long, A., Yin, W., Ajanthan, T., Nguyen, V., Purkait, P., Garg, R., Blair, A., Shen, C., and van den Hengel, A. Retrieval augmented classification for long-tail visual recognition. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 6959 6969, 2022. Muandet, K., Fukumizu, K., Sriperumbudur, B., Sch olkopf, B., et al. Kernel mean embedding of distributions: A review and beyond. Foundations and Trends R in Machine Learning, 10(1-2):1 141, 2017. Nakano, R., Hilton, J., Balaji, S., Wu, J., Ouyang, L., Kim, C., Hesse, C., Jain, S., Kosaraju, V., Saunders, W., et al. Webgpt: Browser-assisted question-answering with human feedback. ar Xiv preprint ar Xiv:2112.09332, 2021. Pentina, A. and Lampert, C. A PAC-bayesian bound for lifelong learning. In International Conference on Machine Learning, pp. 991 999, 2014. Pritzel, A., Uria, B., Srinivasan, S., Badia, A. P., Vinyals, O., Hassabis, D., Wierstra, D., and Blundell, C. Neural episodic control. In International Conference on Machine Learning, pp. 2827 2836. PMLR, 2017. Radford, A., Kim, J. W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., Sastry, G., Askell, A., Mishkin, P., Clark, J., et al. Learning transferable visual models from natural language supervision. In International Conference on Machine Learning, pp. 8748 8763. PMLR, 2021. Reimers, N. and Gurevych, I. Sentence-bert: Sentence embeddings using siamese bert-networks. ar Xiv preprint ar Xiv:1908.10084, 2019. A Statistical Perspective on Retrieval-Based Models Ritter, S., Faulkner, R., Sartran, L., Santoro, A., Botvinick, M., and Raposo, D. Rapid task-solving in novel environments. ar Xiv preprint ar Xiv:2006.03662, 2020. Ruppert, D. and Wand, M. P. Multivariate locally weighted least squares regression. The annals of statistics, pp. 1346 1370, 1994. Ruppert, D., Sheather, S. J., and Wand, M. P. An effective bandwidth selector for local least squares regression. Journal of the American Statistical Association, 90(432): 1257 1270, 1995. Samarin, M., Roth, V., and Belius, D. On the empirical neural tangent kernel of standard finite-width convolutional neural network architectures. ar Xiv preprint ar Xiv:2006.13645, 2020. Shalev-Shwartz, S. and Ben-David, S. Understanding machine learning: From theory to algorithms. Cambridge university press, 2014. Smola, A., Gretton, A., Song, L., and Sch olkopf, B. A hilbert space embedding for distributions. In Hutter, M., Servedio, R. A., and Takimoto, E. (eds.), Algorithmic Learning Theory, pp. 13 31, Berlin, Heidelberg, 2007. Springer Berlin Heidelberg. ISBN 978-3-540-75225-7. Steinwart, I. Adaptive density level set clustering. In Proceedings of the 24th Annual Conference on Learning Theory, pp. 703 738. JMLR Workshop and Conference Proceedings, 2011. Steinwart, I. and Christmann, A. Support Vector Machines. Springer Publishing Company, Incorporated, 1st edition, 2008. ISBN 0387772413. Stone, C. J. Consistent nonparametric regression. The annals of statistics, pp. 595 620, 1977. Stone, C. J. Optimal rates of convergence for nonparametric estimators. The annals of Statistics, pp. 1348 1360, 1980. Tripuraneni, N., Jin, C., and Jordan, M. Provable metalearning of linear representations. In International Conference on Machine Learning, pp. 10434 10443. PMLR, 2021. Vilalta, R. and Drissi, Y. A perspective view and survey of meta-learning. Artificial intelligence review, 18(2):77 95, 2002. von Oswald, J., Niklasson, E., Randazzo, E., Sacramento, J., Mordvintsev, A., Zhmoginov, A., and Vladymyrov, M. Transformers learn in-context by gradient descent. ar Xiv preprint ar Xiv:2212.07677, 2022. Wang, S., Xu, Y., Fang, Y., Liu, Y., Sun, S., Xu, R., Zhu, C., and Zeng, M. Training data is more valuable than you think: A simple and effective method by retrieving from training data. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pp. 3170 3179, Dublin, Ireland, May 2022. Association for Computational Linguistics. doi: 10.18653/v1/2022.acl-long.226. URL https:// aclanthology.org/2022.acl-long.226. Yarotsky, D. Error bounds for approximations with deep relu networks. Neural Networks, 94:103 114, 2017. Zakai, A. and Ritov, Y. How local should a learning method be?. In COLT, pp. 205 216. Citeseer, 2008. Zhai, X., Kolesnikov, A., Houlsby, N., and Beyer, L. Scaling vision transformers. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12104 12113, 2022. Zhang, T. Covering number bounds of certain regularized linear function classes. Journal of Machine Learning Research, 2(Mar):527 550, 2002. Zhang, T. Statistical analysis of some multi-category large margin classification methods. Journal of Machine Learning Research, 5(Oct):1225 1251, 2004. A Statistical Perspective on Retrieval-Based Models A. Preliminaries Definition A.1 (Rademacher complexity). Given a sample S = {zi = (xi, yi)}i [n] Z and a real-valued function class F : Z R, the empirical Rademacher complexity of F with respect to S is defined as i=1 σif(zi) where σ = {σi}i [n] is a collection of n i.i.d. Bernoulli random variables. For n N, the Rademacher complexity Rn(F) and worst case Rademacher complexity Rn(F) are defined as follows. Rn(F) = ES Dn [RS(F)] , and Rn(F) = sup S Zn RS(F). (20) Definition A.2 (Covering Number). Let ϵ > 0 and be a norm defined over Rn. Given a function class F : Z R and a collection of points S = {zi}i [n] Z, we call a set of points {uj}j [m] Rn an (ϵ, )-cover of F with respect to S, if we have sup f F min j [m] f(S) uj ϵ, (21) where f(S) = f(z1), . . . , f(zn) Rn. The -covering number N (ϵ, F, S) denotes the cardinally of the minimal (ϵ, )-cover of F with respect to S. In particular, if is an normalized-ℓp norm ( v = ( 1 dim(v) Pdim(v) i=1 |vi|p)1/p), then we simply use Np(ϵ, F, S) to denote the corresponding ℓp-covering number. B. Proofs for Section 3.2 B.1. Proof of Lemma 3.5 E(X,Y ) D h ℓ( ˆf X(X), Y ) ℓ(f (X), Y ) i // We add and subtract loss of the local optimizer f X, ( ) expected over DX,r = E(X,Y ) D h ℓ( ˆf X(X), Y ) E(X ,Y ) DX,r ℓ f X, (X ), Y + E(X ,Y ) DX,r ℓ f X, (X ), Y ℓ(f (X), Y ) i // We add and subtract loss of the global optimizer f ( ) expected over DX,r = E(X,Y ) D h ℓ( ˆf X(X), Y ) E(X ,Y ) DX,r ℓ f X, (X ), Y + E(X ,Y ) DX,r ℓ f (X ), Y ℓ(f (X), Y ) + E(X ,Y ) DX,r ℓ f X, (X ), Y E(X ,Y ) DX,r ℓ f (X ), Y i // We group (1) local vs global optimizer, (2) global optimizer at X vs expected over DX,r, // and (3) ERM loss at X vs local optimizer loss expected over DX,r = E(X,Y ) D h E(X ,Y ) DX,r ℓ f X, (X ), Y ℓ f (X ), Y i + E(X,Y ) D h E(X ,Y ) DX,r ℓ f (X ), Y ℓ(f (X), Y ) i + E(X,Y ) D h ℓ( ˆf X(X), Y ) E(X ,Y ) DX,r ℓ f X, (X ), Y i // We add and subtract loss of the empirical optimizer ˆf X( ) expected over DX,r = E(X,Y ) D h E(X ,Y ) DX,r ℓ f X, (X ), Y ℓ f (X ), Y i + E(X,Y ) D h E(X ,Y ) DX,r ℓ f (X ), Y ℓ(f (X), Y ) i A Statistical Perspective on Retrieval-Based Models + E(X,Y ) D h ℓ( ˆf X(X), Y ) E(X ,Y ) DX,r[ℓ ˆf X(X ), Y ] + E(X ,Y ) DX,r[ℓ ˆf X(X ), Y ] E(X ,Y ) DX,r ℓ f X, (X ), Y i // We (1) bound difference of loss at X and loss expected over DX,r by maximizing over function class, // and (2) subtract empirical loss of empirical optimizer and add (larger) empirical loss of local optimizer E(X,Y ) D h E(X ,Y ) DX,r ℓ f X, (X ), Y ℓ f (X ), Y i + E(X,Y ) D h sup f Fglobal E(X ,Y ) DX,r ℓ f(X ), Y ℓ(f(X), Y ) i + E(X,Y ) D h sup f Floc ℓ(f(X), Y ) E(X ,Y ) DX,r[ℓ f(X ), Y ]| i + E(X,Y ) D h E(X ,Y ) DX,r[ℓ ˆf X(X ), Y ] 1 |RX| (x ,y ) RX ℓ ˆf X(x ), y i + E(X,Y ) D h 1 |RX| (x ,y ) RX ℓ f X, (x ), y E(X ,Y ) DX,r ℓ f X, (X ), Y i (22) // We (1) bound difference of empirical vs expected loss of empirical optimizer by maximizing over function class, E(X,Y ) D h E(X ,Y ) DX,r ℓ f X, (X ), Y ℓ f (X ), Y i + E(X,Y ) D h sup f Fglobal E(X ,Y ) DX,r ℓ f(X ), Y ℓ(f(X), Y ) i + E(X,Y ) D h sup f Floc ℓ(f(X), Y ) E(X ,Y ) DX,r[ℓ f(X ), Y ]| i + E(X,Y ) D h sup f Floc E(X ,Y ) DX,r[ℓ f(X ), Y ] 1 |RX| (x ,y ) RX ℓ f(x ), y i + E(X,Y ) D h E(X ,Y ) DX,r ℓ f X, (X ), Y 1 |RX| (x ,y ) RX ℓ f X, (x ), y i (23) B.2. Proof of Theorem 3.7 As discussed in Sec. 3, the proof of Theorem 3.7 requires bounding three terms in Lemma 3.5. We now proceed to establishing the desired bounds. Local vs global loss. The local vs global loss can bounded easily using the local regularity condition, and due to the fact that Floc x Fx. Let f X,loc = arg min f FX E(X ,Y ) DX,r ℓ f(X ), Y . E(X,Y ) D h E(X ,Y ) DX,r ℓ f X, (X ), Y ℓ f (X ), Y i E(X,Y ) D h E(X ,Y ) DX,r ℓ f X, (X ), Y ℓ f X,loc(X ), Y i + E(X,Y ) D h E(X ,Y ) DX,r ℓ f X,loc(X ), Y ℓ f (X ), Y i A Statistical Perspective on Retrieval-Based Models Global and local: Sample vs retrieved set risk. The following lemma bounds the second term in Lemma 3.5. Recall the definition, for any L > 0, Mr(L; ℓ, ftrue, F) = 2Lℓ Lr + 2 F Lr ctrue 2Ltruer αtrue . (24) Lemma B.1. Under Assumption 3.1, for a L-coordinate Lipschitz function class F with F := supx X supf F f(x) we have E(X,Y ) D h sup f F ℓ(f(X), Y ) E(X ,Y ) DX,r[ℓ f(X ), Y ]| i 2Lℓ Lr + 2 F Lr ctrue(2Ltruer)αtrue . Proof. We are given the example (X, Y ). Let us fix an arbitrary f F, and any arbitrary example (x , y ) in the r neighborhood of X. We first bound the perturbation in γf( ) for a given label Y . |γf(X1, Y )) γf(X2, Y )| |f Y (X1) max s = Y fs(X1) f Y (X2) + max s = Y fs (X2)| |f Y (X1) f Y (X2)| + | max s = Y fs(X1) max s = Y fs (X2)| |f Y (X1) f Y (X2)| + max s = Y |fs(X1) fs(X2)| We can now proceed with bounding the loss. |ℓ(f(X), Y ) ℓ(f(x ), y )| = |ℓ(γf(X, Y )) ℓ(γf(x , y ))| Lℓ|γf(X, Y ) γf(x , y )| ( 4Lℓ f ; Y = y 2LℓLr; Y = y Under Assumption 3.1, if we have γf true(X, Y ) > 2Ltruer, then following the above argument we have γf true(X , Y ) > 0, thus Y is the true label of X . In other words, γf true(X, Y ) > 2Ltruer imply for any X in the r neighborhood of X its true label Y = Y . |ℓ(f(X), Y ) ℓ(f(x ), y )| 2LℓLr1(γf true(X, Y ) > 2Ltruer) + 4Lℓ f 1(γf true(X, Y ) 2Ltruer) 2LℓLr + 2Lℓ 2 f Lr 1(γf true(X, Y ) 2Ltruer) As (x , y ) was an arbitrary r-neighbor, we have |ℓ(f(X), Y ) E(X ,Y ) DX,rℓ(f(X ), Y )| E(X ,Y ) DX,r|ℓ(f(X), Y ) ℓ(f(X ), Y )| 2LℓLr + 2Lℓ 2 f Lr 1(γf true(X, Y ) 2Ltruer) Furthermore, as f was arbitrary, we have sup f F |ℓ(f(X), Y ) E(X ,Y ) DX,rℓ(f(X ), Y )| sup f F 2LℓLr + 2Lℓ 2 f Lr 1(γf true(X, Y ) 2Ltruer) A Statistical Perspective on Retrieval-Based Models = 2LℓLr + 2Lℓ 2 F Lr 1(γf true(X, Y ) 2Ltruer). Note f true is independent of f, which was used in the derivation of above inequalities. Taking expectation over (X, Y ), and using the margin condition as given in assumption 3.1 we obtain E(X,Y ) D h sup f F |ℓ(f(X), Y ) E(X ,Y ) DX,rℓ(f(X ), Y )| i = 2LℓLr + 2Lℓ 2 F Lr P(X,Y ) D h γf true(X, Y ) 2Ltruer i 2LℓLr + 2Lℓ 2 F Lr ctrue(2Ltruer)αtrue = Mr(L; ℓ, ftrue, F). Plugging in the Lipschitz bounds for the function classes Floc and Fglobal in the above lemma bounds the second term. An alternative way of bounding the risk difference is as follows: E(X,Y ) D h E(X ,Y ) DX,r ℓ f(X ), Y ℓ(f(X), Y ) i h(x )ρD(x ) PD[B(x,r) X]dx ρD(x)dx Z x X h(x)ρD(x)dx x X h(x ) Z x B(x ,r) X ρD(x) PD[B(x,r) X]dx ρD(x )dx Z x X h(x)ρD(x)dx x X h(x ) Z x B(x ,r) X ρD(x) PD[B(x,r) X]dx 1 ρD(x )dx x B(x ,r) X ρD(x) PD[B(x,r) X]dx 1 ρD(x )dx x X max 1 1 cwdc+rαwdc+ 1 ρD(x )dx = hmaxcwdc+rαwdc+ 1 cwdc+rαwdc+ . We can express ℓ(f(X), Y ) = h(X) because Y is a deterministic function of X. Under Assumption 3.4, with constants cwdc+ and αwdc+, recall that PX D[d(X , x) r] ρD(x)vold(r) 1 cwdc+rαwdc+. Plugging this in gives us the final inequality. Example: Let us consider the term PD[B(x ,r)] PD[B(x,r) X] for D being multivariate Gaussian N(µ, Σ). Let vold(r) imply the volume, and Sd(r) the surface area of a d-sphere of radius r in dimension d. PD[B(x, r) X] = Z (2π)d|Σ| exp( 1 2(z µ)T Σ 1(z µ))dz (2π)d|Σ| exp( 1 2(x µ)T Σ 1(x µ)) Z z B(x,r) X exp( 1 2(z + x 2µ)T Σ 1(z x))dz u B(0,r) exp( 1 2(u + 2(x µ))T Σ 1u)du = ρD(x) vold(r) c Z u B(0,r) (u + 2(x µ))T Σ 1udu for some c [1/4, 1/2] = ρD(x) vold(r) c Z u B(0,r) (c1 u 2 2 + 2(x µ)T Σ 1u)du for some c1 [λmin(Σ 1), λmax(Σ 1)] We have used exp( x) (1 x, 1 x/2) for x 1.59. Also, u T Σ 1u u 2 2 [λmin(Σ 1), λmax(Σ 1)]. A Statistical Perspective on Retrieval-Based Models Then using polar coordinate transform we obtain Z u B(0,r) u 2 2du = Z r 0 l2Sd(l)dl = Z r 0 l2 d ld 1πd/2 Γ(1+d/2) dl = rd+2 d πd/2 (d+2)Γ(1+d/2) = vold(r) dr2 Let ξ = 2(x µ)Σ 1. We want to integrate ξT u over B(0, r). Through a somewhat different polar transform where the polar axis is parallel to ξ and the angle of u and ξ is θ we can do the integral as follows for d 2. Z u B(0,r) ξT udu = Z r θ=0 |ξ|l cos(θ)Sd 1(l) sind 2(θ)dθdl = dπd/2 Γ(1+d/2) l=0 |ξ|ld 1dl Z π θ=0 cos(θ) sind 2(θ)dθ | {z } =0 Substituting, these values in the above inequality we get PD[B(x, r) X] = ρD(x)vold(r) 1 c2r2 , for some c2 = [ dλmin(Σ 1) 4(d+2) , dλmax(Σ 1) Therefore, the difference of Retrieved vs Sample risk for multi-variate Gaussian is bounded as (d + 2)λmax(Σ) for r p (d + 2)λmax(Σ). Generalization of local ERM. Recall the function class G(X, Y ) = {ℓ(γf( , )) ℓ(γf(X, Y )) : f Floc}. Here G(X, Y ) : X Y R. Note that the function class is parameterized by (X, Y ). Let us define some quantities of the function class on a set S X Y as Gmax((X, Y ); S) = sup g G(X,Y ) sup (x ,y ) S |g(x , y )| By centering each function f Floc at the point (X, Y ) we can transform the generalization over the function class Floc, to the generalization over the function class G(X, Y ). In particular, we have E(X,Y ) D h sup f Floc E(X ,Y ) DX,r[ℓ f(X ), Y ] 1 |RX| (x ,y ) RX ℓ f(x ), y i E(X,Y ) D h sup f Floc E(X ,Y ) DX,r[ℓ f(X ), Y ℓ f(X), Y ] (x ,y ) RX ℓ f(x ), y ℓ f(X), Y |RX| N(r, δ) i + 4δLℓ Floc = E(X,Y ) D h sup g G(X,Y ) E(X ,Y ) DX,r[g(X , Y )] 1 |RX| (x ,y ) RX g(x , y ) |RX| N(r, δ) i + 4δLℓ Floc . We next state a standard result of learning theory that bounds the final term using the Rademacher complexity of the function class G(X, Y ) (Shalev-Shwartz & Ben-David, 2014). Lemma B.2 (Adapted from Theorem 26.5 in Shalev-Shwartz & Ben-David (2014).). For any (X, Y ) X Y and a neighborhood set RX, and any function g G(X, Y ), for each δ > 0 with probability at least (1 δ) the following holds E(X ,Y ) DX,r[g X , Y ] 1 |RX| (x ,y ) RX g x , y 2RRX G(X, Y ) + 4Gmax((X, Y ); RX) Taking expectation with respect to (X, Y ), we obtain E(X,Y ) D h sup g G(X,Y ) E(X ,Y ) DX,r[g X , Y ] 1 |RX| (x ,y ) RX g x , y |RX| N(r, δ) i A Statistical Perspective on Retrieval-Based Models 2E(X,Y ) D h RRX G(X, Y ) |RX| N(r, δ) i + 4E(X,Y ) D h Gmax((X, Y ); RX) |RX| N(r, δ) i + 4δLℓ Floc 2E(X,Y ) D h RRX G(X, Y ) |RX| N(r, δ) i + 4E(X,Y ) D h Gmax((X, Y ); RX) |RX| N(r, δ) i E(X,Y ) D hs |RX| N(r, δ) i + 4δLℓ Floc 2E(X,Y ) D h RRX G(X, Y ) |RX| N(r, δ) i + 4Mr(Lloc; ℓ, ftrue, Floc)E(X,Y ) D hs i + 4δLℓ Floc . In the first inequality, with probability (1 δ) we apply the bound from Lemma B.2, whereas we use the bound 4Lℓ Floc with remaining probability δ. Also from the proof of Lemma B.1 we have that Gmax((X, Y ); RX) 2Lℓ Lr + max{Lr, 2 Floc } Lr 1 γf true(X, Y ) 2Ltruer . Taking expectation with respect to D completes the bound. While taking expectation we crucially use the fact that γf true(X, Y ) is independent of |RX| to arrive at the Mr(Lloc; ℓ, ftrue, Floc) bound. Central absolute moment of f X, . As the function f X, is fixed using centering, and then Hoeffding bound, we can directly bound the remaining term. We have with probability at least (1 δ) E(X ,Y ) DX,r ℓ f X, (X ), Y 1 |RX| (x ,y ) RX ℓ f X, (x ), y = E(X ,Y ) DX,r ℓ f X, (X ), Y ℓ f X, (X), Y 1 |RX| (x ,y ) RX ℓ f X, (x ), y ℓ f X, (X), Y Gmax((X, Y ); RX) Taking expectation similar to the previous case we obtain, E(X,Y ) D h E(X ,Y ) DX,r ℓ f X, (X ), Y 1 |RX| (x ,y ) RX ℓ f X, (x ), y i E(X,Y ) D h min{4Lℓ Floc , Gmax((X, Y ); RX) Mr(Lloc; ℓ, ftrue, Floc)E(X,Y ) D hs ln(2/δ) N(r, δ) i + 4δLℓ Floc . Here, we use the fact that γf true(X, Y ) is independent of |RX|. This concludes the proof of Theorem 3.7. B.3. Bounding the Rademacher complexity RRX G(X, Y ) We now derive bounds on the Rademacher complexity of the class G(X, Y ). We use the covering number based bounds for that purpose. We then start by relating it to the covering number of the Floc function class. Finally, we provide a bound on the class of functions residing in bounded norm Reproducing Kernel Hilbert Space. We will use Gmax(X, Y ) instead of Gmax((X, Y ); RX) when the context is clear. Similar to G(X, Y ), we define the function class G = {ℓ(γf( , )) : f Floc} which does not depend on the locality centered around (X, Y ). On a set S X Y we can define Gmax(S) = supg G sup(x ,y ) S |g(x , y )|. Lemma B.3. Under Assumption 3.1 we have for any retrieved set within radius r of X, RX, for any p 1 RRX G(X, Y ) A Statistical Perspective on Retrieval-Based Models Gmax(X, Y ) infϵ [0,Gp,max(X,Y )/2] 4ϵ + 12 |RX| R Gp,max(X,Y )/2 ϵ ν log Np(ν/2, G, RX) dν infϵ [0,Gmax(X,Y )/2] 4ϵ + 12 |RX| R Gmax(X,Y )/2 ϵ log N (ν/2, G, RX {(X, Y )}) dν . As a corollary we obtain the following rates, as the log-covering number varies with ν at different rates. Corollary B.4. Under Assumption 3.1 we have for any retrieved set within radius r of X, RX, for any p 1 RRX G(X, Y ) C (p, Floc) log2 2 max{|RX|,Gmax} |RX| ; if log Np(ε, G, n) C2(p, Floc) log(n/ε)/ε2, C (p,Floc)Gp,max(X,Y )1 α/2 log 2 max{|RX|,Gmax} |RX| ; if log Np(ε, G, n) C2(p, Floc) log(n/ε)/εα, α [0, 2). Proof. Case log Np(ε, G, n) C2(p, Floc) log(n/ε)/ϵα, α [0, 2): RRX G(X, Y ) 4Gp,max(X,Y ) |RX| + C C(p,Floc) Z Gp,max(X,Y )/2 Gp,max(X,Y )/ ν log 2|RX| 4Gp,max(X,Y ) |RX| + C C(p,Floc) Z Gp,max(X,Y )/2 Gp,max(X,Y )/ |RX| log 2 max{|RX|,Gmax} 4Gp,max(X,Y ) |RX| + C C(p,Floc) Gp,max(X, Y )1 α/2 log 2 max{|RX|, Gmax} + (1 α/2) 1e 1 . The last inequality follows from (1 α/2)2 Z b c x α/2 log(a/x)dx = b1 α/2(1 + (1 α/2) log(a/b)) c1 α/2(1 + (1 α/2) log(a/b)) = (b1 α/2 c1 α/2)(1 + (1 α/2) log(a)) + b1 α/2 log(1/b) c1 α/2 log(1/c) (b1 α/2 c1 α/2)(1 + (1 α/2) log(a)) + (1 α/2) 1e 1 Case log Np(ε, G, n) C2(p, Floc) log(n/ε)/ε2: RRX G(X, Y ) 4 |RX| + C C(p,Floc) Z Gp,max(X,Y )/2 ν log 2|RX| |RX| + C C(p,Floc) Z Gp,max(X,Y )/2 |RX| log 2 max{|RX|,Gmax} |RX| + C C(p,Floc) log2 2 max{|RX|, Gmax} q |RX| log2 4 max{|RX|,Gmax} Gp,max(X,Y ) |RX| + C C(p,Floc) |RX| log2 2 max{|RX|, Gmax} q Proof of Lemma B.3. Given the set RX, and some function g G(X, Y ) let us define for p 1 g p,RX = 1 |RX| X (x ,y ) RX |g(x , y )|p 1/p . A Statistical Perspective on Retrieval-Based Models Then, we have Gp,max (X, Y ); RX = maxg G g p,RX for all g G(X, Y ). For the sake of brevity we will use Gp,max(X, Y ) in place of Gp,max (X, Y ); RX . Note that we have from previous definition Gmax(X, Y ) = G ,max(X, Y ) Gp,max(X, Y ) for any p 1. A simple bound on the Rademacher complexity comes as a function of the radius r but which is independent of the size of |RX|. Specifically, we have for Rademacher random variable σi-s RRX G(X, Y ) 1 |RX|Eσ h | X (X i,Y i ) RX σigi| i max gi G(X,Y ) |gi| Gmax(X, Y ). Next using the Chaining method (Shalev-Shwartz & Ben-David, 2014, Chapter 27) we can bound the Radamacher complexity as RRX G(X, Y ) inf ϵ [0,Gp,max(X,Y )/2] Z Gp,max(X,Y )/2 log Np(ν, G(X, Y ), RX)dν . To finish the proof we need to show, for p 1 Np(ν, G(X, Y ), RX) Np(ν/2, G, RX)Np(ν/2, G, {(X, Y )}). First we fix any p 1. Let b U (a set of real numbers) be a ν/2 cover (in ℓp norm) of G with respect to {(X, Y )}. We have Np(ν, G(X, Y ), RX) 2Gmax ν for any p 1 and any ν > 0. Further, let U be a ν/2 cover of G with respect to RX. Note for any u U we have u R|RX|. Now, we fix any g G. We have at least one u U, and ˆu b U such that (x ,y ) RX |g (x , y ) u(x , y )|p 1/p ν/2, and |g (X, Y ) ˆu| ν/2. (x ,y ) RX | g (x , y ) g (X, Y ) u(x , y ) ˆu |p 1/p (x ,y ) RX | g (x , y ) u(x , y ) + ˆu g (X, Y ) |p 1/p (x ,y ) RX |g (x , y ) u(x , y )|p 1/p + |ˆu g (X, Y )| ν/2 + ν/2 ν The first inequality follows by applying Minkowski s inequality. Whereas, for the second inequality we apply Jensen s inequality for ( )1/p being a concave function for p 1, and applying the appropriate scaling. Therefore, given the covers U and b U, we can construct the set U with entries u R|RX| as: U := {u = ( u(x, y) ˆu) : u U, ˆu b U}. In particular, |U | = |b U|| U|. As the choice of g G and (x , y ) RX were arbitrary, we have U to be the cover of G(X, Y ). For p = we can specialize the bound. In particular, consider U to be a ν/2 cover (in ℓ norm) of G with respect to RX {(X, Y )}. Then U := {u = ( u(x, y) ˆu(X, Y )) : u U} creates a (normalized) ℓ cover for G with respect to RX. This is true because 1 |RX| P (x ,y ) RX |g (x , y ) u(x , y )|p 1/p |g u| = ν/2 and |ˆu g (X, Y )| |g u| = ν/2. This concludes the proof. The first term in the above Lemma is similar to the Chaining based Rademacher bounds (Shalev-Shwartz & Ben-David, 2014, Chapter 28) for G, but the ϵ (in inf and in the integral) varies in [0, Gmax(X, Y )] instead of [0, Gmax]. For small r we have Gmax(X, Y ) << Gmax, which can be leveraged to give tight bounds in certain situations. A Statistical Perspective on Retrieval-Based Models Example: Floc ℓ -bounded RKHS (Zhang, 2004): Let us consider the setting of Zhang (2004). In this setting, given some Reproducing Kernel Hilbert Space (RKHS) H, and a function f H, we can define the function f( ) = f hx where for some h H. We further define the set of functions with bounded norm HA = { f( ) H : f H sup x X hx H A}. Finally, our local function class can be defined as Floc = H|Y| A = {f( ) : fy( ) HA, y Y}. We have Floc = A. Recall that loss function for any y Y is given as ℓ(γf(x, y)), for any f Floc. We also have for all y Y, |ℓ(γf(x, y)) ℓ(γf (x, y))| 2Lℓsupy |fy(x) f y(x)| (Zhang, 2004, Assumption 15) with γA = 2Lℓ). Given the above setting, following Lemma 17 in Zhang (2004) 3, we have for a universal constant c log N (2Lℓν, G, RX {(X, Y )}) c|Y| Floc 2 ln(2 + Floc /ν) + ln(|RX| + 1) This gives us the following bound for the Rademacher complexity of Floc |Y|Lℓ Floc ln(|RX|+1)3/2 Proof of Equation (25). Without optimizing over ϵ above, we plug in ϵ = Gmax(X,Y ) |RX| . We obtain RRX G(X, Y ) 4Gmax(X,Y ) Z Gmax(X,Y )/2 log N ν/2, G, RX {(X, Y )} dν 4Gmax(X,Y ) c|Y|Lℓ Floc Z Gmax(X,Y )/2 ln(2 + 4Lℓ Floc /ν) + ln(|RX| + 1) 4Gmax(X,Y ) c|Y|Lℓ Floc Z Gmax(X,Y )/2 ln((Gmax(X,Y )+4Lℓ Floc )/ν)+ln(|RX|+1) 4Gmax(X,Y ) c|Y|Lℓ Floc ln((1+4Lℓ Floc /Gmax(X,Y ))/ν )+ln(|RX|+1) 4Gmax(X,Y ) c|Y|Lℓ Floc ln (1 + 4Lℓ Floc /Gmax(X, Y )) q |RX| + ln(|RX| + 1) 3/2 ln(a/x) + b/xdx = 2/3(ln(a/x) + b)3/2 for the final inequality, and ignore the negative part. Example: Floc ℓ2 bounded RKHS (Lei et al., 2019): We consider a fixed kernel K(x, x ) = φ(x), φ(x ) for x, x X, and let HK be the RKHS induced by K. Let us define the ℓp,q norm for the vectors W = (w1, w2, . . . , w|Y|) H|Y| K as (w1, . . . , w|Y|) p,q = ( w1 p, . . . , w|Y| p) q. For some norm bound Λ > 0, the local hypothesis space is defined as Floc = {f( ) : fy( ) = wy, φ( ) , wy HK, y Y, (w1, . . . , w|Y|) 2,2 Λ}. Recall that we have the loss function class G = {ℓ(γf( , )) : f Floc}, where the loss function ℓ( ) is assumed to be L-Lipschitz continuous w.r.t. ℓ norm. 3We correct for a typographical error in Zhang (2004), where the n |RX| comes in the denominator of the bound presented in Lemma 17. But Theorem 4 of Zhang (2002) shows this is a typographical error. Indeed, the covering number is not suppossed to decrease with increasing number of points. A Statistical Perspective on Retrieval-Based Models Given the retrieved set RX for some positive integer n 1, FX after Equation (8) in Lei et al. (2019) induced by RX. 4 Let the worst case Rademacher complexity of a function class F over n points be defined as Rn(F). Also, for a set S let ˆB(S) = max(x,y) S sup W : W 2,2 Λ wy, φ(x) . We have from Theorem 23 in Lei et al. (2019) that the covering number is bounded as follows: for any set S = {(xi, yi) : i = 1, . . . , n} of size n 1, for any ε > 4LRn|Y| FX log N ε, G, S 16n|Y|L2(Rn|Y| FX )2 ε2 log 2en|Y| ˆ B(S)L ε . Furthermore, from equation (18) in Lei et al. (2019) we have for any set Λ max(x,y) S φ(x) 2 2n|Y| Rn|Y| FX Λ max(x,y) S φ(x) 2 Therefore, we have for all ε 4L Λ max(x,y) S φ(x) 2 log N ε, G, S 16 max(x,y) S φ(x) 2 2Λ2L2 ε2 log 2en|Y| ˆ B(S)L ε . Plugging this covering number in in our Rademacher bound with ϵ 4L Λ max(x,y) S φ(x) 2 2(|RX|+1)|Y| and taking S = RX {(X, Y )} RRX G(X, Y ) inf ϵ [0,Gmax(X,Y )/2] Z Gmax(X,Y )/2 log N (ν/2, G, RX {(X, Y )})dν 16 max(x,y) RX {(X,Y )} φ(x) 2ΛL p 2(|RX| + 1)|Y| + 12 16 max(x,y) RX {(X,Y )} φ(x) ΛL p Z Gmax(X,Y )/2 4LΛ max(x,y) RX {(X,Y )} φ(x) 2 2(|RX|+1)|Y| log 4e(|RX|+1)|Y| ˆ B(RX {(X,Y )})L 16 max(x,y) RX {(X,Y )} φ(x) 2ΛL p 2(|RX| + 1)|Y| + 8 16 max(x,y) RX {(X,Y )} φ(x) ΛL p 2e L ˆ B(RX {(X,Y )})(|RX|+1)|Y| (|RX|+1)|Y| 4LΛ max(x,y) RX {(X,Y )} φ(x) 2 3/2 16 max(x,y) RX {(X,Y )} φ(x) 2ΛL p 2(|RX| + 1)|Y| + 8 16 max(x,y) RX {(X,Y )} φ(x) ΛL p 2e (|RX| + 1)|Y| 3/2 3/2 In the final inequality we use the fact that ˆB(RX {(X, Y )}) max (x,y) RX {(X,Y )} φ(x) 2 sup W : W 2,2 Λ W 2, max (x,y) RX {(X,Y )} φ(x) 2Λ Therefore, the final bound on the Rademacher complexity can be given as RRX O Lℓ Floc ln(|Y||RX|)3/2 Example: Floc L-layer fully connected deep neural network (DNN)(Bartlett et al., 2017): Following Bartlett et al. (2017), we consider a L-layer deep neural network (DNN) f A = σL(ALσL 1(AL 1σL 2(. . . A1x)) for x X where 4We need FX only to state some theorems in Lei et al. (2019). We refer interested readers to Lei et al. (2019) for the details. A Statistical Perspective on Retrieval-Based Models A = (A1, A2, . . . , AL) is the sequence of weight matrices. The matrix Al Rdl 1 dl for l = 1 to L, with d L = |Y|, and d0 = d given X Rd. Furthermore, σl( ) : Rdl Rdl denotes the non-linearity (including pooling and activation), σl-s are taken to be 1-Lipschitz, and σl(0) = 0. We assume that the Al matrix is initialized at M l, for each l = 1 to L. We consider the local function class Floc = {f A : Al M l 2,1 bl, Al σ sl, l l L 1}. Furthermore, we have for any f Floc and any x X the function (f(x), y) ℓ(γf( , )) is 2Lℓ-Lipschitz. Therefore, for a fixed set S, we have from Theorem 3.3 in Bartlett et al. (2017) that the covering number of the G = {ℓ(γf( , )) : f A Floc} is given as log N2 ε, G, S 4L2 ℓB2ln(2d2 max) ε2 L Y l=1 sl 2 L X l=1 (bl/sl)2/3 3/2 = R where dmax = max L l=1 dl, q x S x 2 2 B, and R = 4L2 ℓB2ln(2d2 max) L Y l=1 sl 2 L X l=1 (bl/sl)2/3 3/2. Using a the covering number based bound on Rademacher complexity we obtain RRX G(X, Y ) inf ϵ [0,G2,max(X,Y )/2] Z G2,max(X,Y )/2 log( 4LℓB QL l=1 sl ν ) log N2 ν/2, G, RX dν inf ϵ [0,G2,max(X,Y )/2] Z Gmax(X,Y )/2 log( 4LℓB QL l=1 sl ν ) R inf ϵ [0,G2,max(X,Y )/2] |RX| log3/2( 4LℓB QL l=1 sl ϵ ) 8 |RX| log3/2( 8LℓB QL l=1 sl G2,max(X,Y ) ) 4G2,max(X,Y ) |RX| log3/2( 4LℓB QL l=1 sl |RX| G2,max(X,Y ) ) 8 |RX| log3/2( 8LℓB QL l=1 sl G2,max(X,Y ) ) B.4. Function approximation The following proposition states that the expected loss between Floc classes can be bounded by the L1 regression error between these two. Proposition B.5. Let F1 and F2 be two classes for scorer functions, and loss ℓsatisfy Assumption 3.2 with Lipschitz constant Lℓ. Then for any x X the following holds min f F1 EDx,r[ℓ(f(X), Y )] min f F2 EDx,r[ℓ(f(X), Y )] + 2Lℓmax f F2 min f F1 EDx,r[max s |fs(X) f s(X)|]. Proof. We compare the loss from two arbitrary (measurable w.r.t. DX,r) functions f and f next, where we are trying to approximate f with f. E[ℓ(f(X), Y )] = E[ℓ(γf(X, Y ))] = E[ℓ(γ f(X, Y ))] + E[ℓ(γf(X, Y )) ℓ(γ f(X, Y ))] E[ℓ(γ f(X, Y ))] + LℓE[|γf(X, Y ) γ f(X, Y )|] = E[ℓ( f(X), Y )] + LℓE[|f Y (X) f Y (X) max s =Y fs(X) + max s =Y fs(X)|] E[ℓ( f(X), Y )] + 2LℓE[max s |fs(X) fs(X)|] A Statistical Perspective on Retrieval-Based Models Let f X, := arg minf FX E[ℓ(f(X), Y )]. Replacing the f = f X, and taking minima over Floc on both sides, we obtain. min f Floc E[ℓ(f(X), Y )] min f FX E[ℓ(f(X), Y )] + 2Lℓmin f Floc E[max s |fs(X) f X, s (X)|]. Applying the above result, we have εloc = 2Lℓminf Floc E[maxs |fs(X) f X, s (X)|] in Eq. (6). A similar argument establishes that εX = 2Lℓminf FX E[maxs |fs(X) f s (X)|] in Eq. (4), where the function f is the population minimizer of the loss over distribution DX among the function class Fglobal. B.5. Proof of Proposition 3.6 Under the weak density condition, for any r > 0 we have PD |RX| = 0 = 0. Furthermore, for any N 1, and x X PD |Rx| < N PD n X i=1 1 d(Xi, x) r < N i=1 1 d(Xi, x) min{r, δwdcρD(x) 1/d} < N exp( 2(p(x, r) N/n)2n) Let p(x, r) := min{cd wdcρD(x)rd, cd wdcδd wdc}, then PD d(Xi, x) min{r, δwdcρD(x) 1/d} p(x, r). Using Chernoff bound we obtian the final inequality for the above definition of p(x, r). It can be shown that choosing N = n min{cd wdcρD(x)rd, cd wdcδd wdc} q we obtain PD |Rx| < N δ for any δ > 0. Recall, PD[ρD(X) < fρ(δ)] δ} for any δ > 0. Let N(r, δ) n min{cd wdcfρ(δ)rd, cd wdcδd wdc} q Then, we have PD |RX| < N(r, δ) PD |RX| < N(r, δ)|ρD(X) fρ(δ/2) + PD[ρD(X) < fρ(δ/2)] δ/2 + δ/2 = δ For the first term in the final inequality, we use the fact that for all x X such that ρD(x) ρ 1 D (δ/2), we have PD |Rx| < N(r, δ) δ/2. For the second term in the final inequality, we just use the definition of fρ(δ). B.6. Computation of the function fρ(δ) in Proposition 3.6 For non-degenerate multi-dimensional Gaussian distributions we have ρN(µ,Σ)(x) = (2π) d/2|Σ| 1/2 exp( 1 2(x µ)T Σ 1(x µ)) Therefore, the level sets are given as PN(µ,Σ) x : ρN(µ,Σ)(x) (2π) d/2|Σ| 1/2γ x:(x µ)T Σ 1(x µ) 2 ln(1/γ) (2π) d/2|Σ| 1/2 exp( 1 2(x µ)T Σ 1(x µ))dx x:(x µ)T Σ 1(x µ) 2 ln(1/γ) (2π) d/2|Σ| 1/2 Z 2 (x µ)T Σ 1(x µ) exp( t)dtdx A Statistical Perspective on Retrieval-Based Models = (2π) d/2|Σ| 1/2 Z x:(x µ)T Σ 1(x µ) min{2t,2 ln(1/γ)} dx | {z } volume of ellipsoid = (2π) d/2|Σ| 1/2|Σ|1/2 πd/2 Γ(d/2+1) Z t=ln(1/γ) (2 ln(1/γ))d/2 exp( t)dt + Z ln(1/γ) t=0 (2t)d/2 exp( t)dt = 1 Γ(d/2+1) Γ(d/2 + 1) Z ln(1/γ) (td/2 ln(1/γ)d/2) exp( t)dt = 1 1 Γ(d/2+1) 0 ((t + ln(1/γ))d/2 ln(1/γ)d/2) exp( (t + ln(1/γ)))dt 1 γ Γ(d/2+1) 0 (q(t /q)d/2 + r(ln(1/γ)/r)d/2 ln(1/γ)d/2) exp( t ))dt [(qα + rβ)p qαp + rβp : q + r = 1] 1 ln(1/γ)d/2 1γ + ((1 1/ln(1/γ)) d/2+1 1) Γ(d/2+1) γ ln(1/γ)d/2 1 ln(1/γ)d/2 1γ + ((1 1/ln(1/γ)) d/2+1 1) Γ(d/2+1) γ ln(1/γ)d/2 1 2.45γ ln(1/γ)d/2, γ 1/2 We now extend the results to mixture of distributions. It is easily shown below that if each of the mixture component k K satisfies PDk x : ρDk(x) γ cγ ln(1/γ)d/2 then PDmix x : ρDmix(x) γ = X k wk PDk x : X l wlρDl(x) γ X k wk PDk x : ρDk(x) γw 1 k k wkcγw 1 k ln(wk/γ)d/2 c Kγ ln(1/γ)d/2. C. Comparison of risk bounds We now compare the risk bounds of the proposed explicit local ERM (which we think of as proxy towards understanding the implicit local learning happening in Retrieval augmented models) between different parametric, and non-parametric methods. C.1. Sobolev spaces (Yarotsky, 2017) In the following paragraph we briefly describe the setting of (Yarotsky, 2017) for the most part borrowing the notations from the authors. The authors study the approximation of functions f : Rd R, with Relu networks with the metrics maxx [0,1]d |f(x) f(x)| for some approximation f. They consider the Sobolev spaces Wk, ([0, 1]d) for n = 1, 2, . . . . For a function in Sobolev spaces Wk, ([0, 1]d) the weak derivatives upto order k are bounded in L norm. In particular, we define the norm in Wk, ([0, 1]d) as . f Wk, ([0,1]d) = max |k| k ess sup x [0,1]d Dkf(x) , where k {0, 2, . . . , k}d is the multi-index of the weak derivative Dk, and |k| = Pd ki=1 ki. 5 The function class Fk,d = {f Wk, ([0, 1]d) : f Wk, ([0,1]d) 1}. From Theorem 4 we know that to approximate Fk,d within accuracy ϵ (0, 1/2) we require Ω(ϵ d/2k) weights in general, and with a depth O(lnp(1/ϵ)) network, for any p 0, we require Ω(ϵ d/k ln (2p+1)(1/ϵ)) weights. A standard bound of Taylor series ensures that a degree (k 1) Taylor polynomial will approximate the function class Fk,d for any x in the L2-radius r of x (hence L -radius r as L2 norm upper bounds L norm) with accuracy dkrk/k!, i.e. for any f Fk,d there exists f(x ) P(k) max x : x x 2 r |f(x ) f(x )| dkrk k! f Wk, ([0,1]d) dkrk 5Note we adopt bold symbol only for multi-indices, whereas vectors in Rd are denoted without bold symbol. A Statistical Perspective on Retrieval-Based Models In particular, f(x ) can be taken as the Taylor polynomial of degree at most k f(x ) = P |k| k 1 ak(x x )k where |ak| 1. Hence, we have P |k| k 1 ak d+k 1 k 1 . Connecting back to our approximation with the approximation error εloc we have for Fglobal = Fk,d the degree qx = (k 1) for all x X, C(Fk,d, k 1) = dkrk k! , and C (Fk,d, k 1) = d+k 1 k 1 . D. Proofs for Section 3.4 This section focuses on providing a proof of Proposition 3.8. It follows the proof technique of (Foster et al., 2019, Eq. (9)). Before presenting the proof of Proposition 3.8, we need to introduce a slight variation of the Rademacher complexity for data-dependent hypothesis set. Let Z = X Y. Let R = {z R j }, T = {z T j } Zm be two m-sized samples and σ {+1, 1}m be a vector of independent Rademacher variables. Now define RT,σ = {z RT,σ j } Zm such that ( z R j , if σj = 1, z T j , if σj = 1, (27) i.e., RT,σ is obtained by replacing i-th element of R by i-th element of T iff σi = 1. Let U Zn m be an m n-sized sample; for R Zm, SR = U R Zn. Note that, following this notation, we have SRT,σ = U RT,σ. For S Zn, let H(S) be a data dependent function class (hypothesis set), which does not depend on the ordering of the elements in S. Definition D.1 (Rademacher complexity for data-dependent function class). Let H = {H(S)}S Zn be a family of data dependent function classes. Given R = {z R j [m]}, T = {z T j [m]} Dm and U = {z U m+i}i [n m], the empirical Rademacher complexity R U,R,T(H) and Rademacher complexity R U,m(H) are defined as follows. R U,R,T(H) = 1 sup h H(SRT,σ ) i=1 σih(z T i ) m ER,T Dm σ sup h H(SRT,σ ) i=1 σih(z T i ) D.1. Proof of Proposition 3.8 We are now ready to establish the proof of Proposition 3.8. As discussed above, we extend the proof technique of (Foster et al., 2019, Eq. (9)) to obtain this result. Our setting differs from that of (Foster et al., 2019) as the local ERM objective only depends on the retrieve samples Rx while the function class of interest FS = FΦS in (16) depends on the entire training set S via representation ΦS. We suitably modify the proof techniques of (Foster et al., 2019) to handle this difference. Let |Rx| := m and U = S\Rx. For R, T Zm, we define Ξ(R, T) = sup f FΦU R E(X ,Y ) Dx,r[ℓ(f(X ), Y )] | {z } :=Rℓ(f;Dx,r) (x ,y ) T ℓ(f(x ), y ) := b Rℓ(f;T) = sup f FΦU R Rℓ(f; Dx,r) b Rℓ(f; T) . Note that we are interested in bounding Ξ(Rx, Rx) = sup f FΦS E(X ,Y ) Dx,r[ℓ(f(X ), Y )] | {z } Rℓ(f;Dx,r) (x ,y ) T ℓ(f(x ), y ) | {z } b Rℓ(f;Rx)= b Rx ℓ(f) where we have used the fact that U Rx = S. Towards this, we first establish that Ξ(R, R) satisfies the Mℓ m + 2 LLℓ,1 - bounded difference property, i.e., for R, R Zm that only differ in one element, we have Ξ(R, R) Ξ(R , R ) Mℓ m + 2 LLℓ,1. (29) A Statistical Perspective on Retrieval-Based Models Ξ(R, R) Ξ(R , R ) Ξ(R, R) Ξ(R, R ) | {z } I + Ξ(R, R ) Ξ(R , R ) | {z } II Now, we will separately bound the two terms in the RHS. Let z = ( x, y) R\R and z = ( x , y ) R \R. Thus, we have the following bound on the first term. I = Ξ(R, R) Ξ(R, R ) = sup f FΦU R Rℓ(f; Dx,r) b Rℓ(f; R) sup f FΦU R Rℓ(f; Dx,r) b Rℓ(f; R ) sup f FΦU R Rℓ(f; Dx,r) b Rℓ(f; R) Rℓ(f; Dx,r) b Rℓ(f; R ) sup f FΦU R Rℓ(f; Dx,r) b Rℓ(f; R) Rℓ(f; Dx,r) + b Rℓ(f; R ) = sup f FΦU R b Rℓ(f; R ) b Rℓ(f; R) = sup f FΦU R ℓ(f( x ), y ) ℓ(f( x), y) Mℓ where the last inequality follows from our boundedness assumption for the loss function ℓ. Now we move to term II. Towards this, note that, it follows from the definition of supremum that, for any ϵ > 0, there exists f FΦU R such that sup f FΦU R Rℓ(f; Dx,r) b Rℓ(f; R ) ϵ Rℓ( f; Dx,r) b Rℓ( f; R ) (32) Let f = g ΦU R FΦU R and f = g ΦU R FΦU R . Note that, for any (x, y) Z, ℓ f(x), y ℓ f (x), y = ℓ g ΦU R(x), y ℓ g ΦU R (x), y (i) Lℓ,1 g ΦU R(x) g ΦU R (x) Lℓ,1 g ΦU R(x) g ΦU R (x) 2 (ii) Lℓ,1L ΦU R(x) ΦU R (x) 2 (iii) Lℓ,1L , (33) where we use Lℓ,1-Lipschitzness of ℓw.r.t. norm, L-Lipschitzness of g, and -sensitivity of the representation Φ in (i), (ii), and (iii), respectively. Now, we have II = Ξ(R, R ) Ξ(R , R ) = sup f FΦU R Rℓ(f; Dx,r) b Rℓ(f; R ) sup f FΦU R Rℓ(f; Dx,r) b Rℓ(f; R ) (i) Rℓ( f; Dx,r) b Rℓ( f; R ) + ϵ sup f FΦU R Rℓ(f; Dx,r) b Rℓ(f; R ) Rℓ( f; Dx,r) b Rℓ( f; R ) + ϵ Rℓ( f ; Dx,r) b Rℓ( f ; R ) = Rℓ( f; Dx,r) Rℓ( f ; Dx,r) b Rℓ( f; R ) b Rℓ( f ; R ) + ϵ Rℓ( f; Dx,r) Rℓ( f ; Dx,r) + b Rℓ( f; R ) b Rℓ( f ; R ) + ϵ (ii) 2Lℓ,1L + ϵ, (34) A Statistical Perspective on Retrieval-Based Models where (i) and (ii) follow from (32) and (33), respectively. Now, since ϵ in (32) can be chosen arbitrarily small, it follows from (30), (31), and (34) that Ξ(R, R) Ξ(R , R ) Mℓ m + 2 LLℓ,1, i.e., Ξ(R, R) indeed satisfies the Mℓ m + 2 LLℓ,1 -bounded difference property. Now, it follows from the Mc Diarmid s inequality that, for δ > 0, we have with probability at least 1 δ: Ξ(Rx, Rx) E Ξ(Rx, Rx) + Mℓ+ 2 LLℓ,1m r Rℓ(f; Dx,r) b Rx ℓ(f) ERx sup f FΦS Rℓ(f; Dx,r) b Rx ℓ(f) + Mℓ+ 2 LLℓ,1m r Now, first statement of Proposition 3.8 follows from (35) and the fact that m = |Rx|. It follows from the proof steps in Foster et al. (2019, Section E.1) that ERx h sup f FΦS=U Rx Rℓ(f; Dx,r) b Rx ℓ(f) i 2R U(ℓ F), (36) where F = {FΦU R}R Zm and R U is defined in (28). This completes the proof of Proposition 3.8. E. Classification in extended feature space: A kernel-based approach As introduced in Sec. 2.3, our objective is to learn a function f : X (X Y) R|Y|. For a given instance x, such a function can leverage its neighboring set Rx (X Y) to improve the prediction on x. In this work, we restrict ourselves to a sub-family of such retrieval-based methods that first map Rx Dx,r to ˆDx,r an empirical estimate of the local distribution Dx,r, which is subsequently utilized to make a prediction for x. In particular, the scorers of interest are of the form: (x, Rx) 7 f(x, ˆDx,r) = f1(x, ˆDx,r), . . . , f|Y|(x, ˆDx,r) R|Y|, (37) where fy(x, ˆDx,r) denotes the score assigned to the y-th class. Thus, assuming that X Y denotes the set of distribution over X Y, we restrict to a suitable function class in {f : X X Y R|Y|}. Note that, given a surrogate loss ℓ: R|Y| Y R and scorer f, the empirical risk b Rex ℓ(f) and population risk Rex ℓ(f) take the following form: ˆRex ℓ(f) = 1 i [n] ℓ xi, ˆDxi,r and Rex ℓ(f) = E(X,Y ) D ℓ f(X, DX,r), Y . (38) Note that that the general framework for learning in the extended feature space e X := X X Y provides a very rich class of functions. In this paper, we focus on a specific form of learning methods in the extended feature space by using the kernel methods. The method as well as its analysis is obtained by adapting the work on utilizing kernel methods for domain generalization (Blanchard et al., 2011; Deshmukh et al., 2019). E.1. Kernel-based classification Before introducing a kernel method for the classification, we need to define a suitable kernel k : e X e X R on the extended feature space e X := X X Y. Towards this, let k Z be a kernel over Z := X Y. Assuming that Hk Z is the reproducing kernel Hilbert space (RKHS) associated with k Z, we can define a kernel mean embedding (Smola et al., 2007) Ψ : Z Hk Z as follows: Z k Z z, d P. (39) A Statistical Perspective on Retrieval-Based Models For an empirical distribution ˆDx,r defined by Rx, kernel embedding in (39) takes the following form. Ψ(ˆDx,r) = 1 |Rx| (x ,y ) Rx k Z (x , y ), . (40) Now, using a kernel k X over X and a kernel-like function κ over Ψ( Z), we define a desired kernel k : e X e X R as follows: k e X1, e X2 = k (X1, DX1,r), (X2, DX2,r) = k X(X1, X2) κ Ψ(DX1,r), Ψ(DX2,r) . (41) Let Hk be the RKHS corresponding to the kernel k in (41), and Hk be the norm associated with Hk. Equipped with the kernel in (41) and associated Hk, for λ > 0, we propose to learn a scorer f = (f1, . . . , f|Y|) H|Y| k := Hk Hk via the following regularized ERM problem. ˆf ex = arg min f H|Y| k i=1 ℓ f( xi), yi + λ Ω(f), (42) where xi = (xi, ˆDxi,r) and Ω(f) := f 2 H|Y| k := P y Y fy 2 Hk. It follows from the representer theorem that the solution of (42) takes the form ˆf ex( ) = P i [n] αik (xi, ˆDxi,r), . One can apply multiclass extensions of SVMs to learn the weights {αi} (Deshmukh et al., 2019). Next, we focus on studying the generalization behavior of the scorer ˆf ex recovered in (42). E.2. Generalization bounds for kernel-based classification Before presenting a generalization bound for kernel-based classification over the extended feature space e X, we state the three key assumptions that are utilized in our analysis. Assumption E.1. The loss function ℓ: R|Y| Y is Lℓ,1-Lipschitz w.r.t. the first argument, i.e., |ℓ(s1, y) ℓ(s2, y)| Lℓ,1 s1 s2 s1, s2 R|Y| and y Y. (43) Furthermore, assume that sup(x,y) ℓ(x, y) := Mℓ . Assumption E.2. Kernels k X, k Z, and κ are bounded by Mk X, Mk Z, and Mκ, respectively. Assumption E.3. Let Hk Z and Hκ be the RKHS associated with k Z and κ, respectively. Then, the canonical feature map ϕκ : Hk Z Hκ is α-H older continuous with α (0, 1], i.e., ϕκ(h1) ϕκ(h2) Hκ L h1 h2 α Hk Z h1, h2 {h Hk Z : h Hk Z Mk Z} (44) The following result states our generalization bound for the kernel-based classification method described in Sec. E.1. Theorem E.4. Let 0 δ 1 and Assumptions E.1 E.3 hold. Furthermore, let N(r, δ) be as defined in (15). Then, for any B > 0, the following holds with probability at least 1 3δ b Rex ℓ(f) Rex ℓ(f) 32 p log 2Lℓ,1BMκMk Xn 1 2 1 + log 3 2 + Lℓ,1L Mk XB n) + 4Mk Z log( n where Fk B = f = (f1, . . . , f|Y|) H|Y| k : Ω(f) B2 and M := Mℓ+ Lℓ,1BMk XMκ. Before presenting the proof of Theorem E.4, we state two key results from the literature that are used in our analysis. Proposition E.5 ((Steinwart & Christmann, 2008)). Let (Ω, A, P) be a probability space, H be a separable Hilbert space, and M > 0. Let η1, . . . , ηm : Ω H be m independent H-valued random variables satisfying ηj M, for all j [m]. The, for δ > 0, the following holds with probability at least 1 δ. j=1 (ηj EP [ηj] H M 1 m + 4M log(1/δ) A Statistical Perspective on Retrieval-Based Models Proposition E.6. (Deshmukh et al., 2019; Lei et al., 2019) Let e Z = e X Y be (extended) input and output space pair and S = z1, . . . , zn . Let Hk be a RKHS defined on e X, with k being the associated kernel. Let Fk B = (f1, . . . , f|Y|) : fy Hk y Y and X y Y fy p Hk and ℓ: R|Y| Y R be a Lipschitz function in its first argument, i.e., |ℓ(s1, y) ℓ(s2, y)| Lℓ,1 s1 s2 s1, s2 R|Y| and y Y. Then the Rademacher complexity of the induced function class ℓ Fk B := {ℓ f : f Fk B} satisfies R S ℓ Fk B := Eσi h sup f Fk B i [n] σiℓ f( xi), yi i log 2B sup x X k( x, x)n 1 1 2 1 max{2,p} 1 + log 3 2 2n|Y| . (46) Note that σ = (σ1, . . . , σn) denotes n i.i.d. Rademacher random variable. Proof of Theorem E.4. Note that b Rex ℓ(f) Rex ℓ(f) = sup f Fk B i=1 ℓ f(xi, b Dxi,r), yi E(X,Y ) D ℓ f(X, DX,r), Y i=1 ℓ f(xi, b Dxi,r), yi 1 i=1 ℓ f(xi, Dxi,r), yi | {z } I i=1 ℓ f(xi, Dxi,r), yi E(X,Y ) D ℓ f(X, DX,r), Y | {z } II Bounding the term-I in (47). Note that I = sup f Fk B i=1 ℓ f(xi, b Dxi,r), yi 1 i=1 ℓ f(xi, Dxi,r), yi i [n] f(xi, b Dxi,r) f(xi, Dxi,r) i [n] max y Y |fy(xi, b Dxi,r) fy(xi, Dxi,r)| Lℓ,1 max y Y max i [n] |fy(xi, b Dxi,r) fy(xi, Dxi,r)| (48) It follows from the reproducing property of the kernel k that, for any y Y, |fy(xi, b Dxi,r) fy(xi, Dxi,r)| = | fy, k((xi, b Dxi,r), ) k((xi, Dxi,r), ) | fy Hk k((xi, b Dxi,r), ) k((xi, Dxi,r), ) Hk. (49) k((xi, b Dxi,r), ) k((xi, Dxi,r), ) Hk = k((xi, b Dxi,r), (xi, b Dxi,r)) + k((xi, Dxi,r)), (xi, Dxi,r)) 2k((xi, b Dxi,r), (xi, Dxi,r)) Hk 1/2 A Statistical Perspective on Retrieval-Based Models k X(xi, xi) κ(Ψ(b Dxi,r), Ψ(b Dxi,r)) + κ(Ψ(Dxi,r)), Ψ(Dxi,r)) 2κ(Ψ(b Dxi,r), Ψ(Dxi,r)) Hk 1/2 k X(xi, xi) κ(Ψ(b Dxi,r), ) κ(Ψ(Dxi,r), ) Hκ Mk X κ(Ψ(b Dxi,r), ) κ(Ψ(Dxi,r), ) Hκ (50) = Mk X ϕκ(Ψ(b Dxi,r)) ϕκ(Ψ(Dxi,r)) Hκ L Mk X Ψ(b Dxi,r) Ψ(Dxi,r) α Hk Z (51) By combining (49) and (50), we obtain that |fy(xi, b Dxi,r) fy(xi, Dxi,r)| L Mk X fy Hk Ψ(b Dxi,r) Ψ(Dxi,r) α Hk Z. (52) Now, Hoeffding s inequality in Hilbert spaces (cf. Proposition E.5) implies that, for i [n], the following holds with probability at least 1 δ. Ψ(b Dxi,r) Ψ(Dxi,r) α Hk Z = 1 |Rxi| (x ,y ) Rxi k Z((x , y ), ) EDxi,r k Z((X , Y ), ) Hk Z |Rxi| + Mk Z 1 |Rxi| + 4Mk Z log(1/δ) 3|Rxi| . (53) It follows from (52) and (53) that, for each i [n], |fy(xi, b Dxi,r) fy(xi, Dxi,r)| L Mk X fy Hk Mk Z δ ) |Rxi| + Mk Z 1 |Rxi| + 4Mk Z log( 1 holds with probability at least 1 δ. Next, taking union bound over i [n] implies that the following holds for all i [n] and y Y with probability at least 1 δ. |fy(xi, b Dxi,r) fy(xi, Dxi,r)| L Mk X fy Hk |Rxi| + Mk Z 1 |Rxi| + 4Mk Z log(n/δ) Recall that, for each i [n], we have |Rxi| N(r, δ) with probability at least 1 δ (cf. (15)). Using union bound, we have |Rxi| N(r, δ/n), i [n], with probability at least 1 δ. Thus, the following holds for all i [n] and y Y with probability at least 1 2δ |fy(xi, b Dxi,r) fy(xi, Dxi,r)| L Mk X fy Hk N(r, δ/n) + Mk Z 1 N(r, δ/n) + 4Mk Z log(n/δ) By using fy Hk B and combining (48) with (56), we obtain that I Lℓ,1L Mk XB N(r, δ/n) + Mk Z 1 N(r, δ/n) + 4Mk Z log(n/δ) holds with probability at least 1 2δ. Bounding the term-II in (47). Note that II = sup f Fk B i=1 ℓ f(xi, Dxi,r), yi E(X,Y ) D ℓ f(X, DX,r), Y (58) A Statistical Perspective on Retrieval-Based Models Using the Assumptions E.1 and E.2 and the fact that f Fk B, we can argue that ℓ f(x, Dx,r), y = ℓ(0, y) + |ℓ f(x, Dx,r), y ℓ(0, y)| Mℓ+ Lℓ,1 f(x, Dx,r) Mℓ+ Lℓ,1 max y Y fy , k (x, Dx,r), Mℓ+ Lℓ,1 max y Y fy Hk Mk Mℓ+ Lℓ,1RMk Mℓ+ Lℓ,1RMk XMκ := M Now, it follows from the Azuma-Mc Diarmid s inequality that the following holds with probability at least 1 δ. i=1 ℓ f(xi, Dxi,r), yi E(X,Y ) D ℓ f(X, DX,r), Y i=1 ℓ f(xi, Dxi,r), yi E(X,Y ) D h ℓ f(X, DX,r)Y i Using the standard symmetrization procedure, we get that i=1 ℓ f(xi, Dxi,r), yi E(X,Y ) D h ℓ f(X, DX,r)Y i n E(Xi,Yi) DEσi i [n] σiℓ f(xi, Dxi,r), yi = 2 Re S ℓ Fk B where σ = (σ1, . . . , σn) denotes n i.i.d. Rademacher random variables and Re S ℓ Fk B denote the Rademarcher complexity of the function class ℓ Fk B = n (x, y, Dx,r) 7 ℓ f(x, Dx,r), y : f Fk B o . Now, using Proposition E.6 with p = 2 and Assumption E.2, we have R S ℓ Fk B 16Lℓ,1 p log 2B sup x X k( x, x)n 1 2 1 + log 3 2 log 2BMκMk Xn 1 2 1 + log 3 2 Now, by combining (58), (59), and (60), we obtain that with probability at least 1 δ log 2Lℓ,1BMκMk Xn 1 2 1 + log 3 2 Finally, combining (47), (57) and (61) completes the proof. E.3. Empirical verification We run an experiment to empirically verify the kernel based extended feature space-based approach. We design a kernel for extended feature space as in (41). In particular, we use Gaussian-like function for κ(Ψ(Dx1,r), Ψ(Dx2,r)) = exp( Ψ(Dx1,r) Ψ(Dx1,r) 2/2σ2 κ). A Statistical Perspective on Retrieval-Based Models To empirically estimate the distance between kernel mean embeddings of the two distributions Ψ(ˆDx1,r) Ψ(ˆDx1,r) 2 we follow Muandet et al. (2017); Li et al. (2015) as: Ψ(ˆDx1,r) Ψ(ˆDx1,r) 2 = 1 |Rx1|2 X (x ,y ) Rx1 (x ,y ) Rx1 k Z (x , y ), (x , y ) + 1 |Rx2|2 X (x ,y ) Rx2 (x ,y ) Rx2 k Z (x , y ), (x , y ) 2 |Rx1||Rx2| (x ,y ) Rx1 (x ,y ) Rx2 k Z (x , y ), (x , y ) We took k Z (x , y ), (x , y ) = exp( x x /2σ2 x λ1{y = y }), which is basically like a normal L2 distance with labels concatenated as one-hot vectors. Also, k X (x1, x2) = exp( x1 x2 2/2σ2 x) with normal L2 distance, with which we finally obtain the overall kernel for the extended feature space as: k e X1, e X2 = k (X1, DX1,r), (X2, DX2,r) = k X(X1, X2) κ Ψ(DX1,r), Ψ(DX2,r) . For the synthetic dataset from Sec. 5, results are tabulated below: Table 1. Accuracy of kernel classifier over extended feature space kernel as a function of number of retrieved neighbors used to form the extended feature space. Neighbors Kernel machine 2 0.776 0.023 5 0.769 0.021 10 0.777 0.019 20 0.835 0.021 50 0.819 0.021 100 0.792 0.022 200 0.585 0.020 Note that, as the generalization bound suggested, that the model performance only improves up to a specific number of neighbors and starts degrading when further increasing the number of neighbors. It s also worth highlighting that a (4 layer) Transformer model that directly processes an instance along with the associated retrieved neighboring examples achieves a much higher performance of 0.898 with 10 neighbors. This is consistent with similar observations in the deep learning literature where kernel-based methods are often significantly outperformed by end-to-end neural networks (Bai & Lee, 2019; Chen et al., 2020; Samarin et al., 2020). A Statistical Perspective on Retrieval-Based Models F. Additional details for experiments F.1. Synthetic 100 101 102 103 Number of nearest neighbor Linear Poly(deg=3) RBF MLP(layers=2) k NN Retrieve+Linear Retrieve+Poly(deg=3) Retrieve+RBF Retrieve+MLP(layers=2) Figure 4. Performance of ERM and local ERM for various models on synthetic data. Task and data. We consider the task of binary classification on mixtures using synthetic data: In particular, we assume k = 100 clusters in a D = 10-dimensional space. Each cluster is specified by a mean parameter µi RD Uniform( 10, 10) and a classification weight vector wi Rd N(0, I) for i = 1, 2, , k. We randomly generate a train set of n = 10000 points as follows: To generate a labeled example (xj, yj), j [n]: 1) select a cluster i uniformly at random, and 2) sample xj N(µi, I) and its label yj = sign(w T i (xj µi)). Additionally, we also generate another set of points as test set using the same procedure. Methods. As baseline, we consider models of various complexity, starting from simple linear classifier, to support vector machines with polynomial kernel (of degree 3) and with radial basis function (RBF) kernel, to a multi-layer perceptron (MLP) of two layers. For retrieval-based models, we consider each of the above method as the local model to fit on retrieved data points via local ERM framework (Sec. 3). Additionally, we also report simple k NN baseline. We compare all these methods using classification accuracy on the held out test set. We repeat all the experiments 10 times. Observations. In Figure 4, we observe the tradeoff of varying the size of the retrieved set (as dictated by the neighborhood radius) on the performance of the proposed algorithms. We see that when the number of retrieved samples is small the local methods have lower accuracy, this is due to large generalization error. When the size of the retrieved sample space is high, the local methods fail to minimize the loss effectively due to the lack of model capacity. We see that this effect being more pronounced for simpler function classes such as linear classifier as compared to RBF or polynomial classifiers. F.2. CIFAR-10 100 101 102 103 Number of nearest neighbor Linear MLP(layers=2) k NN Retrieve+Linear Retrieve+MLP(layers=2) Figure 5. Performance of ERM and local ERM for various models on (binary) CIFAR-10. Task and data. We consider the task of binary classification on a real image data for object detection. In particular, we consider a subset of CIFAR-10 dataset where we only restrict to images from Cat and Dog classes. We randomly partition the data into a train set of n = 10000 points and remaining 2000 points for test. We do a 10-fold cross-validation. Methods. We consider a subset of method from Appendix. F.1. In particular, we only consider a simple linear classifier and a multi-layer perceptron (MLP) of two layers. For retrieval-based models, we consider each of the above methods as the local model to fit on retrieved data points via local ERM framework (Sec. 3). The retrieval is done using L2 distance in the input space directly (no features is extracted). Additionally, we also report simple k NN baseline. We compare all these methods using classification accuracy on the held out test set. We repeat all the experiments 10 times. Observations. Similar to Figure 4, Figure 5 exhibits a tradeoff, where varying the size of the retrieved set (as dictated by the neighborhood radius) impacts the performance of the proposed algorithms. We see when the number of retrieved samples is small the local methods have lower accuracy, this is due to large generalization error; and when the number of retrieved samples is large, simple local function class incurs a large approximation error. F.3. Image Net Task and data. We consider the task of 1000-way image classification on Image Net ILSVRC-12 dataset. We use the standard train-test set split, where we have of n = 1281167 points for training and 50000 points for test. Given large computational cost, we could only run each experiment once. A Statistical Perspective on Retrieval-Based Models 100 101 102 103 104 105 Number of nearest neighbor Linear MLP(layers=2) Vi T-G/14 (So TA) Mobile Net V3 k NN Retrieve+Linear Retrieve+MLP(layers=2) Retrieve+Mobile Net V3 Figure 6. Performance of ERM and local ERM for various models on on Image Net. Methods. We compare proposed Local ERM (Sec. 3) to state-of-the-art (So TA) single model published for this task, which is from the most recent CVPR 2022 (Zhai et al., 2022). For the local parametric model we use a small Mobile Net V3 architecture (Howard et al., 2019) with 4.01M parameters and 156 MFLOPs compute cost. Contrast this to So TA model Vi T-G/14 with 1.84B parameters and 938 GFLOPs compute cost. Following standard practice in literature, we use unsupervised learned features from ALIGN (Jia et al., 2021) to do image retrieval using L2 distance. For solving the local ERM, we fine-tune a Mobile Net V3 model, which has been pretrained on Image Net, on the retrieved set using Adam optimizer with a linear decay schedule. Additionally, we also report simple k NN baseline. We compare all these methods using classification accuracy on the held out test set. Observations In Figure 6, we see that local ERM with a small Mobile Net V3 model is able to achieve the top-1 accuracy of 82.78 whereas a regularly trained Mobile Net-V3 model achieves the top-1 accuracy of only 65.80. Also the result is very competitive with So TA of 90.45 with a much larger model. Thus, the result suggest that the simple local ERM framework (analyzed in our work) is able to demonstrate the utility of retrieval-based models. In particular, it allows a realistic small sized model to attain very competitive numbers on the popular Image Net benchmark. Furthermore, as pointed at end of Sec. 3.4, using global representation from ALIGN embeddings help simplest linear model to outperform Mobile Net-V3 working directly on image input, thereby showcasing the benefits of endowing local ERM with global representation.