# concurrent_adversarial_learning_for_largebatch_training__8c73e043.pdf Published as a conference paper at ICLR 2022 CONCURRENT ADVERSARIAL LEARNING FOR LARGEBATCH TRAINING Yong Liu1, Xiangning Chen2, Minhao Cheng3, Cho-Jui Hsieh2, Yang You1 1Department of Computer Science, National University of Singapore 2Department of Computer Science, University of California, Los Angeles 3Department of Computer Science and Engineering, Hong Kong University of Science and Technology {liuyong, youy}@comp.nus.edu.sg, {xiangning, chohsieh}@cs.ucla.edu, minhaocheng@ust.hk Large-batch training has become a widely used technique when training neural networks with a large number of GPU/TPU processors. As batch size increases, stochastic optimizers tend to converge to sharp local minima, leading to degraded test performance. Current methods usually use extensive data augmentation to increase the batch size as a remedy, but we found the performance brought by data augmentation decreases with the increase of batch size. In this paper, we propose to leverage adversarial learning to increase the batch size in large-batch training. Despite being a natural choice for smoothing the decision surface and biasing towards a flat region, adversarial learning has not been successfully applied in large-batch training since it requires at least two sequential gradient computations at each step. To overcome this issue, we propose a novel Concurrent Adversarial Learning (Con Adv) method that decouples the sequential gradient computations in adversarial learning by utilizing stale parameters. Experimental results demonstrate that Con Adv can successfully increase the batch size of Res Net-50 training on Image Net while maintaining high accuracy. This is the first work that successfully scales the Res Net-50 training batch size to 96K. 1 INTRODUCTION As larger datasets and bigger models being proposed, training deep neural networks has become quite time-consuming. For instance, training BERT (Devlin et al., 2019) takes 3 days on 16 v3 TPUs. GPT-2 (Radford et al., 2019) contains 1,542M parameters and requires 168 hours of training on 16 v3 TPUs. With the developments of high performance computing clusters, (e.g., Google and NVIDIA build high performance clusters with thousands of TPU or GPU chips), how to fully utilize those computing resources to accelerate the training process has become an important research topic. Data parallelism is a commonly used technique for distributed neural network training, where each processor computes the gradient of a local batch and the gradients across processors are aggregated at each iteration for a parameter update. Training with hundreds or thousands of processors with data parallelism is thus equivalent to running a stochastic gradient optimizer (e.g., SGD or Adam) with very large batch size, also known as large batch training. For example, Google and NVIDIA show that by increasing the batch size to 64k on Image Net, they can finish 90-epochs Res Net training within one minute (Kumar et al., 2021; Mattson et al., 2019). But why can t we infinitely increase the batch size as long as more computing resources are available? Large batch training often faces two challenges. First, given a fixed number of training epochs, increasing the batch size implies reducing the number of training iterations. Even worse, it has been observed that large-batch training often converges to solutions with bad generalization performance (also known as sharp local minima) (Keskar et al., 2017), possibly due to the lack of inherent noise in each stochastic gradient update. Although this problem can be partially mitigated by using different optimizers such as LARS (You et al., 2017) and LAMB (You et al., 2019), the limit of batch size still exists. For instance, Google utilizes several techniques, such as distributed batch normaliza- Published as a conference paper at ICLR 2022 tion and mixed-precision training, to further scale the training of Res Net-50 on 4096 v3 TPU chips. However, it can just expand the batch size to 64k (Kumar et al., 2021; Ying et al., 2018). To mitigate the generalization gap brought by large-batch training, data augmentation has become an indispensable component. For instance, researchers at Facebook use augmentation to scale the training of Res Net-50 to 256 NVIDIA P100 GPUs with a batch size of 8k on Image Net (Goyal et al., 2017). You et al. (2018) also use data augmentation to expand the batch size to 32k on 2048 KNL nodes. However, in this paper we observe that when batch size is large enough (i.e., larger than 32k), data augmentation will also increase the difficulty of training and even have a negative impact on test accuracy. This motivates us to study the application of adversarial training in large-batch training, which finds a perturbation within a bounded set around each sample to train the model. Previous works find that adversarial training can lead to a significant decrease in the curvature of the loss surface and make the network behave more linearly in small batch size cases, which could be used as a way to improve generalization (Xie et al., 2020; Moosavi-Dezfooli et al., 2019). However, adversarial training has not been used in large-batch training since it requires a series of sequential gradient computations within each update to find an adversarial example. Even when conducting only one gradient ascent to find the adversarial example, adversarial training requires two sequential gradient computations (one for adversarial example and one for weight update) that cannot be parallelized. Therefore, even with infinite computing resources, adversarial training is at least two times slower than standard training and increasing the batch size cannot compensate for that. To resolve this issue and make adversarial training applicable for large-batch training, we propose a novel Concurrent Adversarial Learning (Con Adv) algorithm. We show that by allowing the computation of adversarial examples using stale weights, the two sequential gradient computations in adversarial training can be decoupled, leading to fully parallelized computations at each step. As a result, extra processors can be fully utilized to achieve the same iteration throughput as original SGD or Adam optimizers. Comprehensive experimental results on large-batch training demonstrate that Con Adv is a better choice than existing augmentations. Our main contributions are listed below: This is the first work showing that adversarial learning can significantly increase the batch size limit of large-batch training without using data augmentation. The proposed algorithm, Conv Adv, can successfully decouple the two sequential gradient computations in adversarial training and make them parallelizable. This makes adversarial training achieve similar efficiency with standard stochastic optimizers when using sufficient computing resources. Furthermore, we empirically show that Con Adv achieves almost identical performance as the original adversarial training. We also provide theoretical analysis for Con Adv. Comprehensive experimental studies demonstrate that the proposed method can push the limit of large batch training on various tasks. For Res Net-50 training on Image Net, Con Adv alone achieves 75.3% accuracy when using 96K batch size. Further, the accuracy will rise to 76.2% when combined with data augmentation. This is the first method scaling Res Net-50 batch size to 96K with accuracy matching the MLPerf standard (75.9%), while previous methods fail to scale beyond 64K batch size. 2 BACKGROUND 2.1 LARGE-BATCH TRAINING Using data parallelism with SGD naturally leads to large-batch training on distributed systems. However, it was shown that an extremely large batch is difficult to converge and has a generalization gap (Keskar et al., 2017; Hoffer et al., 2017) . Therefore, related work starts to carefully fine-tune the hyper-parameters to bridge the gap, such as learning rate, momentum (You et al., 2018; Goyal et al., 2017; Li, 2017; Shallue et al., 2018; Xue et al., 2021; Lou et al., 2021). Goyal et al. (2017) try to narrow the generalization gap with the heuristics of learning rate scaling. However, there is still big room to increase the batch size. Several recent works try to use adaptive learning rate to reduce the fine-tuning of hyper-parameters and further scaling the batch size to larger value (You et al., 2018; Iandola et al., 2016; Codreanu et al., 2017; Akiba et al., 2017; Jia et al., 2018; Smith Published as a conference paper at ICLR 2022 et al., 2017; Martens & Grosse, 2015; Devarakonda et al., 2017; Osawa et al., 2018; You et al., 2019; Yamazaki et al., 2019; Liu et al., 2022). You et al. (2017) propose Layer-wise Adaptive Rate Scaling (LARS) for better optimization and scaling to the batch size of 32k without performance penalty on Image Net. In addition, related work also tries to bridge the gap from the aspect of augmentation. Goyal et al. (2017) use data augmentation to further scale the training of Res Net-50 on Image Net. Yao et al. (2018a) propose an adaptive batch size method based on Hessian information to gradually increase batch size during training and use vanilla adversarial training to regularize against the sharp minima. However, the process of adversarial training is time-consuming and they just use the batch size of 16k in the second half of the training process (the initial batch size is 256). How to further accelerate the training process based on adversarial training and reduce its computational burden is still an open problem. 2.2 ADVERSARIAL LEARNING Adversarial training has shown great success on improving the model robustness through collecting adversarial examples and injecting them into training data (Goodfellow et al., 2015; Papernot et al., 2016; Wang et al., 2019). Madry et al. (2017) formulates it into a min-max optimization framework as follows: min θ E(xi,yi) D[ max ||δ||p ϵ L(θt, x + δ, y)], (1) where D = {(xi, yi)}n i=1 denotes training samples and xi Rd, yi {1, ..., Z}, δ is the adversarial perturbation, || ||p denotes some Lp-norm distance metric, θt is the parameters of time t and Z is the number of classes. Goodfellow et al. (2015) proposes FGSM to collect adversarial data, which performs a one-step update along the gradient direction (the sign) of the loss function. Project Gradient Descent (PGD) algorithm (Madry et al., 2017) firstly carries out a random initial search in the allowable range (spherical noise region) near the original input, and then iterates FGSM several times to generate adversarial examples. Recently, several papers (Shafahi et al., 2019; Wong et al., 2020; Andriushchenko & Flammarion, 2020) aim to improve the computation overhead brought by adversarial training. Specifically, Free Adv (Shafahi et al., 2019) tries to update both weight parameter θ and adversarial example x at the same time by exploiting the correlation between the gradient to the input and to the model weights. Similar to Free-adv, Zhang et al. (2019) further restrict most of the forward and backpropagation within the first layer to speedup computation. Wong et al. (2020) finds the overhead could be further reduced by using single-step FGSM with random initialization. While these works aim to improve the efficiency of adversarial training, they still require at least two sequential gradient computations for every step. Our concurrent framework could decouple the two sequential gradient computations to further boost efficiency, which is more suitable for large-batch training. Recently, several works (Xie et al., 2020; Cheng et al., 2021; Chen et al., 2021; Mei et al., 2022) show that the adversarial example can serve as an augmentation to benefit the clean accuracy in the small batch size setting. However, whether adversarial training can improve the performance of large-batch training is still an open problem. MLPerf (Mattson et al., 2019) is an industry-standard performance benchmark for machine learning, which aims to fairly evaluate system performance. Currently, it includes several representative tasks from major ML areas, such as vision, language, recommendation. In this paper, we use Res Net-50 (He et al., 2016) as our baseline model and the convergence baseline is 75.9% accuracy on Image Net. 3 PROPOSED ALGORITHM In this section, we introduce our enlightening findings and the proposed algorithm. We first study the limitation of data augmentation in large-batch training. Then we discuss the bottleneck of adversarial training in distributed systems and propose a novel Concurrent Adversarial Learning (Con Adv) method for large-batch training. Published as a conference paper at ICLR 2022 The update process of step t Training Data Local Data Local Data Adv Data Concat Data Adv Data The update process of step t Training Data Local Data Local Data Figure 1: (a) Distributed Adversarial Learning (Dis Adv), (b) Concurrent Adversarial Learning (Con Adv). To ease the understanding, we just show the system including two workers. 3.1 DOES DATA AUGMENTATION IMPROVE THE PERFORMANCE OF LARGE-BATCH TRAINING? Data augmentation can usually improve the generalization of models and is a commonly used technique to improve the batch size limit in large-batch training. To formally study the effect of data augmentation in large-batch training, we train Res Net-50 using Image Net (Deng et al., 2009) by Auto Aug (AA) (Cubuk et al., 2019). The results shown in Figure 2 reveal that although AA helps improve generalization under batch size 64K, the performance gain decreases as batch size increases. Further, it could lead a negative effect when the batch size is large enough (e.g., 128K or 256K). For instance, the top-1 accuracy is increased from 76.9% to 77.5% when using AA on 1k batch size. However, it decreases from 73.2% to 72.9% under data augmentation when the batch size is 128k and drops from 64.7% to 62.5% when the batch size is 256k. The main reason is that the augmented data increases the diversity of training data, which leads to slower convergence when using fewer training iterations. Recent work tries to concat the original data and augmented data to jointly train the model and improve their accuracy (Berman et al., 2019). However, we find that concating them will hurt the accuracy when batch size is large. Therefore, we just use the augmented data to train the model. The above experimental results motivate us to explore a new method for large batch training. 3.2 ADVERSARIAL LEARNING IN THE DISTRIBUTED SETTING 1k 4k 8k 16k 32k 64k 96k 128k 256k Top-1 Accuracy Res Net-50 Res Net-50+AA Figure 2: Augmentation Analysis Adversarial learning can be viewed as a way to automatically conduct data augmentation. Instead of defining fixed rules to augment data, adversarial learning conducts gradient-based adversarial attacks to find adversarial examples. As a result, adversarial learning leads to smoother decision boundary (Karimi et al., 2019; Madry et al., 2017), which often comes with flatter local minima (Yao et al., 2018b). Instead of solving the original empirical risk minimization problem, adversarial learning aims to solve a min-max objective that minimizes the loss under the worst case perturbation of samples within a small radius. In this paper, since our main goal is to improve clean accuracy instead of robustness, we consider the following training objective that includes loss on both natural samples and adversarial samples: min θ E(xi,yi) D[L(θt; xi, yi) + max δ p ϵ L(θt; xi + δ, yi)], (2) where L is the loss function and ϵ represents the value of perturbation. Although many previous work in adversarial training focus on improving the trade-off between accuracy and robustness (Shafahi et al., 2019; Wong et al., 2020), recently Xie et al. (2020) show that using split Batch Norm for adversarial and clean data can improve the test performance on clean data. Therefore, we also adopt this split Batch Norm approach. Published as a conference paper at ICLR 2022 Sequence 1 Concurrent 1 Concurrent 2 Concurrent Training (Independet) First Gradient Second Gradient Figure 3: Vanilla Training and Concurrent Training Algorithm 1 Con Adv for t = 1, , T do for xi Bk c,t do Compute Loss: L(θt; xi, yi) using main BN, Lk a(θt; ˆxi(θt τ), yi) using adv BN, LB(θt) = EBk c,t L(θt; xi, yi)+ EBk a,t(ˆxi(θt τ), yi) Minimize the LB(θt) and obtain gk t (θt) end for for xi Bk c,t+τ do Calculate adv gradient gk a(θt)on Bk c,t+τ Obtain adv examples (ˆxi(θt), yi) end for end for Aggregate: ˆgt(θt) = 1 K PK k=1 ˆgk t (θt) Update weight θt+1 on parameter sever For Distributed Adversarial Learning (Dis Adv), training data D is partitioned into N local dataset Dk, and D = k=K k=1 Dk. For worker k, we firstly sample a mini-batch data (clean data) Bk t,c from the local dataset Dk at each step t. After that, each worker downloads the weights θt from parameter sever and then uses θt to obtain the adversarial gradients gk a(θt) = x L(θt; xi, yi) on input example xi Bk t,c. Noted that we just use the local loss E(xi,yi) Dk L(θt; xi, yi) to calculate the adversarial gradient gk a(θt) rather than the global loss E(xi,yi) DL(θt; xi, yi), since we aim to reduce the communication cost between workers. In addition, we use 1-step Project Gradient Descent (PGD) to calculate ˆx i (θt) = xi + α x L(θt; xi, yi) to approximate the optimal adversarial example x i . Therefore, we can collect the adversarial mini-batch Bk a,t = {(ˆx i (θt), yi)} and use both the clean example (xi, yi) Bk c,t and adversarial example (ˆx i (θt), yi) Bk a,t to update the weights θt. More specially, we use main Batch Norm to calculate the statics of clean data and auxiliary Batch Norm to obtain the statics of adversarial data. We show the workflow of adversarial learning on distributed systems (Dis Adv) as Figure 1, and more importantly, we notice that it requires two sequential gradient computations at each step which is time-consuming and, thus, not suitable for large-batch training. Specifically, we firstly need to compute the gradient gk a(θt) to collect adversarial example ˆx . After that, we use these examples to update the weights θt, which computes the second gradient. In addition, the process of collecting adversarial example ˆx i and use ˆx i to update the model are tightly coupled, which means that each worker cannot calculate local loss E(xi,yi) Dk L(θt; xi, yi) and E(xi,yi) Dk L(θt; ˆx i , yi) to update the weights θt, until the total adversarial examples ˆx i are obtained. 3.3 CONCURRENT ADVERSARIAL LEARNING FOR LARGE-BATCH TRAINING As mentioned in the previous section, the vanilla Dis Adv requires two sequential gradient computations at each step, where the first gradient computation is to obtain ˆx i based on L(θt, xi, yi) and then compute the gradient of L(θt, ˆx i , yi) to update θt. Due to the sequential update nature, this overhead cannot be reduced even when increasing the number of processors even with an infinite number of processors, the speed of two sequential computations will be twice of one parallel update. This makes adversarial learning unsuitable for large-batch training. In the following, we propose a simple but novel method to resolve this issue, and provide theoretical analysis on the proposed method. Concurrent Adversarial Learning (Con Adv) As shown in Figure 3, our main finding is that if we use stale weights (θt τ) for generating adversarial examples, then two sequential computations can Published as a conference paper at ICLR 2022 be de-coupled and the parameter update step run concurrently with the future adversarial example generation step. Now we formally define the Con Adv procedure. Assume xi is sampled at iteration t, instead of the current weights θt, we use stale weights θt τ (where τ is the delay) to calculate the gradient and further obtain an approximate adversarial example ˆxi(θt τ): ga(θt τ) = x L(θt τ; xi, yi), ˆxi(θt τ) = xi + α ga(θt τ). (3) In this way, we can obtain the adversarial sample ˆxi(θt τ) through stale weights before updating the model at each step t. Therefore, the training efficiency can be improved. The structure of Con Adv is shown in Figure 1: At each step t, each worker k can directly concatenate the clean mini-batch data and adversarial mini-batch data to calculate the gradient ˆgk t (θt) and update the model. That is because the system has obtained the approximate adversarial example ˆxi based on the stale weights θt τ before iteration t. In practice, we set τ = 1 so the adversarial examples ˆxi is computed at iteration t 1. Therefore, each iteration will compute the current weight update and the adversarial examples for the next batch: θt+1 = θt + η 2 θ(E(xi,yi) Bt,c L(θt; xi, yi) + Eˆxi,yi Bt,a L(θt, ˆxi(θt 1), yi)), (4) ˆxi(θt) = xi + α x L(θt; xi, yi), where (xi, yi) Bc,t+1, (5) where Bc,t = k=K k=1 Bk c,t denotes clean mini-batch of all workers and Ba,t = k=K k=1 Bk a,t represents adversarial mini-batch of all workers. These two computations can be parallelized so there is no longer two sequential computations at each step. In large-batch setting when the number of workers reaches the limit that each batch size can use, Con Adv is similarly fast as standard optimizers such as SGD or Adam. The pseudo code of proposed Con Adv is shown in Algorithm 1. 3.4 CONVERGENCE ANALYSIS In this section, we will show that despite using the stale gradients, Con Adv still enjoys nice convergence properties. For simplicity, we will use L(θ, xi) as a shorthand for L(θ; xi, yi) and indicates the ℓ2 norm. We define the optimal adversarial example as x i = arg maxx i Xi L(θt, x i ). In order to present our main theorem, we will need the following assumptions. Assumption 1. The function L(θ, x) satisfies the Lipschitzian conditions: x L(θ1; x) x L(θ2; x) Lxθ θ1 θ2 , θL(θ1; x) θL(θ2; x) Lθθ θ1 θ2 , θL(θ; x1) θL(θ; x2) Lθx x1 x2 , x L(θ; x1) x L(θ; x2) Lxx x1 x2 . (6) Assumption 2. L(θ, x) is locally µ-strongly concave in Xi = {x : ||x xi|| ϵ} for all i [n], i.e., for any x1, x2 Xi, L(θ, x1) L(θ, x2) + x L(θ, x2), x1 x2 µ 2 x1 x2 . (7) Assumption 2 can be verified based on the relationship between robust optimization and distributional robust optimization in (Sinha et al., 2017; Lee & Raginsky, 2017). Assumption 3. The concurrent stochastic gradient ˆg(θt) = 1 2|B| P|B| i=1( θL(θt; xi)+ θL(θt, ˆxi)) is bounded by the constant M: ˆg(θt) M. (8) Assumption 4. Suppose LD(θt) = 1 2n Pn i=1(L(θt, x i ) + L(θt, xi)), g(θt) = 1 2|B| P|B| i=1( θL(xi) + θL(θt, x i )) and E[g(θt)] = LD(θt), where |B| represents batch size . The variance of g(θt) is bounded by σ2: E[ g(θt) LD(θt) 2] σ2. (9) Based on the above assumptions, we can obtain the upper bound between original adversarial example x i (θt) and concurrent adversarial example x i (θt τ), where τ is the delay time. Published as a conference paper at ICLR 2022 Lemma 1. Under Assumptions 1 and 2, we have x i (θt) x i (θt τ) L µ θt θt τ Lxθ µ ητM. (10) Lemma 1 illustrates the relation between x i (θt) and x i (θt τ), which is bounded by the delay τ. When the delay is small enough, x i (θt τ) can be regarded as an approximator of x i (θt). We now establish the convergence rate as the following theorem. Theorem 1. Suppose Assumptions 1, 2, 3 and 4 hold. Let loss function LD(θt) = 1 2n Pn i=1(L(θt; x i , yi) + L(θt; xi, yi)) and ˆxi(θt τ) be the λ-solution of x i (θt τ): x i (θt τ) ˆxi(θt τ), x L(θt τ; ˆxi(θt τ), yi) λ. Under Assumptions 1 and 2, for the concurrent stochastic gradient ˆg(θ). If the step size of outer minimization is set to ηt = η = min(1/L, p /Lσ2T). Then the output of Algorithm 1 satisfies: t=0 E[|| LD(θt)||2 2] 2σ T + L2 θx 2 (τMLxθ λ µ)2, (11) where L = Lθθ + Lxθ Our result provides a formal convergence rate of Con Adv and it can converge to a first-order sta- tionary point at a sublinear rate up to a precision of L2 θx 2 ( τMLxθ λ µ)2, which is related to τ. In practice we use the smallest delay τ = 1 as discussed in the previous subsection. 4 EXPERIMENTAL RESULTS 4.1 EXPERIMENTAL SETUP Architectures and Datasets. We select Res Net as our default architectures. More specially, we use the mid-weight version (Res Net-50) to evaluate the performance of our proposed algorithm. The dataset we used in this paper is Image Net-1k, which consists of 1.28 million images for training and 50k images for testing. The convergence baseline of Res Net-50 in MLPerf is 75.9% top-1 accuracy in 90 epochs (i.e. Res Net-50 version 1.5 (Goyal et al., 2017)). Implementation Details. We use TPU-v3 for all our experiments and the same setting as the baseline. We consider 90-epoch training for Res Net-50. For data augmentation, we mainly consider Auto Aug (AA). In addition, we use LARS (You et al., 2017) to train all the models. Finally, for adversarial training, we always use 1-step PGD attack with random initialization. 4k 8k 16k 32k 64k Batch Size Throughput (images/ms) Dis Adv Con Adv (a) Res Net-50 512 1k 2k Batch Size Throughput (images/ms) Dis Adv Con Adv Baseline (b) Res Net-50 Limit Figure 4: (a): throughput on scaling up batch size for Res Net-50, (b): throughtput when the number of processors reach the limit that each batch size can use for Res Net-50 . 4.2 IMAGENET TRAINING WITH RESNET We train Res Net-50 with Con Adv and compare it with vanilla training and Dis Adv. The experimental results of scaling up batch size in Table 1 illustrates that Con Adv can obtain the similar Published as a conference paper at ICLR 2022 accuracy compared with Dis Adv and meanwhile speed up the training process. More specially, we can find that the top-1 accuracy of all methods are stable when the batch size is increased from 4k to 32k. After that, the performance starts to drop, which illustrates the bottleneck of large-batch training. However, Con Adv can improve the top-1 accuracy and the improved performance is stable as Dis Adv does when the batch size reaches the bottleneck (such as 32k, 64k, 96k), but Auto Aug gradually reaches its limitations. For instance, the top-1 accuracy increases from 74.3 to 75.3 when using Con Adv with a batch size of 96k and improved accuracy is 0.7%, 0.9% and 1.0% for 32k, 64k and 96k. However, Auto Aug cannot further improve the top-1 accuracy when the batch size is 96k. The above results illustrate that adversarial learning can successfully maintain a good test performance in the large-batch training setting and can outperform data augmentation. Table 1: Top-1 accuracy for Res Net-50 on Image Net Method 1k 4k 8k 16k 32k 64k 96k Res Net-50 76.9 76.9 76.6 76.6 76.6 75.3 74.3 Res Net-50+AA 77.5 77.5 77.4 77.1 76.9 75.6 74.3 Res Net-50+Dis Adv 77.4 77.4 77.4 77.4 77.3 76.2 75.3 Res Net-50+Con Adv 77.4 77.4 77.4 77.4 77.3 76.2 75.3 In addition, Figure 4(a) presents the throughput (images/ms) on scaling up batch size. We can observe that Con Adv can further increase throughput and accelerate the training process. To obtain accurate statistics of Batch Norm, we need to make sure each worker has at least 64 examples to calculate them (Normal Setting). Thus, the number of cores is [Batch Size / 64]. For example, we use TPU v3-256 to train Dis Adv when batch size is 32k, which has 512 cores (32k/64=512). As shown in Figure 4(a), the throughput of Dis Adv increases from 10.3 on 4k to 81.7 on 32k and Cond Adv achieve about 1.5x speedup compared with Dis Adv, which verifies our proposed Con Adv can maintain the accuracy of large-batch training and meanwhile accelerate the training process. To simulate the speedup when the number of workers reach the limit that each Batch Size can use, we use a large enough distributed system to train the model with the batch size of 512, 1k and 2k on TPU v3-128, TPU v3-256 and TPU v3-512 , respectively. The result is shown in Figure 4(b), we can obtain that Con Adv can achieve about 2x speedup compared with Dis Adv. Furthermore, in this scenario we can observe Con Adv can achieve the similar throughput as Baseline (vanilla Res Net-50 training). For example, compared with Dis Adv, the throughput increases from 36.1 to 71.3 when using Con Adv with a batch size of 2k. In addition, the throughput is 75.7, which illustrates that Con Adv can achieve a similar speed as baseline. However, Con Adv can expand to larger batch size than baseline. Therefore, Con Adv can further accelerate the training of deep neural network. 4.3 IMAGENET TRAINING WITH DATA AUGMENTATION To explore the limit of our method and evaluate whether adversarial learning can be combined with data augmentation for large-batch training, we further apply data augmentation into the proposed adversarial learning algorithm and the results are shown in Table 2. We can find that Con Adv can further improve the performance of large-batch training on Image Net when combined with Autoaug (AA). Under this setting, we can expand the batch size to more than 96k, which can improve the algorithm efficiency and meanwhile benefit the machine utilization. For instance, for Res Net, the top-1 accuracy increases from 74.3 to 76.2 under 96k when using Con Adv and Auto Aug. Table 2: Top-1 accuracy with Auto Aug on Image Net Method 1k 4k 8k 16k 32k 64k 96k Res Net-50 76.9 76.9 76.6 76.6 76.6 75.3 74.3 Res Net-50+AA 77.5 77.5 77.4 77.1 76.9 75.6 74.3 Res Net-50+Con Adv+AA 78.5 78.5 78.5 78.5 78.3 77.3 76.2 4.4 TRAINING TIME The wall clock training times for Res Net-50 are shown in Table 3. We can find that the training time gradually decreases with batch size increasing. For example, the training time of Con Adv decreases Published as a conference paper at ICLR 2022 Table 3: Training Time Analysis Method 16k 32k 64k Res Net-50 1194s 622s / Dis Adv 1657s 1191s 592s Con Adv 1277s 677s 227s from 1277s to 227s when scale the batch size from 16k to 64k. In addition, we can find that Dis Adv need about 1.5x training time compared with vanilla Res Net-50 but Con Adv can efficiently reduce the training of Dis Adv to a level similar to vanilla Res Net. For instance, the training time of Dis Adv is reduced from 1191s to 677s when using Con Adv. Noted that we don t report the clock time for vanilla Res Net-50 at 64k since the top-1 accuracy is below the MLPerf standard 75.9%. The number of machines required to measure the maximum speed at 96K exceeds our current resources. The comparison on 32k and 64k also can evaluate the runtime improvement. 4.5 GENERALIZATION GAP To the best of our knowledge, theoretical analysis of generalization errors in large-batch setting is still an open problem. However, we empirically found that our method can successfully reduce the generalization gap in large-batch training. The experimental results in Table 4 indicate that Con Adv can narrow the generalization gap. For example, the generalization gap is 4.6 for vanilla Res Net-50 at 96k and Con Adv narrows the gap to 2.7. In addition, combining Con Adv with Auto Aug, the training accuracy and test accuracy can further increase and meanwhile maintain the similar generalization gap. Table 4: Generalization Gap of Large-Batch Training on Image Net-1k Vanilla Res Net-50 Con Adv Con Adv + AA 16k 32k 64k 96k 16k 32k 64k 96k 16k 32k 64k 96k Training Accuracy 81.4 82.5 79.6 78.9 80.3 80.8 78.2 78.0 81.6 81.7 79.6 78.4 Test Accuracy 76.6 76.6 75.3 74.3 77.4 77.3 76.2 75.3 78.5 78.3 77.3 76.2 Generalization Gap 4.8 5.9 4.3 4.6 2.9 3.5 2.0 2.7 3.1 3.4 2.3 2.2 4.6 ANALYSIS OF ADVERSARIAL PERTURBATION Adversarial learning calculates an adversarial perturbation on input data to smooth the decision boundary and help the model converge to the flat minima. In this section, we analyze the effects of different perturbation values for the performance of large-batch training on Image Net. The analysis results are illustrated in Table 5. It presents that we should increase the attack intensity as the batch size increasing. For example, the best attack perturbation value increases from 3 (32k) to 7 (96k) for Res Net-50 and from 8 (16k) to 12 (64k). In addition, we should increase the perturbation value when using data augmentation. For example, the perturbation value should be 3 for the original Res Net-50 but be 5 when data augmentation is applied. Table 5: Experiment Results (Top-1 Accuracy) when useing Different Adversarial Perturbation. Method Batch Size p=0 p=1 p=2 p=3 p=4 p=5 p=6 p=7 p=8 p=9 p=10 p=12 Res Net-50 + Con Adv 32K 76.8 77.2 77.3 77.4 77.3 77.3 77.3 77.3 77.3 77.3 77.2 77.2 Res Net-50 + Con Adv + AA 32K 77.8 78.0 78.1 78.1 78.0 78.3 78.2 78.2 78.2 78.2 78.2 78.1 Res Net-50 + Con Adv 64K 75.7 76.2 76.3 76.3 76.4 76.7 76.4 76.4 76.4 76.4 76.4 76.3 Res Net-50 + Con Adv + AA 64K 76.8 77.0 76.8 77.0 77.1 77.1 77.2 77.4 77.2 77.1 77.1 77.1 Res Net-50 + Con Adv 96K 74.6 75.1 75.1 75.1 75.3 75.1 75.1 75.1 75.3 75.2 75.1 75.1 Res Net-50 + Con Adv + AA 96K 75.8 75.9 75.8 76.0 76.0 76.0 76.0 76.1 76.2 76.2 76.0 76.0 5 CONCLUSIONS We firstly analyze the effect of data augmentation for large-batch training and propose a novel distributed adversarial learning algorithm to scale to a larger batch size. To reduce the overhead of adversarial learning, we further propose a novel concurrent adversarial learning to decouple the two sequential gradient computations in adversarial learning. We evaluate our proposed method on Res Net. The experimental results show that our proposed method is beneficial for large-batch training. Published as a conference paper at ICLR 2022 6 ACKNOWLEDGEMENTS We thank Google TFRC for supporting us to get access to the Cloud TPUs. We thank CSCS (Swiss National Supercomputing Centre) for supporting us to get access to the Piz Daint supercomputer. We thank TACC (Texas Advanced Computing Center) for supporting us to get access to the Longhorn supercomputer and the Frontera supercomputer. We thank Lux Provide (Luxembourg national supercomputer HPC organization) for supporting us to get access to the Melu Xina supercomputer. CJH and XC are partially supported by NSF under IIS-2008173, IIS-2048280 and by Army Research Laboratory under agreement number W911NF-20-2-0158. 7 ETHICS STATEMENT We do not have any potential ethics issues in this paper. We hope to propose a novel distributed adversarial learning algorithm to accelerate large-batch training. 8 REPRODUCIBILITY STATEMENT we list our main hyperparameters for large-batch training in Appendix A.4 (Table 6). For experimental details, we introduce our experiment settings in section 4, such as the dataset, model architecture, data augmentation, optimizer and so on. Takuya Akiba, Shuji Suzuki, and Keisuke Fukuda. Extremely large minibatch sgd: Training resnet50 on imagenet in 15 minutes. ar Xiv preprint ar Xiv:1711.04325, 2017. Maksym Andriushchenko and Nicolas Flammarion. Understanding and improving fast adversarial training. In Advances in Neural Information Processing Systems, 2020. Maxim Berman, Herv e J egou, Andrea Vedaldi, Iasonas Kokkinos, and Matthijs Douze. Multigrain: a unified image embedding for classes and instances. ar Xiv preprint ar Xiv:1902.05509, 2019. Xiangning Chen, Cihang Xie, Mingxing Tan, Li Zhang, Cho-Jui Hsieh, and Boqing Gong. Robust and accurate object detection via adversarial learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 16622 16631, June 2021. Minhao Cheng, Zhe Gan, Yu Cheng, Shuohang Wang, Cho-Jui Hsieh, and Jingjing Liu. Adversarial masking: Towards understanding robustness trade-off for generalization, 2021. URL https: //openreview.net/forum?id=LNt TXJ9XXr. Valeriu Codreanu, Damian Podareanu, and Vikram Saletore. Scale out for large minibatch sgd: Residual network training on imagenet-1k with improved accuracy and reduced time to train. ar Xiv preprint ar Xiv:1711.04291, 2017. Ekin D. Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, and Quoc V. Le. Autoaugment: Learning augmentation strategies from data. In IEEE Conference on Computer Vision and Pattern Recognition, 2019. J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, and L. Fei-Fei. Image Net: A Large-Scale Hierarchical Image Database. In IEEE Conference on Computer Vision and Pattern Recognition, 2009. Aditya Devarakonda, Maxim Naumov, and Michael Garland. Adabatch: Adaptive batch sizes for training deep neural networks. ar Xiv preprint ar Xiv:1712.02029, 2017. Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: pre-training of deep bidirectional transformers for language understanding. In Jill Burstein, Christy Doran, and Thamar Solorio (eds.), Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, NAACL-HLT, 2019. Published as a conference paper at ICLR 2022 Ian Goodfellow, Jonathon Shlens, and Christian Szegedy. Explaining and harnessing adversarial examples. In International Conference on Learning Representations, 2015. Priya Goyal, Piotr Doll ar, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch sgd: Training imagenet in 1 hour. ar Xiv preprint ar Xiv:1706.02677, 2017. Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In IEEE Conference on Computer Vision and Pattern Recognition, 2016. Elad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better: closing the generalization gap in large batch training of neural networks. ar Xiv preprint ar Xiv:1705.08741, 2017. Forrest N Iandola, Matthew W Moskewicz, Khalid Ashraf, and Kurt Keutzer. Firecaffe: nearlinear acceleration of deep neural network training on compute clusters. In IEEE Conference on Computer Vision and Pattern Recognition, 2016. Xianyan Jia, Shutao Song, Wei He, Yangzihao Wang, Haidong Rong, Feihu Zhou, Liqiang Xie, Zhenyu Guo, Yuanzhou Yang, Liwei Yu, et al. Highly scalable deep learning training system with mixed-precision: Training imagenet in four minutes. ar Xiv preprint ar Xiv:1807.11205, 2018. Hamid Karimi, Tyler Derr, and Jiliang Tang. Characterizing the decision boundary of deep neural networks. ar Xiv preprint ar Xiv:1912.11460, 2019. Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. In International Conference on Learning Representations, 2017. Sameer Kumar, Yu Wang, Cliff Young, James Bradbury, Naveen Kumar, Dehao Chen, and Andy Swing. Exploring the limits of concurrency in ml training on google tpus. Machine Learning and Systems, 2021. Jaeho Lee and Maxim Raginsky. Minimax statistical learning with wasserstein distances. ar Xiv preprint ar Xiv:1705.07815, 2017. Mu Li. Scaling distributed machine learning with system and algorithm co-design. Ph D thesis, Ph D thesis, Intel, 2017. Yong Liu, Siqi Mai, Xiangning Chen, Cho-Jui Hsieh, and Yang You. Towards efficient and scalable sharpness-aware minimization. ar Xiv preprint ar Xiv:2203.02714, 2022. Yuxuan Lou, Fuzhao Xue, Zangwei Zheng, and Yang You. Sparse-mlp: A fully-mlp architecture with conditional computation. ar Xiv preprint ar Xiv:2109.02008, 2021. Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, and Adrian Vladu. Towards deep learning models resistant to adversarial attacks. ar Xiv preprint ar Xiv:1706.06083, 2017. James Martens and Roger Grosse. Optimizing neural networks with kronecker-factored approximate curvature. In International conference on machine learning. PMLR, 2015. Peter Mattson, Christine Cheng, Cody Coleman, Greg Diamos, Paulius Micikevicius, David Patterson, Hanlin Tang, Gu-Yeon Wei, Peter Bailis, Victor Bittorf, et al. Mlperf training benchmark. ar Xiv preprint ar Xiv:1910.01500, 2019. Jieru Mei, Yucheng Han, Yutong Bai, Yixiao Zhang, Yingwei Li, Xianhang Li, Alan Yuille, and Cihang Xie. Fast advprop. In International Conference on Learning Representations, 2022. URL https://openreview.net/forum?id=hcosws DHNAW. Seyed-Mohsen Moosavi-Dezfooli, Alhussein Fawzi, Jonathan Uesato, and Pascal Frossard. Robustness via curvature regularization, and vice versa. In IEEE Conference on Computer Vision and Pattern Recognition, 2019. Published as a conference paper at ICLR 2022 Kazuki Osawa, Yohei Tsuji, Yuichiro Ueno, Akira Naruse, Rio Yokota, and Satoshi Matsuoka. Second-order optimization method for large mini-batch: Training resnet-50 on imagenet in 35 epochs. ar Xiv preprint ar Xiv:1811.12019, 2018. Nicolas Papernot, Patrick Mc Daniel, and Ian Goodfellow. Transferability in machine learning: from phenomena to black-box attacks using adversarial samples. ar Xiv preprint ar Xiv:1605.07277, 2016. Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. Language models are unsupervised multitask learners. Open AI blog, 2019. Ali Shafahi, Mahyar Najibi, Amin Ghiasi, Zheng Xu, John Dickerson, Christoph Studer, Larry S Davis, Gavin Taylor, and Tom Goldstein. Adversarial training for free! ar Xiv preprint ar Xiv:1904.12843, 2019. Christopher J Shallue, Jaehoon Lee, Joseph Antognini, Jascha Sohl-Dickstein, Roy Frostig, and George E Dahl. Measuring the effects of data parallelism on neural network training. ar Xiv preprint ar Xiv:1811.03600, 2018. Aman Sinha, Hongseok Namkoong, and John Duchi. Certifiable distributional robustness with principled adversarial training. ar Xiv preprint ar Xiv:1710.10571, 2017. Samuel L Smith, Pieter-Jan Kindermans, Chris Ying, and Quoc V Le. Don t decay the learning rate, increase the batch size. ar Xiv preprint ar Xiv:1711.00489, 2017. Yisen Wang, Xingjun Ma, James Bailey, Jinfeng Yi, Bowen Zhou, and Quanquan Gu. On the convergence and robustness of adversarial training. In ICML, volume 1, pp. 2, 2019. Eric Wong, Leslie Rice, and J Zico Kolter. Fast is better than free: Revisiting adversarial training. ar Xiv preprint ar Xiv:2001.03994, 2020. Cihang Xie, Mingxing Tan, Boqing Gong, Jiang Wang, Alan L Yuille, and Quoc V Le. Adversarial examples improve image recognition. In IEEE Conference on Computer Vision and Pattern Recognition, 2020. Fuzhao Xue, Ziji Shi, Futao Wei, Yuxuan Lou, Yong Liu, and Yang You. Go wider instead of deeper. ar Xiv preprint ar Xiv:2107.11817, 2021. Masafumi Yamazaki, Akihiko Kasagi, Akihiro Tabuchi, Takumi Honda, Masahiro Miwa, Naoto Fukumoto, Tsuguchika Tabaru, Atsushi Ike, and Kohta Nakashima. Yet another accelerated sgd: Resnet-50 training on imagenet in 74.7 seconds. ar Xiv preprint ar Xiv:1903.12650, 2019. Zhewei Yao, Amir Gholami, Daiyaan Arfeen, Richard Liaw, Joseph Gonzalez, Kurt Keutzer, and Michael Mahoney. Large batch size training of neural networks with adversarial training and second-order information. ar Xiv preprint ar Xiv:1810.01021, 2018a. Zhewei Yao, Amir Gholami, Qi Lei, Kurt Keutzer, and Michael W Mahoney. Hessian-based analysis of large batch training and robustness to adversaries. ar Xiv preprint ar Xiv:1802.08241, 2018b. Chris Ying, Sameer Kumar, Dehao Chen, Tao Wang, and Youlong Cheng. Image classification at supercomputer scale. ar Xiv preprint ar Xiv:1811.06992, 2018. Yang You, Igor Gitman, and Boris Ginsburg. Scaling sgd batch size to 32k for imagenet training. ar Xiv preprint ar Xiv:1708.03888, 2017. Yang You, Zhao Zhang, Cho-Jui Hsieh, James Demmel, and Kurt Keutzer. Imagenet training in minutes. In International Conference on Parallel Processing, 2018. Yang You, Jonathan Hseu, Chris Ying, James Demmel, Kurt Keutzer, and Cho-Jui Hsieh. Largebatch training for lstm and beyond. In Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis, 2019. Dinghuai Zhang, Tianyuan Zhang, Yiping Lu, Zhanxing Zhu, and Bin Dong. You only propagate once: Accelerating adversarial training via maximal principle. ar Xiv preprint ar Xiv:1905.00877, 2019. Published as a conference paper at ICLR 2022 A.1 THE PROOF OF LEMMA 1: This completes the proof. The proof is inspired by Sinha et al. (2017); Wang et al. (2019). Under Assumptions 1 and 2, we have x i (θt) x i (θt τ) Lxθ µ θt θt τ Lxθ where x i (θt) and x i (θt τ) denote the adversarial example of xi calculated by θt and stale weight θt τ, respectively. According to Assumption 2, we have L(θt, x i (θt τ)) L(θt, x i (θt)) + x L(θt, x i (θt)), x i (θt τ) x i (θt) µ 2 x i (θt τ) x i (θt) 2, L(θt, x i (θt)) µ 2 x i (θt τ) x i (θt) 2 (13) In addition, we have L(θt, x i (θt)) L(θt, x i (θt τ)) + x L(θt, x i (θt τ)), x i (θt) x i (θt τ) µ 2 x i (θt τ) x i (θt) 2 (14) Combining (13) and (14), we can obtain: µ x i (θt τ) x i (θt) 2 x L(θt, x i (θt)), x i (θt) x i (θt τ) x L(θt, x i (θt τ)) x L(θt τ, x i (θt τ)), x i (θt) x i (θt τ) x L(θt, x i (θt τ)) x L(θt τ, x i (θt τ)) x i (θt) x i (θt τ) Lxθ θt θt τ x i (θt) x i (θt τ) (15) where the second inequality is due to x L(θt τ, x i (θt τ)), x i (θt) x i (θt τ) 0, the third inequality holds because Cauchy Schwarz inequality and the last inequality follows from Assumption 1. Therefore, x i (θt) x i (θt τ) Lxθ j [1,τ] (θt j+1 θt j) j [1,τ] ηˆgt j(xi)) where the second inequality follows the calculation of delayed weight, the third inequality holds because the difference of weights is calculated with gradient ˆgt j(j [1, τ]) and the last inequality holds follows Assumption 3. Published as a conference paper at ICLR 2022 x i (θt) x i (θt τ) Lxθ This completes the proof. Lemma 2. Under Assumptions 1 and 2, we have LD(θ) is L-smooth where L = Lθθ + Lxθ 2µ Lθx, i.e., for any θ1 and θ2, we can say θLD(θt) θLD(θt τ) L θt θt τ (18) LD(θt) = LD(θt τ) + LD(θt τ), θt θt τ + L t τ θt θt τ (19) Based on Lemma 1, we can obtain: x i (θt) x i (θt τ) Lxθ µ θt θt τ (20) We can obtain for i [n]: θL(θt, x i (θt)) θL(θt τ, x i (θt τ)) θL(θt, x i (θt)) θL(θt, x i (θt τ)) + θL(θt, x i (θt τ)) θL(θt τ, x i (θt τ)) Lθθ θt θt τ + Lθx x i (θt) x i (θt τ) Lθθ θt θt τ + Lθx Lxθ =(Lθθ + Lθx Lxθ µ ) θt θt τ where the second inequality holds because Assumption 1, the third inequality holds follows Lemma 1. LD(θt) LD(θt τ) 2 1 i=1 ( θL(θt, xi) + θL(θt, x i (θt))) i=1 ( θL(θt τ, xi) + θL(θt τ, x i (θt τ))) i=1 θL(θt, xi) θL(θt τ, xi) i=1 θL(θt, x i (θt)) θL(θt τ, x i (θt τ)) 2Lθθ θt θt τ + 1 2(Lθθ + Lθx Lxθ µ ) θt θt τ = (Lθθ + Lxθ 2µ Lθx) θt θt τ This completes the proof. Published as a conference paper at ICLR 2022 Lemma 3. Let ˆxi(θt) be the λ-solution of x i (θt): x i (θt) ˆxi(θt), x L(θt, ˆxi(θt)) λ. Under Assumptions 1 and 2, for the concurrent stochastic gradient ˆg(θ), we have g(θt) ˆg(θt τ) Lθx g(θt) ˆg(θt τ) = 1 i |B| ( θL(θt, xi) + θL(θt, x i (θt)) ( θL(θt, xi) + θL(θt, ˆxi(θt τ))) i |B| ( θL(θt, x i (θt)) θL(θt, ˆxi(θt τ))) i |B| θL(θt, x i (θt)) θL(θt, ˆxi(θt τ)) i |B| Lθx x i (θt) ˆxi(θt τ) i |B| Lθx x i (θt) x i (θt τ) + x i (θt τ) ˆxi(θt τ) i |B| (Lθx x i (θt) x i (θt τ) + Lθx x i (θt τ) ˆxi(θt τ) ) i |B| (Lθx x i (θt) x i (θt τ) + Lθx x i (θt τ) ˆxi(θt τ) ) i |B| (LθxητM Lxθ µ + Lθx x i (θt τ) ˆxi(θt τ) ) Let ˆxi(θt τ) be the λ-approximate of x i (θt τ), we can obtain: x i (θt τ) ˆxi(θt τ), θL(θt τ; ˆxi(θt τ)) δ (25) In addition, we can obtain: ˆxi(θt τ) x i (θt τ), x L(θt τ, x i (θt τ)) 0 (26) Combining 25 and 26, we have: x i (θt τ) ˆxi(θt τ), θL(θt τ; ˆxi(θt τ)) x L(θt τ, x i (θt τ)) λ (27) Based on Assumption 2, we have µ x i (θt τ) ˆxi(θt τ) 2 x L(θt τ, x i (θt τ)) x L(θt τ, ˆxi(θt τ), ˆxi x i (θt τ)) (28) Combining 28 with 27, we can obtain: µ x i (θt τ) ˆx(θt τ) 2 λ (29) Published as a conference paper at ICLR 2022 Therefore, we have x i (θt τ) ˆx(θt τ) Thus, we can obtain g(θt) ˆg(θt τ) Lθx This completes the proof. A.3 THE PROOF OF THEOREM 1: Suppose Assumptions 1, 2, 3 and 4 hold. Let = LD(θ0) minθ LD(θ), LD(θt) = 1 2n Pn i=1( L(θt, xi, yi) + L(θt, x i , yi)). If the step size of outer minimization is set to ηt = η = min(1/L, p /Lσ2T), where L = Lθθ + Lxθ 2µ Lθx. Then the output of Algorithm 1 satisfies: t=0 E[ LD(θt) 2] 2σ T + L2 θx 2 (τMLxθ where L = (MLθx Lxθ/ϵµ + Lθθ). LD(θt+1) LD(θt) + LD(θt), θt+1 θt + L 2 θt+1 θt 2 = LD(θt) η LD(θt) 2 + Lη2 2 ||ˆg(θt) 2 2 + η LD(θt), LD(θt) ˆg(θt) = LD(θt) η(1 Lη 2 ) LD(θt) 2 + η(1 Lη) LD(θt), LD(θt) ˆg(θt) 2 ˆg(θt) LD(θt) 2 = LD(θt) η(1 Lη 2 ) LD(θt) 2 + η(1 Lη) LD(θt), g(θt) ˆg(θt) + η(1 Lη) LD(θt), LD(θt) g(θt) + Lη2 2 ˆg(θt) g(θt) + g(θt) LD(θt) 2 2 LD(θt) 2 2 + η 2(1 Lη) ˆg(θt) g(θt) 2 + η(1 Lη) LD(θt), LD(θt) g(θt) + Lη2( ˆg(θt) g(θ) 2 2 + g(θt) LD(θt) 2) 2|| LD(θt) 2 2 + η 2(1 + Lη) ˆg(θt) g(θt) 2 + η(1 Lη) LD(θt), LD(θt) g(θt) + Lη2 g(θt) LD(θt) 2 2) (33) Taking expectation on both sides of the above inequality conditioned on θt, we can obtain: Published as a conference paper at ICLR 2022 E[LD(θt+1) LD(θt)|θt] η 2|| LD(θt) 2 + η 2(1 + Lη)(Lθx λ µ))2 + Lη2σ2 2 LD(θt) 2 + η 8(1 + Lη)(Lθx(ητM Lxθ λ µ))2 + Lη2σ2 2 LD(θt) 2 + ηL2 θx 8 (1 + Lη)(ητM Lxθ λ µ)2 + Lη2σ2 where we used the fact that E[g(θt)] = LD(θt), Assumption 2, Lemma 2 and Lemma 3. Taking the sum of (34) over t = 0, ..., T 1, we obtain that: η 2E[ LD(θt) 2] E[LD(θ0) LD(θT )]+ ηL2 θx 8 (1+Lη)(ητM Lxθ Choose η = min(1/L, q T Lσ2 ) where = LD(θ0) LD(θT ) and L = Lθθ + Lxθ 2µ Lθx, we can show that: t=0 E[ LD(θt) 2] 2σ T + L2 θx 2 (τMLxθ A.4 HYPERPARAMETERS HYPERPARAMETERS: More specially, our main hyperparameters are shown in Table 6. Table 6: Hyperparameters of Res Net-50 on Image Net 32k 64k 96k Peak LR 35.0 41.0 43.0 Epoch 90 90 90 Weight Decay 5E-4 5E-4 5E-4 Warmup 40 41 41 LR decay POLY POLY POLY Optimizer LARS LARS LARS Momentum 0.9 0.9 0.9 Label Smoothing 0.1 0.1 0.1