# switching_autoregressive_lowrank_tensor_models__677c8840.pdf Switching Autoregressive Low-rank Tensor Models Hyun Dong Lee Computer Science Department Stanford University hdlee@stanford.edu Andrew Warrington Department of Statistics Stanford University awarring@stanford.edu Joshua I. Glaser Department of Neurology Northwestern University j-glaser@northwestern.edu Scott W. Linderman Department of Statistics Stanford University scott.linderman@stanford.edu An important problem in time-series analysis is modeling systems with timevarying dynamics. Probabilistic models with joint continuous and discrete latent states offer interpretable, efficient, and experimentally useful descriptions of such data. Commonly used models include autoregressive hidden Markov models (ARHMMs) and switching linear dynamical systems (SLDSs), each with its own advantages and disadvantages. ARHMMs permit exact inference and easy parameter estimation, but are parameter intensive when modeling long dependencies, and hence are prone to overfitting. In contrast, SLDSs can capture long-range dependencies in a parameter efficient way through Markovian latent dynamics, but present an intractable likelihood and a challenging parameter estimation task. In this paper, we propose switching autoregressive low-rank tensor (SALT) models, which retain the advantages of both approaches while ameliorating the weaknesses. SALT parameterizes the tensor of an ARHMM with a low-rank factorization to control the number of parameters and allow longer range dependencies without overfitting. We prove theoretical and discuss practical connections between SALT, linear dynamical systems, and SLDSs. We empirically demonstrate quantitative advantages of SALT models on a range of simulated and real prediction tasks, including behavioral and neural datasets. Furthermore, the learned low-rank tensor provides novel insights into temporal dependencies within each discrete state. 1 Introduction Many time series analysis problems involve jointly segmenting data and modeling the time-evolution of the system within each segment. For example, a common task in computational ethology [Datta et al., 2019] the study of natural behavior is segmenting videos of freely moving animals into states that represent distinct behaviors, while also quantifying the differences in dynamics between states [Wiltschko et al., 2015, Costacurta et al., 2022]. Similarly, discrete shifts in the dynamics of neural activity may reflect changes in underlying brain state [Saravani et al., 2019, Recanatesi et al., 2022]. Model-based segmentations are experimentally valuable, providing an unsupervised grouping of neural or behavioral states together with a model of the dynamics within each state. One common probabilistic state space model for such analyses is the autoregressive hidden Markov model (ARHMM) [Ephraim et al., 1989]. For example, Mo Seq [Wiltschko et al., 2015] uses ARHMMs for unsupervised behavioral analysis of freely moving animals. ARHMMs learn a set of linear autoregressive models, indexed by a discrete state, to predict the next observation as a function of previous observations. Inference in ARHMMs then reduces to inferring which AR 37th Conference on Neural Information Processing Systems (Neur IPS 2023). Obs. dim, N AR Tensor !("!) Obs. dim, N Output Factor Core Tensor Input Factor !! = #(#!) !!%&:!%( !! = &)*+ #! !!%&:!%( 200 Time, & 0 50 100 150 !! = 1 !! = 2 !! = 3 Observations Discrete states Figure 1: SALT imposes a low-rank constraint on the autoregressive tensor: (A) The probabilistic graphical model of an ARHMM. (B) An example multi-dimensional time series generated from an ARHMM. Background color indicates which discrete state (and hence autoregressive tensor) was selected at each time. (C) In SALT, each autoregressive dynamics tensor of an ARHMM is parameterized as a low-rank tensor. process best explains the observed data at each timestep (in turn also providing the segmentation). The simplicity of ARHMMs allows for exact state inference via message passing, and closed-form updates for parameter estimation using expectation-maximization (EM). However, the ARHMM requires high order autoregressive dependencies to model long timescale dependencies, and its parameter complexity is quadratic in the data dimension, making it prone to overfitting. Switching linear dynamical systems (SLDS) [Ghahramani and Hinton, 2000] ameliorate some of the drawbacks of the ARHMM by introducing a low-dimensional, continuous latent state. These models have been used widely throughout neuroscience [Saravani et al., 2019, Petreska et al., 2011, Linderman et al., 2019, Glaser et al., 2020, Nair et al., 2023]. Unlike the ARHMM, the SLDS can capture long timescale dependencies through the dynamics of the continuous latent state, while also being much more parameter efficient than ARHMMs. However, exact inference in SLDSs is intractable due to the exponential number of potential discrete state paths governing the time-evolution of the continuous latent variable. This intractability has led to many elaborate and specialized approximate inference techniques [Ghahramani and Hinton, 2000, Barber, 2006, Fox, 2009, Murphy and Russell, 2001, Linderman et al., 2017, Zoltowski et al., 2020]. Thus, the SLDS gains parameter efficiency at the expense of the computational tractability and statistical simplicity of the ARHMM. We propose a new class of unsupervised probabilistic models that we call switching autoregressive low-rank tensor (SALT) models. Our novel insight is that when you marginalize over the latent states of a linear dynamical system, you obtain an autoregressive model with full history dependence. However, these autoregressive dependencies are not arbitrarily complex they factor into a low-rank tensor that can be well-approximated with a finite-history model. We formalize this connection in Proposition 1. SALT models are constrained ARHMMs that leverage this insight. Rather than allowing for arbitrary autoregressive dependencies, SALT models are constrained to be low-rank. The low-rank property allows us to construct a low-dimensional continuous description of the data, jointly with the discrete segmentation provided by the switching states. Thus, SALT models inherit the experimentally useful representations and parsimonious parameter complexity of an SLDS, as well as the ease of inference and estimation of ARHMMs. We demonstrate the advantages of SALT models empirically using synthetic data as well as real neural and behavioral time series. Finally, in addition to improving predictive performance, we show how the low-rank nature of SALT models can offer new insights into complex systems, like biological neural networks. 2 Background This section introduces the notation used throughout the paper and describes preliminaries on lowrank tensor decomposition, vector autoregressive models, switching autoregressive models, linear dynamical systems, and switching linear dynamical systems. Notation We follow the notation of Kolda and Bader [2009]. We use lowercase letters for scalar variables (e.g. a), uppercase letters for scalar constants (e.g. A), boldface lowercase letters for vectors (e.g. a), boldface uppercase letters for matrices (e.g. A), and boldface Euler script for tensors of Source code is available at https://github.com/lindermanlab/salt. order three or higher (e.g. A). We use Ai::, A:j:, and A::k to denote the horizontal, lateral, and frontal slices respectively of a three-way tensor A. Similarly, we use ai: and a:j to denote the ith row and jth column of a matrix A. a b represents the vector outer product between vectors a and b. The n-mode tensor-matrix (tensor-vector) product is represented as A n A (A na). We denote the vectorization of an n-way tensor G, with dimensions D1:n, as vec (G). This is performed by successively flattening the last dimensions of the tensor, and results in a vector of size equal to the product of the dimensions of the tensor. We denote the mode-n matricization of a tensor G as G(n). This is defined as the stack of vectors resulting from vectorizing the matrix (or tensor) defined by each slice through the nth dimension. This results in a matrix with leading dimension Dn, and second dimension equal to the product of the sizes of the other dimensions. We will denote a T-length time series of N-dimensional observed data as Y RN T . Note that we will use the shorthand yt RN to denote the observation at time t, and yj,t R to denote the jth element in the tth observation. It will be clear from context which dimension is being indexed. Tensor Decomposition For A RN1 N2 N3, the Tucker decomposition is defined as, k=1 gijk u:i v:j w:k, (1) where u:i, v:j, and w:k are the columns of the factor matrices U RN1 D1, V RN2 D2, and W RN3 D3, respectively, and gijk are the entries in the core tensor G RD1 D2 D3. The CANDECOMP/PARAFAC (CP) decomposition is a special case of the Tucker decomposition, with D1 = D2 = D3 and a diagonal core tensor G. Vector autoregressive models Let Y RN T denote a multivariate time series with yt RN for all t. An order-L vector autoregressive (VAR) model with Gaussian innovations is defined by, k=1 a:jkyj,t k + b, R where A RN N L is the autoregressive tensor, whose frontal slice A::l is the dynamics matrix for lag l, b RN is the bias, and R RN N 0 is a positive semi-definite covariance matrix. The parameters Θ = (A, b, R) can be estimated via ordinary least squares [Hamilton, 2020]. We note that, to our knowledge, there is no clear consensus on the best way to regularize the potentially large parameter space of vector autoregressive (hidden Markov) models; several possibilities exist, see, e.g., Melnyk and Banerjee [2016] or Ni and Sun [2005]. Many regularizers and priors are difficult to work with, and so are not widely used in practice. Beyond this, even well-regularized ARHMMs do not natively capture interpretable low-dimensional dynamics, as both SALT and SLDS models do (see Figure 3). These low-dimensional continuous representations are as experimentally useful as the discrete segmentation, and hence are a key desiderata for any method we consider. Switching autoregressive models One limitation of VAR models is that they assume the time series is stationary; i.e. that one set of parameters holds for all time steps. Time-varying autoregressive models allow the autoregressive process to change at various time points. One such VAR model, referred to as a switching autoregressive model or autoregressive hidden Markov model (ARHMM), switches the parameters over time according to a discrete latent state [Ephraim et al., 1989]. Let zt {1, . . . , H} denote the discrete state at time t, an ARHMM defines the following generative model, zt Cat π(zt 1) , yt N k=1 a(zt) :jk yj,t k + b(zt), R(zt) where π(h) {π(h)}H h=1 is the the h-th row of the discrete state transition matrix. A switching VAR model is simply a type of hidden Markov model, and as such it is easily fit via the expectation-maximization (EM) algorithm within the Baum-Welch algorithm. The M-step amounts to solving a weighted least squares problem. Linear dynamical systems The number of parameters in a VAR model grows as O(N 2L). For high-dimensional time series, this can quickly become intractable. Linear dynamical systems (LDS) [Murphy, 2012] offer an alternative means of modeling time series via a continuous latent state xt RS, xt N(Axt 1 + b, Q), yt N(Cxt + d, R), (4) where Q RS S 0 and R RN N 0 . Here, the latent states follow a first-order VAR model, and the observations are conditionally independent given the latent states. As we discuss in Section 3.3, marginalizing over the continuous latent states renders yt dependent on the preceding observations, just like in a high order VAR model. Compared to the VAR model, however, the LDS has only O(S2 + NS + N 2) parameters if R is a full covariance matrix. This further reduces to O(S2 + NS) if R is diagonal. As a result, when S N, the LDS has many fewer parameters than a VAR model. Thanks to the linear and Gaussian assumptions of the model, the parameters can be easily estimated via EM, using the Kalman smoother to compute the expected values of the latent states. Switching linear dynamical systems A switching LDS combines the advantages of the lowdimensional continuous latent states of an LDS, with the advantages of discrete switching from an ARHMM. Let zt {1, . . . , H} be a discrete latent state with Markovian dynamics (3), and let it determine some or all of the parameters of the LDS (e.g. A would become A(zt) in (4)). We note that SLDSs often use a single-subspace, where C, d and R are shared across states, reducing parameter complexity and simplifying the optimization. Unfortunately, parameter estimation is considerably harder in SLDS models. The posterior distribution over all latent states, p(z1:T , x1:T | y1:T , Θ), where Θ denotes the parameters, is intractable [Lerner, 2003]. Instead, these models are fit via approximate inference methods like MCMC [Fox, 2009, Linderman et al., 2017], variational EM [Ghahramani and Hinton, 2000, Zoltowski et al., 2020], particle EM [Murphy and Russell, 2001, Doucet et al., 2001], or other approximations [Barber, 2006]. Selecting the appropriate fitting and inference methodologies is itself non-trivial hyperparameter. Furthermore, each method also brings additional estimation hyperparameters that need to be tuned prior to even fitting the generative model. We look to define a model that enjoys the benefits of SLDSs, but avoids the inference and estimation difficulties. 3 SALT: Switching Autoregressive Low-rank Tensor Models Here we formally introduce SALT models. We begin by defining the generative model (also illustrated in Figure 1), and describing how inference and model fitting are performed. We conclude by drawing connections between SALT and SLDS models. 3.1 Generative Model SALT factorizes each autoregressive tensor A(h) for h {1, . . . , H} of an ARHMM as a product of low-rank factors. Given the current discrete state zt, each observation yt RN is modeled as being normally distributed conditioned on L previous observations yt 1:t L, zt Cat π(zt 1) , (5) yt i.i.d. N k=1 a(zt) SALT,:jkyj,t k + b(zt),Σ(zt) A(zt) SALT = k=1 g(zt) ijk u(zt) :i v(zt) :j w(zt) :k , (7) where u(zt) :i , v(zt) :j , and w(zt) :k are the columns of the factor matrices U(zt) RN D1, V(zt) RN D2, and W(zt) RL D3, respectively, and g(zt) ijk are the entries in the core ten- sor G(zt) RD1 D2 D3. The vector b(zt) RN and positive definite matrix Σ(zt) RN N 0 are the bias and covariance for state zt. Without further restriction this decomposition is a Tucker decomposition [Kolda and Bader, 2009]. If D1 = D2 = D3 and Gzt is diagonal, it corresponds to a Table 1: Comparison of number of parameters for the methods we consider. We exclude covariance matrix parameters, as the parameterization of the covariance matrix is independent of method. Throughout our experiments, we find S D. Model Parameter Complexity (Example from Section 5.4) SLDS O(NS + HS2) 2.8K CP-SALT O(H(ND + LD)) 8.1K Tucker-SALT O(H(ND + LD + D3)) 17.4K Order-L ARHMM O(HN 2L) 145.2K CP decomposition [Kolda and Bader, 2009]. We refer to ARHMM models with these factorizations as Tucker-SALT and CP-SALT respectively. Note that herein we will only consider models where D1 = D2 = D3 = D, where we refer to D as the rank of the SALT model (for both Tucker-SALT and CP-SALT). In practice, we find that models constrained in this way perform well, and so this constraint is imposed simply to reduce the search space of models and could easily be relaxed. Table 1 shows the number of parameters for order-L ARHMMs, SLDSs, and SALT. Focusing on the lag dependence, the number of ARHMM parameters grows as O(HN 2L), whereas SALT grows as only O(HDL) with D N. SALT can also make a simplifying single-subspace constraint, where certain emission parameters are shared across discrete states. Low-dimensional Representation Note that SALT implicitly defines a low-dimensional continuous representation, analogous to the continuous latent variable in SLDS, k=1 g(zt) :jk v(zt) :j w(zt) :k , (8) k=1 p:jkyj,t k. (9) The low-dimensional xt RD1 vectors can be visualized, similar to the latent states in SLDS models, to further interrogate the learned dynamics, as we show in Figure 3. Note the vector xt RD1, when multiplied by the output factors U(zt), is the mean of the next observation. 3.2 Model Fitting and Inference Since SALT models are ARHMMs, we can apply the expectation-maximization (EM) algorithm to fit model parameters and perform state space inference. We direct the reader to Murphy [2012] for a detailed exposition of EM and include only the key points here. The E-step solves for the distribution over latent variables given observed data and model parameters. For SALT, this is the distribution over zt, denoted ω(h) t = E[zt = h | y1:T , θ]. This can be computed exactly with the forward-backward algorithm, which is fast and stable. The marginal likelihood can be evaluated exactly by taking the product across t of expectations of (6) under ω(h) t . The M-step then updates the parameters of the model given the distribution over latent states. For SALT, the emission parameters are θ = {U(h), V(h), W(h), G(h), b(h),Σ(h), π(h)}H h=1. We use closed-form coordinate-wise updates to maximize the expected log likelihood evaluated in the E-step. Each factor update amounts to solving a weighted least squares problem. We include just one update step here for brevity, and provide all updates in full in Appendix A. Assuming here that b(h) = 0 for simplicity, the update rule for the lag factors is as follows: t ω(h) t e X(h) t (Σ(h)) 1 e X(h) t t ω(h) t e X(h) t (Σ(h)) 1yt where e X(h) t = U(h)G(h) (1)(V(h) yt 1:t L ID3) and w(h) = vec(W(h)). Crucially, these coordinate wise updates are exact, and so we recover the fast and monotonic convergence of EM. 3.3 Connections Between SALT and Linear Dynamical Systems SALT is not only an intuitive regularization for ARHMMs, it is grounded in a mathematical correspondence between autoregressive models and linear dynamical systems. Proposition 1 (Low-Rank Tensor Autoregressions Approximate Stable Linear Dynamical Systems). Consider a stable linear time-invariant Gaussian dynamical system. We define the steady-state Kalman gain matrix as K = limt Kt, and Γ = A(I KC). The matrix Γ RS S has eigenvalues λ1, . . . , λS. Let λmax = maxs |λs|; for a stable LDS, λmax < 1 [Davis and Vinter, 1985]. Let n denote the number of real eigenvalues and m the number of complex conjugate pairs. Let ˆy(LDS) t = E[yt | y1:t 1] denote the predictive mean under a steady-state LDS, and ˆy(SALT) t the predictive mean under a SALT model. An order-L Tucker-SALT model with rank n + 2m = S, or a CP-SALT model with rank n + 3m, can approximate the predictive mean of the steady-state LDS with error ˆy(LDS) t ˆy(SALT) t = O(λL max). Proof. We give a sketch of the proof here and a full proof in Appendix B. The analytic form of E [yt | y1:t 1] is a linear function of yt l for l = 1, . . . , . For this sketch, consider the special case where b = d = 0. Then the coefficients of the linear function are CΓl K. As all eigenvalues of Γ have magnitude less than one, the coefficients decay exponentially in l. We can therefore upper bound the approximation error introduced by truncating the linear function to L terms to O(λL max). To complete the proof, we show that the truncated linear function can be represented exactly by a tensor regression with at most a specific rank. Thus, only truncated terms contribute to the error. This proposition shows that the steady-state predictive distribution of a stable LDS can be approximated by a low-rank tensor autoregression, with a rank determined by the eigenspectrum of the LDS. We validate this proposition experimentally in Section 5.1. Note as well that the predictive distribution will converge to a fixed covariance, and hence can also be exactly represented by the covariance matrices Σ(h) estimated in SALT models. Connections with Switching Linear Dynamical Systems With this foundation, it is natural to hypothesize that a switching low-rank tensor autoregression like SALT could approximate a switching LDS. There are two ways this intuition could fail: first, if the dynamics in a discrete state of an SLDS are unstable, then Proposition 1 would not hold; second, after a discrete state transition in an SLDS, it may take some time before the dynamics reach stationarity. We empirically test how well SALT approximates an SLDS in Section 5 and find that, across a variety of datasets, SALT obtains commensurate performance with considerably simpler inference and estimation algorithms. 4 Related Work Low-rank tensor decompositions of time-invariant autoregressive models Similar to this work, Wang et al. [2021] also modeled the transition matrices as a third-order tensor A RN N L where the A::l is the l-th dynamics matrix. They then constrained the tensor to be low-rank via a Tucker decomposition, as defined in (1). However, unlike SALT, their model was time-invariant,did not have an ARHMM structure, or, make connections to the LDS and SLDS, as in Proposition 1. Low-rank tensor decompositions of time-varying autoregressive models Low-rank tensor-based approaches have also been used to model time-varying AR processes [Harris et al., 2021, Zhang et al., 2021]. Harris et al. [2021] introduced TVART, which first splits the data into T contiguous fixed-length segments, each with its own AR-1 process. TVART can be thought of as defining a T N N ARHMM dynamics tensor and progressing through discrete states at fixed time points. This tensor is parameterized using the CP decomposition and optimized using an alternating least squares algorithm, with additional penalties such that the dynamics of adjacent windows are similar. By contrast, SALT automatically segments, rather than windows, the time-series into learned and re-usable discrete states. Zhang et al. [2021] constructed a Bayesian model of higher-order AR matrices that can vary over time. First, H VAR dynamics tensors are specified, parameterized as third-order tensors with a rank-1 CP decomposition. The dynamics at a given time are then defined as a weighted sum of the tensors, where the weights have a prior density specified by an Ising model. Finally, inference over the weights is performed using MCMC. This method can be interpreted as a factorial ARHMM, hence offering substantial modeling flexibility, but sacrificing computational tractability when H is large. Figure 2: SALT approximates LDS: Data simulated from an LDS for which n = 1 and m = 3 (see Proposition 1). (A-B): Average mean squared error of the autoregressive tensor corresponding to the LDS simulation and the log-likelihood of test data, as a function of SALT rank. According to Proposition 1, to model the LDS Tucker-SALT and CP-SALT require 7 and 10 ranks respectively (indicated by vertical dashed lines). Note the parameter error increases above the predicted threshold as a result of overfitting. (C-D): Mean squared error of the learned autoregressive tensor and loglikelihood of test data as a function of training data. Low-rank tensor decompositions of neural networks Low-rank tensor decomposition methods have also been used to make neural networks more parameter efficient. Novikov et al. [2015] used the tensor-train decomposition [Oseledets, 2011] on the dense weight matrices of the fully-connected layers to reduce the number of parameters. Yu et al. [2017] and Qiu et al. [2021] applied the tensortrain decomposition to the weight tensors for polynomial interactions between the hidden states of recurrent neural networks (RNNs) to efficiently capture high-order temporal dependencies. Unlike switching models with linear dynamics, recurrent neural networks have dynamics that are hard to interpret, their state estimates are not probabilistic, and they do not provide experimentally useful data segmentations. Linear dynamical systems and low-rank linear recurrent neural networks Valente et al. [2022] recently examined the relationship between LDSs and low-rank linear RNNs. They provide the conditions under which low-rank linear RNNs can exactly model the first-order autoregressive distributions of LDSs, and derive the transformation to convert between model classes under those conditions. This result has close parallels to Proposition 1. Under the conditions identified by Valente et al. [2022], the approximation in Proposition 1 becomes exact with just one lag term. However, when those conditions are not satisfied, we show that one still recovers an LDS approximation with a bounded error that decays exponentially in the number of lag terms. We now empirically validate SALT by first validating the theoretical claims made in Section 3, and then apply SALT to two synthetic examples to compare SALT to existing methods. We conclude by using SALT to analyze real mouse behavioral recordings and C. elegans neural recordings. 5.1 SALT Faithfully Approximates LDS To test the theoretical result that SALT can closely approximate a linear dynamical system, we fit SALT models to data sampled from an LDS. The LDS has S = 7 dimensional latent states with random rotational dynamics, where Γ has n = 1 real eigenvalue and m = 3 pairs of complex eigenvalues, and N = 20 observations with a random emission matrix. For Figure 2, we trained CP-SALT and Tucker-SALT with L = 50 lags and varying ranks. We first analyzed how well SALT reconstructed the parameters of the autoregressive dynamics tensor. As predicted by Proposition 1, Figure 2A shows that the mean squared errors between the SALT tensor and the autoregressive tensor corresponding to the simulated LDS are the lowest when the ranks of CP-SALT and Tucker-SALT are n + 3m = 10 and n + 2m = 7 respectively. We then computed log-likelihoods on 5,000 timesteps of held-out test data (Figure 2B). Interestingly, the predictive performance of both CP-SALT and Tucker-SALT reach the likelihood of the ground truth LDS model with rank n + 2m = 7, suggesting that sometimes smaller tensors than suggested by Proposition 1 may still be able to provide good approximations to the data. We also show in Figures 2C and 2D that, as predicted, SALT models require much less data to fit than ARHMMs. We show extended empirical results and discussion on Proposition 1 in Appendix D.1. 2 Ground truth observation 2 SLDS filtered observations 0 200 400 Timesteps 2 SALT filtered observations SALT filtered trajectory Ground truth Lorenz attractor SLDS filtered trajectory Ground truth trajectory Ground truth observation SLDS filtered trajectory SLDS filtered observations SALT filtered trajectory 0 100 200 Timesteps SALT filtered observations Figure 3: SALT reconstructs simulated SLDS data and Lorenz attractor: (Top row) Observation generated from a low-dimensional trajectory. (A) shows ten observations generated from a recurrent NASCAR SLDS trajectory Linderman et al. [2017]. (B) 20-dimensional observations generated from a Lorenz attractor (5 observed dimensions are shown). (Middle and bottom rows): filtered observations and inferred low-dimensional trajectories from SLDS and SALT models. Colors indicate discrete state for ground truth (if available) and fitted models. SLDS and SALT find comparable filtered trajectories and observations. It is important to note that the latent spaces in both SLDS and SALT are only identifiable up to a linear transformation. We therefore align the latent trajectories for ease of comparison. This latent structure is reliably found by both SALT and SLDS. 5.2 Synthetic Switching LDS Examples Proposition 1 quantifies the convergence properties of low-rank tensor regressions when approximating stable LDSs. Next we tested how well SALT can approximate the more expressive switching LDSs. We first applied SALT to data generated from a recurrent SLDS [Linderman et al., 2017], where the two-dimensional ground truth latent trajectory resembles a NASCAR track (Figure 3A). SALT accurately reconstructed the ground truth filtered trajectories and discrete state segmentation, and yielded very similar results to an SLDS model. We also tested the ability of SALT to model nonlinear dynamics specifically, a Lorenz attractor which SLDSs are capable of modeling. Again, SALT accurately reconstructed ground truth latents and observations, and closely matched SLDS segmentations. These results suggest that SALT models provide a good alternative to SLDS models. Finally, in Appendix D.3, we used SLDS-generated data to compare SALT and TVART [Harris et al., 2021], another tensor-based method for modeling autoregressive processes, and find that SALT more accurately reconstructed autoregressive dynamics tensors than TVART. 5.3 Modeling Mouse Behavior Next we considered a video segmentation problem commonly faced in the field of computational neuroethology [Datta et al., 2019]. Wiltschko et al. [2015] collected videos of mice freely behaving in a circular open field. They projected the video data onto the top 10 principal components (Figure 4A) and used an ARHMM to segment the PCA time series into distinct behavioral states. Here, we compared ARHMMs and CP-SALT with data from three mice. We used the first 35,949 timesteps of each recording, which were collected at 30Hz resolution. We used H = 50 discrete states and fitted ARHMMs and CP-SALT models with varying lags and ranks. The likelihood on a held-out validation set shows that the ARHMM overfitted quickly as the number of lags increased, while CP-SALT was more robust to overfitting (Figure 4B). We compared loglikelihoods of the best model (evaluated on the validation set) on a separate held-out test set and found that CP-SALT consistently outperformed ARHMM across mice (Figure 4C). We also investigated the quality of SALT segmentations of the behavioral data (Appendix E.3). We found that the PCA trajectories upon transition into a discrete SALT state were highly stereotyped, suggesting that SALT segments the data into consistent behavioral states. Furthermore, CP-SALT used fewer discrete states than the ARHMM, suggesting that the ARHMM may have oversegmented and that CP-SALT offers a more parsimonious description of the data. Figure 4: CP-SALT consistently outperforms ARHMM on mouse behavior videos and segments data into distinct behavioral syllables: (A) An example frame from the Mo Seq dataset. The models were trained on the top 10 principal components of the video frames from three mice. (B) CP-SALT and ARHMM trained with different ranks and lags. Mean and standard deviation across five seeds evaluated on a validation set are shown. CP-SALT parameterization prevents overfitting for larger lags. (C) Test log-likelihood, averaged across 5 model fits, computed from the best ARHMM and CP-SALT hyperparameters in (B). CP-SALT outperforms ARHMM across all three mice. 5.4 Modeling C. elegans Neural Data Finally, we analyzed neural recordings of an immobilized C. elegans worm from Kato et al. [2015]. SLDS have previously been used to capture the time-varying low-dimensional dynamics of the neural activity [Linderman et al., 2019, Glaser et al., 2020]. We compared SLDS, ARHMM, and CP-SALT with 18 minutes of neural traces (recorded at 3Hz; 3200 timesteps) from one worm, in which 48 neurons were confidently identified. The dataset also contains 7 manually identified state labels based on the neural activity. We used H = 7 discrete states and fitted SLDSs, ARHMMs, and CP-SALT with varying lags and ranks (or continuous latent dimensions for SLDSs). Following Linderman et al. [2019], we searched for sets of hyperparameters that achieve 90% explained variance on a held-out test dataset (see Appendix F for more details). For ARHMMs and CP-SALT, we chose a larger lag (L = 9, equivalent to 3 seconds) to examine the long-timescale correlations among the neurons. We find that SALT can perform as well as SLDSs and ARHMMs in terms of held-out explained variance ratio (a metric used by previous work [Linderman et al., 2019]). As expected, we find that CP-SALT can achieve these results with far fewer parameters than ARHMMs, and with a parameter count closer to SLDS than ARHMM (as more continuous latent states were required in an SLDS to achieve 90% explained variance; see Appendix F). Figure 5A shows that SALT, SLDS and ARHMM produce similar segmentations to the given labels, as evidenced by the confusion matrix having high entries on the leading diagonal (Figure 5B and Appendix F). Figure 5C shows the one-dimensional autoregressive filters learned by CP-SALT, defined as PD1 i=1 PD2 j=1 PD3 k=1 g(h) ijk u(h) pi v(h) qj w(h) :k for neurons p and q. We see that neurons believed to be involved in particular behavioral states have high weights in the filter (e.g., SMDV during the Ventral Turn state and SMDD during the Dorsal Turn state [Linderman et al., 2019, Kato et al., 2015, Gray et al., 2005, Kaplan et al., 2020, Yeon et al., 2018]). This highlights how switching autoregressive models can reveal state-dependent functional interactions between neurons (or observed states more generally). In Appendix F, we show the autoregressive filters learned by an ARHMM, an SLDS, and a generalized linear model (GLM), a method commonly used to model inter-neuronal interactions [Pillow et al., 2008]. Interestingly, the GLM does not find many strong functional interactions between neurons, likely because it is averaging over many unique discrete states. In addition to its advantages in parameter efficiency and estimation, SALT thus provides a novel method for finding changing functional interactions across neurons at multiple timescales. 6 Discussion We introduce switching autoregressive low-rank tensor (SALT) models: a novel model class that parameterizes the autoregressive tensors of an ARHMM with a low-rank factorization. This constraint Turn CP-SALT Lag (Rank=11, Lag=9, equivalent to 3 seconds; from left to right) SALT Tensor Weights Active during: Reverse Sustained Forward Ventral Turn Dorsal Turn Reverse 1 Reverse 2 Slow Traces Labels 1 2 3 4 5 6 7 VT DT REV1 REV2 0 SALT Labels # of Timesteps AIBL AVAL AVBR AVEL AVER AVFL RIBL RIML RIS RIVL RMED SMDDL SMDVL VA01 VB02 0 200 400 600 800 1000 1200 1400 ARHMM 0.5 0.0 0.5 RIVL 0.5 0.0 0.5 RIVR 0.5 0.0 0.5 SMDVL 0.5 0.0 0.5 SMDVR 0.5 0.0 0.5 AVFL 0.5 0.0 0.5 AVFR 0.5 0.0 0.5 SMDDL 0.5 0.0 0.5 SMDDR AIBL AVAL AVBR AVEL AVER AVFL RIS RIVL RMED SMDDL SMDVL Given SALT SLDS ARHMM 0 200 400 600 800 1000 1200 1400 Timesteps VT DT Rev. 1 Rev. 2 Given Labels 1 2 3 4 5 6 7 SALT State 5 (Dorsal Turn) SALT State 4 (Ventral Turn) Figure 5: CP-SALT provides good segmentations of C. elegans neural data, and inferred lowrank tensors give insights into temporal dependencies among neurons in each discrete state: (A) Example data with manually generated labels (Given), as well as segmentations generated by SALT, SLDS and ARHMM models. Learned states are colored based on the permutation of states that best matches given labels. All methods produce comparable segmentations, with high agreement with the given labels. (B) Confusion matrix of SALT-generated labels. (C) One-dimensional autoregressive filters learned in two states by SALT (identified as ventral and dorsal turns). Colors indicate the area under curve (red is positive; blue is negative). The first four rows are neurons known to mediate ventral turns, while the last two rows mediate dorsal turns [Kato et al., 2015, Gray et al., 2005, Yeon et al., 2018]. These known behavior-tuned neurons generally have larger magnitude autoregressive filters. Interestingly, AVFL and AVFR also have large filters for dorsal turns. These neurons do not have a well-known function. However, they are associated with motor neurons, and so may simultaneously activate due to factors that co-occur with turning. This highlights how SALT may be used for proposing novel relationships in systems. allows SALT to model time-series data with fewer parameters than ARHMMs and with simpler estimation procedures than SLDSs. We also make theoretical connections between low-rank tensor regressions and LDSs. We then demonstrate, with both synthetic and real datasets, that SALT offers both efficiency and interpretability, striking an advantageous balance between the ARHMM and SLDS. Moreover, SALT offers an enhanced ability to investigate the interactions across observations, such as neurons, across different timescales in a data-efficient manner. However, SALT is not without limitations. Foremost, SALT cannot readily handle missing observations, or share information between multiple time series with variable observation dimensions. Hierarchical SALT is an interesting extension, where information is shared across time series, but the factors of individual time series are allowed to vary. Furthermore, SALT could be extended to handle non-Gaussian data. For example, neural spike trains are often modeled with Poisson likelihoods instead of SALT s Gaussian noise model. In this case, the E-step would still be exact, but the M-step would no longer have closed-form coordinate updates. Despite these limitations, SALT offers simple, effective and complementary means of modeling and inference methodology for complex, time-varying dynamical systems. Ethical Concerns We note that there are no new ethical concerns as a result of SALT. Acknowledgments and Disclosure of Funding This work was supported by grants from the Simons Collaboration on the Global Brain (SCGB 697092), the NIH (U19NS113201, R01NS113119, R01NS130789, and K99NS119787), the Sloan Foundation, and the Stanford Center for Human and Artificial Intelligence. We thank Liam Paninski and the anonymous reviewers for their constructive feedback on the paper. We also thank the members of the Linderman Lab for their support and feedback throughout the project. Sandeep Robert Datta, David J Anderson, Kristin Branson, Pietro Perona, and Andrew Leifer. Computational neuroethology: A call to action. Neuron, 104(1):11 24, 2019. Alexander B Wiltschko, Matthew J Johnson, Giuliano Iurilli, Ralph E Peterson, Jesse M Katon, Stan L Pashkovski, Victoria E Abraira, Ryan P Adams, and Sandeep Robert Datta. Mapping sub-second structure in mouse behavior. Neuron, 88(6):1121 1135, 2015. Julia Costacurta, Lea Duncker, Blue Sheffer, Winthrop Gillis, Caleb Weinreb, Jeffrey Markowitz, Sandeep R Datta, Alex Williams, and Scott Linderman. Distinguishing discrete and continuous behavioral variability using warped autoregressive HMMs. Advances in Neural Information Processing Systems, 35:23838 23850, 2022. Aram Giahi Saravani, Kiefer J Forseth, Nitin Tandon, and Xaq Pitkow. Dynamic brain interactions during picture naming. e Neuro, 6(4), 2019. Stefano Recanatesi, Ulises Pereira-Obilinovic, Masayoshi Murakami, Zachary Mainen, and Luca Mazzucato. Metastable attractors explain the variable timing of stable behavioral action sequences. Neuron, 110(1):139 153, 2022. Yariv Ephraim, David Malah, and B-H Juang. On the application of hidden Markov models for enhancing noisy speech. IEEE Transactions on Acoustics, Speech, and Signal Processing, 37(12): 1846 1856, 1989. Zoubin Ghahramani and Geoffrey E Hinton. Variational learning for switching state-space models. Neural Computation, 12(4):831 864, 2000. Biljana Petreska, Byron M Yu, John P Cunningham, Gopal Santhanam, Stephen Ryu, Krishna V Shenoy, and Maneesh Sahani. Dynamical segmentation of single trials from population neural data. Advances in Neural Information Processing Systems, 24, 2011. Scott Linderman, Annika Nichols, David Blei, Manuel Zimmer, and Liam Paninski. Hierarchical recurrent state space models reveal discrete and continuous dynamics of neural activity in C. elegans. bio Rxiv, page 621540, 2019. Joshua Glaser, Matthew Whiteway, John P Cunningham, Liam Paninski, and Scott Linderman. Recurrent switching dynamical systems models for multiple interacting neural populations. Advances in Neural Information Processing Systems, 33:14867 14878, 2020. Aditya Nair, Tomomi Karigo, Bin Yang, Surya Ganguli, Mark J Schnitzer, Scott W Linderman, David J Anderson, and Ann Kennedy. An approximate line attractor in the hypothalamus encodes an aggressive state. Cell, 186(1):178 193, 2023. David Barber. Expectation correction for smoothed inference in switching linear dynamical systems. Journal of Machine Learning Research, 7(11), 2006. Emily Beth Fox. Bayesian nonparametric learning of complex dynamical phenomena. Ph D thesis, Massachusetts Institute of Technology, 2009. Kevin Murphy and Stuart Russell. Rao-Blackwellised particle filtering for dynamic Bayesian networks. In Sequential Monte Carlo methods in practice, pages 499 515. Springer, 2001. Scott Linderman, Matthew Johnson, Andrew Miller, Ryan Adams, David Blei, and Liam Paninski. Bayesian learning and inference in recurrent switching linear dynamical systems. In Artificial Intelligence and Statistics, pages 914 922. PMLR, 2017. David Zoltowski, Jonathan Pillow, and Scott Linderman. A general recurrent state space framework for modeling neural dynamics during decision-making. In International Conference on Machine Learning, pages 11680 11691. PMLR, 2020. Tamara G Kolda and Brett W Bader. Tensor decompositions and applications. SIAM review, 51(3): 455 500, 2009. James Douglas Hamilton. Time series analysis. Princeton University Press, 2020. Igor Melnyk and Arindam Banerjee. Estimating structured vector autoregressive models. In Maria Florina Balcan and Kilian Q. Weinberger, editors, Proceedings of The 33rd International Conference on Machine Learning, volume 48 of Proceedings of Machine Learning Research, pages 830 839, New York, New York, USA, 20 22 Jun 2016. PMLR. Shawn Ni and Dongchu Sun. Bayesian estimates for vector autoregressive models. Journal of Business & Economic Statistics, 23(1):105 117, 2005. Kevin P Murphy. Machine learning: a probabilistic perspective. MIT press, 2012. Uri Nahum Lerner. Hybrid Bayesian networks for reasoning about complex systems. Ph D thesis, Stanford University, 2003. Arnaud Doucet, Neil J Gordon, and Vikram Krishnamurthy. Particle filters for state estimation of jump Markov linear systems. IEEE Transactions on signal processing, 49(3):613 624, 2001. Mark H A Davis and Richard B Vinter. Stochastic modelling and control. Chapman and Hall London ; New York, 1985. ISBN 0412162008. Di Wang, Yao Zheng, Heng Lian, and Guodong Li. High-dimensional vector autoregressive time series modeling via tensor decomposition. Journal of the American Statistical Association, pages 1 19, 2021. Kameron Decker Harris, Aleksandr Aravkin, Rajesh Rao, and Bingni Wen Brunton. Time-varying autoregression with low-rank tensors. SIAM Journal on Applied Dynamical Systems, 20(4): 2335 2358, 2021. Wei Zhang, Ivor Cribben, Sonia Petrone, and Michele Guindani. Bayesian time-varying tensor vector autoregressive models for dynamic effective connectivity. ar Xiv preprint ar Xiv:2106.14083, 2021. Alexander Novikov, Dmitrii Podoprikhin, Anton Osokin, and Dmitry P Vetrov. Tensorizing neural networks. Advances in Neural Information Processing Systems, 28, 2015. Ivan V Oseledets. Tensor-train decomposition. SIAM Journal on Scientific Computing, 33(5): 2295 2317, 2011. Rose Yu, Stephan Zheng, Anima Anandkumar, and Yisong Yue. Long-term forecasting using higher order tensor RNNs. ar Xiv preprint ar Xiv:1711.00073, 2017. Hejia Qiu, Chao Li, Ying Weng, Zhun Sun, Xingyu He, and Qibin Zhao. On the memory mechanism of tensor-power recurrent models. In International Conference on Artificial Intelligence and Statistics, pages 3682 3690. PMLR, 2021. Adrian Valente, Srdjan Ostojic, and Jonathan W Pillow. Probing the relationship between latent linear dynamical systems and low-rank recurrent neural network models. Neural Computation, 34(9): 1871 1892, 2022. Saul Kato, Harris S Kaplan, Tina Schrödel, Susanne Skora, Theodore H Lindsay, Eviatar Yemini, Shawn Lockery, and Manuel Zimmer. Global brain dynamics embed the motor command sequence of Caenorhabditis elegans. Cell, 163(3):656 669, 2015. Jesse M Gray, Joseph J Hill, and Cornelia I Bargmann. A circuit for navigation in Caenorhabditis elegans. Proceedings of the National Academy of Sciences, 102(9):3184 3191, 2005. Harris S Kaplan, Oriana Salazar Thula, Niklas Khoss, and Manuel Zimmer. Nested neuronal dynamics orchestrate a behavioral hierarchy across timescales. Neuron, 105(3):562 576, 2020. Jihye Yeon, Jinmahn Kim, Do-Young Kim, Hyunmin Kim, Jungha Kim, Eun Jo Du, Kyeong Jin Kang, Hyun-Ho Lim, Daewon Moon, and Kyuhyung Kim. A sensory-motor neuron type mediates proprioceptive coordination of steering in C. elegans via two TRPC channels. PLo S biology, 16(6): e2004929, 2018. Jonathan W Pillow, Jonathon Shlens, Liam Paninski, Alexander Sher, Alan M Litke, EJ Chichilnisky, and Eero P Simoncelli. Spatio-temporal correlations and visual signalling in a complete neuronal population. Nature, 454(7207):995 999, 2008. Tohru Katayama. Subspace Methods for System Identification, volume 1. Springer, 2005. Martin Chalfie, John E Sulston, John G White, Eileen Southgate, J Nicol Thomson, and Sydney Brenner. The neural circuit for touch sensitivity in Caenorhabditis elegans. Journal of Neuroscience, 5 (4):956 964, 1985. Beverly J Piggott, Jie Liu, Zhaoyang Feng, Seth A Wescott, and XZ Shawn Xu. The neural circuits and synaptic mechanisms underlying motor initiation in C. elegans. Cell, 147(4):922 933, 2011. Amirreza Farnoosh, Bahar Azari, and Sarah Ostadabbas. Deep switching auto-regressive factorization: Application to time series forecasting. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, pages 7394 7403, 2021. Andreas S Weigend and Neil A Gershenfeld. Time series prediction: Forecasting the future and understanding the past. Santa Fe Institute Studies in the Sciences of Complexity, 1994. Supplementary Materials for: Switching Autoregressive Low-rank Tensor Models Table of Contents Appendix A: SALT Optimization via Tensor Regression. Appendix B: SALT Approximates a (Switching) Linear Dynamical System. Appendix C: Single-subspace SALT. Appendix D: Synthetic Data Experiments. Appendix E: Modeling Mouse Behavior. Appendix F: Modeling C. elegans Neural Data. Appendix G: Deep Switching Autoregressive Model. A SALT Optimization via Tensor Regression Let yt RN1 be the t-th outputs and Xt RN2 N3 be the t-th inputs. The regression weights are a tensor A RN1 N2 N3, which we model via a Tucker decomposition, k=1 gijk u:i v:j w:k, (11) where ui, vj, and wk are columns of the factor matrices U RN1 D1, V RN2 D2, and W RN3 D3, respectively, and gijk are entries in the core tensor G RD1 D2 D3. We define j,k to be a tensor-matrix product over the jth and kth slices of the tensor. For example, given a three-way tensor A RD1 D2 D3 and a matrix X RD2 D3, A 2,3 X = PD2 j=1 PD3 k=1 a:jkxjk. This operation is depicted in Figure 6. Consider the linear model, yt N(A 2,3 Xt, Q) where A 2,3 Xt is defined using the Tucker decomposition of A as, A 2,3 Xt = A(1)vec(Xt) (12) = UG(1)(V W )vec(Xt) (13) = UG(1)vec(V Xt W) (14) where A(1) RN1 N2N3 and G(1) RD1 D2D3 are mode-1 matricizations of the corresponding tensors. Note that these equations assume that matricization and vectorization are performed in row-major order, as in Python but opposite to what is typically used in Wikipedia articles. Equation (14) can be written in multiple ways, and these equivalent forms will be useful for deriving the updates below. We have, A 2,3 Xt = UG(1)(ID2 W X t )vec(V ) (15) = UG(1)(V Xt ID3)vec(W) (16) = U vec(V Xt W) vec(G). (17) We minimize the negative log likelihood by coordinate descent. Optimizing the output factors Let ext = G(1)vec(V Xt W) (18) for fixed V, W, and G. The NLL as a function of U is, t (yt Uext) Q 1(yt Uext). (19) This is a standard least squares problem with solution Optimizing the core tensors Let e Xt = U vec(V Xt W) RN1 D1D2D3 denote the coefficient on vec(G) in eq. (17). The NLL as a function of g = vec(G) is, t (yt e Xtg) Q 1(yt e Xtg). (21) The minimizer of this quadratic form is, t e X t Q 1 e Xt t e X t Q 1yt Mode-1 (column) fibres -:/0 Observation dimension, N !! = # ,,. !!%&:!%( = + + + + + + = 5% 6:'( 8',%*( + 6:'( 8',%*( + 6:'( 8',%*( Figure 6: Depiction of the 2,3 tensor operator we use [Kolda and Bader, 2009] This can be thought of as a generalization of matrix-vector products to tensor-matrix products. Optimizing the input factors Let e Xt = UG(1)(ID2 W X t ) (23) for fixed U, W, and G. The NLL as a function of v = vec(V ) is, t (yt e Xtv) Q 1(yt e Xtv). (24) The minimizer of this quadratic form is, t e X t Q 1 e Xt t e X t Q 1yt Optimizing the lag factors Let e Xt = UG(1)(V Xt ID3) (26) for fixed U, V, and G. The NLL as a function of w = vec(W) is, t (yt e Xtw) Q 1(yt e Xtw). (27) The minimizer of this quadratic form is, t e X t Q 1 e Xt t e X t Q 1yt Multiple discrete states If we have discrete states zt {1, . . . , H} and each state has its own parameters (G(h), U(h), V(h), W(h), Q(h)), then letting ω(h) t = E[zt = h] denote the weights from the E-step, the summations in coordinate updates are weighted by ω(h) t . For example, the coordinate update for the core tensors becomes, t ω(h) t e X(h) t Q(h) 1 e X(h) t t ω(h) t e X(h) t Q(h) 1yt B SALT approximates a (Switching) Linear Dynamical System We now re-state and provide a full proof for Proposition 1. Proposition 1 (Low-Rank Tensor Autoregressions Approximate Stable Linear Dynamical Systems). Consider a stable linear time-invariant Gaussian dynamical system. We define the steady-state Kalman gain matrix as K = limt Kt, and Γ = A(I KC). The matrix Γ RS S has eigenvalues λ1, . . . , λS. Let λmax = maxs |λs|; for a stable LDS, λmax < 1 [Davis and Vinter, 1985]. Let n denote the number of real eigenvalues and m the number of complex conjugate pairs. Let ˆy(LDS) t = E[yt | y1:t 1] denote the predictive mean under a steady-state LDS, and ˆy(SALT) t the predictive mean under a SALT model. An order-L Tucker-SALT model with rank n + 2m = S, or a CP-SALT model with rank n + 3m, can approximate the predictive mean of the steady-state LDS with error ˆy(LDS) t ˆy(SALT) t = O(λL max). Proof. A stationary linear dynamical system (LDS) is defined as follows: xt = Axt 1 + b + ϵt (30) yt = Cxt + d + δt (31) where yt RN is the t-th observation, xt RS is the t-th hidden state, ϵt i.i.d. N(0, Q), δt i.i.d. N(0, R), and θ = (A, b, Q, C, d, R) are the parameters of the LDS. Following the notation of Murphy [2012], the one-step-ahead posterior predictive distribution for the observations of the LDS defined above can be expressed as: p(yt|y1:t 1) = N(Cµt|t 1 + d, CΣt|t 1CT + R) (32) µt|t 1 = Aµt 1 + b (33) µt = µt|t 1 + Ktrt (34) Σt|t 1 = AΣt 1AT + Q (35) Σt = (I Kt C)Σt|t 1 (36) p(x1) = N(x1 | µ1|0, Σ1|0) (37) Kt = (Σ 1 t|t 1 + CT RC) 1CT R 1 (38) rt = yt Cµt|t 1 d. (39) We can then expand the mean Cµt|t 1 + d as follows: Cµt|t 1 + d = C l=1 Γl AKt lyt l + C l=1 Γl(b AKt ld) + d (40) i=1 A(I Kt i C) for l {2, 3, . . .} , (41) Γ1 = I. (42) Theorem 3.3.3 of Davis and Vinter [1985] (reproduced with our notation below) states that for a stabilizable and detectable system, the limt Σt|t 1 = Σ, where Σ is the unique solution of the discrete algebraic Riccati equation Σ = AΣAT AΣCT (CΣCT + R) 1CΣAT + Q. (43) As we are considering stable autonomous LDSs here, the system is stabilizable and detectable, as all unobservable states are themselves stable [Davis and Vinter, 1985, Katayama, 2005] Theorem 3.3.3 (Reproduced from Davis and Vinter [1985], updated to our notation and context). The theorem has two parts. (a) If the pair (A, C) is detectable then there exists at least one non-negative solution, Σ, to the discrete algebraic Riccati equation (43). (b) If the pair (A, C) is stabilizable then this solution Σ is unique, and Σt|t 1 Σ as t , where Σt|t 1 is the sequence generated by (33)-(39) with arbitrary initial covariance Σ0. Then, the matrix Γ = A(I KC) is stable, where K is the Kalman gain corresponding to Σ; i.e., K = (Σ 1 + CT RC) 1CT R 1 (44) Proof. See Davis and Vinter [1985]. Note that Davis and Vinter [1985] define the Kalman gain as AK. The convergence of the Kalman gain also implies that each term in the sequence Γl converges to i=1 A(I KC) = (A(I KC))l 1 = Γl 1, (45) where, concretely, we define Γ = A(I KC). We can therefore make the following substitution and approximation Cµt|t 1 + d lim t = C l=1 Γl AKyt l + C l=1 Γl(b AKd) + d (46) l=1 Γl AKyt l + C l=1 Γl(b AKd) + d + l=1 Γl AKyt l + C l=1 Γl(b AKd) + d (48) The approximation is introduced as a result of truncating the sequence to consider just the first L terms, and discarding the higher-order terms (indicated in blue). It is important to note that each term in (46) is the sum of a geometric sequence multiplied elementwise with yt. There are two components we prove from here. First, we derive an element-wise bound on the error introduced by the truncation, and verify that under the conditions outlined that the bound decays monotonically in L. We then show that Tucker and CP decompositions can represent the truncated summations in (48), and derive the minimum rank required for this representation to be exact. Bounding The Error Term We first rearrange the truncated terms in (46), where we define xl AKyt l + b AKd l=L+1 Γl AKyt l + C l=L+1 Γl(b AKd) + d, (49) l=L+1 CΓlxl, (50) l=L+1 CEΛl 1E 1xl, (51) l=L+1 PΛl 1ql, (52) where EΛE 1 is the eigendecomposition of Γ, P CE, and ql E 1xl. We now consider the infinity-norm of the error, and apply the triangle and Cauchy-Schwartz inequalities. We can write the bound on the as l=L+1 F Γl ! , where n = arg max k l=L+1 F Γl ! s=1 pnsλl 1 d ql,s s=1 |pns| λl 1 s |ql,s| . (55) Upper bounding the absolute magnitude of ql,s by W provides a further upper bound, which we can then rearrange s=1 |pns| λl 1 s , (56) λl 1 s . (57) The first two terms are constant, and hence the upper bound is determined by the sum of the of the lth power of the eigenvalues. We can again bound this sum by setting all eigenvalues equal to the magnitude of the eigenvalue with the maximum magnitude (spectral norm), denoted λmax: l=L+1 λl 1 max, (58) where these second summation is not a function of s, and W PS s=1 |pns| is constant. This summation is a truncated geometric sequence. Invoking Theorem 3.3.3 of Davis and Vinter [1985] again, the matrix Γ has only stable eigenvalues, and hence λmax < 1. Therefore the sequence sum will converge to X l=L+1 λl 1 max = λL max 1 λmax . (59) Rearranging again, we see that the absolute error on the nth element of yt is therefore bounded according to a power of the spectral norm s=1 |pns| λL max 1 λmax , (60) = O λL max . (61) More specifically, for a stable linear time-invariant dynamical system, and where q and hence y is bounded, then the bound on the error incurred reduces exponentially in the length of the window L. Furthermore, this error bound will reduce faster for systems with a lower spectral norm. Diagonalizing the System We first transform Γ into real modal form, defined as EΛE 1, where E and Λ are the eigenvectors and diagonal matrix of eigenvalues of Γ. Letting Γ have n real eigenvalues and m pairs of complex eigenvalues (i.e., n + 2m = S), we can express E, Λ, and E 1 as: E = [ a1 . . . an b1 c1 . . . bm cm ] (62) λ1 ... λn σ1 ω1 ω1 σ1 ... σm ωm ωm σm d T 1... d T n e T 1 f T 1... e T m f T m where a1 . . . an are the right eigenvectors corresponding to n real eigenvalues λ1 . . . λn, and bi and ci are the real and imaginary parts of the eigenvector corresponding to the complex eigenvalue σi + jωi. Note that Γl = (A(I KC))l 1 = EΛl 1E 1 (65) The lth power of Λ, Λl, where l 0, can be expressed as: λl 1 ... λl n σ1,l ω1,l ω1,l σ1,l ... σm,l ωm,l ωm,l σm,l where σi,l = σ2 i,l 1 ω2 i,l 1, ωi,l = 2σi,l 1ωi,l 1 for l 2, σi,1 = σi, ωi,1 = ωi, σi,0 = 1, and ωi,0 = 0. Tucker Tensor Regression Let H RS S L be a three-way tensor, whose lth frontal slice H::l = Λl 1. Let G RS S S be a three-way tensor, whose entry gijk = 1i=j=k for 1 k n, and gijk = ( 1)1i+1=j=k+11(i=j=k) (i 1=j 1=k) (i=j+1=k+1) (i+1=j=k+1) for k {n + 1, n + 3, . . . , n + 2m 1}. Let W RL S be a matrix, whose entry wlk = λl 1 k for 1 k n, wlk = σk,l 1 for k {n+1, n+3, . . . , n+2m 1}, and wlk = ωk,l 1 for k {n+2, n+4, . . . , n+2m}. We can then decompose H into G RS S S and W RL S such that H = G 3 W (Figure 7). 1 0 𝜎!,! 𝜎!," 𝜎!,#$! 𝜔!,#$! 1 0 𝜎',! 𝜎'," 𝜎',#$! 𝜔',#$! Figure 7: Decomposition of H into G and W such that H = G 3 W: Given an LDS whose A(I KC) has n real eigenvalues and m pairs of complex eigenvalues, this decomposition illustrates how Tucker-SALT can approximate the LDS well with rank n + 2m. With V = (E 1AK)T , U = CE, m = C PL l=1 Γl(b AKd) + d, and Xt = yt 1:t L, we can rearrange the mean to: Cµt|t 1 + d C l=1 EΛl 1E 1AKyt l + C l=1 Γl(b AKd) + d (67) l=1 H::l VT yt l + m (68) l=1 (G 3wl)VT yt l + m (69) l=1 ((G 2 V) 3wl)yt l + m (70) k=1 g:jk v:j(wlkyt l) + m (71) k=1 g:jk(v :j Xtw:k) + m (72) k=1 u:igijk(v :j Xtw:k) + m (73) k=1 gijku:i v:j w:k 2,3 Xt + m (74) CP Tensor Regression By rearranging E, Λl, and E 1 into J, Pl, and S respectively as follows: J = [ a1 . . . an b1 + c1 b1 c1 . . . bm + cm bm cm ] (75) λl 1 ... λl n σ1,l α1,l β1,l ... σm,l αm,l βm,l d T 1... d T n e T 1 + f T 1 f T 1 e T 1... e T m + f T m f T m e T m where J RS (n+3m), Pl R(n+3m) (n+3m), S R(n+3m) S, αi,l = ωi,l σi,l, and βi,l = ωi,l σi,l, we can diagonalize (A(I KC))l as JPl S. Let V = (SAK)T , U = CJ, m = C PL l=1 Γl(b AKd) + d, and Xt = yt 1:t L. Let W RL (n+3m) be a matrix, whose element in the lth row and kth column is pl 1,kk (i.e., the element in the kth row and kth column of Pl 1), and G R(n+3m) (n+3m) (n+3m) be a diagonal 3-way tensor, where gijk = 1i=j=k. We can then rearrange the mean to: Cµt|t 1 + d C l=1 EΛl 1E 1AKyt l + C l=1 Γl(b AKd) + d (78) l=1 JPl 1SAKyt l + m (79) l=1 Pl 1V yt l + m (80) k gijk u:i v:j(pl 1,kkyt l) + m (81) k gijk u:i v:j(Xtw:k) + m (82) k=1 gijk u:i v:j w:k 2,3 Xt + m (83) And so concludes the proof. C Single-subspace SALT Here we explicitly define the generative model of multi-subspace and single-subspace Tucker-SALT and CP-SALT. Single-subspace SALT is analogous to single-subspace SLDSs (also defined below), where certain emission parameters (e.g., C, d, and R) are shared across discrete states. This reduces the expressivity of the model, but also reduces the number of parameters in the model. Note that both variants of all models have the same structure on the transition dynamics of zt. Multi-subspace SALT Note that the SALT model defined in (6) and (7) in the main text is a multi-subspace SALT. We repeat the definition here for ease of comparison. yt i.i.d. N k=1 g(zt) ijk u(zt) :i v(zt) :j w(zt) :k 2,3 yt 1:t L + b(zt),Σ(zt) D1 = D2 = D3 = D and G is diagonal for CP-SALT. Single-subspace Tucker-SALT In single-subspace methods, the output factors are shared across discrete states yt i.i.d. N k=1 g(zt) :jk v(zt) :j w(zt) :k 2,3 yt 1:t L where m(zt) RD1. Single-subspace CP-SALT Single-subspace CP-SALT requires an extra tensor compared to Tucker-SALT, as this tensor can no longer be absorbed in to the core tensor. yt i.i.d. N m(zt) + P(zt) k=1 g(zt) :jk v(zt) :j w(zt) :k 2,3 yt 1:t L where U RN D 1, P(zt) RD 1 D1, m(zt) RD 1, D1 = D2 = D3 = D, and G is diagonal. Multi-subspace SLDS Multi-subspace SLDS is a much harder optimization problem, which we found was often numerically unstable. We therefore do not consider multi-subspace SLDS in these experiments, but include its definition here for completeness xt N A(zt)xt 1 + b(zt), Q(zt) , (87) yt N C(zt)xt + d(zt), R(zt) . (88) Single-subspace SLDS Single-subspace SLDS was used in all of our experiments, and is typically used in practice [Petreska et al., 2011, Linderman et al., 2017] xt N A(zt)xt 1 + b(zt), Q(zt) , (89) yt N (Cxt + d, R) . (90) D Synthetic Data Experiments D.1 Extended Experiments for Proposition 1 In Section 5.1 we showed that Proposition 1 can accurately predict the required rank for CPand Tucker-SALT models. We showed results for a single LDS for clarity. We now extend this analysis across multiple random LDS and SALT models. We randomly sampled LDSs with latent dimensions ranging from 4 to 10, and observation dimensions ranging from 9 to 20. For each LDS, we fit 5 randomly initialized CP-SALT and Tucker-SALT models with L = 50 lags. We varied the rank of our fit SALT models according to the rank predicted by Proposition 1. Specifically, we computed the estimated number of ranks for a given LDS, denoted D , and then fit SALT models with {D 2, D 1, D , D + 1, D + 2} ranks. According to Proposition 1, we would expect to see the reconstruction error of the autoregressive tensor be minimized, and for prediction accuracy to saturate, at D = D . To analyze these model fits, we first computed the average mean squared error of the autoregressive tensor corresponding to the LDS simulation, as a function of SALT rank relative to the rank required by Proposition 1. We see, as predicted by Proposition 1, that error in the autoregressive tensor is nearly always minimized at D (Figure 8A). Tucker-SALT was always minimized at D . Some CP-SALT fits have lower MSE at ranks other than predicted by Proposition 1. We believe this is due to local minima in the optimization. We next investigated the test log-likelihood as a function of the relative rank (Figure 8B). Interestingly, the test log-likelihood shows that Tucker-SALT strongly requires the correct number of ranks for accurate prediction, but CP-SALT can often perform well with fewer ranks than predicted (although still a comparable number of ranks to Tucker-SALT). As in Figure 2, these analyses empirically confirm Proposition 1. -2 -1 0 1 2 D D Relative MSE (a) Normalized MSE of autoregressive tensor. -2 -1 0 1 2 D D Relative test log-likelihood Tucker-SALT CP-SALT (b) Normalized log-likelihood on held-out test set. Figure 8: Extended results examining Proposition 1. Results are shown for the ability of SALT to estimate ten randomly generated LDSs, using five SALT repeats for each LDS. MSEs (in panel A) and log-likelihoods (in panel B) are normalized by the mean MSE and mean log-likelihood of SALT models trained with D = D . D is the rank of the fit SALT model, and D is the necessary rank predicted by Proposition 1. D.2 Quantitative Performance: Synthetic Switching LDS Experiments We include further results and analysis for the NASCAR and Lorenz attractor experiments presented in Section 5.2. We compare the marginal likelihood achieved by single-subspace SALT models of different sizes. We see that SALT outperforms ARHMMs, and can fit larger models (more lags) without overfitting (Figure 9). Note that the SLDS does not admit exact inference, and so we cannot readily compute the exact marginal likelihood for the SLDS. D.3 TVART versus SALT in recovering the parameters of SLDSs We compared SALT to TVART Harris et al. [2021], another tensor-based method for modeling autoregressive processes. We modified TVART (as briefly described in the original paper, Harris et al. [2021]) so that it can handle AR(p) processes, as opposed to only AR(1) processes. TVART is also 1 5 6 7 8 9 10 15 20 Lags Test log-likelihood (a) NASCAR. 1 5 10 15 20 Lags Test log-likelihood Tucker-SALT rank 1 Tucker-SALT rank 2 Tucker-SALT rank 3 Tucker-SALT rank 4 CP-SALT rank 1 CP-SALT rank 2 CP-SALT rank 3 CP-SALT rank 4 ARHMM (b) Lorenz. Figure 9: Quantitative performance of different SALT models and ARHMMs (averaged over 3 different runs) on the synthetic experiments presented in Section 5.2. The test-set log likelihood is shown as a function of lags in the SALT model, for both (A) the NASCAR and (B) Lorenz synthetic datasets. not a probabilistic model (i.e., cannot compute log-likelihoods), and so we focus our comparison on how well these methods recover the parameters of a ground-truth SLDS. We used the same SLDS that we used to generate the NASCAR dataset in Section 5.2. We then used L = 7 CP-SALT and Tucker-SALT with ranks 3 and 2, respectively, and computed the MSE between the ground truth tensor and SALT tensors. For TVART, we used L = 7, bin size of 10, and ranks 2 and 3 to fit the model to the data. We then clustered the inferred dynamics parameters to assign discrete states. To get the TVART parameter estimation, we computed the mean of the dynamics parameters for each discrete state and computed the MSE against the ground truth tensor. The MSE results are as follows: Table 2: Results comparing SALT and TVART Harris et al. [2021] on the NASCAR example. Model Rank Tensor Reconstruction MSE ( 10 3) Number of parameters TVART 2 0.423 1.4K TVART 3 0.488 2.0K Tucker-SALT 2 0.294 0.6K CP-SALT 3 0.297 0.7K Table 2 shows that SALT models recover the dynamics parameters of the ground truth SLDS more accurately. Furthermore, we see that SALT models use fewer parameters than TVART models for the dataset (as the number of parameters in TVART scales linearly with the number of windows). We also note that TVART cannot be applied to held-out data, and, without post-hoc analysis, does not readily have a notion of re-usable dynamics or syllables. D.4 The effect of the number of switches on the recovery of the parameters of the autoregressive dynamic tensors We asked how the frequency of discrete state switches affected SALT s ability to recover the autoregressive tensors. We trained SALT, the ARHMM, all with L = 5 lags, and the SLDS on data sampled from an SLDS with varying number of discrete state switches. The ground-truth SLDS model had H = 2 discrete states, N = 20 observations and S = 7 dimensional continuous latent states. The matrix A(h)(I K(h)C(h)) of each discrete state of the ground-truth SLDS had 1 real eigenvalue and 3 pairs of complex eigenvalues. We sampled 5 batches of T = 15, 000 timesteps of data from the ground-truth SLDS, with sn {1, 10, 25, 75, 125} number of discrete state switches that were evenly spaced out across the data. We then computed the mean squared error (MSE) between the SLDS tensors and the tensors reconstructed by SALT, the ARHMM, and the SLDS. (Figure 10). More precisely, we combined the 3rd order autoregressive tensors from each discrete state into a 4th order tensor, and calculated the MSE based on these 4th order tensors. As expected, the MSE 1 10 25 75 125 sn (Number of discrete state switches) Mean squared error CP-SALT Tucker-SALT ARHMM SLDS Figure 10: The quality of SALT approximation of SLDSs decreases as the number of discrete state switches increases: The data comes from an SLDS with H = 2, N = 20, and S = 7. 15,000 timesteps were generated, with varying numbers of evenly spaced discrete state switches (x-axis). The mean squared error of reconstructing the autoregressive tensors increased as a function of the number of discrete state switches. Note that we combined the 3rd order autoregressive tensors from each discrete state into a 4th order tensor, and calculated the MSE based on these 4th order tensors. increased with the number of switches in the data, indicating that the quality of SALT approximation of SLDSs decreases as the frequency of discrete state switches increases. E Modeling Mouse Behavior We include further details for the mouse experiments in Section 5.3. E.1 Training Details We used the first 35,949 timesteps of data from each of the three mice, which were collected at 30Hz resolution. We used H = 50 discrete states and fitted ARHMMs and CP-SALT models with varying lags and ranks. Similar to Wiltschko et al. [2015], we imposed stickiness on the discrete state transition matrix via a Dirichlet prior with concentration of 1.1 on non-diagonals and 6 104 on the diagonals. These prior hyperparameters were empirically chosen such that the durations of the inferred discrete states and the given labels were comparable. We trained each model 5 times with random initialization for each hyperparameter, using 100 iterations of EM on a single NVIDIA Tesla P100 GPU. E.2 Video Generation Here we describe how the mouse behavioral videos were generated. We first determined the CPSALT hyperparameters as those which led to the highest log-likelihood on the validation dataset. Then, using that CP-SALT model, we computed the most likely discrete states on the train and test data. Given a discrete state h, we extracted slices of the data whose most likely discrete state was h. We padded the data by 30 frames (i.e. 1 second) both at the beginning and the end of each slice for the movie. A red dot appears on each mouse for the duration of discrete state h. We generated such videos for all 50 discrete states (as long as there existed at least one slice for each discrete state) on the train and test data. For a given discrete state, the mice in each video behaved very similarly (e.g., the mice in the video for state 18 pause" when the red dots appear, and those in the video for state 32 walk" forward), suggesting that CP-SALT is capable of segmenting the data into useful behavioral syllables. See Mo Seq_salt_videos_train" and Mo Seq_salt_videos_test" in the supplementary material for the videos generated from the train and test data, respectively. salt_crowd_i.mp4" refers to the crowd video for state i. We show the principal components for states 1, 2, 13, 32, 33, 47 in Figure 11. E.3 Modeling Mouse Behavior: Additional Analyses We also investigated whether SALT qualitatively led to a good segmentation of the behavioral data into discrete states, shown in Figure 11. Figure 11A shows a 30 second example snippet of the test data from one mouse colored by the discrete states inferred by CP-SALT. CP-SALT used fewer discrete states to model the data than the ARHMM (Figure 11B). Coupled with the finding that CP-SALT improves test-set likelihoods, this suggests that the ARHMM may have oversegmented the data and CP-SALT may be better able to capture the number of behavioral syllables. Figure 11C shows average test data (with two standard deviations) for a short time window around the onset of a discrete state (we also include mouse videos corresponding to that state in the supplementary materials). The shrinking gray area around the time of state onset, along with the similar behaviors of the mice in the video, suggests that CP-SALT is capable of segmenting the data into consistent behavioral syllables. Figure 11: CP-SALT leads to qualitatively good segmentation of the mouse behavioral data into distinct syllables.: (A) 30 seconds of test data (Mouse 1) with the discrete states inferred by CP-SALT as the background color. (B) For one mouse, the cumulative number of frames that are captured by each discrete state, where the discrete states are ordered according to how frequently they occur. (C) The average test data, with two standard deviations, for six states of CP-SALT, aligned to the time of state onset. The shrinkage of the gray region around the state onset tells us that CP-SALT segments the test data consistently. F Modeling C. elegans Neural Data We include further details and results for the C. elegans example presented in Section 5.4. This example highlights how SALT can be used to gain scientific insight in to the system. F.1 Training Details We used 3200 timesteps of data (recorded at 3Hz) from one worm, for which 48 neurons were confidently identified. The data were manually segmented in to seven labels (reverse sustained, slow, forward, ventral turn, dorsal turn, reversal (type 1) and reversal (type 2). We therefore used H = 7 discrete states in all models (apart from the GLM). After testing multiple lag values, we selected L = 9 for all models, as these longer lags allow us to examine longer-timescale interactions and produced better segmentations across models, with only a small reduction in variance explained. We trained each model 5 times with KMeans initialization, using 100 iterations of EM on a single NVIDIA Tesla V100 GPU. Models that achieved 90% explained variance on a held-out test set were then selected and analyzed (similar to Linderman et al. [2019]). F.2 Additional Quantitative Results Figure 12 shows additional results for training different models. In Figure 12A we see that models with larger ranks (or latent dimension) achieve higher explained variance. Interestingly, longer lags can lead to a slight reduction in the explained variance, likely due to overfitting. This effect is less pronounced in the more constrained single-subspace SALT, but, these models achieve lower explained variance ratios throughout. Longer lag models allow us to inspect longer-timescale dependencies, and so are more experimentally insightful. Figure 12B shows the confusion matrix for discrete states between learned models and the given labels. The segmentations were similar across all models that achieved 90% explained variance. F.3 Additional Autoregressive Filters Figures 13 and 14 show extended versions of the autoregressive filters included in Section 5.4. Figure 13 shows the filters learned for ventral and dorsal turns (for which panel A was included in Figure 1 2 3 4 5 6 7 SALT Labels VT DT REV1 REV2 Given Labels 1 2 3 4 5 6 7 SLDS Labels 1 2 3 4 5 6 7 ARHMM Labels Number of Timesteps Number of Timesteps 8 9 10 11 12 13 14 15 16 17 18 SALT rank or SLDS continuous latent state dimension Held-out test data Explained variance ratio SS CP-SALT L=1 SS CP-SALT L=3 SS CP-SALT L=6 SS CP-SALT L=9 MS CP-SALT L=1 MS CP-SALT L=3 MS CP-SALT L=6 MS CP-SALT L=9 SS SLDS Figure 12: SALT and SLDS perform comparably on held-out data: (A): Explained variance on a held-out sequence. Single-subspace (SS) SALT and SLDS perform comparably. Multi-subspace (MS) SALT achieves a higher explained variance with fewer ranks. Multi-subspace SLDS was numerically unstable. (B): Confusion matrices between given labels and predicted labels. All methods produce similar quality segmentations. 5), while Figure 14 shows the filters for forward and backward locomotion. Note that the GLM does not have multiple discrete states, and hence the same filters are used across states. We see for ARHMM and SALT that known-behavior tuned neurons have higher magnitude filters (determined by area under curve), whereas the SLDS and GLM do not recover such strong state-specific tuning. Since the learned SLDS did not have stable within-state dynamics, the autoregressive filters could not be computed using Equation (48). We thus show CA(h)l C+ for lag l, where C+ denotes the Moore-Penrose pseudoinverse of C, as a proxy for the autoregressive filters of discrete state h of the SLDS. Note that this is a post-hoc method and does not capture the true dependencies in the observation space. We see that SALT consistently assigns high autoregressive weight to neurons known to be involved in certain behaviors (see Figures 13 and 14). In contrast, the ARHMM identifies these relationships less reliably, and the estimate of the SLDS autoregressive filters identifies few strong relationships. As the GLM only have one state , the autoregressive filters are averaged across state, and so few strong relationships are found. This highlights how the low-rank and switching properties of SALT can be leveraged to glean insight into the system. Active during: SALT Tensor Weights Lag=9, equivalent to 3 seconds; from left to right SLDS State 4 (Ventral Turn) SLDS State 5 (Dorsal Turn) SLDS Tensor Weights C ARHMM State 4 (Ventral Turn) ARHMM State 5 (Dorsal Turn) ARHMM Tensor Weights GLM Tensor Weights SALT State 4 (Ventral Turn) SALT State 5 (Dorsal Turn) Figure 13: Autoregressive tensors learned by different models (Ventral and Dorsal Turns): (A-C) One-dimensional autoregressive filters learned in two states by SALT, SLDS, ARHMM (identified as ventral and dorsal turns), and (D) by a GLM. RIV and SMDV are known to mediate ventral turns, while SMDD mediate dorsal turns [Kato et al., 2015, Gray et al., 2005, Yeon et al., 2018]. SALT State 3 (Forward) SALT State 7 (Reverse) Active during: SALT Tensor Weights Lag=9, equivalent to 3 seconds; from left to right SLDS Tensor Weights SLDS State 3 (Forward) SLDS State 7 (Reverse) ARHMM Tensor Weights ARHMM State 3 (Forward) ARHMM State 7 (Reverse) GLM Tensor Weights Figure 14: Autoregressive tensors learned by different models (Forward Locomotion and Reversal): (A-C) One-dimensional autoregressive filters learned in two states by SALT, SLDS, ARHMM (identified as forward and reverse), and (D) by a GLM. AVB and RIB are known to mediate forward locomotion, while AVA and AVE are involved in initiating reversals [Kato et al., 2015, Gray et al., 2005, Chalfie et al., 1985, Piggott et al., 2011]. G Deep Switching Autoregressive Model Here we compare SALT to deep switching auto-regressive factorization (DSARF) models [Farnoosh et al., 2021]. DSARF models construct a deep generative model using a set of low-rank factors, which are then weighted according to inferred discrete states at each timestep. This combination then defines the time evolution of an autoregressive process, which in turn defines the distribution over the observed variables. As such, non-linear function approximators can be used to parameterize many of the link functions, gaining expressivity, but retaining the parameter efficiency and interpretability of conventional methods. Variational inference is used to learn the parameters and perform inference. Unlike SALT, DSARF can handle missing data, as autoregressive processes are defined only in the latent space. We compare SALT against DSARF on an example drawn from Farnoosh et al. [2021], proposed for studying switching systems by Ghahramani and Hinton [2000] (see also Weigend and Gershenfeld [1994]). This example studies a patient believed to have sleep apnea, typified by periods where normal rhythmic breathing ceases, resulting in periods of low or zero respiratory rates. The data are a onedimensional time series of a measure of chest volume, such that oscillations in the data correspond to rhythmic breathing (see Figure 15). Periods of constant volume correspond to apnea bouts. Disjoint one thousand length sequences are used for train and test sets, such that ytrain R1 1000 and ytest R1 1000. We apply DSARF as described in Farnoosh et al. [2021] and SALT. The SALT model we used uses H = 2 discrete states, D = 5 ranks and L = 10 lags, and with a single latent subspace. We use an L2 weight penalty of 10 4. As per the mouse experiments, we add a Dirichlet stickiness prior with parameters γ = 10 2 and κ = 103 (See Table 3). Inference results on the test set are shown in Figure 15. SALT hyperparameters were selected through manual tuning. SALT achieves a normalized next-step prediction RMSE of 22.57%, vs 23.86% achieved by DSARF (described by [Farnoosh et al., 2021] as short-term prediction ) This result highlights that SALT is competitive with deep methods that provide the desired discrete-continuous representation. Figure 15: Comparison of DSARF and SALT-CP on the apnea example as presented in Farnoosh et al. [2021] Shown are filtering reconstructions of the observed trace and binary discrete label. SALT achieves a normalized RMSE of 22.57% vs 23.86% achieved by DSARF. Right panel is a zoomed in version of the left panel. We see that reconstructions and scores achieved by both models are comparable. H Experiment Configurations In this section we provide extended details for the experiments presented in the main text, including the hyperparameters selected and the hyperparameter tuning process. H.1 List of Hyperparameters Table 3 outlines the key hyperparameters we consider when comparing and selecting models. Details of the values for each of these hyperparameters are then specified in the sections afterwards. Table 3: Outline of the key hyperparameters in the three main models families we consider: SALT, SLDS and ARHMM. Hyperparameter Description Applicable to models Permissible Values Tensor factorization Factorization structure of autoregressive tensor SALT, ARHMM {CP, Tucker, None} Subspace type Whether the output factors are shared across discrete states. SALT, SLDS {Single, Multi} D, Tensor ranks Dimension of the core tensor. Assumed throughout that D = D1 = D2 = D3. SALT Z 1 H, Number of discrete states Number of discrete states in switching models. SALT, SLDS, ARHMM Z 1 L, Number of lags Number of previous observations in autoregressive models. SALT, ARHMM Z 1 S, Latent space dimension Dimension of continuous latent state. SLDS Z 1 Temporal L2 penalty L2 penalty applied lag parameters at longer lags. The penalty strength is defined as α βl 1 where l is the lag. SALT, ARHMM α R 0, β R 1 Stickiness Dirichlet prior that can be added to penalize discrete state switches. SALT, SLDS, ARHMM γ R>0, κ R 0 Global L2 penalty L2 penalty applied across all parameters. (See note below) SALT, SLDS, ARHMM R 0 Unless otherwise specified, we performed a grid search over a range of values within the permissible set. In certain circumstances, the hyperparameter was selected to match known properties of the data, e.g. we used H = 4 for the NASCAR data, because we know there are four underlying states in the data. During pilot experiments we experimented with a global L2 penalty applied to all parameters in the model. We found that varying this parameter did not affect the key outcomes of each experiment. We therefore set the L2 penalty strength to the same value for all the models for all the experiments, with strength 0.0001, unless otherwise specified. H.2 Experiment Specific Hyperparameters Here we specify experiment-specific details. H.2.1 Section 5.1: SALT Faithfully Approximates LDS For a given LDS with latent dimension size of S and initialized with a random rotation matrix, we computed the estimated number of CP-SALT and Tucker-SALT ranks, denoted D CP = n + 3m and D Tucker = n + 2m, respectively, where n is the number of real eigenvalues and m is the number of complex conjugate pairs of Γ defined in Proposition 1. We then fitted CP-SALT models with {min(S, D CP) 2, . . . , max(S, D CP) + 2} and Tucker-SALT models with {min(S, D Tucker) 2, . . . , max(S, D Tucker) + 2}. We chose L = 50, as this was the horizon at which the autoregressive parameters were approximately zero. The temporal L2 penalty was set to 1.0. H.2.2 Section 5.2: Synthetic Switching LDS Examples Synthetic NASCAR dataset All models used the true number of discrete latent states, H = 4. We fitted single-subspace CP-SALT and Tucker-SALT models with ranks D {1, 2, 3, 4}. For both SALT models and ARHMMs, we used L {1, 5, 6, 7, 8, 9, 10, 15, 20}. A temporal L2 penalty of 1.0 was used for CP-SALT and Tucker-SALT models. Single-subspace SLDSs were fitted with the true underlying latent dimension, S = 2. Synthetic Lorenz attractor dataset All models used the approximate number of discrete latent states in the data, H = 2 (following Linderman et al. [2017]). We fitted single-subspace CP-SALT and Tucker-SALT models with ranks D {1, 2, 3, 4}. For both SALT models and ARHMMs, we used L {1, 5, 10, 15, 20}. A temporal L2 penalty of 2.0 was used for CP-SALT and Tucker-SALT models. Single-subspace SLDSs were fitted with the true underlying latent dimension, S = 3. H.2.3 Section 5.3: Modeling Mouse Behavior We fitted multi-subspace CP-SALT models with ranks D {10, 11, 12, 13, 14}. For both CP-SALT models and ARHMMs, we used L {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20} and the number of discrete states was set to H = 50 (51 behavioral states explained 95% of videos Wiltschko et al. [2015]). We use a temporal L2 penalty of 1.0. Similar to Wiltschko et al. [2015], we imposed stickiness on the discrete state transition matrix of both SALT models and ARHMMs via a Dirichlet prior. For discrete state h, the concentration parameters ν RH >0 of the Dirichlet prior is γ for νi, i = h, and γ + κ for νh. For this experiment, γ was set to 1.1 and κ to 6 104, which were empirically chosen such that the duration of the inferred discrete states and the given labels were comparable. H.2.4 Section 5.4: Modeling C. elegans Neural Data Following Linderman et al. [2019], we empirically searched for sets of hyperparameters that achieve 90% explained variance on a held-out test dataset. We fitted both single and multi-subspace CP-SALT models with ranks D {8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}. Similarly, SLDSs were fitted with the same range of latent dimension size. For both CP-SALT models and ARHMMs, we used L {1, 3, 6, 9} and the number of discrete states was set to H = 7, which is the number of unique manual labels. After testing multiple lag values, we selected L = 9 for SALT models and ARHMMs, as these longer lags allow us to examine longer-timescale interactions and produced better segmentation across models, with only a small reduction in variance explained. For ARHMMs, we set the global L2 penalty strength to 20.0 with the temporal L2 penalty set to 1.0. For SLDSs, we set the global L2 penalty strength to 100.0. For SALT models, we set the global L2 penalty strength to 0.2 and the temporal L2 penalty to 1.1. We additionally imposed a smoothness penalty to the weights of the lag factors of SALT models by adding η||FW||2 2 to the NLL for optimizing the lag factors, where η is the smoothness penalty strength and F is the first difference matrix. We set η = 0.8 for SALT models.