# learning_the_positions_in_countsketch__8382cf50.pdf Published as a conference paper at ICLR 2023 LEARNING THE POSITIONS IN COUNTSKETCH Yi Li Nanyang Technological University yili@ntu.edu.sg Honghao Lin, Simin Liu Carnegie Mellon University {honghaol, siminliu}@andrew.cmu.edu Ali Vakilian Toyota Technological Institute at Chicago vakilian@ttic.edu David P. Woodruff Carnegie Mellon University dwoodruf@andrew.cmu.edu We consider sketching algorithms which first compress data by multiplication with a random sketch matrix, and then apply the sketch to quickly solve an optimization problem, e.g., low-rank approximation and regression. In the learning-based sketching paradigm proposed by Indyk, Vakilian, and Yuan (2019), the sketch matrix is found by choosing a random sparse matrix, e.g., Count Sketch, and then the values of its non-zero entries are updated by running gradient descent on a training data set. Despite the growing body of work on this paradigm, a noticeable omission is that the locations of the non-zero entries of previous algorithms were fixed, and only their values were learned. In this work, we propose the first learning-based algorithms that also optimize the locations of the non-zero entries. Our first proposed algorithm is based on a greedy algorithm. However, one drawback of the greedy algorithm is its slower training time. We fix this issue and propose approaches for learning a sketching matrix for both low-rank approximation and Hessian approximation for second order optimization. The latter is helpful for a range of constrained optimization problems, such as LASSO and matrix estimation with a nuclear norm constraint. Both approaches achieve good accuracy with a fast running time. Moreover, our experiments suggest that our algorithm can still reduce the error significantly even if we only have a very limited number of training matrices. 1 INTRODUCTION The work of (Indyk et al., 2019) investigated learning-based sketching algorithms for low-rank approximation. A sketching algorithm is a method of constructing approximate solutions for optimization problems via summarizing the data. In particular, linear sketching algorithms compress data by multiplication with a sparse sketch matrix and then use just the compressed data to find an approximate solution. Generally, this technique results in much faster or more space-efficient algorithms for a fixed approximation error. The pioneering work of Indyk et al. (2019) shows it is possible to learn sketch matrices for low-rank approximation (LRA) with better average performance than classical sketches. In this model, we assume inputs come from an unknown distribution and learn a sketch matrix with strong expected performance over the distribution. This distributional assumption is often realistic there are many situations where a sketching algorithm is applied to a large batch of related data. For example, genomics researchers might sketch DNA from different individuals, which is known to exhibit strong commonalities. The high-performance computing industry also uses sketching, e.g., researchers at NVIDIA have created standard implementations of sketching algorithms for CUDA, a widely used GPU library. They investigated the (classical) sketched singular value decomposition (SVD), but found that the solutions were not accurate enough across a spectrum of inputs (Chien & Bernabeu, 2019). This is precisely the issue addressed by the learned sketch paradigm where we optimize for good average performance across a range of inputs. All authors contributed equally. Published as a conference paper at ICLR 2023 While promising results have been shown using previous learned sketching techniques, notable gaps remain. In particular, all previous methods work by initializing the sketching matrix with a random sparse matrix, e.g., each column of the sketching matrix has a single non-zero value chosen at a uniformly random position. Then, the values of the non-zero entries are updated by running gradient descent on a training data set, or via other methods. However, the locations of the non-zero entries are held fixed throughout the entire training process. Clearly this is sub-optimal. Indeed, suppose the input matrix A is an n d matrix with first d rows equal to the d d identity matrix, and remaining rows equal to 0. A random sketching matrix S with a single non-zero per column is known to require m = Ω(d2) rows in order for S A to preserve the rank of A (Nelson & Nguyên, 2014); this follows by a birthday paradox argument. On the other hand, it is clear that if S is a d n matrix with first d rows equal to the identity matrix, then S Ax 2 = Ax 2 for all vectors x, and so S preserves not only the rank of A but all important spectral properties. A random matrix would be very unlikely to choose the non-zero entries in the first d columns of S so perfectly, whereas an algorithm trained to optimize the locations of the non-zero entries would notice and correct for this. This is precisely the gap in our understanding that we seek to fill. Learned Count Sketch Paradigm of Indyk et al. (2019). Throughout the paper, we assume our data A Rn d is sampled from an unknown distribution D. Specifically, we have a training set Tr = {A1, . . . , AN} D. The generic form of our optimization problems is min X f(A, X), where A Rn d is the input matrix. For a given optimization problem and a set S of sketching matrices, define ALG(S, A) to be the output of the classical sketching algorithm resulting from using S; this uses the sketching matrices in S to map the given input A and construct an approximate solution ˆX. We remark that the number of sketches used by an algorithm can vary and in its simplest case, S is a single sketch, but in more complicated sketching approaches we may need to apply sketching more than once hence S may also denote a set of more than one sketching matrix. The learned sketch framework has two parts: (1) offline sketch learning and (2) online sketching (i.e., applying the learned sketch and some sketching algorithm to possibly unseen data). In offline sketch learning, the goal is to construct a Count Sketch matrix (abbreviated as CS matrix) with the minimum expected error for the problem of interest. Formally, that is, arg min CS S EA Tr f(A, ALG(S, A)) f(A, X ) = arg min CS S EA Tr f(A, ALG(S, A)), where X denotes the optimal solution. Moreover, the minimum is taken over all possible constructions of CS. We remark that when ALG needs more than one CS to be learned (e.g., in the sketching algorithm we consider for LRA), we optimize each CS independently using a surrogate loss function. In the second part of the learned sketch paradigm, we take the sketch from part one and use it within a sketching algorithm. This learned sketch and sketching algorithm can be applied, again and again, to different inputs. Finally, we augment the sketching algorithm to provide worst-case guarantees when used with learned sketches. The goal is to have good performance on A D while the worst-case performance on A D remains comparable to the guarantees of classical sketches. We remark that the learned matrix S is trained offline only once using the training data. Hence, no additional computational cost is incurred when solving the optimization problem on the test data. Our Results. In this work, in addition to learning the values of the non-zero entries, we learn the locations of the non-zero entries. Namely, we propose three algorithms that learn the locations of the non-zero entries in Count Sketch. Our first algorithm (Section 4) is based on a greedy search. The empirical result shows that this approach can achieve a good performance. Further, we show that the greedy algorithm is provably beneficial for LRA when inputs follow a certain input distribution (Section F). However, one drawback of the greedy algorithm is its much slower training time. We then fix this issue and propose two specific approaches for optimizing the positions for the sketches for low-rank approximation and second-order optimization, which run much faster than all previous algorithms while achieving better performance. For low-rank approximation, our approach is based on first sampling a small set of rows based on their ridge leverage scores, assigning each of these sampled rows to a unique hash bucket, and then placing each non-sampled remaining row in the hash bucket containing the sampled row for which it is most similar to, i.e., for which it has the largest dot product with. We also show that the worst-case guarantee of this approach is strictly better than that of the classical Count-Sketch (see Section 5). Published as a conference paper at ICLR 2023 For sketch-based second-order optimization where we focus on the case that n d, we observe that the actual property of the sketch matrix we need is the subspace embedding property. We next optimize this property of the sketch matrix. We provably show that the sketch matrix S needs fewer rows, with optimized positions of the non-zero entries, when the input matrix A has a small number of rows with a heavy leverage score. More precisely, while Count Sketch takes O(d2/(δϵ2)) rows with failure probability δ, in our construction, S requires only O((d polylog(1/ϵ) + log(1/δ))/ϵ2) rows if A has at most d polylog(1/ϵ)/ϵ2 rows with leverage score at least ϵ/d. This is a quadratic improvement in d and an exponential improvement in δ. In practice, it is not necessary to calculate the leverage scores. Instead, we show in our experiments that the indices of the rows of heavy leverage score can be learned and the induced S is still accurate. We also consider a new learning objective, that is, we directly optimize the subspace embedding property of the sketching matrix instead of optimizing the error in the objective function of the optimization problem in hand. This demonstrates a significant advantage over non-learned sketches, and has a fast training time (Section 6). We show strong empirical results for real-world datasets. For low-rank approximation, our methods reduce the errors by 70% than classical sketches under the same sketch size, while we reduce the errors by 30% than previous learning-based sketches. For second-order optimization, we show that the convergence rate can be reduced by 87% over the non-learned Count Sketch for the LASSO problem on a real-world dataset. We also evaluate our approaches in the few-shot learning setting where we only have a limited amount of training data (Indyk et al., 2021). We show our approach reduces the error significantly even if we only have one training matrix (Sections 7 and 8). This approach clearly runs faster than all previous methods. Additional Related Work. In the last few years, there has been much work on leveraging machine learning techniques to improve classical algorithms. We only mention a few examples here which are based on learned sketches. One related body of work is data-dependent dimensionality reduction, such as an approach for pair-wise/multi-wise similarity preservation for indexing big data (Wang et al., 2017), learned sketching for streaming problems (Indyk et al., 2019; Aamand et al., 2019; Jiang et al., 2020; Cohen et al., 2020; Eden et al., 2021; Indyk et al., 2021), learned algorithms for nearest neighbor search (Dong et al., 2020), and a method for learning linear projections for general applications (Hegde et al., 2015). While we also learn linear embeddings, our embeddings are optimized for the specific application of low rank approximation. In fact, one of our central challenges is that the theory and practice of learned sketches generally needs to be tailored to each application. Our work builds off of (Indyk et al., 2019), which introduced gradient descent optimization for LRA, but a major difference is that we also optimize the locations of the non-zero entries. 2 PRELIMINARIES Notation. Denote the canonical basis vectors of Rn by e1, . . . , en. Suppose that A has singular value decomposition (SVD) A = UΣV . Define [A]k = UkΣk V k to be the optimal rank-k approximation to A, computed by the truncated SVD. Also, define the Moore-Penrose pseudo-inverse of A to be A = V Σ 1U , where Σ 1 is constructed by inverting the non-zero diagonal entries. Let row(A) and col(A) be the row space and the column space of A, respectively. Count Sketch. We define SC Rm n as a classical Count Sketch (abbreviated as CS). It is a sparse matrix with one nonzero entry from { 1} per column. The position and value of this nonzero entry are chosen uniformly at random. Count Sketch matrices can be succinctly represented by two vectors. We define p [m]n, v Rn as the positions and values of the nonzero entries, respectively. Further, we let CS(p, v) be the Count Sketch constructed from vectors p and v. Below we define the objective function f( , ) and a classical sketching algorithm ALG(S, A) for each individual problem. Low-rank approximation (LRA). In LRA, we find a rank-k approximation of our data that minimizes the Frobenius norm of the approximation error. For A Rn d, minrank-k B f LRA(A, B) = minrank-k X A B 2 F . Usually, instead of outputting the a whole B Rn d, the algorithm outputs two factors Y Rn k and X Rk d such that B = Y X for efficiency. Published as a conference paper at ICLR 2023 Indyk et al. (2019) considered Algorithm 1, which only compresses one side of the input matrix A. However, in practice often both dimensions of the matrix A are large. Hence, in this work we consider Algorithm 2 that compresses both sides of A. Constrained regression. Given a vector b Rn, a matrix A Rn d (n d) and a convex set C, we want to find x to minimize the squared error minx C f REG([A b], X) = minx C Ax b 2 2 . (2.1) Iterative Hessian Sketch. The Iterative Hessian Sketching (IHS) method (Pilanci & Wainwright, 2016) solves the constrained least-squares problem by iteratively performing the update xt+1 = arg min x C 2 St+1A(x xt) 2 2 A (b Axt), x xt , (2.2) where St+1 is a sketching matrix. It is not difficult to see that for the unsketched version (St+1 is the identity matrix) of (2.2), the optimal solution xt+1 coincides with the optimal solution to the original constrained regression problem (2.1). The IHS approximates the Hessian A A by a sketched version (St+1A) (St+1A) to improve runtime, as St+1A typically has very few rows. Algorithm 1 Rank-k approximation of A using a sketch S (see (Clarkson & Woodruff, 2009, Sec. 4.1.1)) Input: A Rn d, S Rm n 1: U, Σ, V COMPACTSVD(SA) {r = rank(SA), U Rm r, V Rd r} 2: Return: [AV ]k V Algorithm 2 ALGLRA(SKETCH-LOWRANK) Sarlos (2006); Clarkson & Woodruff (2017); Avron et al. (2017). Input: A Rn d, S Rm S n, R Rm R d, V Rm V n, W Rm W d 1: UC [TC T C] V AR , T D T D SAW with UC, UD orthogonal 2: G V AW , Z LZ R [U C GUD]k 3: ZL Z L(T 1 D ) 0 , ZR T 1 C Z R 0 4: Z ZLZR 5: return: AR ZSA in form Pn k, Qk d Learning-Based Algorithms in the Few-Shot Setting. Recently, Indyk et al. (2021) studied learning-based algorithms for LRA in the setting where we have access to limited data or computing resources. We provide a brief explanation of learning-based algorithms in the Few-Shot setting in Appendix A.3. Leverage Scores and Ridge Leverage Scores. Given a matrix A, the leverage score of the i-th row ai of A is defined to be τi := ai(A A) a i , which is the squared ℓ2-norm of the i-th row of U, where A = UΣV T is the singular value decomposition of A. Given a regularization parameter λ, the ridge leverage score of the i-th row ai of A is defined to be τi := ai(A A + λI) a i . Our learning-based algorithms employs the ridge leverage score sampling technique proposed in (Cohen et al., 2017), which shows that sampling proportional to ridge leverage scores gives a good solution to LRA. 3 DESCRIPTION OF OUR APPROACH We describe our contributions to the learning-based sketching paradigm which, as mentioned, is to learn the locations of the non-zero values in the sketch matrix. To learn a Count Sketch for the given training data set, we locally optimize the following in two stages: min S EA D [f(A, ALG(S, A))] . (3.1) (1) compute the positions of the non-zero entries, then (2) fix the positions and optimize their values. Stage 1: Optimizing Positions. In Section 4, we provide a greedy search algorithm for this stage, as our starting point. In Section 5 and 6, we provide our specific approaches for optimizing the positions for the sketches for low-rank approximation and second-order optimization. Stage 2: Optimizing Values. This stage is similar to the approach of Indyk et al. (2019). However, instead of the power method, we use an automatic differentiation package, Py Torch (Paszke et al., 2019), and we pass it our objective minv Rn EA D [f(A, ALG(CS(p, v), A))], implemented as a chain of differentiable operations. It will automatically compute the gradient using the chain rule. We Published as a conference paper at ICLR 2023 also consider new approaches to optimize the values for LRA (proposed in Indyk et al. (2021), see Appendix A.3 for details) and second-order optimization (proposed in Section 6). Worst-Cases Guarantees. In Appendix D, we show that both of our approaches for the above two problems can perform no worse than a classical sketching matrix when A does not follow the distribution D. In particular, for LRA, we show that the sketch monotonicity property holds for the time-optimal sketching algorithm for low rank approximation. For second-order optimization, we propose an algorithm which runs in input-sparsity time and can test for and use the better of a random sketch and a learned sketch. 4 SKETCH LEARNING: GREEDY SEARCH Algorithm 3 POSITION OPTIMIZATION: GREEDY SEARCH Input: f, ALG, Tr = {A1, ..., AN Rn d}; sketch dimension m 1: initialize SL = Om n 2: for i = 1 to n do 3: j arg min j [m] A Tr f(A, ALG(SL eje i , A)) 4: SL SL (eje i ) 5: end for 6: return p for SL = CS(p, v) When S is a Count Sketch, computing SA amounts to hashing the n rows of A into the m n rows of SA. The optimization is a combinatorial optimization problem with an empirical risk minimization (ERM) objective. The naïve solution is to compute the objective value of the exponentially many (mn) possible placements, but this is clearly intractable. Instead, we iteratively construct a full placement in a greedy fashion. We start with S as a zero matrix. Then, we iterate through the columns of S in an order determined by the algorithm, adding a nonzero entry to each. The best position in each column is the one that minimizes Eq. (3.1) if an entry were to be added there. For each column, we evaluate Eq. (3.1) O(m) times, once for each prospective half-built sketch. While this greedy strategy is simple to state, additional tactics are required for each problem to make it more tractable. Usually the objective evaluation (Algorithm 3, line 3) is too slow, so we must leverage our insight into their sketching algorithms to pick a proxy objective. Note that we can reuse these proxies for value optimization, since they may make gradient computation faster too. Proxy objective for LRA. For the two-sided sketching algorithm, we can assume that the two factors X, Y has the form Y = AR Y and X = XSA, where S and R are both CS matrices, so we optimize the positions in both S and R. We cannot use f(A, ALG(S, R, A)) as our objective because then we would have to consider combinations of placements between S and R. To find a proxy, we note that a prerequisite for good performance is for row(SA) and col(AR ) to both contain a good rank-k approximation to A (see proof of Lemma C.5). Thus, we can decouple the optimization of S and R. The proxy objective for S is [AV ]k V A 2 F where SA = UΣV . In this expression, ˆX = [AV ]k V is the best rank-k approximation to A in row(SA). The proxy objective for R is defined analogously. In Appendix F, we show the greedy algorithm is provably beneficial for LRA when inputs follow the spiked covariance or the Zipfian distribution. Despite the good empirical performance we present in Section 7, one drawback is its much slower training time. Also, for the iterative sketching method for second-order optimization, it is non-trivial to find a proxy objective because the input of the i-th iteration depends on the solution to the (i 1)-th iteration, for which the greedy approach sometimes does not give a good solution. In the next section, we will propose our specific approach for optimizing the positions of the sketches for low-rank approximation and second-order optimization, both of which achieve a very high accuracy and can finish in a very short amount of time. 5 SKETCH LEARNING: LOW-RANK APPROXIMATION Now we present a conceptually new algorithm which runs much faster and empirically achieves similar error bounds as the greedy search approach. Moreover, we show that this algorithm has strictly better guarantees than the classical Count-Sketch. To achieve this, we need a more careful analysis. To provide some intuition, if rank(SA) = k and SA = UΣV , then the rank-k approximation cost is exactly AV V A 2 F , the projection cost Published as a conference paper at ICLR 2023 onto col(V ). Minimizing it is equivalent to maximizing the sum of squared projection coefficients: F = arg min S i [n] ( Ai 2 2 X j [k] Ai, vj 2) = arg max S j [k] Ai, vj 2. As mentioned, computing SA actually amounts to hashing the n rows of A to the m rows of SA. Hence, intuitively, if we can put similar rows into the same bucket, we may get a smaller error. Algorithm 4 Position optimization: Inner Product Input: A Rn d: average of Tr; sketch dim. m 1: initialize S1, S2 = Om n 2: Sample a set C = {C1 Cm} of rows using ridge leverage score sampling (see Section 2). 3: for i = 1 to n do 4: pi, vi arg max p [m],v { 1} Cp Cp 2 , v Ai Ai 2 5: S1[pi, i] vi 6: end for 7: for i = 1 to m do 8: Ii {j | pj = i} 9: A(i) restriction of A to rows in Ii 10: ui the top left singular vector of A(i) 11: S1[i, Ii] u i 12: end for 13: for i = 1 to m do 14: qi index such that Ci is the qi-th row of A 15: S2[i, qi] 1 16: end for 17: return S1 or [ S1 S2 ] Our algorithm is given in Algorithm 4. Suppose that we want to form matrix S with m rows. At the beginning of the algorithm, we sample m rows according to the ridge leverage scores of A. By the property of the ridge leverage score, the subspace spanned by this set of sampled rows contains an approximately optimal solution to the low rank approximation problem. Hence, we map these rows to separate buckets of SA. Then, we need to decide the locations of the remaining rows (i.e., the non-sampled rows). Ideally, we want similar rows to be mapped into the same bucket. To achieve this, we use the m sampled rows as reference points and assign each (nonsampled) row Ai to the p-th bucket in SA if the normalized row Ai and Cp have the largest inner product (among all possible buckets). Once the locations of the non-zero entries are fixed, the next step is to determine the values of these entries. We follow the same idea proposed in (Indyk et al., 2021): for each block A(i), one natural approach is to choose the unit vector si R|Ii| that preserves as much of the Frobenius norm of A(i) as possible, i.e., to maximize s i A(i) 2 2. Hence, we set si to be the top left singular vector of A(i). In our experiments, we observe that this step reduces the error of downstream value optimizations performed by SGD. To obtain a worst-case guarantee, we show that w.h.p., the row span of the sampled rows Ci is a good subspace. We set the matrix S2 to be the sampling matrix that samples Ci. The final output of our algorithm is the vertical concatenation of S1 and S2. Here S1 performs well empirically, while S2 has a worst-case guarantee for any input. Combining Lemma E.2 and the sketch monotonicity for low rank approximation in Section D, we get that O(k log k + k/ϵ) rows is enough for a (1 ϵ)- approximation for the input matrix A induced by Tr, which is better than the Ω(k2) rows required of a non-learned Count-Sketch, even if its non-zero values have been further improved by the previous learning-based algorithms in (Indyk et al., 2019; 2021). As a result, under the assumption of the input data, we may expect that S will still be good for the test data. We defer the proof to Appendix E.1. In Appendix A, we shall show that the assumptions we make in Theorem 5.1 are reasonable. We also provide an empirical comparison between Algorithm 4 and some of its variants, as well as some adaptive sketching methods on the training sample. The evaluation result shows that only our algorithm has a significant improvement for the test data, which suggests that both ridge leverage score sampling and row bucketing are essential. Theorem 5.1. Let S R2m n be given by concatenating the sketching matrices S1, S2 computed by Algorithm 4 with input A induced by Tr and let B Rn d. Then with probability at least 1 δ, we have minrank-k X:row(X) row(SB) B X 2 F (1 + ϵ) B Bk 2 F if one of the following holds:. 1. m = O(β (k log k + k/ϵ)), δ = 0.1, and τi(B) 1 β τi(A) for all i [n]. 2. m = O(k log k + k/ϵ), δ = 0.1 + 1.1β, and the total variation distance dtv(p, q) β, where p, q are sampling probabilities defined as pi = τi(A) P i τi(A) and qi = τi(B) P Time Complexity. As mentioned, an advantage of our second approach is that it significantly reduces the training time. We now discuss the training times of different algorithms. For the value-learning Published as a conference paper at ICLR 2023 algorithms in (Indyk et al., 2019), each iteration requires computing a differentiable SVD to perform gradient descent, hence the runtime is at least Ω(nit T), where nit is the number of iterations (usually set > 500) and T is the time to compute an SVD. For the greedy algorithm, there are m choices for each column, hence the runtime is at least Ω(mn T). For our second approach, the most complicated step is to compute the ridge leverage scores of A and then the SVD of each submatrix. Hence, the total runtime is at most O(T). We note that the time complexities discussed here are all for training time. There is no additional runtime cost for the test data. 6 SKETCH LEARNING: SECOND-ORDER OPTIMIZATION In this section, we consider optimizing the sketch matrix in the context of second-order methods. The key observation is that for many sketching-based second-order methods, the crucial property of the sketching matrix is the so-called subspace embedding property: for a matrix A Rn d, we say a matrix S Rm n is a (1 ϵ)-subspace embedding for the column space of A if (1 ϵ) Ax 2 SAx 2 (1 + ϵ) Ax 2 for all x Rd. For example, consider the iterative Hessian sketch, which performs the update (2.2) to compute {xt}t. Pilanci & Wainwright (2016) showed that if S1, . . . , St+1 are (1 + O(ρ))-subspace embeddings of A, then A(xt x ) 2 ρt Ax 2. Thus, if Si is a good subspace embedding of A and we will have a good convergence guarantee. Therefore, unlike (Indyk et al., 2019), which treats the training objective in a black-box manner, we shall optimize the subspace embedding property of the matrix A. Optimizing positions. We consider the case that A has a few rows of large leverage score, as well as access to an oracle which reveals a superset of the indices of such rows. Formally, let τi(A) be the leverage score of the i-th row of A and I = {i : τi(A) ν} be the set of rows with large leverage score. Suppose that a superset I I is known to the algorithm. In the experiments we train an oracle to predict such rows. We can maintain all rows in I explicitly and apply a Count-Sketch to the remaining rows, i.e., the rows in [n] \ I. Up to permutation of the rows, we can write and S = I 0 0 S where S is a random Count-Sketch matrix of m rows. Clearly S has a single non-zero entry per column. We have the following theorem, whose proof is postponed to Section E.2. Intuitively, the proof for Count-Sketch in (Clarkson & Woodruff, 2017) handles rows of large leverage score and rows of small leverage score separately. The rows of large leverage score are to be perfectly hashed while the rows of small leverage score will concentrate in the sketch by the Hanson-Wright inequality. Theorem 6.1. Let ν = ϵ/d. Suppose that m = O((d/ϵ2)(polylog(1/ϵ) + log(1/δ))), δ (0, 1/m] and d = Ω((1/ϵ) polylog(1/ϵ) log2(1/δ)). Then, there exists a distribution on S of the form in (6.1) with m + |I| rows such that Pr x col(A), | Sx 2 2 x 2 2 | > ϵ x 2 2 δ. In particular, when δ = 1/m, the sketching matrix S has O((d/ϵ2) polylog(d/ϵ)) rows. Hence, if there happen to be at most d polylog(1/ϵ)/ϵ2 rows of leverage score at least ϵ/d, the overall sketch length for embedding colsp(A) can be reduced to O((d polylog(1/ϵ)+log(1/δ))/ϵ2), a quadratic improvement in d and an exponential improvement in δ over the original sketch length of O(d2/(ϵ2δ)) for Count-Sketch. In the worst case there could be O(d2/ϵ) such rows, though empirically we do not observe this. In Section 8, we shall show it is possible to learn the indices of the heavy rows for real-world data. Optimizing values. When we fix the positions of the non-zero entries, we aim to optimize the values by gradient descent. Rather than the previous black-box way in (Indyk et al., 2019) that minimizes P i f(A, ALG(S, A)), we propose the following objective loss function for the learning algorithm L(S, A) = P Ai A (Ai Ri) Ai Ri I F , over all the training data, where Ri comes from the QR decomposition of SAi = Qi R 1 i . The intuition for this loss function is given by the lemma below, whose proof is deferred to Section E.3. Lemma 6.2. Suppose that ϵ (0, 1 2), S Rm n, A Rn d of full column rank, and SA = QR is the QR-decomposition of SA. If (AR 1) AR 1 I op ϵ, then S is a (1 ϵ)-subspace embedding of col(A). Published as a conference paper at ICLR 2023 Lemma 6.2 implies that if the loss function over Atrain is small and the distribution of Atest is similar to Atrain, it is reasonable to expect that S is a good subspace embedding of Atest. Here we use the Frobenius norm rather than operator norm in the loss function because it will make the optimization problem easier to solve, and our empirical results also show that the performance of the Frobenius norm is better than that of the operator norm. 7 EXPERIMENTS: LOW-RANK APPROXIMATION In this section, we evaluate the empirical performance of our learning-based approach for LRA on three datasets. For each, we fix the sketch size and compare the approximation error A X F A Ak F averaged over 10 trials. In order to make position optimization more efficient, in line 3 of Algorithm 3), instead of computing many rank-1 SVD updates, we use formulas for fast rank-1 SVD updates (Brand, 2006). For the greedy method, we used several Nvidia Ge Force GTX 1080 Ti machines. For the maximum inner product method, the experiments are conducted on a laptop with a 1.90GHz CPU and 16GB RAM. Datasets. We use the three datasets from (Indyk et al., 2019): (1, 2) Friends, Logo (image): frames from a short video of the TV show Friends and of a logo being painted; (3) Hyper (image): hyperspectral images from natural scenes. Additional details are in Table A.1. Baselines. We compare our approach to the following baselines. Classical CS: a random Count Sketch. IVY19: a sparse sketch with learned values, and random positions for the non-zero entries. Ours (greedy): a sparse sketch where both the values and positions of the non-zero entries are learned. The positions are learned by Algorithm 3. The values are learned similarly to (Indyk et al., 2019). Ours (inner product): a sparse sketch where both the values and the positions of the non-zero entries are learned. The positions are learned by S1 in Algorithm 4. IVY19 and greedy algorithm use the full training set and our Algorithm 4 takes the input as the average over the entire training matrix. We also give a sensitivity analysis for our algorithm, where we compare our algorithm with the following variants: Only row sampling (perform projection by ridge leverage score sampling), ℓ2 sampling (Replace leverage score sampling with ℓ2-norm row sampling and maintain the same downstream step), and Randomly Grouping (Use ridge leverage score sampling but randomly distribute the remaining rows). The result shows none of these variants outperforms non-learned sketching. We defer the results of this part to Appendix A.1. Result Summary. Our empirical results are provided in Table 7.1 for both Algorithm 2 and Algorithm 1, where the errors take an average over 10 trials. We use the average of all training matrices from Tr, as the input to the algorithm 4. We note that all the steps of our training algorithms are done on the training data. Hence, no additional computational cost is incurred for the sketching algorithm on the test data. Experimental parameters (i.e., learning rate for gradient descent) can be found in Appendix G. For both sketching algorithms, Ours are always the best of the four sketches. It is significantly better than Classical CS, obtaining improvements of around 70%. It also obtains a roughly 30% improvement over IVY19. Offline learning Online solving Ours (inner product) 5 0.166 Ours (greedy) 6300 (1.75h) 0.172 IVY19 193 (3min) 0.168 Classical CS 0.166 Table 7.2: Runtime (in seconds) of LRA on Logo with k = 30, m = 60 Wall-Clock Times. The offline learning runtime is in Table 7.2, which is the time to train a sketch on Atrain. We can see that although the greedy method will take much longer (1h 45min), our second approach is much faster (5 seconds) than the previous algorithm in (Indyk et al., 2019) (3 min) and can still achieve a similar error as the greedy k, m, Sketch Logo Friends Hyper 20, 40, Classical CS 2.371 4.073 6.344 20, 40, IVY19 0.687 1.048 3.764 20, 40, Ours (greedy) 0.500 0.899 2.497 20, 40, Ours (inner product) 0.532 0.733 2.975 30, 60, Class CS 1.642 2.683 5.390 30, 60, IVY19 0.734 1.077 3.748 30, 60, Ours (greedy) 0.492 0.794 2.492 30, 60, Ours (inner product) 0.436 0.733 2.409 k, m, Sketch Logo Friends Hyper 20, 40, Classical CS 0.930 1.542 2.971 20, 40, IVY19 0.255 0.723 1.273 20, 40, Ours (greedy) 0.196 0.407 0.784 20, 40, Ours (inner product) 0.205 0.407 1.223 30, 60, Classical CS 0.650 1.0575 2.315 30, 60, IVY19 0.290 0.713 1.274 30, 60, Ours(greedy) 0.197 0.406 0.717 30, 60, Ours(inner product) 0.201 0.340 0.943 Table 7.1: Test errors for LRA. (Left: two-side sketch. Right: one-side sketch) Published as a conference paper at ICLR 2023 0.05 0.10 0.15 0.20 0.25 runtime(seconds) log_10(error) learned(value-only) count-sketch learned(position and value) 0.05 0.10 0.15 0.20 0.25 0.30 runtime(seconds) log_10(error) learned(value-only) count-sketch learned(position and value) 0.05 0.10 0.15 0.20 0.25 0.30 runtime(seconds) log_10(error) learned(value-only) count-sketch learned(position and value) Figure 7.1: Test error of LASSO in Electric dataset. algorithm. The reason is that Algorithm 4 only needs to compute the ridge leverage scores on the training matrix once, which is actually much cheaper than IVY19 which needs to compute a differentiable SVD many times during gradient descent. In Section A.4, we also study the performance of our approach in the few-shot learning setting, which has been studied in Indyk et al. (2021). 8 EXPERIMENTS: SECOND-ORDER OPTIMIZATION In this section, we consider the IHS on the following instance of LASSO regression: x = arg min x 1 λ f(x) = arg min x 1 λ 1 2 Ax b 2 2 , (8.1) where λ is a parameter. We also study the performance of the sketches on the matrix estimation with a nuclear norm constraint problem, the fast regression solver (van den Brand et al. (2021)), as well as the use of sketches for first-order methods. The results can be found in Appendix B. All of our experiments are conducted on a laptop with a 1.90GHz CPU and 16GB RAM. The offline training is done separately using a single GPU. The details of the implementation are deferred to Appendix G. Dataset. We use the Electric1 dataset of residential electric load measurements. Each row of the matrix corresponds to a different residence. Matrix columns are consecutive measurements at different times. Here Ai R370 9, bi R370 1, and |(A, b)train| = 320, |(A, b)test| = 80. We set λ = 15. Experiment Setting. We compare the learned sketch against the classical Count-Sketch2. We choose m = 6d, 8d, 10d and consider the error f(x) f(x ). For the heavy-row Count-Sketch, we allocate 30% of the sketch space to the rows of the heavy row candidates. For this dataset, each row represents a specific residence and hence there is a strong pattern of the distribution of the heavy rows. We select the heavy rows according to the number of times each row is heavy in the training data. We give a detailed discussion about this in Appendix B.1. We highlight that it is still possible to recognize the pattern of the rows even if the row orders of the test data are permuted. We also consider optimizing the non-zero values after identifying the heavy rows, using our new approach in Section 6. Results. We plot in Figures 7.1 the mean errors on a logarithmic scale. The average offline training time is 3.67s to find a superset of the heavy rows over the training data and 66s to optimize the values when m = 10d, which are both faster than the runtime of Indyk et al. (2019) with the same parameters. Note that the learned matrix S is trained offline only once using the training data. Hence, no additional computational cost is incurred when solving the optimization problem on the test data. We see all methods display linear convergence, that is, letting ek denote the error in the k-th iteration, we have ek ρke1 for some convergence rate ρ. A smaller convergence rate implies a faster convergence. We calculate an estimated rate of convergence ρ = (ek/e1)1/k with k = 7. We can see both sketches, especially the sketch that optimizes both the positions and values, show significant improvements. When the sketch size is small (6d), this sketch has a convergence rate that is just 13.2% of that of the classical Count-Sketch, and when the sketch size is large (10d), this sketch has a smaller convergence rate that is just 12.1%. 1https://archive.ics.uci.edu/ml/datasets/Electricity Load Diagrams20112014 2The framework of Indyk et al. (2019) does not apply to the iterative sketching methods in a straightforward manner, so here we only compare with the classical Count Sketch. For more details, please refer to Section B. Published as a conference paper at ICLR 2023 ACKNOWLEDGEMENTS Yi Li would like to thank for the partial support from the Singapore Ministry of Education under Tier 1 grant RG75/21. Honghao Lin and David Woodruff were supported in part by an Office of Naval Research (ONR) grant N00014-18-1-2562. Ali Vakilian was supported by NSF award CCF-1934843. Anders Aamand, Piotr Indyk, and Ali Vakilian. (learned) frequency estimation algorithms under zipfian distribution. ar Xiv preprint ar Xiv:1908.05198, 2019. Akshay Agrawal, Brandon Amos, Shane T. Barratt, Stephen P. Boyd, Steven Diamond, and J. Zico Kolter. Differentiable convex optimization layers. In Advances in Neural Information Processing Systems, pp. 9558 9570, 2019. Haim Avron, Kenneth L. Clarkson, and David P. Woodruff. Sharper bounds for regularized data fitting. In Approximation, Randomization, and Combinatorial Optimization. Algorithms and Techniques, (APPROX/RANDOM), pp. 27:1 27:22, 2017. Jean Bourgain, Sjoerd Dirksen, and Jelani Nelson. Toward a unified theory of sparse dimensionality reduction in Euclidean space. Geometric and Functional Analysis, pp. 1009 1088, 2015. Matthew Brand. Fast low-rank modifications of the thin singular value decomposition. Linear Algebra and its Applications, 415.1, 2006. Lung-Sheng Chien and Samuel Rodriguez Bernabeu. Fast singular value decomposition on gpu. NVIDIA presentation at GPU Technology Conference, 2019. URL https://developer. download.nvidia.com/video/gputechconf/gtc/2019/presentation/ s9226-fast-singular-value-decomposition-on-gpus-v2.pdf. Kenneth L Clarkson and David P Woodruff. Numerical linear algebra in the streaming model. In Proceedings of the forty-first annual symposium on Theory of computing (STOC), pp. 205 214, 2009. Kenneth L Clarkson and David P Woodruff. Low-rank approximation and regression in input sparsity time. Journal of the ACM (JACM), 63(6):54, 2017. Edith Cohen, Ofir Geri, and Rasmus Pagh. Composable sketches for functions of frequencies: Beyond the worst case. In International Conference on Machine Learning, pp. 2057 2067. PMLR, 2020. Michael B. Cohen, Cameron Musco, and Christopher Musco. Input sparsity time low-rank approximation via ridge leverage score sampling. In Proceedings of the Twenty-Eighth Annual ACM-SIAM Symposium on Discrete Algorithms, (SODA), pp. 1758 1777, 2017. Graham Cormode and Charlie Dickens. Iterative hessian sketch in input sparsity time. In Proceedings of 33rd Conference on Neural Information Processing Systems (Neur IPS), Vancouver, Canada, 2019. Yihe Dong, Piotr Indyk, Ilya Razenshteyn, and Tal Wagner. Learning sublinear-time indexing for nearest neighbor search. In International Conference on Learning Representations, 2020. Talya Eden, Piotr Indyk, Shyam Narayanan, Ronitt Rubinfeld, Sandeep Silwal, and Tal Wagner. Learning-based support estimation in sublinear time. In International Conference on Learning Representations, 2021. Chinmay Hegde, Aswin C. Sankaranarayanan, Wotao Yin, and Richard G. Baraniuk. Numax: A convex approach for learning near-isometric linear embeddings. In IEEE Transactions on Signal Processing, pp. 6109 6121, 2015. Piotr Indyk, Ali Vakilian, and Yang Yuan. Learning-based low-rank approximations. In Advances in Neural Information Processing Systems, pp. 7400 7410, 2019. Piotr Indyk, Tal Wagner, and David Woodruff. Few-shot data-driven algorithms for low rank approximation. Advances in Neural Information Processing Systems, 34, 2021. Published as a conference paper at ICLR 2023 Tanqiu Jiang, Yi Li, Honghao Lin, Yisong Ruan, and David P. Woodruff. Learning-augmented data stream algorithms. In International Conference on Learning Representations, 2020. Xiangrui Meng and Michael W Mahoney. Low-distortion subspace embeddings in input-sparsity time and applications to robust linear regression. In Proceedings of the forty-fifth annual ACM symposium on Theory of computing, pp. 91 100, 2013. Jelani Nelson and Huy L Nguyên. Osnap: Faster numerical linear algebra algorithms via sparser subspace embeddings. In Foundations of Computer Science (FOCS), 2013 IEEE 54th Annual Symposium on, pp. 117 126, 2013. Jelani Nelson and Huy L. Nguyên. Lower bounds for oblivious subspace embeddings. In Automata, Languages, and Programming - 41st International Colloquium (ICALP), pp. 883 894, 2014. Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, and Trevor Killeen. Pytorch: An imperative style, high-performance deep learning library. 2019. Mert Pilanci and Martin J. Wainwright. Iterative Hessian sketch: Fast and accurate solution approximation for constrained least-squares. J. Mach. Learn. Res., 17:53:1 53:38, 2016. Tamas Sarlos. Improved approximation algorithms for large matrices via random projections. In 47th Annual IEEE Symposium on Foundations of Computer Science (FOCS), pp. 143 152, 2006. Jan van den Brand, Binghui Peng, Zhao Song, and Omri Weinstein. Training (overparametrized) neural networks in near-linear time. In James R. Lee (ed.), 12th Innovations in Theoretical Computer Science Conference, ITCS, volume 185, pp. 63:1 63:15, 2021. Roman Vershynin. Introduction to the non-asymptotic analysis of random matrices. In Yonina C. Eldar and Gitta Kutyniok (eds.), Compressed Sensing: Theory and Applications, pp. 210 268. Cambridge University Press, 2012. doi: 10.1017/CBO9780511794308.006. Jingdong Wang, Ting Zhang, Nicu Sebe, and Heng Tao Shen Wang. A survey on learning to hash. In IEEE Transactions on Pattern Analysis and Machine Intelligence, pp. 769 790, 2017. Published as a conference paper at ICLR 2023 A ADDITIONAL EXPERIMENTS: LOW-RANK APPROXIMATION The details (data dimension, Ntrain, etc.) are presented in Table A.1. Name Type Dimension Ntrain Ntest Friends Image 5760 1080 400 100 Logo Image 5760 1080 400 100 Hyper Image 1024 768 400 100 Table A.1: Data set descriptions A.1 SENSITIVITY ANALYSIS OF ALGORITHM 4 Sketch Logo Friends Hyper Ours(inner product) 0.311 0.470 1.232 ℓ2 Sampling 0.698 0.935 1.293 Only Ridge 0.994 1.493 4.155 Randomly Grouping 0.659 1.069 2.070 Table A.2: Sensitivity analysis for our approach (using Algorithm 1 from Indyk et al. (2019) with one sketch) In this section we explore how sensitive the performance of our Algorithm 4 is to the ridge leverage score sampling and maximum inner product grouping process. We consider the following baselines: ℓ2 norm sampling: we sample the rows according to their squared length instead of doing ridge leverage score sampling. Only ridge leverage score sampling: the subspace spanned by only the sampled rows from ridge leverage score sampling. Randomly grouping: we put the sampled rows into different buckets as before, but randomly divide the non-sampled rows into buckets. The results are shown in Table A.2. Here we set k = 30, m = 60 as an example. To show the difference of the initialization method more clearly, we compare the error using the one-sided sketching Algorithm 1 and do not further optimize the non-zeros values. From the table we can see both that ridge leverage score sampling and the downstream grouping process are necessary, otherwise the error will be similar or even worse than that of the classical Count-Sketch. A.2 TOTAL VARIATION DISTANCE As we have shown in Theorem 5.1, if the total variation distance between the row sampling probability distributions p and q is O(1), we have a worst-case guarantee of O(k log k + k/ϵ), which is strictly better than the Ω(k2) lower bound for the random Count Sketch, even when its non-zero values have been optimized. We now study the total variation distance between the train and test matrix in our dataset. The result is shown in Figure A.1. From the figure we can see that for all the three dataset, the total variation distance is bounded by a constant, which suggests that the assumptions are reasonable for real-world data. 0 20 40 60 80 100 Test Matrix # Total Variation Distance Tv distance 0 20 40 60 80 100 Test Matrix # Total Variation Distance Tv distance 0 20 40 60 80 Test Matrix # Total Variation Distance Tv distance Figure A.1: Total variation distance between train and test matrices. left: Logo, middle: friend, right: Hyper. Published as a conference paper at ICLR 2023 A.3 LEARNING-BASED ALGORITHMS FOR LOW-RANK APPROXIMATION IN THE FEW-SHOT SETTING In this section, we will give a brief explanation of the two algorithms proposed in Indyk et al. (2021). Both algorithms aim to optimize the non-zero values of a Count-Sketch matrix under fixed locations of the non-zero entries. One-shot closed-form algorithm. Given a sparsity pattern of a Count-Sketch matrix S Rm n, it partitions the rows of A into m blocks A(1), ..., A(m) as follows: let Ii = {j : Sij = 1}. The block A(i) R|Ii| d is the sub-matrix of A that contains the rows whose indices are in Ii. The goal here is for each block A(i), to choose a (non-sparse) one-dimensional sketching vector si R|Ii|. The first approach is to set si to be the top left singular vector of A(i), which is the algorithm 1Shot2Vec. Another approach is to set si to be a left singular vector of A(i) chosen randomly and proportional to its squared singular value. The main advantage of the latter approach over the previous one is that it endows the algorithm with provable guarantees on the LRA error. The 1Shot2Vec algorithm combines both ways, obtaining the benefits of both approaches. The advantage of these two algorithms is that they extract a sketching matrix by an analytic computation, requiring neither GPU access nor auto-gradient functionality. Few-shot SGD algorithm. In this algorithm, the authors propose a new loss function for LRA, namely, min CS S E A Tr U k S SU I0 2 where A = UΣV is the SVD-decomposition of A and Uk Rn k denotes the submatrix of U that contains its first k columns. I0 Rk d denotes the result of augmenting the identity matrix of order k with d k additional zero columns on the right. This loss function is motivated by the analysis of prior LRA algorithms that use random sketching matrices. It is faster to compute and differentiate than the previous empirical loss in Indyk et al. (2019). In the experiments the authors also show that this loss function can achieve a smaller error in a shorter amount of time, using a small number of randomly sampled training matrices, though the final error will be larger than that of the previous algorithm in Indyk et al. (2019) if we allow a longer training time and access to the whole training set Tr. A.4 EXPERIMENTS: LRA IN THE FEW-SHOT SETTING In the rest of this section, we study the performance of our second approach in the few-shot learning setting. We first consider the case where we only have one training matrix randomly sampled from Tr. Here, we compare our method with the 1Shot2Vec method proposed in (Indyk et al., 2021) in the same setting (k = 10, m = 40) as in their empirical evaluation. The result is shown in Table A.3. Compared to 1Shot2Vec, our method reduces the error by around 50%, and has an even slightly faster runtime. Indyk et al. (2021) also proposed a Few Shot SGD algorithm which further improves the non-zero values of the sketches after different initialization methods. We compare the performance of this approach for different initialization methods: in all initialization methods, we only use one training matrix and we use three training matrices for the Few Shot SGD step. The results are shown in Table A.4. We report the minimum error of 50 iterations of the Few Shot SGD because we aim to compare the computational efficiency for different methods. From the table we see that our approach plus the Few Shot SGD method can achieve a much smaller error, with around a 50% improvement upon (Indyk et al., 2021). Moreover, even without further optimization by Few Shot SGD, our initialization method for learning the non-zero locations in Count Sketch obtains a smaller error than other methods (even when they are optimized with 1Shot SGD or Few Shot SGD learning). B ADDITIONAL EXPERIMENTS: SECOND-ORDER OPTIMIZATION As we mentioned in Section 8, despite the number of problems that learned sketches have been applied to, they have not been applied to convex optimization, or say, iterative sketching algorithms in general. To demonstrate the difficulty, we consider the Iterative Hessian Sketch (IHS) as an example. In that scheme, suppose that we have k iterations of the algorithm. Then we need k independent Published as a conference paper at ICLR 2023 Algorithm Dataset Few-shot Error Training Time Classical CS Friends 0.524 Hyper 1.082 Logo 0.171 5.682 Friends 0.306 5.680 Hyper 0.795 1.054 Ours (inner product) Logo 0.065 4.515 Friends 0.139 4.773 Hyper 0.535 0.623 Table A.3: Test errors and training times for LRA in the one-shot setting (using Alg. 1 with one sketch) Sketch Logo Friends Hyper Ours (Initialization only) 0.065 0.139 0.535 Ours + Few Shot SGD 0.048 0.125 0.443 1Shot1Vec only 0.171 0.306 0.795 1Shot1Vec + Few Shot SGD 0.104 0.229 0.636 Classical CS 0.331 0.524 1.082 Classical CS + Few Shot SGD 0.173 0.279 0.771 Table A.4: Test errors for LRA in the fewshot setting (using Alg. 1 from Indyk et al. (2019) with one sketch) sketching matrices (otherwise the solution may diverge). A natural way is to follow the method in (Indyk et al., 2019), which is to minimize the following quantity min S1,...Sk E A D f(A, ALG(S1, ..., Sk, A)) , where the minimization is taken over k Count-Sketch matrices S1, . . . , Sk. In this case, however, calculating the gradient with respect to S1 would involve all iterations and in each iteration we need to solve a constrained optimization problem. Hence, it would be difficult and intractable to compute the gradients. An alternative way is to train k sketching matrices sequentially, that is, learn the sketching matrix for the i-th iteration using a local loss function for the i-th iteration, and then using the learned matrix in the i-th iteration to generate the training data for the (i + 1)-st iteration. However, the empirical results suggest that it works for the first iteration only, because in this case the training data for the (i + 1)-st iteration depends on the solution to the i-th iteration and may become farther away from the test data in later iterations. The core problem here is that the method proposed in Indyk et al. (2019) treats the training process in a black-box way, which is difficult to extend to iterative methods. B.1 THE DISTRIBUTION OF THE HEAVY ROWS In our experiments, we hypothesize that in real-world data there may be an underlying pattern which can help us identify the heavy rows. In the Electric dataset, each row of the matrix corresponds to a specific residence and the heavy rows are concentrated on some specific rows. To exemplify this, we study the heavy leverage score rows distribution over the Electric dataset. For a row i [370], let fi denote the number of times that row i is heavy out of 320 training data points from the Electric dataset, where we say row i is heavy if ℓi 5d/n. Below we list all 74 pairs (i, fi) with fi > 0. (195,320), (278,320), (361,320), (207,317), (227,285), (240,284), (219,270), (275,232), (156,214), (322,213), (193,196), (190,192), (160,191), (350,181), (63,176), (42,168), (162,148), (356,129), (363,110), (362,105), (338,95), (215,94), (234,93), (289,81), (97,80), (146,70), (102,67), (98,58), (48,57), (349,53), (165,46), (101,41), (352,40), (293,34), (344,29), (268,21), (206,20), (217,20), (327,20), (340,19), (230,18), (359,18), (297,14), (357,14), (161,13), (245,10), (100,8), (85,6), (212,6), (313,6), (129,5), (130,5), (366,5), (103,4), (204,4), (246,4), (306,4), (138,3), (199,3), (222,3), (360,3), (87,2), (154,2), (209,2), (123,1), (189,1), (208,1), (214,1), (221,1), (224,1), (228,1), (309,1), (337,1), (343,1) Observe that the heavy rows are concentrated on a set of specific row indices. There are only 30 rows i with fi 50. We view this as strong evidence for our hypothesis. Heavy Rows Distribution Under Permutation. We note that even though the order of the rows has been changed, we can still recognize the patterns of the rows. We continue to use the Electric dataset as an example. To address the concern that a permutation may break the sketch, we can measure the similarity between vectors, that is, after processing the training data, we can instead test similarity on the rows of the test matrix and use this to select the heavy rows, rather than an index which may simply be permuted. To illustrate this method, we use the following example on the Electric dataset, using locality sensitive hashing. After processing the training data, we obtain a set I of indices of heavy rows. For each i I, we pick q = 3 independent standard Gaussian vectors g1, g2, g3 Rd, and compute f(ri) = Published as a conference paper at ICLR 2023 (g T 1 ri, g T 2 ri, g T 3 ri) R3, where ri takes an average of the i-th rows over all training sets. Let A be the test matrix. For each i I, let ji = argminj f(Aj) f(ri) 2. We take the ji-th row to be a heavy row in our learned sketch. This method only needs an additional O(1) passes over the entries of A and hence, the extra time cost is negligible. To test the performance of the method, we randomly pick a matrix from the test set and permute its rows. The result shows that when k is small, we can roughly recover 70% of the top-k heavy rows, and we plot below the regression error using the learned Count-Sketch matrix generated this way, where we set m = 90 and k = 0.3m = 27. We can see that the learned method still obtains a significant improvement. 1 2 3 4 5 6 7 iteration round log_10(error) count-sketch learned(position-only) Figure B.1: Test error of LASSO on Electric dataset B.2 MATRIX NORM ESTIMATION WITH A NUCLEAR NORM CONSTRAINT In many applications, for the problem X := arg min X Rd1 d2 AX B 2 F , it is reasonable to model the matrix X as having low rank. Similar to ℓ1-minimization for compressive sensing, a standard relaxation of the rank constraint is to minimize the nuclear norm of X, defined as X := Pmin{d1,d2} j=1 σj(X), where σj(X) is the j-th largest singular value of X. Hence, the matrix estimation problem we consider here is X := arg min X Rd1 d2 AX B 2 F such that X ρ, where ρ > 0 is a user-defined radius as a regularization parameter. We conduct Iterative Hessian Sketch (IHS) experiments on the following dataset: Tunnel3: The data set is a time series of gas concentrations measured by eight sensors in a wind tunnel. Each (A, B) corresponds to a different data collection trial. Ai R13530 5, Bi R13530 6, |(A, B)|train = 144, |(A, B)|test = 36. In our nuclear norm constraint, we set ρ = 10. Experiment Setting: We choose m = 7d, 10d for the Tunnel dataset. We consider the error 1 2 AX B 2 2 1 2 AX B 2 2. The leverage scores of this dataset are very uniform. Hence, for this experiment we only consider optimizing the values of the non-zero entries. Results of Our Experiments: We plot on a logarithmic scale the mean errors of the dataset in Figures B.2. We can see that when m = 7d, the gradient-based sketch, based on the first 6 iterations, has a rate of convergence that is 48% of the random sketch, and when m = 10d, the gradient-based sketch has a rate of convergence that is 29% of the random sketch. B.3 FAST REGRESSION SOLVER Consider an unconstrained convex optimization problem minx f(x), where f is smooth and strongly convex, and its Hessian 2f is Lipschitz continuous. This problem can be solved by Newton s 3https://archive.ics.uci.edu/ml/datasets/Gas+sensor+array+exposed+to+ turbulent+gas+mixtures Published as a conference paper at ICLR 2023 2 4 6 8 10 iteration round log_10(error) Ours(learned CS) Count-sketch 1 2 3 4 5 6 7 8 9 iteration round log_10(error) Ours(learned CS) Count-sketch Figure B.2: Test error of matrix estimation with a nuclear norm constraint on the Tunnel dataset 1.00 1.25 1.50 1.75 2.00 2.25 2.50 2.75 3.00 iteration round Gaussian(eta = 1) Gaussian(eta = 0.2) sparse-JL(eta = 1) sparse-JL(eta = 0.2) Ours(learned CS) 1.00 1.25 1.50 1.75 2.00 2.25 2.50 2.75 3.00 iteration round Gaussian(eta = 1) Gaussian(eta = 0.2) sparse-JL(eta = 1) sparse-JL(eta = 0.2) Ours(learned CS) 1.00 1.25 1.50 1.75 2.00 2.25 2.50 2.75 3.00 iteration round Gaussian(eta = 1) Gaussian(eta = 0.2) sparse-JL(eta = 1) sparse-JL(eta = 0.2) Ours(learned CS) Figure B.3: Test error of the subroutine in fast regression on Electric dataset. 2 4 6 8 10 12 14 iteration round log_10(error) Gaussian Sparse-JL Ours(learned CS) Figure B.4: Test error of fast regression on Electric dataset method, which iteratively performs the update xt+1 = xt arg min z ( 2f(xt)1/2) ( 2f(xt)1/2)z f(xt) 2 , (B.1) provided it is given a good initial point x0. In each step, it requires solving a regression problem of the form minz A Az y 2, which, with access to A, can be solved with a fast regression solver in (van den Brand et al., 2021). The regression solver first computes a preconditioner R via a QR decomposition such that SAR has orthonormal columns, where S is a sketching matrix, then solves ˆz = arg minz (AR) (AR)z y 2 by gradient descent and returns Rˆz in the end. Here, the point of sketching is that the QR decomposition of SA can be computed much more efficiently than the QR decomposition of A, since S has only a small number of rows. In this section, We consider the unconstrained least squares problem minx f(x) with f(x) = 1 2 Ax b 2 2 using the Electric dataset, using the above fast regression solver. Training: Note that 2f(x) = A A, independent of x. In the t-th round of Newton s method, by (B.1), we need to solve a regression problem minz A Az y 2 2 with y = f(xt). Hence, we can use the same methods in the preceding subsection to optimize the learned sketch Si. For a general problem where 2f(x) depends on x, one can take xt to be the solution obtained from the learned sketch St to generate A and y for the (t + 1)-st round, train a learned sketch St+1, and repeat this process. Experiment Setting: For the Electric dataset, we set m = 10d = 90. We observe that the classical Count-Sketch matrix makes the solution diverge terribly in this setting. To make a clearer comparison, we consider the following sketch matrix: Gaussian sketch: S = 1 m G, where G Rm n with i.i.d. N(0, 1) entries. Published as a conference paper at ICLR 2023 Sparse Johnson-Lindenstrauss Transform (SJLT): S is the vertical concatenation of s independent Count-Sketch matrices, each of dimension m/s n. We note that the above sketching matrices require more time to compute SA but need fewer rows to be a subspace embedding than the classical Count-Sketch matrix. For the step length η in gradient descent, we set η = 1 in all iterations of the learned sketches. For classical random sketches, we set η in the following two ways: (a) η = 1 in all iterations and (b) η = 1 in the first iteration and η = 0.2 in all subsequent iterations. Experimental Results: We examine the accuracy of the subproblem minz A Az y 2 2 and define the error to be A ARzt y 2 / y 2. We consider the subproblems in the first three iterations of the global Newton method. The results are plotted in Figure B.3. Note that Count-Sketch causes a terrible divergence of the subroutine and is thus omitted in the plots. Still, we observe that in setting (a) of η, the other two classical sketches cause the subroutine to diverge. In setting (b) of η, the other two classical sketches lead to convergence but their error is significantly larger than that of the learned sketches, in each of the first three calls to the subroutine. The error of the learned sketch is less than 0.01 in all iterations of all three subroutine calls, in both settings (a) and (b) of η. We also plot a figure on the convergence of the global Newton method. Here, for each subroutine, we only run one iteration, and plot the error of the original least squares problem. The result is shown in Figure B.4, which clearly displays a significantly faster decay with learned sketches. The rate of convergence using heavy-rows sketches is 80.6% of that using Gaussian or sparse JL sketches. B.4 FIRST-ORDER OPTIMIZATION In this section, we study the use of the sketch in first-order methods. Particularly, let QR 1 = SA be the QR-decomposition for SA, where S is a sketch matrix. We use R as an (approximate) preconditioner and use gradient descent to solve the problem min ARx b 2. Here we use the Electric dataset where A is 370 9 and we set S to have 90 rows. The result is shown in the following table, where the time includes the time to compute R. We can see that if we use a learned sketch matrix, the error converges very fast when we set the learning rate to be 1 and 0.1, while the classical Count-Sketch will lead to divergence. Iteration 1 10 100 500 Error (learned, lr = 1) 2.73 1.5e-7 Error (learned, lr = 0.1) 4056 605 4.04e-6 Error (learned, lr = 0.01) 4897 4085 667 0.217 Error (random, lr = 1) N.A N.A N.A N.A Error (random, lr = 0.1) Error (random, lr = 0.01) 4881 3790 685 1.52 Time 0.00048 0.00068 0.0029 0.0013 Table B.1: Test Error for Gradient Descent C PRELIMINARIES: THEOREMS AND ADDITIONAL ALGORITHMS In this section, we provide the full description of the time-optimal sketching algorithm for LRA in Algorithm 2. We also provide several definitions and lemmas that are used in the proofs of our results for LRA. Definition C.1 (Affine Embedding). Given a pair of matrices A and B, a matrix S is an affine ϵ-embedding if for all X of the appropriate shape, S(AX B) 2 F = (1 ϵ) AX B 2 F . Lemma C.2 (Clarkson & Woodruff (2017); Lemma 40). Let A be an n d matrix and let S RO(1/ϵ2) n be a Count Sketch matrix. Then with constant probability, SA 2 F = (1 ϵ) A 2 F . The following result is shown in Clarkson & Woodruff (2017) and sharpened with Nelson & Nguyên (2013); Meng & Mahoney (2013). Lemma C.3. Given matrices A, B with n rows, a Count Sketch with O(rank(A)2/ϵ2) rows is an affine ϵ-embedding matrix with constant probability. Moreover, the matrix product SA can be computed in O(nnz(A)) time, where nnz(A) denotes the number of non-zero entries of matrix A. Published as a conference paper at ICLR 2023 Lemma C.4 (Sarlos (2006); Clarkson & Woodruff (2017)). Suppose that A Rn d and B Rn d . Let S Rm n be a Count Sketch with m = rank(A)2 ϵ2 . Let X = arg minrank-k X SAX SB 2 F . Then, 1. With constant probability, A X B 2 F (1 + ϵ) minrank-k X AX B 2 F . In other words, in O(nnz(A) + nnz(B) + m(d + d )) time, we can reduce the problem to a smaller (multi-response regression) problem with m rows whose optimal solution is a (1 + ϵ)-approximate solution to the original instance. 2. The (1 + ϵ)-approximate solution X can be computed in time O(nnz(A) + nnz(B) + mdd + min(m2d, md2)). Now we turn our attention to the time-optimal sketching algorithm for LRA. The next lemma is known, though we include it for completeness Avron et al. (2017): Lemma C.5. Suppose that S Rm S n and R Rm R d are sparse affine ϵ-embedding matrices for (Ak, A) and ((SA) , A ), respectively. Then, min rank-k X F (1 + ϵ) Ak A 2 F Proof. Consider the following multiple-response regression problem: min rank-k X Ak X A 2 F . (C.1) Note that since X = Ik is a feasible solution to Eq. (C.1), minrank-k X Ak X A 2 F = Ak A 2 F . Let S Rm S n be a sketching matrix that satisfies the condition of Lemma C.4 (Item 1) for A := Ak and B := A. By the normal equations, the rank-k minimizer of SAk X SA 2 F is (SAk)+SA. Hence, Ak(SAk)+SA A 2 F (1 + ϵ) Ak A 2 F , (C.2) which in particular shows that a (1 + ϵ)-approximate rank-k approximation of A exists in the row space of SA. In other words, min rank-k X XSA A 2 F (1 + ϵ) Ak A 2 F . (C.3) Next, let R Rm R d be a sketching matrix which satisfies the condition of Lemma C.4 (Item 1) for A := (SA) and B := A . Let Y denote the rank-k minimizer of R(SA) X RA 2 F . Hence, (SA) Y A 2 F (1 + ϵ) min rank-k X XSA A 2 F Lemma C.4 (Item 1) (1 + O(ϵ)) Ak A 2 F Eq. (C.3) (C.4) Note that by the normal equations, again rowsp(Y ) rowsp(RA ) and we can write Y = AR Z where rank(Z) = k. Thus, min rank-k X F AR ZSA A 2 F = (SA) Y A 2 (1 + O(ϵ)) Ak A 2 F Eq. (C.4) Lemma C.6 (Avron et al. (2017); Lemma 27). For C Rp m , D Rm p , G Rp p , the following problem min rank-k Z CZD G 2 F (C.5) can be solved in O(pm r C + p mr D + pp (r D + r C)) time, where r C = rank(C) min{m , p} and r D = rank(D) min{m, p }. Published as a conference paper at ICLR 2023 Lemma C.7. Let S Rm S d, R Rm R d be Count Sketch (CS) matrices such that min rank-k X F (1 + γ) Ak A 2 F . (C.6) m2 R β2 n, and W R m2 S β2 d be CS matrices. Then, Algorithm 2 gives a (1 + O(β + γ))- approximation in time nnz(A) + O( m4 S β2 + m4 R β2 + m2 Sm2 R(m S+m R) β4 + k(nm S + dm R)) with constant probability. Proof. The approximation guarantee follows from Eq. (C.6) and the fact that V and W are affine β-embedding matrices of AR and SA, respectively (see Lemma C.3). The algorithm first computes C = V AR , D = SAW , G = V AW which can be done in time O(nnz(A)). As an example, we bound the time to compute C = V AR. Note that since V is a CS, V A can be computed in O(nnz(A)) time and the number of non-zero entries in the resulting matrix is at most nnz(A). Hence, since R is a CS as well, C can be computed in time O(nnz(A) + nnz(V A)) = O(nnz(A)). Then, it takes an extra O((m3 S + m3 R + m2 Sm2 R)/β2) time to store C, D and G in matrix form. Next, as we showed in Lemma C.6, the time to compute Z in Algorithm 2 is O( m4 S β2 + m4 R β2 + m2 Sm2 R(m S+m R) β4 ). Finally, it takes O(nnz(A) + k(nm S + dm R)) time to compute Q = AR ZL and P = ZRSA and to return the solution in the form of Pn k Qk d. Hence, the total runtime is O(nnz(A) + m4 S β2 + m4 R β2 + m2 Sm2 R(m S + m R) β4 + k(nm S + dm R)) D ATTAINING WORST-CASE GUARANTEES D.1 LOW-RANK APPROXIMATION We shall provide the following two methods to achieve worst case guarantees: Mixed Sketch whose guarantee is via the sketch monotonicity property, and approximate comparison method (a.k.a. Approx Check), which just approximately evaluates the cost of two solutions and takes the better one. These methods asymptotically achieve the same worst-case guarantee. However, for any input matrix A and any pair of sketches S, T, the performance of the Mixed Sketch method on (A, S, T) is never worse than the performance of its corresponding Approx Check method on (A, S, T), and can be much better. Remark D.1. Let A = diag(2, 2, 2), and suppose the goal is to find a rank-2 approximation of A. Consider two sketches S and T such that SA and TA capture span(e1, e3) and span(e2, e4), respectively. Then for both SA and TA, the best solution in the subspace of one of these two spaces is a ( 3 2)-approximation: A A2 2 F = 4 and A PSA 2 F = A PT A 2 F = 6 where PSA and PT A respectively denote the best approximation of A in the space spanned by SA and TA. However, if we find the best rank-2 approximation of A, Z, inside the span of the union of SA and TA, then A Z 2 F = 4. Since Approx Check just chooses the better of SA and TA by evaluating their costs, it misses out on the opportunity to do as well as Mixed Sketch. Here, we show the sketch monotonicity property for LRA. Theorem D.2. Let A Rn d be an input matrix, V and W be η-affine embeddings, and S1 Rm S n, R1 Rm R n be arbitrary matrices. Consider arbitrary extensions to S1, R1: S, R (e.g., S is a concatenation of S1 with an arbitrary matrix with the same number of columns). Then, A ALGLRA((S, R, V, W), A)) 2 F (1 + η)2 A ALGLRA((S1, R1, V, W), A) 2 F Proof. We have A ALGLRA((S, R, V, W), A) 2 F (1 + η) minrank-k X ARXSA A 2 F = (1 + η) minrank-k X:X row(SA) col(AR) X A 2 F , which is in turn at most (1 + η) minrank-k X:X row(S1A) col(AR1) X A 2 F = (1 + η) minrank-k X AR1XS1A A 2 F (1 + Published as a conference paper at ICLR 2023 η)2 A ALGLRA((S1, R1, V, W), A) 2 F , where we use the fact the V, W are affine η-embeddings (Definition C.1), as well as the fact that (col(AR1) row(S1A)) col(AR) row(SA) . Approx Check for LRA. We give the pseudocode for the Approx Check method and prove that the runtime of this method for LRA is of the same order as the classical time-optimal sketching algorithm of LRA. Algorithm 5 LRA APPROXCHECK Input: learned sketches SL, RL, VL, WL; classical sketches SC, RC, VC, WC; β; A Rn d 1: PL, QL ALGLRA(SL, RL, VL, WL, A), PCQC ALGLRA(SC, RC, VC, WC, A) 2: Let S RO(1/β2) n, R RO(1/β2) d be classical Count Sketch matrices 3: L S (PLQL A) R 2 F , C S (PCQC A) R 2 F 4: if L C then 5: return PLQL 6: end if 7: return PCQC Theorem D.3. Assume we have data A Rn d, learned sketches SL Rpoly( k ϵ ) n, RL Rpoly( k ϵ ) d, VL Rpoly( k ϵ ) n, WL Rpoly( k ϵ ) d which attain a (1 + O(γ))-approximation, classical sketches of the same size, SC, RC, VC, WC, which attain a (1 + O(ϵ))-approximation, and a tradeoff parameter β. Then, Algorithm 5 attains a (1 + β + min(γ, ϵ))-approximation in O(nnz(A) + (n + d) poly( k Proof. Let (PL, QL), (PC, QC) be the approximate rank-k approximations of A in factored form using (SL, RL) and (SO, RO). Then, clearly, min( PLQL A 2 F , PCQC A 2 F ) = (1 + O(min(ϵ, γ))) Ak A 2 F (D.1) Let ΓL = PLQL A, ΓC = PCQC A and ΓM = arg min( SΓLR F , SΓCR F ). Then, ΓM 2 F (1 + O(β)) SΓMR 2 F by Lemma C.2 (1 + O(β)) min( ΓL 2 F , ΓC 2 F ) (1 + O(β + min(ϵ, γ))) Ak A 2 F by Eq. (D.1) Runtime analysis. By Lemma C.7, Algorithm 2 computes PL, QL and PC, QC in time O(nnz(A) + k16(β2+ϵ2) ϵ2 (n + dk2 ϵ4 )). Next, once we have PL, QL and PC, QC, it takes O(nnz(A) + k β4 ) time to compute L and C. O(nnz(A) + k16(β2 + ϵ2) ϵ2 (n + dk2 β4 ) = O(nnz(A) + (n + d + k4 β4 ) poly(k To interpret the above theorem, note that when ϵ k(n + d) 4, we can set β 4 = O(k(n + d) 4) so that Algorithm 5 has the same asymptotic runtime as the best (1 + ϵ)-approximation algorithm for LRA with the classical Count Sketch. Moreover, Algorithm 5 is a (1 + o(ϵ))-approximation when the learned sketch outperforms classical sketches, γ = o(ϵ). On the other hand, when the learned sketches perform poorly, γ = Ω(ϵ), the worst-case guarantee of Algorithm 5 remains (1 + O(ϵ)). D.2 SECOND-ORDER OPTIMIZATION For the sketches for second-order optimization, the monotonicity property does not hold. Below we provide an input-sparsity algorithm which can test for and use the better of a random sketch and a learned sketch. Our theorem is as follows. Theorem D.4. Let ϵ (0, 0.09) be a constant and S1 a learned Count-Sketch matrix. Suppose that A is of full rank. There is an algorithm whose output is a solution ˆx which, with probability at least 0.98, satisfies that A(ˆx x ) 2 O min Z2(S1) Z1(S1), ϵ Ax 2, where x = arg minx C Ax b 2 is the least-squares solution. Furthermore, the algorithm runs in O(nnz(A) log( 1 ϵ ) + poly( d Published as a conference paper at ICLR 2023 Algorithm 6 Solver for (D.2) 1: S1 learned sketch, S2 random sketch with Θ(d2/ϵ2) rows 2: ( ˆZi,1, ˆZi,2) ESTIMATE(Si, A), i = 1, 2 3: i arg mini=1,2( ˆZi,2/ ˆZi,1) 4: ˆx solution of (D.2) with S = Si 5: return ˆx 6: function ESTIMATE(S, A) 7: T sparse (1 η)-subspace embedding matrix for d-dimensional subspaces 8: (Q, R) QR(TA) 9: ˆZ1 σmin(SAR 1) 10: ˆZ2 (1 η)-approximation to (SAR 1) (SAR 1) I op 11: return ( ˆZ1, ˆZ2) Consider the minimization problem 2 SAx 2 2 A y, x , (D.2) which is used as a subroutine for the IHS (cf. (2.2)). We note that in this subroutine if we let x x xi 1, b b Axi 1, C C xi 1, we would get the guarantee of the i-th iteration of the original IHS. To analyze the performance of the learned sketch, we define the following quantities (corresponding exactly to the unconstrained case in (Pilanci & Wainwright, 2016)) Z1(S) = inf v colsp(A) Sn 1 Sv 2 2 , Z2(S) = sup u,v colsp(A) Sn 1 u, (S S In)v . When S is a (1 + ϵ)-subspace embedding of colsp(A), we have Z1(S) 1 ϵ and Z2(S) 2ϵ. For a general sketching matrix S, the following is the approximation guarantee of ˆZ1 and ˆZ2, which are estimates of Z1(S) and Z2(S), respectively. The main idea is that AR 1 is well-conditioned, where R is as calculated in Algorithm 6. Lemma D.5. Suppose that η (0, 1 3) is a small constant, A is of full rank and S has poly(d/η) rows. The function ESTIMATE(S, A) returns in O((nnz(A) log 1 η)) time ˆZ1, ˆZ2 which with probability at least 0.99 satisfy that Z1(S) 1+η ˆZ1 Z1(S) 1 η and Z2(S) (1+η)2 3η ˆZ2 Z2(S) (1 η)2 + 3η. Proof. Suppose that AR 1 = UW, where U Rn d has orthonormal columns, which form an orthonormal basis of the column space of A. Since T is a subspace embedding of the column space of A with probability 0.99, it holds for all x Rd that 1 1 + η TAR 1x 2 AR 1x 2 1 1 η Since TAR 1x 2 = Qx 2 = x 2 and Wx 2 = UWx 2 = AR 1x 2 (D.3) we have that 1 1 + η x 2 Wx 2 1 1 η x 2 , x Rd. (D.4) It is easy to see that Z1(S) = min x Sd 1 SUx 2 = min y =0 SUWy 2 min y =0(1 η) SUWy 2 y 2 Z1(S) min y =0(1 + η) SUWy 2 Published as a conference paper at ICLR 2023 Recall that SUW = SAR 1. We see that (1 η)σmin(SAR 1) Z1(S) (1 + η)σmin(SAR 1). By definition, Z2(S) = U T (S S In)U op . It follows from (D.4) that (1 η)2 W T U T (ST S In)UW op Z2(S) (1 + η)2 W T U T (ST S In)UW op . and from (D.4), (D.3) and Lemma 5.36 of Vershynin (2012) that (AR 1) (AR 1) I op 3η. Since W T U T (ST S In)UW op = (AR 1) (ST S In)AR 1 op and (AR 1) ST SAR 1 I op (AR 1) (AR 1) I op (AR 1) (ST S In)AR 1 op (AR 1) ST SAR 1 I op + (AR 1) (AR 1) I op , it follows that (1 η)2 (SAR 1) SAR 1 I op 3(1 η)2η (1 + η)2 (SAR 1) SAR 1 I op + 3(1 + η)2η. We have so far proved the correctness of the approximation and we next analyze the runtime below. Since S and T are sparse, computing SA and TA takes O(nnz(A)) time. The QR decomposition of TA, which is a matrix of size poly(d/η) d, can be computed in poly(d/η) time. The matrix SAR 1 can be computed in poly(d) time. Since it has size poly(d/η) d, its smallest singular value can be computed in poly(d/η) time. To approximate Z2(S), we can use the power method to estimate (SAR 1)T SAR 1 I op up to a (1 η)-factor in O((nnz(A) + poly(d/η)) log(1/η)) time. Now we are ready to prove Theorem D.4. Proof of Theorem D.4. In Lemma D.5, we have with probability at least 0.99 that 1 (1+ϵ)2 Z2(S) 3ϵ 1 1 ϵZ1(S) 1 ϵ (1 + ϵ)2 Z2(S) Z1(S) 3ϵ(1 ϵ) and similarly, ˆZ2 ˆZ1 1 (1 ϵ)2 Z2(S) + 3ϵ 1 1+ϵZ1(S) 1 + ϵ (1 ϵ)2 Z2(S) Z1(S) + 3ϵ(1 + ϵ) Note that since S2 is an ϵ-subspace embedding with probability at least 0.99, we have that Z1(S2) 1 ϵ and Z2(S2)/Z1(S2) 2.2ϵ. Consider Z1(S1). First, we consider the case where Z1(S1) < 1/2. Observe that Z2(S) 1 Z1(S). We have in this case b Z1,2/ b Z1,1 > 1/5 2.2ϵ Z2(S2)/Z1(S2). In this case our algorithm will choose S2 correctly. Next, assume that Z1(S1) 1/2. Now we have with probability at least 0.98 that (1 3ϵ)Z2(Si) Z1(Si) 3ϵ b Zi,2 b Zi,1 (1 + 4ϵ)Z2(Si) Z1(Si) + 4ϵ, i = 1, 2. Published as a conference paper at ICLR 2023 Therefore, when Z2(S1)/Z1(S1) c1Z2(S2)/Z1(S2) for some small absolute constant c1 > 0, we will have b Z1,2/ b Z1,1 < b Z2,2/ b Z2,1, and our algorithm will choose S1 correctly. If Z2(S1)/Z1(S1) C1ϵ for some absolute constant C1 > 0, our algorithm will choose S2 correctly. In the remaining case, both ratios Z2(S2)/Z1(S2) and Z2(S1)/Z1(S1) are at most max{C2, 3}ϵ, and the guarantee of the theorem holds automatically. The correctness of our claim then follows from Proposition 1 of Pilanci & Wainwright (2016), together with the fact that S2 is a random subspace embedding. The runtime follows from Lemma D.5 and Theorem 2.2 of Cormode & Dickens (2019). E SKETCH LEARNING: OMITTED PROOFS E.1 PROOF OF THEOREM 5.1 We need the following lemmas for the ridge leverage score sampling in (Cohen et al., 2017). Lemma E.1 ((Cohen et al., 2017, Lemma 4)). Let λ = A Ak 2 F /k. Then we have P i τi(A) 2k. Lemma E.2 ((Cohen et al., 2017, Theorem 7)). Let λ = A Ak 2 F /k and τi τi(A) be an overestimate to the i-th ridge leverage score of A. Let pi = τi/ P i τi. If C is a matrix that is constructed by sampling t = O((log k + log(1/δ) i τi) rows of A, each set to ai with probability pi, then with probability at least 1 δ we have min rank-k X:row(X) row(C) A X 2 F (1 + ϵ) A Ak 2 F . Recall that the sketch monotonicity for low-rank approximation says that concatenating two sketching matrices S1 and S2 will not increase the error compared to the single sketch matrix S1 or S2, Now matter how S1 and S2 are constructed. (see Section D.1 and Section 4 in (Indyk et al., 2019)) Proof. We first consider the first condition. From the condition that τi(B) 1 β τi(A) we know that if we sample m = O(β (k log k + k/ϵ)) rows according to τi(A). The actual probability that the i-th row of B gets sampled is 1 (1 τi(A))m = O(m τi(A)) = O ((k log k + k/ϵ) τi(B)) . i τi(B) 2k and Lemma E.2 (recall the sketch monotonicity property for LRA), we have that with probability at least 9/10, S2 is a matrix such that min rank-k X:row(X) row(S2B) B X 2 F (1 + ϵ) B Bk 2 F . Hence, since S = [ S1 S2 ], from the the sketch monotonicity property for LRA we have that min rank-k X:row(X) row(SB) B X 2 F (1 + ϵ) B Bk 2 F . Now we consider the second condition. Suppose that {Xi}i m and {Yi}i m are a sequence of m = O(k log k + k/ϵ) samples from [n] according to the sampling probability distribution p and q, where pi = τi(A) P i τi(A) and qi = τi(B) P i τi(B). Let S be the set of index i such that Xi = Yi. From the property of the total variation distance, we get that Pr [Xi = Yi] dtv(p, q) = β , and E[|S|] = X i Pr [Xi = Yi] βm. From Markov s inequality we get that with probability at least 1 1.1β, |S| 1/(1.1β) βm = 10 11m. Let T be the set of index i such that Xi = Yi. We have that with probability at least 1 1.1β, |T| m 10 11m = Ω(k log k + k/ϵ). Because that {Yi}i T is i.i.d samples according to q and the Published as a conference paper at ICLR 2023 actual sample we take is {Xi}i T . From Lemma E.2 we get that with probability at least 9/10, the row space of BT satisfies min rank-k X:row(X) row(BT ) B X 2 F (1 + ϵ) B Bk 2 F . Similarly, from the the sketch monotonicity property we have that with probability at least 0.9 1.1β min rank-k X:row(X) row(SB) B X 2 F (1 + ϵ) B Bk 2 F . E.2 PROOF OF THEOREM 6.1 First we prove the following lemma. Lemma E.3. Let δ (0, 1/m]. It holds with probability at least 1 δ that sup x colsp(A) Sx 2 2 x 2 2 ϵ x 2 2 , provided that m ϵ 2((d + log m) min{log2(d/ϵ), log2 m} + d log(1/δ)), 1 ϵ 2ν((log m) min{log2(d/ϵ), log2 m} + log(1/δ)) log(1/δ). Proof. We shall adapt the proof of Theorem 5 in (Bourgain et al., 2015) to our setting. Let T denote the unit sphere in colsp(A) and set the sparsity parameter s = 1. Observe that Sx 2 2 = x I 2 2 + Sx Ic 2 2, and so it suffices to show that Pr n S x Ic 2 2 x Ic 2 2 > ϵ o δ for x T. We make the following definition, as in (2.6) of (Bourgain et al., 2015): j Ic δijxjei ej, and thus, S x Ic = Aδ,xσ. Also by E S x Ic 2 2 = x Ic 2 2, one has S x Ic 2 2 x Ic 2 2 = sup x T Aδ,xσ 2 2 E Aδ,xσ 2 2 . (E.1) Now, in (2.7) of (Bourgain et al., 2015) we instead define a semi-norm x δ = max 1 i m j Ic δijx2 j Then (2.8) continues to hold, and (2.9) as well as (2.10) continue to hold if the supremum in the left-hand side is replaced with the left-hand side of (E.1). At the beginning of Theorem 5, we define U (i) to be U, but each row j Ic is multiplied by δij and each row j I is zeroed out. Then we have in the first step of (4.5) that k=1 gk fk, ej instead of equality. One can verify that the rest of (4.5) goes through. It remains true that δ (1/ s) 2, and thus (4.6) holds. One can verify that the rest of the proof of Theorem 5 in (Bourgain et al., 2015) continues to hold if we replace Pn j=1 with P j Ic and max1 j n with maxj Ic, noting that E X j Ic δij PEej 2 2 = s j Ic PEej, ej s Published as a conference paper at ICLR 2023 E(U (i)) U (i) = X j Ic (E δij)uju j 1 Thus, the symmetrization inequalities on j Ic δij PEej 2 2 j Ic δijuju j continue to hold. The result then follows, observing that maxj Ic PEej 2 ν. The subspace embedding guarantee now follows as a corollary. Theorem 6.1. Let ν = ϵ/d. Suppose that m = Ω((d/ϵ2)(polylog(1/ϵ) + log(1/δ))), δ (0, 1/m) and d = Ω((1/ϵ) polylog(1/ϵ) log2(1/δ)). Then, there exists a distribution on S with m + |I| rows such that Pr n x colsp(A), Sx 2 2 x 2 2 > ϵ x 2 2 o δ. Proof. One can verify that the two conditions in Lemma E.3 are satisfied if ϵ ) + log 1 ϵ ) + log 1 The last condition is satisfied if E.3 PROOF OF LEMMA 6.2 Proof. On the one hand, since Q = SAR is an orthogonal matrix, we have x 2 = Qx 2 = SARx 2 . (E.2) On the other hand, the assumption implies that (ARx)T (ARx) x T x 2 ϵ x 2 2 , that is, (1 ϵ) x 2 2 ARx 2 2 (1 + ϵ) x 2 2 . (E.3) Combining both (E.2) and (E.3) leads to 1 ϵ SARx 2 ARx 2 1 + ϵ SARx 2 , Equivalently, it can be written as 1 1 + ϵ SAy 2 Ay 2 1 1 ϵ SAy 2 , y Rd. The claimed result follows from the fact that 1/ 1 + ϵ 1 ϵ and 1/ 1 ϵ 1 + ϵ whenever ϵ (0, Published as a conference paper at ICLR 2023 F LOCATION OPTIMIZATION IN COUNTSKETCH: GREEDY SEARCH While the position optimization idea is simple, one particularly interesting aspect is that it is provably better than a random placement in some scenarios (Theorem. F.1). Specifically, it is provably beneficial for LRA when inputs follow the spiked covariance model or Zipfian distributions, which are common for real data. Spiked covariance model. Every matrix A Rn d from the distribution Asp(s, ℓ) has s < k heavy rows Ar1, , Ars of norm ℓ> 1. The indices of the heavy rows can be arbitrary, but must be the same for all members of Asp(s, ℓ) and are unknown to the algorithm. The remaining ( light ) rows have unit norm. In other words, let R = {r1, . . . , rs}. For all rows Ai, i [n], Ai = ℓ vi if i R and Ai = vi otherwise, where vi is a uniformly random unit vector. Zipfian on squared row norms. Every A Rn d Azipf has rows which are uniformly random and orthogonal. Each A has 2i+1 rows of squared norm n2/22i for i [1, . . . , O(log(n))]. We also assume that each row has the same squared norm for all members of Azipf. Theorem F.1. Consider a matrix A from either the spiked covariance model or a Zipfian distribution. Let SL denote a Count Sketch constructed by Algorithm 3 that optimizes the positions of the non-zero values with respect to A. Let SC denote a Count Sketch matrix. Then there is a fixed η > 0 such that, minrank-k X rowsp(SLA) X A 2 F (1 η) minrank-k X rowsp(SCA) X A 2 F Remark F.2. Note that the above theorem implicitly provides an upper bound on the generalization error of the greedy placement method on the two distributions that we considered in this paper. More precisely, for each of these two distributions, if Π is learned via our greedy approach over a set of sampled training matrices, the solution returned by the sketching algorithm using Π over any (test) matrix A sampled from the distribution has error at most (1 η) minrank-k X rowsp(SCA) X A 2 F . A key structural property of the matrices from these two distributions that is crucial in our analysis is the ϵ-almost orthogonality of their rows (i.e., (normalized) pairwise inner products are at most ϵ). Hence, we can find a QR-factorization of the matrix of such vectors where the upper diagonal matrix R has diagonal entries close to 1 and entries above the diagonal are close to 0. To state our result, we first provide an interpretation of the location optimization task as a selection of hash function for the rows of A. Note that left-multiplying A by Count Sketch S Rm n is equivalent to hashing the rows of A to m bins with coefficients in { 1}. The greedy algorithm proceeds through the rows of A (in some order) and decides which bin to hash to, denoting this by adding an entry to S. The intuition is that our greedy approach separates heavy-norm rows (which are important directions in the row space) into different bins. Proof Sketch of Theorem F.1 The first step is to observe that in the greedy algorithm, when rows are examined according to a non-decreasing order of squared norms, the algorithm will isolate rows into their singleton bins until all bins are filled. In particular, this means that the heavy norm rows will all be isolated e.g., for the spiked covariance model, Lemma F.8 presents the formal statement. Next, we show that none of the rows left to be processed (all light rows) will be assigned to the same bin as a heavy row. The main proof idea is to compare the cost of colliding with a heavy row to the cost of avoiding the heavy rows. This is the main place we use the properties of the aforementioned distributions and the fact that each heavy row is already mapped to a singleton bin. Overall, we show that at the end of the algorithm no light row will be assigned to the bins that contain heavy rows the formal statement and proof for the spiked covariance model is in Lemma F.12. Finally, we can interpret the randomized construction of Count Sketch as a balls and bins experiment. In particular, considering the heavy rows, we compute the expected number of bins (i.e., rows in SCA) that contain a heavy row. Note that the expected number of rows in SCA that do not contain any heavy row is k (1 1 k)s k e s k 1 . Hence, the number of rows in SCA that contain a heavy row of A is at most k(1 e s k 1 ). Thus, at least s k(1 e s k 1 ) heavy rows are not mapped to an isolated bin (i.e., they collide with some other heavy rows). Then, it is straightforward to show that the squared loss of the solution corresponding to SC is larger than the squared loss of the solution corresponding to SL, the Count Sketch constructed by Algorithm 3 please see Lemma F.14 for the formal statement of its proof. Published as a conference paper at ICLR 2023 Preliminaries and notation. Left-multiplying A by a Count Sketch S Rm n is equivalent to hashing the rows of A to m bins with coefficients in { 1}. The greedy algorithm proceeds through the rows of A (in some order) and decides which bin to hash to, which we can think of as adding an entry to S. We will denote the bins by bi and their summed contents by wi. F.1 SPIKED COVARIANCE MODEL WITH SPARSE LEFT SINGULAR VECTORS. To recap, every matrix A Rn d from the distribution Asp(s, ℓ) has s < k heavy rows (Ar1, , Ars) of norm ℓ> 1. The indices of the heavy rows can be arbitrary, but must be the same for all members of the distribution and are unknown to the algorithm. The remaining rows (called light rows) have unit norm. In other words: let R = {r1, . . . , rs}. For all rows Ai, i [n]: Ai = ℓ vi if i R vi o.w. where vi is a uniformly random unit vector. We also assume that Sr, Sg Rk n and that the greedy algorithm proceeds in a non-increasing row norm order. Proof sketch. First, we show that the greedy algorithm using a non-increasing row norm ordering will isolate heavy rows (i.e., each is alone in a bin). Then, we conclude by showing that this yields a better k-rank approximation error when d is sufficiently large compared to n. We begin with some preliminary observations that will be of use later. It is well-known that a set of uniformly random vectors is ϵ-almost orthogonal (i.e., the magnitudes of their pairwise inner products are at most ϵ). Observation F.3. Let v1, , vn Rd be a set of random unit vectors. Then with probability 1 1/ poly(n), we have | vi, vj | 2 q d , i < j n. We define ϵ = 2 q Observation F.4. Let u1, , ut be a set of vectors such that for each pair of i < j t, | ui ui , uj uj | ϵ, and gi, , gj { 1, 1}. Then, i=1 ui 2 2 2ϵ X i 0. This exists by Lemma F.5. Let {e1, . . . , ei 2, e} be an orthonormal basis for {A1, . . . , Ai+2, Ai 1 Ai}. Now, e = c0ei 1 + c1ei for some c0, c1 because (Ai 1 Ai) proj{e1,...,ei 2}(Ai 1 Ai) span(ei 1, ei). We note that c2 0 + c2 1 = 1 because we let e be a unit vector. We can find c0, c1 to be: c0 = ai 1,i 1 + ai,i 1 q (ai 1,i 1 + ai,i 1)2 + a2 i,i , c1 = ai,i q (ai 1,i 1 + ai,i 1)2 + a2 i,i 1. j i 2: The cost is zero for both cases because Aj span({e1, . . . , ei 2}). 2. j i + 1: We compare the rewards (sum of squared projection coefficients) and find that {e1, . . . , ei 2, e} is no better than {e1, . . . , ei}. Aj, e 2 = (c0 Aj, ei 1 + c1 Aj, ei )2 (c2 1 + c2 0)( Aj, ei 1 2 + Aj, ei 2) Cauchy-Schwarz inequality = Aj, ei 1 2 + Aj, ei 2 3. j {i 1, i}: We compute the sum of squared projection coefficients of Ai 1 and Ai onto e: ( 1 (ai 1,i 1 + ai,i 1)2 + a2 i,i ) (a2 i 1,i 1(ai 1,i 1 + ai,i 1)2 + (ai,i 1(ai 1,i 1 + ai,i 1) + ai,iai,i)2) = ( 1 (ai 1,i 1 + ai,i 1)2 + a2 i,i ) ((ai 1,i 1 + ai,i 1)2(a2 i 1,i 1 + a2 i,i 1) + a4 i,i + 2ai,i 1a2 i,i(ai 1,i 1 + ai,i 1)) (F.2) On the other hand, the sum of squared projection coefficients of Ai 1 and Ai onto ei 1 ei is: ((ai 1,i 1 + ai,i 1)2 + a2 i,i (ai 1,i 1 + ai,i 1)2 + a2 i,i ) (a2 i 1,i 1 + a2 i,i 1 + a2 i,i) (F.3) Hence, the difference between the sum of squared projections of Ai 1 and Ai onto e and ei 1 ei is ((F.3) - (F.2)) a2 i,i((ai 1,i 1 + ai,i 1)2 + a2 i 1,i 1 + a2 i,i 1 2ai,i 1(ai 1,i 1 + ai,i 1)) (ai 1,i 1 + ai,i 1)2 + a2 i,i = 2a2 i,ia2 i 1,i 1 (ai 1,i 1 + ai,i 1)2 + a2 i,i > 0 Thus, we find that {e1, . . . , ei} is a strictly better basis than {e1, . . . , ei 2, e}. This means the greedy algorithm will choose to place Ai in an empty bin. Next, we show that none of the rows left to be processed (all light rows) will be assigned to the same bin as a heavy row. The main proof idea is to compare the cost of colliding with a heavy row to the cost of avoiding the heavy rows. Specifically, we compare the decrease (before and after bin assignment of a light row) in sum of squared projection coefficients, lower-bounding it in the former case and upper-bounding it in the latter. We introduce some results that will be used in Lemma F.12. Claim F.9. Let Ak+r, r [1, . . . , n k] be a light row not yet processed by the greedy algorithm. Let {e1, . . . , ek} be the Gram-Schmidt basis for the current {w1, . . . , wk}. Let β = O(n 1k 3) upper bound the inner products of the normalized Ak+r, w1, . . . , wk. Then, for any bin i, ei, Ak+r 2 β2 k2. Published as a conference paper at ICLR 2023 Proof. This is a straightforward application of Lemma F.5. From that, we have Ak+r, ei 2 i2β2, for i [1, . . . , k], which means Ak+r, ei 2 k2β2. Claim F.10. Let Ak+r be a light row that has been processed by the greedy algorithm. Let {e1, . . . , ek} be the Gram-Schmidt basis for the current {w1, . . . , wk}. If Ak+r is assigned to bin bk 1 (w.l.o.g.), the squared projection coefficient of Ak+r onto ei, i = k 1 is at most 4β2 k2, where β = O(n 1k 3) upper bounds the inner products of normalized Ak+r, w1, , wk. Proof. Without loss of generality, it suffices to bound the squared projection of Ak+r onto the direction of wk that is orthogonal to the subspace spanned by w1, , wk 1. Let e1, , ek be an orthonormal basis of w1, , wk guaranteed by Lemma F.5. Next, we expand the orthonormal basis to include ek+1 so that we can write the normalized vector of Ak+r as vk+r = Pk+1 j=1 bjej. By a similar approach to the proof of Lemma F.5, for each j k 2, bj β2j2. Next, since | wk, vk+r | β, |bk| 1 | wk, ek | (| wk, vk+r | + j=1 |bj wk, ej |) 1 Pk 1 j=1 β2 j2 (β + j=1 β2 j2 + (k 1) β) |bk 1| 1 = β + Pk 2 j=1 β2 j2 q 1 Pk 1 j=1 β2 j2 + (k 1)β 2(k 1)β β2(k 1)2 q 1 Pk 1 j=1 β2 j2 similar to the proof of Lemma F.5 Hence, the squared projection of Ak+r onto ek is at most 4β2 k2 Ak+r 2 2. We assumed Ak+r = 1; hence, the squared projection of Ak+r onto ek is at most 4β2 k2. Claim F.11. We assume that the absolute values of the inner products of vectors in v1, , vn are at most ϵ < 1/(n2 P Ai b Ai 2) and the absolute values of the inner products of the normalized vectors of w1, , wk are at most β = O(n 3k 3 2 ). Suppose that bin b contains the row Ak+r. Then, the squared projection of Ak+r onto the direction of w orthogonal to span({w1, , wk} \ {w}) is at most Ak+r 4 2 w 2 2 + O(n 2) and is at least Ak+r 4 2 w 2 2 O(n 2). Proof. Without loss of generality, we assume that Ak+r is mapped to bk; w = wk. First, we provide an upper and a lower bound for | vk+r, wk | where for each i k, we let wi = wi wi 2 denote the normalized vector of wi. Recall that by definition vk+r = Ak+r Ak+r 2 . | wk, vk+r | Ak+r 2 + P Ai bk ϵ Ai 2 wk 2 Ak+r 2 + n 2 wk 2 by ϵ < n 2 P wk 2 + n 2 wk 2 1 (F.4) | wk, vk+r | Ak+r 2 P Ai bk Ai 2 ϵ wk 2 wk 2 n 2 (F.5) Published as a conference paper at ICLR 2023 Now, let {e1, , ek} be an orthonormal basis for the subspace spanned by {w1, , wk} guaranteed by Lemma F.5. Next, we expand the orthonormal basis to include ek+1 so that we can write vk+r = Pk+1 j=1 bjej. By a similar approach to the proof of Lemma F.5, we can show that for each j k 1, b2 j β2j2. Moreover, |bk| 1 | wk, ek | (| wk, vk+r | + j=1 |bj wk, ej |) 1 Pk 1 j=1 β2 j2 (| wk, vk+r | + j=1 β2 j2) by Lemma F.5 1 Pk 1 j=1 β2 j2 (n 2 + Ak+r 2 j=1 β2 j2) by (F.4) < β k + 1 p 1 β2k3 (n 2 + Ak+r 2 wk 2 ) similar to the proof of Lemma F.5 O(n 2) + (1 + O(n 2)) Ak+r 2 wk 2 by β = O(n 3k 3 wk 2 + O(n 2) Ak+r 2 |bk| 1 | wk, ek | (| wk, vk+r | j=1 |bj wk, ej |) | wk, vk+r | j=1 β2 j2 since | wk, ek | 1 j=1 β2 j2 by (F.5) wk 2 O(n 2) by β = O(n 3k 3 Hence, the squared projection of Ak+r onto ek is at most Ak+r 4 2 wk 2 2 +O(n 2) and is at least Ak+r 4 2 wk 2 2 Now, we show that at the end of the algorithm no light row will be assigned to the bins that contain heavy rows. Lemma F.12. We assume that the absolute values of the inner products of vectors in v1, , vn are at most ϵ < min{n 2k 5 Ai w Ai 2) 1}. At iteration k + r, the greedy algorithm will assign the light row Ak+r to a bin that does not contain a heavy row. Proof. The proof is by induction. Lemma F.8 implies that no light row has been mapped to a bin that contains a heavy row for the first k iterations. Next, we assume that this holds for the first k + r 1 iterations and show that is also must hold for the (k + r)-th iteration. To this end, we compare the sum of squared projection coefficients when Ak+r avoids and collides with a heavy row. Published as a conference paper at ICLR 2023 First, we upper bound β = maxi =j k | wi, wj |/( wi 2 wj 2). Let ci and cj respectively denote the number of rows assigned to bi and bj. β = max i =j k | wi, wj | wi 2 wj 2 ci cj ϵ p ci 2ϵc2 i q cj 2ϵc2 j Observation F.4 16ϵ cicj ϵ n 2k 5/3 3 ϵ n 2k 5/3 1. If Ak+r is assigned to a bin that contains c light rows and no heavy rows. In this case, the projection loss of the heavy rows A1, , As onto row(SA) remains zero. Thus, we only need to bound the change in the sum of squared projection coefficients of the light rows before and after iteration k + r. Without loss of generality, let wk denote the bin that contains Ak+r. Since Sk 1 = span({w1, , wk 1}) has not changed, we only need to bound the difference in cost between projecting onto the component of wk Ak+r orthogonal to Sk 1 and the component of wk orthogonal to Sk 1, respectively denoted as ek and ek. 1. By Claim F.9, for the light rows that are not yet processed (i.e., Aj for j > k + r), the squared projection of each onto ek is at most β2k2. Hence, the total decrease in the squared projection is at most (n k r) β2k2. 2. By Claim F.10, for the processed light rows that are not mapped to the last bin, the squared projection of each onto ek is at most 4β2k2. Hence, the total decrease in the squared projection cost is at most (r 1) 4β2k2. 3. For each row Ai = Ak+r that is mapped to the last bin, by Claim F.11 and the fact Ai 4 2 = Ai 2 2 = 1, the squared projection of Ai onto ek is at most Ai 2 2 wk Ak+r 2 2 + O(n 2) and the squared projection of Ai onto ek is at least Ai 2 2 wk 2 2 O(n 2). Moreover, the squared projection of Ak+r onto ek compared to ek increases by at least ( Ak+r 2 2 wk 2 2 O(n 2)) O(n 2) = Ak+r 2 2 wk 2 2 O(n 2). Hence, the total squared projection of the rows in the bin bk decreases by at least: Ai wk/{Ar+k} Ai 2 2 wk Ar+k 2 2 + O(n 2)) ( X Ai 2 2 wk 2 2 O(n 2)) wk Ar+k 2 2 + O(n 1) wk Ar+k 2 2 wk 2 2 O(n 1) wk 2 2 + O(n 1) by Observation F.4 Hence, summing up the bounds in items 1 to 3 above, the total decrease in the sum of squared projection coefficients is at most O(n 1). 2. If Ak+r is assigned to a bin that contains a heavy row. Without loss of generality, we can assume that Ak+r is mapped to bk that contains the heavy row As. In this case, the distance of heavy rows A1, , As 1 onto the space spanned by the rows of SA is zero. Next, we bound the amount of change in the squared distance of As and light rows onto the space spanned by the rows of SA. Note that the (k 1)-dimensional space corresponding to w1, , wk 1 has not changed. Hence, we only need to bound the decrease in the projection distance of Ak+r onto ek compared to ek (where ek, ek are defined similarly as in the last part). 1. For the light rows other than Ak+r, the squared projection of each onto ek is at most β2k2. Hence, the total increase in the squared projection of light rows onto ek is at most (n k) β2k2 = O(n 1). Published as a conference paper at ICLR 2023 2. By Claim F.11, the sum of squared projections of As and Ak+r onto ek decreases by at least As 2 2 ( As 4 2 + Ak+r 4 2 As + Ar+k 2 2 + O(n 1)) As 2 2 ( As 4 2 + Ak+r 4 2 As 2 2 + Ar+k 2 2 n O(1) + O(n 1)) by Observation F.4 Ar+k 2 2 ( As 2 2 Ak+r 2 2) As 2 2 O(n 1) As 2 2 + Ar+k 2 2 O(n 1) O(n 1) Ar+k 2 2 ( As 2 2 Ak+r 2 2) As 2 2 O(n 1) As 2 2 + Ar+k 2 2 O(n 1) Ar+k 2 2 ( As 2 2 Ak+r 2 2) As 2 2 + Ar+k 2 2 O(n 1) Ar+k 2 2 (1 ( Ak+r 2 2 / As 2 2)) 1 + ( Ar+k 2 2 / As 2 2) O(n 1) Ar+k 2 2 (1 Ak+r 2 As 2 ) O(n 1) 1 ϵ2 Hence, in this case, the total decrease in the squared projection is at least Ar+k 2 2 (1 Ak+r 2 As 2 ) O(n 1) = 1 Ak+r 2 As 2 ) O(n 1) Ar+k 2 = 1 ℓ) O(n 1) As 2 = Thus, for a sufficiently large value of ℓ, the greedy algorithm will assign Ak+r to a bin that only contains light rows. This completes the inductive proof and in particular implies that at the end of the algorithm, heavy rows are assigned to isolated bins. Corollary F.13. The approximation loss of the best rank-k approximate solution in the rowspace Sg A for A Asp(s, ℓ), where A Rn d for d = Ω(n4k4 log n) and Sg is the Count Sketch constructed by the greedy algorithm with non-increasing order, is at most n s. Proof. First, we need to show that the absolute values of the inner products of vectors in v1, , vn are at most ϵ < min{n 2k 2, (n P Ai w Ai 2) 1} so that we can apply Lemma F.12. To show this, note that by Observation F.3, ϵ 2 q d n 2k 2 since d = Ω(n4k4 log n). The proof follows from Lemma F.8 and Lemma F.12. Since all heavy rows are mapped to isolated bins, the projection loss of the light rows is at most n s. Next, we bound the Frobenius norm error of the best rank-k-approximation solution constructed by the standard Count Sketch with a randomly chosen sparsity pattern. Lemma F.14. Let s = αk where 0.7 < α < 1. The expected squared loss of the best rank-k approximate solution in the rowspace Sr A for A Rn d Asp(s, ℓ), where d = Ω(n6ℓ2) and Sr is the sparsity pattern of Count Sketch is chosen uniformly at random, is at least n + ℓk 4e (1 + α)k n O(1). Proof. We can interpret the randomized construction of the Count Sketch as a balls and bins experiment. In particular, considering the heavy rows, we compute the expected number of bins (i.e., rows in Sr A) that contain a heavy row. Note that the expected number of rows in Sr A that do not contain any heavy row is k (1 1 k)s k e s k 1 . Hence, the number of rows in Sr A that contain a heavy row of A is at most k(1 e s k 1 ). Thus, at least s k(1 e s k 1 ) heavy rows are not mapped to an isolated bin (i.e., they collide with some other heavy rows). Then, it is straightforward to show that the squared loss of each such row is at least ℓ n O(1). Published as a conference paper at ICLR 2023 Claim F.15. Suppose that heavy rows Ar1, , Arc are mapped to the same bin via a Count Sketch S. Then, the total squared distances of these rows from the subspace spanned by SA is at least (c 1)ℓ O(n 1). Proof. Let b denote the bin that contains the rows Ar1, , Arc and suppose that it has c light rows as well. Note that by Claim F.10 and Claim F.11, the squared projection of each row Ari onto the subspace spanned by the k bins is at most Ahi 4 2 w 2 2 + O(n 1) cℓ+ c 2ϵ(c2ℓ+ cc ℓ+ c 2) + O(n 1) cℓ n O(1) + n O(1) by ϵ n 3ℓ 1 c2ℓ2 (cℓ+ O(n 1) + O(n 1) Hence, the total squared loss of these c heavy rows is at least cℓ c ( ℓ c + O(n 1)) (c 1)ℓ O(n 1). Thus, the expected total squared loss of the heavy rows is at least: ℓ (s k(1 e s k 1 )) s n O(1) ℓ k(α 1 + e α) ℓα n O(1) s = α (k 1) where 0.7 < α < 1 2e ℓ n O(1) α 0.7 4e O(n 1) assuming k > 4e Next, we compute a lower bound on the expected squared loss of the light rows. Note that Claim F.10 and Claim F.11 imply that when a light row collides with other rows, its contribution to the total squared loss (note that the loss accounts for the amount it decreases from the squared projection of the other rows in the bin as well) is at least 1 O(n 1). Hence, the expected total squared loss of the light rows is at least: (n s k)(1 O(n 1)) (n (1 + α) k) O(n 1) Hence, the expected squared loss of a Count Sketch whose sparsity is picked at random is at least ℓk 4e O(n 1) + n (1 + α)k O(n 1) n + ℓk 4e (1 + α)k O(n 1) Corollary F.16. Let s = α(k 1) where 0.7 < α < 1 and let ℓ (4e+1)n αk . Let Sg be the Count Sketch whose sparsity pattern is learned over a training set drawn from Asp via the greedy approach. Let Sr be a Count Sketch whose sparsity pattern is picked uniformly at random. Then, for an n d matrix A Asp where d = Ω(n6ℓ2), the expected loss of the best rank-k approximation of A returned by Sr is worse than the approximation loss of the best rank-k approximation of A returned by Sg by at least a constant factor. Published as a conference paper at ICLR 2023 E Sr[ min rank-k X rowsp(Sr A) X A 2 F ] n + ℓk 4e (1 + α)k n O(1) Lemma F.14 (1 + 1/α)(n s) ℓ (4e + 1)n αk = (1 + 1/α) min rank-k X rowsp(Sg A) X A 2 F Corollary F.13 F.2 ZIPFIAN ON SQUARED ROW NORMS. Each matrix A Rn d Azipf has rows which are uniformly random and orthogonal. Each A has 2i+1 rows of squared norm n2/22i for i [1, . . . , O(log(n))]. We also assume that each row has the same squared norm for all members of Azipf. In this section, the s rows with largest norm are called the heavy rows and the remaining are the light rows. For convenience, we number the heavy rows 1, . . . , s; however, the heavy rows can appear at any indices, as long as any row of a given index has the same norm for all members of Azipf. Also, we assume that s k/2 and, for simplicity, s = Phs i=1 2i+1 for some hs Z+. That means the minimum squared norm of a heavy row is n2/22hs and the maximum squared norm of a light row is n2/22hs+2. The analysis of the greedy algorithm ordered by non-increasing row norms on this family of matrices is similar to our analysis for the spiked covariance model. Here we analyze the case in which rows are orthogonal. By continuity, if the rows are close enough to being orthogonal, all decisions made by the greedy algorithm will be the same. As a first step, by Lemma F.8, at the end of iteration k the first k rows are assigned to different bins. Then, via a similar inductive proof, we show that none of the light rows are mapped to a bin that contains one of the top s heavy rows. Lemma F.17. At each iteration k + r, the greedy algorithm picks the position of the non-zero value in the (k + r)-th column of the Count Sketch matrix S so that the light row Ak+r is mapped to a bin that does not contain any of top s heavy rows. Proof. We prove the statement by induction. The base case r = 0 trivially holds as the first k rows are assigned to distinct bins. Next we assume that in none of the first k + r 1 iterations a light row is assigned to a bin that contains a heavy row. Now, we consider the following cases: 1. If Ak+r is assigned to a bin that only contains light rows. Without loss of generality we can assume that Ak+r is assigned to bk. Since the vectors are orthogonal, we only need to bound the difference in the projection of Ak+r and the light rows that are assigned to bk onto the direction of wk before and after adding Ak+r to bk. In this case, the total squared loss corresponding to rows in bk and Ak+r before and after adding Ak+1 are respectively before adding Ak+r to bk: Ak+r 2 2 + X Aj bk Aj 2 2 ( Aj bk Aj 4 2 P Aj bk Aj 2 2 ) after adding Ak+r to bk: Ak+r 2 2 + X Aj bk Aj 2 2 ( Ak+r 4 2 + P Aj bk Aj 4 2 Ak+r 2 2 + P Aj bk Aj 2 2 ) Published as a conference paper at ICLR 2023 Thus, the amount of increase in the squared loss is Aj bk Aj 4 2 P Aj bk Aj 2 2 ) ( Ak+r 4 2 + P Aj bk Aj 4 2 Ak+r 2 2 + P Aj bk Aj 2 2 ) = Ak+r 2 2 P Aj bk Aj 4 2 Ak+r 4 2 P Aj bk Aj 2 2 (P Aj bk Aj 2 2)( Ak+r 2 2 + P Aj bk Aj 2 2) Aj bk Aj 4 2 P Aj bk Aj 2 2 Ak+r 2 2 P Aj bk Aj 2 2 + Ak+r 2 2 Aj bk Aj 2 2 Ak+r 2 2 P Aj bk Aj 2 2 + Ak+r 2 2 (F.6) 2. If Ak+r is assigned to a bin that contains a heavy row. Without loss of generality and by the induction hypothesis, we assume that Ak+r is assigned to a bin b that only contains a heavy row Aj. Since the rows are orthogonal, we only need to bound the difference in the projection of Ak+r and Aj In this case, the total squared loss corresponding to Aj and Ak+r before and after adding Ak+1 to b are respectively before adding Ak+r to bk: Ak+r 2 2 after adding Ak+r to bk: Ak+r 2 2 + Aj 2 2 ( Ak+r 4 2 + Aj 4 2 Ak+r 2 2 + Aj 2 2 ) Thus, the amount of increase in the squared loss is Aj 2 2 ( Ak+r 4 2 + Aj 4 2 Ak+r 2 2 + Aj 2 2 ) = Ak+r 2 2 Aj 2 2 Ak+r 2 2 Aj 2 2 + Ak+r 2 2 (F.7) Then (F.7) is larger than (F.6) if Aj 2 2 P Ai bk Ai 2 2. Next, we show that at every inductive iteration, there exists a bin b which only contains light rows and whose squared norm is smaller than the squared norm of any heavy row. For each value m, define hm so that m = Phm i=1 2i+1 = 2hm+2 2. Recall that all heavy rows have squared norm at least n2 22hs . There must be a bin b that only contains light rows and has squared norm at most Ai b Ai 2 2 n2 Phn i=hk+1 2i+1n2 22(hs+1) + 2n2 22(hs+1) + n2 22hk s k/2 and k > 2hk+1 22hs+1 hk hs + 1 Hence, the greedy algorithm will map Ak+r to a bin that only contains light rows. Corollary F.18. The squared loss of the best rank-k approximate solution in the rowspace of Sg A for A Rn d Azipf where A Rn d and Sg is the Count Sketch constructed by the greedy algorithm with non-increasing order, is < n2 Proof. At the end of iteration k, the total squared loss is Phn i=hk+1 2i+1 n2 22i . After that, in each iteration k + r, by (F.6), the squared loss increases by at most Ak+r 2 2. Hence, the total squared Published as a conference paper at ICLR 2023 loss in the solution returned by Sg is at most 22i ) = 4n2 Next, we bound the squared loss of the best rank-k-approximate solution constructed by the standard Count Sketch with a randomly chosen sparsity pattern. Observation F.19. Let us assume that the orthogonal rows Ar1, , Arc are mapped to the same bin and for each i c, Ar1 2 2 Ari 2 2. Then, the total squared loss of Ar1, , Arc after projecting onto Ar1 Arc is at least Ar2 2 2 + + Arc 2 2. Proof. Note that since Ar1, , Arc are orthogonal, for each i c, the squared projection of Ari onto Ar1 Arc is Ari 4 2 / Pc j=1 Arj 2 2. Hence, the sum of squared projection coefficients of Ar1, , Arc onto Ar1 Arc is Pc j=1 Arj 4 2 Pc j=1 Arj 2 2 Ar1 2 2 Hence, the total projection loss of Ar1, , Arc onto Ar1 Arc is at least Arj 2 2 Ar1 2 2 = Ar2 2 2 + + Arc 2 2 . In particular, Observation F.19 implies that whenever two rows are mapped into the same bin, the squared norm of the row with smaller norm fully contributes to the total squared loss of the solution. Lemma F.20. For k > 210 2, the expected squared loss of the best rank-k approximate solution in the rowspace of Sr A for An d Azipf, where Sr is the sparsity pattern of a Count Sketch chosen uniformly at random, is at least 1.095n2 Proof. In light of Observation F.19, we need to compute the expected number of collisions between rows with large norm. We can interpret the randomized construction of the Count Sketch as a balls and bins experiment. For each 0 j hk, let Aj denote the set of rows with squared norm n2 22(hk j) and let A>j = S jj| = Phk i=j+1 2hk i+1 = Phk j i=1 2i = 2(2hk j 1). Moreover, note that k = 2(2hk+1 1). Next, for a row Ar in Aj (0 j < hk), we compute the probability that at least one row in A>j collides with Ar. Pr[at least one row in A>j collides with Ar] = (1 (1 1 (1 e |A>j | = (1 e 2hk j 1 (1 e 2 j 2) since 2hk j 1 2hk+1 1 > 2 j 2 Hence, by Observation F.19, the contribution of rows in Aj to the total squared loss is at least (1 e 2 j 2) |Aj| n2 22(hk j) =(1 e 2 j 2) n2 2hk j 1 = (1 e 2 j 2) n2 Published as a conference paper at ICLR 2023 Thus, the contribution of rows with large squared norm, i.e., A>0, to the total squared loss is at least4 j=0 2j 1 (1 e 2 j 2) 1.095 n2 2hk 2 for hk > 8 Corollary F.21. Let Sg be a Count Sketch whose sparsity pattern is learned over a training set drawn from Asp via the greedy approach. Let Sr be a Count Sketch whose sparsity pattern is picked uniformly at random. Then, for an n d matrix A Azipf, for a sufficiently large value of k, the expected loss of the best rank-k approximation of A returned by Sr is worse than the approximation loss of the best rank-k approximation of A returned by Sg by at least a constant factor. Proof. The proof follows from Lemma F.20 and Corollary F.18. Remark F.22. We have provided evidence that the greedy algorithm that examines the rows of A according to a non-increasing order of their norms (i.e., greedy with non-increasing order) results in a better rank-k solution compared to the Count Sketch whose sparsity pattern is chosen at random. However, still other implementations of the greedy algorithm may result in a better solution compared to the greedy algorithm with non-increasing order. To give an example, in the following simple instance the greedy algorithm that checks the rows of A in a random order (i.e., greedy with random order) achieves a rank-k solution whose cost is a constant factor better than the solution returned by the greedy with non-increasing order. Let A be a matrix with four orthogonal rows u, u, v, w where u 2 = 1 and v 2 = w 2 = 1 + ϵ and suppose that the goal is to compute a rank-2 approximation of A. Note that in the greedy algorithm with non-decreasing order, v and w will be assigned to different bins and by a simple calculation we can show that the copies of u also will be assigned to different bins. Hence, the squared loss in the computed rank-2 solution is 1 + (1+ϵ)2 2+(1+ϵ)2 . However, the optimal solution will assign v and w to one bin and the two copies of u to the other bin which results in a squared loss of (1 + ϵ)2 which is a constant factor smaller than the solution returned by the greedy algorithm with non-increasing order for sufficiently small values of ϵ. On the other hand, in the greedy algorithm with a random order, with a constant probability of ( 1 8), the computed solution is the same as the optimal solution. Otherwise, the greedy algorithm with random order returns the same solution as the greedy algorithm with a non-increasing order. Hence, in expectation, the solution returned by the greedy with random order is better than the solution returned by the greedy algorithm with non-increasing order by a constant factor. G EXPERIMENT DETAILS G.1 LOW-RANK APPROXIMATION In this section, we describe the experimental parameters in our experiments. We first introduce some parameters in Stage 2 of our approach proposed in Section 3. bs: batch size, the number of training samples used in one iteration. lr: learning rate of gradient descent. iter: the number of iterations of gradient descent. Table 7.1: Test errors for LRA (using Algorithm 2 with four sketches) For a given m, the dimensions of the four sketches were: S Rm n, R Rm d, S2 R5m n, R2 R5m d. Parameters of the algorithm: bs = 1, lr = 1.0, 10.0 for hyper and video respectively, num_it = 1000. For our algorithm 4, we use the average of all training matrix as the input to the algorithm. 4The numerical calculation is computed using Wolfram Alpha. Published as a conference paper at ICLR 2023 Table 7.1: Test errors for LRA (using Algorithm 1 with one sketch) Parameters of the algorithm: bs = 1, lr = 1.0, 10.0 for hyper and video respectively, num_it = 1000. For our algorithm 4, we use the sum of all training matrix as the input to the algorithm. G.2 SECOND-ORDER OPTIMIZATION As we state in Section 6, when we fix the positions of the non-zero entries (uniformly chosen in each column or sampled according to the heavy leverage score distribution), we aim to optimize the values by gradient descent, as mentioned in Section 3. Here the loss function is given in Section 6. In our implementation, we use Py Torch (Paszke et al. (2019)), which can compute the gradient automatically (here we can use torch.qr() and torch.svd() to define our loss function). For a more nuanced loss function, which may be beneficial, one can use the package released in Agrawal et al. (2019), where the authors studied the problem of computing the gradient of functions which involve the solution to certain convex optimization problem. As mentioned in Section 2, each column of the sketch matrix S has exactly one non-zero entry. Hence, the i-th coordinate of p can be seen as the non-zero position of the i-th column of S. In the implementation, to sample p randomly, we can sample a random integer in {1, . . . , m} for each coordinate of p. For the heavy rows mentioned in Section 6, we can allocate positions 1, . . . , k to the k heavy rows, and for the other rows, we randomly sample an integer in {k + 1, . . . , m}. We note that once the vector p, which contains the information of the non-zero positions in each column of S, is chosen, it will not be changed during the optimization process in Section 3. Next, we introduce the parameters for our experiments: bs: batch size, the number of training samples used in one iteration. lr: learning rate of gradient descent. iter: the number of iterations of gradient descent In our experiments, we set bs = 20, iter = 1000 for all datasets. We set lr = 0.1 for the Electric dataset.