# understanding_selfsupervised_learning_dynamics_without_contrastive_pairs__524e9688.pdf Understanding Self-Supervised Learning Dynamics without Contrastive Pairs Yuandong Tian 1 Xinlei Chen 1 Surya Ganguli 1 2 While contrastive approaches of self-supervised learning (SSL) learn representations by minimizing the distance between two augmented views of the same data point (positive pairs) and maximizing views from different data points (negative pairs), recent non-contrastive SSL (e.g., BYOL and Sim Siam) show remarkable performance without negative pairs, with an extra learnable predictor and a stop-gradient operation. A fundamental question arises: why do these methods not collapse into trivial representations? We answer this question via a simple theoretical study and propose a novel approach, Direct Pred, that directly sets the linear predictor based on the statistics of its inputs, without gradient training. On Image Net, it performs comparably with more complex two-layer non-linear predictors that employ Batch Norm and outperforms a linear predictor by 2.5% in 300-epoch training (and 5% in 60-epoch). Direct Pred is motivated by our theoretical study of the nonlinear learning dynamics of non-contrastive SSL in simple linear networks. Our study yields conceptual insights into how non-contrastive SSL methods learn, how they avoid representational collapse, and how multiple factors, like predictor networks, stop-gradients, exponential moving averages, and weight decay all come into play. Our simple theory recapitulates the results of real-world ablation studies in both STL-10 and Image Net. Code is released1. 1. Introduction Self-supervised learning (SSL) has emerged as a powerful method for learning useful representations without re- 1Facebook AI Research 2Stanford University. Correspondence to: Yuandong Tian . Proceedings of the 38 th International Conference on Machine Learning, PMLR 139, 2021. Copyright 2021 by the author(s). 1https://github.com/facebookresearch/ luckmatters/tree/master/ssl quiring expensive target labels (Devlin et al., 2018). Many state-of-the-art SSL methods in computer vision employ the principle of contrastive learning (Oord et al., 2018; Tian et al., 2019; He et al., 2020; Chen et al., 2020a; Bachman et al., 2019) whereby the hidden representations of two augmented views of the same object (positive pairs) are brought closer together, while those of different objects (negative pairs) are encouraged to be further apart. Minimizing differences between positive pairs encourages modeling invariances, while contrasting negative pairs is thought to be required to prevent representational collapse (i.e., mapping all data to the same representation). However, some recent SSL work, notably BYOL (Grill et al., 2020) and Sim Siam (Chen & He, 2020), have shown the remarkable capacity to learn powerful representations using only positive pairs, without ever contrasting negative pairs. These methods employ a dual pair of Siamese networks (Bromley et al., 1994) (Fig. 1): the representation of two views are trained to match, one obtained by the composition of an online and predictor network, and the other by a target network. The target network is not trained via gradient descent; and either employs a direct copy of the online network (e.g., Sim Siam (Chen & He, 2020)), or a momentum encoder that slowly follows the online network in a delayed fashion through an exponential moving average (EMA) (e.g., Mo Co (He et al., 2020; Chen et al., 2020b) and BYOL (Grill et al., 2020)). Compared to contrastive learning, these non-contrastive SSL methods do not require large batch size (e.g., 4096 in Sim CLR (Chen et al., 2020a)) or memory queue (e.g., Mo Co (He et al., 2020; Chen et al., 2020b)) to provide negative pairs. Therefore, they are generally more efficient and conceptually simple while maintaining state-of-the-art performance. Since the entire procedure in non-contrastive SSL encourages the online+predictor network and the target network to become similar to each other, this overall scheme raises several fundamental unsolved theoretical questions. Why/how does it avoid collapsed representations? What is the nature of the learned representations? How do multiple design choices and hyperparameters interact nonlinearly in the learning dynamics? While there are interesting theoretical studies of contrastive SSL (Arora et al., 2019; Lee et al., 2020; Tosh et al., 2020), any theoretical understanding of the nonlinear learning dynamics of non-contrastive Understanding Self-Supervised Learning Dynamics without Contrastive Pairs Augmentation Stop-Gradient Figure 1. Two-layer setting with a linear, bias-free predictor. SSL remains open. In this paper, we make a first attempt to analyze the behavior of non-contrastive SSL training and the empirical effects of multiple hyperparameters, including (1) Exponential Moving Average (EMA) or momentum encoder, (2) Higher relative learning rate (αp) of the predictor, and (3) Weight decay η. We explain all these empirical findings with an exceedingly simple theory based on analyzing the nonlinear learning dynamics of simple linear networks. Note that deep linear networks have provided a useful tractable theoretical model of nonconvex loss landscapes (Kawaguchi, 2016; Du & Hu, 2019; Laurent & Brecht, 2018) and nonlinear learning dynamics (Saxe et al., 2013; 2019; Lampinen & Ganguli, 2018; Arora et al., 2018) in these landscapes, yielding insights like dynamical isometry (Saxe et al., 2013; Pennington et al., 2017; 2018) that lead to improved training of nonlinear deep networks. Despite the simplicity of our theory, it can still predict how various hyperparameter choices affect performance in an extensive set of real-world ablation studies. Moreover, the simplicity also enables us to provide conceptual and analytic insights into why performance patterns vary the way they do. Specifically, our theory accounts for the following diverse empirical findings: Essential part of non-contrastive SSL. The existence of the predictor and stop-gradient is absolutely essential. Removing either of them leads to representational collapse in BYOL and Sim Siam. EMA. While the original BYOL needs EMA to work, they later confirmed that EMA is not necessary (i.e., the online and target networks can be identical) if a higher αp is used. This is also confirmed with Sim Siam, as long as the predictor is updated more often or has larger learning rate (or larger αp). However, the performance is slightly lower. Predictor Optimality and Relative learning rate αp. Both BYOL and Sim Siam suggest that the predictor should always be optimal, in the sense of always achieving min- Plug-in frequency (every N minibatches) 1 2 3 5 EMA 40.67 0.50 35.29 2.49 34.60 0.98 35.63 2.66 no EMA 39.45 1.26 34.01 1.54 34.58 2.93 32.22 2.94 Table 1. Simply plugging in the optimal solution to the linear predictor shows poor performance after 100 BYOL epochs (Top1 accuracy in STL-10 (Coates et al., 2011) downstream classification task). The optimal solution is obtained by solving (with regularization) Wp E [ff ] = 1 2(E [faf ] + E [ff a ]), in which the two expectations is estimated with exponential moving average. In comparison, with gradient descent, BYOL with a single linear layer predictor can reach 74%-75% Top-1 in STL-10 after 100 epochs. Unless explicitly stated, in all our experiments, we use Res Net-18 (He et al., 2016) as the backbone network for CIFAR10/STL-10 experiments and SGD as the optimizer with learning rate α = 0.03, momentum 0.9, weight decay η = 0.0004 and EMA parameter γa = 0.996. Each setting is repeated 5 times. imal ℓ2 error in predicting the target network s outputs from the online network s outputs. This optimality conjecture was motivated by observed superior performance when the predictor had large learning rates and/or was allowed more frequent updates than the rest of the network. However (Chen & He, 2020) also showed that if the predictor is updated too often, then performance drops, which questions the importance of an always optimal predictor as a key requirement for learning good representations. Weight Decay. Table 15 in BYOL (Grill et al., 2020) indicates that no weight decay may lead to unstable results. A recent blogpost (Fetterman & Albrecht, 2020) also mentions using weight decay leads to stable learning in BYOL. Finally, motivated by our theoretical analysis, we propose a new method Direct Pred that directly sets the predictor weights based on principal components analysis of the predictor s input, thereby avoiding complicated predictor dynamics and initialization issues. We show that this simple Direct Pred method nevertheless yields comparable performance in CIFAR-10 and outperforms gradient training of the linear predictor by +5% Top-1 accuracy in linear evaluation protocol on both STL-10 and Image Net (60 epochs). On the standard Image Net benchmark (300 epochs), Direct Pred achieves 72.4%/91.0% Top-1/Top-5, 2.5% higher than BYOL with linear predictor (69.9%/89.6%) and comparable with default BYOL setting with 2-layer predictor (72.5%/90.8%). 2. Two-layer linear model To obtain analytic and conceptual insights into noncontrastive SSL we analyze a simple, bias-free linear BYOL model where the online, target and predictor networks are specified by the weight matrices W Rn2 n1, Wp Rn2 n2 and Wa Rn2 n1 respectively (Fig. 1). Understanding Self-Supervised Learning Dynamics without Contrastive Pairs Let x Rn1 be a data point drawn from the data distribution p(x) and let x1 and x2 be two augmented views of x: x1, x2 paug( |x) where paug( |x) is the augmentation distribution. In practice such data augmentations correspond to random crops, blurs or color distortions of images (Chen et al., 2020a). Let f1 = Wx1 Rn2 be the online representation of view 1, and f2a = Wax2 Rn2 be the target representation of view 2. In BYOL, the learning dynamics of W and Wp are obtained by minimizing J(W, Wp) := 1 2Ex1,x2 Wpf1 Stop Grad(f2a) 2 2 , (1) while the dynamics of Wa is obtained differently, via an exponential moving average (EMA) of W. We will analyze this combined dynamics for W, Wp and Wa, in the presence of additional weight decay, in the limit of large batch sizes and small discrete time learning rates. This limit can be well approximated by the gradient flow (see Supplementary Material (SM) for all derivations): Lemma 1. BYOL learning dynamics following Eqn. 1: Wp = αp ( Wp W(X + X ) + Wa X) W ηWp (2) W = W p ( Wp W(X + X ) + Wa X) ηW (3) Wa = β( Wa + W) (4) Here, X := E [ x x ] where x(x) := Ex paug( |x) [x ] is the average augmented view of a data point x and X := Ex Vx |x[x ] is the covariance matrix Vx |x[x ] of augmented views x conditioned on x, subsequently averaged over the data x. Note that αp and β reflect multiplicative learning rate ratios between the predictor and target networks relative to the online network. Finally, the terms involving η reflect weight decay. As a gradient flow formulation, the learning rate α does not appear in Lemma 1. In the actual finite time update, the learning rate for Wp is ααp, the EMA rate is αβ = 1 γa, where γa is the usual EMA parameter (e.g,. BYOL uses 0.996), and the weight decay for actual training is η := αη. We note that since Sim Siam is an ablation of BYOL that removes the EMA computation, the underlying dynamics of Sim Siam can also be obtained from Lemma 1 simply by setting Wa = W, inserting this relation into Eqn. 2 and Eqn. 3, and ignoring Eqn. 4. Importantly, the stop-gradient on the target branch is still there. Overall Eqns. 2-4 constitute our starting point for analyzing the combined roles of relative learning rates αp and β, weight decay rate η and various ablations in determining the performance of both BYOL and Sim Siam. We first derive two very general results (see SM). Theorem 1 (Weight decay promotes balancing of the predictor and online networks.). Completely independent of EMA + no-bias EMA + bias no EMA + no-bias no EMA + bias 70.62 1.05 70.99 1.01 71.36 0.44 71.37 0.77 Table 2. Top-1 accuracy of BYOL on STL-10 under linear evaluation protocol, trained for 100 epochs with no weight decay (η = 0) and αp = 1. It is worse than the baseline (74.51 0.47 without predictor bias) when the weight decay is set to be η = 0.0004. No-bias means the linear predictor does not have a bias term. the particular dynamics of Wa in Eqn. 4, the update rules (Eqn. 2 and Eqn. 3) possess the invariance W(t)W (t) = α 1 p W p (t)Wp(t) + e 2ηt C, (5) where C is a symmetric matrix that depends only on the initialization of W and Wp. This theorem implies that for both BYOL and Sim Siam, there exists a balancing that ensures that any matching between the online and target representations will not be attributable solely to the predictor weights, rendering the online weights useless. Instead what the predictor learns, the online network will also learn, which is important as the online network s representations are what is used for downstream tasks. We note that similar weight balancing dynamics has been discovered in multi-layer linear networks and matrix factorization (Arora et al., 2018; Du et al., 2018). Our results generalize this to SSL dynamics. Second, a nonzero weight decay could help remove the extra constant C due to initialization, further balancing the predictor and online network weights and possibly leading to better performance on downstream tasks (Tbl. 2). Theorem 2 (The stop-gradient signal is essential for success.). With Wa = W (Sim Siam case), removing the stop-gradient signal yields a gradient update for W given by positive semi-definite (PSD) matrix H(t) := X (W p Wp + In2) + X W p Wp + ηIn1n2 (here Wp := Wp In2 and is the Kronecker product): d dtvec(W) = H(t)vec(W). (6) If the minimal eigenvalue λmin(H(t)) over time is bounded below, inft 0 λmin(H(t)) λ0 > 0, then W(t) 0. Thus we have proven analytically in this simple setting that removing the stop-gradient leads to representational collapse, as observed in more complex settings in Sim Siam (Chen & He, 2020). Similarly, with Wa = W and no predictor (Wp = In2), then the dynamics Eqn. 3 also reduces to a similar form and W(t) 0 (see SM). 3. How multiple factors affect learning dynamics The learning dynamics in Eqns. 2-4 constitute a set of high dimensional coupled nonlinear differential equations that Understanding Self-Supervised Learning Dynamics without Contrastive Pairs 0 20 40 60 80 100 120 Sorted eigenvalue index Log of eigenvalues log(sj) of matrix F Evolvement of Eigenvalues sj of F Epoch 0 Epoch 5 Epoch 10 Epoch 15 Epoch 20 Epoch 95 0 20 40 60 80 100 120 Sorted eigenvalue index Eigenvalues pj of matrix Wp Evolvement of Eigenvalues pj of Wp Epoch 0 Epoch 5 Epoch 10 Epoch 15 Epoch 20 Epoch 95 0 20 40 60 80 100 120 Sorted eigenvalue index Normalized correlation Eigenspace alignment Epoch 0 Epoch 5 Epoch 10 Epoch 15 Epoch 20 Epoch 95 0 20 40 60 80 Epochs Assymmetry Measure |Wp W p|/|Wp| Assymetric Measure 0 20 40 60 80 100 120 Sorted eigenvalue index Log of eigenvalues log(sj) of matrix F Evolvement of Eigenvalues sj of F Epoch 0 Epoch 5 Epoch 10 Epoch 15 Epoch 20 Epoch 95 0 20 40 60 80 100 120 Sorted eigenvalue index Eigenvalues pj of matrix Wp Evolvement of Eigenvalues pj of Wp Epoch 0 Epoch 5 Epoch 10 Epoch 15 Epoch 20 Epoch 95 0 20 40 60 80 100 120 Sorted eigenvalue index Normalized correlation Eigenspace alignment Epoch 0 Epoch 5 Epoch 10 Epoch 15 Epoch 20 Epoch 95 0 20 40 60 80 Epochs Assymmetry Measure |Wp W p|/|Wp| Assymetric Measure Figure 2. Training BYOL in STL-10 for 100 epochs with EMA. Top row: No symmetric regularization imposed on Wp, Bottom row: symmetric regularization on Wp. From left to right: (1) Evolvement of eigenvalues for F. Since F is PSD and its eigenvalue sj varies across scales, we plot log(si). We could see some eigenvalues are growing while others are shrinking to zero over training. (2) Similar step-function behaviors for the predictor Wp. Its negative eigenvalues shrinks towards zero and leading eigenvalues becomes larger. (3) The eigenspace of F and Wp gradually align with each other (Theorem 3). For each eigenvector uj of F, we compute cosine angle (normalized correlation) between uj and Wpuj to measure alignment. (4) Wp gradually becomes symmetric and PSD during training. can be difficult to solve analytically in general. Therefore, to obtain analytic insights into the functional roles of the relative learning rates αp and β and weight decay η, we make a series of simplifying assumptions. Intriguingly, under these simplifying assumptions we obtain a rich set of analytic predictions, which we then test experimentally in more realistic scenarios. We find, nicely, that these predictions still qualitatively hold even when our simplifying assumptions required for obtaining analytic results do not. Assumption 1 (Proportional EMA). We first reduce the dimensionality of the dynamics in Eqns. 2-4 by enforcing that the target network Wa undergoes EMA but is forced to always be proportional to the online network via the relation Wa(t) = τ(t)W(t). Inserting this relation into the EMA dynamics in Eqn. 4 yields τW + τ W = β(1 τ)W. Thus we obtain a reduced dynamics for W, Wp and τ. By not enforcing the stronger Sim Siam constraint that Wa = W, we can still model EMA dynamics. Intuitively, τ = τ(t) is a dynamic parameter that depends on how quickly W = W(t) grows over time. If W is constant, then W = 0 and τ stabilizes to 1. On the other hand, if W grows rapidly, then τ becomes small. While Assumption 1 is a simplification, as we shall see, it still reveals interesting verifiable predictions about the functional role of EMA. Assumption 2 (Isotropic data and augmentation). We assume the data distribution p(x) has zero mean and identity covariance, while the augmentation distribution paug( |x) has mean x and covariance σ2I. This simplifies the dynamics in Eqns. 2-4 by reducing the augmentation averaged data covariance to X = I and the data averaged augmentation covariance to X = σ2I. Many previous studies of deep learning dynamics made simplifying isotropic assumptions about data (Tian, 2017; Brutzkus & Globerson, 2017; Du et al., 2019; Bartlett et al., 2018; Safran & Shamir, 2018). Since our fundamental goal is to obtain the first analytic understanding of the dynamics of non-contrastive SSL methods, it is useful to first achieve this in the simplest possible isotropic setting. Interestingly, we will find that our final conclusions generalize to nonisotropic real world settings. Assumption 3 (Symmetric predictor). We enforce symmetry in Wp by initializing it to be a symmetric matrix, and then symmetrizing the flow for Wp in Eqn. 2 (see SM). This symmetry assumption was motivated by both fixed point analysis and empirical findings. First, the fixed point of Eqn. 2 under Assumption 1 and 2 and η > 0 is always a symmetric matrix and in numerical simulation the asymmetric part Wp W p eventually vanishes (See Appendix for the proof and numerical simulations). Moreover, during BYOL training without a symmetry constraint on the predictor, Wp gradually moves towards symmetry (Fig. 2). Second, a set of experiments reveal that whether the predictor is symmetric or not has a dramatic effect in terms of both performance and interaction with EMA. In our STL-10 experiment, enforcing symmetric Wp in the presence of EMA improves performance on downstream tasks (Tbl. 3). In contrast, in the absence of EMA, a symmetric Wp fails while an asymmetric Wp works reasonably well. Similar behavior holds on Image Net: a symmetric one layer linear predictor Wp in Sim Siam (i.e. without EMA) achieves performance no better than random guessing (Top-1/5: 0.1%/0.5%), while an asymmetric Wp Understanding Self-Supervised Learning Dynamics without Contrastive Pairs No predictor bias With predictor bias sym Wp regular Wp sym Wp regular Wp One-layer linear predictor EMA 75.09 0.48 74.51 0.47 74.52 0.29 74.16 0.33 no EMA 36.62 1.85 72.85 0.16 36.04 2.74 72.13 0.53 Two-layer predictor with Batch Norm and Re LU EMA 71.58 6.46 78.85 0.25 77.64 0.41 78.53 0.34 no EMA 35.59 2.10 65.98 0.71 41.92 4.25 65.59 0.66 Table 3. The effect of symmetrization of Wp on downstream classification task (BYOL Top-1 on STL-10). Symmetric Wp leads to slightly better performance compared to regular Wp in the presence of EMA. On the other hand, without EMA, symmetric Wp crashes. Same effects happen in two-layer predictor with Batch Norm and Re LU as well. Weight decay η = 0.0004 and αp = 1. achieves a Top-1/5 accuracy of 68.1%/88.2%. Our theory will explain this as well as show how to obtain good performance with a symmetric predictor without EMA by increasing its relative learning rate αp. 3.1. Dynamical alignment of eigenspaces between the predictor and its input correlation matrix Under the three assumptions stated above, we analyze the coupled dynamics of F := WXW and Wp. Note that F is the correlation matrix of the outputs of the online network which also serve as inputs to the predictor. By Assumption 2, E [x] = 0 and F is also the covariance matrix. We find F and Wp obey the following dynamics (see SM): 2 (1 + σ2){Wp, F} + αpτF ηWp (7) F = (1 + σ2){W 2 p , F} + τ{Wp, F} 2ηF This dynamics reveals that the eigenspace of Wp will gradually align with that of F under certain conditions (see SM for derivation): Theorem 3 (Eigenspace alignment). Under Eqn. 7, the commutator [F, Wp] := FWp Wp F satisfies: d dt[F, Wp] = [F, Wp]K K[F, Wp] (8) K(t) = (1+σ2) αp 2 F(t) + W 2 p (t) τ 1 + σ2 Wp(t) +3 (9) If inft 0 λmin[K(t)] = λ0 > 0, then the commutator [F(t), Wp(t)] F e 2λ0t [F(0), Wp(0)] F 0 (10) For symmetric Wp, when Wp and F commute they can be simultaneously diagonalized. Thus this shows that the eigenspace of Wp gradually aligns with that of F. To test this prediction, we performed extensive experiments showing that training BYOL using Res Net-18 on STL-10 yields eigenspace alignment, as demonstrated in Fig. 2. Now if the eigenspaces of Wp and F do align, we can obtain fully decoupled dynamics. Let the columns of the matrix U be the common eigenvectors, so that Wp = UΛWp U where ΛWp = diag[p1, p2, . . . , pd], F = UΛF U where ΛF = diag[s1, s2, . . . , sd]. For each mode j, we have (see SM for derivation): pj = αpsj τ (1 + σ2)pj ηpj (11) sj = 2pjsj τ (1 + σ2)pj 2ηsj (12) sj τ = β(1 τ)sj τ sj/2. (13) This decoupled dynamics constitutes a dramatically simplified set of 3 dimensional nonlinear dynamical systems for BYOL learning, and two dimensional nonlinear systems (obtained by constraining τ = 1) for Sim Siam. As expected, each mode s dynamics is equivalent to the 3 dimensional dynamics obtained by setting n1 = n2 = 1 in Eqns. 2-4 and making the replacements W 2 = sj, Wp = pj, and Wa/W = τ (see SM). Thus the decoupled dynamics in Eqns 1113 reduce to the scalar case of BYOL dynamics in Eqns. 2-4 after a change of variables and the condition in Thm. 3 reveals when this decoupled regime is reachable. Non-symmetric Wp. When Assumption 3 is absent, the analysis is much more convoluted. One possible way is to decompose Wp = A + B where A = A is symmetric and B = B is skew-symmetric. We leave it for future work. 3.2. Analysis of decoupled dynamics The simplified three (two) dimensional dynamics of BYOL (Sim Siam) yields significant insights. First, there is clearly a collapsed fixed point at pj(t) = sj(t) = 0 and τ taking any value. We wish to understand conditions under which pj and sj can avoid this collapsed fixed point and grow from small random initial conditions. Since sj is an eigenvalue of WW , we are particularly interested in conditions under which sj achieves large final values, corresponding to a non-collapsed online network, that are moreover sensitive to the statistics of the data, governed by σ2. Exact integral. First, an important observation, similar to Theorem 1, is that the dynamics possesses an exact integral of motion, obtained by multiplying Eqn. 11 by 2α 1 p pj, subtracting, Eqn. 12 and integrating over time yielding sj(t) = α 1 p p2 j(t) + e 2ηtcj (14) where cj = α 1 p p2 j(0) sj(0) is fixed by initial conditions. In absence of weight decay (η = 0), this integral reveals that the initial condition encoded in cj is never forgotten Understanding Self-Supervised Learning Dynamics without Contrastive Pairs Figure 3. State space dynamics in Eqns. 11 and 12 for no (η = 0) weak (η = 0.01) and strong (η = 1) weight decay at fixed τ = 1 and αp = 1. Red (green) points indicate stable (unstable) fixed points, blue curves indicate flow lines, and the dashed black curve indicates the parabola sj = p2 j/αp. and the dynamics of pj and sj are confined to parabolas of the form sj(t) = p2 j(t) + cj, as can be seen by the blue flow lines in Fig. 3(left). With weight decay (η > 0) over time the initial condition is forgotten and the dynamics approaches the invariant parabola sj = α 1 p p2 j as can been seen by the approach of the blue flow lines to the black dashed parabola in Fig. 3 right and middle. We discuss these two cases in turn. First we note that in both cases, since the EMA computation is often very slow (Grill et al., 2020), corresponding to small β, the dynamics of τ in Eqn. 13 is slow relative to that of pj and sj. Therefore to understand the combined dynamics, we can search for the fixed points that pj and sj will rapidly approach at fixed τ. Over time τ will then either slowly approach 1 (BYOL) or be always equal to 1 (Sim Siam), and sj and pj will follow their τ-dependent fixed points. No weight decay. When η = 0, Eqns. 11 and 12 at a fixed value of τ yield a branch of collapsed fixed points given by sj = 0 and pj taking any value, and a branch of non-collapsed fixed points, with pj = τ/(1 + σ2) and sj taking any value (horizontal and vertical red/green lines in Fig. 3,left). A sufficient criterion on initial conditions to avoid the collapsed branch is sj(0) > p2 j(0)/αp corresponding to lying above the dashed black parabola in Fig. 3,left. This restricted initial condition reveals why a fast predictor (large αp) is advantageous (Obs#1): larger αp leads to a smaller basin of attraction of the collapsed branch by flattening the dashed parabola. Indeed both BYOL and Sim Siam have noted that a fast predictor can help avoid collapse. On the other hand, αp cannot be infinitely large (Obs#2): since sj(+ ) = sj(0)+α 1 p (p2 j(+ ) p2 j(0)), very large αp implies that sj, the final value of the online network characterizing the learned representation, does not grow even if pj does. This is consistent with results which show that optimizing the predictor too often doesn t work in Sim Siam (Chen & He, 2020), and directly setting an optimal predictor fails as well (Tbl. 1). The online network needs to grow along with the predictor and that cannot happen if the predictor is too fast. Advantage of weight decay. In the non-collapsed branch of fixed points without weight decay (vertical red line in Fig. 3,left), the predictor pj takes the exact value τ/(1 + Positive effects Negative effects Relative predictor lr αp #1,#6 #2 Weight decay η #3,#7 #4,#5 EMA β #8 #9,#10 Table 4. Summarization of positive/negative effects of various hyperparameter choices (EMA β, relative predictor learning rate αp and weight decay η). #1 means (Obs#1) in the text. (b) (c) 𝜂< 𝜏% Saddle Point Figure 4. Fixed point of pj = pj(pj p j )(pj p j+). Stable fixed points are in red, unstable in green and saddle in black. When the weight decay η = 0, the trivial solution pj = 0 is a saddle. When η > 0, the trivial solution becomes stable near to the origin and initial pj needs to be large enough to converge to the stable non-collapsed solution p j+. σ2), which models the invariance to augmentation correctly: a large data augmentation variance σ2 should lead to a small magnitude of the learned representation. Ideally, we want sj to have the same property. With weight decay η > 0 in Eqn. 14, memory of the initial condition cj fades away, yielding convergence to some point on the invariant parabola sj = α 1 p p2 j. (Obs#3): Therefore, by tying the online network to the predictor, weight decay allows sj to also model invariance to augmentations correctly if the predictor does, regardless of the random initial condition cj. Dynamics on the invariant parabola. Because weight decay forces convergence to the invariant parabola sj = α 1 p p2 j, we next focus on dynamics along this parabola (i.e. cj = 0 in Eqn. 14). In this case, Eqn. 13 has a solution: τ(t) = p 1 j (t)βe βt Z t 0 pj(t )eβt dt, (15) with initial condition τ(0) = 0. Inserting the invariant sj = α 1 p p2 j into Eqn. 11, the dynamics of pj is given by: pj = p2 j τ(t) (1 + σ2)pj ηpj. (16) We first analyze the fixed points where pj = 0 at fixed τ. When the weight decay 0 < η τ 2 4(1+σ2), pj has has three fixed points (Fig. 4(b)): τ 2 4η(1 + σ2) 2(1 + σ2) > 0, p j0 = 0 where both p j0 and p j+ are stable and p j is unstable, as shown in Fig. 4(b). The basin of attraction of the collapsed fixed point p j0 = 0 is pj < p j while the basin of attraction of the useful non-collapsed fixed point p j+ is pj > p j , yielding an important constraint on initial conditions to avoid collapse. Note that p j is a decreasing function of τ and increasing function of η (see SM). This Understanding Self-Supervised Learning Dynamics without Contrastive Pairs means that with larger η, p j moves right and the basin of collapse expands (Obs#4). When η > τ 2 4(1+σ2) there is only one stable fixed point p j0 = 0 (Fig. 4(c)). Under such strong weight decay collapse is unavoidable (Obs#5). We now discuss the dynamics. First we define the quantity j := pj[τ (1+σ2)pj] η, which must satisfy two criteria. Note that Eqn. 16 can be written as pj = pj j, so j must at some point be positive to drive pj(t) to any positive non-collapsed fixed point p j+. Second, for eigenspace alignment in Theorem 3 to remain stable (even if the alignment has already happened), K(t) must be positive definite (PD) in Eqn. 9. Using the eigen-space alignment conditions and the invariance sj = α 1 p p2 j, the positive definite condition on K(t) can be written as 2 αp(1 + σ2)sj + η . (17) This criterion and the criterion j > 0 yield interesting insights into the roles of various hyperparameters choices. First (Obs#6), larger predictor learning rate αp can play an advantageous role by loosening the upper bound in Eqn. 17, making it easier to satisfy. Second (Obs#7), increasing η also has the same effect. Role of EMA. Without EMA, τ 1 and (Eqn. 17) may not hold initially when pj is small. The reason is j is to leading order linear in pj when τ = 1 while the right hand side is to leading order sj p2 j, so the left hand side has a larger contribution from pj than the right. EMA resolves this as follows. When the training begins, sj is often quite small, and τ remains small since W changes rapidly. When pj grows to the fixed point p j+ τ/(1 + σ2), the growth of sj stops, making τ larger. This in turns sets a higher fixed point goal for pj. This process continues until the feature is stabilized and τ = 1 (Fig. 5 for details). Therefore, EMA can serve as an automatic curriculum (Obs#8): it sets an initial small goal of τ 1+σ2 for pj so j need only be small and positive to both drive pj larger and satisfy Eqn. 17. Then EMA gradually sets a higher goal for pj by increasing τ, so that pj and sj can grow, while keeping the eigenspaces of Wp and F aligned. As a trade-off, a very slow EMA schedule (β small) yields a slow training procedure (Obs#9) (See Fig. 5). Also small τ leads to larger p j and more eigen modes can be trapped in the collapsed basin (Obs#10). 3.3. Summarizing the effects of hyperparameters We summarize the positive and negative effects of multiple hyperparameters in Tbl. 4. We next provide additional ablations and experiments to further justify our reasoning. Different weight decay ηp and ηs. If we set a higher weight decay for the predictor (ηp) than the online net (ηs), No predictor bias With predictor bias sym Wp regular Wp sym Wp regular Wp Weight decay only for predictor ( ηp = 0.0004 and ηs = 0) EMA 71.91 0.70 70.54 0.93 73.67 0.47 70.89 0.98 no EMA 71.12 0.71 71.34 0.63 73.01 0.37 71.70 0.83 No weight decay for all ( ηp = ηs = 0) EMA 71.76 0.28 70.62 1.05 71.86 0.39 70.99 1.01 no EMA 43.04 2.32 71.36 0.44 41.36 3.33 71.37 0.77 Table 5. Symmetric weight works without EMA, if we set weight decay for the predictor ( ηp = 0.0004) but not the trunk ( ηs = 0) in BYOL experiment on STL-10. Report Top-1 accuracy after 100 epochs. If there is no weight decay for all layers, then again symmetric weight doesn t work without EMA. then pj grows slower than sj and it is possible that the condition of Theorem 3 can still be satisfied without using EMA. Indeed Tbl. 5 shows this is the case. Larger learning rate of the predictor αp > 1. Our analysis predicts that one way to make symmetric Wp work with no EMA is to use αp > 1 (i.e. Theorem 3 is more easily satisfied). Fig. 6 verifies this prediction. Moreover Table 22 in Appendix of BYOL (Grill et al., 2020) also shows that αp > 1 is required to get BYOL working without EMA. As a reference, Table 22 in Appendix I.2 of BYOL (Grill et al., 2020) also shows a similar trend: the learning rate of the (2-layer) predictor needs to be higher than that of the projector for strong performance in Image Net, when EMA is absent. 4. Optimization-free Predictor Wp A direct consequence of our theory is a new method for choosing the predictor that avoids gradient descent altogether. Instead, we estimate the correlation matrix F of predictor inputs and directly set Wp to be a function of this, thereby avoiding both the need to align the eigenspaces of F and Wp through optimization, and the need to initialize Wp outside the basin of collapse. As we shall see, this exceedingly simple, theory motivated method also yields better performance in practice compared to gradient-based optimization of a linear predictor. We call our method Direct Pred which simply estimates F, computes its eigen-decomposition ˆF = ˆU ˆΛF ˆU , where ˆΛF = diag[s1, s2, . . . , sd], and sets Wp via pj = sj + ϵ max j sj, Wp = ˆUdiag[pj] ˆU . (18) This choice is theoretically motivated by eigenspacealignment between Wp and F (Theorem. 3) and convergence to the invariant parabola sj p2 j in Eqn. 14 with weight decay (η > 0). Here the estimate correlation matrix ˆF can be obtained by a moving average: ˆF = ρ ˆF + (1 ρ)EB [ff ] (19) where EB [ ] is the expectation over a batch. Note that Understanding Self-Supervised Learning Dynamics without Contrastive Pairs 0 20 40 60 80 100 Time t Eigenvalue pj of Wp Eigenvalue pj of Wp 0 20 40 60 80 100 Time t Eigenvalue sj of F Eigenvalue sj of F 0 20 40 60 80 100 Time t EMA coefficent τ EMA coefficent τ 0 20 40 60 80 100 Time t Eigenvalue of K(t) Eigenvalue of K(t) 0 20 40 60 80 100 Time t Eigenvalue pj of Wp Eigenvalue pj of Wp 0 20 40 60 80 100 Time t Eigenvalue sj of F Eigenvalue sj of F 0 20 40 60 80 100 Time t EMA coefficent τ EMA coefficent τ 0 20 40 60 80 100 Time t Eigenvalue of K(t) Eigenvalue of K(t) Figure 5. The role played by weight decay η and EMA β when applying symmetric regularization on Wp on synthetic experiments simulating decoupled dynamics (Eqn. 11-13). The learning rate α = 0.01. Both terms boost the eigenvalue of K(t) to above 0 so that eigen space alignment could happen (Theorem 3), but also come with different trade-offs. Here β = 0.4 so that αβ = 0.004 = 1 γa where γa = 0.996 as in BYOL. Top row (Weight Decay η): A large η boost the eigenvalue of K(t) up, but substantially decreases the final converging eigenvalues pj and sj (i.e., the final features are not salient), or even drags them to zero (no training happens). Bottom row (EMA β). A small EMA β also boost the eigenvalue of K(t), but the training converges much slower. Here η = 0.04 so that ηα equals to the weight decay ( η = 0.0004) in our STL-10 experiments. 0 5 10 15 Relative predictor learning rate αp Top-1 Accuracy in Linear Evaluation 0 5 10 15 Relative predictor learning rate αp Top-1 Accuracy in Linear Evaluation regular, bias=false sym, bias=false regular, bias=true sym, bias=true Figure 6. The effects of relative learning rate αp without EMA. If αp > 1, symmetric Wp with no EMA can also work. Experiments on STL-10 and CIFAR-10 (Krizhevsky et al., 2009) (100 epochs with 5 random seeds). where f is not zero-mean, we keep ˆF a correlation matrix (rather than a covariance) without zero-centering f, otherwise the performance deteriorates. We also added a regularization factor proportional to a small ϵ to boost the small eigenvalues sj so they can learn faster. In all our experiments on real-world datasets, we use ℓ2-normalization so the absolute magnitude of sj doesn t matter. Hyper-parameter freq. Besides, we also evaluate a hybrid approach by introducing freq, which is how frequently eigen-decomposition is conducted for matrix ˆF to set Wp. For example, freq = 5 means that eigen decomposition is run every 5 minibatches. When Wp is not set by eigen decomposition, it is updated by regular gradient updates. freq = 1 means the eigen-decomposition is performed at every minibatch. Tbl. 6 shows that directly computing Wp through Direct Pred works better (76.77%) than training via gradient descent (74.51% in Tbl. 3, regular Wp with EMA). Additional regularization through ϵ yields even better perfor- Regularization factor ϵ 0 0.01 0.1 0.5 ρ = 0.3 76.77 0.24 77.11 0.35 77.86 0.16 75.06 1.10 ρ = 0.5 76.65 0.20 76.76 0.33 77.56 0.25 75.22 0.81 Table 6. STL-10 Top-1 after BYOL training for 100 epochs, if we use Direct Pred (Eqn. 18). It outperforms training Wp using gradient descent (74.51% in Tbl. 3, regular Wp with EMA). EMA is used in all experiments. No predictor bias. ρ defined in Eqn. 19. Initial constant cj 0.1 0.05 0.05 0.1 freq=1 46.57 18.43 65.31 18.22 77.11 0.66 76.46 0.55 freq=2 75.01 0.48 75.10 0.35 76.83 0.52 76.31 0.27 Table 7. STL-10 Top-1 Accuracy after BYOL training for 100 epochs. With different cj. ρ = 0.3 and ϵ = 0. EMA is used in all experiments. No predictor bias. mance (77.38%). Different ways to estimate F (moving average or simple average) yield only small differences. The performance of Direct Pred also remains good over many more training epochs (Tbl. 8). Moreover, if we allow some gradient steps in between directly setting Wp (i.e., freq > 1), performance becomes even better (80.28%). This might occur because the estimated ˆF may not be accurate enough and SGD can help correct it. This also mitigates the computational cost of eigen-decomposition. The constant cj. What happens if pj = p max(sj cj, 0) with cj = 0? If cj is small negative, performance is still fine but a positive cj leads to very poor performance (Tbl. 7), likely due to many small eigen-values sj becoming zero and therefore trapped in the collapsed basin. Feature-dependent Wp. Note one of the advantages of Understanding Self-Supervised Learning Dynamics without Contrastive Pairs Number of epochs 100 300 500 STL-10 Direct Pred 77.86 0.16 78.77 0.97 78.86 1.15 Direct Pred (freq=5) 77.54 0.11 79.90 0.66 80.28 0.62 SGD baseline 75.06 0.52 75.25 0.74 75.25 0.74 CIFAR-10 Direct Pred 85.21 0.23 88.88 0.15 89.52 0.04 Direct Pred (freq=5) 84.93 0.29 88.83 0.10 89.56 0.13 SGD baseline 84.49 0.20 88.57 0.15 89.33 0.27 Table 8. STL-10/CIFAR-10 Top-1 accuracy of Direct Pred, after training for longer epochs. ρ = 0.3, ϵ = 0.1 with EMA. using two layer predictors is that Wp can depend on the input features. We explored this idea by using a few random partitions of the input space, and within each random partition we estimated a different correlation matrix ˆF. The final ˆF is the sum of all the correlation matrices. With 6 random partitions, Direct Pred achieves 78.20 0.16 Top-1 accuracy after 100 epochs, closing performance gap to twolayer predictors (78.85% in Tbl. 3). We leave a thorough analysis of the two layer setting to future work. Image Net experiments. We conducted additional experiments on Image Net (Deng et al., 2009), with our own BYOL (Grill et al., 2020) implementation. We used Res Net-50 (He et al., 2016) as the backbone to produce features for a linear probe, followed by a projector and a predictor. The architecture design (e.g., feature dimensions), augmentation strategies (e.g., color jittering, blur (Chen et al., 2020a), solarization, etc.) and linear classification protocol strictly follow BYOL (Grill et al., 2020). We experimented with two different training settings to study the generalization ability of Direct Pred. In the first setting, we employ an asymmetric loss (given two views, only one view is used as the prediction target). The loss is optimized using standard SGD for 60 epochs with a batch size of 256. The second setting follows BYOL more closely, where we use a symmetrized loss, 4096 batch size and LARS optimizer (You et al., 2017), and train for 300 epochs. The results are summarized in Tbl. 9. Both settings exhibit similar behaviors in comparison, and we take the 300epoch results as our highlights in the following. As a baseline, the default 2-layer predictor from BYOL (with Batch Norm and Re LU, 4096 hidden dimension, 256 input/output dimension) achieves 72.5% top-1 accuracy, and 90.8% top5 accuracy with 300-epoch pre-training. This reproduces the accuracy reported in BYOL (Grill et al., 2020). We find Direct Pred can match this performance (72.4% top1, and 91.0% top-5) without any gradient-based training by instead directly setting the (256 256) linear predictor weights every mini-batch. In particular for top-5 Direct Pred is even 0.2% better. For a fair comparison, we also BYOL variants Accuracy (60 ep) Accuracy (300 ep) Top-1 Top-5 Top-1 Top-5 2-layer predictor* 64.7 85.8 72.5 90.8 linear predictor 59.4 82.3 69.9 89.6 Direct Pred 64.4 85.8 72.4 91.0 * 2-layer predictor is BYOL default setting. Table 9. Image Net experiments comparing Direct Pred with BYOL (Grill et al., 2020). Without gradient-based training, Direct Pred is able to match the performance of the default 2-layer predictor introduced by BYOL, and significantly outperform the linear predictor by 5% (60 epoch) and 2.5% (300 epoch). run BYOL with a learned linear predictor. We find the performance drops to 69.9%, and 89.6% respectively (2.5% gap to our method). The gap is even bigger in 60-epoch settings, up to 5.0% in top-1 (59.4% vs. 64.4%). These experiments demonstrate the success of Direct Pred on STL10 and CIFAR can also generalize and scale to Image Net. 5. Discussion Summary. Therefore, remarkably, our theoretical analysis of non-contrastive SSL, primarily centered around a 3 dimensional nonlinear dynamical system, not only yields conceptual insights into the functional roles of complex ingredients like EMA, stop-gradients, predictors, predictor symmetry, diverse learning rates, weight decay and all their interactions, but also predicts the performance patterns of many ablation studies as well as suggests an exceedingly simple Direct Pred method that rivals the performance of more complex predictor dynamics in real-world settings. Two-layer non-linear predictor. With only a linear predictor, our results on Image Net (Tbl. 9) have already shown strong performance, on par with a default BYOL setting with a 2-layer predictor on Image Net. One interesting question is how the dynamics changes if the predictor has 2 layers. While we don t provide a formal analysis and the math can be quite complicated, the intuition here is that the fat 2-layer predictor used in practice (e.g., more (4096) hidden dimension than input/output dimensions (256), and a Re LU in between) essentially provides a large pool of initial weight directions to start with, and some of them could be lucky draws , that make eigen-space alignment faster. On the other hand, a 1-layer predictor with gradient updates may get stuck in local minima. Therefore, with the same number of epochs, a 2-layer predictor outperforms 1layer, and is comparable with Direct Pred which does not suffer from local minima issues. Acknowledgements We thank Lantao Yu for helpful discussions. Understanding Self-Supervised Learning Dynamics without Contrastive Pairs Arora, S., Cohen, N., and Hazan, E. On the optimization of deep networks: Implicit acceleration by overparameterization. In ICML. PMLR, 2018. Arora, S., Khandeparkar, H., Khodak, M., Plevrakis, O., and Saunshi, N. A theoretical analysis of contrastive unsupervised representation learning. 2019. Bachman, P., Hjelm, R. D., and Buchwalter, W. Learning representations by maximizing mutual information across views. ar Xiv preprint ar Xiv:1906.00910, 2019. Bartlett, P., Helmbold, D., and Long, P. Gradient descent with identity initialization efficiently learns positive definite linear transformations by deep residual networks. In ICML, 2018. Bromley, J., Guyon, I., Le Cun, Y., S ackinger, E., and Shah, R. Signature verification using a siamese time delay neural network. Neur IPS, 1994. Brutzkus, A. and Globerson, A. Globally optimal gradient descent for a convnet with gaussian inputs. In ICML, 2017. Chen, T., Kornblith, S., Norouzi, M., and Hinton, G. A simple framework for contrastive learning of visual representations. ar Xiv preprint ar Xiv:2002.05709, 2020a. Chen, X. and He, K. Exploring simple siamese representation learning. ar Xiv preprint ar Xiv:2011.10566, 2020. Chen, X., Fan, H., Girshick, R., and He, K. Improved baselines with momentum contrastive learning. ar Xiv preprint ar Xiv:2003.04297, 2020b. Coates, A., Ng, A., and Lee, H. An analysis of single-layer networks in unsupervised feature learning. In International conference on artificial intelligence and statistics, 2011. Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., and Fei Fei, L. Image Net: A Large-Scale Hierarchical Image Database. In CVPR, 2009. Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. Bert: Pre-training of deep bidirectional transformers for language understanding. ar Xiv preprint ar Xiv:1810.04805, 2018. Du, S. and Hu, W. Width provably matters in optimization for deep linear neural networks. In ICML, 2019. Du, S. S., Hu, W., and Lee, J. D. Algorithmic regularization in learning deep homogeneous models: Layers are automatically balanced. ar Xiv preprint ar Xiv:1806.00900, 2018. Du, S. S., Lee, J. D., Li, H., Wang, L., and Zhai, X. Gradient descent finds global minima of deep neural networks. ICML, 2019. Fetterman, A. and Albrecht, J. Understanding selfsupervised and contrastive learning with bootstrap your own latent (byol), 2020. https://untitled-ai.github.io/ understanding-self-supervisedcontrastive-learning.html. Grill, J.-B., Strub, F., Altch e, F., Tallec, C., Richemond, P. H., Buchatskaya, E., Doersch, C., Pires, B. A., Guo, Z. D., Azar, M. G., et al. Bootstrap your own latent: A new approach to self-supervised learning. ar Xiv preprint ar Xiv:2006.07733, 2020. He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In CVPR, 2016. He, K., Fan, H., Wu, Y., Xie, S., and Girshick, R. Momentum contrast for unsupervised visual representation learning. In CVPR, 2020. Kawaguchi, K. Deep learning without poor local minima. Neur IPS, 2016. Krizhevsky, A., Hinton, G., et al. Learning multiple layers of features from tiny images. 2009. Lampinen, A. K. and Ganguli, S. An analytic theory of generalization dynamics and transfer learning in deep linear networks. In ICLR, 2018. Laurent, T. and Brecht, J. Deep linear networks with arbitrary loss: All local minima are global. In ICML, pp. 2902 2907. PMLR, 2018. Lee, J. D., Lei, Q., Saunshi, N., and Zhuo, J. Predicting what you already know helps: Provable self-supervised learning. ar Xiv preprint ar Xiv:2008.01064, 2020. Oord, A. v. d., Li, Y., and Vinyals, O. Representation learning with contrastive predictive coding. ar Xiv preprint ar Xiv:1807.03748, 2018. Pennington, J., Schoenholz, S., and Ganguli, S. Resurrecting the sigmoid in deep learning through dynamical isometry: theory and practice. In Neur IPS. 2017. Pennington, J., Schoenholz, S. S., and Ganguli, S. The emergence of spectral universality in deep networks. In AISTATS, 2018. Safran, I. and Shamir, O. Spurious local minima are common in two-layer relu neural networks. In ICML. PMLR, 2018. Understanding Self-Supervised Learning Dynamics without Contrastive Pairs Saxe, A. M., Mc Clelland, J. L., and Ganguli, S. Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. ar Xiv preprint ar Xiv:1312.6120, 2013. Saxe, A. M., Mc Clelland, J. L., and Ganguli, S. A mathematical theory of semantic development in deep neural networks. Proc. Natl. Acad. Sci. U. S. A., 2019. Tian, Y. An analytical formula of population gradient for two-layered relu network and its applications in convergence and critical point analysis. In ICML, 2017. Tian, Y., Krishnan, D., and Isola, P. Contrastive multiview coding. ar Xiv preprint ar Xiv:1906.05849, 2019. Tosh, C., Krishnamurthy, A., and Hsu, D. Contrastive learning, multi-view redundancy, and linear models. ar Xiv preprint ar Xiv:2008.10150, 2020. You, Y., Gitman, I., and Ginsburg, B. Large batch training of convolutional networks. ar Xiv preprint ar Xiv:1708.03888, 2017.