# unit_scaling_outofthebox_lowprecision_training__62ab0729.pdf Unit Scaling: Out-of-the-Box Low-Precision Training Charlie Blake 1 Douglas Orr 1 Carlo Luschi 1 1. Abstract We present unit scaling, a paradigm for designing deep learning models that simplifies the use of low-precision number formats. Training in FP16 or the recently proposed FP8 formats offers substantial efficiency gains, but can lack sufficient range for out-of-the-box training. Unit scaling addresses this by introducing a principled approach to model numerics: seeking unit variance of all weights, activations and gradients at initialisation. Unlike alternative methods, this approach neither requires multiple training runs to find a suitable scale nor has significant computational overhead. We demonstrate the efficacy of unit scaling across a range of models and optimisers. We further show that existing models can be adapted to be unit-scaled, training BERTLARGE in FP16 and then FP8 with no degradation in accuracy. 2. Introduction The development of algorithms that efficiently leverage available hardware has been key to the substantial advances seen in deep learning over the last decade (Sutton, 2019; Hooker, 2021). With the increase in size of state-of-the-art models, hardware-efficiency is also motivated by the need to lower the costs of training. These have grown to become substantial in terms of money, time, and environmental impact (Strubell et al., 2019; Chowdhery et al., 2022; Luccioni et al., 2022). However, with the end of Moore s law and Dennard scaling (Esmaeilzadeh et al., 2011; Theis and Wong, 2017), increased transistor density can no longer be relied upon to provide a simple path towards greater efficiency, and other techniques must be leveraged. One such technique is the use of low-precision number formats. The gains to be had here are considerable: compute, memory and bandwidth usage all depend on the bit-width of a format. 1Graphcore Research, United Kingdom. Correspondence to: Charlie Blake , Douglas Orr . Proceedings of the 40 th International Conference on Machine Learning, Honolulu, Hawaii, USA. PMLR 202, 2023. Copyright 2023 by the author(s). matmul matmul Ge LU 9 6 3 0 -3 -6 -9 -12 w1 w1 x1 x1 9 6 3 0 -3 -6 -9 -12 x2 x2 x3 x3 w2 w2 x4 x4 Loss Scaling Unit Scaling Exponent value Figure 1. Above: Unit scaling of an FFN layer. We multiply each tensor by a fixed scalar to achieve consistent scale, no longer requiring a loss scale to control the scale of x4. Hyperparameters here are the same as those in our BERTLARGE experiments (Table A.5). Below: A histogram of exponent values at initialisation for the above FFN, with shade indicating bin density. The y-axis reflects exponent values available in FP16, while dashed lines show the max/min exponents of the FP8 E4 format of Noune et al. (2022). Unlike inference, where integer quantisation is possible (Jacob et al., 2018), for training, floating point formats are required (Noune et al., 2022; Micikevicius et al., 2022; Kuzmin et al., 2022). The traditional approach of using 32bit floats is being superseded by mixed precision strategies, which place many values into 16-bit formats (Micikevicius et al., 2018). Furthermore, 8-bit floating-point hardware is becoming available (Graphcore, 2022; Nvidia, 2022), with the potential for accurate 8-bit training already demonstrated (Wang et al., 2018; Sun et al., 2019; Noune et al., 2022; Micikevicius et al., 2022). However, the use of low-precision formats introduces new difficulties, reducing the absolute range of representable values and increasing quantisation noise. Existing techniques to address these issues either introduce additional overhead or require manual tuning. An approach is needed which is both accurate and places minimal burden on the user. Unit Scaling 2 30 2 20 2 10 20 210 220 230 Scale σ FP16 FP8 E4 FP8 E5 Figure 2. The signal to noise ratio (SNR) of samples from a normal distribution, quantised in FP16 and FP8, as a function of the distribution s scale. To this end, we present unit scaling: a technique for model design that operates on the principle of ideal scaling at initialisation (unit variance for activations, weights and gradients). This is achieved by considering how each operation in the model affects the variance of different tensors, and introducing fixed scaling factors to counteract changes. Empirically, we show that unit scaling aligns values much closer to the centre of the representable range than conventional loss scaling (Micikevicius et al., 2018), and removes the need for a scaling hyperparameter to be swept. None of our experiments require dynamic re-scaling of values, indicating robustness to shifting distributions during training. 2.1. Contributions In this paper we make the following contributions: 1. We provide an analysis of how scale changes as a result of operations within a typical model, and the challenges this introduces for low-precision training. 2. We present unit scaling: a method for combating changes in scale, along with an implementation recipe and code examples. 3. We validate unit scaling empirically across a range of models and optimisers. 4. For the first time, we show training of BERTBASE and BERTLARGE in FP16 without loss scaling. We then go a step further, training successfully in FP8, still without degradation. We emphasise that our method works out-of-the-box, with no extra sweeps or hyperparameters, demonstrating the effectiveness of unit scaling for simplifying the use of lowprecision formats. 3. Background 3.1. Floating-point formats for deep learning Definition The conventional representation used for floating point numbers is defined by the IEEE 754 standard (IEEE, 2019). In this standard, a binary floating point format can be defined by specifying the number of exponent bits, E, and the number of mantissa bits, M. A value within such a format is defined by a sign bit, exponent and mantissa value. Each is represented using a bit-string of the requisite length (with values bsign, bexp, bmant respectively), which are interpreted as follows: exponent = bexp bias, (bias = 2E 1 1) mantissa = 1 + bmant value = ( 1)bsign 2exponent mantissa There are also a small number of special values which denote bit-strings to which the above interpretation does not apply. These represent infinities, Na N (not-a-number) and a range of subnormal numbers which allow for the representation of even smaller (absolute) values. Common floating point formats used in machine learning that implement the IEEE 754 standard are shown in Table A.1. The term low precision typically refers to all formats requiring fewer than 32 bits. More recently, two kinds of FP8 format have been proposed, which we term E4 and E5, i.e. (E, M) = (4, 3) or (5, 2). These are similar to the IEEE 754 standard, but contain differences, especially for the representation of special values. These formats are covered in detail in Appendix B. Quantisation error Formats with more exponent bits are able to represent a wider range of values, whereas those with more mantissa bits have smaller gaps between represented values. This trade-off between range and precision can be framed in terms of quantisation error. This consists of two terms: the loss of accuracy due to values lying outside the absolute range of a format (overflow or underflow) is termed the clipping error (or saturation error), whereas the loss of accuracy due to values lying between representable numbers is termed the rounding error. We demonstrate the effect quantisation error has for different formats in Figure 2. This shows the signal to noise ratio (SNR) of normally distributed values X N(0, σ2) quantised in FP16 and FP8 as σ varies. SNR measures the faithful reproduction of an input (signal) versus the error (noise) introduced, defined as E[X2]/ E[(q(X) X)2], where q( ) is the quantisation function mapping an input to the nearest representable value. The heights of the SNR curves reflect the level of rounding error incurred by each format, and the widths reflect the range in which they are free of clipping error. With the exception of subnormal numbers (which slope away on the left-hand-side), the height of each format s SNR curve is roughly constant. This reflects the fact that exponents are evenly distributed, giving a relative rounding error that is approximately uniform. Unit Scaling Table 1. A comparison of techniques for low precision training. indicates that this method ideally requires no tuning, but in practice may introduce hyperparameters that need to be swept. Method Fine-grained scaling No tuning required Adapts during training Loss scaling Automatic loss scaling Automatic per-tensor scaling Unit scaling 3.2. Trade-offs of low-precision training Drawbacks The two common 16-bit formats, FP16 and BFLOAT16, offer different trade-offs: FP16 has more precision, but BFLOAT16 has more range. As a result FP16 is more prone to clipping error, requiring careful scaling, and BFLOAT suffers more from rounding error, which in some cases can degrade model accuracy (e.g. Rae et al., 2021). For FP8 there is a reduction in both range and precision. For range, the same techniques used to train in FP16 are required, and for precision, the use of FP8 has thus far been restricted to only the inputs of matmul (matrix multiply) operations (Sun et al., 2019; Noune et al., 2022; Micikevicius et al., 2022), with 3 mantissa bits typically required for weights and activations, and 2 mantissa bits for gradients. Benefits The potential efficiency gains when using lowprecision formats are substantial. These include memory usage (often a limiting factor for large models), bandwidth usage (the main overhead for low-arithmetic-intensity ops), compute (the main overhead for high-arithmetic-intensity ops) and cross-device communication (a substantial overhead for distributed training). 3.3. Low-precision training techniques Here we analyse existing techniques for addressing the challenges of low precision training. Table 1 provides a summary of their trade-offs and a comparison with unit scaling. Mixed precision Mixed precision is the use of multiple number formats with different bit-widths. This differs from the traditional approach of placing all values in FP32, with Micikevicius et al. (2018) showing that most activations, weights and gradients (collectively, tensors) can be put in FP16 with no loss in accuracy, with the exception of master weights that are often kept in FP32. Mixed precision training is also possible in BFLOAT16 (Kalamkar et al., 2019). By training in FP8 we mean that matmuls are performed in FP8 (inputs are cast down to FP8, with outputs in higher precision) with wider formats typically used elsewhere, following the lead of Sun et al. (2019); Noune et al. (2022) and Micikevicius et al. (2022). FP8 reduces both precision and range, and has not generally been used for other operations as matmuls benefit most from using low-precision formats. Mixed precision training is complementary to unit scaling all of our experiments use some form of mixed precision. Loss scaling Reduced range in FP16 and FP8 is particularly challenging for the backward pass, where standard model-design practices lead to gradients that risk underflow. To combat this, Micikevicius et al. (2018) have observed that the loss can be multiplied by a scalar to increase the scale of gradients, where weight gradients are then divided by the same scalar in the optimiser. This is valid due to the linearity of the backward pass implicit in the chain rule. Loss scaling is often essential to accurate mixed precision training in FP16 and FP8. However, there is no theoretical motivation for the choice of loss scale, which instead must be found empirically. This comes with a number of downsides. Firstly, a hyperparameter sweep must be conducted to find the loss scale value. This can require multiple full runs, as insufficient loss scales may only become apparent later in training. Secondly, it s not clear ahead-of-time what changes require the loss scale to be re-swept. Thirdly, as loss scaling only applies a single, global scaling factor, it has no mechanism to combat differences in scale between gradient tensors. For some models this difference may be too large for effective training. Automatic loss scaling The dynamic adjustment of the loss scale during training is termed automatic loss scaling (Kuchaiev et al., 2018). This can remove the need to sweep the initial loss scale, and combats shifts in tensor distributions during training. The combination of automatic loss scaling and automatic selection of number formats, is termed automatic mixed precision (Py Torch, 2023). Unit scaling doesn t specify tensors formats, so can be used in systems that automate it. Per-tensor scaling To address the inherent scaling difficulties of FP8 training, Micikevicius et al. (2022) propose a per-tensor scaling system, re-scaling locally based on runtime statistics. Unit Scaling Like unit scaling, at the beginning of training this technique may be able to achieve well-scaled tensors throughout the model. However, additional compute, memory, bandwidth and cross-device communication costs may be incurred by the recording of statistics (see Section 8 for a more detailed discussion of the potential compute overheads incurred by each of these schemes). 4. Analysis For normally distributed tensors we use the term scale to refer to standard deviation. We observe minimal change (relative to the range of our formats) of the mean. Scale therefore characterises the probability of clipping error given a format, as too large or small a scale will lead to values that lie outside of the representable range. Ideal scaling Given we are able to influence the scale of tensors at the start of training, the questions arises what scale should we aim for? As suggested by Figure 2, we argue that unit scale, σ = 1 is a sweet spot representing a sensible compromise between several competing factors. We address this question further in Appendix C. Is scale predictable? The ability to predict the scales of tensors in a deep learning model would give us a powerful tool to address clipping error. This is hard in general, but the problem is simpler at initialisation. Before any training steps, parameters are drawn from known initialisation distributions, so if the input distribution is known, analysis or simulation can derive the scale of each tensor. A further simplification is to make local distributional assumptions for a single layer in the model and consider the propagation of scale through the model. This permits a methodical analysis: first, characterise the scaling effect of each operation independently; second, propagate scales through the computational graph, forwards and backwards. We provide an example of such analysis in Appendix E.1. Scaling at initialisation Since the initial distribution of parameters is directly controlled by the model designer, the dominant approach to scaling is to select initial parameter variance to trade off forward and backward pass variance scaling (Glorot and Bengio, 2010; He et al., 2015). Such schemes were developed to avoid exploding/vanishing gradients in deep multilayer perceptrons. As such, they do not seek to constrain the scale of parameters and parameter gradients. They are also limited to computations where scale factors can be moved into trainable parameters. Example: BERT (Devlin et al., 2019) BERT s initialisation scheme does not use the rules of Glorot and Bengio (2010), instead initialising all non-bias parameters from N(0, (0.02)2). It also adopts a scaling factor from the Transformer (Vaswani et al., 2017), which scales the product of activation matrices QK , Q, K Rs d by 1/ We instrument the model to record histograms of all tensors at the start and end of training, and plot the results in Figures A.4 and A.6. In light of this analysis, we can understand loss scaling as simply enacting a shift of the gradx and gradw histograms by log2(loss scale) bits to the right, trading off underflow and overflow globally across gradient tensors. BERT with loss scaling illustrates the drawbacks of having just three scales: weight initialisation scale, loss scale, and QK scale. These are not sufficient to centre most tensors distributions in the representable range. 5. Unit Scaling Based on our analysis of the scaling within typical models and the limitations of existing methods for managing scale, we present unit scaling. A model is said to be unit-scaled if its activations, weight and gradients have approximately unit variance at initialisation. We achieve this by inserting scaling factors into the forward and backward passes. Like loss scaling, our modification of the backward pass still ensures correct gradients up to a constant multiplicative factor. However, unlike loss scaling, unit scaling determines these scales based on a set of rules for each operation, rather than a single hyperparameter to be found empirically, or via an adaptive algorithm. The scales chosen enable each operation to approximately preserve the variance of its inputs. This effect then propagates through the model, giving global unit-scaling. By concentrating values in approximately the centre of the exponent range at initialisation, we give tensors headroom to potentially shift during training without going out-of-range. Unit scaling does not address the issue of adapting scales during training. We anticipate that unit scale is sufficient to avoid numerical instability for many models, and observe this in all our experiments. We leave to further work a full investigation of where dynamic re-scaling is required, and how to integrate such a scheme into unit scaling. 5.1. A framework for scaling computational graphs Computational Graphs We take our model to be represented by the differentiable function fmodel(x1, . . . , xm), itself a composition of differentiable functions f1, . . . , fn. We can describe the structure of such a model using a directed acyclic graph (DAG) denoted G = (V, E), with the property that the vertex vi V corresponds to the function fi for each i {1, . . . n}, and where the vector-valued Unit Scaling output of function fa used as an input to function fb is represented by the edge (va, vb) E. This kind of graph is commonly known as a computational graph, with vertices as nodes and their corresponding functions as ops. Forward and backward graphs We refer to the computational graph corresponding to fmodel as the forward graph. In deep learning we typically apply reverse-mode automatic differentiation to the forward graph to create a second computational graph whose output nodes represent the partial derivatives of the model with respect to its inputs: fmodel xi , i [1..m]. We call this the backward graph. The backward graph mirrors the structure of the forward graph, but with edge directions reversed. Thus each op f in the forward graph corresponds to a new op fgrad in the backward graph. This op computes the gradient of the model up to f by calculating the product of the incoming gradient g from the previous grad op and the partial derivatives of f evaluated at its inputs: fgrad(x1, . . . , xk, g)j g f xj (x1, . . . , xk), j [1..k]. Scaled ops Given an op f(x1, . . . , xk), we define the scaled op f (x1, . . . , xk, α, β1, . . . , βk) with scaling factors α, β1, . . . , βk R+, such that: f α f(x1, . . . , xk), f grad(x1, ..xk, g)i βi fgrad(x1, ..xk, g)i, i [1..k]. Proposition 5.1. For any scaled op, there is an equivalent unscaled op with the same training dynamics under a firstorder optimiser. We demonstrate this for SGD and Adam in Appendix E.2. Scaled computational graph A scaled computational graph is one where every op f in the forward graph is replaced by a scaled equivalent f , with the backward graph then generated to produce f grad for each fgrad, using any choice of scaling factors. If we can show that a scaled computational graph represents a scaled op, by Proposition 5.1, we are within a reparameterisation of regular training. Unfortunately, this is not true for scaled computational graphs in general, for example h (x) x+f (x, α, β) is not a scaled op for some choices of the scaled op f and when α = β (see Appendix E.3). Constraint-scaled computational graphs We denote the set of edges in the forward graph that are cut-edges1 1A cut-edge is an edge in the equivalent undirected graph where the number of connected components increases upon its deletion. as C E. A constraint-scaled computational graph is a scaled computational graph where we restrict the scaling factors of ops that consume non-cut-edge variables in the following way: for any edge e C, we require the op consuming the variable xe to have scaling factors α = βe. Theorem 5.2. A constraint-scaled computational graph itself represents a scaled op. Proven in Appendix E.4. This is sufficient to show that we ve achieved the property we set out to: valid gradients, up to a constant multiplicative factor. 5.2. A scaling strategy for unit variance Unit scaled computational graphs We define a unitscaled computational graph as an instance of a constraintscaled computational graph, with scales selected via the following: 1. Initially set aside any scale constraints, and calculate the scaling factors that give each op expected unit variance outputs (this process is covered below). 2. Now resolve any scale constraints by taking each constrained group {α, β1, . . . , βl} and selecting the geometric mean (α β1 . . . βl) 1 l+1 . This compromise is necessary to ensure valid gradients, but diverges from strict unit scale. In practice though, we observe that the scales going into our geometric mean are often similar enough to preserve approximate unit variance. Selecting scaling factors Assuming unit-scaled inputs to y = f(xi, . . . , xk), derive the output scale σY and set the forward scaling factor α = 1/σY . Repeat this process for x i = fgrad(. . . )i, i [1..k], to obtain the gradient scale σx i and set the backward scaling factor βi = 1/σx i. (See Table A.2 for the scaling factors of common ops.) Note that our assumption of unit-scaled inputs above is justified by inductive reasoning: we assume that a given op has unit-scaled inputs, which allows us to unit scale its outputs. In this way, unit scale propagates through the graph. The base-cases here are the model s initial inputs, corresponding to parameters and input data. As we initialise parameters to have unit scale, the only extra step we require is to normalise the input data. 5.3. Weighted addition For the most part, the scale of tensors at initialisation in unscaled deep learning models does not play a critical role. A notable exception is when tensors of different scales are added, for example residual layers, losses and positional encodings. Unit Scaling def scaled(X, alpha=1, beta=1): # Forward: Y = X * alpha # Backward: grad_X = grad_Y * beta def scaled_projection(X, W): (b, _), (m, n) = X.shape, W.shape alpha = beta_X = (m * n) ** -(1/4) beta_W = b ** -(1/2) X = scaled(X, beta=beta_X) W = scaled(W, beta=beta_W) return scaled(matmul(X, W), alpha) class FFN(nn.Module): def __init__(self, d, h): super().__init__() self.norm = Layer Norm(d) sigma = (d * h) ** -(1/4) self.W_1 = Parameter( randn(d, h) * sigma) self.W_2 = Parameter( randn(h, d) * sigma) def forward(self, X): Z = self.norm(X) Z = matmul(Z, self.W_1) Z = gelu(Z) Z = matmul(Z, self.W_2) return X + Z class Scaled FFN(nn.Module): def __init__(self, d, h, tau): super().__init__() self.norm = Scaled Layer Norm(d) self.W1 = Parameter(randn(d, h)) self.W2 = Parameter(randn(h, d)) self.tau = tau def forward(self, X): a = (1 - self.tau) ** (1/2) b = self.tau ** (1/2) Z = self.norm(scaled(X, beta=b)) Z = scaled_projection(Z, self.W1) Z = scaled_gelu(Z) Z = scaled_projection(Z, self.W2) return X * a + scaled(Z, b) Figure 3. Py Torch examples. Left: Scaled projection op, which implicitly constrains βX. Center vs Right: Unscaled vs scaled Transformer FFN layers. Changes: a) initialise weights with unit scale, b) replace unscaled with scaled ops, c) replace residual add with interpolation according to τ, moving the backward pass scale as in Section 5.2. See Figure A.2 for the implementation of scaled and further ops. If we na ıvely convert these add ops to unit-scaled equivalents, they place equal weight on their inputs, which can be detrimental to performance. We propose using weighted add (Table A.2) to resolve this. This introduces new hyperparameters into the model, which can be chosen by design principle, empirically by sweep, or selected to match a reference model (see Appendix H). For residual layers, there are existing design principles in literature. We consider the following residual layers based on NF-Res Nets (Brock et al., 2021): default: xl+1 = xl + f(xl) (not suitable for unit scaling) fixed (τ): xl+1 = 1 τ xl + τ f(xl) running-mean: xl+1 = p l/(l+1) xl + p 1/(l+1) f(xl) An issue with these weighting rules is that they may produce small gradient scales in the residual branch, which isn t a cut-edge so can t be independently rescaled. To resolve this, we perform a special-case rewrite to replace γ f(x) with id (f(id (x, 1, γ)), γ, 1), where id (x, α, β) is the scaled identity function. This maintains unit scale for the backward pass fgrad, while preserving G as a scaled op. 5.4. Recipe We now outline a high-level recipe for a unit-scaled model: 1. Initialise non-bias parameters with unit variance. 2. Calculate scaling factors for all scaled ops. 3. Identify non-cut-edges, and constrain the ops consuming them to have α = β by taking the geometric mean. 4. Replace adds with weighted adds. Unconstrained scaling factors are as outlined in Appendix G. Identifying cut-edges may sound challenging, but in practice is similar across models. The set of cut-edges commonly contains parameters and any encoder/decoder layers (anything before/after a stack of residual layers). After applying this recipe, training and inference proceed as usual. To align a unit-scaled model with an existing model, there are some additional considerations. We cover these in Appendix H. One notable difference is that unit scaled models have different effective optimiser step sizes across their parameters versus unscaled models.2 While this difference can be compensated by per-tensor step size modifiers, it means that the training dynamics may be different by default. 5.5. Example Using the unit scaling recipe, we first build a scaled op, and then a full scaled layer. Consider a scaled projection op with learnable weights: matmul (X, W) = α X W matmul grad(X, W, G)1 = β1 G W matmul grad(X, W, G)2 = β2 X G , for input X Rb m, weight W Rm n, output Rb n and incoming gradients G Rb n. Assuming large b, m, n, the analysis of Appendix E.1 gives unconstrained scaling factors α = m 1 2 , β1 = n 1 2 , β2 = b 1 2 . Typically, the edge connecting the weights W is a cut-edge, while the edge connecting in the inputs X is not. Given that assumption, we constrain α = β1, satisfied by setting both to the geometric mean of the unconstrained values: α = β1 = (m n) 1 4 . We leave β2 unchanged. We show code for the above in Figure 3, which also gives a scaled layer for the Transformer FFN of Figure 1. 2For instance, a larger effective step size for bias parameters when using unit scaling. Effective step size considers the effect of an optimiser update on model output, rather than parameters. Unit Scaling Regular better +FP16 better 1.6 1.8 2.0 Regular, FP32 Regular, FP16, No scaling Regular better +FP16 +LS better 1.6 1.8 2.0 Regular, FP32 Regular, FP16, Loss scaling (2048) Regular better 1.6 1.8 2.0 Regular, FP32 Unit scaling, FP32 Regular better US +FP16 better 1.6 1.8 2.0 Regular, FP32 Unit scaling, FP16 Figure 4. Character language modelling, showing validation bits per character over a wide range of models. Each point represents one combination of: {Conv, RNN, Attention}, {Pre, Post, No norm}, {Fixed, Running-mean residual}, {SGD, Adam}, {2, 8 Layers}. Each point is the best final value over a learning rate sweep. 6.1. Character language modelling Experimental Setup To evaluate unit scaling for multiple model architectures and optimisers, we perform small-scale experiments on Wiki Text-103 raw character language modelling (Merity et al., 2017). We train causal language models, using cross entropy loss during training and evaluate on bits per character (BPC). All models follow the pattern of a Transformer decoder layer (Vaswani et al., 2017), with the following variants: Sequence layer type: Attention, RNN and Convolution. Norm placement: Pre Norm, Post Norm and No Norm. Residual scaling: default, fixed and running-mean (as defined in Section 5.2). Over the product of these settings, we compare the performance of regular (baseline) and unit scaling in both FP32 and FP16. For this, we also evaluate the regular model in FP16 with loss scaling. For full hyperparameters and details, see Appendix J.1. Results The above configurations amount to a 2092-run sweep, the results of which are shown in Figure 4. First, these demonstrate the need for scaling when using FP16. This is due to gradient underflow, since loss scaling with a factor of 2048 resolves the issue. Second, they demonstrate that unit scaling, despite changing the training behaviour of the model beyond just numerics, matches or even slightly improves upon baseline performance in almost all cases. Finally, they show that no tuning is necessary when switching unit scaling to FP16. We also explore the effect of using different residual scaling schemes, with results shown in Figure A.3. We find that performance is not sensitive to the choice of scheme, and suggest that running-mean or fixed are reasonable choices when using unit scaling. 6.2. Masked language modelling Experimental setup To evaluate unit scaling against a standard baseline known for challenging numerics, where loss scaling is conventionally required (Lin et al., 2020), we train unit-scaled BERTBASE and BERTLARGE models. We use the standard BERT masked language model pretraining objective over English Wikipedia articles, and demonstrate downstream performance on SQu AD v1.1 and SQu AD v2.0 (Rajpurkar et al., 2016; 2018). We follow the unit scaling recipe, along with our guide on aligning a unit scaled model with a regular model (Appendix H). Full hyperparameters and details are covered in Appendix J.2. Note that we do not sweep any additional hyperparameters for our unit-scaled BERT (or character language models) relative to the baselines. Results We report our results in Table 2. For unit scaling in FP16, we are able to attain the same performance as the baseline model, and whereas the baseline requires sweeping a loss scale, unit scaling works in all cases out-of-the-box. Due to differences in the effective optimiser step size across parameters (Section 5.4), our regular and unit-scaled models aren t exactly equivalent, but deviations in their downstream performance are minor (BERTBASE is slightly below the baseline, and BERTLARGE is slightly above). For FP8, we build on the results of Noune et al. (2022) who demonstrate the training of loss-scaled BERT in FP8 with no degradation relative to FP16. We show that the same can also be achieved with unit scaling, with no additional techniques required to make FP8 work over FP16 we simply quantise our matmul inputs into FP8 and are able to train accurately. These results represent the first time BERTBASE or BERTLARGE have been trained in either FP16 or FP8 without requiring a form of loss scaling. To highlight the precise effects of unit scaling, we show histograms for activations, weights and gradients for unitscaled FP16 BERT. These can be found in Figures A.5, A.7, alongside equivalent plots for a regular FP16 BERT. Unit Scaling Table 2. Downstream performance of regular and unit-scaled BERT models. We pretrain 3 models for every model-method-format combination, then fine-tune 5 SQu AD v1.1 and 5 v2.0 runs for each (i.e. 15 runs per downstream task). The values shown represent the mean across the 15 runs, with indicating the standard deviation across the mean scores of the 3 sub-groups. published result from Devlin et al. (2019). published result from Noune et al. (2022); this model also adds an activation scale alongside the loss scale. Model Method Precision SQu AD v1.1 SQu AD v2.0 EM F1 EM F1 No Scaling FP32 80.8 88.5 Loss Scaling FP16 80.55 ( 0.16) 88.19 ( 0.16) 73.36 ( 0.27) 76.47 ( 0.23) Unit Scaling FP16 79.96 ( 0.31) 87.86 ( 0.44) 72.31 ( 0.60) 75.70 ( 0.53) Unit Scaling FP8 80.15 ( 0.18) 88.04 ( 0.12) 72.28 ( 0.02) 75.67 ( 0.01) No Scaling FP32 84.1 90.9 78.7 81.9 Loss Scaling FP16 84.23 ( 0.20) 90.93 ( 0.14) 77.52 ( 0.63) 80.54 ( 0.61) Loss Scaling FP8 83.40 ( 0.23) 90.69 ( 0.16) Unit Scaling FP16 85.67 ( 0.10) 92.14 ( 0.08) 79.94 ( 0.10) 82.97 ( 0.09) Unit Scaling FP8 85.22 ( 0.03) 91.77 ( 0.10) 79.29 ( 0.31) 82.29 ( 0.29) The code used in these experiments can be found at https://github.com/graphcore-research/ unit-scaling-demo, alongside a separate notebook implementing a unit-scaled Nano GPT model. We recommend this resource for those looking to understand unit scaling through a simple example implementation. For those interested in using unit scaling in their own models, we also provide a Py Torch library: https://graphcore-research.github.io/ unit-scaling. The documentation includes a practical guide to developing and optimising a unit-scaled model. This implementation should be considered a definitive reference for unit scaling. 7. Related Work Variance scaling analysis Klambauer et al. (2017) and Peiwen and Changsheng (2022) propose activation functions that encourage unit-variance activations and gradients, which are complementary to unit scaling. He et al. (2016) introduce residual networks, using skip connections and explicit normalisation to stabilise forward and backward passes. Variants on normalisation (Ioffe and Szegedy, 2015; Ba et al., 2016; Labatie et al., 2021; Salimans and Kingma, 2016) are complementary to unit scaling, which considers the norm of the gradients as well as activations and does not constrain activation norms after initialisation. Alternative residual schemes (Zhang et al., 2019; Brock et al., 2021) can be incorporated into unit-scaled models, although the residual layer output variance should not be allowed to grow with depth. The reparameterisation implied by unit scaling is also used by Jacot et al. (2018), later broadened by Yang and Hu (2020) and exploited by Yang et al. (2022) in their work analysing the training behaviour of deep networks. Moti- vated by low-precision computation rather than training dynamics, unit scaling applies scaling factors locally throughout the compute graph, but the effect on training hyperparameter scaling is similar. FP8 inference Although there has been little hardware support for FP8 training, accelerated 8-bit inference is increasingly common via the use of integer quantisation (Jacob et al., 2018) to the INT8 format. This process typically results in degraded accuracy, requiring additional techniques such as quantisation-aware training (see Nagel et al. (2021) for a thorough discussion on this topic). Though recent efforts have been made to improve efficient INT8 quantisation (Yao et al., 2022; Park et al., 2022; Dettmers et al., 2022; Xiao et al., 2022), the use of FP8 enables accelerated inference in the same format as training, promising a substantial improvement in the simplicity and accuracy of 8-bit inference (Kuzmin et al., 2022). 8. Discussion Compute overhead Unit scaling relies solely on the addition of scaling operations of the form γ X, where γ is a fixed scalar and X is a tensor. These scaling factors can be fused into the preceding ops (e.g. via torch.jit, torch.compile or jax.jit). By doing this we observe that the increase in memory-access cost is negligible. For models with reasonably large hidden sizes, the compute overhead is also minimal. For example, the FLOPs required to train our unit-scaled BERTLARGE are only 0.2% greater than the baseline (explained further in Appendix I.2). Basic loss scaling operates on a similar principle, and only introduces a single scaling factor. From this we conclude that both techniques have low overall overhead, assuming a fused implementation. Unit Scaling Automatic loss scaling has an additional feature which increases overhead: its requirement to occasionally discard batches. This assumes that re-scaling is determined by tracking gradient overflows (the standard approach, as used in Py Torch (2023)). When overflows occur, batches must not be used to update parameters. The overhead of dropping batches is tolerable for FP16 but may not be for FP8 (Micikevicius et al., 2022). Proposed automatic per-tensor scaling schemes take a different approach, and have potential to add overhead in other areas (how much depends largely on software and hardware characteristics). Micikevicius et al. (2022) reject scaling based on gradient overflows, instead opting for heuristics based on properties of the tensors being scaled. Their preferred training heuristic is not specified, but for inference they choose between max, percentile, and minimum MSE methods. These approaches trade-off overhead for accuracy. At one extreme, max is likely easy to fuse but may be distorted by outliers; at the other extreme minimum MSE may be more robust but is challenging to implement efficiently (e.g. Sakr et al. (2022)). Distributed training adds further challenges, potentially requiring the communication of statistics across devices to keep scales synchronised. It remains to be seen whether effective automatic scaling methods can be implemented efficiently given these complexities. This will likely be an important future research objective. In contrast unit scaling, with fixed precomputed scaling factors, offers a simpler alternative. Broader impact The potential for unit scaling to simplify the use of 8-bit number formats may lead to increased adoption, and in turn facilitate training larger models. At scale, new capabilities emerge (Wei et al., 2022), potentially exacerbating known harms (Weidinger et al., 2021) such as toxicity (Nadeem et al., 2020), misinformation (Lin et al., 2021), privacy concerns (Carlini et al., 2021) and environmental damage (Strubell et al., 2019). To mitigate these outcomes, a variety of methods have been proposed, including reinforcement learning from human (Ouyang et al., 2022) or AI (Bai et al., 2022) feedback, anti-experts (Liu et al., 2021) and baked-in safety models (Xu et al., 2020), all of which are applicable to unit-scaled models. Conclusion We have demonstrated that unit scaling addresses the complexities of low-precision training, providing a simpler and more granular solution. This is demonstrated by our training of BERTLARGE for the first time without loss scaling, in FP16 and even FP8. The community s transition to FP8 training will see new capabilities emerge as a result of improved efficiency, and this transition can be accelerated by unit scaling. Acknowledgements We would like to thank the following people for their contributions to the paper at the various stages of its development: Daniel Justus, Alberto Cattaneo, Andrew Fitzgibbon, Paul Balanca, Luke Prince, Ivan Chelombiev, Luka Ribar and Zach Eaton-Rosen. Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton. Layer normalization. ar Xiv preprint ar Xiv:1607.06450, 2016. Yuntao Bai, Saurav Kadavath, Sandipan Kundu, Amanda Askell, Jackson Kernion, Andy Jones, Anna Chen, Anna Goldie, Azalia Mirhoseini, Cameron Mc Kinnon, et al. Constitutional ai: Harmlessness from ai feedback. ar Xiv preprint ar Xiv:2212.08073, 2022. Yelysei Bondarenko, Markus Nagel, and Tijmen Blankevoort. Understanding and overcoming the challenges of efficient transformer quantization. ar Xiv preprint ar Xiv:2109.12948, 2021. Andy Brock, Soham De, Samuel L. Smith, and Karen Simonyan. High-performance large-scale image recognition without normalization. 38th International Conference on Machine Learning, ICML 2021, 2021. Nicholas Carlini, Florian Tramer, Eric Wallace, Matthew Jagielski, Ariel Herbert-Voss, Katherine Lee, Adam Roberts, Tom Brown, Dawn Song, Ulfar Erlingsson, et al. Extracting training data from large language models. 30th USENIX Security Symposium (USENIX Security 21), pages 2633 2650, 2021. Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, and Sebastian et al. Gehrmann. Palm: Scaling language modeling with pathways. ar Xiv preprint ar Xiv:2204.02311, 2022. Zihang Dai, Zhilin Yang, Yiming Yang, Jaime G. Carbonell, Quoc Viet Le, and Ruslan Salakhutdinov. Transformer-xl: Attentive language models beyond a fixed-length context. Proceedings of the 57th Conference of the Association for Computational Linguistics, ACL, 2019. Tim Dettmers, Mike Lewis, Younes Belkada, and Luke Zettlemoyer. Llm.int8(): 8-bit matrix multiplication for transformers at scale. ar Xiv preprint ar Xiv:2208.07339, 2022. Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: Pre-training of deep bidirectional transformers for language understanding. 2019 Conference Of The North American Chapter Of The Association Unit Scaling For Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), 2019. Hadi Esmaeilzadeh, Emily Blem, Renee St. Amant, Karthikeyan Sankaralingam, and Doug Burger. Dark silicon and the end of multicore scaling. Proceedings of the 38th annual international symposium on Computer architecture, pages 365 376, 2011. Xavier Glorot and Yoshua Bengio. Understanding the difficulty of training deep feedforward neural networks. 13th International Conference on Artificial Intelligence and Statistics, AISTATS 2010, 2010. Graphcore. Graphcore launches C600 PCIe card for AI compute. https://www.graphcore.ai/posts/ graphcore-launches-c600-pcie-card-for -ai-compute, 2022. (Online: accessed 25 January 2023). Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. IEEE International Conference on Computer Vision, ICCV 2015, 2015. Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. IEEE/CVF Conference on Computer Vision and Pattern Recognition, CVPR 2016, 2016. Sara Hooker. The hardware lottery. Communications of the Association for Computing Machinery, 2021. Xiao Shi Huang, Felipe Perez, Jimmy Ba, and Maksims Volkovs. Improving transformer optimization through better initialization. Proceedings of the 37th International Conference on Machine Learning, 2020. Computer Society IEEE. IEEE standard for floating-point arithmetic. IEEE Std 754-2019, pages 1 84, 2019. Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. 32nd International Conference on Machine Learning, ICML 2015, 2015. Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew G. Howard, Hartwig Adam, and Dmitry Kalenichenko. Quantization and training of neural networks for efficient integer-arithmetic-only inference. IEEE/CVF Conference on Computer Vision and Pattern Recognition, CVPR 2018, 2018. Arthur Jacot, Franck Gabriel, and Cl ement Hongler. Neural tangent kernel: Convergence and generalization in neural networks. Advances in Neural Information Processing Systems 31, Neur IPS 2018, 2018. Zhe Jia, Blake Tillman, Marco Maggioni, and Daniele Paolo Scarpazza. Dissecting the graphcore ipu architecture via microbenchmarking. ar Xiv preprint ar Xiv:1912.03413, 2019. Dhiraj Kalamkar, Dheevatsa Mudigere, Naveen Mellempudi, Dipankar Das, Kunal Banerjee, Sasikanth Avancha, Dharma Teja Vooturi, Nataraj Jammalamadaka, Jianyu Huang, Hector Yuen, Jiyan Yang, Jongsoo Park, Alexander Heinecke, Evangelos Georganas, Sudarshan Srinivasan, Abhisek Kundu, Misha Smelyanskiy, Bharat Kaul, and Pradeep Dubey. A study of BFLOAT16 for deep learning training. ar Xiv preprint ar Xiv:1905.12322, 2019. Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization. 3rd International Conference on Learning Representations, ICLR 2015, 2015. G unter Klambauer, Thomas Unterthiner, Andreas Mayr, and Sepp Hochreiter. Self-normalizing neural networks. Advances in Neural Information Processing Systems 30, Neur IPS 2017, 2017. Oleksii Kuchaiev, Boris Ginsburg, Igor Gitman, Vitaly Lavrukhin, Jason Li, Huyen Nguyen, Carl Case, and Paulius Micikevicius. Mixed-precision training for nlp and speech recognition with openseq2seq. ar Xiv preprint ar Xiv:1805.10387, 2018. Andrey Kuzmin, Mart Van Baalen, Yuwei Ren, Markus Nagel, Jorn Peters, and Tijmen Blankevoort. Fp8 quantization: The power of the exponent. ar Xiv preprint ar Xiv:2208.09225, 2022. Antoine Labatie, Dominic Masters, Zach Eaton-Rosen, and Carlo Luschi. Proxy-normalizing activations to match batch normalization while removing batch dependence. Advances in Neural Information Processing Systems 34, Neur IPS 2021, 2021. Jiahuang Lin, Xin Li, and Gennady Pekhimenko. Multinode BERT-pretraining: Cost-efficient approach. ar Xiv preprint ar Xiv:2008.00177, 2020. Stephanie Lin, Jacob Hilton, and Owain Evans. Truthful QA: Measuring how models mimic human falsehoods. ar Xiv preprint ar Xiv:2109.07958, 2021. Alisa Liu, Maarten Sap, Ximing Lu, Swabha Swayamdipta, Chandra Bhagavatula, Noah A Smith, and Yejin Choi. Dexperts: Decoding-time controlled text generation with experts and anti-experts. ar Xiv preprint ar Xiv:2105.03023, 2021. Alexandra Sasha Luccioni, Sylvain Viguier, and Anne Laure Ligozat. Estimating the carbon footprint of bloom, a 176b parameter language model. ar Xiv preprint ar Xiv:2211.02001, 2022. Unit Scaling Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models. 5th International Conference on Learning Representations, ICLR 2017, 2017. Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, et al. Mixed precision training. 6th International Conference on Learning Representations, ICLR 2018, 2018. Paulius Micikevicius, Dusan Stosic, Patrick Judd, John Kamalu, Stuart Oberman, Mohammad Shoeybi, Michael Siu, and Hao Wu. FP8 formats for deep learning. ar Xiv preprint ar Xiv:2209.05433, 2022. Moin Nadeem, Anna Bethke, and Siva Reddy. Stereo Set: Measuring stereotypical bias in pretrained language models. ar Xiv preprint ar Xiv:2004.09456, 2020. Markus Nagel, Marios Fournarakis, Rana Ali Amjad, Yelysei Bondarenko, Mart van Baalen, and Tijmen Blankevoort. A white paper on neural network quantization. ar Xiv preprint ar Xiv:2106.08295, 2021. Badreddine Noune, Philip Jones, Daniel Justus, Dominic Masters, and Carlo Luschi. 8-bit numerical formats for deep neural networks. ar Xiv preprint ar Xiv:2206.02915, 2022. Nvidia. Nvidia H100 Tensor Core GPU Architecture. https://resources.nvidia.com/en-ustensor-core, 2022. (Online: accessed 25 January 2023). Long Ouyang, Jeff Wu, Xu Jiang, Diogo Almeida, Carroll L. Wainwright, Pamela Mishkin, Chong Zhang, Sandhini Agarwal, Katarina Slama, Alex Ray, et al. Training language models to follow instructions with human feedback. ar Xiv preprint ar Xiv:2203.02155, 2022. Gunho Park, Baeseong Park, Se Jung Kwon, Byeongwook Kim, Youngjoo Lee, and Dongsoo Lee. nu Qmm: Quantized matmul for efficient inference of large-scale generative language models. ar Xiv preprint ar Xiv:2206.09557, 2022. Yuan Peiwen and Zhu Changsheng. Normalized activation function: Toward better convergence. ar Xiv preprint ar Xiv:2208.13315, 2022. Py Torch. Automatic mixed precision package - torch.amp. https://pytorch.org/docs/stable/amp. html, 2023. (Online: accessed 25 January 2023). Jack W. Rae, Sebastian Borgeaud, Trevor Cai, Katie Millican, Jordan Hoffmann, Francis Song, John Aslanides, Sarah Henderson, Roman Ring, Susannah Young, et al. Scaling language models: Methods, analysis & insights from training Gopher. ar Xiv preprint ar Xiv:2112.11446, 2021. Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. SQu AD: 100,000+ questions for machine comprehension of text. ar Xiv preprint ar Xiv:1606.05250, 2016. Pranav Rajpurkar, Robin Jia, and Percy Liang. Know what you don t know: Unanswerable questions for SQu AD. ar Xiv preprint ar Xiv:1806.03822, 2018. Charbel Sakr, Steve Dai, Rangha Venkatesan, Brian Zimmer, William Dally, and Brucek Khailany. Optimal clipping and magnitude-aware differentiation for improved quantization-aware training. 39th International Conference on Machine Learning, ICML 2022, 2022. Tim Salimans and Durk P Kingma. Weight normalization: A simple reparameterization to accelerate training of deep neural networks. Advances in Neural Information Processing Systems 29, Neur IPS 2016, 2016. Emma Strubell, Ananya Ganesh, and Andrew Mc Callum. Energy and policy considerations for deep learning in nlp. ar Xiv preprint ar Xiv:1906.02243, 2019. Xiao Sun, Jungwook Choi, Chia-Yu Chen, Naigang Wang, Swagath Venkataramani, Vijayalakshmi Srinivasan, Xiaodong Cui, Wei Zhang, and Kailash Gopalakrishnan. Hybrid 8-bit floating point (HFP8) training and inference for deep neural networks. Advances in Neural Information Processing Systems 32, Neur IPS 2019, 2019. Richard S. Sutton. The bitter lesson. http: //www.incompleteideas.net/Inc Ideas/ Bitter Lesson.html, 2019. (Online: accessed 25 January 2023). Tesla. A guide to tesla s configurable floating point formats & arithmetic. https://tesla-cdn.thron.com/ static/MXMU3S_tesla-dojo -technology_1WDVZN.pdf, 2021. (Online: accessed 25 January 2023). Thomas N. Theis and H.-S. Philip Wong. The end of Moore s law: A new beginning for information technology. Computing in Science & Engineering, 19(2):41 50, 2017. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in Neural Information Processing Systems 30, Neur IPS 2017, 2017. Unit Scaling Naigang Wang, Jungwook Choi, Daniel Brand, Chia-Yu Chen, and Kailash Gopalakrishnan. Training deep neural networks with 8-bit floating point numbers. ar Xiv preprint ar Xiv:1812.08011, 2018. Jason Wei, Yi Tay, Rishi Bommasani, Colin Raffel, Barret Zoph, Sebastian Borgeaud, Dani Yogatama, Maarten Bosma, Denny Zhou, Donald Metzler, et al. Emergent abilities of large language models. ar Xiv preprint ar Xiv:2206.07682, 2022. Laura Weidinger, John Mellor, Maribeth Rauh, Conor Griffin, Jonathan Uesato, Po-Sen Huang, Myra Cheng, Mia Glaese, Borja Balle, Atoosa Kasirzadeh, et al. Ethical and social risks of harm from language models. ar Xiv preprint ar Xiv:2112.04359, 2021. Yonghui Wu, Mike Schuster, Zhifeng Chen, Quoc V. Le, Mohammad Norouzi, Wolfgang Macherey, Maxim Krikun, Yuan Cao, Qin Gao, Klaus Macherey, Jeff Klingner, Apurva Shah, Melvin Johnson, Xiaobing Liu, Lukasz Kaiser, Stephan Gouws, Yoshikiyo Kato, Taku Kudo, Hideto Kazawa, Keith Stevens, George Kurian, Nishant Patil, Wei Wang, Cliff Young, Jason Smith, Jason Riesa, Alex Rudnick, Oriol Vinyals, Greg Corrado, Macduff Hughes, and Jeffrey Dean. Google s neural machine translation system: Bridging the gap between human and machine translation. ar Xiv preprint ar Xiv:1609.08144, 2016. Guangxuan Xiao, Ji Lin, Micka el Seznec, Julien Demouth, and Song Han. Smoothquant: Accurate and efficient posttraining quantization for large language models. ar Xiv preprint ar Xiv:2211.10438, 2022. XLA and Tensor Flow teams. XLA Tensor Flow, compiled. https://developers.googleblog.com/ 2017/03/xla-tensorflow-compiled.html, 2017. (Online: accessed 26 January 2023). Jing Xu, Da Ju, Margaret Li, Y-Lan Boureau, Jason Weston, and Emily Dinan. Recipes for safety in open-domain chatbots. ar Xiv preprint ar Xiv:2010.07079, 2020. Greg Yang and Edward J. Hu. Feature learning in infinitewidth neural networks. ar Xiv preprint ar Xiv:2011.14522, 2020. Greg Yang, Edward J. Hu, Igor Babuschkin, Szymon Sidor, Xiaodong Liu, David Farhi, Nick Ryder, Jakub Pachocki, Weizhu Chen, and Jianfeng Gao. Tensor programs V: Tuning large neural networks via zero-shot hyperparameter transfer. ar Xiv preprint ar Xiv:2203.03466, 2022. Zhewei Yao, Reza Yazdani Aminabadi, Minjia Zhang, Xiaoxia Wu, Conglong Li, and Yuxiong He. Zeroquant: Efficient and affordable post-training quantization for large- scale transformers. ar Xiv preprint ar Xiv:2206.01861, 2022. Yang You, Jing Li, Sashank Reddi, Jonathan Hseu, Sanjiv Kumar, Srinadh Bhojanapalli, Xiaodan Song, James Demmel, Kurt Keutzer, and Cho-Jui Hsieh. Large batch optimization for deep learning: Training BERT in 76 minutes. ar Xiv preprint ar Xiv:1904.00962, 2019. Hongyi Zhang, Yann N. Dauphin, and Tengyu Ma. Residual learning without normalization via better initialization. 7th International Conference on Learning Representations, ICLR 2019, 2019. Julian Georg Zilly, Rupesh Kumar Srivastava, Jan Koutnık, and J urgen Schmidhuber. Recurrent highway networks. 34th International Conference on Machine Learning, ICML 2017, 2017. Unit Scaling A. Floating point format specification Table A.1. Common floating point formats for deep learning. E refers to the number of exponent bits, and M the number of mantissa bits of a given format. Max exp. and Min exp. refer to the maximum and minimum values that can be represented by the exponent, excluding special values. E5 (a) and E4 (a) refer to the FP8 formats introduced by Noune et al. (2022), whereas E5 (b) and E4 (b) refer to those introduced by Micikevicius et al. (2022) Format E M Max exp. Min exp. FP32 8 23 127 -126 TF32 8 10 127 -126 BFLOAT16 8 7 127 -126 FP16 5 10 15 -14 FP8 E5 (a) 5 2 15 -15 FP8 E5 (b) 5 2 15 -14 FP8 E4 (a) 4 3 7 -7 FP8 E4 (b) 4 3 8 -6 B. Proposed FP8 formats Here we analyse the recently-proposed FP8 formats. We cover two proposals for 8-bit floating point formats (Noune et al., 2022; Micikevicius et al., 2022) (other proposals include Tesla (2021); Kuzmin et al. (2022)), each of which introduce one format with 4 exponent bits and a second format with 5. We refer to these here as E4 and E5 respectively (with the implication that the remaining bits represent the sign and mantissa). To compensate for the low number of representable values, all of the proposed formats except the Micikevicius et al. (2022) E5 format deviate from the IEEE 754 standard by reducing the number of special values available. Both Noune et al. (2022) formats also increment the IEEE 754 bias by one. This slightly alters the maximum and minimum (absolute normal) values that each can represent. FP8 formats in the literature are sometimes presented as having an explicit bias value, to be defined by the user (Noune et al., 2022; Kuzmin et al., 2022). The bias is subtracted from the exponent, just as in the IEEE 754 standard. This approach is equivalent to multiplying by 2 bias, and hence is no different from using a scaling factor to control the range of values represented. Micikevicius et al. (2022) explore both interpretations, with a preference for the scaling-factor viewpoint which aligns better with software implementations, whereas the exponent-bias viewpoint is more hardware aligned and in practice is likely to restrict bias values to integers. These caveats aside, the proposed FP8 formats do not differ significantly from a standard-compliant 8-bit format. C. Is unit standard deviation the correct criterion? Here we justify the position that aiming for unit standard deviation of normally distributed tensors at initialisation is a sensible strategy. When considering the scale of a floating-point tensor, we aim to keep absolute values within the upper and lower absolute normal bounds defined by a given format. To analyse the absolute values generated by a normal distribution, we instead consider a folded normal distribution with zero mean and unit variance. Here, the central 90% of all probability mass falls within [2 4, 21]. As a point of comparison, for an IEEE 754 float the absolute range of normal values we can represent is approximately h 22E 1, 22 2E 1i , giving a centre-point (in log-space) of 21. From the perspective of clipping error, one might suggest scaling values to be as close as possible to this point, as we are equidistant from the upper and lower boundaries. Hence, we can conclude that unit standard deviation will concentrate most values very near to, but slightly below the centre of the numerical range. Whether centrality within the normal floating-point range is the correct criterion for normally-distributed tensors during the training of deep learning models is a much harder question to answer. In favour of sub-central scaling, is the argument that the subnormal values provides us with extra range at the lower end of the spectrum, albeit with reduced precision. Additionally, underflow in deep learning models tends to be less detrimental to training than overflow. In favour of super-central scaling, is the argument that we might expect values such as gradients to decrease in magnitude during the course of training (our results in Section K suggest that this is true for BERT s gradw values, though not for gradx), and so we ought to up-scale values to compensate. In light of these arguments, we argue that in situations where we can control scale, aiming for unit scaling is a sensible compromise. If we wished to precisely align the 90%-probability-mass range with the centre point calculated above, we might aim for a slightly larger scale. But given the confounding factors outlined, the difference is small enough that σ = 20 is still a strong choice, and keeps us aligned with other techniques in the literature with the same aim (e.g. Glorot and Bengio (2010)). D. Unit scaling and emergent outliers Recent work on inference quantisation for large language models (>1B parameters) has highlighted the importance of Unit Scaling special techniques for accommodating outliers. These are large-magnitude values concentrated in particular sequenceelements (Bondarenko et al., 2021) or feature-dimensions (Dettmers et al., 2022), emerging as model size increases. The main difficulty with accommodating tensors with outliers is that a single outlier can reduce the quantisation precision of all other values (Dettmers et al., 2022). These outliers have been shown to degrade INT8 quantisation accuracy at the 6.7B parameter model size and above, which leads to a key question: what impact do we expect outliers to have for unit scaling when applied to models of this size? Firstly, we do not expect unit scaling to have a significant effect on the magnitude of outliers. This is because outliers occur in activation tensors, and these typically have a similar scale in unit and non-unit-scaled models (primarily due to the frequent presence of layernorms, which control scale). However, we still expect unit scaling to be less impaired by outliers than the examples seen in recent literature. The key consideration here is that unit scaling is a training method and uses floating-point formats. In contrast, the literature on emergent outliers has all been in the integer quantisation setting. Integer formats lack the dynamic range for training (Noune et al., 2022), and the same problem arises in the presence of outliers. We anticipate that using FP8 over INT8 will mitigate the difficulties presented to unit scaling by outliers. An analysis of the relative SNRs of the formats is insightful: We first make some assumptions about the problem setting. We take the work of Dettmers et al. (2022) as our starting point, who show that the median outlier magnitude is 60 as accuracy begins to degrade. The distribution of non-outlier values is not clear, though the authors define non-outliers to have a magnitude of < 6. Hence, we assume that these have remained approximately unit-scaled. To represent values in INT8 we will assume that they are scaled throughout such that outliers are (just) within the range of the format. This involves dividing by the expected maximum outlier value, and multiplying by the maximum INT8 value (127). We will assume a maximum outlier value of 3 the median, giving a scaling of 127/(3 60). To represent values in FP8 (E4) we do not need to re-scale values to accommodate outliers as the maximum FP8 E4 value is already larger than the maximum outlier, at 240. Having scaled INT8 to accommodate outliers, the key question is what effect this has on the representation of nonoutlier values. As observed in the literature, the range of the quantisation distribution is too large so that most quantisation bins are empty and small quantisation values are quantised to zero, essentially extinguishing information (Dettmers et al., 2022). 2 30 2 20 2 10 20 210 220 230 scale of non-outliers median outlier value FP16 FP8 E4 FP8 E5 INT8 ( 127/180) Figure A.1. The signal to noise ratio (SNR) of a quantised normal distribution, as a function of the distribution s scale. This plot is the same as Figure 2, but with the addition of scaled INT8 quantisation and vertical lines for outliers and non-outliers. We model this scenario, calculating an SNR for the nonoutlier values of only 2.03 (this raises to 14.8 if we scale for the median outlier rather than the max). In contrast, the SNR calculated in FP8 E4 is 635x higher at 1.29 103. This is due to the exponential distribution of values in floatingpoint formats, which gives it a small number of large values (suitable for outliers) and a large number of small values (suitable for non-outliers). This can be observed in Figure A.1, where we plot the SNR for this INT8 quantisation applied to a normally distributed tensor across different scales. Although INT8 gives a good representation of the outlier values (as does FP8 E4), the non-outlier values have low signal. One challenge for FP8 is the scenario in which outlier magnitude increases; in this case we would have to either re-scale or switch to the less precise E5 format. Another way of viewing this is to look at the number of quantisation bins each format makes use of in this setting. For INT8 the lower 95% of non-outlier values are assigned to just 3 out of 256 quantisation bins. In contrast, for FP8 90 bins are utilised. This modelling gives us cause for optimism when applying unit scaling in the presence of outliers, though we acknowledge there may still be challenges. E. Theoretical results E.1. Example scaling analysis We reproduce a simple version of the scaling analysis of Glorot and Bengio (2010), for a multilayer perceptron (MLP). Consider an MLP which transforms inputs X0 to outputs XL using Xl+1 = f(Xl Wl) for l [0, . . . , L 1], where f( ) is an elementwise activation function. We separate the analysis of a single layer into Z = XW and Y = f(Z). Unit Scaling Projection First, Z = XW, where Z Rb n, X Rb m, W Rm n, and X, W each have independently distributed elements with zero mean and variance σ2 X and σ2 W respectively. The values in Z follow Zik = P j Xij Wjk, which is a sum over m uncorrelated products, each with variance σ2 Xσ2 W . Then, by the variance of an independent sum, the output variance σ2 Z = m σ2 Xσ2 W . When computing the partial derivative of a scalar loss L with respect to X, XL = ( ZL) W , assuming ZL is zero mean with variance σ2 ZL and is not correlated with W,3 then by same reasoning as above σ2 XL = n σ2 ZL σ2 W . And again σ2 W L = b σ2 ZL σ2 X. Activation Consider f(Z) = relu(Z) = max(Z, 0), with Z N(0, 1). Then, in the forward pass P(f(Z) = y) = 1 2δ(y) + H(y) PN (y), where PN ( ) is the pdf of a standard normal distribution, and H( ) is the Heaviside step function. This gives variance σ2 Y = 1 2(1 1/π). In the backward pass, P( ZL = z ) = 1 2PN (z ), with variance σ2 ZL = 1 He et al. (2015) note that the activation function can break the local distributional assumption for the first step: for example, the Re LU function f(Z) = max(Z, 0) does not produce zero mean output, invalidating our previous assumption on Xl. However, the corrections for such invalid assumptions are often small, and can be ignored for sake of expedience, permitting local scaling analysis. For an example of extending scale analysis to training, Huang et al. (2020) consider the training dynamics of a Transformer under Adam, using this to derive an initialisation scheme that avoids vanishing updates. E.2. Proofs in support of Proposition 5.1 For two common choices of optimiser, SGD and Adam, we show that there is an unscaled model with identical training dynamics as any unit-scaled model. We define a model as an op with scalar output and a subset of inputs denoted as trainable parameters θi, written f(θi 1...n, xj 1...k). A training trajectory is defined as a sequence θ(t) i for all parameters in a model, given initial settings θ(0) i and optimiser. 3This is likely to be a very bad assumption, since W was used to generate Z and therefore ZL. But it is hard to avoid this assumption without doing a global analysis of the model. θ(t+1) i = θ(t) i η f(. . . ) = θ(t) i ηfgrad(. . . , 1)i , where η is a constant learning rate hyperparameter. We define the trajectory under a scaled op similarly, using f grad: θ ,(t+1) i = θ ,(t) i ηf grad(. . . , 1)i . Proposition E.1. For any scaled op with training trajectory θ ,(t) i under SGD, there exists an equivalent unscaled op with training trajectory θ(t) i = p α/βi θ ,(t) i . We consider the evolution of the following unscaled op under SGD on θ: ˆf(θi 1...n, xj 1...k) α f( p βi/α θi 1...n, xj 1...k). Applying the chain rule to obtain gradients, ˆf(θi 1...n, . . . ) βi /α θi 1...n, . . . ) Substituting to get the evolution of θi under SGD, θ(t+1) i = θ(t) i η p α/βi βi f( p βi /α θ(t) i 1...n, . . . ) θi . We can now use the define θ as follows and obtain θ ,(t+1) i = θ ,(t) i ηf grad(θ ,(t) i 1...n, . . . , 1)i. Therefore if the initial condition θ(0) i = p α/βi θ ,(0) i is satisfied, then θ(t) i = p α/βi θ ,(t) i thereafter. E.2.2. ADAM As noted by Kingma and Ba (2015), Adam is invariant to diagonal rescaling of the gradients. Defining the function adam that computes a single update thus: θ(t+1) = adam(θ(t), f invariance to diagonal rescaling gives adam(θ(t), f θ ) = adam(θ(t), s f for any positive-valued scaling vector s (R+)|θ| that is constant over all timesteps t. Proposition E.2. For any scaled op with training trajectory θ(t) i under Adam with ϵ = 0, there exists an equivalent unscaled op with the same training trajectory. Unit Scaling Consider the unscaled op ˆf(. . . ) = α f(. . . ). This follows the trajectory θ(t+1) i = adam(θ(t) i , α f Now consider the scaled op f with the same α, f. This follows: θ ,(t+1) i = adam(θ ,(t) i , βi f = adam(θ ,(t) i , βi Therefore if θ ,(0) = θ(0), we conclude θ ,(t) = θ(t). E.3. Example a scaled computational graph does not necessarily represent a scaled op Let f(x1, . . . , xn) be an unscaled operation with values in Rn and consider the scaled computational graph defined by x + f (x, α, β1, . . . , βn). If this scaled computational graph represented a scaled op h (x1, . . . , xn) for some function h(x1, . . . , xn), there would exist scalars α , β 1, . . . , β n such that: α h(x) = x + f (x, α, β) , β ig h(. . . ) xi = gi + f grad(x, α, β, g)i i {1, . . . , n} . Consider f(x) = x2, so that f (x, α, β) = α x2 , f grad(x, α, β, g)i = 2βi xi gi i {1, . . . , n} . This implies β i α gi (1 + 2αxi) = gi + 2βi gi xi i {1, . . . , n} . Assuming gi = 0, in the case α = βi these two expressions cannot be made to match by any choice of (α , β i). Therefore the scaled graph does not implement a scaled op. E.4. Proof of Theorem 5.2 We first define how a computational graph represents an op. Then we show that an unscaled graph correctly represents an unscaled op. Finally, we proceed to show that a constraintscaled graph with a single output correctly represents a scaled op. Graph op We adopt a generalisation of the earlier definition of an op, to permit multiple outputs. An op defines mappings from k vector-valued inputs to m vector-valued outputs via f(xi 1...k)j 1...m, and corresponding gradient mappings, fgrad(xi 1...k, gj 1...m)i X j g j f(xi 1...k)j We use f G to denote the graph op represented by the computational graph G. To evaluate the function and the vector Jacobian product fgrad,G, we assign inputs and outputs to edges in the graph.4 Define a list of input edges, ini 1...k E, and output edges, outj 1...m E. Define the forward value of an edge using z : E R( ), via the recursive relations: z(ini) xi , z((u, v)) fu({z((w, u)) | (w, u) E})v, f G(xi 1...k)j z(outj), where fu(. . . )v evaluates node u s output corresponding to the edge (u, v). Similarly, define the backward value of an edge using h : E R( ) via: h(outj) gj, h((u, v)) fgrad,v({z((u , v))} , {h((v, r))})u, fgrad,G(. . . , gj 1...m)i h(ini), where fgrad,v(. . . )u evaluates the grad op for node v for the input xv,u corresponding to the edge (u, v). Note that we use the shorthand {z((u , v))} to denote {z((u , v)) | (u , v) E}. Unscaled graph op To show that (f G, fgrad,G) represent an op, we must show they are consistent with the definition of fgrad. We expand the backward value using the definition of fgrad,v, h((u, v)) = X w h((v, w)) fv({z((u , v))})w Using the base case for h(outj) and the chain rule, h((u, v)) = X q h((w, q)) fw(. . . )q ! fv(. . . )w h((u, v)) = X j g j f G,v(. . . )j Therefore h(ini) gives the correct gradient, so G correctly represents an op. 4It is often natural to assign inputs and outputs to nodes, but we use edges in our analysis for notational convenience. Such edges imply the existence of dummy nodes. Unit Scaling Constraint-scaled graph scaled op Again, generalising the earlier definition to multiple outputs, f (xi 1...k)j α f(xi 1...k)j , f grad(xi 1...k, gj 1...m)i βi X j g j f(xi 1...k)j Note that all outputs are scaled using a single value α. Using the same definitions for z and h, h((u, v)) = βv,u X w h((v, w)) fv({z((u , v))})w w h((v, w)) f v ({z((u , v))})w In order to apply the chain rule here, we must first deal with the scale ratio βv,u αv . To do this, we define the unscaled backward value, ˆh, in terms of a single reachable output out and a rescaling function s : E E R, thus: ˆh((u, v)) h((u, v)) s((u, v), out), (u,v) Ecut(a,b) where Ecut(a,b) is the set of edges where, after the removal of any one, there is no path connecting the head of a and the head of b in G. We observe this property for adjacent edges: s((v, w), out) s((u, v), out) = ( αv βv,u if (u, v) is a cut-edge 1 otherwise , which follows directly from the definition of s. Now we substitute into our grad, ˆh((u, v)) = X w γ(u, v, w) ˆh((v, w)) f v (. . . )w xv,u , γ(u, v, w) βv,u αv s((v, w), out) s((u, v), out) . Consider two cases: Case 1: (u, v) is not a cut-edge. The rules of constraintscaled computation graphs ensure βv,u = αv. From the aforementioned property, s((u, v), out) = s((v, w), out). So we conclude γ(u, v, w) = 1. Case 2: (u, v) is a cut-edge. From the same property, we conclude γ(u, v, w) = 1. Since in either case, γ(u, v, w) = 1, we can simplify: ˆh((u, v)) = X w ˆh((v, w)) f v (. . . )w xv,u , which is the correct form for the chain rule and induction from the base case as previously, noting that s(out, out) = 1 so ˆh(out) = g. We can therefore conclude that ˆh gives true gradients and: ˆh((u, v)) = g f G,v(. . . ) h((u, v)) = s((u, v), out) ˆh((u, v)). So G represents a scaled op with βi = s(ini, out). F. Constraint-scaled computational graphs for other schemes For sake of comparison, it can be instructive to consider other scaling schemes within the constraint-scaled computational graph framework. Glorot initialisation (Glorot and Bengio, 2010) For a layer Y = f(XW), consider the scales σY and σ XL, ignoring σ W L. Apply full constraints, and typically use arithmetic mean rather than geometric mean to combine scales. Finally, push the combined scale into the initialisation of W, so that no multiplication is required at execution time. Loss scaling (Micikevicius et al., 2018) Introduce a single scaled identity op before the loss. f (x) = α x, f grad(x, g) = β g. Since this edge is always a cut-edge, set α = 1, and use β to generate gradients that all share a single scale. Unlike unit scaling, there are no local distributional assumptions that can inform the choice of loss scale it must be chosen empirically or heuristically. Scaled dot product self attention (Vaswani et al., 2017) When computing the similarity matrix A = QK , Q, K Rs d, consider the scale σA, ignoring σ QL, σ KL. Apply fully constrained scaling, yielding α = β1 = β2 = 1 d. This is perhaps the best pre-existing example of a commonly employed scheme similar to unit scaling. G. Unit scaled ops compendium Unit scaling relies on the correct selection of the scaling factors α, βi, . . . , βk for a given op. These scaling factors are derived from an analysis of the scaling of a given operation and its corresponding grad op, as outlined in Section 5.2, with an example of analysing the scaling of a multilayer perceptron given in Appendix E. To avoid practitioners having to analyse the scaling characteristics of each op in their model by hand, we provide a reference for common ops in Table A.2, giving scaled versions of each op alongside necessary scaling factors. Unit Scaling We provide further details on the derivation of certain nontrivial scaled operations below. Activations We calculate the scaling of Re LU analytically, based on the analysis in Appendix E.1. The other activation functions given are not amenable to the same procedure, so we calculate their scaling empirically (this is done through the use of short programs, which only need consider functions in isolation rather than within a larger model). Softmax (followed by matmul) We make the simplifying assumption in our analysis that the output of a softmax over s normally-distributed elements is uniformly 1/s. In practice, there is some variance across output elements but this is small enough to ignore for our purposes. This deviates from our standard unit scaling assumption of zero mean and unit variance, with 1/s mean and zero variance instead. Hence we require a different strategy for scaling softmax if we wish to still propagate unit scale. We assume in this scenario that the softmax is followed by a matmul (as in multi-head self-attention). Based on this assumption, we scale by a factor of s, meaning the output is approximately a vector of ones. From the perspective of the subsequent matmul, its ideal choice of scaling factor is then identical to the scaling factor it would have required if its input were sampled from a unit normal distribution: m 1 2 , where m is the size of the dimension reduced over. The subsequent matmul op can then be implemented using our standard scaling without any special-case behaviour. We also find through empirical analysis that the backward pass of softmax requires s scaling, though in this direction it generates normally distributed values, conforming to our standard assumption. Softmax cross-entropy We now consider a softmax going into a cross-entropy function, treating this composition as a single operation: softmax xent(x, t) = log softmax(x)t (where t is the index of the target label), and assume that this is the final layer in a model used to generate a loss. On this basis, we need not consider forward scaling, and focus on the backward operation x = softmax xentgrad(x, t) and the calculation of its scaling factor β = 1/σ(x ). Assuming again that at the beginning of training the output of the softmax over s inputs is uniformly 1/s, the gradient of softmax cross-entropy is given by, x = softmax xentgrad(x, t)i = s , if i = t 1 s, otherwise where x Rs. To calculate σ(x ) we first observe that, from which we derive, σ(x )2 = E (x )2 E [x ]2 2 + (s 1) 1 1 2s + s2 + s 1 This gives us our scaling factor, β = s/ s 1. H. Aligning unit scaling with existing models Our presentation of unit scaling in Section 5 assumes the design of a model from scratch. However, we anticipate there will be cases in which practitioners will wish to unit scale existing models, such that their unit scaled model and base model are either equivalent or similar enough to give matching performance. Here we outline the additional considerations required to do so. We follow this approach for our BERT experiments in Section 6.2. H.1. Activation functions We take activation function to mean any non-linear element-wise function in a model. Due to non-linearity, the behaviour of an activation function f(x) depends on the scale of its input. Therefore a base model s activation functions may not have the same effect on their inputs as a unit scaled version, as the unit scaled model alters the scale of inputs. To counter this, one can introduce a scaling factor immediately before an activation function (temporarily breaking unit scale), and a second un-scaling factor immediately afterwards (restoring unit scale): ˆf(ˆx) = f(s1 ˆx) s2, where ˆf is our new aligned activation function, ˆx is assumed to be normally distributed with unit scale (not necessarily true for x in the base model), and s1, s2 R are our new scaling factors. Unit Scaling Table A.2. Table of unit scaling factors, based on simple distributional assumptions on inputs and gradients, most often that they are unit normal. Op Unit scaling factors matmul(Xb m, W m n)b n = XW α = m 1 2 , βX = n 1 2 , βW = b 1 sum(x) = Pn i=1 xi α = n 1 weighted add(xi 1...n, γi 1...n) = Pn i=1 γixi α = P 2 , βi = γ 1 i ACTIVATIONS relu(x) = max(x, 0) α = p 2/ (1 1/π), β = gelu(x) = x Φ(x) α = 1.701, β = 1.481 tanh(x) = (e2x 1)/(e2x + 1) α = 1.593, β = 1.467 sigmoid(x) = (1 + e x) 1 α = 4.802, β = 4.722 softmax(x)i = exi/ Ps j=1 ex j α = s, β = s softmax xent(x, t) = log softmax(x)t α = 1, β = s/ s 1 layer norm(Xb n, w, c)ij = cj + wj (Xij µi)/σi, . . .µi = 1 n Pn j=1 Xij, σi = q 1 n Pn j=1 X2 ij µ2 i α = 1, βx = 1, βw = βc = b 1 We select the first scaling factor such that s1 = σ(x), giving identical-scale inputs to both activation functions: σ(s1 ˆx) = σ(x). The second scaling factor is selected to restore unit scale: s2 = 1 σ(f(x)), giving, σ( ˆf(ˆx)) = f(σ(x) ˆx) All that remains is the estimation of σ(x) and σ(f(x)) in the base model. This can be done either analytically (by stepping through operations in the base model and calculating the expected scale at each) or empirically (via instrumentation of the base model). The latter method tends to be simpler and less error-prone, but the former is more mathematically rigorous and has the advantage of generating scaling factors that are a function of the model s hyperparameters. Note that although we temporarily break the assumption of unit scale in the above analysis, in practice scaling factors here are close enough to 1 that this momentary mis-scaling is negligible from a numerics perspective. H.2. Softmax functions The above analysis also applies to softmax functions. Although softmax is not an element-wise function, the same approach is still valid and s1, s2 should be chosen in the same way. Note that the standard softmax function is sometimes introduced with a temperature scalar T, by which all inputs are divided. Hence our method can be seen as tuning the effective temperature of the softmax to align the unit scaled model with the base model. H.3. Residual weighted add In Section 5.3 we recommended that practitioners introduce a weighted addition into their models between residual and skip branches, in order to actively select how much each contributes to the output. A typical unscaled base model implicitly makes this choice via the scaling effect of the residual branch (i.e. the ratio of σ(f(x))/σ(x), which typically = 1). For our unit-scaled model to be equivalent to the base model, we need the output of our addition to be equal up to a constant (unit) scaling factor α. Taking a fixed(τ) residual layer, this means we must maintain: 1 τ ˆx + τ ˆf(ˆx) = α(x + f(x)), where ˆf( ) is the residual branch and ˆx the input in our unit-scaled model. Thanks to unit scaling, we have ˆx = x/σ(x) and ˆf ˆ (x) = Unit Scaling f(x)/σ(f(x)) giving, 1 τ ˆx + τ ˆf(ˆx) = 1 τ x σ(x) + τ f(x) σ(f(x)) Our desired form requires the terms multiplying x and f(x) to be equal, meaning: 1 τ σ(x) = τ σ(f(x)) τ = σ(f(x))2 σ(x)2 + σ(f(x))2 , σ(x)2 + σ(f(x))2 , and recalling that our original definition of a fixed(τ) residual layer ensures that this still maintains a unit-scaled output. Hence to align the residual add operation with a base model, we need first need to use a fixed(τ) residual layer, and secondly calculate σ(x) and σ(f(x)) for the base model, plugging them into the above equation for τ. This calculation of σ in the base model can again be done analytically or empirically. For typical models, the correct value of τ is the same across layers. H.4. Shared parameters Weights used in multiple operations in the forward pass sum the weight gradients coming from those operations in the backward pass. The same argument used for the residual add applies to the alignment of this summation too: for a unit-scaled model to be equivalent it must match the ratio of scales going into this sum as in the base model. Unit scaling will normalise these all to have σ = 1, but this is not guaranteed in the base model. The same analysis as used for the residual add op can be applied here, with the same outcome. The calculation of the scale of residual branches in the base model should be substituted with the scale of each weight gradient. In the case that the weight gradient is used more than twice, the above argument will have to be generalised to multiple operands. H.5. Example: aligning BERT We follow the steps above in our experiments for Section 6.2, where we align unit-scaled BERT models against standard baseline models, to match performance. Here we outline how we apply the above rules in practice, along with a few additional considerations required due to specifics relating to the BERT architecture. Where these rules require the calculation of standard deviation of tensors in the base model, we always calculate them analytically, rather than relying on empirical measurements (though we have then used empirical measurements to check the correctness of our calculations). Embedding layer BERT contains three separate embeddings: a general word embedding, along with segment and positional embeddings. These are all combined using a summation at the beginning of the model. For unit scaling we must implement this using: xemb = weighted add xword, xsegxpos, 1 Weights are equal here as the initial scales of the embeddings in the base model are unchanged from their initialisation, and all are initialised with the same scale. FFN For the FFN, alignment need not be considered for the matmul and layernorm ops, which we scale using the set of scaling factors for common ops given in Table A.2. For the gelu activation function, we must follow the alignment process outlined above, applying scaling factors immediately before and after. Multi-head self-attention For multi-head self attention, we employ the rule for aligning softmax (followed by a matmul) given above. Again, matmuls do not require alignment with the base model. We note that in the particular case of the matmul with the V tensor, our standard distributional assumption of independent elements no longer strictly holds, due to correlation across the sequence dimension introduced by the segment embedding. This requires a slight correction to ensure unit scaling is maintained. Residual connection Both the FFN and multi-head selfattention layers are residuals, and as such employ the rule above for aligning weighted addition with a base model. Loss heads We train BERT according to the standard procedure of using two heads: one for the masked-languagemodelling (MLM) task, and one for the next-sentenceprediction (NSP) task. The NSP head uses a tanh activation function which requires alignment, and the MLM head reuses the weights of the word embedding for a matmul, requiring the above rule for aligning shared parameters. Each head is terminated by a softmax cross-entropy, that we also tune to match the base model. Unit Scaling Sequence length considerations Care must be taken when unit-scaling sequence-based models to account for the role of the sequence dimension. For many ops this effectively becomes an extra batch dimension, and must be handled as such when applying unit scaling. In our experiments we use padding to compensate for uneven-length input-sequences. In this case the value used for our sequence calculations is not the length of the sequence dimension, but the average number of non-padding tokens in a sequence (for our experiments, this was approximately 77% of the padded length). One additional complication specific to BERT, is that the gradients flowing back into the final transformer layer are sparse, as they only come via the subset of tokens used in the two heads (specifically, the [CLASS] token, and those tokens masked for the MLM head). As a result, backwardspass sequence length calculations for this layer must be adapted to assume a smaller sequence length, according to the level of sparsity in the gradient. I. Implementation Unit scaling is straightforward to implement in deep learning frameworks such as Py Torch, JAX and Tensor Flow, that support user-defined custom gradient autograd operations. A convenient way to do this is via a scaled identity op id (x, α, β), which can be used to implement scaled ops without defining custom gradients for each. I.1. Code examples We show an example implementations in Figure 3, with additional code listings in Figure A.2, demonstrating basic tools for constructing unit-scaled models in Py Torch. Note: scaled is the basic building block of unit-scaled models. It enables independent control of forward and backward pass scaling factors, and as such must be used with care it could be used to define a scaled graph with incorrect constraints, leading to gradients that are inconsistent with the forward pass of the model. scaled matmul demonstrates how to combine multiple constraints using geometric mean. scaled gelu implements only fully constrained scaling, for brevity. When scales are fully constrained, custom gradients via scaled are optional. Note that it may still be useful in certain situations for improving the scale of intermediate values. Scaled Layer Norm uses the usual assumption for scaled layers: weights are cut-edges, activations are not. This permits independent scales for the weight and bias parameters. class Scaled Grad(autograd.Function): @staticmethod def forward(ctx, X, alpha, beta): ctx.save_for_backward( tensor(beta, dtype=X.dtype)) return alpha * X @staticmethod def backward(ctx, grad_Y): beta, = ctx.saved_tensors return beta * grad_Y, None, None def scaled(X, alpha=1, beta=1): # Forward: Y = X * alpha # Backward: grad_X = grad_Y * beta return Scaled Grad.apply(X, alpha, beta) def scaled_matmul( A, B, constrain_A=True, constrain_B=True, ): (m, k), (_, n) = A.shape, B.shape alpha = k ** -(1/2) beta_A = n ** -(1/2) beta_B = m ** -(1/2) if constrain_A and constrain_B: alpha = beta_A = beta_B = \ (alpha * beta_A * beta_B) ** (1/3) elif constrain_A: alpha = beta_A = (alpha * beta_A) ** (1/2) elif constrain_B: alpha = beta_B = (alpha * beta_B) ** (1/2) A = scaled(A, beta=beta_A) B = scaled(B, beta=beta_B) return scaled(matmul(A, B), alpha) def scaled_gelu(X): return 1.5876 * gelu(X) class Scaled Layer Norm(nn.Layer Norm): def forward(self, x): beta = ( np.prod(self.normalized_shape) / x.nelement() ) ** 0.5 return nn.functional.layer_norm( x, self.normalized_shape, scaled(self.weight, beta=beta), scaled(self.bias, beta=beta), self.eps, ) Figure A.2. Definition of scaled in Py Torch, as a custom autograd function. Additional scaled ops and layers required for a Transformer FFN. See Table A.2 for a reference of scaling factors. Unit Scaling I.2. Computational overhead Unit scaling typically introduces one extra function invocation per invocation in the equivalent unscaled model. For example, matmul typically involves 3 function invocations during training, corresponding to 1 forward, 2 backward functions (one for each input). Using unit scaling, there are 3 additional rescaling function invocations of the form f(x, γ) = γ x, where γ R, x Rn. FLOPs Considering the typical theoretical metric for computational effort, floating point operations (FLOPs), the overhead appears much smaller. For the matmul op with forward pass matmul : Rb n Rn m Rb m, the amount of computational effort due to 3 matmul is 6 b n m (note this is 2 because multiply and add are counted separately), while rescaling consumes bn + nm + bm. Therefore the ratio of rescaling to matmul flops follows: FLOPrescaling FLOPmatmul = 1 6(b 1 + m 1 + n 1). Note that this is bounded above by (2 min(b, n, m)) 1. For the matmuls that dominate compute in many models, this minimum dimension corresponds to the hidden size. There are also operations other than matmuls that require scaling, but contribute negligible FLOPs. To simplify analysis, we ll assume that there are (ops per matmul 1) additional ops for every matmul in the model. So we write FLOPmatmul+ FLOPmatmul and FLOPrescaling+ = ops per matmul FLOPrescaling. This gives the following adjusted estimate for the FLOP overhead of unit scaling a model: FLOPrescaling FLOPunscaled = ops per matmul 2 hidden size . In the example of BERTLARGE, we set hidden size = 1024, pessimistically estimate ops per matmul = 4, and obtain a FLOP overhead of 0.2%. Other large models should behave in a similar manner, so we conclude that the theoretical FLOP overhead of unit scaling is small for large models. Actual performance will depend on many other factors, and we anticipate that FLOP-based measures are likely to be optimistic in predicting runtime overhead on typical deep learning hardware. However, we expect the efficiency gains of low-precision formats to vastly outweigh the scaling overhead. Fusing scale factors We anticipate substantial efficiency gains from fusing the fixed scale factors from unit scaling into preceding ops. This yields two potential benefits. First, fusing avoids the communication overhead of an extra round-trip to memory. Second, it may permit low-precision outputs and even intermediate values. This may be particularly valuable for distributed aggregation ops, where partial results are aggregated on separate workers before sharing them to compute a final result. Transformations implementing automatic fusing of ops are widely available using optimising compilers such as XLA (XLA and Tensor Flow teams, 2017). These are particularly effective at fusing consecutive elementwise ops, which should encompass most unit scaling factors (since matmul outputs are typically first used in add or activation functions). J. Additional experimental details and results J.1. Character language modelling The Wiki Text-103 raw dataset consists of approximately 500 million characters of text extracted from Wikipedia articles. We do not perform any additional preprocessing beyond that of the published dataset. All results correspond to the best value over a learning rate sweep starting from a low value, with step 2. A complete set of hyperparameters used is shown in Table A.3. Mixed precision When running in FP16, all activations, parameters and gradients are stored in FP16. Optimiser state is also stored in FP16, with the exception of Adam s second moment state, which is stored in FP32 since squared values are more prone to clipping. Model architectures All models are based on causal Transformer-like stacks that interleave contextual (i.e. token-mixing) layers and FFN layers. Input tokens are embedded by indexing into a trainable embedding table, and output token probabilities are generated by softmax(Wproj layernorm(x L) + bproj), where x L is the final hidden state from the Transformer stack. The basic unscaled layer definition follows: xl+1 = res(ffn, res(context, xl)) res No Norm(f, z) = interp(z, f(z)) res Pre Norm(f, z) = interp(z, f(layernorm(z))) res Post Norm(f, z) = layernorm(interp(z, f(z))) interpdefault(a, b) = a + b interpfixed(a, b; τ) = 1 τ a + τ b interpmean(a, b; l) = p l/(l + 1) a + p 1/(l + 1) b ffn(z) = W2 max(0, W1 z + b1) + b2 The contextual layers are as follows: 1. context Attention: multi-head dot product self attention using causal masking (Vaswani et al., 2017), with Unit Scaling Default better Fixed better 1.6 1.8 2.0 Default Fixed (τ = 0.5) attention-2L attention-8L conv-2L conv-8L rnn-2L Fixed better Mean better 1.6 1.8 2.0 Fixed (τ = 0.5) Running-mean Figure A.3. Comparison of residual scaling approaches. We observe (a) for regular models, default scaling performs similarly to fixed interpolation τ = 0.5; (b) in most cases, running-mean scaling is similar or better than fixed interpolation. The exception is 2-layer attention models, where we hypothesise that running mean places too much weight on the first layer, which is detrimental in such a shallow model. relative-positional encoding using sinusiodal bases following Dai et al. (2019), 2. context Conv: 1D grouped causal convolution with relu nonlinearity, 3. context RNN: recurrent highway network (Zilly et al., 2017) with tied transform and carry gates xt+1 = (1 g(xt)) xt + g(xt) f(xt), where g(x) is a projection with sigmoid nonlinearity, and f(x) is a projection with tanh nonlinearity. When applying unit scaling, we also reduce the learning rate for non-projection parameters by 1/ hidden size to compensate for the relative step size increase implied by unit scaling. Additional results Test set results, with multiple runs per learning rate are shown in Table A.4. These support the main findings shown for the wider sweep of Figure 4: unitscaled models perform comparably to regular models, and can be trained in FP16 without modification or additional hyperparameter selection. Figure A.3 shows the effect of employing residual scaling schemes described in Section 5.2. This supports the claim that fixed and running-mean residual scaling are viable alternatives to default scaling, since both perform well in regular and unit-scaled models. J.2. Masked language modelling We follow the standard practice of splitting BERT pretraining into two phases. For the first phase we use a sequence length of 128 tokens, and for the second we use 384. Tokens are derived using the Word Piece tokeniser (Wu et al., Table A.3. Character language modelling hyperparameters. Parameter Value Sequence length 256 characters Sequence mask 32 characters Batch size 2048 characters Training duration 219 steps Learning rate decay half-life 216 steps Adam (β1, β2) (0.9, 0.999) SGD momentum 0.9 Vocabulary size 5008 characters (100% coverage, no OOV) Hidden size 128 FFN size 512 Depth [2, 8] layers Attention heads 2 Attention head size 64 Relative positional frequency components 128 bases, period [1 ...1024] characters Convolution kernel size 7 Convolution group size 16 Typical learning rate ranges: Regular, SGD 2 8 . . . 2 4 Regular, Adam 2 12 . . . 2 8 Unit, SGD 2 14 . . . 2 10 Unit, Adam 2 8 . . . 2 4 Table A.4. Character language modelling, test BPC with 3 runs per learning rate. The best learning rate is chosen according to validation BPC. 95% confidence interval is 0.010. All models use Pre Norm and 8 layers, except where noted. Model Regular FP32 Unit scaling FP32 Unit scaling FP16 Attention (Post Norm) 1.548 1.540 1.540 Attention 1.582 1.562 1.567 Convolution 1.625 1.620 1.622 RNN (2 layers) 1.674 1.677 1.673 Unit Scaling Table A.5. BERT pre-training hyperparameters. Parameter Value Sequence length [128, 384] tokens (phase 1/2) Depth [12, 24] (base/large) Hidden size [768, 1024] (base/large) FFN size [3072, 4096] (base/large) Attention heads [12, 16] (base/large) Attention head size 64 Vocabulary size 30400 Total batch size [16320, 4080] seqs (ph. 1/2) Micro-batch size [8, 2] (phase 1/2) Data-parallel count 4 Gradient accumulation count 510 Training duration [28266, 8437] steps (ph. 1/2) Learning rate [0.0045, 0.0015] (phase 1/2) Warmup steps [2827, 275] steps (phase 1/2) Learning rate decay linear Optimiser LAMB LAMB Beta1 0.9 LAMB Beta2 0.999 LAMB epsilon 1e-06 Weight decay 0.01 Weight init std 0.02 (unit scaling=n/a) Loss scaling [512, 512, 32768, 128] (base phase 1/2, large phase 1/2; unit scaling=n/a) 2016), with a vocabulary of 30400 tokens. Our masking approach is consistent with that used in Devlin et al. (2019). A complete set of pretraining hyperparameters used is shown in Table A.5. Mixed precision For FP16, we follow the same approach here as in our character language modelling experiments (appendix J.1), storing all tensors and optimiser state in FP16, apart from the optimiser second moment state which is stored in FP32 (note, we use the LAMB optimiser (You et al., 2019) here over Adam). For FP8, we modify our FP16 mixed precision strategy by quantising the inputs to all matmul operations. Note that our experiments do not utilise hardware FP8 support; we instead simulate FP8 training by quantising from FP16 to the set of supported values in a given FP8 format. In this, we are following the approach taken by Noune et al. (2022) and Micikevicius et al. (2022). As recommended in both studies, we also use E4 for activations and weights, and E5 for all gradients. Again, following the precedent set in these studies, the one matmul operation we exclude from FP8 quantisation is the vocabulary embedding matmul, which has been known to cause numerical instabilities. Hardware & distributed training Models were trained on IPU hardware (Jia et al., 2019), using either Bow Pod16 or IPU-POD16 Classic machines. On each machine we distribute training across 16 IPUs, using 4-way model parallelism and 4-way pipeline parallelism, with gradient accumulation across pipeline stages. K. Histograms of tensor-scaling within BERT To give readers a better intuitive sense of how loss scaling and unit scaling operate for a standard model, we provide histograms of absolute tensor values taken from FP16 BERTBASE. Figures A.4 and A.5 show the beginning of training for loss and unit scaling respectively, and Figures A.6 and A.7 show the end of training. We use 9 transformer layers rather than the standard 12 in order to accommodate the overheads of tracking histograms across all tensors in the model. For the sake of concision we omit histograms of the middle layers, which are substantially similar to layers 0 and 7 in both the forward and backward pass. A small number of numerically insignificant ops are also omitted. The first two figures can be understood as the full-model equivalent to the plot in Figure 1, with the second two showing how values shift as a result of training. The x-axis is labelled slightly differently to Figure 1, showing the log of absolute values rather than the exponent value, but by Unit Scaling the definition of floating point values given in Section 3.1, these two are approximately equivalent. We also have a special bin for the range 2 24, 2 14 , which represents all subnormal values in the FP16 range, and bins on either end to hold zero and infinity values. There are some surprising features in the shapes of these plots, resulting from the design of BERT. We provide a brief analysis here of our key plot: Figure A.5 (unit scaling at initialisation). K.1. Analysis of Figure A.5 Impact of unit scaling A comparison with Figure A.4 demonstrates the effectiveness of unit scaling. Whereas the loss-scaled model has to tune a hyperparameter to centre the two gradient sub-plots, unit scaling does this naturally. Furthermore, values in the unit-scaled model are typically closer to the centre of the range. Loss scaling also has the problem of very large gradx values in its NSP and MLM heads. Effect of aligning with regular BERT As outlined in Appendix H.5, we take a range of measures to align our unit scaled model more closely with the regular BERT base model, so that their performance is similar. This has the impact of temporarily mis-scaling certain operations. This can be seen most clearly in the case of gelu, which requires a scaling factor for alignment, but as a result is slightly below unit-scale in the diagram. Sparse gradients for layer 8 The gradx values for layer 8 in all plots have most of their values set to zero. This is a consequence of sparse gradients flowing back into this layer from the NSP and MLM heads, as described in Appendix H.5. The cross-sequence mixing of gradients in the multi-head self-attention layer has the effect of removing this sparsity, giving a strong signal for all subsequent layers. Three groups of gradient scales Our final observation is somewhat subtle, but key to understanding both the shape of the gradx plots, and the particular difficulties encountered when training BERT in low-precision. We note that in the gradx plots there are in effect three separate columns visible: a strong signal (i.e. many values) on the left, a faint signal through the centre, and a very small number of values on the right. This is a consequence of BERT s design, rather than of any scaling technique. The right-hand column is a result of the natural up-scaling of gradients flowing from BERT s NSP head. BERT naturally has larger gradients flowing out of this head. Note that these gradients are sparse, representing only a single token-gradient in each sequence, but the signal is kept alive throughout the layers by the residual connection, resulting in this feature of the plot. The central column comes out of the MLM head in a similar fashion. This is still sparse, but contains more tokengradients and hence gives a stronger signal. Finally the main left-hand column results from the mixing of gradients in the multi-head self-attention layer. This removes sparsity in the tensor, giving a stronger signal. However, the attention mechanism in BERT naturally lowers the scale of values, meaning this third signal is shifted to the left. The existence of these three groups of gradients creates a trimodal distribution of exponent values. As most values are still concentrated in the left-hand column, our assumption of a single normal distribution is still sufficient, but we effectively have to balance the positions of these three columns, meaning that the backward pass does not fall into a single, neat column. Unit Scaling -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf log2(|value|): % of values per bin: word position segment word+position+segment input matmul qkv weights matmul qk\sqrt(head sz) mask,softmax,dropout matmul v matmul o dropout +residual layer norm input matmul w1,gelu dropout +residual layer norm input matmul qkv weights matmul qk\sqrt(head sz) mask,softmax,dropout matmul v matmul o dropout +residual layer norm input matmul w1,gelu dropout +residual layer norm input matmul qkv weights matmul qk\sqrt(head sz) mask,softmax,dropout matmul v matmul o dropout +residual layer norm input matmul w1,gelu dropout +residual layer norm gather matmul embedding log softmax gather by label idx gather matmul w1,tanh matmul w2 log softmax gather by label idx activations -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf log2(|value|): % of values per bin: word position segment layer norm:beta layer norm:gamma query:kernel key:kernel value:kernel output:kernel layer norm:beta layer norm:gamma w1:bias w2:kernel w2:bias layer norm:beta layer norm:gamma query:kernel key:kernel value:kernel output:kernel layer norm:beta layer norm:gamma w1:bias w2:kernel w2:bias layer norm:beta layer norm:gamma query:kernel key:kernel value:kernel output:kernel layer norm:beta layer norm:gamma w1:bias w2:kernel w2:bias layer norm:beta layer norm:gamma prediction bias dense:kernel -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf 0.2 0.4 0.6 0.8 1.0 layer 8 attention layer 7 attention layer 0 attention 0.2 0.4 0.6 0.8 1.0 0.2 0.4 0.6 0.8 1.0 layer 8 attention layer 7 attention layer 0 attention 0.2 0.4 0.6 0.8 1.0 Figure A.4. A histogram of absolute values in regular BERTBASE at initialisation. Here a loss scale of 215 was required for stable training. We can understand loss scaling in light of this plot as enacting a shift of the gradx and gradw histograms by log2(loss scale) to the right. Unit Scaling -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf log2(|value|): % of values per bin: word position segment word+position+segment input matmul qkv weights matmul qk\sqrt(head sz) mask,softmax,dropout matmul v matmul o dropout +residual layer norm input matmul w1,gelu dropout +residual layer norm input matmul qkv weights matmul qk\sqrt(head sz) mask,softmax,dropout matmul v matmul o dropout +residual layer norm input matmul w1,gelu dropout +residual layer norm input matmul qkv weights matmul qk\sqrt(head sz) mask,softmax,dropout matmul v matmul o dropout +residual layer norm input matmul w1,gelu dropout +residual layer norm gather matmul embedding log softmax gather by label idx gather matmul w1,tanh matmul w2 log softmax gather by label idx activations -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf log2(|value|): % of values per bin: word position segment layer norm:beta layer norm:gamma query:kernel key:kernel value:kernel output:kernel layer norm:beta layer norm:gamma w1:bias w2:kernel w2:bias layer norm:beta layer norm:gamma query:kernel key:kernel value:kernel output:kernel layer norm:beta layer norm:gamma w1:bias w2:kernel w2:bias layer norm:beta layer norm:gamma query:kernel key:kernel value:kernel output:kernel layer norm:beta layer norm:gamma w1:bias w2:kernel w2:bias layer norm:beta layer norm:gamma prediction bias w1:bias w2:kernel -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf 0.2 0.4 0.6 0.8 1.0 layer 8 attention layer 7 attention layer 0 attention 0.2 0.4 0.6 0.8 1.0 0.2 0.4 0.6 0.8 1.0 layer 8 attention layer 7 attention layer 0 attention 0.2 0.4 0.6 0.8 1.0 Figure A.5. A histogram of absolute values in unit-scaled BERTBASE at initialisation. Unit scaling naturally places values in approximately the centre of the range without requiring a tuned hyperparameter. See Appendix K.1 for specific details of this plot. Unit Scaling -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf log2(|value|): % of values per bin: word position segment word+position+segment input matmul qkv weights matmul qk\sqrt(head sz) mask,softmax,dropout matmul v matmul o dropout +residual layer norm input matmul w1,gelu dropout +residual layer norm input matmul qkv weights matmul qk\sqrt(head sz) mask,softmax,dropout matmul v matmul o dropout +residual layer norm input matmul w1,gelu dropout +residual layer norm input matmul qkv weights matmul qk\sqrt(head sz) mask,softmax,dropout matmul v matmul o dropout +residual layer norm input matmul w1,gelu dropout +residual layer norm gather matmul embedding log softmax gather by label idx gather matmul w1,tanh matmul w2 log softmax gather by label idx activations -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf log2(|value|): % of values per bin: word position segment layer norm:beta layer norm:gamma query:kernel key:kernel value:kernel output:kernel layer norm:beta layer norm:gamma w1:bias w2:kernel w2:bias layer norm:beta layer norm:gamma query:kernel key:kernel value:kernel output:kernel layer norm:beta layer norm:gamma w1:bias w2:kernel w2:bias layer norm:beta layer norm:gamma query:kernel key:kernel value:kernel output:kernel layer norm:beta layer norm:gamma w1:bias w2:kernel w2:bias layer norm:beta layer norm:gamma prediction bias dense:kernel -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf 0.2 0.4 0.6 0.8 1.0 layer 8 attention layer 7 attention layer 0 attention 0.2 0.4 0.6 0.8 1.0 0.2 0.4 0.6 0.8 1.0 layer 8 attention layer 7 attention layer 0 attention 0.2 0.4 0.6 0.8 1.0 Figure A.6. A histogram of absolute values in regular BERTBASE at the end of training. Compare with figure A.4 to see the shift in distributions during training and the implications for numerics. Unit Scaling -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf log2(|value|): % of values per bin: word position segment word+position+segment input matmul qkv weights matmul qk\sqrt(head sz) mask,softmax,dropout matmul v matmul o dropout +residual layer norm input matmul w1,gelu dropout +residual layer norm input matmul qkv weights matmul qk\sqrt(head sz) mask,softmax,dropout matmul v matmul o dropout +residual layer norm input matmul w1,gelu dropout +residual layer norm input matmul qkv weights matmul qk\sqrt(head sz) mask,softmax,dropout matmul v matmul o dropout +residual layer norm input matmul w1,gelu dropout +residual layer norm gather matmul embedding log softmax gather by label idx gather matmul w1,tanh matmul w2 log softmax gather by label idx activations -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf log2(|value|): % of values per bin: word position segment layer norm:beta layer norm:gamma query:kernel key:kernel value:kernel output:kernel layer norm:beta layer norm:gamma w1:bias w2:kernel w2:bias layer norm:beta layer norm:gamma query:kernel key:kernel value:kernel output:kernel layer norm:beta layer norm:gamma w1:bias w2:kernel w2:bias layer norm:beta layer norm:gamma query:kernel key:kernel value:kernel output:kernel layer norm:beta layer norm:gamma w1:bias w2:kernel w2:bias layer norm:beta layer norm:gamma prediction bias w1:bias w2:kernel -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 inf 0.2 0.4 0.6 0.8 1.0 layer 8 attention layer 7 attention layer 0 attention 0.2 0.4 0.6 0.8 1.0 0.2 0.4 0.6 0.8 1.0 layer 8 attention layer 7 attention layer 0 attention 0.2 0.4 0.6 0.8 1.0 Figure A.7. A histogram of absolute values in unit-scaled BERTBASE at the end of training. Compare with figure A.5 to see the shift in distributions during training and the implications for numerics.