# nasx_neural_adaptive_smoothing_via_twisting__6bdbc34b.pdf NAS-χ: Neural Adaptive Smoothing via Twisting Dieterich Lawson*, Google Research dieterichl@google.com Michael Y. Li* Stanford University michaelyli@stanford.edu Scott W. Linderman Stanford University scott.linderman@stanford.edu Sequential latent variable models (SLVMs) are essential tools in statistics and machine learning, with applications ranging from healthcare to neuroscience. As their flexibility increases, analytic inference and model learning can become challenging, necessitating approximate methods. Here we introduce neural adaptive smoothing via twisting (NAS-X), a method that extends reweighted wake-sleep (RWS) to the sequential setting by using smoothing sequential Monte Carlo (SMC) to estimate intractable posterior expectations. Combining RWS and smoothing SMC allows NAS-X to provide low-bias and low-variance gradient estimates, and fit both discrete and continuous latent variable models. We illustrate the theoretical advantages of NAS-X over previous methods and explore these advantages empirically in a variety of tasks, including a challenging application to mechanistic models of neuronal dynamics. These experiments show that NAS-X substantially outperforms previous VIand RWS-based methods in inference and model learning, achieving lower parameter error and tighter likelihood bounds. 1 Introduction Sequential latent variable models (SLVMs) are a foundational model class in statistics and machine learning, propelled by the success of hidden Markov models [1] and linear dynamical systems [2]. To model more complex data, SLVMs have incorporated nonlinear conditional dependencies, resulting in models such as sequential variational autoencoders [3 7], financial volatility models [8], and biophysical models of neural activity [9]. While these nonlinear dependencies make SLVMs more flexible, they also frustrate inference and model learning, motivating the search for approximate methods. One popular method for inference in nonlinear SLVMs is sequential Monte Carlo (SMC), which provides a weighted particle approximation to the true posterior and an unbiased estimator of the marginal likelihood. For most SLVMs, SMC is a significant improvement over standard importance sampling, providing an estimator of the marginal likelihood with variance that grows linearly in the length of the sequence rather than exponentially [10]. SMC s performance, however, depends on having a suitable proposal. The optimal proposal is often intractable, so in practice, proposal parameters are learned from data. Broadly, there are two approaches to proposal learning: variational inference [11, 12] and reweighted wake-sleep [13, 14]. * Equal contribution. Work performed while at Stanford University. 37th Conference on Neural Information Processing Systems (Neur IPS 2023). Variational inference (VI) methods for SLVMs optimize the model and proposal parameters by ascending a lower bound on the log marginal likelihood that can be estimated with SMC, an approach called variational SMC [15 17]. Recent advances in variational SMC have extended it to work with smoothing SMC [18, 19], a crucial development for models where future observations are strongly related to previous model states. Without smoothing SMC, particle degeneracy can cause high variance of the lower bound and its gradients, resulting in unstable learning [18, 10]. Reweighted wake-sleep (RWS) methods instead use SMC s posterior approximation to directly estimate gradients of the log marginal likelihood [14, 20]. This approach can be interpreted as descending the inclusive KL divergence from the true posterior to the proposal. Notably, RWS methods work with discrete latent variables, providing a compelling advantage over VI methods that typically resort to high-variance score function estimates for discrete latent variables. In this work, we combine recent advances in smoothing variational SMC with the benefits of reweighted wake-sleep. To this end, we introduce neural adaptive smoothing via twisting (NASX), a method for inference and model learning in nonlinear SLVMs that uses smoothing SMC to approximate intractable posterior expectations. The result is a versatile, low-bias and low-variance estimator of the gradients of the log marginal likelihood suitable for fitting proposal and model parameters in both continuous and discrete SLVMs. After introducing NAS-X, we present two theoretical results that highlight the advantages of NAS-X over other RWS-based methods. We also demonstrate NAS-X s performance empirically in model learning and inference in linear Gaussian state-space models, discrete latent variable models, and high-dimensional ODE-based mechanistic models of neural dynamics. In all experiments, we find that NAS-X substantially outperforms several VI and RWS alternatives including ELBO, IWAE, FIVO, SIXO. Furthermore, we empirically show that our method enjoys lower bias and lower variance gradients, requires minimal additional computational overhead, and is robust and easy to train. 2 Background This work considers model learning and inference in nonlinear sequential latent variable models with Markovian structure; i.e. models that factor as pθ(x1:T , y1:T ) = pθ(x1)pθ(y1 | x1) t=2 pθ(xt | xt 1)pθ(yt | xt), (1) with latent variables x1:T X T , observations y1:T YT , and global parameters θ Θ. By nonlinear, we mean latent variable models where the parameters of the conditional distributions pθ(xt | xt 1) and pθ(yt | xt) depend nonlinearly on xt 1 and xt, respectively. Estimating the marginal likelihood pθ(y1:T ) and posterior pθ(x1:T | y1:T ) for this model class is difficult because it requires computing an intractable integral over the latents, pθ(y1:T ) = Z X T pθ(y1:T , x1:T ) dx1:T , We begin by introducing two algorithms, reweighted wake-sleep [14, 13] and smoothing sequential Monte Carlo [19], that are crucial for understanding our approach. 2.1 Reweighted Wake-Sleep Reweighted wake-sleep (RWS, [13, 14]) is a method for maximum marginal likelihood in LVMs that estimates the gradients of the marginal likelihood using self-normalized importance sampling. This is motivated by Fisher s identity, which allows us to write the gradients of the marginal likelihood as a posterior expectation, θ log pθ(y1:T ) = Epθ(x1:T |y1:T ) [ θ log pθ(x1:T , y1:T )] , (2) as proved in Appendix 8.1. The term inside the expectation is computable with modern automatic differentiation tools [21, 22], but the posterior pθ(x1:T | y1:T ) is unavailable. Thus, SNIS is used to form a biased but consistent Monte Carlo estimate of Eq. (2) [23]. Specifically, SNIS draws N IID samples from a proposal distribution, qϕ(x1:T | y1:T ) and weights them to form the estimator w(i) θ log pθ(x(i) 1:T , y1:T ), x(i) 1:T qϕ(x1:T | y1:T ), w(i) pθ(x(i) 1:T , y1:T ) qϕ(x(i) 1:T | y1:T ) (3) where w(i) are normalized weights, i.e. PN i=1 w(i) = 1. The variance of this estimator is reduced as qϕ approaches the posterior [23], so RWS also updates qϕ by minimizing the inclusive Kullback-Leibler (KL) divergence from the posterior to the proposal. Crucially, the gradient for this step can also be written as the posterior expectation ϕKL(pθ(x1:T | y1:T ) || qϕ(x1:T | y1:T )) = Epθ(x1:T |y1:T ) [ ϕ log qϕ(x1:T | y1:T )] , (4) as derived in Appendix 8.2. This allows RWS to estimate Eq. (4) using SNIS with the same set of samples and weights as Eq. (3), Epθ(x1:T |y1:T ) [ ϕ log qϕ(x1:T | y1:T )] w(i) ϕ log qϕ(x(i) 1:T | y1:T ). (5) Importantly, any method that provides estimates of expectations w.r.t. the posterior can be used for gradient estimation within the RWS framework, as we will see in the next section. 2.2 Estimating Posterior Expectations with Smoothing Sequential Monte Carlo As we saw in Eqs. (2) and (4), key quantities in RWS can be expressed as expectations under the posterior. Standard RWS uses SNIS to approximate these expectations, but in sequence models the variance of the SNIS estimator can scale exponentially in the sequence length. In this section, we review sequential Monte Carlo (SMC) [10, 24], an inference algorithm that can produce estimators of posterior expectations with linear or even sub-linear variance scaling. SMC approximates the posterior pθ(x1:T | y1:T ) with a set of N weighted particles x1:N 1:T constructed by sampling from a sequence of target distributions {πt(x1:t)}T t=1. Since these intermediate targets are often only known up to some unknown normalizing constant Zt, SMC uses the unnormalized targets {γt(x1:t)}T t=1, where πt(x1:t) = γt(x1:t)/Zt. Provided mild technical conditions are met and γT (x1:T ) pθ(x1:T , y1:T ), SMC returns weighted particles that approximate the posterior pθ(x1:T | y1:T ) [10, 24]. These weighted particles can be used to compute biased but consistent estimates of expectations under the posterior, similar to SNIS. SMC repeats the following steps for each time t: 1. Sample latents x1:N 1:t from a proposal distribution qϕ(x1:t | y1:T ). 2. Weight each particle using the unnormalized target γt to form an empirical approximation ˆπt to the normalized target distribution πt. 3. Draw new particles x1:N 1:t from the approximation ˆπt (the resampling step). By resampling away latent trajectories with low weight and focusing on promising particles, SMC can produce lower variance estimates than SNIS. For a thorough review of SMC, see Doucet and Johansen [10], Naesseth et al. [24], and Del Moral [25]. Filtering vs. Smoothing The most common choice of unnormalized targets γt are the filtering distributions pθ(x1:t, y1:t), resulting in the algorithm known as filtering SMC or a particle filter. Filtering SMC has been used to estimate posterior expectations within the RWS framework in neural adaptive sequential Monte Carlo (NASMC) [20], but a major disadvantage of filtering SMC is that it ignores future observations yt+1:T . Ignoring future observations can lead to particle degeneracy and high-variance estimates, which in turn causes poor model learning and inference [17, 18, 26, 27]. We could avoid these issues by using the smoothing distributions as unnormalized targets, choosing γt(x1:t) = pθ(x1:t, y1:T ), but unfortunately the smoothing distributions are not readily available from the model. We can approximate them, however, by observing that pθ(x1:t, y1:T ) is proportional to the filtering distributions pθ(x1:t, y1:t) times the lookahead distributions pθ(yt+1:T | xt). If the lookahead distributions are well-approximated by a sequence of twists {rψ(yt+1:T , xt)}T t=1, then running SMC with targets γt(x1:t) = pθ(x1:t, y1:t) rψ(yt+1:T , xt) approximates smoothing SMC [26]. Twist Learning We have reduced the challenge of obtaining the smoothing distributions to learning twists that approximate the lookahead distributions. Previous twist-learning approaches include maximum likelihood training on samples from the model [18, 28] and Bellman-type losses motivated by writing the twist at time t recursively in terms of the twist at time t + 1 [18, 29]. For NAS-X we use density ratio estimation (DRE) via classification to learn the twists, as introduced in Lawson et al. [19]. This method is motivated by observing that the lookahead distribution is proportional to a ratio of densities up to a constant independent of xt, pθ(yt+1:T | xt) = pθ(xt | yt+1:T ) pθ(yt+1:T ) pθ(xt) pθ(xt | yt+1:T ) pθ(xt) . (6) Results from the DRE via classification literature [30] provide a way to approximate this density ratio: train a classifier to distinguish between samples from the numerator pθ(xt | yt+1:T ) and denominator pθ(xt). Then, the pre-sigmoid output of the classifier will approximate the log of the ratio in Eq. (6). For an intuitive argument for this fact see Appendix 8.3, and for a full proof see Sugiyama et al. [30]. In practice, it is not possible to sample directly from pθ(xt | yt+1:T ). Instead, Lawson et al. [19] sample full trajectories from the model s joint distribution, i.e. draw x1:T , y1:T pθ(x1:T , y1:T ), and discard unneeded timesteps, leaving only xt and yt+1:T which are distributed marginally as pθ(xt, yt+1:T ). Training the DRE classifier on data sampled in this manner will approximate the ratio pθ(xt, yt+1:T )/pθ(xt)pθ(yt+1:T ), which is equivalent to Eq. (6), see Appendix 8.3. 3 NAS-X: Neural Adaptive Smoothing via Twisting The goal of NAS-X is to combine recent advances in smoothing SMC with the advantages of reweighted wake-sleep. Because SMC is a self-normalized importance sampling algorithm, it can be used to estimate posterior expectations and therefore the model and proposal gradients within a reweighted wake-sleep framework. In particular, NAS-X repeats the following steps: 1. Draw a set of N trajectories x(1:N) 1:T and weights w(1:N) 1:T from a smoothing SMC run with model pθ, proposal qϕ, and twist rψ. 2. Use those trajectories and weights to form estimates of gradients for the model pθ and proposal qϕ, as in reweighted wake-sleep. Specifically, NAS-X computes the gradients of the inclusive KL divergence for learning the proposal qϕ as w(i) t ϕ log qϕ(x(i) t | x(i) t 1, yt:T ) (7) and computes the gradients of the model pθ as w(i) t θ log pθ(x(i) t , yt | x(i) t 1). (8) 3. Update the twists rψ using density ratio estimation via classification. A full description is available in Algorithms 1 and 2. A key design decision in NAS-X is the specific form of the gradient estimators. Smoothing SMC provides two ways to estimate expectations of test functions with respect to the posterior: both the timestep-t and timestep-T approximations of the target distribution could be used, in the latter case by discarding timesteps after t. Specifically, pθ(x1:t | y1:T ) w(i) t δ(x1:t ; x(i) 1:t) w(i) T δ(x1:t ; (x(i) 1:T )1:t) (9) where δ(a ; b) is a Dirac delta of a located at b and (x1:T )1:t denotes selecting the first t timesteps of a timestep-T particle; due to SMC s ancestral resampling step these are not in general equivalent. For NAS-X we choose the time-t approximation of the posterior to lessen particle degeneracy, as in NASMC [20]. Note, however, that in the case of NASMC this amounts to approximating the posterior with the filtering distributions, which ignores information from future observations. In the case of NAS-X, the intermediate distributions directly approximate the posterior distributions because of the twists, a key advantage that we explore theoretically in Section 3.1 and empirically in Section 5. Algorithm 1: NAS-X Procedure NAS-X(θ0, ϕ0, ψ0, y1:T ) θ θ0, ϕ ϕ0, ψ ψ0 while not converged do x1:N 1:T , w1:N 1:T SMC({pθ(x1:t, y1:t), qϕ(xt | xt 1, yt:T ), rψ(xt, yt+1:T )}T t=1) θ = PT t=1 PN i=1 w(i) t θ log pθ(x(i) t , yt | x(i) t 1) ϕ = PT t=1 PN i=1 w(i) t ϕ log qϕ(x(i) t | x(i) t 1, yt:T ) θ grad-step(θ, θ) ϕ grad-step(ϕ, ϕ) ψ twist-training(θ, ψ) end return θ, ϕ, ψ Procedure twist-training(θ, ψ0) See Algorithm 2 in Appendix 8.3. 3.1 Theoretical Analysis of NAS-X In this section, we state two theoretical results that illustrate NAS-X s advantages over NASMC, with proofs given in Appendix 7. Proposition 1. Consistency of NAS-X s gradient estimates. Suppose the twists are optimal so that rψ(yt+1:T , xt) p(yt+1:T | xt) up to a constant independent of xt for t = 1, . . . , T 1. Let ˆ θ log pθ(y1:T ) be NAS-X s weighted particle approximation to the true gradient of the log marginal likelihood θ log pθ(y1:T ). Then ˆ θ log pθ(y1:T ) a.s. θ log pθ(y1:T ) as N . Proposition 2. Unbiasedness of NAS-X s gradient estimates. Assume that proposal distribution qϕ(xt | x1:t 1, y1:T ) is optimal so that qϕ(xt | x1:t 1, y1:T ) = p(xt | x1:t 1, y1:T ) for t = 1, . . . , T, and the twists rψ(yt+1:T , xt) are optimal so that rψ(yt+1:T , xt) p(yt+1:T | xt) up to a constant independent of xt for t = 1, . . . , T 1. Let ˆ θ log pθ(y1:T ) be NAS-X s weighted particle approximation to the true gradient of the log marginal likelihood, θ log pθ(y1:T ). Then, for any number of particles, E[ ˆ θ log pθ(y1:T )] = log pθ(y1:T ). 4 Related Work VI Methods There is a large literature on model learning via stochastic gradient ascent on an evidence lower bound (ELBO) [31, 32, 4, 33]. Subsequent works have considered ELBOs defined by the normalizing constant estimates from multiple importance sampling [34], nested importance sampling, [35, 36], rejection sampling, and Hamiltonian Monte Carlo [37]. Most relevant to our work is the literature that uses SMC s estimates of the normalizing constant as a surrogate objective. There are a number of VI methods based on filtering [17, 16, 15] and smoothing SMC [18, 38, 39, 19, 27], but filtering SMC approaches can suffer from particle degeneracy and high variance [18, 19]. Reweighted Wake-Sleep Methods The wake-sleep algorithm was introduced in Hinton et al. [13] as a way to train deep directed graphical models. Bornschein and Bengio [14] interpreted the wake-sleep algorithm as self-normalized importance sampling and proposed reweighted wakesleep, which uses SNIS to approximate gradients of the inclusive KL divergence and log marginal likelihood. Neural adaptive sequential Monte Carlo (NASMC) extends RWS by using filtering SMC to approximate posterior expectations instead of SNIS [20]. To combat particle degeneracy, NASMC approximates the posterior with the filtering distributions, which introduces bias. NAS-X vs. SIXO Both NAS-X and SIXO [19] leverage smoothing SMC with DRE-learned twists, but NAS-X uses smoothing SMC to estimate gradients in an RWS-like framework while SIXO uses it within a VI-like framework. Thus, NAS-X follows biased but consistent estimates of the log marginal likelihood while SIXO follows unbiased estimates of a lower bound on the log marginal likelihood. It is not clear a-priori which approach would perform better, but we provide empirical evidence in Section 5 that shows NAS-X is more stable than SIXO and learns better models and proposals. In 0 100 200 300 400 500 600 700 Proposal training steps (1000s) log p(y1:T) L128 NAS-X NASMC ELBO IWAE RWS SIXO FIVO 0 100 200 300 400 500 600 700 Proposal training steps (1000s) Proposal relative error 0 10 20 30 40 50 Model timesteps Figure 1: Comparison of NAS-X and baseline methods on inference in LG-SSM. (left) Comparison of log-marginal likelihood bounds (lower is better), (middle) proposal parameter error (lower is better), and (right) learned proposal variances. NAS-X outperforms all baseline methods and recovers the true posterior marginals. addition to these empirical advantages, NAS-X can fit discrete latent variable models while SIXO would require high-variance score function estimates. 5 Experiments We empirically validate the following advantages of NAS-X: By using the approximate smoothing distributions as targets for proposal learning, NAS-X can learn proposals that match the true posterior marginals, while NASMC and other baseline methods cannot. We illustrate this in Section 5.1, in a setting where the true posterior is tractable. We illustrate the practical benefits on inference in nonlinear mechanistic models in Section 5.3. By optimizing the proposal within the RWS framework (e.g., descending the inclusive KL), NAS-X can perform learning and inference in discrete latent variable models, which SIXO cannot. We explore this in Section 5.2. We explore the practical benefits of this combination in a challenging setting in Section 5.3, where we show NAS-X can fit ODE-based mechanistic models of neural dynamics with 38 model parameters and an 18-dimensional latent state. In addition to these experiments, we analyze the computational complexity and wall-clock speed of each method and the bias and variance of the gradient estimates in Sections 15 and 16 of the Appendix. 5.1 Linear Gaussian State Space Model We first consider a one-dimensional linear Gaussian state space model with joint distribution p(x1:T , y1:T ) = N(x1; 0, σ2 x) t=2 N(xt+1; xt, σ2 x) t=1 N(yt; xt, σ2 y). (10) In Figure 1, we compare NAS-X against several baselines (NASMC, FIVO, SIXO, RWS, IWAE, and ELBO) by evaluating log marginal likelihood estimates (left panel) and recovery of the true posterior (middle and right panels). For all methods we learn a mean-field Gaussian proposal factored over time, q(x1:T ) = QT t=1 qt(xt) = QT t=1 N(xt; µt, σ2 t ), with parameters µ1:T and σ2 1:T corresponding to the means and variances at each time-step. For twist-based methods, we parameterize the twist as a quadratic function in xt whose coefficients are functions of the observations and time step. We chose this form to match the functional form of the analytic log density ratio. For details, see Section 9 in the Appendix. NAS-X outperforms all baseline methods, achieving a tighter lower bound on the log-marginal likelihood and lower parameter error. In the right panel of Figure 1, we compare the learned proposal variances against the true posterior variance, which can be computed in closed form. See Section 9 for comparison of proposal means; we do not report this comparison in the main text since all methods recover the posterior mean. This comparison gives insight into NAS-X s better performance. NASMC s learned proposal overestimates the posterior variance and fails to capture the true posterior distribution, because it employs a filtering approximation to the gradients of the proposal distribution. On the other hand, by using the twisted targets, which approximate the smoothing distributions, to estimate proposal gradients, NAS-X recovers the true posterior. 5.2 Switching Linear Dynamical Systems To explore NAS-X s ability to handle discrete latent variables, we consider a switching linear dynamical system (SLDS) model [40, 41]. Specifically, we adapt the recurrent SLDS example from Linderman et al. [41] in which the latent dynamics trace ovals in a manner that resembles cars racing on a NASCAR track. There are two coupled sets of latent variables: a discrete state zt, with K = 4 possible values, and a two-dimensional continuous state xt that follows linear dynamics that depend on zt. The observations are a noisy projection of xt into a ten-dimensional observation space. There are 1000 observations in total. For the proposal, we factor q over both time and the continuous and discrete states. The continuous distributions are parameterized by Gaussians, and categorical distributions are used for the discrete latent variables. For additional details on the proposal and twist, see Section 10 in the Appendix and for details on the generative model see Linderman et al. [41]. (a) Comparing learned dynamics of NAS-X, NASMC, and Laplace EM on r SLDS. Ground truth (b) Train L1024 BPF for r SLDS. Method σ2 O = 0.001 σ2 O = 0.01 σ2 O = 0.1 NAS-X 19.837 0.0234 8.63 0.0015 2.79 0.0009 NASMC 19.834 0.0018 8.53 0.001 2.874 0.0007 Laplace EM 19.154 0.057 8.54 0.0039 2.765 0.0012 RWS 17.148 0.087 6.314 0.023 5.78 0.0026 Figure 2: Inference and model learning in switching linear dynamical systems (SLDS). (top) Comparison of learned dynamics and inferred latent states in model learning. Laplace EM sometimes learns incorrect segmentations, as seen in the rightmost panel. (bottom) Quantitative comparison of log marginal likelihood lower bounds obtained from running bootstrap particle filter (BPF) with learned models. We present qualitative results from model learning and inference in the top panel of Figure 2. We compare the learned dynamics for NAS-X, NASMC, and a Laplace-EM algorithm designed specifically for recurrent state space models [42]. In each panel, we plot the vector field of the learned dynamics and the posterior means, with different colors corresponding to the four discrete states. NAS-X recovers the true dynamics accurately. In the Table in Figure 2, we quantitatively compare the model learning performances across these three approaches by running a bootstrap proposal with the learned models and the true dynamics and observation variances. We normalize the bounds by the sequence length. NAS-X outperforms or performs on par with both NASMC and Laplace EM across the different observation noises σ2 O. See Section 10, for additional results on inference. 5.3 Biophysical Models of Neuronal Dynamics For our final set of experiments we consider inference and parameter learning in Hodgkin-Huxley (HH) models [9, 43] mechanistic models of voltage dynamics in neurons. These models use systems of coupled nonlinear differential equations to describe the evolution of the voltage difference across a neuronal membrane as it changes in response to external stimuli such as injected current. Un- derstanding how voltage propagates throughout a cell is central to understanding electrical signaling and computation in the brain. Voltage dynamics are governed by the flow of charged ions across the cell membrane, which is in turn mediated by the opening and closing of ion channels and pumps. HH models capture the states of these ion channels as well as the concentrations of ions and the overall voltage, resulting in a complex dynamical system with many free parameters and a high dimensional latent state space. Model learning and inference in this setting can be extremely challenging due to the dimensionality, noisy data, and expensive and brittle simulators that fail for many parameter settings. Model Description We give a brief introduction to the models used in this section and defer a full description to the appendix. We are concerned with modeling the potential difference across a neuronal cell membrane, v, which changes in response to currents flowing through a set of ion channels, c C. Each ion channel c has an activation which represents a percentage of the maximum current that can flow through the channel and is computed as a nonlinear function gc of the channel state λc, with gc(λc) [0, 1]. This activation specifies the time-varying conductance of the channel as a fraction of the maximum conductance of the channel, gc. Altogether, the dynamics for the voltage v can be written as gcgc(λc)(v Ecion) (11) where cm is the specific membrane capacitance, Iext is the external current applied to the cell, S is the cell membrane surface area, cion is the ion transported by channel c, and Ecion is that ion s reversal potential. In addition to the voltage dynamics, the ion channel states {λc}c C evolve as dt = A(v)λc + b(v) c C (12) where A(v) and b(v) are nonlinear functions of the membrane potential that produce matrices and vectors, respectively. Together, equations (11) and (12) define a conditionally linear system of first-order ordinary differential equations (ODEs), meaning that the voltage dynamics are linear if the channel states are known and vice-versa. Following Lawson et al. [19], we augment the deterministic dynamics with zero-mean, additive Gaussian noise to the voltage v and unconstrained gate states logit(λc) at each integration time-step. The observations are produced by adding Gaussian noise to the instantaneous membrane potential. Proposal and Twist Parameterization For all models in this section we amortize proposal and twist learning across datapoints. SIXO and NAS-X proposals used bidirectional recurrent neural networks (RNNs) [44, 45] with a hidden size of 64 units to process the raw observations and external current stimuli, and then fed the processed observations, previous latent state, and a transformer positional encoding [46] into a 64-unit single-layer MLP that produced the parameters of an isotropic Gaussian distribution over the current latent state. Twists were similarly parameterized with an RNN run in reverse across observations, combined with an MLP that accepts the RNN outputs and latent state and produces the twist values. Integration via Strang Splitting The HH ODEs are stiff, meaning they are challenging to integrate at large step sizes because their state can change rapidly. While small step sizes can ensure numerical stability, they also make methods prohibitively slow. For example, many voltage dynamics of interest unfold over hundreds of milliseconds, which could take upwards of 40,000 integration steps at the standard 0.005 milliseconds per step. Because running our models for even 10,000 steps would be too costly, we developed new numerical integration techniques based on an explicit Strang splitting approach that allowed us to stably integrate at 0.1 milliseconds per step, a 20-time speedup [47]. For details, see Section 13 in the Appendix. 5.3.1 Hodgkin-Huxley Inference Results First, we evaluated NAS-X, NASMC, and SIXO in their ability to infer underlying voltages and channel states from noisy voltage observations. For this task we sampled 10,000 noisy voltage traces from a probabilistic Hodgkin-Huxley model of the squid giant axon [9], and used each method to 4 8 16 32 64 128 256 Eval Num Particles Log-likelihood lower bound IWAE FIVO NASMC SIXO NAS-X (a) HH inference performance. Proposals were trained with 4 particles and evaluated across a range of particle numbers. RWS performed too poorly to be included. Voltage (m V) True Voltage Inferred Voltage Resampling Event 0 5 10 15 20 25 30 35 40 45 50 Time (ms) Voltage (m V) (b) Inferred voltage traces for SIXO and NAS-X. (top) SIXO generates a high number of resampling events leading to particle degeneracy and a single mistimed spike. (bottom) NASX perfectly infers the latent voltage with no mistimed spikes, and resamples infrequently. See Fig. 9 in Appendix for NASMC s traces. Figure 3: Inference in Mechanistic HH Model train proposals (and twists for NAS-X and SIXO) to compute the marginal likelihood assigned to the data under the true model. As in [19], we sampled trajectories of length 50 milliseconds, with a single noisy voltage observation every millisecond. The stability of the Strang splitting based ODE integrator allowed us to integrate at dt = 0.1ms, meaning there were 10 latent states per observation. In Figure (3a) we plot the performance of proposals and twists trained with 4 particles and evaluated across a range of particle numbers. All methods perform roughly the same when evaluated with 256 particles, but with lower numbers of evaluation particles the smoothing methods emerge as more particle-efficient than the filtering methods. To achieve NAS-X s inference performance with 4 particles, NASMC would need 256 particles, a 64x increase, and NAS-X is also on average 2x more particle-efficient than SIXO. In Figure 3b we further investigate these results by examining the inferred voltage traces of NAS-X and SIXO. SIXO accurately infers the timing of most spikes but resamples at a high rate, which can lead to particle degeneracy and poor bound performance. NAS-X correctly infers the voltage across the whole trace with no spurious or mistimed spikes and almost no resampling events, indicating it has learned a high-quality proposal that does not generate poor particles that must be resampled away. These qualitative results support the quantitative results in Figure 3a: SIXO s high resampling rate and NASMC s filtering approach lead to lower bound values. These results highlight a benefit of RWS-based methods over VI methods: when the model is correctly specified, it can be beneficial to have a more deterministic proposal. Empirically, we find that maximizing the variational lower bound encourages the proposal to have high entropy, which in this case resulted in SIXO s poorer performance relative to NAS-X. In the next section, we explore the implications of this on model learning. 5.3.2 Hodgkin-Huxley Model Learning Results In this section, we assess NAS-X and SIXO s ability to fit model parameters in a more complex, biophysically realistic model of a pyramidal neuron from the mouse visual cortex. This model was taken from the Allen Institute Brain Atlas [48] and includes 9 different voltage-gated ion channels as well as a calcium pump/buffer subsystem and a calcium-gated potassium ion channel. In total, the model had 38 free parameters and an 18-dimensional latent state space, in contrast to the 1 free parameter and 4-dimensional state space of the model considered by Lawson et al. [19]. For full details of the models, see Appendix Section 12. We fit these models to voltage traces gathered from a real mouse neuron by the Allen Institute, but downsampled and noised the data to simulate a more common voltage imaging setting. We ran a hyperparameter sweep over learning rates and initial values of the voltage and observation noise variances (270 hyperparameter settings in all), and selected the best performing model via early stopping on the train log marginal likelihood lower bound. Each hyperparameter setting was run for 0 20 40 60 80 100 120 140 160 180 200 Time (ms) True Voltage SIXO Samples NAS-X Samples Method L32 BPF # Spikes Err. Rest Voltage Err. Cross-Corr. % Runs Failed NAS-X 686.4 6.8 0.76 0.15 2.74 0.1 6258 11 18.9 SIXO 660.6 4.4 1.88 0.41 1.8 0.2 6055 22 25.2 Figure 4: Model learning in HH model of a mouse pyramidal neuron (top) Samples drawn from learned models when stimulated with a square pulse of 250 picoamps beginning at 20 milliseconds (vertical grey dashed line). NAS-X s samples are noisier than SIXO s, but spike more consistently. (bottom) A comparison of NAS-Xand SIXO-trained models along various evaluation metrics. SIXO s models achieve higher bounds, but are less stable and capture overall spike count more poorly than NAS-X-trained models. All errors are absolute errors. 5 seeds, and each seed was run for 2 days on a single CPU core with 7 Gb of memory. Because of the inherent instability of these models, many seeds failed, and we discarded hyperparameter settings with more than 2 failed runs. In Figure 4 (bottom), we compare NAS-X and SIXO-trained models with respect to test set loglikelihood lower bounds as well as biophysically relevant metrics. To compute the biophsyical metrics, we sampled 32 voltage traces for each input stimulus trace in the test set, and averaged the feature errors over the samples and test set. NAS-X better captures the number of spikes, an important feature of the traces, and attains a higher cross correlation. Both methods capture the resting voltage well, although SIXO attains a slightly lower error and outperforms NAS-X in terms of log-likelihood lower bound. Training instability is a significant practical challenge when fitting mechanistic models. Therefore, we also include the percentage of runs that failed for each method. SIXO s more entropic proposals more frequently generate biophysically implausible latent states, causing the ODE integrator to return Na Ns. In contrast, fewer of NAS-X s runs suffer from numerical instability issues, a great advantage when working with mechanistic models. 6 Conclusion In this work we presented NAS-X, a new method for model learning and inference in sequential latent variable models, that combines reweighted wake-sleep framework and approximate smoothing SMC. Our approach involves learning twist functions to use in smoothing SMC, and then running smoothing SMC to approximate gradients of the log marginal likelihood with respect to the model parameters and gradients of the inclusive KL divergence with respect to the proposal parameters. We validated our approach in experiments including model learning and inference for discrete latent variable models and mechanistic models of neural dynamics, demonstrating that NAS-X offers compelling advantages in many settings. Acknowledgements This work was supported by the Simons Collaboration on the Global Brain (SCGB 697092), NIH (U19NS113201, R01NS113119, and R01NS130789), NSF (2223827), Sloan Foundation, Mc Knight Foundation, the Stanford Center for Human and Artificial Intelligence, and the Stanford Data Science Institute. [1] Leonard E Baum and Ted Petrie. Statistical inference for probabilistic functions of finite state markov chains. The annals of mathematical statistics, 37(6):1554 1563, 1966. [2] Rudolph Emil Kalman. A new approach to linear filtering and prediction problems. ASME Journal of Basic Engineering, 82(1):35 45, March 1960. [3] Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra. Stochastic backpropagation and approximate inference in deep generative models. In International Conference on Machine Learning, pages 1278 1286. PMLR, 2014. [4] Diederik Kingma and Max Welling. Auto-encoding variational Bayes. In 2nd International Conference on Learning Representations, 2014. [5] Rahul G Krishnan, Uri Shalit, and David Sontag. Deep Kalman filters. ar Xiv preprint ar Xiv:1511.05121, 2015. [6] Junyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron C Courville, and Yoshua Bengio. A recurrent latent variable model for sequential data. Advances in neural information processing systems, 28, 2015. [7] Marco Fraccaro, Søren Kaae Sønderby, Ulrich Paquet, and Ole Winther. Sequential neural models with stochastic layers. Advances in Neural Information Processing Systems, 29, 2016. [8] Siddhartha Chib, Yasuhiro Omori, and Manabu Asai. Multivariate stochastic volatility. In Handbook of Financial Time Series, pages 365 400. Springer, 2009. [9] Alan L. Hodgkin and Andrew F. Huxley. A quantitative description of membrane current and its application to conduction and excitation in nerve. The Journal of Physiology, 117(4):500, 1952. [10] Arnaud Doucet and Adam M. Johansen. A tutorial on particle filtering and smoothing: Fifteen years later. In Dan Crisan and Boris Rozovsky, editors, The Oxford Handbook of Nonlinear Filtering, pages 656 704. Oxford University Press, 2011. [11] David Blei, Alp Kucukelbir, and Jon D. Mc Auliffe. Variational inference: A review for statisticians. Journal of the American Statistical Association, 112(518):859 877, 2017. [12] Martin J. Wainwright and Michael I. Jordan. Graphical models, exponential families, and variational inference. Foundations and Trends in Machine Learning, 1(1 2):1 305, 2008. [13] Geoffrey E Hinton, Peter Dayan, Brendan J Frey, and Radford M Neal. The wake-sleep algorithm for unsupervised neural networks. Science, 268(5214):1158 1161, 1995. [14] Jörg Bornschein and Yoshua Bengio. Reweighted Wake-Sleep. ar Xiv preprint ar Xiv:1406.2751, 2014. [15] Tuan Anh Le, Maximilian Igl, Tom Rainforth, Tom Jin, and Frank Wood. Auto-encoding Sequential Monte Carlo. In 6th International Conference on Learning Representations, 2018. [16] Christian A. Naesseth, Scott Linderman, Rajesh Ranganath, and David Blei. Variational Sequential Monte Carlo. In International Conference on Artificial Intelligence and Statistics, pages 968 977. PMLR, 2018. [17] Chris J. Maddison, Dieterich Lawson, George Tucker, Nicolas Heess, Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, and Yee Whye Teh. Filtering variational objectives. Advances in Neural Information Processing Systems, 30, 2017. [18] Dieterich Lawson, George Tucker, Christian A. Naesseth, Chris Maddison, Ryan P. Adams, and Yee Whye Teh. Twisted Variational Sequential Monte Carlo. In Third workshop on Bayesian Deep Learning (Neur IPS), 2018. [19] Dieterich Lawson, Allan Raventós, Andrew Warrington, and Scott Linderman. SIXO: Smoothing Inference with Twisted Objectives. In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho, editors, Advances in Neural Information Processing Systems, 2022. [20] Shixiang Shane Gu, Zoubin Ghahramani, and Richard E. Turner. Neural Adaptive Sequential Monte Carlo. Advances in Neural Information Processing Systems, 28, 2015. [21] James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake Vander Plas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transformations of Python+Num Py programs, 2018. URL http://github.com/google/jax. [22] Martín Abadi, Ashish Agarwal, Paul Barham, Eugene Brevdo, Zhifeng Chen, Craig Citro, Greg S. Corrado, Andy Davis, Jeffrey Dean, Matthieu Devin, Sanjay Ghemawat, Ian Goodfellow, Andrew Harp, Geoffrey Irving, Michael Isard, Yangqing Jia, Rafal Jozefowicz, Lukasz Kaiser, Manjunath Kudlur, Josh Levenberg, Dandelion Mané, Rajat Monga, Sherry Moore, Derek Murray, Chris Olah, Mike Schuster, Jonathon Shlens, Benoit Steiner, Ilya Sutskever, Kunal Talwar, Paul Tucker, Vincent Vanhoucke, Vijay Vasudevan, Fernanda Viégas, Oriol Vinyals, Pete Warden, Martin Wattenberg, Martin Wicke, Yuan Yu, and Xiaoqiang Zheng. Tensor Flow: Largescale machine learning on heterogeneous systems, 2015. URL https://www.tensorflow. org/. Software available from tensorflow.org. [23] Art B Owen. Monte Carlo Theory, Methods and Examples. Stanford, 2013. [24] Christian A. Naesseth, Fredrik Lindsten, Thomas B. Schön, et al. Elements of Sequential Monte Carlo. Foundations and Trends in Machine Learning, 12(3):307 392, 2019. [25] Pierre Del Moral. Feynman-Kac formulae: genealogical and interacting particle systems with applications, volume 88. Springer, 2004. [26] Nick Whiteley and Anthony Lee. Twisted particle filters. The Annals of Statistics, 42(1): 115 141, 2014. [27] Mark Briers, Arnaud Doucet, and Simon Maskell. Smoothing algorithms for state space models. Annals of the Institute of Statistical Mathematics, 62(1):61 89, 2010. [28] Ming Lin, Rong Chen, and Jun S. Liu. Lookahead strategies for sequential Monte Carlo. Statistical Science, 28(1):69 94, 2013. [29] Vasileios Lioutas, Jonathan Wilder Lavington, Justice Sefas, Matthew Niedoba, Yunpeng Liu, Berend Zwartsenberg, Setareh Dabiri, Frank Wood, and Adam Scibior. Critic sequential monte carlo. ar Xiv preprint ar Xiv:2205.15460, 2022. [30] Masashi Sugiyama, Taiji Suzuki, and Takafumi Kanamori. Density ratio estimation in machine learning. Cambridge University Press, 2012. [31] Rajesh Ranganath, Sean Gerrish, and David Blei. Black box variational inference. In International Conference on Artificial Intelligence and Statistics, pages 814 822. PMLR, 2014. [32] Matthew D. Hoffman, David Blei, Chong Wang, and John Paisley. Stochastic variational inference. Journal of Machine Learning Research, 2013. [33] Andriy Mnih and Danilo Rezende. Variational inference for Monte Carlo objectives. In International Conference on Machine Learning, pages 2188 2196. PMLR, 2016. [34] Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance weighted autoencoders. In 4th International Conference on Learning Representations, 2016. [35] Christian Naesseth, Fredrik Lindsten, and Thomas Schon. Nested Sequential Monte Carlo Methods. In International Conference on Machine Learning, pages 1292 1301. PMLR, 2015. [36] Heiko Zimmermann, Hao Wu, Babak Esmaeili, and Jan-Willem van de Meent. Nested variational inference. Advances in Neural Information Processing Systems, 34:20423 20435, 2021. [37] Dieterich Lawson, George Tucker, Bo Dai, and Rajesh Ranganath. Energy-inspired models: Learning with sampler-induced distributions. Advances in Neural Information Processing Systems, 32, 2019. [38] Antonio Moretti, Zizhao Wang, Luhuan Wu, Iddo Drori, and Itsik Pe er. Variational objectives for Markovian dynamics with backward simulation. In ECAI 2020, pages 1371 1378. IOS Press, 2020. [39] Antonio Moretti, Zizhao Wang, Luhuan Wu, and Itsik Pe er. Smoothing nonlinear variational objectives with sequential Monte Carlo. ICLR Workshop: Deep Generative Models for Highly Structured Data, 2019. [40] Emily Fox, Erik Sudderth, Michael Jordan, and Alan Willsky. Nonparametric Bayesian Learning of Switching Linear Dynamical systems. In D. Koller, D. Schuurmans, Y. Bengio, and L. Bottou, editors, Advances in Neural Information Processing Systems, volume 21. Curran Associates, Inc., 2008. [41] Scott W. Linderman, Matthew J. Johnson, Andrew C. Miller, Ryan P. Adams, David M. Blei, and Liam Paninski. Bayesian learning and inference in recurrent switching linear dynamical systems. In Proceedings of the 20th International Conference on Artificial Intelligence and Statistics (AISTATS), 2017. [42] David Zoltowski, Jonathan Pillow, and Scott Linderman. A general recurrent state space framework for modeling neural dynamics during decision-making. In Hal Daumé III and Aarti Singh, editors, Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pages 11680 11691. PMLR, 13 18 Jul 2020. [43] Peter Dayan and Laurence F. Abbott. Theoretical Neuroscience: Computational and Mathematical Modeling of Neural Systems. MIT press, 2005. [44] David E Rumelhart, Geoffrey E Hinton, and Ronald J Williams. Learning internal representations by error propagation. Technical report, California Univ San Diego La Jolla Inst for Cognitive Science, 1985. [45] Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural Computation, 9 (8):1735 1780, 1997. [46] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is All You Need. Advances in Neural Information Processing Systems, 30, 2017. [47] Zhengdao Chen, Baranidharan Raman, and Ari Stern. Structure-preserving numerical integrators for Hodgkin Huxley-type systems. SIAM Journal on Scientific Computing, 42(1):B273 B298, 2020. [48] Quanxin Wang, Song-Lin Ding, Yang Li, Josh Royall, David Feng, Phil Lesnar, Nile Graddis, Maitham Naeemi, Benjamin Facer, Anh Ho, et al. The Allen mouse brain common coordinate framework: a 3d reference atlas. Cell, 181(4):936 953, 2020. [49] Michael I Jordan. Serial order: A parallel distributed processing approach. In Advances in psychology, volume 121, pages 471 495. Elsevier, 1997. [50] Diederik Kingma, Jimmy Ba, Yoshua Bengio, and Yann Le Cun. Adam: A method for stochastic optimization. In 3rd International Conference on Learning Representations, 2015. [51] Tom Rainforth, Adam Kosiorek, Tuan Anh Le, Chris Maddison, Maximilian Igl, Frank Wood, and Yee Whye Teh. Tighter variational bounds are not necessarily better. In International Conference on Machine Learning, pages 4277 4285. PMLR, 2018. [52] AIBS. Biophysical modeling perisomatic. Technical report, Allen Institute for Brain Science, 10 2017. URL http://help.brain-map.org/display/celltypes/Documentation. [53] Costa M Colbert and Enhui Pan. Ion channel properties underlying axonal action potential initiation in pyramidal neurons. Nature neuroscience, 5(6):533 538, 2002. [54] Jacopo Magistretti and Angel Alonso. Biophysical properties and slow voltage-dependent inactivation of a sustained sodium current in entorhinal cortex layer-ii principal neurons: a whole-cell and single-channel study. The Journal of general physiology, 114(4):491 509, 1999. [55] Maarten HP Kole, Stefan Hallermann, and Greg J Stuart. Single ih channels in pyramidal neuron dendrites: properties, distribution, and impact on action potential output. Journal of Neuroscience, 26(6):1677 1687, 2006. [56] I Reuveni, A Friedman, Y Amitai, and Michael J Gutnick. Stepwise repolarization from ca2+ plateaus in neocortical pyramidal cells: evidence for nonhomogeneous distribution of hva ca2+ channels in dendrites. Journal of Neuroscience, 13(11):4609 4621, 1993. [57] Robert B Avery and Daniel Johnston. Multiple channel types contribute to the low-voltageactivated calcium current in hippocampal ca3 pyramidal neurons. Journal of Neuroscience, 16 (18):5567 5582, 1996. [58] AD Randall and RW Tsien. Contrasting biophysical and pharmacological properties of t-type and r-type calcium channels. Neuropharmacology, 36(7):879 893, 1997. [59] PR Adams, DA Brown, and A Constanti. M-currents and other potassium currents in bullfrog sympathetic neurones. The Journal of Physiology, 330(1):537 572, 1982. [60] Alon Korngreen and Bert Sakmann. Voltage-gated k+ channels in layer 5 neocortical pyramidal neurones from young rats: subtypes and gradients. The Journal of physiology, 525(3):621 639, 2000. [61] M Köhler, B Hirschberg, CT Bond, John Mark Kinzie, NV Marrion, James Maylie, and JP Adelman. Small-conductance, calcium-activated potassium channels from mammalian brain. Science, 273(5282):1709 1714, 1996. [62] Alastair J Walker. New fast method for generating discrete random numbers with arbitrary frequency distributions. Electronics Letters, 8(10):127 128, 1974. [63] Richard A Kronmal and Arthur V Peterson Jr. On the alias method for generating random variables from a discrete distribution. The American Statistician, 33(4):214 218, 1979. 7 Theoretical Analyses Proposition 1. Consistency of NAS-X s gradient estimates. Suppose the twists are optimal so that rψ(yt+1:T , xt) p(yt+1:T | xt) up to a constant independent of xt for t = 1, . . . , T 1. Let ˆ θ log pθ(y1:T ) be NAS-X s weighted particle approximation to the true gradient of the log marginal likelihood θ log pθ(y1:T ). Then ˆ θ log pθ(y1:T ) a.s. θ log pθ(y1:T ) as N . Proof. This is a direct application of Theorem 7.4.3 from Del Moral [25]. SMC methods provide strongly consistent estimates of expectations of test functions with respect to a normalized target distribution. That is, consider some test function h and let SMC s particle weights be denoted as wt i. We have that PK i=1 wk t h(xk 1:t) a.s. R h(x1:t)πt(x1:t) as K where πt(x1:t) is the normalized target distribution. NAS-X sets πt(x1:t) γt(x1:t) = pθ(x1:t, y1:t)rψ (yt+1:T , xt), which by assumption is proportional to pθ(x1:t, y1:t)pθ(yt+1:T | xt) = pθ(x1:t, y1:T ). The desired result follows immediately. Proposition 2. Unbiasedness of NAS-X s gradient estimates. Assume that proposal distribution qϕ(xt | x1:t 1, y1:T ) is optimal so that qϕ(xt | x1:t 1, y1:T ) = p(xt | x1:t 1, y1:T ) for t = 1, . . . , T, and the twists rψ(yt+1:T , xt) are optimal so that rψ(yt+1:T , xt) p(yt+1:T | xt) up to a constant independent of xt for t = 1, . . . , T 1. Let ˆ θ log pθ(y1:T ) be NAS-X s weighted particle approximation to the true gradient of the log marginal likelihood, θ log pθ(y1:T ). Then, for any number of particles, E[ ˆ θ log pθ(y1:T )] = log pθ(y1:T ). Proof. We first provide a proof sketch then give a more detailed derivation, adapting the proof of a similar theorem in Lawson et al. [19]. We will prove that particles produced from running SMC with the smoothing targets and optimal proposal are exact samples from the posterior distribution of interest. The claim then follows immediately. Under the stated assumptions, both NAS-X and NASMC propose particles from the true posterior. However, the different intermediate target distributions will affect how these particles are distributed after reweighting. For NAS-X, the particles will have equal weight since they are reweighted using the smoothing targets. Thus, after reweighting, the particles are still samples from the true posterior. In contrast, in NASMC, the samples are reweighted by filtering targets and will not be distributed as samples from the true posterior. We first consider NAS-X which uses the smoothing targets. We will show that the particles drawn at each timestep are sampled from the posterior distribution. This will follow from the fact that the particle weights at each timestep in the SMC sweep are equal; that is, wk t = 1 or wk t = p(y1:T ) for k = 1, . . . , K, depending on whether resampling occurred. We proceed by induction on t, the timestep in the SMC sweep. For t = 1 note that 1. γ1(xk 1) = p(xk 1, y1)p(y2:T | xk 1), 3. and q1(xk 1) = p(xk 1 | y1:T ). This implies that the incremental weight αk 1 is αk 1 = p(xk 1, y1)p(y2:T | xk 1) p(xk 1 | y1:T ) = p(xk 1, y1:T ) p(xk 1 | y1:T ) = p(y1:T ) (13) which does not depend on k. Because wk 0 1, we have that wk 1 = wk 0αk 1 = p(y1:T ) for all k. Note that, since the proposal is optimal, prior to SMC s reweighting step, the particles were distributed as follows xk 1 p(x1 | y1:T ). Since the incremental weights are equal, the distribution of particles is unchanged. For the induction step, assume that w1:K t 1 equals 1 or p(y1:T ) and that xk 1:t 1 p(x1:t 1 | y1:T ). The particles are distributed as follows, xk t p(xt | x1:t 1, y1:T ). This implies that (xk 1:t 1, xk t ) p(xt | x1:t 1, y1:T )p(x1:t 1 | y1:T ) = p(x1:t | y1:T ). We now show that the incremental particle weights αk t are equal using the following identities/assumptions 1. γt(xk 1:t) = p(xk 1:t, y1:t)p(yt+1:T | xk t ), 2. γt 1(xk 1:t 1) = p(xk 1:t 1, y1:t 1)p(yt:T | xk t 1), 3. rψ(yt+1:T , xt) p(yt+1:T | xt) up to a constant independent of xt for t = 1, . . . , T 1. 4. and qt(xk t ) = p(xk t | xk 1:t 1, y1:T ) This shows that αk t is given by αk t = p(xk 1:t, y1:t)p(yt+1:T | xk t ) p(xk 1:t 1, y1:t 1)p(yt:T | xk t 1)p(xk t | xk 1:t 1, y1:T ) (14) = p(xk 1:t, y1:T ) p(xk 1:t 1, y1:T )p(xk t | xk 1:t 1, y1:T ) (15) = p(xk 1:t 1, y1:T )p(xk t | x1:t 1, y1:T ) p(xk 1:t 1, y1:T )p(xk t | xk 1:t 1, y1:T ) (16) for k = 1, . . . , K. There are two cases depending on the value of the weights at the previous timestep. If w1:K t 1 = 1, then wk t = wk t 1αk t = 1 for all k. On the other hand, if w1:K t 1 = p(y1:T ) then wk t = p(y1:T ) for all k. Therefore, even after reweighting, the particles are still drawn from the true posterior distribution. If resampling occurs, since the weights are equal in both cases, the distribution of the particles remains unchanged. To conclude, we note that the incremental particle weights are, in general, not the same for NASMC. To see this, consider NASMC s incremental weights at timestep 1. αk 1 = p(xk 1, y1) p(xk t | y1:T ) (18) After reweighting, the particles will be distributed according to the filtering distributions. The distribution of particles after reweighting is proportional to p(xk 1 | y1). This is because the distribution of the reweighted particles is proportional to the incremental weights times the optimal proposal distribution. Therefore, the term in the denominator corresponding to the proposal cancels out with the proposal distribution term. Interestingly, the particle weights will be the same at each iteration under a certain dependency structure for the model p(x1:t 1|y1:t) = p(x1:t 1|y1:t 1) that was identified in Maddison et al. [17]. However, this dependency is not satisfied in general and therefore NASMC s gradients are not unbiased. 8 Derivations 8.1 Gradient of the Marginal Likelihood We derive the gradients for the marginal likelihood. This identity is known as Fisher s identity. θ log p(y1:T ) = θ log Z pθ(x1:T , y1:T )dy1:T = 1 pθ(y1:T ) θ Z pθ(x1:T , y1:T )dx1:T = 1 pθ(y1:T ) Z θpθ(x1:T , y1:T )dx1:T = 1 pθ(y1:T ) Z pθ(x1:T , y1:T ) θ log pθ(x1:T , y1:T )dx1:T = Z pθ(x1:T |y1:T ) θ log pθ(x1:T , y1:T )dx1:T = Z pθ(x1:T |y1:T ) θ X t log pθ(yt, xt|xt 1)dx1:T Z pθ(x1:T |y1:T ) θ log pθ(yt, xt|xt 1)dx1:T t Epθ(x1:T |y1:T ) [ θ log pθ(yt, xt|xt 1)] The key steps were the log-derivative trick and Bayes rule. 8.2 Gradient of Inclusive KL Divergence Below, we derive the gradient of the inclusive KL divergence for a generic Markovian model. In this derivation, we assume there are no shared parameters between the proposal and model. ϕKL(pθ||qϕ) = ϕ Z pθ(x1:T |y1:T ) log qϕ(x1:T |y1:T )dx1:T = Z pθ(x1:T |y1:T ) ϕ log qϕ(x1:T |y1:T )dx1:T = Z pθ(x1:T |y1:T ) ϕ t log qϕ(xt|xt 1, yt:T ) Z pθ(x1:T |y1:T ) ϕ log qϕ(xt|xt 1, yt:T )dx1:T t Epθ(x1:T |y1:T ) [ ϕ log qϕ(xt|xt 1, yt:T )] We use the assumption that there are no shared parameters in the second equality. 8.3 Density Ratio Estimation via Classification Here we briefly summarize density ratio estimation (DRE) via classification. For a full treatment, see Sugiyama et al. [30]. Let a(x) and b(x) be two distributions defined over the same space X, and consider a classifier g : X R that accepts a specific x X and classifies it as either being sampled from a or b. We will train this classifier to predict whether a given x was sampled from a(x) or b(x). The raw outputs (logits) of this classifier will approximate log(a(x)/b(x)) up to a constant that does not depend on x. To see this, define an expanded generative model where we first sample z {0, 1} from a Bernoulli random variable with probability 0 < ρ < 1, and then sample x from a(x) if z = 1, and sample x from b(x) if z = 0. This defines the joint distribution p(x, z) = p(z)p(x|z) = Bernoulli(z; ρ)a(x)zb(x)(1 z), (19) where p(x | z = 1) = a(x) and p(x | z = 0) = b(x). Let g : X R be a function that accepts x X and produces the logit for Bernoulli distribution over z. This function will parameterize a classifier via the sigmoid function, meaning that the classifier s Bernoulli conditional distribution over z is defined as pg(z|x) σ(g(x))z(1 σ(g(x)))1 z, z {1, 0} (20) where σ is the sigmoid function and σ 1 is its inverse, the logit function σ(ℓ) = 1 1 + e ℓ, σ 1(p) = log p 1 p The optimal function g will be selected by solving the maximum likelihood problem g arg max g Ep(x,z) [pg(z | x)] . (22) The solution to this problem is the true p(z | x). Because we have not restricted g, this solution can be obtained. Thus, p(z = 1 | x) =pg (z = 1 | x) (23) =σ(g (x))1(1 σ(g (x)))1 1 (24) =σ(g (x)). (25) This in turn implies that g (x) =σ 1(p(z = 1 | x)) (26) = log p(z = 1 | x) 1 p(z = 1 | x) = log p(z = 1 | x) p(z = 0 | x) = log p(z = 1 | x)p(x) p(z = 0 | x)p(x) = log p(z = 1, x) p(z = 0, x) = log p(x | z = 1)p(z = 1) p(x | z = 0)p(z = 0) = log p(x | z = 1) p(x | z = 0) + log p(z = 1) + log ρ 1 ρ Thus, the optimal solution to the classification problem, g , is proportional to log(a(x)/b(x)) up to a constant that does not depend on x. In practice we observe empirically that as long as a sufficiently flexible parametric family for g is selected, g will closely approximate the desired density ratio. In the case of learning the ratio required for smoothing SMC, pθ(xt | yt+1:T ) pθ(xt) , (34) Lawson et al. [19] instead learn the equivalent ratio pθ(xt, yt+1:T ) pθ(xt)pθ(yt+1:T ). (35) 0 100 200 300 400 500 600 700 Proposal training steps (1000s) log p(y1:T) L128 NAS-X NASMC ELBO IWAE RWS SIXO FIVO 0 100 200 300 400 500 600 700 Proposal training steps (1000s) Proposal relative error 0 10 20 30 40 50 Model timesteps Figure 5: Comparison of NAS-X vs baseline methods on Inference in LG-SSM. (left) Comparison of log-marginal likelihood bounds (lower is better), (middle) proposal parameter error (lower is better), and (right) learned proposal means. NAS-X outperforms several baseline methods and recovers the true posterior marginals. As per the previous derivation, it suffices to train a classifier to distinguish between samples from the numerator and denominator of Eq. 35. To accomplish this, Lawson et al. [19] draw paired and unpaired samples from the model that are distributed marginally according to the desired densities. Specifically, consider drawing x1:T , y1:T pθ(x1:T , y1:T ) x1:T pθ(x1:T ) (36) and note that any xt, yt+1:T selected from the sample will be distributed marginally according to pθ(xt, yt+1:T ). Similarly, any xt, yt+1:T will be distributed marginally as pθ(xt)pθ(yt+1:T ). In this way, T 1 positive and negative training examples for the DRE classifier can be drawn using a single set of samples as in Eq. (36). The twist training process is summarized in Algorithm 2. Algorithm 2: Twist Training Procedure twist-training(θ, ψ0) ψ ψ0 while not converged do x1:T , y1:T pθ(x1:T , y1:T ) x1:T pθ(x1:T ) L(ψ) = 1 T 1 PT 1 t=1 log σ(rψ(xt, yt+1:T )) + log(1 σ(rψ( xt, yt+1:T ))) ψ grad-step(ψ, ψL(ψ)) end return ψ Model Details We consider a one-dimensional linear Gaussian state space model with joint distribution p(x1:T , y1:T ) = N(x1; 0, σ2 x) t=2 N(xt+1; xt, σ2 x) t=1 N(yt; xt, σ2 y). (37) In our experiments we set the dynamics variance σ2 x = 1.0 and the observation variance σ2 y = 1.0. Proposal Parameterization For both NAS-X and NASMC, we use a mean-field Gaussian proposal factored over time t=1 qt(xt) = t=1 N(xt; µt, σ2 t ), (38) with parameters µ1:T and σ2 1:T corresponding to the means and variances at each timestep. In total, we learn 2T proposal parameters. Twist Parametrization We parameterize the twist as a quadratic function in xt whose coefficients are functions of the observations and time step and are learned via the density ratio estimation procedure described in [19]. We chose this form to match the analytic log density ratio for the model defined in Eq 10. Given that p(x1:T , y1:T ) is a multivariate Gaussian, we know that p(xt | yt+1:T ) and p(xt) are both marginally Gaussian. Let p(xt | yt+1:T ) N(µ1, σ2 1) p(xt) N(0, σ2 1) log p(xt | yt+1:T ) = log N(xt; µ1, σ2 1) log N(xt; 0, σ2 2) = log Z(σ1) 1 2σ2 1 x2 t + µ1 σ2 1 xt µ2 1 2σ2 1 log Z(σ2) + 1 2σ2 x2 t where Z(σ) = 1 σ 2π, so log Z(σ) = log(σ Collecting terms gives: 2π) + log(σ2 µ2 1 2σ2 1 So we ll define c µ2 1 2σ2 1 log(σ1 2π) + log(σ2 We ll explicitly model log σ2 1, log σ2 2 and µ1. Both log σ2 1 and log σ2 2 are only functions of t, not of yt+1:T , so those can be vectors of shape T initialized at 0. µ1 is a linear function of yt+1:T and t, so that can be parameterized by a set of T T weights, initialized to 1/T and T biases initialized to 0. Training Details We use a batch size of 32 for the density ratio estimation step. Since we do not perform model learning, we do not repeatedly alternate between twist training and proposal training for NAS-X. Instead, we first train the twist for 3,000,000 iterations with a batch size of 32 using samples from the model. We then train the proposal for 750, 000 iterations. For the twist, we used Adam with a learning rate schedule that starts with a constant learning rate of 1e 3, decays the learning by 0.3 and 0.33 at 100, 000 and 300, 000 iterations. For the proposal, we used Adam with a constant learning rate of 1e 3. For NASMC, we only train the proposal. Evaluation In the right panel of Figure 1, we compare the bound gaps of NAS-X and NASMC averaged across 20 different samples from the generative model. To obtain the bound gap for NAS-X, we run SMC 16 times with 128 particles and with the learned proposal and twists. We then record the average log marginal likelihood. For NASMC, we run SMC with the current learned proposal (without any twists). Model details The generative model is as follows. At each time t, there is a discrete latent state zt {1, . . . , 4} as well as a two-dimensional continuous latent state xt R2. The discrete state 0 100 200 300 400 500 Proposal training steps (100s) zt absolute error NAS-X, σ2 O = 0.1 NASMC, σ2 O = 0.1 0 100 200 300 400 500 Proposal training steps (100s) xt absolute error NAS-X, σ2 O = 0.1 NASMC, σ2 O = 0.1 Figure 6: Inference in NASCAR experiments. transition probabilities are given by p(zt+1 = i | zt = j, xt) exp ri + RT i xt 1 (39) Here Ri and ri are weights for the discrete state zi. These discrete latent states dictates two-dimensional latent state xt R2 which evolves according to linear Gaussian dynamics. xt+1 = Azt+1xt + bzt+1 + vt, vt iid N(0, Qzt+1) (40) Here Ak, Qk R2x2 and bk R2. Importantly, from Equations 40 and 39 we see that the dynamics of the continuous latent states and discrete latents are coupled. The discrete latent states index into specific linear dynamics and the discrete transition probabilities depend on the continuous latent state. The observations yt R10 are linear projections of the continuous latent state xt with some additive Gaussian noise. yt = Cxt + d + wt, vt iid N(0, S) (41) Here C, S R10x10 and d R10. Proposal Parameterization We use a mean-field proposal distribution factorized over the discrete and continuous latent variables (i.e. q(z1:T , x1:T ) = q(z1:T )q(x1:T )). For the continuous states, q(x1:T ) is a Gaussian factorized over time with parameters µ1:T and σ2 1:T . For the discrete states, q(z1:T ) is a Categorical distribution over K categories factorized over time with parameters p1:K 1:T . In total, we learn 2T + TK proposal parameters. Twist Parameterization We parameterize the twists using a recurrent neural network (RNN) that is trained using density ratio estimation. To produce the twist values at each timestep, we first run a RNN backwards over the observations y1:T to produce a sequence of encodings e1:T 1. We then concatenate the encodings of xt and zt into a single vector and pass that vector into an MLP which outputs the twist values at each timestep. The RNN has one layer with 128 hidden units. The MLP has 131 hidden units and Re LU activations. Model Parameter Evaluation We closely follow the parameter initialization strategy employed by Linderman et al. [41]. First, we use PCA to obtain a set of continuous latent states and initialize the matrices C and d. We then fit an autoregressive HMM to the estimated continuous latent states in order to initialize the dynamics matrices {Ak, bk}. Importantly, we do not initialize the proposal with the continuous latent states described above. Training Details We use a batch size of 32 for the density ratio estimation step. We alternate between 100 steps of twist training and 100 steps of proposal training for a total of 50,000 training steps in total. We used Adam and considered a grid search over the model, proposal, and twist learning rates. In particular, we considered learning rates of 1e 4, 1e 3, 1e 2 for the model, proposal, and twist. Bootstrap Bound Evaluation To obtain the log marginal likelihood bounds and standard deviations in Table 5.2, we ran a bootstrapped particle filter (BPF) with the learned model parameters for all three methods (NAS-X, NASMC, Laplace EM) using 1024 particles. We repeat this across 30 random seeds. Initialization of the latent states was important for a fair comparison. To initialize the latent states, for NAS-X and NASMC, we simply sampled from the learned proposal at time t = 0. To initialize the latent state for Laplace EM, we sampled from a Gaussian distribution with the learned dynamics variance at t = 0. Inference Comparison In the top panel of Figure 6, we compare NAS-X and NASMC on inference in the SLDS model. We report (average) posterior parameter recovery for the continuous and discrete latent states across 5 random samples from the generative model. NAS-X systematically recovers better estimates of both the discrete and continuous latent states. 11 Inference in Squid Giant Axon Model 11.1 HH Model Definition For the inference experiments (Section 5.3.1) we used a probabilistic version of the squid giant axon model [9, 43]. Our experimental setup was constructed to broadly match [19], and used a single-compartment model with dynamics defined by dt = Iext g Nam3h(v ENa) g Kn4(v EK) gleak(v Eleak) (42) dt = αm(v)(1 m) βm(v)m (43) dt = αh(v)(1 h) βh(v)h (44) dt = αn(v)(1 n) βn(v)n (45) where Cm is the membrane capacitance; v is the potential difference across the membrane; Iext is the external current; g Na, g K, and gleak are the maximum conductances for sodium, potassium, and leak channels; ENa, EK, and Eleak are the reversal potentials for the sodium, potassium, and leak channels; m and h are subunit states for the sodium channels and n is the subunit state for the potassium channels. The functions α and β that define the dynamics for n, m, and h are defined as αm(v) = 4 v/10 exp( 4 v/10) 1, βm(v) = 4 exp(( 65 v)/18) (46) αh(v) = 0.07 exp(( 65 v)/20), βh(v) = 1 exp( 3.5 v/10) + 1 (47) αn(v) = 5.5 v/10 exp( 5.5 v/10) 1, βn(v) = 0.125 exp(( 65 v)/80) (48) This system of ordinary differential equations defines a nonlinear dynamical system with a fourdimensional state space: the instantaneous membrane potential v and the ion gate subunit states n, m, and h. As in [19], we use a probabilistic version of the original HH model that adds zero-mean Gaussian noise to both the membrane voltage v and the unconstrained subunit states. The observations are produced by adding Gaussian noise with variance σ2 y to the membrane potential v. Specifically, let xt be the state vector of the system at time t containing (vt, mt, ht, nt), and let φdt(x) be a function that integrates the system of ODEs defined above for a step of length dt. Then the probabilistic HH model can be written as p(x1:T , y1:T ) = p(x1) t=2 p(xt | φdt(xt 1)) t=1 N(yt; xt,1, σ2 y) (49) where the 4-D state distributions p(x1) and p(xt | φdt(xt 1)) are defined as p(xt | φdt(xt 1)) = N(xt,1; φdt(xt 1)1, σ2 x,1) i=2 Logit Normal(xt,i; φdt(xt 1)i, σ2 x,i). (50) 4 8 16 32 64 128 256 Num Particles Log-likelihood lower bound Inference in Hodgkin-Huxley 50ms trace, 1 obs/ms, σ2 y = 20 FIVO-BS FIVO NASMC SIXO NAS-X 4 8 16 32 64 128 256 Num Particles Log-likelihood lower bound SIXO vs NAS-X 50ms trace, 1 obs/ms, σ2 y = 20 Figure 7: HH inference performance across different numbers of particles. (left) Log-likelihood lower bounds for proposals trained with 4 particles and evaluated across a range of particle numbers. NAS-X s inference performance decays only minimally as the number of particles is decreased, while all other methods experience significant performance degradation. (right) A comparison of SIXO and NAS-X containing the same values as the left panel, but zoomed in. NAS-X is roughly twice as particle efficient as SIXO, and outperforms SIXO by roughly 34 nats at 4 particles. In words, we add Gaussian noise to the voltage (xt,1) and logit-normal noise to the gate states n, m, and h. The logit-normal is defined as the distribution of a random variable whose logit has a Gaussian distribution, or equivalently it is a Gaussian transformed by the sigmoid function and renormalized. We chose the logit-normal because its values are bounded between 0 and 1, which is necessary for the gate states. Problem Setting For the inference experiments we sampled 10,000 noisy voltage traces from a fixed model and used each method to train proposals (and possibly twists) to compute the marginal likelihood assigned to the data under the true model. As in [19], we sampled trajectories of length 50 milliseconds, with a single noisy voltage observation every millisecond. The stability of our ODE integrator allowed us to integrate at dt = 0.1ms, meaning that there were 10 latent states per observation. Proposal and Twist Details Each proposal was parameterized using the combination of a bidirectional recurrent neural network (RNN) that conditioned on all observed noisy voltages as well as a dense network that conditioned on the RNN hidden state and the previous latent state xt 1 [45, 49]. The twists for SIXO and NAS-X were parameterized using an RNN run in reverse over the observations combined with a dense network that conditioned on the reverse RNN hidden state and the latent being twisted , xt. Both the proposal and twists were learned in an amortized manner, i.e. they were shared across all trajectories. All RNNs had a single hidden layer of size 64, as did the dense networks. All models were fit with ADAM [50] with proposal learning rate of 10 4 and twist learning rate of 10 3. A crucial aspect of fitting the proposals was defining them in terms of a residual from the prior, a technique known as Resq [7]. In our setting, we defined the true proposal density as proportional to the product of a unit-variance Gaussian centered at φ(xt) and a Gaussian with parameters output from the RNN proposal. 11.2 Experimental Results In Figure 7 we plot the performance of proposals and twists trained with 4 particles and evaluated across a range of particle numbers. All methods except FIVO perform roughly the same when evaluated with 256 particles, but with lower numbers of evaluation particles the smoothing methods emerge as more particle-efficient than the filtering methods. To achieve NAS-X s inference perfor- 4 8 16 32 64 128 256 Log-likelihood lower bound Train num particles = 4 4 8 16 32 64 128 256 Eval Num Particles Train num particles = 8 4 8 16 32 64 128 256 Train num particles = 16 IWAE FIVO NASMC SIXO NAS-X Figure 8: Training HH proposals with increasing numbers of particles. HH proposal performance plots similar to Figure 7, but trained with varying numbers of particles. Increasing the number of particles at training time has a negligible effect on NAS-X performance in this setting, but caused many VI-based methods to perform worse. This could be due to signal-to-noise issues in proposal gradients, as discussed in Rainforth et al. [51]. 0 5 10 15 20 25 30 35 40 45 50 Voltage (m V) True Voltage Inferred Voltage Resampling Event Voltage (m V) True Voltage Inferred Voltage Resampling Event 0 5 10 15 20 25 30 35 40 45 50 Time (ms) Voltage (m V) Figure 9: Inferred voltage traces for NASMC, SIXO, and NAS-X. (top) NASMC exhibits poor performance, incorrectly inferring the timing of most spikes. (middle) SIXO s inferred voltage traces are more accurate than NASMC s with only a single mistimed spike, but SIXO generates a high number of resampling events leading to particle degeneracy. (bottom) NAS-X perfectly infers the latent voltage with no mistimed spikes, and resamples very infrequently. mance with 4 particles, NASMC would need 256 particles, a 64-times increase. NAS-X is also more particle-efficient than SIXO, achieving on average a 2x particle efficiency improvement. We show the effect of changing the number of training particles in Figure 8. The FIVO method with a parametric proposal drastically underperformed all smoothing methods as well as NASMC, indicating that the combination of filtering SMC and the exclusive KL divergence leads to problems optimizing the proposal parameters. To compensate, we also evaluated the performance of FIVO-BS", a filtering method that uses a bootstrap proposal. This method is identical to a bootstrap particle filter, i.e. it proposes from the model and has no trainable parameters. FIVO-BS far outperforms standard FIVO, and is only marginally worse than NASMC in this setting. In Figure 9 we investigate these results qualitatively by examining the inferred voltage traces of each method. We see that NASMC struggles to produce accurate spike timings and generates many spurious spikes, likely because it is unable to incorporate future information into its proposal or resampling method. SIXO performs better than NASMC, accurately inferring the timing of most spikes but resampling at a high rate. High numbers of resampling events can lead to particle degeneracy and poor inferences. NAS-X is able to correctly infer the voltage across the whole trace with no suprious or mistimed spikes. Furthermore NAS-X rarely resamples, indicating it has learned a high-quality proposal that does not generate low-quality particles that must be resampled away. These qualitative results seem to support the quantitative results in Figure 7 SIXO s high resampling rate and NASMC s filtering approach lead to lower bound values. Table 1: Train Bound comparison Metric NAS-X SIXO L256 BPF 660.7003 636.2579 L4 train 664.3528 668.6865 L8 train 662.8712 653.6352 L16 train 662.0753 644.8764 L32 train 661.5387 639.5388 L64 train 660.8040 636.5131 L128 train 660.5102 633.7875 L256 train 660.3423 632.1377 12 Model Learning in Mouse Pyramidal Neuron Model 12.1 Model Definition For the model learning experiments in Section 5.3.2 we used a generalization of the Hodgkin-Huxley model developed for modeling mouse visual cortex neurons by the Allen Institute for Brain Science [48, 52]. Specifically we used the perisomatic model with ID 482657528 developed to model cell 480169178. The model is detailed in the whitepaper [52] and the accompanying code, but we reproduce the details here to ensure our work is self-contained. Similar to the squid giant axon model, the mouse visual cortex model is composed of ion channels that affect the current flowing in and out of the cell. Let I be the set of ions {Na+, Ca2+, K+}. Each ion has associated with it 1. A set of channels that transport that ion, denoted Ci for i I. 2. A reversal potential, Ei. 3. An instantaneous current density, Ii, which is computed by summing the current density flowing through each channel that transports that ion. Correspondingly, let C be the set of all ion channels so that C = S i I Ci. Each c C has associated with it 1. A maximum conductance density, gc. 2. A set of subunit states, referred to collectively as the vector λc. Let λc [0, 1]dc, i.e. λc is a dc-dimensional vector of values in the interval [0, 1]. 3. A function gc that combines the gate values to produce a number in [0, 1] that weights the maximum conductance density, gc gc(λc). 4. Functions Ac( ) and bc( ) which compute the matrix and vector used in the ODE describing λc dynamics. Ac and bc are functions of both the current membrane voltage v and calcium concentration inside the cell [Ca2+]i. If the number of subunits (i.e. the dimensionality of λc) is dc, then the output of Ac(v, [Ca2+]i) is a dc dc diagonal matrix and the output of bc(v, [Ca2+]i) is a dc-dimensional vector. With this notation we can write the system of ODEs SA gleak(v Eleak) X i ions Ii (51) gcgc(λc)(v Ei) (52) dt = Ac(v, [Ca2+]i)λc + bc(v, [Ca2+]i) c C (53) dt = k ICa2+ [Ca2+]i [Ca2+]min Most symbols are as described earlier, SA is the membrane surface area of the neuron, [Ca2+]i is the calcium concentration inside the cell, [Ca2+]min is the minimum interior calcium concentration with a value of 1 nanomolar, τ is the rate of removal of calcium with a value of 80 milliseconds, and k and is a constant with value k = 10000 γ 2 F depth (55) where 10000 is a dimensional constant, γ is the percent of unbuffered free calcium, F is Faraday s constant, and depth is the depth of the calcium buffer with a value of 0.1 microns. Because the concentration of calcium changes over time, this model calculates the reversal potential for calcium ECa2+ using the Nernst equation ECa2+ = G T 2 F log [Ca2+]o where G is the gas constant, T is the temperature in Kelvin (308.15 ), F is Faraday s constant, and [Ca2+]o is the extracellular calcium ion concentration which was set to 2 millimolar. Probabilistic Model The probabilistic version of the deterministic ODEs was constructed similarly to the probabilistic squid giant axon model Gaussian noise was added to the voltage and unconstrained gate states. One difference is that the system state now includes [Ca2+]i which is constrained to be greater than 0. To noise [Ca2+]i we added Gaussian noise in the log space, analagous to the logit-space noise for the gate states. Model Size The 38 learnable parameters of the model include: 1. Conductances g for all ion channels (10 parameters). 2. Reversal potentials of sodium, potassium, and the non-specific cation: EK+, ENa+, and ENSC+. 3. The membrane surface area and specific capacitance. 4. Leak channel reversal potential and max conductance density. 5. The calcium decay rate and free calcium percent. 6. Gaussian noise variances for the voltage v and interior calcium concentration [Ca2+]i. 7. Gaussian noise variances for all subunit states (16 parameters). 8. Observation noise variance. The 18-dimensional state includes: 1. Voltage v 2. Interior calcium concentration [Ca2+]i 3. All subunit states (16 dimensions) 12.2 Channel Definitions In this section we provide a list of all ion channels used in the model. In the following equations we often use the function exprel which is defined as exprel(x) = 1 if x = 0 exp(x) 1 x otherwise (57) A numerically stable implementation of this function was critical to training our models. Additionally, many of the channel equations below contain a temperature correction qt that adjusts for the fact that the original experiments and Allen Institute experiments were not done at the same temperature. In those equations, T is the temperature in Celsius which was 35 . 12.2.1 Transient Na+ From Colbert and Pan [53]. λc = (m, h), gc(λc) = m3h 1 qt dt = αm(v)(1 m) βm(v)m dt = αh(v)(1 h) βh(v)h qt = 2.3( T 23 αm(v) = 0.182 6 exprel( (v + 40)/6), βm(v) = 0.124 6 exprel((v + 40)/6) αh(v) = 0.015 6 exprel((v + 66)/6), βh(v) = 0.015 6 exprel( (v + 66)/6) 12.2.2 Persistent Na+ From Magistretti and Alonso [54]. λc = h, gc(λc) = m h m = 1 1 + exp( (v + 52.6)/4.6) 1 qt dt = αh(v)(1 h) βh(v)h qt = 2.3( T 21 αh(v) = 2.88 10 6 4.63 exprel((v + 17.013)/4.63), βh(v) = 6.94 10 6 2.63 exprel( (v + 64.4)/2.63) 12.2.3 Hyperpolarization-activated cation conductance From Kole et al. [55]. This channel uses a nonspecific cation current meaning it can transport any cation. In practice, this is modeled by giving it its own special ion NSC+ with resting potential ENSC+. λc = m, gc(λc) = m ENSC+ = 45.0 dm dt = αm(v)(1 m) βm(v)m αm(v) = 0.001 6.43 11.9 exprel((v + 154.9)/11.9), βm(v) = 0.001 193 exp(v/33.1) 12.2.4 High-voltage-activated Ca2+ conductance From Reuveni et al. [56] λc = (m, h), gc(λc) = m2h dm dt = αm(v)(1 m) βm(v)m dt = αh(v)(1 h) βh(v)h αm(v) = 0.055 3.8 exprel( (v + 27)/3.8), βm(v) = 0.94 exp( (v + 75)/17) αh(v) = 0.000457 exp( (v + 13)/50), βh(v) = 0.0065 exp( (v + 15)/28) + 1 12.2.5 Low-voltage-activated Ca2+ conductance From Avery and Johnston [57], Randall and Tsien [58]. λc = (m, h), gc(λc) = m2h 1 qt hτ qt =2.3(T 21)/10 m = 1 1 + exp( (v + 40)/6), mτ = 5 + 20 1 + exp((v + 35)/5) h = 1 1 + exp((v + 90)/6.4), hτ = 20 + 50 1 + exp((v + 50)/7) 12.2.6 M-type (Kv7) K+ conductance From Adams et al. [59]. λc = m, gc(λc) = m 1 qt dt = αm(v)(1 m) βm(v)m qt = 2.3( T 21 αm(v) = 0.0033 exp(0.1(v + 35)), βm(v) = 0.0033 exp( 0.1(v + 35)) 12.2.7 Kv3-like K+ conductance λc = m, gc(λc) = m dm m = 1 1 + exp( (v 18.7)/9.7), mτ = 4 1 + exp( (v + 46.56)/44.14) 12.2.8 Fast inactivating (transient, Kv4-like) K+ conductance From Korngreen and Sakmann [60]. λc = (m, h), gc(λc) = m4h 1 qt hτ qt =2.3(T 21)/10 m = 1 1 + exp( (v + 47)/29), mτ = 0.34 + 0.92 exp(((v + 71)/59)2) h = 1 1 + exp((v + 66)/10), hτ = 8 + 49 exp(((v + 73)/23)2) 12.2.9 Slow inactivating (persistent) K+ conductance From Korngreen and Sakmann [60]. λc = (m, h), gc(λc) = m2h 1 qt hτ qt =2.3(T 21)/10 m = 1 1 + exp( (v + 14.3)/14.6) mτ = 1.25 + 175.03 e0.026v, if v < 50 1.25 + 13 e 0.026v, if v 50 h = 1 1 + exp((v + 54)/11) hτ = 24v + 2690 exp(((v + 75)/48)2) 12.2.10 SK-type calcium-activated K+ conductance From Köhler et al. [61]. Note this is the only calcium-gated ion channel in the model. λc = z, gc(λc) = z dz z = 1 1 + (0.00043/[Ca2+]i)4.8 , zτ = 1 12.3 Training Details Dataset The dataset used to fit the model was a subset of the stimulus/response pairs available from the Allen Institute. First, all stimuli and responses were downloaded for cell 480169178. Then, sections of length 200 milliseconds were extracted from a subset of the stimuli types. The stimuli types and sections were chosen so that the neuron was at rest and unstimulated at the beginning of the trace. We list the exclusion criteria below. 1. Any Hold stimuli: Excluded because these traces were collected under voltage clamp conditions which we did not model. 2. Test: Excluded because the stimulus is 0 m V for the entire trace. 3. Ramp/Ramp to Rheobase: Excluded because the cell is only at rest at the very beginning of the trace. 4. Short Square: 250 ms to 450 ms. 5. Short Square Triple: 1250 ms to 1450 ms. 6. Noise 1 and Noise 2: 1250 ms to 1450 ms, 9250 ms to 9450 ms, 17250 ms to 17450 ms. 7. Long Square: 250 ms to 450 ms. 8. Square 0.5ms Subthreshold: The entire trace. 9. Square 2s Suprathreshold: 250 ms to 450 ms. 10. All others: Excluded. For cell 480169178, the criteria above selected 95 stimulus/response pairs of 200 milliseconds each. Each trace pair was then downsampled to 1 ms (from the original 0.005 ms per step) and corrupted with mean-zero Gaussian noise of variance 20 m V2 to simulate voltage imaging conditions. Finally, the 95 traces were randomly split into 72 training traces and 23 test traces. Proposal and Twist The proposal and twist hyperparameters were broadly similar to the squid axon experiments, with the proposal being parameterized by a bidirectional RNN with a single hidden layer of size 64 and an MLP with a single hidden layer of size 64. The RNN was conditioned on the observed response and stimulus voltages at each timestep, and the MLP accepted the RNN hidden state, the previous latent state, and a transformer positional encoding of the number of steps since the last voltage response observation. The twist was similarly parameterized using an RNN run in reverse across the stimulus and response, combined with an MLP that accepted the RNN hidden state, the latent state being evaluated, and a transformer positional encoding of the number of steps elapsed since the last voltage response observation. The positional encodings were used to inform the twist and proposal of the number of steps elapsed since the last observation because the model was integrated with a stepsize of 0.1ms while observations were received once every millisecond. Hyperparameter Sweeps To evaluate the methods we swept across the parameters 1. Initial observation variance: e2, e3, e5 2. Initial voltage dynamics variance: e, e2, e3 3. Bias added to scales produced by the proposal: e2, e5 We also evaluated the models across three different data noise variances (20, 10, and 5) but the results were similar for all values, so we reported only the results for variance 20. This amounted to 3 3 3 2 different hyperparameter settings, and 5 seeds were run for each setting yielding a total of 270 runs. When computing final performance, a hyperparameter setting was only evaluated if it had at least 3 runs that achieved 250,000 steps without Na N-ing out. For each hyperparameter setting selected for evaluation, all successful seeds were evaluated using early stopping on the train 4-particle log likelihood lower bound. 13 Strang Splitting for Hodgkin-Huxley Models Because the Hodgkin-Huxley model is a stiff ODE, integrating it can be a challenge, especially at large step sizes. The traditional solution is to use an implicit integration scheme with varying step size, allowing the algorithm to take large steps when the voltage is not spiking. However, because our model adds noise to the ODE state at each timestep adaptive step-size methods are not viable as the different stepsizes would change the noise distribution. Instead, we sought an explicit, fixed step-size method that could be stably integrated at relatively large stepsizes. Inspired by Chen et al. [47], we developed a splitting approach that exploits the conditional linearity of the system. The system of ODEs describing the model can be split into two subsystems 0 500 1000 1500 2000 2500 3000 Twist training steps Twist parameter error 0 500 1000 1500 2000 2500 3000 Twist training steps Twist accuracy Figure 10: Twist learning in LG-SSM. (left) Twist parameter error relative to optimal twist parameters for LG-SSM task; (right) Classification accuracy of learned twist. With an appropriate twist parameterization, twist learning via density ratio estimation is robust. of linear first-order ODEs when conditioned on the state of the other subsystem. Specifically, the dynamics of the channel subunit states {λc | c C} is a system of linear first-order ODEs when conditioned on the voltage v and interior calcium concentration [Ca2+]i. Similarly, the dynamics for v and [Ca2+]i is a system of linear first-order ODEs when conditioned on the subunit states. Because the conditional dynamics of each subsystem are linear first-order ODEs, an exact solution to each subsystem is possible under the assumption that the states being conditioned on are constant for the duration of the step. Our integration approach uses these exact updates in an alternating fashion, first performing an exact update to the voltage and interior calcium concentration while holding the subunit states constant, and then performing an exact update to the subunit states while holding the voltage and interior calcium concentration constant. For details on Strang and other splitting methods applied to Hodgkin-Huxley type ODEs, see [47]. 14 Robustness of Twist Learning NAS-X uses SIXO s twist learning framework to approximate the smoothing distributions. The twist learning approaches involves density ratio estimation. In brief, the density ratio estimation procedure involves training a classifier to distinguish between samples from pθ(xt | yt+1:T ) and pθ(xt). These samples can be obtained from the generative model. For details see Section 2.2. In principle, incorporating twists complicates the overall learning problem and traditional methods for twist learning can indeed be challenging. However, in practice, twist learning using the SIXO framework is robust and easy. In Figure 10, we present twist parameter recovery and classification accuracy for the Gaussian SSM experiments (Section 5.1); in this setting, the optimal twists have a known parametric form. The optimal twist parameters are recovered quickly, the classification accuracy is high, and training is stable. This suggests that, with an appropriate twist parameterization, twist learning via density ratio estimation is tractable and straightforward. 15 Computational Complexity and Wall-clock Time Theoretically, all multi-particle methods considered (NAS-X, SIXO, FIVO, NASMC, RWS, IWAE) have O(KT) time complexity, where K is the number of particles and T is the number of time steps. Once concern is that the resampling operation in SMC could require super-linear time in the number of particles, but drawing K samples from a K-category discrete distribution can be accomplished in O(K) time using the alias method [62, 63]. Additionally, for NAS-X, evaluating the twists is amortized across timesteps as in Lawson et al. [19], giving time linear in T. NAS-X and SIXO have similar wall-clock speeds but are slower than FIVO and NASMC, primarily because of twist training, see Table 2. Even if FIVO and NASMC were run with more particles to equalize wall-clock times, they would still far underperform NAS-X in log marginal likelihood lower bounds, see Figure 7. Specifically, SIXO and NAS-X take 3.5x longer per step than NASMC and FIVO and 2.5x longer per step than RWS and IWAE. However, Figure 7 shows that FIVO, NASMC, IWAE, and RWS cannot match NAS-X s performance even with 64 times more computation (4 vs. 256 particles). SIXO only matches NAS-X s performance with 4x as many particles (4 vs. 16 particles). Therefore, NAS-X uses computational resources much more effectively than other methods. Method ms / global step ms / proposal step ms / twist step IWAE 70.3 15.9 70.3 15.9 N/A RWS 71.6 8.2 71.6 8.2 N/A NASMC 53.9 6.6 53.9 6.6 N/A FIVO 51.4 13.3 51.4 13.3 N/A NAS-X 163.2 39.8 73.5 18.5 89.7 21.3 SIXO 175.3 42.4 85.6 21.1 89.7 21.3 Table 2: Wall-clock speeds of various methods during HH inference. 0 100 200 300 Proposal training steps (1000s) Gradient variances 0 100 200 300 Proposal training steps (1000s) Gradient bias 0 100 200 300 Proposal training steps (1000s) LML lower bound NAS-X NASMC SIXO RWS FIVO IWAE Figure 11: Hodgkin-Huxley gradient variances (left) gradient bias (middle), and log-marginal likelihood lower bounds (right) over training. 16 Empirical analysis of bias and variance of gradients In Figure 11, we analyze the gradient variance and bias for the Hodgkin-Huxley experiments, supplementing our theoretical analyses in Section 3.1. Figure 11 (left) shows NAS-X attains lower variance gradient estimates than IWAE, FIVO, and SIXO with comparable variance to RWS. We also studied the bias (middle) by approximating the true gradient by running a bootstrap particle filter with 256 particles using the best proposal from the inference experiments. NAS-X s gradients are lower bias than all methods but FIVO, but FIVO s gradients are also the highest variance. We hypothesize that FIVO s gradients appear less biased because its parameters are pushed towards degenerate values where gradient estimation is easier . We illustrate this in the right panel, where we plot log-marginal likelihood bounds.