# memoryefficient_backpropagation_through_time__2cd42c52.pdf Memory-Efficient Backpropagation Through Time Audr unas Gruslys Google Deep Mind audrunas@google.com Rémi Munos Google Deep Mind munos@google.com Ivo Danihelka Google Deep Mind danihelka@google.com Marc Lanctot Google Deep Mind lanctot@google.com Alex Graves Google Deep Mind gravesa@google.com We propose a novel approach to reduce memory consumption of the backpropagation through time (BPTT) algorithm when training recurrent neural networks (RNNs). Our approach uses dynamic programming to balance a trade-off between caching of intermediate results and recomputation. The algorithm is capable of tightly fitting within almost any user-set memory budget while finding an optimal execution policy minimizing the computational cost. Computational devices have limited memory capacity and maximizing a computational performance given a fixed memory budget is a practical use-case. We provide asymptotic computational upper bounds for various regimes. The algorithm is particularly effective for long sequences. For sequences of length 1000, our algorithm saves 95% of memory usage while using only one third more time per iteration than the standard BPTT. 1 Introduction Recurrent neural networks (RNNs) are artificial neural networks where connections between units can form cycles. They are often used for sequence mapping problems, as they can propagate hidden state information from early parts of the sequence back to later points. LSTM [9] in particular is an RNN architecture that has excelled in sequence generation [3, 13, 4], speech recognition [5] and reinforcement learning [12, 10] settings. Other successful RNN architectures include the differentiable neural computer (DNC) [6], DRAW network [8], and Neural Transducers [7]. Backpropagation Through Time algorithm (BPTT) [11, 14] is typically used to obtain gradients during training. One important problem is the large memory consumption required by the BPTT. This is especially troublesome when using Graphics Processing Units (GPUs) due to the limitations of GPU memory. Memory budget is typically known in advance. Our algorithm balances the tradeoff between memorization and recomputation by finding an optimal memory usage policy which minimizes the total computational cost for any fixed memory budget. The algorithm exploits the fact that the same memory slots may be reused multiple times. The idea to use dynamic programming to find a provably optimal policy is the main contribution of this paper. Our approach is largely architecture agnostic and works with most recurrent neural networks. Being able to fit within limited memory devices such as GPUs will typically compensate for any increase in computational cost. 2 Background and related work In this section, we describe the key terms and relevant previous work for memory-saving in RNNs. 30th Conference on Neural Information Processing Systems (NIPS 2016), Barcelona, Spain. Definition 1. An RNN core is a feed-forward neural network which is cloned (unfolded in time) repeatedly, where each clone represents a particular time point in the recurrence. For example, if an RNN has a single hidden layer whose outputs feed back into the same hidden layer, then for a sequence length of t the unfolded network is feed-forward and contains t RNN cores. Definition 2. The hidden state of the recurrent network is the part of the output of the RNN core which is passed into the next RNN core as an input. In addition to the initial hidden state, there exists a single hidden state per time step once the network is unfolded. Definition 3. The internal state of the RNN core for a given time-point is all the necessary information required to backpropagate gradients over that time step once an input vector, a gradient with respect to the output vector, and a gradient with respect to the output hidden state is supplied. We define it to also include an output hidden state. An internal state can be (re)evaluated by executing a single forward operation taking the previous hidden state and the respective entry of an input sequence as an input. For most network architectures, the internal state of the RNN core will include a hidden input state, as this is normally required to evaluate gradients. This particular choice of the definition will be useful later in the paper. Definition 4. A memory slot is a unit of memory which is capable of storing a single hidden state or a single internal state (depending on the context). 2.1 Backpropagation through Time Backpropagation through Time (BPTT) [11, 14] is one of the commonly used techniques to train recurrent networks. BPTT unfolds the neural network in time by creating several copies of the recurrent units which can then be treated like a (deep) feed-forward network with tied weights. Once this is done, a standard forward-propagation technique can be used to evaluate network fitness over the whole sequence of inputs, while a standard backpropagation algorithm can be used to evaluate partial derivatives of the loss criteria with respect to all network parameters. This approach, while being computationally efficient is also fairly intensive in memory usage. This is because the standard version of the algorithm effectively requires storing internal states of the unfolded network core at every time-step in order to be able to evaluate correct partial derivatives. 2.2 Trading memory for computation time The general idea of trading computation time and memory consumption in general computation graphs has been investigated in the automatic differentiation community [2]. Recently, the rise of deep architectures and recurrent networks has increased interest in a less general case where the graph of forward computation is a chain and gradients have to be chained in a reverse order. This simplification leads to relatively simple memory-saving strategies and heuristics. In the context of BPTT, instead of storing hidden network states, some of the intermediate results can be recomputed on demand by executing an extra forward operation. Chen et. al. proposed subdividing the sequence of size t into t equal parts and memorizing only hidden states between the subsequences and all internal states within each segment [1]. This uses O( t) memory at the cost of making an additional forward pass on average, as once the errors are backpropagated through the right-side of the sequence, the second-last subsequence has to be restored by repeating a number of forward operations. We refer to this as Chen s t algorithm. The authors also suggest applying the same technique recursively several times by sub-dividing the sequence into k equal parts and terminating the recursion once the subsequence length becomes less than k. The authors have established that this would lead to memory consumption of O(k logk+1(t)) and computational complexity of O(t logk(t)). This algorithm has a minimum possible memory usage of log2(t) in the case when k = 1. We refer to this as Chen s recursive algorithm. 3 Memory-efficient backpropagation through time We first discuss two simple examples: when memory is very scarce, and when it is somewhat limited. When memory is very scarce, it is straightforward to design a simple but computationally inefficient algorithm for backpropagation of errors on RNNs which only uses a constant amount of memory. Every time when the state of the network at time t has to be restored, the algorithm would simply re-evaluate the state by forward-propagating inputs starting from the beginning until time t. As backpropagation happens in the reverse temporal order, results from the previous forward steps can not be reused (as there is no memory to store them). This would require repeating t forward steps before backpropagating gradients one step backwards (we only remember inputs and the initial state). This would produce an algorithm requiring t(t + 1)/2 forward passes to backpropagate errors over t time steps. The algorithm would be O(1) in space and O(t2) in time. When the memory is somewhat limited (but not very scarce) we may store only hidden RNN states at all time points. When errors have to be backpropagated from time t to t 1, an internal RNN core state can be re-evaluated by executing another forward operation taking the previous hidden state as an input. The backward operation can follow immediately. This approach can lead to fairly significant memory savings, as typically the recurrent network hidden state is much smaller than an internal state of the network core itself. On the other hand this leads to another forward operation being executed during the backpropagation stage. 3.1 Backpropagation though time with selective hidden state memorization (BPTT-HSM) The idea behind the proposed algorithm is to compromise between two previous extremes. Suppose that we want to forward and backpropagate a sequence of length t, but we are only able to store m hidden states in memory at any given time. We may reuse the same memory slots to store different hidden states during backpropagation. Also, suppose that we have a single RNN core available for the purposes of intermediate calculations which is able to store a single internal state. Define C(t, m) as a computational cost of backpropagation measured in terms of how many forward-operations one has to make in total during forward and backpropagation steps combined when following an optimal memory usage policy minimizing the computational cost. One can easily set the boundary conditions: C(t, 1) = 1 2t(t + 1) is the cost of the minimal memory approach, while C(t, m) = 2t 1 for all m t when memory is plentiful (as shown in Fig. 3 a). Our approach is illustrated in Figure 1. Once we start forward-propagating steps at time t = t0, at any given point y > t0 we can choose to put the current hidden state into memory (step 1). This step has the cost of y forward operations. States will be read in the reverse order in which they were written: this allows the algorithm to store states in a stack. Once the state is put into memory at time y = D(t, m), we can reduce the problem into two parts by using a divide-and-conquer approach: running the same algorithm on the t > y side of the sequence while using m 1 of the remaining memory slots at the cost of C(t y, m 1) (step 2), and then reusing m memory slots when backpropagating on the t y side at the cost of C(y, m) (step 3). We use a full size m memory capacity when performing step 3 because we could release the hidden state y immediately after finishing step 2. Step 3: cost = C(y, m) 1 2 y ... y+1 ... t Step 1: cost = y y+1 ... t Step 2: cost = C(t-y, m-1) Hidden state is propagated Gradients get back-propagated Hidden state stored in memory Internal state of RNN core at time t t Recursive application of the algorithm Hidden state is read from memory Hidden state is saved in memory Hidden state is removed from memory A single forward operation A single backward operation Figure 1: The proposed divide-and-conquer approach. The base case for the recurrent algorithm is simply a sequence of length t = 1 when forward and backward propagation may be done trivially on a single available RNN network core. This step has the cost C(1, m) = 1. (a) Theoretical computational cost measured in number of forward operations per time step. (b) Measured computational cost in miliseconds. Figure 2: Computational cost per time-step when the algorithm is allowed to remember 10 (red), 50 (green), 100 (blue), 500 (violet), 1000 (cyan) hidden states. The grey line shows the performance of standard BPTT without memory constraints; (b) also includes a large constant value caused by a single backwards step per time step which was excluded from the theoretical computation, which value makes a relative performance loss much less severe in practice than in theory. Having established the protocol we may find an optimal policy D(t, m). Define the cost of choosing the first state to be pushed at position y and later following the optimal policy as: Q(t, m, y) = y + C(t y, m 1) + C(y, m) (1) C(t, m) = Q(t, m, D(t, m)) (2) D(t, m) = argmin 1 y