# spatially_structured_recurrent_modules__55c2803e.pdf Published as a conference paper at ICLR 2021 SPATIALLY STRUCTURED RECURRENT MODULES Nasim Rahaman1,2 Anirudh Goyal2 Muhammad Waleed Gondal1 Manuel Wuthrich1 Stefan Bauer1 Yash Sharma3 Yoshua Bengio2,4 Bernhard Sch olkopf1 Capturing the structure of a data-generating process by means of appropriate inductive biases can help in learning models that generalize well and are robust to changes in the input distribution. While methods that harness spatial and temporal structures find broad application, recent work has demonstrated the potential of models that leverage sparse and modular structure using an ensemble of sparingly interacting modules. In this work, we take a step towards dynamic models that are capable of simultaneously exploiting both modular and spatiotemporal structures. To this end, we model the dynamical system as a collection of autonomous but sparsely interacting sub-systems that interact according to a learned topology which is informed by the spatial structure of the underlying system. This gives rise to a class of models that are well suited for capturing the dynamics of systems that only offer local views into their state, along with corresponding spatial locations of those views. On the tasks of video prediction from cropped frames and multi-agent world modeling from partial observations in the challenging Starcraft2 domain, we find our models to be more robust to the number of available views and capable of better generalization to novel tasks without additional training than strong baselines that perform equally well or better on the training distribution. 1 INTRODUCTION Many spatiotemporal complex systems can be abstracted as a collection of autonomous but sparsely interacting sub-systems, where sub-systems tend to interact if they are in each others vicinity. As an illustrative example, consider a grid of traffic intersections. Traffic flows from a given intersection to the adjacent ones, and the actions taken by some agent , say an autonomous vehicle, may at first only affect its immediate surroundings. Now suppose we want to forecast the future state of the traffic grid (say for the purpose of avoiding traffic jams). There is a spectrum of possible strategies for modeling the system at hand. On one extreme lies the most general strategy which considers the entirety of all intersections simultaneously to predict the next state of the grid (Figure 1c). The resulting model class can in principle account for interactions between any two intersections, irrespective of their spatial distance. However, the number of interactions such models must consider does not scale well with the size of the grid, and this strategy might be rendered infeasible for large grids with hundreds of intersections. On the other end of the spectrum is a strategy which abstracts the dynamics of each intersection as an autonomous sub-system, with each sub-system interacting only with its immediate neighbors (Figure 1a). The interactions may manifest as messages that one sub-system passes to another and possibly contain information about how many vehicles are headed towards which direction, resulting in a collection of message passing entities (i.e. sub-systems) that collectively model the entire grid. By adopting this strategy, one assumes that the immediate future of any given intersection is affected only by the present states of the neighboring intersections, and not some intersection at the opposite end of the grid. The resulting class of models scales well with the size of the grid, but is possibly unable to model certain long-range interactions that could be leveraged to efficiently distribute traffic flow. The spectrum above parameterizes the extent to which the spatial structure of the underlying system informs the design of the model. The former extreme ignores spatial structure altogether, resulting 1Max-Planck Institute for Intelligent Systems T ubingen, 2Mila, Qu ebec, 3Bethgelab, Eberhard Karls Universit at T ubingen, 4Universit e de Montreal. Correspondence to: . Published as a conference paper at ICLR 2021 (a) Fully localized sub-systems. (b) Middle ground. (c) Single, monolithic system. Figure 1: A schematic representation of the spectrum of modeling strategies. Solid arrows with speech bubbles denote (dynamic) messages being passed between sub-systems (dotted arrows denote the lack thereof). Gist: on one end of the spectrum, (Figure 1a), we have the strategy of abstracting each intersection as a sub-system that interact with neighboring sub-systems. On the other end of the spectrum (Figure 1c) we have the strategy of modeling the entire grid with one monolithic system. The middle ground (Figure 1b) we explore involves letting the model develop a notion of locality by (say) abstracting entire avenues with a single sub-system. in a class of models that can be expressive but whose sample and computational complexity do not scale well with the size of the system. The latter extreme results in a class of models that can scale well, but its adequacy (in terms of expressivity) is contingent on a predefined notion of locality (in the example above: the immediate four-neighborhood of an intersection). In this work, we aim to explore a middle-ground between the two extremes: namely, by proposing a class of models that learns a notion of locality instead of relying on a predefined one (Figure 1b). Reconsidering the traffic grid example: the proposed strategy results in a model that may learn to abstract (say) entire avenues with a single sub-system, if it is useful towards solving the prediction task. This yields a scheme where a single sub-system might account for events that are spatially distant (such as those in the opposite ends of an avenue), while events that are spatially closer together (like those on two adjacent avenues of the same street, where streets run perpendicular to avenues) might be assigned to different sub-systems. To implement this scheme, we build on a framework wherein the sub-systems are modelled as independent recurrent neural network (RNN) modules that interact sparsely via a bottleneck of attention (a variant of which is explored in Goyal et al. (2019)) while extending it along two salient dimensions. First, we learn an interaction topology between the sub-systems, instead of assuming that all sub-systems interact with all others in an all-to-all topology. We achieve this by learning to embed each sub-system in a space endowed with a metric, and attenuate the interaction between two given sub-systems according to their distance in this space (i.e., sub-systems too far away from each other in this space are not allowed to interact). Second, we relax a common assumption that the entire system is perceived simultaneously; instead, we only assume access to local (partial) observations alongside with the associated spatial locations, resulting in a setting that partially resembles that of Eslami et al. (2018). Expressed in the language of the example above: we do not expect a bird s eye view of the traffic grid, but only (say) LIDAR observations from autonomous vehicles at known GPS coordinates, or video streams from traffic cameras at known locations. The spatial location associated with an observation plays a crucial role in the proposed architecture in that we map it to the embedding space of sub-systems and address the corresponding observation only to sub-systems whose embeddings lie in close vicinity. Likewise, to predict future observations at a queried spatial location, we again map said location to the embedding space and poll the states of sub-systems situated nearby. The result is a model that can learn which spatial locations are to be associated with each other and be accounted for by the same sub-system. As an added plus, the parameterization we obtain is not only agnostic to the number of available observations and query locations, but also to the number of sub-systems. To evaluate the proposed model, we choose a problem setting where (a) the task is composed of different sub-systems or processes that locally interact both spatially and temporally, and (b) the environment offers local views into its state paired with their corresponding spatial locations. The challenge here lies in building and maintaining a consistent representation of the global state of the system given only a set of partial observations. To succeed, a model must learn to efficiently capture the available observations and place them in an appropriate spatial context. The first problem we consider is that of video prediction from crops, analogous to that faced by visual systems of many animals: given a set of small crops of the video frames centered around stochastically sampled pixels (corresponding to where the fovea is focused), the task is to predict the content of a crop Published as a conference paper at ICLR 2021 around any queried pixel position at a future time. The second problem is that of multi-agent world modeling from partial observations in spatial domains, such as the challenging Starcraft2 domain (Samvelyan et al., 2019; Vinyals et al., 2017). The task here is to model the dynamics of the global state of the environment given local observations made by cooperating agents and their corresponding actions. Finally, we also include visualizations on a multi-agent grid-world environment designed for simulating railroad traffic (Eichenberger et al., 2019). Importantly and unlike prior work (Sun et al., 2019), our parameterization is agnostic to the number of agents in the environment, which can be flexibly adjusted on the fly as new agents become available or existing agents retire. This is beneficial for generalization in settings where the number of agents during training and testing are different. Contributions. (a) We propose a new class of models, which we call Spatially Structured Recurrent Modules or S2RMs, which perform attention-driven modular computations according to a learned spatial topology. (b) We evaluate S2RMs (along with several strong baselines) on a selection of challenging problems and find that S2RMs are robust to the number of available observations and can generalize to novel tasks. 2 PROBLEM STATEMENT In this section, we build on the intuition from the previous section to formally specify the problem we aim to approach with the methods described in the later sections. To that end, let X be a metric space, O some set of possible observations, and OX a set of mappings X O. Now, consider the evolution function of a discrete-time dynamical system: φ : Z OX OX satisfying: (1) φ(0, o) = o where o OX and φ(t2, φ(t1, o)) = φ(t1 + t2, o) for t1, t2 Z Informally, o can be interpreted as the world state of the system; together with a spatial location x X, it gives the local observation O = o(x) O. Given an initial world state o, the mapping φ(t, o) yields the world state ot at some (future) time t, thereby characterizing the dynamics of the system (which might be stochastic). The problem we consider is the following: Problem: At every time step t = 0, ..., T, we are given a set of positions {xa t }A a=1 and the corresponding observations {Oa t }A a=1, where Oa t := ot(xa). The task is to infer the world state ot at some future time-step t > T in order to predict Oq t = ot (xq) at some arbitrary query position xq. In the traffic grid example of Section 1, one could imagine a as indexing traffic cameras or autonomous vehicles (i.e., observers), xa t as the GPS coordinates of observer a, and Oa t as the corresponding sensor feed (e.g. LIDAR observations or video streams from vehicles or traffic cameras). 3 MODELLING ASSUMPTIONS Given the problem in Section 2, we now make certain modelling assumptions. These assumptions will ultimately inform the inductive biases we select for the model (proposed in Section 4); nevertheless, we remark beforehand that as with any inductive bias, their applicability is subject to the properties of the system being modeled and the objectives being optimized (OOD generalization, etc). Recurrent Dynamics Modeling. While there exist multiple ways of modeling dynamical systems, we shall focus on recurrent neural networks (RNNs). Typically, RNN-based dynamics models are expressed as functions of the form: ht+1 = F(Ot, ht) Ot = D(ht) (2) where Ot is the observation at time t Z, and ht+1 is the hidden state of the model. F can be thought of as the parameterized forward-evolution function the hidden state h conditioned on the observation O, whereas D is a decoder that maps the hidden state to observations. Decomposition into Locally Interacting Sub-systems. We make the assumption that the dynamical system φ can be decomposed to constituent sub-systems (φ1, φ2, ..., φM) that dynamically and sparsely interact with each other while respecting some interaction topology. By interaction topology, we mean that each module φi can be identified with an embedding pi in a topological space S equipped with a similarity kernel Z and that the sub-system φi may preferentially interact with another sub-system φj if their respective embeddings are close in S with respect to Z, i.e. if Published as a conference paper at ICLR 2021 Encoder Decoder Positional Embedding Positional Embedding Observation Observation at Location Query Location Predicted Observation at Queried Location Input Attention Output Attention Set of Interacting RNNs Figure 2: Schematic illustration of the proposed architecture. An observation is addressed to modules with embeddings situated in vicinity of its embedded location. Likewise, modules with embeddings in the vicinity of an embedded query location are polled to produce a prediction. Z(pi, pj) is large. Intuitively, one may think of Z as inducing a notion of locality between subsystems, according to which φj lies in the local vicinity of φi. Locality of Observations. The notion of locality between sub-systems induced by Z is distinct from that induced by the metric of space X of locations in the environment (cf. Section 2), and one important modelling decision is how these two should interact. We propose to embed the position x X associated with an observation O to the metric space of sub-systems S via a continuous1 and injective mapping P : X S. This allows us to match the observation O to all sub-systems φm that are in the vicinity of P(x) S, i.e., where Z(P(x), pm) is sufficiently large. Each subsystem φm therefore accounts for observations made at a set of positions Xm X, which we call its enclave. 4 SPATIALLY STRUCTURED RECURRENT MODULES (S2RMS) Informed2 by the model assumptions detailed in the previous section, we now describe the proposed model (Figure 2) comprising the following components: Model Inputs. Recall from Section 2 that we have for every time step t = 0, ..., T a set of tuples of positions and observations {(xa t , Oa t )}A a=1 where xa t X and Oa t O for all t and a. To simplify, we assume that X Rn, and denote by xi the i-th component of the vector x X. Encoder. The encoder E is a parameterized function mapping observations O to a corresponding vector representation e = E(O). Here, E processes all observations in parallel across t and a to yield representations ea t . Positional Embedding. The positional embedding P is a fixed mapping from X to S. We choose S to be the unit sphere in d-dimensions, d being a multiple of 2n, and the positional encoder as the following function: P(x) = s/ s S where (si+m, si+1+m) = (sin (xm/10000i), cos (xm/10000i)) (3) with m = 0, ..., n 1 and i = 0, 2, ..., d/n 1. The above function finds common use (Vaswani et al., 2017; Mildenhall et al., 2020; Zhong et al., 2020) and can be motivated from the perspective of Reproducing Kernel Hilbert Spaces (Rahimi & Recht, 2007) (see Appendix C.3 for a discussion). We henceforth refer to P(x) as s and P(xa t ) as sa t . Set of Interacting RNNs. To model the dynamics of the world state, we use a set of M independent RNN modules, which we denote as {Fm}M m=1. To each Fm, we associate an embedding vector pm S, where all {pm}M m=1 are learnable parameters. The RNNs Fm interact with each other via an inter-cell attention, and with the input representations ea t via input attention. Precisely, at a given time step t, each Fm expects an input um t , an aggregated hidden state hm t and optionally, a memory state cm t to yield the hidden and memory states at the next time step: (hm t+1, cm t+1) = Fm(um t , hm t , cm t ) (4) 1The continuity of P ties the two notions of locality by requiring that an infinitesimal change in x corresponds to one in S. Injectivity ensures that no two points in X are mapped to the same point in S. 2In doing so, we use the assumptions merely as guiding principles; we do not claim that we infer the true decomposition of the ground-truth system, even if all assumptions are satisfied. Published as a conference paper at ICLR 2021 where the input um t results from the input attention and hm t from the inter-cell attention (see below). Kernel Modulated Dot-Product Attention. A central component of the proposed architecture is the kernel modulated dot-product attention (KMDPA), which we now define. First, we let Z : S S [0, 1] be the following kernel: Z(p, s) = exp [ 2ϵ(1 p s)], if p s τ 0, otherwise (5) where ϵ (0, ) is the kernel bandwidth, and τ [ 1, 1) is the truncation parameter (additional details in Appendix C.1). Now, KMDPA maps two sets A and B to a third set ˆA, where: A = {(ai, yi)}I i=1; B = {(bj, zj)}J j=1; ˆA = {(ak, ˆyk)}I k=1 = KMDPA(A, B) (6) Here, ai, bj S, and yi, zj are vectors of not necessarily the same dimension. In order to evaluate ˆA, we first compute the interaction weights Wij between any two pairs of entities (ai, yi) and (bj, zj), which depends on a local term W (L) ij and a non-local term Wij. We have: Wij = softmaxj Θ(Query)(yi) Θ(Key)(zj) ; W (L) ij = Z(ai, bj); Wij = Wij W (L) ij (7) where Θ(Query) and Θ(Key) are learnable linear mappings that project yi and zj to the same space. The penultimate step computes the following two quantities: j W (L) ij yj; yi = P j WijΘ(Value)(zj) (8) where Θ(Value) is another learnable linear function mapping from the vector space of z to that of y. Finally, we have: ˆyi = G( yi, yi) yi + 1 G( yi, yi) yi (9) where G is a gating layer with sigmoid non-linearity implementing a soft selection mechanism between the linear combination of values ( yi) and the inputs weighted by the local weights ( yi). In what follows, we will refer to the set A as query set, B as key set and ˆA as output set. Input Attention. The input attention mechanism is a KMDPA, mapping between sets of observation tuples {(sa t , ea t )}A a=1 (key set) and the current RNN-states {(pm, hm t )}M m=1 (query set) to that of RNN inputs {(pn, un t )}M n=1 (output set). Now on the one hand, we observe that the input um t to RNN Fm can contain information about an observations ea t only if the embedded location of the said observation sa t is close enough to the embedding of the RNN pm in S (i.e. if Z(sa t , pm) > 0), thereby implementing the assumption of locality of observations. On the other hand, the non-local term allows a module Fm to reject (or accept) an observation based on its content, which can be beneficial if two modules attend to overlapping regions in the environment but specialize to different aspects of the dynamics. Please refer to Appendix C.1 for a precise description of the mechanism, in particular the (optional) use of multiple dot-product attention heads. Inter-cell Attention. The intercell attention mechanism is another KMDPA, mapping two copies of the current RNN-states {(pm, hm t )}M m=1 (one as query and another as key set) to the set of aggregated hidden states {(pn, hn t )}M n=1 (output set). This enables local interaction between the RNNs Fm, in that the local term ensures that RNN Fm interacts with RNN Fn only if their respective embeddings pm and pn are close enough in S (i.e. if Z(pm, pn) > 0), thereby implementing the assumption of local interactions between sub-systems. The non-local term allows two modules to interact with each other based on their hidden states, i.e. it provides the mechanism for a module to (not) interact with another other based on their respective states, even if their embeddings are similar enough in S. Appendix C.2 contains a precise description of the attention mechanism. Output Attention. The output attention mechanism together with the decoder (described below) serve as an apparatus to evaluate the world state modeled implicitly by the set of RNNs ({Fm}M m=1) at time t + 1 (for one-step forward models). Given a query location xq X and its corresponding embedding sq, the output attention mechanism polls the RNNs Fm whose embeddings pm are similar enough to sq, as measured by the kernel Z. Denoting by hmj the j-th component of hm t+1 and by dq j the j-th component of the vector dq t+1 associated with the query location xq, we have: m Z(sq, pm) hmj (10) Published as a conference paper at ICLR 2021 Decoder. The decoder D is a parameterized function that predicts the observation ˆOq t+1 O at xq given the representation dq t+1 from the output attention. This concludes the description of the generic architecture, which allows for flexibility in the choice of the RNN architecture (i.e., the internal architecture of Fm). In practice, we find Gated Recurrent Units (GRUs) (Cho et al., 2014) to work well, and call the resulting model Spatially Structured GRU or S2GRU. Moreover, Relational Memory Cores (RMCs) (Santoro et al., 2018) also profit from our architecture (with a modification detailed in Appendix E.3), and we call the resulting model S2RMC. 5 RELATED WORK Problem Setting. Recall that the problem setting we consider is one where the environment offers local (partial) views into its global state paired with the corresponding spatial locations. With Generative Query Networks (GQNs), Eslami et al. (2018) investigate a similar setting where the 2D images of 3D scenes are paired with the corresponding viewpoint (camera position, yaw, pitch and roll). Given that GQNs are feedforward models, they do not consider the dynamics of the underyling scene and as such cannot be expected to be consistent over time (Kumar et al., 2018). Singh et al. (2019) and Kumar et al. (2018) propose variants that are temporally consistent, but unlike us, they do not focus on the problem of predicting the future state of the system. Modularity. Modularity has been a recurring topic in the context of meta-learning (Alet et al., 2018; Bengio et al., 2019; Ke et al., 2019), sequence modeling (Ghahramani & Jordan, 1996; Henaff et al., 2016; Li et al., 2018; Goyal et al., 2019; Mei et al., 2020; Mittal et al., 2020) and beyond (Jacobs et al., 1991; Shazeer et al., 2017; Parascandolo et al., 2017). In the context of RNNs, Li et al. (2018) explore a setting where the recurrent units operate entirely independently of each other. Closer to our work, Goyal et al. (2019) explores the setting where autonomous RNN modules interact with each other via the bottleneck of sparse attention. However, instead of leveraging the spatial structure of the environment, they induce sparsity using a scheme inspired by the k-winners-take-all principle (Majani et al., 1988) where only the k modules that attend the most to the input are activated and propagate their state forward, whereas the remaining modules that do not receive an input follow default dynamics in that their hidden states are not updated. This can be contrasted with S2RMs, where the modules that do not receive inputs may still evolve their states forward in time, reflecting that the environment may evolve even when no observations are available. Attention Mechanisms and Information Flow. Attention mechanisms have been used to attenuate the flow of information between components of the network, e.g. (Graves et al., 2014; 2016; Santoro et al., 2018; Ke et al., 2018; Veliˇckovi c et al., 2017; Battaglia et al., 2018). There is a growing interest in efficient attention mechanisms for use in transformers (Vaswani et al., 2017), and like KMDPA, some recently proposed methods rely on learned sparsity (Kitaev et al., 2020; Tay et al., 2020). However, these induce sparsity by dynamically clustering or sorting based on content, while we make explicit use of the spatial information accompanying observations to learn a spatially-grounded sparsity pattern. Moreover, mechanisms for spatial attention have also been studied (Jaderberg et al., 2015; Wang et al., 2017; Zhang et al., 2018; Parmar et al., 2018), but they typically operate on image pixels. Our setting is different in that we do not assume that the world-state (from which we sample local observations) can be represented as an image. 6 EXPERIMENTS In this section, we present a selection of experiments to quantitatively evaluate S2RMs and gauge their performance against strong baselines on two data domains, namely video prediction from crops on the well-known bouncing-balls domain and multi-agent world modelling from partial observations in the challenging Starcraft2 domain. We also include qualitative visualizations on a grid-world task in Appendix A. Additional tables, results and supporting plots can be found in Appendix F. Baselines. To draw fair comparisons between various RNN architectures, we require an architectural scaffolding that is agnostic to the number of observations A, is invariant to the ordering of {(xa t , Oa t )}A a with respect to a and features a querying mechanism to extract a predicted observation Oq t at a given query location xq in a future time-step t > t. Fortunately, it is possible to obtain a performant class of models fulfilling our requirements by extending prior work on Generative Query Networks or GQNs (Eslami et al., 2018). The resulting model has three components: an encoder, Published as a conference paper at ICLR 2021 Figure 3: Rollouts (OOD) with 5 bouncing balls, from top to bottom: ground-truth, S2GRU, RIMs, RMC, LSTM. Note that all models were trained on sequences with 3 bouncing balls, and the global state is reconstructed by stitching together 16 patches of size 11 11 produced by the models (queried on a 4 4 grid). Gist: S2GRU succeeds at keeping track of all bouncing balls over long rollout horizons (25 frames). 1 2 3 4 5 6 Number of Bouncing Balls Binary F1-Score Training Distr. S2GRU* TTO RMC LSTM RIMS 1 2 3 4 5 6 Number of Bouncing Balls Balanced Accuracy Training Distr. S2GRU* TTO RMC LSTM RIMS Figure 4: Performance metrics on OOD one-step forward prediction task. Gist: S2GRU outperforms all RNN baselines OOD. a RNN, and a decoder, which we describe in detail in Appendix D. In our experiments, we fix the encoder and decoder to be essentially identical to those in S2RMs, but vary the architecture of the RNN, where we experiment with LSTMs (Hochreiter & Schmidhuber, 1997), RMCs (Santoro et al., 2018) and RIMs (Goyal et al., 2019). As a sanity check, we also show results with a Time Travelling Oracle (TTO), which at time-step t has access to the (partially observed) state at t + 1. Its purpose is to verify that the architectural scaffolding around the baseline RNNs (defined in Appendix D) does not constrain their performance and that the comparison to S2RMs is indeed fair. Video Prediction from Crops. We consider the problem of predicting the future frames of simulated videos of balls bouncing in a closed box, given only crops from the past video frames which are centered at known pixel positions. Using the notation introduced in Section 2: at every time step t, we sample A = 10 pixel positions {xa t }10 a=1 from the t-th full video frame ot of size 48 48. Around the sampled central pixel positions xa t , we extract 11 11 crops, which we use as the local observations Oa t . The task now is to predict 11 11 crops Oq t corresponding to query central-pixel-positions xq t at a future time-step t > t. Observe that at any given time-step t, the model has access to at most 52% of the global video frame assuming that the crops never overlap (which is rather unlikely). Having trained on the training dataset with 3 bouncing balls, we evaluate the forward-prediction performance on all test datasets with 1 to 6 bouncing balls. Given that we treat the prediction problem as a pixel-wise binary classification problem, we report the balanced accuracy (i.e. arithmetic mean of recall and specificity) or F1-scores (i.e. harmonic mean of precision and recall) to account for class-imbalance. In Figure 4, we see that S2GRUs out-perform Figure 5: Visualization of the spatial locations each module is responsible for modeling (i.e. the enclaves Xm, defined in Section 3). The central ball does not bounce, i.e. it is stationary in all sequences. Gist: the modules focus attention on challenging regions, e.g. the corners of the arena and the surface of the fixed ball. Published as a conference paper at ICLR 2021 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction Views Available Balanced Accuracy Number of Balls = 1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction Views Available Number of Balls = 2 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction Views Available Number of Balls = 3 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction Views Available Number of Balls = 4 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction Views Available Number of Balls = 5 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction Views Available Number of Balls = 6 Model Original Without Non-Local Term at Inp-Attn Without Local Term at Inp-Attn Without Non-Local Term at Int-Cell-Attn Without Local Term at Int-Cell-Attn Figure 6: Ablation over the local and the non-local terms in the input and inter-cell attention mechanisms (KMDPAs). For a set number of bouncing balls, each sub-plot shows how the balanced accuracy changes with the fraction of views (crops) available to the model. Gist: Both local and non-local terms in KMDPA contribute to the overall performance. The non-local term is more important for the input attention, whereas the local term is more important for the inter-cell attention. 0.0 0.2 0.4 0.6 0.8 Agent Drop Probability Unit Type (Macro) F1-Score S2GRU* S2RMC* TTO RMC LSTM 0.0 0.2 0.4 0.6 0.8 Agent Drop Probability Friendly Marker F1-Score S2GRU* S2RMC* TTO RMC LSTM 0.0 0.2 0.4 0.6 0.8 Agent Drop Probability HECS Negative MSE S2GRU* S2RMC* TTO RMC LSTM 0.0 0.2 0.4 0.6 0.8 Agent Drop Probability Log Likelihood S2GRU* S2RMC* TTO RMC LSTM Figure 8: Performance metrics (larger the better) as a function of the probability that an agent will not supply information to the world model but still query it. Gist: while all models lose performance as fewer agents share observations, we find S2RMs to be most robust. all non-oracle baselines on the one-step forward prediction task and strike a good balance with regard to in-distribution and OOD performance. In Figure 3, we qualitatively show reconstructions from 25 step rollouts on the out of distribution dataset with 5 balls to demonstrate that S2GRUs can perform OOD rollouts over long horizons without losing track of balls. 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction of Modules Active Balanced Accuracy Number of Balls 1 2 3 4 5 6 Figure 7: The effect of removing random modules at test time. Gist: Performance degrades gracefully as modules are removed, suggesting that modules can function even when their counterparts are removed, and that there is limited co-adaptation between them. Figure 6 shows the result of an ablation study where we disable the local and non-local terms in each of the two KMDPAs while keeping everything else same. We see that both local and nonlocal terms contribute to the overall performance; moreover, the input attention relies on the non-local term and the performance severely affected by its absence, whereas the inter-cell attention is dependent on the local-term to yield good performance. This suggests that the modules indeed rely on the content of the input observations as they select their inputs, and learning an interaction topology between modules is a strong contributor to the final performance. In Figure 5, we show for each module its corresponding enclave, which is the spatial region that it is responsible for modelling, i.e. for pixels at position x, we plot {Z(P(x), pm)}10 m=1 (cf. Section 4). We find that the modules learn to share the responsibility of modelling the entire spatial domain. Finally, in Figure 7 we see the effect of removing (randomly sampled) modules at test time, i.e. without additional retraining. The performance degrades gracefully as fewer modules are available, suggesting that the individual modules can function while other modules are missing. We include details and additional results in Appendix F.1. Multi-Agent World Modeling on Starcraft2. In Section 2, we formulated the problem of modeling what we called the world state o of a dynamical system φ given local observations {(xa t , Oa t )}A a=1 where Oa t = φ(t, o)(xa t ). Under certain restrictions, this problem can be mapped to that of multiagent world modeling from partial and local observations, allowing us to evaluate the proposed model in a rich and challenging setting. In particular, we consider environments that are (a) spatial, i.e. all agents a have a well-defined and known location xa t (at time t), (b) the agents actions ua t are local, in that their effects propagate away (from the agent) only at a finite speed, (c) the observations are local and centered around agents, in the sense that the agent only observes the events in its local vicinity, i.e., Oa t . Observe that we do not fix the number of agents in the environment, and allow Published as a conference paper at ICLR 2021 for agents to dynamically enter or exit the environment. Now, the task is: given observations Oa t from a team of (cooperating) agents at position xa t and their corresponding actions ua t , predict the observation Oq t that would be made by an agent at time t = t + 1 if it were at position xq. UT-F1 FM-F1 NMSE LL (1s2z) LSTM 0.6267 0.8464 -0.0040 -0.0382 RMC 0.6839 0.8597 -0.0033 -0.0334 S2GRU 0.7488 0.8627 -0.0023 -0.0233 S2RMC 0.7317 0.8563 -0.0026 -0.0261 (TTO) 0.7518 0.8883 -0.0025 -0.0259 (5s3z) LSTM 0.4975 0.7123 -0.0134 -0.1251 RMC 0.5414 0.7486 -0.0132 -0.1167 S2GRU 0.5310 0.7058 -0.0119 -0.1108 S2RMC 0.5114 0.6945 -0.0124 -0.1205 (TTO) 0.6115 0.7872 -0.0107 -0.0940 Table 1: Performance metrics on OOD scenarios 1s2z and 5s3z (larger numbers are better): unittype macro F1 score (UT-F1), friendly-marker F1 score (FM-F1), HECS Negative Mean Squared Error (NMSE) and Log Likelihood (LL). Starcraft2 unit-micromanagement (Samvelyan et al., 2019) is a multi-agent reinforcement learning benchmark, wherein teams of heterogeneously typed units must defeat a team of opponents in melee and ranged combat. The observations Oa t and actions ua t are both multi-channel images represented in polar coordinates centered around the agent position xt a. The field of view (FOV) of each agent is therefore a circle of fixed radius centered around it. The channels of the image correspond to (a) a binary indicator marking whether a position in FOV is occupied by a living friendly agent (friendly marker), (b) a categorical indicator marking the type of living units at a given position in FOV (unit-type marker), and (c) four channels marking the health, energy, weapon-cooldown and shields (HECS markers) of all agents in FOV. With a heuristic, we gather a total of 9K trajectories ({xa t , Oa t , ua t }A a=1)100 t=1 spread over three training scenarios, corresponding to 1c3s5z3, 3s5z and 2s5z in Samvelyan et al. (2019). We also sample 1K trajectories (each) from two OOD scenarios 1s2z and 5s3z. Details in Appendix B.1. Having trained all models on scenarios 1c3s5z, 3s5z and 2s5z, we test their robustness to dropped agents (Figure 8) and their performance on OOD scenarios (Table 1). We only include baselines that achieve similar or better validation scores than S2RMs. Figure 8 shows that S2RMCs remain robust when fewer agents supply their observations to the world model, whereas Table 1 shows that S2GRUs outperforms the baselines in the OOD scenario 1s2z but is matched by RMCs in 5s3z (see Appendix F.2 for details). The strong performance of RMCs suggests that the task benefits from the inductive bias of relational memory. One hypothesis as to why is that the pace of the considered environments requires fast communication between agents, which can be achieved by a shared memory where all agents may read from and write to. Further, we observe that while the oracle (TTO) can generalize well out of distribution, Figure 8 shows that it is less robust to the number of available observations. This is explained by the fact that unlike recurrent models, TTO does not leverage the temporal dynamics to fill in the missing information due to fewer available observations. This pattern also holds for the bouncing balls task, cf. Figures 21e and 20e in Appendix F.1. CONCLUSIONS, LIMITATIONS AND FUTURE WORK We proposed Spatially Structured Recurrent Modules, a new class of models constructed to jointly leverage both spatial and modular structure in data, and explored its potential in the challenging problem setting of predicting the forward dynamics from partial observations at known spatial locations. In the tasks of video prediction from crops and multi-agent world modeling in the Starcraft2 domain, we found that it compares favorably against strong baselines in terms of out-ofdistribution generalization and robustness to the number of available observations. Future work may attempt to extend the idea to parallel-in-time methods like universal transformers (Dehghani et al., 2018) and thereby address the computational bottleneck of recurrent processing, which is a current limitation. Another interesting avenue of research could be to explore how latent random variables can be used in tandem with the spatial structure to obtain a variational version of S2RMs. Finally, efficient implementations using block-sparse methods (Gray et al., 2017) might hold the key to unlock applications to significantly larger scale spatiotemporal forecasting problems encountered in domains like climate change research (Rolnick et al., 2019). 3Here, the code 1c3s5z refers to a scenario where each team comprises 1 colossus (1c), 3 stalkers (3s), and 5 zealots (5z). Published as a conference paper at ICLR 2021 ACKNOWLEDGEMENTS The authors would like to thank Georgios Arvanitidis, Luigi Gresele, Michael Cobos for their feedback on the paper, and Murray Shanahan for the discussions. The authors also acknowledge the important role played by their colleagues at the Empirical Inference Department of MPI-IS T ubingen and Mila throughout the duration of this work. Ferran Alet, Tom as Lozano-P erez, and Leslie P Kaelbling. Modular meta-learning. ar Xiv preprint ar Xiv:1806.10166, 2018. Peter W Battaglia, Jessica B Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi, Mateusz Malinowski, Andrea Tacchetti, David Raposo, Adam Santoro, Ryan Faulkner, et al. Relational inductive biases, deep learning, and graph networks. ar Xiv preprint ar Xiv:1806.01261, 2018. Yoshua Bengio, Tristan Deleu, Nasim Rahaman, Rosemary Ke, S ebastien Lachapelle, Olexa Bilaniuk, Anirudh Goyal, and Christopher Pal. A meta-transfer objective for learning to disentangle causal mechanisms. ar Xiv preprint ar Xiv:1901.10912, 2019. Alberto Cenzato, Alberto Testolin, and Marco Zorzi. On the difficulty of learning and predicting the long-term dynamics of bouncing objects. ar Xiv preprint ar Xiv:1907.13494, 2019. Kyunghyun Cho, Bart Van Merri enboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, and Yoshua Bengio. Learning phrase representations using rnn encoder-decoder for statistical machine translation. ar Xiv preprint ar Xiv:1406.1078, 2014. Mostafa Dehghani, Stephan Gouws, Oriol Vinyals, Jakob Uszkoreit, and Łukasz Kaiser. Universal transformers. ar Xiv preprint ar Xiv:1807.03819, 2018. Christian Eichenberger, Adrian Egli, Mattias Ljungstr om, Sharada Mohanty, Guillaume Mollard, Erik Nygren, Giacomo Spigler, and Jeremy Watson. Flatland, 2019. URL http://flatland-rl-docs.s3-website.eu-central-1.amazonaws.com/ readme.html. SM Ali Eslami, Danilo Jimenez Rezende, Frederic Besse, Fabio Viola, Ari S Morcos, Marta Garnelo, Avraham Ruderman, Andrei A Rusu, Ivo Danihelka, Karol Gregor, et al. Neural scene representation and rendering. Science, 360(6394):1204 1210, 2018. Gregory E Fasshauer. Positive definite kernels: past, present and future. 2011. Marco Fraccaro, Simon Kamronn, Ulrich Paquet, and Ole Winther. A disentangled recognition and nonlinear dynamics model for unsupervised learning. In Advances in Neural Information Processing Systems, pp. 3601 3610, 2017. Marta Garnelo, Jonathan Schwarz, Dan Rosenbaum, Fabio Viola, Danilo J Rezende, SM Eslami, and Yee Whye Teh. Neural processes. ar Xiv preprint ar Xiv:1807.01622, 2018. Zoubin Ghahramani and Michael I Jordan. Factorial hidden markov models. In Advances in Neural Information Processing Systems, pp. 472 478, 1996. Anirudh Goyal, Alex Lamb, Jordan Hoffmann, Shagun Sodhani, Sergey Levine, Yoshua Bengio, and Bernhard Sch olkopf. Recurrent independent mechanisms. ar Xiv preprint ar Xiv:1909.10893, 2019. Alex Graves, Greg Wayne, and Ivo Danihelka. Neural turing machines. ar Xiv preprint ar Xiv:1410.5401, 2014. Alex Graves, Greg Wayne, Malcolm Reynolds, Tim Harley, Ivo Danihelka, Agnieszka Grabska Barwi nska, Sergio G omez Colmenarejo, Edward Grefenstette, Tiago Ramalho, John Agapiou, et al. Hybrid computing using a neural network with dynamic external memory. Nature, 538(7626): 471 476, 2016. Published as a conference paper at ICLR 2021 Scott Gray, Alec Radford, and Diederik P Kingma. Gpu kernels for block-sparse weights. ar Xiv preprint ar Xiv:1711.09224, 3, 2017. Mikael Henaff, Jason Weston, Arthur Szlam, Antoine Bordes, and Yann Le Cun. Tracking the world state with recurrent entity networks, 2016. Sepp Hochreiter and J urgen Schmidhuber. Long short-term memory. Neural computation, 9(8): 1735 1780, 1997. Robert A Jacobs, Michael I Jordan, Steven J Nowlan, and Geoffrey E Hinton. Adaptive mixtures of local experts. Neural computation, 3(1):79 87, 1991. Max Jaderberg, Karen Simonyan, Andrew Zisserman, and Koray Kavukcuoglu. Spatial transformer networks, 2015. Nan Rosemary Ke, Anirudh Goyal ALIAS PARTH GOYAL, Olexa Bilaniuk, Jonathan Binas, Michael C Mozer, Chris Pal, and Yoshua Bengio. Sparse attentive backtracking: Temporal credit assignment through reminding. In Advances in neural information processing systems, pp. 7640 7651, 2018. Nan Rosemary Ke, Olexa Bilaniuk, Anirudh Goyal, Stefan Bauer, Hugo Larochelle, Chris Pal, and Yoshua Bengio. Learning neural causal models from unknown interventions. ar Xiv preprint ar Xiv:1910.01075, 2019. Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. ar Xiv preprint ar Xiv:1412.6980, 2014. Nikita Kitaev, Łukasz Kaiser, and Anselm Levskaya. Reformer: The efficient transformer. ar Xiv preprint ar Xiv:2001.04451, 2020. Jannik Kossen, Karl Stelzner, Marcel Hussing, Claas Voelcker, and Kristian Kersting. Structured object-aware physics prediction for video modeling and planning. ar Xiv preprint ar Xiv:1910.02425, 2019. Ananya Kumar, SM Eslami, Danilo J Rezende, Marta Garnelo, Fabio Viola, Edward Lockhart, and Murray Shanahan. Consistent generative query networks. ar Xiv preprint ar Xiv:1807.02033, 2018. Shuai Li, Wanqing Li, Chris Cook, Ce Zhu, and Yanbo Gao. Independently recurrent neural network (indrnn): Building a longer and deeper rnn. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5457 5466, 2018. E Majani, Ruth Erlanson, and Yaser Abu-Mostafa. On the k-winners-take-all network. Advances in neural information processing systems, 1:634 642, 1988. Hongyuan Mei, Guanghui Qin, Minjie Xu, and Jason Eisner. Informed temporal modeling via logical specification of factorial {lstm}s, 2020. URL https://openreview.net/forum? id=S1ghzl HFPS. Djordje Miladinovi c, Muhammad Waleed Gondal, Bernhard Sch olkopf, Joachim M Buhmann, and Stefan Bauer. Disentangled state space representations. ar Xiv preprint ar Xiv:1906.03255, 2019. Ben Mildenhall, Pratul P Srinivasan, Matthew Tancik, Jonathan T Barron, Ravi Ramamoorthi, and Ren Ng. Nerf: Representing scenes as neural radiance fields for view synthesis. ar Xiv preprint ar Xiv:2003.08934, 2020. Sarthak Mittal, Alex Lamb, Anirudh Goyal, Vikram Voleti, Murray Shanahan, Guillaume Lajoie, Michael Mozer, and Yoshua Bengio. Learning to combine top-down and bottom-up signals in recurrent neural networks with attention over modules. ar Xiv preprint ar Xiv:2006.16981, 2020. Giambattista Parascandolo, Niki Kilbertus, Mateo Rojas-Carulla, and Bernhard Sch olkopf. Learning independent causal mechanisms. ar Xiv preprint ar Xiv:1712.00961, 2017. Niki Parmar, Ashish Vaswani, Jakob Uszkoreit, Łukasz Kaiser, Noam Shazeer, Alexander Ku, and Dustin Tran. Image transformer, 2018. Published as a conference paper at ICLR 2021 Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas Kopf, Edward Yang, Zachary De Vito, Martin Raison, Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. Pytorch: An imperative style, high-performance deep learning library. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d Alch e-Buc, E. Fox, and R. Garnett (eds.), Advances in Neural Information Processing Systems 32, pp. 8024 8035. Curran Associates, Inc., 2019. Ali Rahimi and Benjamin Recht. Random features for large-scale kernel machines. Advances in neural information processing systems, 20:1177 1184, 2007. David Rolnick, Priya L Donti, Lynn H Kaack, Kelly Kochanski, Alexandre Lacoste, Kris Sankaran, Andrew Slavin Ross, Nikola Milojevic-Dupont, Natasha Jaques, Anna Waldman-Brown, et al. Tackling climate change with machine learning. ar Xiv preprint ar Xiv:1906.05433, 2019. Mikayel Samvelyan, Tabish Rashid, Christian Schroeder de Witt, Gregory Farquhar, Nantas Nardelli, Tim G. J. Rudner, Chia-Man Hung, Philip H. S. Torr, Jakob Foerster, and Shimon Whiteson. The starcraft multi-agent challenge, 2019. Adam Santoro, David Raposo, David G Barrett, Mateusz Malinowski, Razvan Pascanu, Peter Battaglia, and Timothy Lillicrap. A simple neural network module for relational reasoning. In Advances in neural information processing systems, pp. 4967 4976, 2017. Adam Santoro, Ryan Faulkner, David Raposo, Jack Rae, Mike Chrzanowski, Theophane Weber, Daan Wierstra, Oriol Vinyals, Razvan Pascanu, and Timothy Lillicrap. Relational recurrent neural networks. In Advances in Neural Information Processing Systems, pp. 7299 7310, 2018. Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton, and Jeff Dean. Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. ar Xiv preprint ar Xiv:1701.06538, 2017. Gautam Singh, Jaesik Yoon, Youngsung Son, and Sungjin Ahn. Sequential neural processes, 2019. Chen Sun, Per Karlsson, Jiajun Wu, Joshua B Tenenbaum, and Kevin Murphy. Predicting the present and future states of multi-agent systems from partially-observed visual data. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum? id=r1xd H3Cc KX. Yi Tay, Dara Bahri, Liu Yang, Donald Metzler, and Da-Cheng Juan. Sparse sinkhorn attention. ar Xiv preprint ar Xiv:2002.11296, 2020. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in neural information processing systems, pp. 5998 6008, 2017. Petar Veliˇckovi c, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, and Yoshua Bengio. Graph attention networks. ar Xiv preprint ar Xiv:1710.10903, 2017. Oriol Vinyals, Timo Ewalds, Sergey Bartunov, Petko Georgiev, Alexander Sasha Vezhnevets, Michelle Yeo, Alireza Makhzani, Heinrich K uttler, John Agapiou, Julian Schrittwieser, et al. Starcraft ii: A new challenge for reinforcement learning. ar Xiv preprint ar Xiv:1708.04782, 2017. Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. Non-local neural networks, 2017. Nicholas Watters, Daniel Zoran, Theophane Weber, Peter Battaglia, Razvan Pascanu, and Andrea Tacchetti. Visual interaction networks: Learning a physics simulator from video. In Advances in neural information processing systems, pp. 4539 4547, 2017. Han Zhang, Ian Goodfellow, Dimitris Metaxas, and Augustus Odena. Self-attention generative adversarial networks. ar Xiv preprint ar Xiv:1805.08318, 2018. Ellen D. Zhong, Tristan Bepler, Joseph H. Davis, and Bonnie Berger. Reconstructing continuous distributions of 3d protein structure from cryo-em images, 2020. Published as a conference paper at ICLR 2021 A QUALITATIVE VISUALIZATIONS A GRID-WORLD NAVIGATION TASK In this section, we show qualitative results on a grid-world task defined in Eichenberger et al. (2019), which formulates the problem of navigation on a railway network in a multi-agent reinforcement learning framework. The environment comprises a network of railroads, on which agents (trains) may move in order to reach their destination. In our experiments, the entire railway network is defined on a 60 60 grid-world and we let each agent only observe a partial and local view of the environment, which is a 5 5 crop centered around itself. We gather 10000 multi-agent trajectories with 10 agents and maximum length 128, from which we use 8000 for training and reserve 2000 for validation. We train S2GRU with 10 modules for 100 epochs and early stop when the validation loss is at its minimum. With the trained model, we visualize the following two things. First: for each module Fm, we visualize the spatial locations it may attend to. To this end, we consider all 60 60 = 3600 pixel locations in the grid-world, say xij where i, j {1, ..., 60}. For each such xij, we evaluate the quantity: X m ij = Z(pm, sij) (11) where sij = P(xij) (see Eqn 3), pm is the embedding of module Fm and X m is a 60 60 image indexed by i and j, which we call the enclave of module Fm. Note that this is identical to what we visualize in Figure 5. Next: we identify each module with its enclave, and visualize the graph of interactions between them. In Figure 9, we plot as nodes the enclaves X m. Further, we draw an edge between enclaves X m and X n iff Z(pm, pn) > 0. We make the following two observations. First, the images in Figure 9 show that each module learns to account for a spatial region in the environment, as we imagined in Figure 1b in Section 1. Second, we find that the modules interact sparsely with each other while some modules learn to interact with up-to five other modules, other modules learn to operate independently. B DETAILED TASK DESCRIPTIONS B.1 STARCRAFT2 The Starcraft2 Environment we use is a modified version of the SMAC-Env proposed in Samvelyan et al. (2019) and built on Py SC2 wrapper around Blizzard SC2 API (Vinyals et al., 2017). Starcraft2 is a real-time-strategy (RTS) game where players are tasked with manufacturing and controlling armies of units (airborne or land-based) to defeat the opponent s army (where the opponent can be an AI or another human). The players must choose their alien race4 before starting the game; available options are Protoss, Terran and Zerg. All unit types (of all races) have their strengths and weaknesses against other unit types, be it in terms of maximum health, shields (Protoss), energy (Terran), DPS (damage per second, related to weapon cooldown), splash damage, or manufacturing costs (measured in minerals and vespene gas, which must be mined). The key engineering contribution of Samvelyan et al. (2019) is to repurpose the RTS game as a multi-agent environment, where the individual units in the army become individual agents5. The result is a rich and challenging environment where heterogeneous teams of agents must defeat each other in melee and ranged combat. The composition of teams vary between scenarios, of which Samvelyan et al. (2019) provide a selection. Further, new scenarios can be easily created with the SC2Map Editor, which allows for practically endlessly many possibilities. We build on Samvelyan et al. (2019) by modifying their environment to better expose the transfer and out-of-distribution aspects of the domain by (a) standardizing the state and action space across a large class of scenarios and (b) standardizing the unit stats to better reflect the game-defined notion of hit-points. 4Please note that this is a game-specific notion. 5Note that this is rather unconventional, since each player usually controls entire armies and must switch between macroand micro-management of units or unit-groups. Published as a conference paper at ICLR 2021 Figure 9: Joint visualization of spatial enclaves and the interaction graph between modules in the grid-world environment of Eichenberger et al. (2019), as detailed in Appendix A. The images show which spatial locations a module attends to via the local attention (spatial enclaves), whereas the presence of an edge indicates that the corresponding modules may interact via inter-cell attention. Gist: The modules indeed learn a notion of spatial locality, while interacting sparsely with each other. Published as a conference paper at ICLR 2021 (a) 1s2z (1 Stalker and 2 Zealots per team). (b) 5s3z (5 Stalkers and 3 Zealots per team). (c) 2s3z (2 Stalkers and 3 Zealots per team). (d) 3s5z (3 Stalkers and 5 Zealots per team). (e) 1c3s5z (1 Colossus, 3 Stalkers and 5 Zealots per team). Figure 10: Human readable illustrations of the Starcraft2 (SMAC) scenarios we consider in this work. Figures 10a and 10b show the OOD scenarios, whereas Figures 10c, 10d and 10e show the training scenarios (provided by Samvelyan et al. (2019)). B.1.1 STANDARDIZED STATE SPACE FOR ALL SCENARIOS In the environment provided by Samvelyan et al. (2019), the dimensionality of the vector state space varies with the number of friendly and enemy agents, which in turn varies with the scenario. While this is not an issue in the typical use case of training MARL agents in a fixed scenario, it is not convenient for designing models that seamlessly handle multiple scenarios. In the following, we propose an alternate state representation that preserves the spatial structure and is consistent across multiple scenarios. Instead of representing the state of an agent a with a vector of variable dimension, we represent it with a multi-channel polar image Ia of shape C I J, where C is the number of channels and (I, J) is the image size. Given the radial and angular resolutions ρ and ϕ (respectively), the pixel coordinate i = 0, ..., I 1, j = 0, ..., J 1 corresponds to coordinates (i ρ, j ϕ) with respect to a polar coordinate system centered on the agent a, where the positive x-axis (j = 0) points towards the east. Further, the field of view (FOV) of an agent is characterized by a circle of radius I ρ centered on the agent at 2D game-coordinates xa = (xa 1, xa 2), to which the Starcraft2 API (Vinyals et al., 2017) provides raw access. The polar image Ia therefore provides an agent-centric view of the environment, where pixel coordinates i, j in Ia can be mapped to global game coordinates x = (x1, x2) in FOV via: x1 = i ρ cos [j ϕ] + xa 1 (12) x2 = i ρ sin [j ϕ] + xa 2 (13) In what follows, we denote this transformation with Ta, as in Ta(i, j) = (x1, x2). Now, the channels in the polar image can encode various aspects of the observation; in our case: friendly markers (one channel), unit-type markers (nine channels, one-hot), health-energy-cooldown- Published as a conference paper at ICLR 2021 shields (HECS, four channels) and terrain height (one channel). As an example, let us consider the friendly markers, which is a binary indicator marking units that are friendly. If we have an agent at game position (x1, x2) that is friendly to agent a, then we would expect the pixel coordinate (i, j) = T 1 a (x1, x2) of the corresponding channel in the polar image Ia to be 1, but 0 otherwise. Likewise, the value of I at the channels corresponding to HECS at pixel position i, j gives the HECS of the corresponding unit6 at Ta(i, j). This representation has the following advantages: (a) it does not depend on the number of units in the field of view, (b) it exposes the spatial structure in the arrangement of units which can naturally processed by convolutional neural networks (e.g. with circular convolutions). Nevertheless, it has the disadvantage that the positions are quantized to pixels, but the euclidean distance between the locations represented by pixels (i, j) and (i, j + 1) increases with increasing i. Consequently, this representation may not remain suitable for larger FOVs. Further, this representation is also appropriate for the action space. Given an agent, we represent the one-hot categorical actions of all friendly agents in FOV as a multi-channel polar image. In this representation, the pixel position i, j gives the action taken by an agent at at position Ta(i, j). Unfriendly agents get assigned an unknown action , whereas positions not occupied by a living agent are assigned a no-op action. B.1.2 STANDARDIZED UNIT STATS At any given point in time, an active unit in Starcraft2 has certain stats, e.g. its health, energy (Terran), shields (Protoss) and weapon-cooldown (for armed units). A large and expensive unit-type like the Colossus has more max-health (hit-points) than smaller units like Stalkers and Marines7. Likewise, unit-types differ in the rate at which they deal damage (measured in damage-per-second or DPS, excluding splash damage), which in turn depends on the cooldown duration of the active weapon. Now, the environment provided by Samvelyan et al. (2019) normalizes the stats by their respective maximum value, resulting in values between 0 and 1. However, given that different units may have different normalization, the stats are rendered incomparable between unit types (without additionally accounting the unit-type). We address this by standardizing stats (instead of normalizing) by dividing them by a fixed value. In this scheme, the stats are scaled uniformly across all unit-types, enabling models to directly rely on them instead of having to account for the respective unit-types. B.2 VIDEO PREDICTION FROM CROPS ON THE BOUNCING BALLS TASK The bouncing balls task is a well-known test-bed for evaluating the performance of video prediction models (Fraccaro et al., 2017; Watters et al., 2017; Miladinovi c et al., 2019; Kossen et al., 2019; Cenzato et al., 2019). We modify the problem by introducing partial observability concretely, instead of providing the model with the full image frames, we only provide it with crops at randomly sampled locations. As mentioned in Section 6, at every time step t we sample A = 10 pixel positions {xa t }10 a=1 from the t-th full video frame ot of size 48 48. Around the sampled central pixel positions xa t , we extract 11 11 crops, which we use as the local observations Oa t . The task now is to predict 11 11 crops Oq t corresponding to query central pixel positions xq t at a future time-step t > t. Observe that at any given time-step t, the model has access to at most 52% of the global video frame assuming that the crops never overlap (which is rather unlikely). We train all models on a training dataset of 20K video sequences with 100 frames of 3 balls bouncing in an arena of size 48 48. We also include an additional fixed ball in the center to make the task more challenging. We use another 1K video sequences of the same length and the same number of balls as a held-out validation set. In addition, we also have 5 out-of-distribution (OOD) test sets with various number of bouncing balls (ranging from 1 to 6) and each containing 1K sequences of length 100. 6If health drops to zero, the unit is considered dead and the representation does not differentiate between dead and absent units. 7These stats may change with game-versions, and are catalogued here: https://liquipedia.net/ starcraft2/Units_(Star Craft). Published as a conference paper at ICLR 2021 In Figure 4, for each number of balls (i.e. point on the x-axis), we plot the respective metrics which are aggregated over 10 randomly selected 11 11 crops of a total of 100000 frames spread over 1000 trajectories with 100 frames each. C PRECISE DESCRIPTION OF ATTENTION MECHANISMS C.1 INPUT ATTENTION Recall from Section 4 that the input attention mechanism is a mapping between sets: namely, from that of observation encodings {ea t }A a=1 to that of RNN inputs {um t }M m=1. In what follows, we use the einsum notation8 to succintly describe the exact mechanism. But before that, we repeat the definition of the truncated spherical Gaussian kernel (Fasshauer, 2011) to quantify the similarity between two points p, s S: Z(p, s) = exp [ 2ϵ(1 p s)], if p s τ 0, otherwise (14) where ϵ R+ and τ [ 1, 1) are hyper-parameters (kernel bandwidth and truncation parameter, respectively), and 0 Z 1 since p and s are unit vectors. We observe that both τ and ϵ controls the sparsity of the kernel: τ determines the size of the neighborhood of p, i.e. the size of the set B(p) S of all s B(p) such that p s τ and accordingly Z(p, s) > 0, whereas the bandwidth ϵ controls how the attention decays inside B(p). Intuitively, τ determines a lower bound to the amount of sparsity that the kernel induces (irrespective of the bandwidth ϵ), whereas for fixed τ, sparsity can be increased by increasing ϵ. We find τ [ 1, 0.6] and ϵ [0.9, 2] to work well; setting τ and ϵ to much larger values destabilizes the training due to excessive sparsity, whereas setting ϵ to much smaller values results in Z being flat inside B(p) and therefore poor propagation of gradients. Now, we use k to index the attention heads, d to index the dimension of the key and query vectors, and denote with eai the i-th component of ea t and with hmj the j-th component of hm t . Given learnable parameters Θ(K), Θ(Q), Θ(V ), we obtain: Qakd = eaiΘ(Q) ikd Kmkd = hmjΘ(K) jkd (15) Vakv = eaiΘ(V ) ikv Wmak = Qakd Kmkd (16) Wmak = sma( Wmak) W (L) ma = Z(pm, sa) (17) Wmak = W (L) ma Wmak um(kv) = Wmak Vakv (18) where: sma denotes softmax along the a-dimension, W (L) is what we will call the local weights, we omit the time subscript in sa for notational clarity, and um(kv) is the (kv)-th component of a vector um. Finally, we obtain the components umi of RNN inputs um t via a gating operation: umi = G(inp) m bmi + (1 G(inp) m ) umi (19) where the gating weight G(inp) m (0, 1) is obtained by passing umi and bmi = W (L) ma eai through a two-layer MLP with sigmoidal output (in parallel across m). Now, observe that by weighting the MHDPA attention outputs ( W in Equation 18) by the kernel Z (via W (L)), we construct a scheme where the interaction between input Oa t and RNN Fm is allowed only if the embedding sa t of the corresponding position xa t has a large enough cosine similarity ( τ) to the embedding pm of Fm. This partially implements the assumption of Locality of Observation detailed in Section 3. C.2 INTER-CELL ATTENTION Recall from Section 4 that the inter-cell attention maps the hidden states of each RNN {hm t }M m=1 to the set of aggregated hidden states { hm t }M m=1, thereby enabling interaction between the RNNs Fm. While its mechanism is identical to that of the input attention, we formulate it below for completeness. 8Indices not appearing on both sides of an equation are summed over; this is implemented as einsum in most DL frameworks. Published as a conference paper at ICLR 2021 To proceed, we denote with hli the i-th component of hl t (in addition to the notation introduced before Equation 15), and take Φ(Q), Φ(K) and Φ(V ) to be learnable parameters. We have: Qmkd = hmjΦ(Q) jkd Klkd = hliΦ(K) ikd (20) Vlkv = hliΦ(V ) ikv Wmlk = Qmkd Klkd (21) Wmlk = sml( Wmlk) W (L) ml = Z(pm, pl) (22) Wmlk = Wmlk W (L) ml hm(kv) = Wmlk Vlkv (23) where hm(kv) is the (kv)-th component of a vector hm. Finally, the j-th component hmj of the aggregated hidden state hm t in Equation 4 is given by a gating operation: hmj = G(ic) m cmj + (1 G(ic) m ) hmj (24) where the gating weight G(ic) m (0, 1) is obtained by passing hmj and cmj = W (L) ml hlj through a two-layer MLP with sigmoid output (in parallel across m). The weighting by Z (in Equation 23, left) ensures that the interaction is constrained to be only between RNNs whose embeddings in S are similar enough, thereby implementing the assumption of Local Interactions between Sub-systems in Section 3. C.3 POSITIONAL ENCODING In Section 4, recall that we used the following positional embedding P: P(x) = s/ s S where (si+m, si+1+m) = (sin (xm/10000i), cos (xm/10000i)) (25) In this section, we explore how the choice of a positional embedding function P determines a function space of spatial functions (defined on X) that the local-attention can represent. To this end, consider the distance in S of a module with embedding p to an observation made at location x as a function of x, given by w(L)(x) = p P(x) (26) Here, the local weight of interaction between the module at p and an observation made at x is given by: Z(p, P(x)) = exp 2ϵ(1 w(L)(x)) , if w(L)(x) τ 0, otherwise (27) In particular, observe that in order for two locations x and y to be connected by the module, we require from w(L) that it be flexible enough such that w(L)(x) τ and w(L)(y) τ for a chosen τ. This flexibility stems from the fact that we implicitly express w(L) as a linear combination of sinusoidal basis functions with learned weights: j=0 [p2j cos (ωj x) + p2j+1 sin (ωj x)] (28) Here, p2j and p2j+1 are learnable parameters (as components of learnable vector p of dimension 2J), and ωj are frequency vectors. Now, if the dimension of the embedding vector p were to tend to infinity, we may have a growing number of frequencies ωj to gradually recover the full Fourier basis of L2(X) (assuming X is Euclidean for simplicity). In the limit, w(L)(x) can be an arbitrary function lying on a unit sphere in L2(X) (i.e. R |w(L)|2 = 1; recall that p is normed to unity). In other words, in a large dimensional embedding space, the system is afforded a large amount flexibility to learn any spatial structure or topology on X by connecting pairs of points x and y in X via a module. For computational tractability, however, we require p to be finite-dimensional, implying that ωj must be sampled. By using P as defined in Equation 25, we essentially sample ωj as coordinate (one-hot) vectors with log ωj sampled on a uniform grid. This sampling step is not only computationally favorable, but also justified in the theory of RKHS Rahimi & Recht (2007) use Bochner s theorem to show that any proper distribution p(ω) (from which ω can be sampled) leads to a feature map, Published as a conference paper at ICLR 2021 the inner product of which in expectation over p corresponds to a positive-definite kernel. The convergence to such a kernel is exponential in the number of samples (equivalent to the dimension of the embedding). Further, we note that while the sampling constrains the function space in which w(L) can lie in, we find (empirically) that this can in fact have a regularizing effect. Nevertheless, this raises the question whether other choices of a basis function are viable. We speculate that a polynomial basis (e.g. feature maps of a degree d polynomial kernel) might also be viable, but leave extensive exploration to future work. D BASELINE ARCHITECTURE As mentioned in Section 6: in order to ensure fair comparison between the baselines and our method, we describe a baseline architecture constructed to satisfy a few critically important desiderata that are naturally satisfied by S2RMs. Namely, (a) it must be parameterically agnostic to the number of available observations and (b) it must be invariant to the permutation of the observations. For this, we extend the framework Generative Query Networks (Eslami et al., 2018) by predicting the forward dynamics of an aggregated representation. While we invest effort in ensuring that the resulting class of models can perform at least as well as S2RMs on in-distribution (validation) data, we do not consider it a novel contribution of this work. Encoder. At a given timestep t, the encoder E jointly maps the embedding sa t S of the position xa t X and the corresponding observations Oa t to encodings ea t , which are then summed over a to obtain an aggregated representation: a=1 E(Oa t , sa t ) (29) The additive aggregation scheme we use is well known from prior work (Santoro et al., 2017; Eslami et al., 2018; Garnelo et al., 2018) and makes the model agnostic to A and to permutations of (xa t , Oa t ) over a. The encoder E is a seven-layer CNN with residual layers, and the positional embedding sa t is injected after the second convolutional layer via concatenation with the feature tensor. The exact architectures can be found in Appendices E.1 and E.2. RNN. The aggregated representation rt is used as an input to a RNN model F as following: ht+1, ct+1 = F(rt, ht, ct) (30) where ht and ct are hidden and memory states of the RNN F respectively. We experiment with various RNN models, including LSTMs (Hochreiter & Schmidhuber, 1997), RMCs (Santoro et al., 2018) and Recurrent Independent Mechanisms (RIMs) (Goyal et al., 2019). As a sanity check, we also show results with a Time Travelling Oracle (TTO), which has access to rt+1 (but at time step t), and produces ht+1 = FT T O(rt+1) with a two layer MLP FT T O. TTO therefore does not model the dynamics, but merely verifies that the additive aggregation scheme (Equation 29) and the querying mechanism (Equation 31) are sufficient for the task at hand. Decoder. Given the embedding sq of the query position xq, the decoder D predicts the corresponding observation ˆOq t+1: ˆOq t+1 = D(ht+1, sq) (31) We parameterize D with a deconvolutional network with residual layers, and inject the positional embedding of the query sq after a single convolutional layer by concatenating with the layer features (see Appendices E.1 and E.2). E HYPERPARAMETERS AND ARCHITECTURES E.1 ENCODER AND DECODER FOR BOUNCING BALLS The architectures of image encoder and decoder was fixed for all models after initial experimentation. We converged to the following architectures. Published as a conference paper at ICLR 2021 Convolution (IC 128) (kernel size 5, padding 0, stride 1) Convolution (128 128) (kernel size 3, padding 1, stride 1) Convolution (128 + CC 128) (kernel size 3, padding 0, stride 1) Concatenate Convolution (128 128) (kernel size 3, padding 1, stride 1) Convolution (128 128) (kernel size 3, padding 0, stride 1) Convolution (128 128) (kernel size 3, padding 1, stride 1) Convolution (128 128) (kernel size 3, padding 0, stride 1) Observation Representation Positional Embedding (a) Encoder. Deconvolution (RC 128) (kernel size 5, padding 0, stride 1) Representation Positional Embedding Concatenate Deconvolution (128 + CC 128) (kernel size 3, padding 0, stride 1) Convolution (128 128) (kernel size 3, padding 1, stride 1) Deconvolution (128 128) (kernel size 3, padding 0, stride 1) Convolution (128 128) (kernel size 3, padding 1, stride 1) Deconvolution (128 128) (kernel size 3, padding 0, stride 1) Convolution (128 IC) (kernel size 3, padding 1, stride 1) Observation (b) Decoder. Figure 11: Baseline encoder and decoder architectures for the Bouncing Ball task. Published as a conference paper at ICLR 2021 Convolution (IC 128) (kernel size 5, padding 0, stride 1) Convolution (128 128) (kernel size 3, padding 1, stride 1) Convolution (128 128) (kernel size 3, padding 0, stride 1) Convolution (128 128) (kernel size 3, padding 1, stride 1) Convolution (128 128) (kernel size 3, padding 0, stride 1) Convolution (128 128) (kernel size 3, padding 1, stride 1) Convolution (128 128) (kernel size 3, padding 0, stride 1) Observation Representation (a) Encoder. Deconvolution (RC 128) (kernel size 5, padding 0, stride 1) Representation Deconvolution (128 128) (kernel size 3, padding 0, stride 1) Convolution (128 128) (kernel size 3, padding 1, stride 1) Deconvolution (128 128) (kernel size 3, padding 0, stride 1) Convolution (128 128) (kernel size 3, padding 1, stride 1) Deconvolution (128 128) (kernel size 3, padding 0, stride 1) Convolution (128 IC) (kernel size 3, padding 1, stride 1) Observation (b) Decoder. Figure 12: S2RM encoder and decoder architectures for the Bouncing Ball task. Published as a conference paper at ICLR 2021 E.1.1 S2RMS The encoder (decoder) is a (de)convolutional network with residual connections (Figure 12). E.1.2 BASELINES Like in the case of S2RMs, the encoder (decoder) is a (de)convolutional network with residual connections (Figure 11), but with the positional embeddings injected after the second convolutional layer. This is loosely inspired by the encoders used in Eslami et al. (2018). E.2 ENCODER AND DECODER FOR STARCRAFT2 E.2.1 S2RMS Recall from Appendix B.1 that the states are polar images. We therefore use polar convolutions, which entails zero-padding the input image along the first (radial) dimension but circular padding along the second (angular) dimension. The encoder and decoder architectures can be found in Figure 14. E.2.2 BASELINES Like for S2RMs, we use polar convolutions while injecting the positional embeddings further downstream in the network. The corresponding encoder and decoder architectures are illustrated in Figure 13. E.3 SPATIALLY STRUCTURED RELATIONAL MEMORY CORES (S2RMCS) Embedding Relational Memory Cores (Santoro et al., 2018) na ıvely in the S2RM architecture did not result in a working model. We therefore had to adapt it by first projecting the memory matrix (M in Santoro et al. (2018)) of the m-th RMC to a message hm t . This message is then processed by the intercell attention to obtain hm t , which is finally concatenated with the memory matrix and current input before applying the attention mechanism (i.e. in Equation 2 of Santoro et al. (2018), we replace [M; x] with M; x, hm t ). E.4 HYPERPARAMETERS E.4.1 BOUNCING BALL MODELS The hyperparameters we used can be found in Table 2. Further, note that in Equation 5, we pass the gradients through the constant region of the kernel as if the kernel had not been truncated. E.4.2 STARCRAFT2 MODELS The hyperparameters we used can be found in Table 3. Note that we only report models that attained a validation loss similar to or better than S2RMs. E.4.3 TRAINING All models were trained using Adam Kingma & Ba (2014) with an initial learning rate 0.00039. We use Pytorch s (Paszke et al., 2019) Reduce LROn Plateau learning rate scheduler to decay the learning rate by a factor of 2 if the validation loss does not improve by at least 0.01% over the span of 5 epochs. We initially train all models for 100 epochs, select the best of three successful runs, fine-tune it for another 100 epochs, and finally select the checkpoint with the lowest validation loss (i.e. we early stop). We train all models with batch-size 8 (Starcraft2) or 32 (Bouncing Balls) on a single V100-32GB GPU (each). 9https://twitter.com/karpathy/status/801621764144971776?s=20 Published as a conference paper at ICLR 2021 Polar Convolution (IC 128) (kernel size 3x5, padding 1x2, stride 2x3) Polar Convolution (128 128) (kernel size 3x5, padding 1x2, stride 1x1) Polar Convolution (128 + CC 128) (kernel size 3x5, padding 0x2, stride 2x2) Concatenate Polar Convolution (128 128) (kernel size 3x5, padding 1x2, stride 1x1) Polar Convolution (128 128) (kernel size 3x3, padding 0x1, stride 1x2) Polar Convolution (128 128) (kernel size 1x3, padding 0x1, stride 1x1) Convolution (128 128) (kernel size 1x3, padding 0x0, stride 1x1) Observation Representation Positional Embedding (a) Encoder. Deconvolution (128 + CC 128) (kernel size 3x5, padding 0x0, stride 1x1) Polar Convolution (128 128) (kernel size 3x5, padding 1x2, stride 1x1) Deconvolution (128 128) (kernel size 3x5, padding 0x0, stride 2x3) Polar Convolution (128 128) (kernel size 3x5, padding 1x2, stride 1x1) Deconvolution (128 128) (kernel size 3x5, padding 0x0, stride 2x2) Polar Convolution (128 IC) (kernel size 3x5, padding 1x2, stride 1x1) Observation Fully Connected Representation Positional Embedding Concatenate (b) Decoder. Figure 13: Baseline encoder and decoder architectures for the Starcraft2 task. Published as a conference paper at ICLR 2021 Polar Convolution (IC 128) (kernel size 3x5, padding 1x2, stride 2x3) Polar Convolution (128 128) (kernel size 3x5, padding 1x2, stride 1x1) Polar Convolution (128 + CC 128) (kernel size 3x5, padding 0x2, stride 2x2) Polar Convolution (128 128) (kernel size 3x5, padding 1x2, stride 1x1) Polar Convolution (128 128) (kernel size 3x3, padding 0x1, stride 1x2) Polar Convolution (128 128) (kernel size 1x3, padding 0x1, stride 1x3) Convolution (128 128) (kernel size 1x3, padding 0x0, stride 1x1) Observation Representation (a) Encoder. Fully Connected Representation Deconvolution (128 128) (kernel size 3x5, padding 0x0, stride 1x1) Polar Convolution (128 128) (kernel size 3x5, padding 1x2, stride 1x1) Deconvolution (128 128) (kernel size 3x5, padding 0x0, stride 2x3) Polar Convolution (128 128) (kernel size 3x5, padding 1x2, stride 1x1) Deconvolution (128 128) (kernel size 3x5, padding 0x0, stride 2x2) Polar Convolution (128 IC) (kernel size 3x5, padding 1x2, stride 1x1) Observation (b) Decoder. Figure 14: S2RM encoder and decoder architectures for the Starcraft2 task. Published as a conference paper at ICLR 2021 Model Hyperparameter Value S2GRU Number of modules (M) 10 GRU: hidden size per module 128 Module embedding size (d) 16 Kernel bandwidth (ϵ) 1 Kernel truncation (τ) 0.6 shape Θ(Q/K) (128, 2, 016) shape Θ(V ) (128, 2, 128) shape Φ(Q/K) (128, 4, 016) shape Φ(V ) (128, 4, 128) RMC (Santoro et al., 2018) Number of attention heads 4 Size of attention head 128 Number of memory slots 1 Key size 128 LSTM (Hochreiter & Schmidhuber, 1997) Hidden size 512 RIMs (Goyal et al., 2019) Number of RIMs (k T ) 6 Update Top-k (k A) 5 Hidden size (hsize) 510 Input key size 32 Input value size 400 TTO MLP hidden size 512 Table 2: Hyperparameters used for various models on the Bouncing Ball task. Hyperparameters not listed here were left at their respective default values. Published as a conference paper at ICLR 2021 Model Hyperparameter Value S2GRUs Number of modules (M) 10 GRU: hidden size per module 128 Module embedding size (d) 8 Kernel bandwidth (ϵ) 1 Kernel truncation (τ) 0.5 shape Θ(Q/K) (128, 2, 016) shape Θ(V ) (128, 2, 128) shape Φ(Q/K) (128, 4, 016) shape Φ(V ) (128, 4, 128) S2RMC Number of modules (M) 10 RMC: number of attention heads 4 RMC: size of attention head 64 RMC: number of memory slots 4 RMC: key size 64 Module embedding size (d) 8 Kernel bandwidth (ϵ) 1 Kernel truncation (τ) 0.5 shape Θ(Q/K) (128, 2, 016) shape Θ(V ) (128, 2, 128) shape Φ(Q/K) (128, 4, 016) shape Φ(V ) (128, 4, 128) RMC (Santoro et al., 2018) Number of attention heads 4 Size of attention head 128 Number of memory slots 1 Key size 16 LSTM (Hochreiter & Schmidhuber, 1997) Hidden size 2048 TTO MLP hidden size 512 Table 3: Hyperparameters used for various models on the Starcraft2 task. Hyperparameters not listed here were left at their respective default values. Published as a conference paper at ICLR 2021 Figure 15: Rollouts (OOD) with 1 bouncing ball, from top to bottom: ground-truth, S2GRU, RIMs, RMC, LSTM. Note that all models were trained on sequences with 3 bouncing balls, and the global state was reconstructed by stitching together 11 11 patches from the models (queried on a 4 4 grid). Figure 16: Rollouts (OOD) with 2 bouncing balls, from top to bottom: ground-truth, S2GRU, RIMs, RMC, LSTM. Note that all models were trained on sequences with 3 bouncing balls, and the global state was reconstructed by stitching together 11 11 patches from the models (queried on a 4 4 grid). E.4.4 OBJECTIVE FUNCTIONS In the Starcraft2 task, predicting the next state entails predicting images of binary friendly markers, categorical unit type markers and real valued HECS markers. Accordingly, the loss function is a sum of a binary cross-entropy term (on friendly markers), a categorical cross-entropy term (on unit-type markers) and a mean squared error term (on HECS markers). In the Bouncing Balls task, the model output is a binary image. Accordingly, we use a pixel-wise binary cross-entropy loss. F ADDITIONAL RESULTS F.1 BOUNCING BALLS F.1.1 ROLLOUTS To obtain the rollouts in Figure 3, we adopt the following strategy. For the first 20 prompt-steps, we present all models with exactly the same 11 11 crops around randomly sampled pixel positions for 20 time-steps. For the next 25 steps, all models are queried at random pixel positions10, and the resulting predictions (on crops) are thresholded at 0.5 and fed back in to the model for the next step (at known pixel positions from the previous step). Also at every time-step, the models are queried for their predictions on 16 pixel locations placed on a 4 4 grid. The resulting predictions are stitched together and shown in Figures 15, 16, 17, 18, 3 and 19. Published as a conference paper at ICLR 2021 Figure 17: Rollouts (ID) with 3 bouncing balls, from top to bottom: ground-truth, S2GRU, RIMs, RMC, LSTM. Note that all models were trained on sequences with 3 bouncing balls, and the global state was reconstructed by stitching together 11 11 patches from the models (queried on a 4 4 grid). Figure 18: Rollouts (OOD) with 4 bouncing balls, from top to bottom: ground-truth, S2GRU, RIMs, RMC, LSTM. Note that all models were trained on sequences with 3 bouncing balls, and the global state was reconstructed by stitching together 11 11 patches from the models (queried on a 4 4 grid). Figure 19: Rollouts (OOD) with 6 bouncing balls, from top to bottom: ground-truth, S2GRU, RIMs, RMC, LSTM. Note that all models were trained on sequences with 3 bouncing balls, and the global state was reconstructed by stitching together 11 11 patches from the models (queried on a 4 4 grid). Published as a conference paper at ICLR 2021 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction of Available Views 1 2 3 4 5 6 Number of Bouncing Balls 0.949 0.961 0.968 0.971 0.974 0.975 0.976 0.977 0.978 0.933 0.959 0.969 0.974 0.977 0.979 0.980 0.981 0.982 0.915 0.946 0.960 0.967 0.971 0.974 0.975 0.977 0.978 0.869 0.907 0.927 0.939 0.946 0.950 0.954 0.956 0.958 0.823 0.856 0.879 0.893 0.902 0.909 0.915 0.918 0.921 0.785 0.812 0.832 0.845 0.856 0.863 0.869 0.873 0.877 0.72 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction of Available Views 1 2 3 4 5 6 Number of Bouncing Balls 0.933 0.948 0.954 0.958 0.960 0.962 0.963 0.964 0.964 0.926 0.950 0.962 0.968 0.972 0.974 0.976 0.977 0.977 0.916 0.950 0.964 0.971 0.975 0.977 0.979 0.980 0.981 0.866 0.900 0.918 0.930 0.937 0.941 0.945 0.947 0.949 0.819 0.844 0.859 0.869 0.877 0.882 0.886 0.889 0.891 0.781 0.797 0.808 0.815 0.821 0.825 0.828 0.830 0.832 0.72 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction of Available Views 1 2 3 4 5 6 Number of Bouncing Balls 0.917 0.929 0.936 0.941 0.944 0.947 0.948 0.949 0.950 0.897 0.923 0.937 0.945 0.950 0.953 0.956 0.958 0.959 0.885 0.935 0.957 0.966 0.971 0.974 0.976 0.977 0.978 0.814 0.835 0.850 0.859 0.866 0.871 0.875 0.878 0.881 0.770 0.783 0.792 0.799 0.804 0.807 0.810 0.812 0.814 0.738 0.746 0.751 0.755 0.759 0.761 0.762 0.763 0.764 0.72 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction of Available Views 1 2 3 4 5 6 Number of Bouncing Balls 0.916 0.928 0.935 0.940 0.942 0.943 0.943 0.942 0.941 0.896 0.920 0.934 0.943 0.949 0.952 0.955 0.958 0.960 0.893 0.931 0.948 0.957 0.962 0.965 0.967 0.969 0.970 0.825 0.843 0.854 0.861 0.867 0.871 0.874 0.876 0.878 0.780 0.791 0.798 0.803 0.806 0.808 0.811 0.812 0.814 0.747 0.753 0.758 0.761 0.763 0.764 0.765 0.766 0.766 0.72 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction of Available Views 1 2 3 4 5 6 Number of Bouncing Balls 0.909 0.929 0.949 0.965 0.977 0.986 0.993 0.997 1.000 0.850 0.880 0.913 0.939 0.960 0.975 0.987 0.995 0.999 0.808 0.846 0.885 0.918 0.944 0.964 0.980 0.992 0.998 0.776 0.817 0.861 0.898 0.927 0.951 0.971 0.986 0.994 0.750 0.793 0.838 0.877 0.909 0.936 0.958 0.975 0.984 0.729 0.771 0.816 0.856 0.889 0.917 0.941 0.957 0.967 0.72 Figure 20: Balanced accuracy (arithmetic mean of recall and specificity) achieved by all evaluated models for one-step forward prediction task with various number of balls and fractions of available views. All models were trained on video sequences with 3 balls and a constant number of crops / views (10 views, corresponding to the right-most columns labelled 1.0). The color map is consistent across all plots. Published as a conference paper at ICLR 2021 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction of Available Views 1 2 3 4 5 6 Number of Bouncing Balls 0.894 0.914 0.924 0.929 0.934 0.935 0.938 0.939 0.940 0.894 0.933 0.948 0.956 0.961 0.964 0.965 0.967 0.968 0.877 0.921 0.941 0.951 0.957 0.960 0.963 0.964 0.966 0.820 0.871 0.898 0.914 0.924 0.930 0.935 0.938 0.941 0.763 0.808 0.838 0.858 0.870 0.879 0.886 0.891 0.895 0.718 0.756 0.782 0.800 0.814 0.824 0.832 0.837 0.842 0.64 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction of Available Views 1 2 3 4 5 6 Number of Bouncing Balls 0.836 0.862 0.874 0.882 0.886 0.889 0.891 0.893 0.894 0.878 0.917 0.937 0.947 0.953 0.957 0.960 0.962 0.963 0.880 0.928 0.948 0.958 0.963 0.966 0.968 0.970 0.971 0.821 0.866 0.891 0.906 0.915 0.921 0.926 0.929 0.931 0.764 0.798 0.819 0.833 0.843 0.850 0.855 0.860 0.863 0.714 0.738 0.754 0.765 0.773 0.778 0.783 0.786 0.789 0.64 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction of Available Views 1 2 3 4 5 6 Number of Bouncing Balls 0.808 0.826 0.836 0.842 0.847 0.849 0.851 0.852 0.852 0.822 0.859 0.880 0.892 0.899 0.904 0.907 0.911 0.913 0.832 0.905 0.937 0.950 0.957 0.962 0.964 0.966 0.967 0.744 0.777 0.799 0.813 0.823 0.831 0.837 0.841 0.845 0.690 0.711 0.725 0.735 0.742 0.748 0.753 0.756 0.759 0.646 0.659 0.668 0.674 0.680 0.683 0.686 0.687 0.689 0.64 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction of Available Views 1 2 3 4 5 6 Number of Bouncing Balls 0.806 0.823 0.833 0.838 0.840 0.837 0.832 0.824 0.820 0.821 0.858 0.880 0.894 0.903 0.910 0.915 0.919 0.922 0.842 0.898 0.924 0.937 0.944 0.949 0.952 0.953 0.955 0.762 0.790 0.807 0.817 0.825 0.831 0.835 0.839 0.841 0.706 0.723 0.734 0.742 0.747 0.751 0.755 0.757 0.759 0.661 0.672 0.678 0.683 0.686 0.689 0.691 0.692 0.693 0.64 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 Fraction of Available Views 1 2 3 4 5 6 Number of Bouncing Balls 0.895 0.918 0.942 0.959 0.972 0.983 0.991 0.997 1.000 0.819 0.859 0.900 0.929 0.951 0.969 0.984 0.994 0.999 0.759 0.814 0.865 0.903 0.933 0.956 0.976 0.990 0.997 0.710 0.773 0.834 0.879 0.914 0.942 0.966 0.983 0.991 0.667 0.736 0.803 0.854 0.893 0.925 0.951 0.970 0.979 0.630 0.702 0.772 0.827 0.869 0.904 0.931 0.949 0.958 0.64 Figure 21: F1-Score (harmonic mean of precision and recall) achieved by all evaluated models for one-step forward prediction task with various number of balls and fractions of available views. All models were trained on video sequences with 3 balls and a constant number of crops / views (10 views, corresponding to the right-most columns labelled 1.0). The color map is consistent across all plots. Published as a conference paper at ICLR 2021 Model LSTM RMC S2GRU S2RMC TTO % of Active Agents 20% 0.570565 0.586541 0.642292 0.637618 0.550806 30% 0.599391 0.606114 0.660127 0.653950 0.578965 40% 0.630606 0.640435 0.678752 0.671476 0.605867 50% 0.638374 0.657472 0.688528 0.685988 0.627444 60% 0.681040 0.704552 0.713851 0.708786 0.671961 70% 0.709861 0.737436 0.734256 0.727980 0.723238 80% 0.721041 0.748138 0.738611 0.732114 0.740936 90% 0.750449 0.778647 0.755476 0.747613 0.786931 100% 0.765592 0.795049 0.763126 0.754637 0.813504 Table 4: Friendly marker F1 scores on the validation set of the training distribution. Larger numbers are better. Model LSTM RMC S2GRU S2RMC TTO % of Active Agents 20% 0.323482 0.326685 0.435318 0.377538 0.297192 30% 0.345108 0.350621 0.491934 0.433945 0.323736 40% 0.373612 0.387733 0.540163 0.485278 0.350733 50% 0.385550 0.406048 0.552589 0.510371 0.371088 60% 0.430793 0.481986 0.599470 0.566149 0.435724 70% 0.497964 0.590214 0.635928 0.606039 0.539652 80% 0.579952 0.649277 0.650682 0.623040 0.617973 90% 0.657643 0.694158 0.675294 0.655581 0.699008 100% 0.677952 0.715929 0.689669 0.672186 0.737745 Table 5: Unit-type marker (macro averaged) F1 scores on the validation set of the training distribution. Larger numbers are better. F.1.2 ROBUSTNESS TO DROPPED VIEWS In this section, we evaluate the robustness of all models to dropped crops on in-distribution and OOD data. We measure the performance metrics on one-step forward prediction task on all datasets (with 1-6 balls), albeit by dropping a given fraction of the available input observations. Figure 20 and 21 visualize the performance of all evaluated models. We find that S2GRU maintains performance on OOD data even with fewer views (or crops) than it was trained on. Interestingly, we find that the time-travelling oracle (TTO), while robust OOD, is adversely affected by the number of available views. This could be because unlike the other models, it cannot leverage the temporal information to compensate for the missing observations. F.2 STARCRAFT2 F.2.1 TABULAR RESULTS The results used to plot Figure 8 can be found tabulated in Tables 4, 5, 6 and 7. 10These random pixel positions are the same for all models, but change between time-steps Published as a conference paper at ICLR 2021 Model LSTM RMC S2GRU S2RMC TTO % of Active Agents 20% -0.014035 -0.013569 -0.011491 -0.011921 -0.014174 30% -0.013355 -0.012747 -0.010631 -0.011101 -0.013539 40% -0.012567 -0.011808 -0.009906 -0.010367 -0.012916 50% -0.012220 -0.011305 -0.009637 -0.009887 -0.012481 60% -0.010888 -0.009799 -0.008751 -0.009034 -0.010929 70% -0.009738 -0.008469 -0.008068 -0.008359 -0.009184 80% -0.009081 -0.008027 -0.007873 -0.008162 -0.008466 90% -0.007970 -0.007180 -0.007347 -0.007615 -0.007038 100% -0.007638 -0.006823 -0.007103 -0.007362 -0.006401 Table 6: HECS Negative MSE on the validation set of the training distribution. Larger numbers are better. Model LSTM RMC S2GRU S2RMC TTO % of Active Agents 20% -0.303051 -0.300892 -0.141989 -0.146553 -0.434099 30% -0.258878 -0.256878 -0.126037 -0.137025 -0.347899 40% -0.216924 -0.211048 -0.113317 -0.126882 -0.276596 50% -0.206582 -0.191644 -0.108643 -0.113293 -0.245019 60% -0.158170 -0.142643 -0.094380 -0.099989 -0.175233 70% -0.126446 -0.109634 -0.084129 -0.089527 -0.120694 80% -0.111735 -0.099229 -0.081624 -0.086723 -0.104135 90% -0.082463 -0.074518 -0.073439 -0.078197 -0.071243 100% -0.070488 -0.063183 -0.069856 -0.074041 -0.057276 Table 7: Log Likelihood (negative loss) on the validation set of the training distribution. Larger numbers are better.