# adaptive_sampling_for_efficient_softmax_approximation__db66f4d1.pdf Adaptive Sampling for Efficient Softmax Approximation Tavor Z. Baharav Eric and Wendy Schmidt Center Broad Institute Cambridge, MA, 02142 baharav@broadinstitute.org Ryan Kang Department of Computer Science Stanford University Stanford, CA 94305 txryank@stanford.edu Colin Sullivan AI Division Software Engineering Institute Pittsburgh, PA 15213 csullivan@sei.cmu.edu Mo Tiwari Department of Computer Science Stanford University Stanford, CA, 94305 motiwari@stanford.edu Eric Luxenberg Gridmatic Cupertino, CA 95014 eric@gridmatic.com David Tse Department of Electrical Engineering Stanford University Stanford, CA 94305 dntse@stanford.edu Mert Pilanci Department of Electrical Engineering Stanford University Stanford, CA 94305 pilanci@stanford.edu The softmax function is ubiquitous in machine learning and optimization applications. Computing the full softmax evaluation of a matrix-vector product can be computationally expensive in high-dimensional settings. In many applications, however, it is sufficient to calculate only the top few outputs of the softmax function. In this work, we present an algorithm, dubbed Adaptive Softmax, that adaptively computes the top k softmax values more efficiently than the full softmax computation, with probabilistic guarantees. We demonstrate the sample efficiency improvements afforded by Adaptive Softmax on real and synthetic data to corroborate our theoretical results. Adaptive Softmax yields > 10x gain over full softmax computation on most datasets, yielding up to 30x improvement for Mistral7B evaluated on the Wikitext dataset. The adaptive method we propose for estimating the partition function (the softmax denominator) is of independent interest and can be used in other applications such as kernel density estimation. 1 Introduction The softmax function appears in many different fields and applications. It is often used in multiclass classification problems, as the final operation in a neural network to obtain a probability distribution over classes, in reinforcement learning to obtain a probability distribution over possible actions, and in statistical mechanics to derive various thermodynamic quantities. In machine learning applications, the softmax function often appears as the final operation in classification models and in attention layers. Crucially, the softmax function takes a vector of weights denotes equal contribution 38th Conference on Neural Information Processing Systems (Neur IPS 2024). as input and returns a probability distribution defined by those weights. Formally, the softmax function for a given temperature parameter β R is defined as: σβ(µ)i = eβµi P j eβµj . (1) where µ Rn is the input vector of weights, also referred to as the logits. Usually, the logits are the result of a matrix-vector product (e.g., in a fully connected layer where the softmax is used as the nonlinear activation function). The output of the softmax function (Equation 1) is a probability distribution that is a soft version of the max operator that is differentiable with respect to the logits. The softmax function can thus be used in gradient-based algorithms as a proxy for the non-differentiable max function. Intuitively, the temperature parameter β controls the peakiness of the softmax output. A larger β corresponds to a peakier distribution and a harder max. The choice of β = 1 corresponds to the canonical softmax function σ, and the choice of β = corresponds to the hard argmax. The denominator of Equation (1), P j eβµj, is called the partition function and is denoted by Zβ. The softmax function is critical in many popular, recent machine learning applications like large language models (LLMs). However, it can present a computational bottleneck in high-dimensional applications. During the training of neural networks, for example, each training example x requires the computation of the softmax function σβ(x), the partition function Zβ(x), and their gradients. During inference of these models, the number of possible labels for next-token prediction corresponds to the vocabulary size, which can be in the hundreds of thousands for common languages such as English. As such, there has been significant recent interest in accelerating the computation of the softmax function and its derivatives [37, 14, 15]. Key Observations: In many applications, we are only interested in identifying the top few outputs of the softmax function; in these settings, it is unnecessary to compute the smaller entries. This suggests that some of the computation of the full softmax function may be unnecessary and motivates our study. First, we observe that when the input vector to the softmax function is the result of a matrix-vector product, we can approximate the intermediary computation instead of exactly computing it. This, in turn, allows us to approximate the output of the softmax function and converts the problem of computing the softmax function from a computational one to a statistical one. We also note that the softmax output is heavily influenced by the largest input elements which suggests that we can allocate computation adaptively to larger input elements to estimate them with greater certainty. This procedure is inspired by recent work in multi-armed bandits that converts computational problems to statistical ones [4]. Outline: We begin this study with a summary of related work in Section 2. In Section 3, we formalize the reduction of computing the softmax function to a statistical estimation problem. In Section 4, we propose the Adaptive Softmax algorithm based on this reduction. In the same section, we provide probably approximately correct (PAC) guarantees for Adaptive Softmax and prove that it is more efficient than brute force computation. Crucially, Adaptive Softmax allocates greater computational resources towards important output values. In Section 5, we demonstrate the empirical advantages of our algorithm in several real-world applications, including in a multiclass classification setting and in large language models. In Section 6, we conclude with a discussion of further applications, potential limitations, and directions for future work. 2 Related Work Recent work has identified the computational complexity of the softmax function as a significant bottleneck in recent machine learning applications [37]. Well before the attention revolution, [36] proposed methods to accelerate softmax computation via a hierarchical model. In their work, a binary tree representing possible outputs as leaves is used and, at each step, the model must predict which path to traverse to a leaf. For a balanced tree with n leaves, the computational complexity of the softmax function is reduced to O(log n) from O(n), at the expense of O(n) internal classifiers and providing only an approximate output dictated by the quality of the clustering. Google s word2vec models used Huffman trees instead of vanilla binary trees in a similar tree-based approach [37]. Other approaches include target sampling, noise contrastive estimation and self normalization (summarized [14]), but all of these methods reduce the complexity in terms of the vocabulary size n, rather than by the dimension d. Additionally, our proposed algorithm provides direct PAC guarantees on the original softmax output, instead of approximating the softmax in a sequence of steps without provable accuracy guarantees. Independently, some works have developed fast methods to approximate the softmax gradients during training by using importance sampling over the classes, improving scaling with respect to n [11, 10] leading to a sampled softmax. This is in contrast with Adaptive Softmax , which utilizes importance sampling in an orthogonal direction to subsample the features efficiently, enabling gains in d at both train and test time. These sampled softmax methods were later specialized to kernel-based sampling, resulting in provably bounded bias [12, 38]. However, these and other optimized methods [23] typically require prior knowledge about the desired output label frequencies, leaving them susceptible to phenomena like distribution shift between training and inference data, where the frequency distribution changes at the time the model is evaluated. Unlike these approaches, Adaptive Softmax does not require auxiliary knowledge, is adaptive on a per instance basis, and provides provable guarantees for the true softmax computation directly rather than a proxy. Our algorithm is inspired by adaptive sampling techniques from the multi-armed bandit literature. Randomized algorithms based on multi-armed bandit algorithms have seen a surge of recent work, due to their ability to provide instance-adaptive guarantees for a variety of problems. This idea was first formalized in the specific case of Monte Carlo Tree Search [26] and later studied in the context of hyper-parameter tuning [31]. Recent work has formalized this approach into the framework of Bandit-Based Monte Carlo Optimization [4], where the computational task is reduced to one of statistical estimation that is solved efficiently with adaptivity. Applications of this framework include finding the medoid of a dataset [5, 7, 41], k-nearest neighbor graph construction [30, 34], Monte Carlo permutation-based multiple testing [46], and an adaptive singular value decomposition (SVD) [24]. Most relevant is the recent work of [8], where the authors provide a general framework for adaptive sampling to approximate function evaluation at unknown but estimable points. This work provides general guarantees, but requires a bound on the Lipschitz factor of the function s gradients as input and has potentially poor performance on specific function classes due to its generality. A sub-problem in our softmax approximation is identifying the index of the largest component; this is equivalent to the Maximum Inner Product Search (MIPS) problem on the preceding matrix-vector product. MIPS is a common problem that has inspired many directions of research [22, 32]. Many of these algorithms focus on specific use cases and place restrictive assumptions on the data (e.g., that all elements of the matrix-vector product are positive), require preprocessing, are not adaptive to the underlying data distribution, or lack PAC guarantees. One large family of MIPS algorithms are based on locality-sensitive hashing (LSH) [19, 2]. In addition to significant preprocessing overhead and practical implementation issues, a shortcoming of these LSH-based approaches is that the maximum dot product is often small compared to the vector norms in high dimensions, which necessitates many hashes and significant storage space (often orders of magnitude more than the data itself). Promising LSH-based algorithms have recently been applied to the problem of softmax computation [1, 15]. These methods focus on intensive preprocessing and work primarily by attaining gains in terms of n. In contrast, Adaptive Softmax subsamples matrix-vector products and obtains gains with respect to d. Furthermore, Adaptive Softmax provides an instance-adaptive algorithm with no required preprocessing and still has PAC guarantees. 3 Problem Formulation In this work, we focus on the problem of identifying the k largest entries of the softmax of an input vector that is the result of a computationally expensive matrix-vector multiplication. Specifically, we analyze the setting where the input vector µ is the result of a matrix vector product Ax, as is common in the final linear layer of neural networks (such scenarios frequently arise in other machine learning problems as well [4]). Our objective is to design an algorithm that can with probability at least 1 δ estimate the top k values to multiplicative accuracy ε, where ε and δ are given input parameters. For clarity of exposition, we focus on the case of k = 1, i.e., identifying and estimating the largest component. All our theoretical results, however, easily extend to the setting k > 1 (discussed in Section 4.2). Notation: We use [n] to denote the set {1, 2, . . . , n} and to denote the vector ℓ2 norm, unless otherwise specified. We use Ψ2 to denote the Orlicz norm (i.e., the sub-Gaussianity parameter or variance proxy) of a random variable; this is discussed in greater detail in Appendix A.2 [44]. For matrix A and vector x, we denote the resulting product as µ = Ax. Assuming for notational simplicity that the arms are in sorted order µ1 > µ2 . . . µn, we define the gaps between the entries of µ as i = µ1 µi. We use the convention from the best-arm identification literature that 1 = 2 and assume that 2 > 0 (this assumption is easily relaxed). Furthermore, we define αi = eβµi and γi = eβµi/2, which are proportional to the optimal first (respectively, second) order sampling frequencies; these are discussed further in Section B.2. Finally, we define p as the softmax output, and i as its largest entry (assumed to be unique), i.e., σβ(Ax) = p, i = argmax i [n] pi. Our goal is to design an algorithm which efficiently outputs the best index i and an estimate its value where, with probability at least 1 δ, the best index is correct and the estimated value is within a factor of ϵ multiplicative accuracy. Mathematically, defining the algorithm s outputs as bi [n] and ˆpi [0, 1], we define the success events Eid, Eest where the algorithm identifies the largest entry, and where it estimates its value to within multiplicative accuracy ϵ. We define the algorithm as providing (ε, δ)-PAC guarantees if these events happen simultaneously with probability at least 1 δ, with respect to the randomness of the algorithm. Eid = n bi = i o (2) Eest = {(1 ε)pi ˆpi (1 + ε)pi } . (3) P (Eid Eest) 1 δ (4) Our objective them becomes to design an algorithm that satisfies Equation (4) and minimizes the requisite sample complexity. 4 Adaptive Softmax Algorithm We now introduce the Adaptive Softmax Algorithm, which approximates the output of the softmax in Algorithm 1. First, Adaptive Softmax approximates the softmax normalization constant Zβ to a multiplicative accuracy of ϵ/4 via Normalization Estimation (Algorithm 2). Next, Adaptive Softmax identifies the best arm (or top k arms, depending on the setting) using a standard multi-armed bandit algorithm, Best Arm Id (Algorithm 3). In our setting, arms correspond to different rows of A and pulling arm i corresponds to computing a coordinate-wise scalar product Ai,jxj for some coordinate j (we provide a more formal overview of the best-arm identification problem and the associated algorithm in Appendix A). Finally, Adaptive Softmax estimates the value of the identified best arm (or top k arms) to a multiplicative accuracy of ϵ/4 by sampling each arm a sufficient number of times via Estimate Arm (Algorithm 10). We prove (ε, δ)-PAC guarantees for Adaptive Softmax by union-bounding over the error probabilities in each step of Algorithm 1. Our results will show that, with probability at least 1 δ, Adaptive Softmax is able to identify the largest output of the softmax function and estimate its value to multiplicative accuracy ϵ. Algorithm 1 Adaptive Softmax 1: Input: Matrix A, vector x, temperature β, error ϵ, failure probability δ, variance proxy σ2 2: Output: ˆpi and bi , highest softmax probability and its index 3: # Estimate denominator of softmax 4: ˆZ Normalization Estimation(A, x, β, ϵ/4, δ/3, σ2) 5: # Compute index of best arm. 6: bi Best Arm Id(A, x, δ/3, σ2) 7: # Estimate value of best arm 8: ˆµi Estimate Arm(Ai , x, ϵ/4, δ/3) 9: ˆpi = exp(βˆµi )/ ˆZ 10: return ˆpi , bi These inputs are typical in the multi-armed bandit setting, but the variance proxy σ2 merits additional discussion. In order for our random-sampling-based approach to succeed, a bound on the rate of Algorithm 2 Normalization Estimation 1: Input: Matrix A, vector x, temperature β, target error ϵ, failure probability δ, variance proxy σ2 2: Output: ˆ Zβ, estimate of the partition function 3: Compute ˆµi using T0 = 17β2σ2 log(6n/δ) coordinate samples for each arm 2σ2 log( 6n δ ) T0 5: ˆαi eβ(ˆµi Ci) 6: ˆγi eβ(ˆµi Ci)/2 7: Sample each arm ni = min( ni, d) times to recompute the estimates ˆµi, where ni = β2σ2 max δ P j ˆγj ˆγi ϵ P j ˆαj , 16 log 12 ϵ 2 ˆαi P j ˆαj 8: return c Zβ = P i eβˆµi concentration of the estimators ˆµi is required; the quantity σ2 governs the concentration rate, as we discuss in Appendix A. In practice, such a bound holds very generally, for example as long as A and x have bounded entries. For algorithmic simplicity we utilize the following assumption. Assumption 1. We assume that we are given a variance proxy bound σ2 for the sub-Gaussian parameters of the constructed estimators: σ2 Ai Jx J Ψ2 i, for J Unif([n]). We provide theoretical guarantees for Adaptive Softmax under Assumption 1. Recall that we defined our optimal first and second order sampling frequencies αi = eβµi and γi = eβµi/2 (see Appendix B.2). We first show in Proposition 1 that our softmax normalization estimation algorithm (Algorithm 2) obtains the desired guarantees. Proposition 1. For input ε (0, 1/2), δ (0, 1), and σ satisfying Assumption 1, Algorithm 2 will, with probability at least 1 δ, estimate Zβ = P j eβµj to a multiplicative accuracy of ϵ. On this success event, Algorithm 2 requires at most samples for some absolute constant C, where non-asymptotic bounds with numerical constants are provided in Appendix B. With the sample complexity of Algorithm 2 bounded, the complexity of best arm identification and the cost of estimating the best arm to a target accuracy are readily available from the multi-armed bandit literature. This enables us to state an overall result for Adaptive Softmax (Algorithm 1) in the following Theorem. Theorem 1. For input ε (0, 1/2), δ (0, 1), and σ satisfying Assumption 1, Algorithm 1 identifies the largest component in σβ(Ax) and estimates its value to a multiplicative accuracy of ϵ with probability at least 1 δ, as in (4). On this success event, the algorithm uses T samples where log n log d 2 i + β2 log n j αj + β2 log(1/δ) for some absolute constant C. Tighter bounds with non-asymptotic numerical constants are provided in Appendix B. The proofs of these two results are detailed in Appendix B; we provide some intuition and brief sketches of the proofs here. For Proposition 1, we first show that we can estimate the quantities {αi}, {γi}, to constant multiplicative error with high probability. This allows us to construct a sampling scheme based off of the asymptotically optimal sampling frequencies, and guarantee that each arm is sampled at least half of what this asymptotically optimal frequency requires. Then, sampling each arm i enough so that the first order Taylor expansion of eβˆµi is sufficiently accurate, we can sample further according to these determined frequencies to guarantee PAC estimation. This is an improved and specialized modification of [8] that exploits the structure of the softmax function to remove the assumption of Lipschitz gradients and yield improved sample complexity (this is discussed further in Appendix B.3). Next, we utilize a classical best-arm identification algorithm to identify the best arm with high probability, leveraging standard results in Bandit-Based Monte Carlo Optimization [6]. Finally, we sample the identified best arm enough times to estimate its value to multiplicative accuracy ϵ/4 with high probability. By union bounding over these error probabilities, we achieve the desired PAC guarantees. 4.1 Interpreting Theoretical Results We now simplify and further interpret the sample complexity results in Theorem 1. First, note that the ε 2 dependence exhibited by Normalization Estimation (Algorithm 2) is optimal: it is inherent even in estimating the mean of the best arm to accuracy ε. The cost stemming from the second order error, which scales as ε 1, is bounded between β2σ2 log(n/δ)ε 1 and nβ2σ2 log(n/δ)ε 1, where in the case where one arm is much better than the rest this will match the first term. Concretely, we analyze the setting where the minimum gap is , i.e. µ1 µi = i for all i. Corollary 1. Under the conditions of Theorem 1, when the minimum gap is at least , Algorithm 1 identifies and provides (ε, δ)-PAC estimation (Equation (4)) of the largest softmax entry, using + β2σ2ε 2 log(1/δ) + nσ2 log n log d samples for some universal constant C. In the case where the gap is large ( 2 β log n), β is not too small, and d < een (see Equation (47) for precise statement), this can be simplified to T Cβ2σ2 log n n + ε 1 + ε 2 log(1/δ) . where all sample complexities are for the 1 δ success event. The proof of this upper bound is in Appendix B.1. Note that this directly implies that when the gap is large (i.e. there is a clear largest output element) and ε is constant, the sample complexity is nearly linear in n and is upper bounded by Cβ2σ2n log(n/δ). 4.2 Implementation details and extensions There are many techniques that we can use to extend and improve Adaptive Softmax in practice. We discuss changes from the written algorithm in detail in Appendix C. Randomized Hadamard Transformation: The variance-proxy bound σ2 of the arms plays a large factor in the Adaptive Softmax algorithm s sample complexity. The underlying sub-Gaussianity parameter of these estimators can be improved using techniques from randomized numerical linear algebra, such as the randomized Hadamard transform [42]. If a small number of entries dramatically increase the variance of the estimator, then the randomized Hadamard transform will make the coordinates more uniform. We provide theoretical guarantees for this approach in Appendix A.3.1. Top-k Identification: Extending our algorithmic results from best arm identification (top 1) to identifying multiple components (top k) follows directly from existing multi-armed bandit algorithms. Numerous algorithms have been developed for this setting [21], and variants for computational settings have been developed and studied in [6]. For simplicity and clarity, we focused on the top 1 identification in this paper, but the top k extension readily follows. Furthermore, in numerical experiments we observe estimating the normalization constant Zβ dominates the sample complexity, and the increase in cost from identifying the top k arms and estimating their values to multiplicative accuracy ε/4 is minimal. Relaxing Assumption of Known sub-Gaussian Parameter σ2: Assumptions regarding known arm concentration parameters are common in multi-armed bandit works and simplify theoretical exposition. These results can naturally be extended in several directions. One simple extension is to the setting where we have a separate sub-Gaussianity parameter σ2 i for each arm, i.e., heterogeneous variances. A more practical extension is to the setting where we do not have a bound on the sub Gaussianity parameter for each arm but know that the arm returns are bounded. In this setting, a common multi-armed bandit approach is to utilize the empirical variance [35]. These approaches are discussed further in [8]. Improved Estimators ˆµ: Naïvely, the Adaptive Softmax algorithm samples coordinates uniformly at random with replacement from the set of coordinates {1, . . . , d} to estimate each P j Aijxj. This procedure can be improved in several ways. For example, we may utilize importance sampling and sample each coordinate with probability zj |xj|. Furthermore, we can sample coordinates without replacement; this is known to yield tighter confidence intervals than sampling with replacement [9]. We can combine these techniques and compute the effective variance as in [18]. Sampling without replacement can be achieved in a computationally efficient manner via Gumbel sampling [27]. We discuss these details further in Appendix A.3; these details may be of independent interest. 5 Experiments In this section, we demonstrate the empirical advantages of Adaptive Softmax over the brute-force softmax computation in terms of sample complexity. All of our results are reproducible via a 1-line script, publicly available on Git Hub at github.com/Thrun Group/adaptive Softmax. 5.1 Complexity on Synthetic Data Crucially, the Adaptive Softmax algorithm scales sublinearly in d. More precisely, Corollary 1 implies that, for fixed ε and δ, the sample complexity of the Adaptive Softmax algorithm scales as O(n log n). We now empirically validate this behavior. We first run the Adaptive Softmax algorithm on two synthetic datasets. In each dataset, we generate A and x with n = 100 and vary d. In the first synthetic dataset, we set x to be a d-dimensional vector of all 1s. We draw each element of A i.i.d. N(0, 1) and add the vector of all 1s to the first row of A, thereby planting a signal. In expectation, the first row of A will have inner product d with x whereas all other rows will have inner product 0 with x. Furthermore, all arms have expected variance σ2 i that scales with d. 10k 20k 30k 40k 50k 60k 70k 80k 90k Dimension d Sample complexity naive Adaptive Softmax (a) Scaling Baseline: All-Ones Query 10k 20k 30k 40k 50k 60k 70k 80k 90k Dimension d Sample complexity naive Adaptive Softmax (b) Scaling Baseline: Sign Query Figure 1: Sample complexity of the Adaptive Softmax algorithm and the brute-force softmax computation on two different synthetic datasets as a function of d. Error bars are obtained from 100 random trials. The sample complexity of the Adaptive Softmax algorithm scales with respect to d for (a) but does not for (b), as expected. The average gains for δ = 10% and ε = 30% are 3.953 for (a) and 29.039 for (b), increasing with increasing dimension. Confidence intervals are 1std. In the second synthetic dataset, we draw each element of A i.i.d. N(0, 1) and set x to be |A1,:|, the entrywise absolute value of the first row of A. Here, arms have expected variance σ2 i = Θ(1). Figures 1(a) and 1(b) demonstrates the scaling of the Adaptive Softmax Algorithm on each of the two datasets. On the first synthetic dataset, the Adaptive Softmax algorithm scales with d because the variance proxies σ2 i do. On the second synthetic dataset, the Adaptive Softmax algorithm does not exhibit significant scaling with d. On both datasets, the Adaptive Softmax algorithm significantly outperforms the naïve brute-force computation of the softmax function. 5.2 Multinomial Logistic Regression Multinomial logistic regression (MNL) is a form of multiclass classification in which the final operation performed by the classifier is of the form: P(y = c) = eβ(w c h(x)) PC c =1 eβ(w c h(x)) (5) i.e., the probabilities that datapoint x belongs to each class c is given by the softmax applied to the vector Wh(x), where W is the matrix containing rows w1, . . . , wc and h(x) is a latent representation of x (i.e., the forward pass of some neural network on x). The multinomial logistic regression is naturally amenable to accelerated softmax computation in Equation (5). In many real-world settings, both the number of classes C and the dimension of the latent representation h(x) (and therefore the dimensionality of each wc) can be very large, motivating the usage of Adaptive Softmax to identify and estimate the probability of the most likely class. However, the application of Adaptive Softmax extends far past vanilla MNL. For instance, the final layer of any neural network classifier utilizing softmax can also be viewed as an MNL problem. We now provide several such practical settings for which we demonstrate the benefits of applying the Adaptive Softmax algorithm. 5.3 Adaptive Softmax Performance on Real Data We now demonstrate the performance of the Adaptive Softmax algorithm on several real-world datasets. For each setting, we provide the sample complexity gain relative to the sample complexity of the brute-force, naïve softmax computation sample complexity nd. We also provide the success rate of our algorithm in each setting, i.e., the proportion of times the Adaptive Softmax algorithm correctly identifies the maximum likelihood output (i.e. ˆi = i ) and estimates its probability pi within a multiplicative error of ε = 30%. 5.3.1 Application to CNNs We consider the application of Adaptive Softmax to CNN classifiers on two distinct image classification datasets: 1. The MNIST dataset, containing black and white images of handwritten digits as input and ten output classes representing all ten possible digits. 2. The Euro SAT dataset, containing RGB satellite imagery as input and ten output classes, representing possible land types (e.g., river, residential, etc) On both of these datasets and for distinct architectures, we show that Adaptive Softmax provides a drastic improvement in sample efficiency. MNIST For the MNIST dataset, we train a shallow CNN from scratch with two convolutional blocks (Conv2d, Re Lu, Max Pool, Batch Norm). This model achieves over 99% accuracy on the test set. The matrix A is obtained by extracting the weight matrix of the model s final linear layer. The vector x is extracted as the output of the final hidden layer (the layer before the final linear layer) constructed by passing the MNIST image through the trained model and flattening the result. The dimensionality of x is adjusted by changing the number of output channels of the convolution blocks. The sample complexity of our algorithm is measured by running the algorithm on 1000 different images in test set with same matrix A. The empirical error rate δ is calculated as the fraction of experiments where the adaptive algorithm fails to identify the same class, or fails to estimate the probability to accuracy ϵ, as assigned by exact computation. Euro SAT We also utilize a larger pre-trained CNN classifier fine-tuned on the Euro SAT dataset, to show that Adaptive Softmax works with larger more sophisticated CNNs. Specifically, we freeze all convolution blocks of VGG-19 (pretrained on Image Net) and changed the final output dimension to 10 classes for Euro SAT without freezing the weights. The resulting model achieves 92% accuracy on the test set. As before, the matrix A can be extracted from the weights of the final linear layer and the vector x represents the final hidden layer activations. The empirical error rate δ is calculated in the same manner as for MNIST. Dataset (Model) δ = 10% δ = 5% δ = 1% Euro SAT (VGG-19) 5.18x (80.62%) 5.16x (83.00%) 4.54x (98.37%) MNIST (Shallow CNN) 8.95x (92.25%) 8.81x (93.75%) 8.13x (99.38%) Table 1: Performance improvement and success rate afforded by Adaptive Softmax for multinomial logistic regression on two different real-world datasets. We used a total of q = 800 test queries to measure success rate. 5.3.2 Application to LLMs We also apply the Adaptive Softmax algorithm to LLMs using Hugging Face s Auto Model For Causal LM module for the task-generation task [45]. The matrix A is the lm-head layer for each model, and the queries x are the final hidden states of the model that is extracted by running a forward pass of the model on the given dataset with a window moving at a certain stride. The context window and stride is modified to generate a desired number of queries. Dataset (Model) δ = 10% δ = 5% δ = 1% Wikitext (GPT-2) 8.25x (88.94%) 7.80x (93.54%) 6.67x (98.26%) Wikitext (Llama3-7B) 14.68x (91.44%) 11.43x (94.04%) 6.88x (99.38%) Wikitext (Mistral7B) 32.65x (89.08%) 26.37x (91.20%) 17.71x (97.77%) Penn Treebank (GPT-2) 8.10x (81.68%) 7.50x (90.73%) 6.66x (96.79%) Penn Treebank (Llama3-7B) 19.18x (87.82%) 16.57x (91.60%) 10.72x (97.81%) Table 2: Performance improvement and success rate afforded by Adaptive Softmax for LLM inference (improvement for final softmax layer). Experiment details in Section 5.3.2. We used q = 1000 unseen test queries to measure δ-accuracy. Our matrix A is the extracted lm-head from Hugging Face s Auto Model For Causal LM for the four models: GPT-2 (n = 50257, d = 768), Llama3-7B (n = 128256, d = 4096), Mistral7B (n = 32000, d = 4096), and Gemma7B (n = 256000, d = 3072). Our task is task-generation, and we generate our queries x by using two datasets (Wikitext and Penn Treebank) with a sliding window of certain stride. Stride and context window is set to get q = 1000 number of queries. Constants and confidence intervals given by theory are empirically quite loose, so we tuned algorithm parameters (constant coefficients for stage length and confidence interval width) on initial training data, described in Appendix C. An aggressive tuning strategy was undertaken in order to demonstrate the potential gains in sample complexity provided by Adaptive Softmax . Specifically, constant multiples were applied to variance estimate within Algorithm 3 and Algorithm 2. Due to the limited sample set, this approach occasionally overoptimized the constants on training data, yielding lower success rates than targeted. However, from the results, it is clear that this target parameter still provides users sufficient control over the tradeoff between true success rate and sample complexity. 6 Discussion, Limitations, and Future Work In this work, we proposed a theoretically novel and practically efficient algorithm for approximating the softmax function. We provided theoretical guarantees on the accuracy of our approximation and demonstrated that, with fewer samples than exact computation, we can approximate the softmax function to the desired accuracy. We further demonstrated the viability of our proposed algorithm in two real-world settings, multinomial logistic regression and LLM inference. A potential limitation of our proposed algorithm is that it is most beneficial when the inner dimension of the matrix vector product is high dimensional; its benefits over exact computation are more modest when the inner dimension is small. In particular, the exact computation of the matrix-vector product preceding a softmax operation is usually performed efficiently using BLAS (Basic Linear Algebra Subroutines, which are highly optimized). Adaptivity at its core is inherently sequential, whereas BLAS operations take advantage of batch computation. In this work we proposed minimally adaptive algorithms, with only a logarithmic number of rounds of adaptivity, but there are important directions of future work to realize these theoretical gains in practice. Limitations: Theoretical sample complexity bounds are useful for understanding the fundamental properties of an algorithm, but in practice, wall-clock time is often the metric of interest. Many of the steps in our algorithm can be batched and made BLAS efficient, yielding comparable wall clock times to brute force computation. However, in general adaptivity is the opposite of batching, as can be seen when we modify our algorithm to adapt to arm specific variances. In this case, we must sample each arm individually, as the number of samples required for each arm is different. This is a trade-off between adaptivity and wall-clock time, and in practice, the choice of which to prioritize depends on the specific application (energy efficiency, computational resources, etc.). There are also possible theoretical analyses, where we can e.g. create batches of arms with similar empirical variance and sample all arms within a batch together, leading to a trade-off between adaptivity and batched computational efficiency. Additionally, in large language models, the final softmax layer is often not a computationally significant step, so while such a method may greatly accelerate multinomial logistic regression, more work may be required to have this accelerate LLMs. Given the ubiquity of the softmax function in today s machine learning workflows, we hope that our algorithm will help pave the way for an optimized adaptive softmax that can accelerate a wide class of machine learning models. An interesting direction of future work is trying to combine this multi-armed bandit approach with LSH [15] to obtain (for the attention case) subquadratic complexity in n, and sublinear complexity in d. The adaptive method we propose for estimating the normalizing constant of the softmax function is of independent interest, and holds potential for applications in kernel density estimation and other machine learning tasks. Acknowledgements Mert Pilanci was supported in part by the National Science Foundation (NSF) under Grant DMS2134248; in part by the NSF CAREER Award under Grant CCF-2236829; in part by the U.S. Army Research Office Early Career Award under Grant W911NF-21-1-0242; and in part by the Office of Naval Research under Grant N00014-24-1-2164. Tavor Baharav was supported by funding from the Eric and Wendy Schmidt Center at the Broad Institute of MIT and Harvard. [1] Josh Alman and Zhao Song. Fast attention requires bounded entries . In: Advances in Neural Information Processing Systems 36 (2024). [2] Alexandr Andoni et al. Practical and optimal LSH for angular distance . In: Advances in neural information processing systems 28 (2015). [3] Jean-Yves Audibert, Rémi Munos, and Csaba Szepesvári. Exploration exploitation tradeoff using variance estimates in multi-armed bandits . In: Theoretical Computer Science 410.19 (2009), pp. 1876 1902. [4] Vivek Bagaria et al. Bandit-Based Monte Carlo Optimization for Nearest Neighbors . In: IEEE Journal on Selected Areas in Information Theory (2021). [5] Vivek Bagaria et al. Medoids in almost-linear time via multi-armed bandits . In: International Conference on Artificial Intelligence and Statistics (2018), pp. 500 509. [6] Tavor Baharav and Tze Leung Lai. Adaptive Data Depth via Multi-Armed Bandits . In: Journal of Machine Learning Research 24.155 (2023), pp. 1 29. [7] Tavor Baharav and David Tse. Ultra fast medoid identification via correlated sequential halving . In: Advances in Neural Information Processing Systems 32 (2019). [8] Tavor Baharav et al. Approximate Function Evaluation via Multi-Armed Bandits . In: International Conference on Artificial Intelligence and Statistics. PMLR. 2022, pp. 108 135. [9] Rémi Bardenet and Odalric-Ambrym Maillard. Concentration inequalities for sampling without replacement . In: Bernoulli 21.3 (2015), pp. 1361 1385. [10] Yoshua Bengio and Jean-Sébastien Senécal. Adaptive importance sampling to accelerate training of a neural probabilistic language model . In: IEEE Transactions on Neural Networks 19.4 (2008), pp. 713 722. [11] Yoshua Bengio and Jean-Sébastien Senécal. Quick training of probabilistic neural nets by importance sampling . In: International Workshop on Artificial Intelligence and Statistics. PMLR. 2003, pp. 17 24. [12] Guy Blanc and Steffen Rendle. Adaptive sampled softmax with kernel based sampling . In: International conference on machine learning. PMLR. 2018, pp. 590 599. [13] Sébastien Bubeck, Nicolo Cesa-Bianchi, et al. Regret analysis of stochastic and nonstochastic multi-armed bandit problems . In: Foundations and Trends in Machine Learning 5.1 (2012), pp. 1 122. [14] Wenlin Chen, David Grangier, and Michael Auli. Strategies for Training Large Vocabulary Neural Language Models . In: Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics. 2016, pp. 1975 1985. [15] Insu Han et al. Hyper Attention: Long-context Attention in Near-Linear Time . In: ar Xiv preprint ar Xiv:2310.05869 (2023). [16] Eshcar Hillel et al. Distributed exploration in multi-armed bandits . In: Advances in Neural Information Processing Systems 26 (2013). [17] Ari Holtzman et al. The curious case of neural text degeneration . In: ar Xiv preprint ar Xiv:1904.09751 (2019). [18] Daniel G Horvitz and Donovan J Thompson. A generalization of sampling without replacement from a finite universe . In: Journal of the American statistical Association 47.260 (1952), pp. 663 685. [19] Piotr Indyk and Rajeev Motwani. Approximate nearest neighbors: towards removing the curse of dimensionality . In: Proceedings of the thirtieth annual ACM symposium on Theory of computing. 1998, pp. 604 613. [20] Andrei Ivanov et al. Data movement is all you need: A case study on optimizing transformers . In: Proceedings of Machine Learning and Systems 3 (2021), pp. 711 732. [21] Kevin Jamieson and Robert Nowak. Best-arm identification algorithms for multi-armed bandits in the fixed confidence setting . In: Annual Conference on Information Sciences and Systems. 2014, pp. 1 6. [22] Hervé Jégou et al. Searching in one billion vectors: re-rank with source coding . In: 2011 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE. 2011, pp. 861 864. [23] Armand Joulin et al. Efficient softmax approximation for GPUs . In: International conference on machine learning. PMLR. 2017, pp. 1302 1310. [24] Govinda Kamath, Tavor Baharav, and Ilan Shomorony. Adaptive learning of rank-one models for efficient pairwise sequence alignment . In: Advances in Neural Information Processing Systems 33 (2020), pp. 7513 7525. [25] Zohar Karnin, Tomer Koren, and Oren Somekh. Almost optimal exploration in multi-armed bandits . In: International Conference on Machine Learning. 2013, pp. 1238 1246. [26] Levente Kocsis and Csaba Szepesvári. Bandit based monte-carlo planning . In: European conference on machine learning. Springer. 2006, pp. 282 293. [27] Wouter Kool, Herke Van Hoof, and Max Welling. Stochastic beams and where to find them: The gumbel-top-k trick for sampling sequences without replacement . In: International Conference on Machine Learning. PMLR. 2019, pp. 3499 3508. [28] Tze Leung Lai and Herbert Robbins. Asymptotically efficient adaptive allocation rules . In: Advances in applied mathematics 6.1 (1985), pp. 4 22. [29] Tor Lattimore and Csaba Szepesvári. Bandit algorithms. Cambridge University Press, 2020. [30] Daniel Le Jeune, Reinhard Heckel, and Richard Baraniuk. Adaptive estimation for approximate k-nearest-neighbor computations . In: The 22nd International Conference on Artificial Intelligence and Statistics. PMLR. 2019, pp. 3099 3107. [31] Lisha Li et al. Hyperband: A novel bandit-based approach to hyperparameter optimization . In: The Journal of Machine Learning Research 18.1 (2017), pp. 6765 6816. [32] Stephan S Lorenzen and Ninh Pham. Revisiting wedge sampling for budgeted maximum inner product search . In: Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer. 2020, pp. 439 455. [33] Pinyan Lu, Chao Tao, and Xiaojin Zhang. Variance-dependent best arm identification . In: Uncertainty in Artificial Intelligence. PMLR. 2021, pp. 1120 1129. [34] Blake Mason, Ardhendu Tripathy, and Robert Nowak. Nearest neighbor search under uncertainty . In: Uncertainty in Artificial Intelligence. PMLR. 2021, pp. 1777 1786. [35] Andreas Maurer and Massimiliano Pontil. Empirical bernstein bounds and sample variance penalization . In: ar Xiv preprint ar Xiv:0907.3740 (2009). [36] Frederic Morin and Yoshua Bengio. Hierarchical probabilistic neural network language model . In: International workshop on artificial intelligence and statistics. PMLR. 2005, pp. 246 252. [37] Kezban Dilek Onal et al. Neural information retrieval: At the end of the early years . In: Information Retrieval Journal 21 (2018), pp. 111 182. [38] Ankit Singh Rawat et al. Sampled softmax with random fourier features . In: Advances in Neural Information Processing Systems 32 (2019). [39] Max Simchowitz, Kevin Jamieson, and Benjamin Recht. The simulator: Understanding adaptive sampling in the moderate-confidence regime . In: Conference on Learning Theory. PMLR. 2017, pp. 1794 1834. [40] Aleksandrs Slivkins et al. Introduction to multi-armed bandits . In: Foundations and Trends in Machine Learning 12.1-2 (2019), pp. 1 286. [41] Mo Tiwari et al. Bandit PAM: Almost linear time k-medoids clustering via multi-armed bandits . In: Advances in Neural Information Processing Systems 33 (2020), pp. 10211 10222. [42] Joel A Tropp. Improved analysis of the subsampled randomized Hadamard transform . In: Advances in Adaptive Data Analysis 3 (2011), pp. 115 126. [43] Tim Vieira. Gumbel-max trick and weighted reservoir sampling. 2014. URL: http : / / timvieira . github . io / blog / post / 2014 / 08 / 01 / gumbel - max - trick - and - weighted-reservoir-sampling/. [44] Martin J Wainwright. High-dimensional statistics: A non-asymptotic viewpoint. Vol. 48. Cambridge university press, 2019. [45] Thomas Wolf et al. Huggingface s transformers: State-of-the-art natural language processing . In: ar Xiv preprint ar Xiv:1910.03771 (2019). [46] Martin Zhang, James Zou, and David Tse. Adaptive monte carlo multiple testing via multiarmed bandits . In: International Conference on Machine Learning. Proceedings of Machine Learning Research. 2019, pp. 7512 7522. A Bandit preliminaries To make this work accessible to a broad audience, we provide a self contained introduction to the multi-armed bandit setting. A.1 Best arm identification We consider a stochastic multi-armed bandit problem [13, 40, 29] with n arms (distributions), where each arm i has an unknown mean reward µi. At each time step t, the algorithm selects an arm It [n] and receives a reward XIt,t drawn from the distribution of arm It. Early work in the multi-armed bandit literature focused on the regret minimization setting, where the goal is to maximize the cumulative reward (sum of arm pulls observed so far), motivated by applications such as online advertising and gambling [28]. Recent work has seen increased interest in the best-arm identification setting, where the goal is to identify the arm with the highest mean reward with high probability, motivated by applications such as clinical trials. Significant research has been devoted to obtaining optimal logarithmic factors, but for the sake of clarity we highlight here a simpler and empirically well performing algorithm, multi-round ϵ-arm from [25]. Algorithm 3 Best Arm Id (modification of best-arm identification from [25]) Input: n arms, error probability δ, variance proxy σ2 Output: Best arm i with probability at least 1 δ S0 [n] r 0 t0 0 while |Sr| > 1 do r r + 1 ϵr 2 r tr 8σ2ϵ 2 r log(4nr2/δ) for all arms i Sr 1 do Pull arm i tr tr 1 times and observe rewards Xi,tr 1+1, . . . , Xi,tr tr Ptr s=1 Xi,s Update mean estimates Ci,r p 2σ2 log(4nr2/δ)/tr Compute confidence interval width end for Set Sr {i Sr 1 : ˆµi,r + Ci,r maxj Sr 1 ˆµj,r Cj,r} Filter bad arms end while return i , the only element in Sr We assume i is unique (this assumption is easily relaxed) The algorithm proceeds in rounds, maintaining a set of arms Sr that are still in contention for being the best arm. At each round r the algorithm pulls each surviving arm such that we can construct a high probability confidence interval of width ϵr/2 around the mean of each arm. Then, any arms whose empirical mean plus confidence interval is less than the maximum empirical mean minus confidence interval are eliminated. If this is the case, then that arms mean is with high probability less than the maximum mean, and so it is eliminated from contention. This preserves the best arm with high probability, and the algorithm terminates when only one arm remains (this best arm). A.2 Sub-gaussian random variables Following the exposition of [44], for a strictly increasing convex function ψ : R+ R+ where ψ(0) = 0, the ψ-Orlicz norm of a random variable X is defined as: X ψ inf t > 0|E ψ t 1|X|| 1 . X ψ is infinite if the expectation E ψ t 1|X|| does not exist for any finite t. The sub-Gaussian parameter of a random variable is defined as the ψ2-Orlicz norm, where ψ2(u) = eu2 1. Standard results (Hoeffding s lemma) provide that for a random variable X such that a X b almost surely, that X ψ2 (b a)2 400 200 0 200 400 Arm pull values gpt2 on wikitext worst arm middle arm best arm 200 100 0 100 200 Arm pull values meta-llama/Meta-Llama-3-8B on wikitext worst arm middle arm best arm 100 50 0 50 100 Arm pull values mistralai/Mistral-7B-v0.1 on wikitext worst arm middle arm best arm 1500 1000 500 0 500 1000 1500 Arm pull values google/gemma-7b on wikitext worst arm middle arm best arm 400 200 0 200 400 Arm pull values gpt2 on penn_treebank worst arm middle arm best arm 200 100 0 100 200 Arm pull values meta-llama/Meta-Llama-3-8B on penn_treebank worst arm middle arm best arm 100 50 0 50 100 Arm pull values mistralai/Mistral-7B-v0.1 on penn_treebank worst arm middle arm best arm 1000 500 0 500 1000 Arm pull values google/gemma-7b on penn_treebank worst arm middle arm best arm Figure 2: Distribution of arm pulls for the best, middle, and worst arms for a random query. The Gaussian fit plotted for each arm is computed using the empirical variance of the given arm, and can be seen to closely match the empirical distribution, indicating that the arm pull distributions are well approximated by a Gaussian. This merits the assumption of sub-Gaussianity. Arm pulls represent importance weighted samples based on the magnitude of the query vector, so for an arm i and sample j, the value is Ai,j sgn(xj) d. σ is computed across all samples for an arm, and windows are truncated at the lowest and highest ends of the 3σ ranges across arms for viewing clarity. Hoeffding s inequality provides a useful concentration bound, where for X with X ψ2 σ2, P (|X E[X]| t) exp t2 A.2.1 Sub-gaussianity in practice The assumption of sub-Gaussianity is the only assumption that we make in this paper. It is one of the weakest assumptions possible (does not assume that the arms are Bernoulli or Gaussian), and is a common assumption in the multi-armed bandit and adaptive computation literature [4]. Unfortunately, without such an assumption, no nontrivial results are possible; consider the case where we do not have preprocessing access to A, the vector x is all 1s, and A is all 1s except for a randomly selected entry which has value 2. In this case, any algorithm for PAC computation of softmax(Ax) with δ = 1 1/n (even just identification of the largest entry of Ax) requires Ω(nd) samples. More practically though, these vectors are the result of a machine learning pipeline, and not of adversarial construction. As shown by our simulations, this worst case scenario never occurs in practice, and arm pulls are generally well approximated by a Gaussian (see Figure 2). Additionally, note that for any fixed problem instance, all arm pulls are bounded, and are thus sub-Gaussian. A.3 Improved estimators To provide theoretical guarantees in multi-armed bandit problems, stringent assumptions are often required, e.g. Assumption 1. This is so that we can provide high probability guarantees on the concentration of the estimator constructed as the empirical mean of the observed samples. Often, these assumptions are phrased as either that random variables corresponding to arm pulls are bounded in [0, 1] a.s., or that they are σ2 sub-Gaussian, with a known bound on σ2. Such analyses have been generalized to bounded random variables with a known bound, where the algorithm is able to adapt to the unknown variance [33, 35]. In this work, as is often done to make multi-armed bandit algorithms more performant [4], we instead directly use the empirical variance of an estimator as its sub-Gaussian parameter. In the Gaussian case, a random variable s sub-Gaussian parameter σ2 matches its variance: as our arm pulls are constructed as the sum of many (d) terms, which can be thought of as weakly dependent, our arm pulls can be thought of as Gaussian random variables with variance σ2. Since in practice, we do not have a good bound on the magnitude of these arm pulls, we directly use Hoeffding s concentration inequalities [44] with the empirical variance ˆσ2, as opposed to an empirical Bernstein type concentration inequality [35]. A.3.1 Randomized Hadamard Transform We discuss applying a randomized Hadamard transform to reduce the sub-Gaussian parameter of our estimators. Define the rotation matrix R = 1 d DH, where D is a diagonal matrix with diagonal entries equiprobably 1. H is a Hadamard matrix (d must be a power of 2, 0 padded if necessary). Then, we have that applying the transform R to A and R to x yields arms with better sub-Gaussian parameters. Concretely, define Z = AR and y = R x. Analyzing y first, we have that no entry of y is too large, as: i=1 P R i x t (6) j=1 ξjd 1/2xj t 4 Pd j=1 1 dx2 j = 2n exp t2 The first inequality is a union bound over the n points and plugs in for y = R x. The second equality uses the fact that the Hadamard matrix multiplied by the random diagonal 1 matrix D makes Ri i.i.d. 1, and we use ξj to denote these Rademacher i.i.d. 1 random variables. Next we use Hoeffding s inequality. Finally, we simplify. Thus, with probability at least 1 δ, y < x 2 q 2 log(2n/δ) A similar argument can be made for Z = AR, showing that each entry in the i-th row is upper bounded by Ai 2 q 2 log(2nd/δ) d , holding simultaneously for all i, j. Since Ax = Zy, we can use bandits to approximate Zy instead of Ax. With the above analysis, we have a bound on the maximum entry of Zi Jy J, giving us a bound on the sub-Gaussian parameter of each arm with probability 1 δ. max ij |Zijyj| = max i Ai 2 x 2 2 log(4nd/δ) Whereas before: max ij |Aijxj| = max i Ai x (11) In practice, we did not observe that this transformation yielded much improved variance, as opposed to simply importance sampling. Thus, we do not utilize this in our main algorithm. A.3.2 Importance sampling Instead of naively sampling each coordinate uniformly, we can construct improved estimators using importance sampling. Concretely, consider the unbiased estimator Z where P(Z = 1 pj Ajxj) = pj for some probability distribution {pj} where pj > 0 for all j and P j pj = 1. Now, we are left with the design choice of how to construct {pj}; naively, this is 1/d. Unpacking the variance of our estimator, we see: 1 pj (Ajxj)2 µ2. (12) Consider the case where Aj i.i.d. Q for some distribution Q with E[Q2] = λ2 (an empirically not unreasonable assumption), where {xj} are fixed constants. We assume x is known, but A isn t. Then, we simplify the expected variance with respect to this randomness in A as E [Var(Z)] = E 1 pj (Ajxj)2 x2 j pj µ2 (13) Simplifying this for Znaive where pj = 1 d, and for Zopt where pj |xj|, we have E[Var(Znaive)] = d x 2 2 µ2 (14) E[Var(Zopt)] = x 2 1 µ2 (15) In the case where we are using the same matrix A over many vectors x (as is often case e.g. in LLM inference, or multinomial logistic regression), we can make our leverage scores a function of precomputed quantities based off of A, not just x. In this case, it makes sense to consider forms of leverage sampling; in this work we consider taking pj |xj|| P A.3.3 Sampling without replacement using Gumbel trick Weighted sampling with replacement becomes wasteful in larger sample size regimes, for which the same high-weight elements are sampled repeatedly. It therefore becomes desirable to remove sampled elements from consideration after they are sampled, re-weighting the remaining elements accordingly. We could naively repeat this iterative process of sampling, removing, and re-weighting our elements until we ended up with a sample of the desired size, say k. However, this process is sequential and quite slow. Fortunately, as noted in [43], sampling with replacement according to a set of weights is equivalent to perturbing the logits λ1...n of our desired sample weights with draws from the i.i.d standard Gumbel distribution and taking the elements with the top-k perturbed logits as our sample, as detailed in Algorithm 6 and 7. This process is easily batched and is much faster as a result. Further, taking the (k + 1)-th perturbed logit as an empirical threshold τ, the inclusion of an element j in our sample is solely dependent on whether or not its perturbed logit value exceeded this threshold. This derivation is detailed in [27], and gives us the following expression (Gumbel CDF) for the inclusion probability of element j in a set S of size k drawn without replacement according to weights w: πj = P (j S) (16) ˆπj = 1 exp( exp(λj τ)) (17) These empirical estimates of the marginal probabilities of selection for each column allow us to generate a sequence of estimators for each arms mean, with improved variance discussed in the next section. A.3.4 Variance estimation for Gumbel Samples The Gumbel sampling trick used in A.3.3 with the fixed empirical threshold also provides us a different lens on our sampling process. Namely, we can compute these closed form inclusion probabilities π, and by setting this empirical threshold, we may treat the inclusion of separate elements as independent. Given these observations (sampled S), an unbiased estimator for the variance of ˆµi, constructed as the importance sampling weighted mean of the observations, can be computed from [18] as: S sample k elements without replacement from [d] according to π (19) 1 ˆπj Aijxj (20) E [ˆµi] = µi (21) Var( ˆµi; τ) = X j S (Ai,jxj)2 1 πj j =k (Ai,j Ai,kxjxk) πj,k πjπk d Var( ˆµi) = X j S (Ai,jxj)2 1 ˆπj Following the analysis of [18] and assuming that our threshold τ is fixed, we may conclude that the estimate of the variance of ˆµi, the quantity Var( ˆµi; τ), is an unbiased estimate of the true variance Var( ˆµi). Further, in all datasets we analyzed, the entries of A were generally symmetric, zero-mean, and not correlated with the corresponding entries of x, as seen in Figure 3(a). Since τ and the values of π are selected solely based on the values of x, these two further assumptions make the second order term (Ai,j Ai,kxjxk) πj,k πjπk zero in expectation. Thus, for simplicity, we discard this latter summation and treat d Var( ˆµi) as an unbiased estimate of the variance throughout our implementation, which can be computed in linear instead of quadratic time (updated in constant vs linear time). In practice, as observed in Figure 4, this variance estimator d Var( ˆµi) (solid green line) provides a far better estimate of the Gumbel sampler s true variance than other methods. 0.010 0.005 0.000 0.005 0.010 A values 2D histogram of A weights vs. X values (a) Sampled entries of A and x 0.100 0.075 0.050 0.025 0.000 0.025 0.050 0.075 0.100 second order term Distribution of second order term (b) Sampled values of the second order term Figure 3: (a) Sampled entries of A and the corresponding entries of x for Mistral on the Wikitext dataset. The values of A are symmetrical about 0 and not correlated with x. (b) Sampled values of the second order term (Ai,j Ai,kxjxk) πj,k πjπk We begin by proving a standard best arm identification result, for a slightly modified version of the round-based algorithm from [16]. 0 500 1000 1500 2000 2500 3000 3500 4000 Sample Size Error of Sample Mean Variance MSE [wr] Variance [imp] MSE [imp] [wr] Variance [imp] [fpc-sparse]] MSE [imp] [wor] New Variance Est. Figure 4: Variance estimates vs. empirical mean squared error. This demonstrates the dramatic improvement afforded by improved estimators and tight confidence intervals. [imp] indicates importance sampling, [wr] with replacement, [wor] without replacement, [fpc-sparse] includes the finite population correction factor. Lemma 1 (Best-arm identification). With probability at least 1 δ, Algorithm Best Arm Id identifies the top softmax value correctly with probability using a number of observations at most δ log2 2 (4/ i) Proof. Following the proof from [16], since an arm s mean is exactly computed after d pulls, we have that the best arm will be identified with probability at least 1 δ requiring for arm i a number of pulls at most δ log2 2 (4/ i) Summing over arms yields the desired result, noting that the best arm is pulled at most as many times as it takes to eliminate every other arm (i.e. the second best arm, 1 = 2). Note that if i C d, then the second term (d) will be selected, for a sufficiently small absolute constant C. Thus, ni Cσ2 log n Hence, the total sample complexity on this success event is at most δ log2 2 (4/ i) We additionally require a lemma for estimating the mean of the best-arm in a PAC sense. Lemma 2 (Exponential best arm estimation). Sampling an arm using T = 32σ2β2 log(2/δ) samples guarantees that eβˆµk estimates eβ maxi µi to multiplicative accuracy ϵ, with probability at least 1 δ. Proof. We estimate the mean of arm k after T draws using the plug-in estimator ˆµk. For simplicity, assume β = 1, where in the end we scale the sample complexity by β2. Sub-Gaussian concentration provides that with probability at least 1 δ, 2σ2 log(2/δ)/T = ϵ/4 which in turn implies log(1 ϵ) ϵ/4 ˆµk µk ϵ/4 log(1 + ϵ) for 0 < ϵ < 1. Then, exponentiating both sides yields 1 ϵ eˆµk µk 1 + ϵ (1 ϵ)eµk eˆµk (1 + ϵ)eµk. Scaling the number of samples by β2 yields the desired result. With these two steps in place, we are now ready to tackle the novel technical challenge of this work; estimating the normalization constant of the softmax. We denote the softmax normalization constant as f(µ) = σβ(µ) = P i eβµi. Proposition 2 (Softmax normalization estimation: restatement of Proposition 1). Under Assumption 1, Algorithm 2 will, with probability at least 1 δ, estimate fβ(µ) = P j eβµj to a multiplicative accuracy of ϵ, using a number of samples at most T = 2n T0 + T1 = 34β2σ2 log(6n/δ)n + 91σ2β2 log(6n/δ) (Pn i=1 γi)2 + 16β2σ2 log(12/δ) Proof. To prove Proposition 1, we want to show that with high probability, we can upper and lower bound our plugin estimator as (1 ϵ)f(µ) f(ˆµ) (1 + ε)f(µ), giving us our desired multiplicative error bound. We construct several success events that collectively guarantee our bound holds, and that occur with high probability. E1 is the event where our estimated optimal sampling frequencies are not too far from the unknown optimal frequencies, i.e. ˆαi αi/2, βCi < 1, i = 1, . . . , n. On this event, we will sample arms sufficiently in the second round. E2 is the event where all estimators ˆµi are within their 2-sided confidence intervals in stage 2, i.e. |ˆµi µi| p 2σ2 log(12n/δ)/ni, i = 1, . . . , n. On this event, we can bound the error in the exponentiated estimator. E3 is the event where the first and second order errors are small, i.e. i eβµiβ(µi ˆµi) ϵ i eβµiβ2(µi ˆµi)2 ϵ These two terms arise from bounding f(ˆµ). We now show that if E1 and E2 and E3 all occur, then our desired bound holds. Lemma 3 shows that when E2 holds, i eβµi 1 + β(µi ˆµi) + β2(µi ˆµi)2 i eβµiβ(µi ˆµi) i eβµiβ2(µi ˆµi)2 Note that if E3 holds as well, the expression is upper bounded by (1 + ϵ)f(µ), since each of (1) and (2) is bounded above by (ϵ/2)f(µ). All that remains is to show the lower bound also holds. We use the global inequality 1 + x ex to lower bound i eβµi+β(ˆµi µi) i eβµiβ(ˆµi µi) Again, under E2, (1) is lower bounded by ϵ/2f(µ), which implies our desired lower bound (1 ϵ)f(µ) f(ˆµ). Thus, all three events holding guarantees our desired bound holds. Now that we know our bound holds on the joint success event, all that remains is to show it holds with sufficiently high probability. To do so, we invoke our lemmas which characterize the probability of each event. The probability of E1, E2, and E3 all holding is P(E1E2E3) = P(E1)P(E2|E1)P(E3|E2E1). Lemma 3 says that if each arm is sampled T0 times, P(E1) 1 δ/3. Lemma 4 says P(E2|E1) 1 δ/3. Lemma 5 and Lemma 6 together show P(E3|E2E1) 1 δ/3. Thus, taken together we have P(E1E2E3) = P(E1)P(E2|E1)P(E3|E2E1) as desired, and so we are done. The lemmas and their proofs follow. Lemma 3. For ni = T0 = 17β2σ2 log( 6n δ ) for all i, we have that the event E1, ˆαi αi/2, βCi < 1, i = 1, . . . , n occurs with probability at least 1 δ/3. Proof. By a standard Chernoff bound, with σ2 a bound on the sub-Gaussian parameter of all arms, we have that with probability at least 1 δ/3 that for all i = 1, . . . , n, ˆαi = eβ(ˆµi Ci) P j eβ(ˆµj Cj) eβ(µi 2Ci) where αi = eβµi P j eβµj , and Ci = q 2σ2 log(6/nδ) T0 is the Chernoff confidence interval width constructed such that Ci < log(2)/2β and so βCi < 1. To simplify constants, we use that 8/ ln2(2) < 17. Lemma 4. Sampling as ni T0 guarantees that conditioned on E1, E2 occurs with probability at least 1 δ/3 and that on E2, f(ˆµ) f(µ) + P i β(µi + ˆµi) + P i β2(µi ˆµi)2. Proof. Suppose we sample each arm ni times. Note that by a Chernoff bound on each arm, 2σ2 log(6n/δ)/ni holds on each arm independently with probability at least 1 δ/3n, so all arms are within the two-sided bound with probability at least 1 δ/3. We upper bound the plugin estimator f(ˆµ). i eβˆµi (26) i eβµi+β(µi ˆµi) (27) i eβµi 1 + β(µi ˆµi) + β2(µi ˆµi)2 where in (28) we use the upper bound ex 1 + x + x2 for x 1.79 on the event E2, since on E2, (µi ˆµi) 1/β. This is because ni T0, and T0 samples already guarantees this. Lemma 5 (First-order error concentration). On the event E2 E1, the first order error i eβµiβ(µi ˆµi) Proof. First, defining Ef as the failure event |G| ϵ i eβµi, note that P(Ef) = P(Ef|E2)P(E2) + P(Ef|E2)P(E2) P(Ef|E2) = P(Ef) P(Ef|E2)P(E2) P(E2) 2P(Ef), where we use that since δ < 1, P(E2) 1/2. Thus, it suffices to show that P(Ef) δ/6 G is a sum of independent sub-Gaussian random variables, each scaled by a constant. Thus, we have the two-sided tail bound that with probability at least 1 δ/6 i eβµiβ(µi ˆµi) 2B2 log(12/δ) with the sum having sub-Gaussian parameter i e2βµiβ2σ2/ni Plugging in our value of B2 we find i eβµiβ(µi ˆµi) 2 log(12/δ) X i e2βµiβ2σ2/ni 2 log(12/δ) 2 log(12/δ) log(12/δ)f(µ)/T 1 2 , where in (30) we use that on E1, ˆαi 1 2αi, and so ni αi T/2. Thus T 16β2σ2ϵ 2 log(12/δ) (30) is sufficient to yield the desired multiplicative error of ε/2 with probability at least 1 δ/3. Lemma 6. If arm i is pulled at least 4 2σ2β2 log(6n/δ)γi Pn j=1 γj ϵ P j αj times, then on the success events E1, E2 where the confidence intervals hold, j=1 β2eβµj (µj ˆµj)2 ε j eβµj (31) and Algorithm 2 will require a number of arm pulls at most 91σ2β2 log(6n/δ) (Pn i=1 γi)2 Proof. Bounding the second order error, we utilize the fact that we have sampled proportional to ˆγi. On the event E2 (where ni γi 2T) the second order error can be bounded as i=1 β2eβµi (µi ˆµi)2 (33) i=1 2σ2β2eβµi log(6n/δ)/ni (34) 2σ2β2 log(6n/δ) 2σ2β2 log(6n/δ) 2σ2β2 log(6n/δ) i=1 eβµi/2 !2 We want this second order error to be at most ϵ i eβµi, and so require T to satisfy the inequality below 2σ2β2 log(6n/δ) Pn i=1 eβµi/2 2 2σ2β2 log(6n/δ) (Pn i=1 γi)2 Algorithmically this is not a valid T to use, since it depends on the unknown µi. However, since ˆαi, ˆγi are close to their true values on the good event E1, we can use these estimates. Thus, we take as our budget T for this second order error: 2σ2β2 log(6n/δ) (Pn i=1 ˆγi)2 i ˆαi , (39) which on the event E1 is larger than (38). This is a random quantity, so to analyze the requisite sample complexity, we use the fact than on E1, we can bound the quantity in (39) as 2σ2β2 log(6n/δ) (Pn i=1 γi)2 ϵ P i αi . (40) 2 < 91 so we use this simple constant in the statement of the lemma. This can be directly compared to the case where we only sample according to αi, not γi, which would yield a sample complexity of T 8nσ2 log(6n/δ)β2 samples. Note that since (Pn i=1 γi)2 n P i αi by Cauchy-Schwarz, this is always an improvement (up to absolute constants) up to a factor of n. Now we can provide a proof of the main theorem. Proof of Theorem 1. We utilize the adaptive approximation subroutine with error probability δ = δ/3 and error ε = ε/4 from Lemma 6. With 1 δ probability, it requires a number of samples at most T 34β2σ2 log(18n/δ)n + 363σ2β2 log(18n/δ) (Pn i=1 γi)2 + 256β2σ2 log(36/δ) Best arm identification is called with error probability δ = δ/3. From Lemma 1 this requires sample complexity n X 32σ2 ln 12n δ log2 2 (4/ i) We then estimate the best arms mean using Lemma 2 to accuracy ε = ε/4 with error probability δ = δ/3. This requires a number of samples at most 512σ2β2 log(6/δ) By a union bound, all these algorithms succeed with probability at least 1 δ. Analyzing the multiplicative error, we have that [1/(1 + ε ), 1/(1 ε )] [1 2ε, 1 + 2ε] for 0 < ε 1/2, and that (1 ε1)(1 ε2) (1 (ε1 + ε2 + ε1ε2). Thus, on these success events, the numerator is approximated to accuracy ε/4, and the denominator to accuracy ε/4. The denominator error converts to a multiplicative error of ε/2 in the numerator, which combines to yield an error of 3ε/4+ε2/8 < ε, as ε < .5. This allows us to simplify the total sample complexity as T 512σ2β2 log(6/δ) ϵ2 + 256β2σ2 log(36/δ) 768σ2β2 log(36/δ) T 34β2σ2 log(18n/δ)n 32σ2β2 ln 12n δ log2 2 (4/ i) + 363σ2β2 log(18n/δ) (Pn i=1 γi)2 i αi + 768σ2β2 log(36/δ)ϵ 2 β2σ2 n log n log n log d j αj + log(1/δ) Algorithmically, in the second stage, arm i needs to be pulled a number of times ni =17β2σ2 log(6n/δ) 2σ2β2 log(6n/δ) (Pn i=1 ˆγi)2 + 16β2σ2 log(12/δ)ϵ 2 ˆαi P 17β2σ2 log(6n/δ) 2σ2β2 log(6n/δ) P j γj γi + 16β2σ2 log(12/δ)ˆαi j αj where we upper bound on the success event E2. Combining this with the initial T0 pulls, the pulls from best arm identification, and the pulls from estimating the value of the best arm, we have that the overall sample complexity for arm i is upper bounded as: ni = 34β2σ2 log(6n/δ) 2σ2β2 log(6n/δ) P + 16β2σ2 log(12/δ)αi + 32σ2 ln 12n δ log2 2 (4/ i) + 512σ2β2 log(6/δ) ϵ2 1{i = 1} And so, the total algorithmic sample complexity on the success event is i min ( ni, d) . (44) B.1 Interpreting the results We work to provide a simplified (looser) bound on the sample complexity when the minimum gap is . The worst case sample complexity in this case is when the best arm has mean µ1, and all the rest have mean µ1 . This allows us to simplify the overall sample complexity as σ2 β2n log n log n log d 2 i + β2 log n j αj + β2 log(1/δ) n + ε 1 1 + n2e β + β2σ2ε 2 log(1/δ) + nσ2 log n log d where we use the fact that for n > 2: n P 1 + (n 1)e β /2 2 1 + (n 1)e β 21 + (n 1)2e β 1 + (n 1)e β C 1 + n2e β Evaluating our sample complexity in (45), when > 2 β log n, this term is bounded by a constant, and the sample complexity can be more simply bounded as. n + ε 1 + β2σ2ε 2 log(1/δ) + nσ2 log n log d Cβ2σ2 log n n + ε 1 + ε 2 log(1/δ) (46) In the last line we require the condition that log n log d log2 n β2 log n i.e. d is not doubly exponential in n and β is not too small. Assumption 2 (large gap, moderate β, moderate d). We assume that (47) holds, and that > 2 Under the conditions of Assumption 2, the sample complexity in Theorem 1 can be simplified as in Corollary 1. B.2 Asymptotic optimality of sampling frequencies Following the approach of [8], we can show the asymptotic optimality of our sampling frequencies (sampling proportional to αi for minimizing the first order error, and γi for minimizing our bound on the second order error). B.2.1 First order frequencies αi Considering a plug-in estimator ˆµ; analyzing the first order taylor expansion of its error, we have that f(ˆµ) f(µ) = f(µ) (ˆµ µ) + O ˆµ µ 2 2 . Thus in the high accuracy regime (ε 0), we can consider only the first order term. Assuming Gaussian noise in our arm pulls (an identical result holds for sub-Gaussian noise), the first order error can be bounded as (assuming we use T pulls, and sample arm i, pi T times for a probability distribution p: f(µ) (ˆµ µ) N Optimizing over probability distributions p, using Sion s minimax theorem and strong duality as in [8] gives us that α = argmin p i=1 β2 e2βµi = argmin p max λ utilizing strong duality, we have that pi = e2βµip 2 i + λ = 0 = α eβµ. (48) Note that in the limit as ε 0, multiplicative and additive error objectives are equivalent. B.2.2 Second order sampling frequencies γi With the first order error term in hand, the second order error term is left to be analyzed: i=1 β2eβµi (µi ˆµi)2 . (49) We minimize a bound on this given by our confidence intervals |µi ˆµi| c ni , where c is some constant. We identify what sampling distribution p minimizes this second order bound, defining ni = pi T. By a similar argument as for the first order analysis: γ = argmin p i=1 β2eβµi 1 = argmin p max λ utilizing strong duality, we have that pi = eβµip 2 i + λ = 0 = γ eβµ/2 (50) Gains by sampling according to γi Sampling according to γ gives a second order error bounded by X i eβµi/2 !2 = γ 2 1, (51) as opposed to sampling according to α which gives a second order error bounded by i eβµi = n γ 2 2 (52) By standard norm inequalities, sampling according to γ is always at least as good, and up to a factor of n improvement in the case where one entry in the softmax is much larger than the rest; exactly the case of interest. B.3 Comparison with [8] In [8], the problem of estimating a real valued function f to additive accuracy ε with probability at least 1 δ is studied, under the assumption that the function has L-Lipschitz gradients. For the case of softmax estimation, the gradients are not Lipschitz due to the unbounded nature of the exponential. However, if we evaluate the norm of the gradient at a point µ, we obtain f(µ) 2 = lim c 0 max u c c 2 f(µ + u) f(µ) 2 2 = lim c 0 max u c c 2 X βeβ(µi+ui) βeβµi 2 = lim c 0 max u c c 2β2 X i e2βµi eβui 1 2 = β4 max i e2βµi. Theorem 1 of [8] states that the number of samples required to achieve ε additive error with probability at least 1 δ is T = O f(µ 2 1 log(1/δ) ε2 + n2L log(n/δ) where the noise variance σ2 is assumed to be 1. Since the error in our setting is multiplicative, we are interested in ε = ϵf(µ) = ε P i eβµi. Additionally, f(µ 2 1 = β2 P i eβµi 2. Thus, the number of samples required is T = O f(µ 2 1 log(1/δ) ε2 + n2L log(n/δ) i eβµi 2 log(1/δ) ε2 (P i eβµi)2 + n2β4 maxi e2βµi log(n/δ) = O β2 log(1/δ) ε2 + n2β4 maxi e2βµi log(n/δ) This is to be compared with the sample complexity of the proposed algorithm in this paper, for the specific setting of softmax normalization estimation, which is taking σ2 = 1 to compare results. The constant term independent of ε is to linearize the exponential. The ε 2 term matches between the two settings, as asymptotically the optimal strategy is indeed to sample according to the first derivative. The term scaling with ε 1 improves dramatically on that of prior work. Note that in the case of f(µ) = µ 2 2, the second order term (scaling with ε 1) can be improved to O(n3/2Lε 1), as the mean of the second order error can be removed. Thus, we can see the massive improvement afforded by our more refined algorithm, tailored for the specific structure of the softmax function. B.4 Extension to heterogeneous arm variances Adapting bandit algorithms to settings with heterogeneous variances has been done in both the standard regret [3] and best arm identification [33] settings. For best arm identification, sacrificing log factors for the sake of clarity, empirical-Bernstein-based confidence intervals [35] can be constructed where we iteratively pull each arm once, try and eliminate, and progress. Union bounding over the nd possible pulls naively upper bounds this, yielding a complexity of i=1 min σ2 i 2 i + 1 This assumes that all arms are bounded in [0, 1] with variance σ2 i . We additionally require a lemma for estimating the mean of the best-arm in a PAC sense. Lemma 7 (Exponential best arm estimation). Sampling arm i T = 32σ2 i β2 log(2/δ) samples guarantees that eβˆµi estimates eβµi to multiplicative accuracy ϵ, with probability at least 1 δ. This trivially follows from the proof of Lemma 2. For the softmax normalization estimation, we know from [8] that the optimal first order sampling frequencies are to sample arm i a number of times proportional to ni σieβµi. (55) However, sampling like this yields an additive first order that scales as P i σieβµi, which cannot be easily related to P i eβµi, as we would need to get multiplicative error bounds. Thus, we instead use suboptimal target first order sampling frequencies, scaling with σ2 i , to avoid this analysis issue ni σ2 i eβµi. (56) Scaling the number of pulls for each arm by σ2 i yields: Proposition 3 (Softmax normalization estimation: heterogeneous variances variant of Proposition 1). Under Assumption 1, Algorithm 2 will, with probability at least 1 δ, estimate fβ(µ) = P j eβµj to a multiplicative accuracy of ϵ, using a number of samples for arm i at most ni = 34β2σ2 i log(6n/δ)n + 91σ2 i β2 log(6n/δ) (Pn i=1 γi)2 i αi + 16β2σ2 i log(12/δ) The proof of this proposition follows similarly to Proposition 2, as sampling proportional to σ2 i cancels the differing variances, essentially reducing the problem to the homogeneous setting (suboptimally). C Implementation Details In this appendix, we present the implementation details of our algorithm. In Algorithm 4, we provide pseudocode with greater detail about our implementation of the Adaptive Softmax algorithm. We note that Algorithm 4 contains some implementation differences from the original Algorithm 1 presented in Section 4. None of these changes materially affect the output of the algorithm; nonetheless, we provide a discussion of them here to enable reproducibility of our experimental results. Our results are also reproducible via a 1-line script in our code submission. In the following subsections, we describe each of these implementation details. As an important note, we consider all variables global unless stated otherwise. This is to say, calling pull arms updates the state of arm mean estimates ˆµ, their variance estimates ˆσ, the number of pulls per arm {ni}, inclusion probabilities π, etc. We define the estimators based on Gumbel sampling according to importance weights as the set of arms A. This idea of treating an arm simply as a sequence of estimators with confidence intervals was pioneered for the computational setting in [4], and saw further usage in [24]. C.1 Reusing Arm Pulls In our theoretical analysis in Appendix B, each phase of Algorithm 1 is handled independently. This allows us to union bound the error probabilities of each phase of the algorithm. In our implementation, Algorithm 4 Adaptive Softmax (implementation details) 1: Input: Matrix A, vector x, temperature β, error ϵ, failure probability δ 2: Output: With probability at least 1 δ, the argmax coordinate i and an estimate ˆpi of its probability such that (1 ϵ)pi ˆpi (1 + ϵ)pi . 3: w Get Importance Weights(A, x) Algorithm 5 4: P, c Gumbel Permutation(w) 5: Construct set of arms A from A, x, P, c 6: σ2 = 1 β Initial variance to pull arms to 7: Pull To Variance(A, σ2) Algorithm 8 8: i Best Arm Id(A, δ/2, ˆσ2) Algorithm 3 9: ˆµi Ai , x Exact computation of ˆµi = µi 10: ˆZ Normalization Estimation(A, ε/2, δ/2, ˆσ) Algorithm 2 11: Compute estimated probability as ˆpi = eβˆµi / ˆZ 12: return i , ˆpi Algorithm 5 Get Importance Weights 1: Input: Matrix A Rn d, vector x Rd 2: Output: w Rd, vector of importance weights 3: for all j = 1, . . . , d do 4: Compute wj = |xj| Pn i=1 |Ai,j| ℓ1 norm of ith column of A 5: end for 6: return w Element-wise multiplication of ℓ1 norms of columns of A, and |x| Algorithm 6 Gumbel Permutation 1: Input: Importance weights w 2: Output: Permutation P, cached outputs c for inclusion probability calculation 3: Draw each ξi i.i.d Gumbel(0, 1) 4: L log w Log importance weights 5: L L + ξ Log importance weights perturbed by i.i.d. Gumbel noise 6: h sorted(L ) Compute thresholds in decreasing order 7: P ordering(L ) Compute sorting order of thresholds (argsort) 8: c (L, h) Store log importance weight and sorted thresholds for later use 9: return P, c Algorithm 7 Inclusion Probabilities 1: Input: Sample size k, cached outputs c = (L, h) from Gumbel Permutation 2: Output: Inclusion probabilities vector π 3: L, h c 4: H = Initialize cutoff threshold 5: if k < d then 6: H hk Set cutoff threshold to kth largest perturbed log importance weight 7: end if 8: π = 1 exp( exp(L H)) Inclusion probabilities; derived from Gumbel CDF 9: return π Algorithm 8 Pull To Variance 1: Input: Set of arms (sequence of estimators) A, variance σ2 2: Set ζ = .1, multiplicative pull increase factor 3: while there exists an arm i with ˆσ2 i > σ2 do 4: A {i A : ˆσ2 i > σ2}, ni is corresponding number of pulls 5: Pull Arms(A , (1 + ζ)ni) 6: end while Algorithm 9 Pull Arms 1: Input: Set of arms (sequence of estimators) A, target number of pulls per arm {Ni} 2: for all arms i do 3: if ni Ni then 4: continue Do not pull arm i if it has already been pulled Ni times 5: end if 6: Compute ki = Ni ni 7: π Inclusion Probabilities(ki, c) 8: Sample ki coordinates without replacement according to weights π 9: Update mean estimate ˆµi ni Ni ˆµi + ki Ni Pki s=1 Xi,s π s 10: Update variance estimate ˆσ2 i n2 i N 2 i ˆσ2 i + 1 N2 i Pki s=1 Ai,(s)x(s) 2 1 π (s) π 2 (s) 11: end for Algorithm 10 Estimate Arm 1: Input: arm i A, pull variance σ2 i , multiplicative error ϵ, failure probability δ 2: Output: Mean estimate ˆµi 3: target Var ε2 2σ2 i β2 log(2/δ) 4: Pull To Variance({i}, target Var) 5: return ˆµi however, we re-use arm pulls across different parts of Algorithm 1. Intuitively, once an arm return has been observed (corresponding to a scalar multiplication of an element of A with an element of x), it can be used to warm-start estimates of ˆµ and ˆσ2 in later stages of the algorithm. In practice, we observe that the re-use of arm pulls does not affect the correctness of our algorithm and yields significant sample complexity improvements compared to cold-starting each stage of the algorithm independently. C.2 Exact Computation of Best Arm In Line 9 of Algorithm 4, we compute the mean of our estimated best arm i exactly, and set the estimate ˆµi = µi. This allows us to reduce the approximation fidelity required by Normalization Estimation (Algorithm 2) to ϵ 2 instead of ϵ 4 and saves a constant factor in sample complexity. Furthermore, this computation of µi is efficiently computed as a vector-vector dot product. In practice, we found that Algorithm 1 usually required over d samples for the best arm. As such, performing the computation µi = Ai , x (reusing coordinate-wise samples from previous stages of the algorithm) did not significantly increase sample complexity. C.3 Initial Pulls (T0) In Normalization Estimation (Algorithm 2), the initial number of arm pulls T0 depends on σ2. However, as discussed in Appendix A.3, this variance proxy is often unknown a priori. As such, we set T0 = d 10. We observe that this choice of T0 works well in experiments, across multiple datasets. We note that despite the changes made above, there still exists some looseness in practice. To remedy this, we scale our variance estimates by a constant factor to reduce the amount of pulls needed to reach target variances, and improving the gains of our algorithm as a result. We generate these constant factors for each dataset/model by tuning on a separate training group of queries. Tuning is performed, generally, via bisection to discover the minimal factor which still satisfies our provided failure probability parameter δ. This bisection is performed in geometric space and terminates when the log10 difference between the low and high end of our interval is within 10 2. The range of factors we consider is [10 6, 1]. We first tune the constant factor independently for the variance estimates used in bandits such that the algorithm successfully identifies the best arm with a rate of at least 1 δ on our training set. Next, we tune the constant factor for the variance estimate used in log norm estimation such that the entire SFTM algorithm succeeds with a rate of at least 1 δ on our training set. C.5 Wall-clock improvement The focus of this paper is to develop the first provably adaptive softmax algorithm with PAC guarantees, highlighting its dramatically improved sample complexity across a wide variety of models and tasks. The eventual goal of this method is to enable wall-clock improvements in hardware implementations. These next steps of converting our provably and empirically performant method into a hardware optimized wall-clock efficient algorithm is an exciting direction of future work, which we detail below. In most modern-day transformer architectures, memory I/O serves as the primary bottleneck [20]. Adaptive Softmax already presents an opportunity to significantly scale down the number of entries of the matrix that must be loaded at inference time, and, in the future - if memory remains the bottleneck - improve model bandwidth by a similar factor. This objective appears in reach, since we have designed the components of Adaptive Softmax to be amenable to tiling and parallelization. Most notably, our implementation of Adaptive Softmax uses the same column to generate an importance-weighted sample for each active arm. The reasons for this implementation decision are two-fold. First, it takes advantage of the locality of entries in the same column to load samples faster, and, second, it removes intra-column correlation, which can yield theoretically improved performance [7]. Adjacent column samples can also be combined by simply summing their respective importance weights - admitting a simple tiling of our matrix that could easily be sized particularly to fit individual tiles into SRAM on a GPU along with a copy of the vector and the current mean/variance estimates for each arm. Then, we can dynamically load these tiles into SRAM based on the arm estimates as we do currently. The successive elimination bandits algorithm utilized by Adaptive Softmax is also, by choice, quite easily parallelizable. We may also store two copies of our matrix one with wider tiles and one with taller tiles to take advantage of this structure at all stages of the Adaptive Softmax algorithm: both when a larger number of samples is necessary for fewer arms, in later stages of adaptive estimation, and when a smaller number of samples is necessary for many arms, in earlier stages of the adaptive estimation. This said, we observe in our experiments that the bulk of compute is invested in our early samples of many arms. Just using basic parallelization to speed up this step could result in the desired speed improvements. C.6 Robustness to parameters The user-desired parameters can take a wide range. We kept ϵ = 30% constant across all simulations because we observed varying ϵ did not result in significant changes to the performance of ada Softmax. Further, we suspect that for most users, δ will fall in the range we considered: 90 99%. However, to assuage any concerns and verify our assertion that ada Softmax is not sensitive to choice of ϵ or δ, we include here ada Softmax with a much wider range of parameters on the MNIST dataset in Table 3. ε = 0.001 ε = 0.01 ε = 0.1 ε = 1.0 δ = 0.01 100%, 5.18x 100%, 6.63x 99.38%, 8.13x 99.38%, 8.14x δ = 0.05 99.75%, 6.64x 99.13%, 8.46x 95.38%, 8.80x 93.50%, 8.81x δ = 0.2 91.25%, 7.19x 89.25%, 8.90x 88.00%, 9.29x 83.375%, 9.25x Table 3: Success rate (%) and FLOP gains (x) for ada Softmax with varied δ and ε on the MNIST dataset, showing the improved performance across a wide range of parameters, and that raising ε past 0.1 causes minimal difference in performance. D Additional extensions and comments D.1 Effect of temperature Temperature is treated as a fixed constant (fixed parameter for the problem at hand, not tunable by the algorithm). This is because tuning the temperature fundamentally changes the problem. With higher temperatures, the only arms that matter are the best and second best arms, and so adaptivity is extremely helpful. At low temperatures, the output will be essentially the uniform distribution, and the computation is trivial and adaptivity unhelpful. With respect to other parameters, the error probability and FLOP gains of ada Softmax are insensitive to changes in ε and vary most with the choice of δ. We demonstrate this trend on the MNIST dataset in Table 3. D.2 Application to Nucleus Sampling The analysis posed in this paper focuses on identifying the largest entry in the softmax output and estimating its associated probability. As discussed above, this naturally extends to identifying the k largest elements in the output vector by replacing the bandit best arm identification algorithm with any top-k identification algorithm [39]. However, in LLM inference, the goal is often to draw a sample from the softmax output distribution via nucleus sampling [17]. Nucleus sampling avoids specifying k directly; instead, it provides cumulative probability p and requires the identification of the top k elements such that k is the smallest value such that the sum of the probabilities of the top k elements is greater than p. The next token is then sampled from the renormalization probability distribution on these k elements. Our adaptive sampling algorithm naturally applies to the nucleus sampling setting. Adaptive Softmax can maintain a predicted set of arms S such that the sum of the arm probabilities ˆp is greater than p based on pessimistic arm mean estimates. Then, we iteratively sample: a) arms in S, sampling both the arm with the lowest mean minus LCB (in an attempt to verify the boundary), as well as the arm with the widest confidence interval (in order to better estimate ˆp), and b) sampling the top arm in [n] \ S, to see if it belongs in S. For simplicity and concreteness, in this work we focus on identifying and estimating the probability of the top-1 element, but this is an exciting direction of future work. Neur IPS Paper Checklist Question: Do the main claims made in the abstract and introduction accurately reflect the paper s contributions and scope? Answer: [Yes] Justification: The abstract and introduction clearly state the theoretical and practical contributions of the paper, highlighting its theoretical guarantees, sample complexity reduction, and potential for wall-clock improvements. 2. Limitations Question: Does the paper discuss the limitations of the work performed by the authors? Answer: [Yes] Justification: We discuss assumptions on known sub-Gaussian parameter bounds in Assumption 1 and computational difficulties in realizing sample complexity gains as wall-clock gains. 3. Theory Assumptions and Proofs Question: For each theoretical result, does the paper provide the full set of assumptions and a complete (and correct) proof? Answer: [Yes] Justification: All formal claims are stated with their requisite assumptions in the main text, with proof sketches. Detailed theoretical proofs are provided or cited in the appendix. 4. Experimental Result Reproducibility Question: Does the paper fully disclose all the information needed to reproduce the main experimental results of the paper to the extent that it affects the main claims and/or conclusions of the paper (regardless of whether the code and data are provided or not)? Answer: [Yes] Justification: We provide a one-line reproducibility script to reproduce the results in the paper. 5. Open access to data and code Question: Does the paper provide open access to the data and code, with sufficient instructions to faithfully reproduce the main experimental results, as described in supplemental material? Answer: [Yes] Justification: Our codebase publicly available on github, and is reproducible via a 1-line reproducibility script: https://github.com/Thrun Group/adaptive Softmax. 6. Experimental Setting/Details Question: Does the paper specify all the training and test details (e.g., data splits, hyperparameters, how they were chosen, type of optimizer, etc.) necessary to understand the results? Answer: [Yes] Justification: This is described in the paper 7. Experiment Statistical Significance Question: Does the paper report error bars suitably and correctly defined or other appropriate information about the statistical significance of the experiments? Answer: [Yes] Justification: Error bars are provided for all plots. 8. Experiments Compute Resources Question: For each experiment, does the paper provide sufficient information on the computer resources (type of compute workers, memory, time of execution) needed to reproduce the experiments? Answer: [Yes] Justification: This is described in the paper. 9. Code Of Ethics Question: Does the research conducted in the paper conform, in every respect, with the Neur IPS Code of Ethics https://neurips.cc/public/Ethics Guidelines? Answer: [Yes] Justification: The research in this paper conforms with the stated Code of Ethics. 10. Broader Impacts Question: Does the paper discuss both potential positive societal impacts and negative societal impacts of the work performed? Answer: [Yes] Justification: We discuss broader impacts in the final section of the paper. There are minimal societal effects, and potential energy and environmental savings from more efficient computation. 11. Safeguards Question: Does the paper describe safeguards that have been put in place for responsible release of data or models that have a high risk for misuse (e.g., pretrained language models, image generators, or scraped datasets)? Answer: [NA] Justification: The algorithms proposed in this paper do not have high risk for misuse. 12. Licenses for existing assets Question: Are the creators or original owners of assets (e.g., code, data, models), used in the paper, properly credited and are the license and terms of use explicitly mentioned and properly respected? Answer: [Yes] Justification: All resources are publicly available. 13. New Assets Question: Are new assets introduced in the paper well documented and is the documentation provided alongside the assets? Answer: [Yes] Justification: We provide a publicly available github https://github.com/Thrun Group/ adaptive Softmax that is well documented and has a 1-line reproducibility script. 14. Crowdsourcing and Research with Human Subjects Question: For crowdsourcing experiments and research with human subjects, does the paper include the full text of instructions given to participants and screenshots, if applicable, as well as details about compensation (if any)? Answer: [NA] Justification: This paper does not use or perform any research with human subjects. 15. Institutional Review Board (IRB) Approvals or Equivalent for Research with Human Subjects Question: Does the paper describe potential risks incurred by study participants, whether such risks were disclosed to the subjects, and whether Institutional Review Board (IRB) approvals (or an equivalent approval/review based on the requirements of your country or institution) were obtained? Answer: [NA] Justification: This paper does not have any human subjects, and does not require IRB approvals.