# neural_processes_with_stability__472d3e5a.pdf Neural Processes with Stability Huafeng Liu, Liping Jing, Jian Yu Beijing Key Lab of Traffic Data Analysis and Mining, Beijing Jiaotong University The School of Computer and Information Technology, Beijing Jiaotong University {hfliu1, lpjing, jianyu}@bjtu.edu.cn Unlike traditional statistical models depending on hand-specified priors, neural processes (NPs) have recently emerged as a class of powerful neural statistical models that combine the strengths of neural networks and stochastic processes. NPs can define a flexible class of stochastic processes well suited for highly non-trivial functions by encoding contextual knowledge into the function space. However, noisy context points introduce challenges to the algorithmic stability that small changes in training data may significantly change the models and yield lower generalization performance. In this paper, we provide theoretical guidelines for deriving stable solutions with high generalization by introducing the notion of algorithmic stability into NPs, which can be flexible to work with various NPs and achieves less biased approximation with theoretical guarantees. To illustrate the superiority of the proposed model, we perform experiments on both synthetic and real-world data, and the results demonstrate that our approach not only helps to achieve more accurate performance but also improves model robustness. 1 Introduction Neural processes (NPs) [9, 10] constitute a family of variational approximation models for stochastic processes with promising properties in computational efficiency and uncertainty quantification. Different from traditional statistical modeling for which a user typically hand-specifies a prior (e.g., smoothness of functions quantified by a Gaussian distribution in Gaussian process [25]), NPs implicitly define a broad class of stochastic processes with neural networks in a data-driven manner. When appropriately trained, NPs can define a flexible class of stochastic processes well suited for highly non-trivial functions that are not easily represented by existing stochastic processes. NPs meta-learn a distribution over predictors and provide a way to select an inductive bias from data to adapt quickly to a new task. Incorporating the data prior into the model as an inductive bias, NPs can reduce the model complexity and improve model generalization. Usually, an NP predictor is described as predicting a set of data (target set) given a set of labeled data (context set). However, the number of noise in data introduces challenges to the algorithmic stability. In NPs, models are biased to the meta-datasets (a dataset of datasets), so small changes in the dataset (noisy or missing) may significantly change the models. As demonstrated in previous work [9, 10, 14, 11], existing NPs cannot provide stable predictions under noisy conditions, which may introduce high training error variance, and minimizing the training error may not guarantee consistent error reduction on the test set, i.e. low generalization performance [2]. In this case, algorithmic stability and generalization performance have strong connections and an unstable NP model has low generalization performance. A stable model is one for which the learned solution does not change much with small changes in training set [2]. In general, heuristic techniques, such as cross-validation and ensemble learning, can be adopted to improve the generalization performance. Cross-validation needs to sacrifice the limited training data, while ensemble learning is computationally expensive on training submodels. Recently, there are several improved NPs focused on considering model stability and 37th Conference on Neural Information Processing Systems (Neur IPS 2023). improving generalization performance empirically, such as hierarchical prior [27], stochastic attention mechanism [15], bootstrap [13], and Mixture of Expert [28]. However, most of them are unable to investigate the theoretical bound of the generalization performance of NPs. It is desirable to develop robust algorithms with low generalization error and high efficiency. In this paper, we investigate NP-related models and explore more expressive stability toward general stochastic processes by proposing a stable solution. Specifically, by introducing the notion of stability into NPs, we focus on developing theoretical guidelines for deriving a stable NPs solution. We propose a method to find out subsets that are harder to predict than average, which is a key step for constructing this optimization problem. Based on it, a new extension of NPs with stable guarantees is formulated, which can be flexible to work with various NPs and achieves less biased approximation with theoretical guarantees. Considering the model adaptivity, an adaptive weighting strategy is proposed. To illustrate the superiority of the proposed stable solution, we perform experiments on synthetic 1D regression, system identification of physics engines, and real-world image completion tasks, and the results demonstrate that NPs with our stable solution are much more robust than original NPs. 2 Related Work In this section, we briefly review two different areas which are highly relevant to the proposed method, neural processes, and algorithmic stability. Neural Processes Neural processes are a well-known member of the stochastic process family by directly capturing uncertainties with deep neural networks, which are not only computationally efficient but also retain a probabilistic interpretation of the model [9, 10, 14, 13]. Starting with conditional neural processes (CNP) [9], there have been several follow-up works to improve NPs in various aspects [6]. Vanilla CNP combines neural networks with the Gaussian process to extract prior knowledge from training data. NP [10] introduces a global latent variable to model uncertainty in a variational manner. Considering the problem of underfitting in the vanilla NP, Attentive NP [14] introduces the attention mechanism to improve the model s reconstruction quality. [11] introduced convolutional conditional neural process (CONVCNP) models translation equivariance in the data. Wang and Van Hoof [27] presented a doubly stochastic variational process (DSVNP), which combines both global and local latent variables. Lee et al. [18] extended NP using Bootstrap and proposed the bootstrapping neural processes (BNP). Kawano et al. [13] presented a group equivariant conditional neural process by incorporating group equivariant into CNPs in a meta-learning manner. Wang and van Hoof [28] proposed to combine the Mixture of Expert models with NPs to develop more expressive exchangeable stochastic processes. Kim et al. [15] proposed a stochastic attention mechanism for NPs to capture appropriate context information. Although there are many NP variants to improve the model performance, those do not consider stability to yield high generalization performance. Algorithmic Stability Stability, as known as algorithmic stability, is a computational learning theory of how a machine learning algorithm is perturbed by small changes to its inputs [2]. Many efforts have been made to analyze various notions of algorithmic stability and prove that a broad spectrum of learning algorithms are stable in some sense [2, 3, 29, 12]. [3] proved that l2 regularized learning algorithms are uniformly stable and able to obtain new bounds on generalization performance. [29] generalized [3] s results and proved that regularized learning algorithms with strongly convex penalty functions on bounded domains. Hardt et al. [12] showed that parametric models trained by stochastic gradient descent algorithms are uniformly stable. Li et al. [19] introduced the stability notation to low-rank matrix approximation. Liu et al. [22] proved that tasks in multi-task learning can act as regularizers and that multi-task learning in a very general setting will therefore be uniformly stable under mild assumptions. This is the first work to investigate the stability of NPs from theoretical guidelines and derive NPs solutions with high stability. 3 Preliminary Let calligraphic letters (e.g., A) indicate sets, capital letters (e.g., A) indicate scalars, lower-case bold letters (e.g., a) denote vectors, and capital bold letters (e.g., A) indicate matrices. Suppose there is a dataset D = (X, y) = {(xi, yi)}N i=1 with N data points X = [x1, x2, , x N] RN D, and corresponding labels y = [y1, y2, , y N] RN. Considering an arbitrary number of data points DC = (XC, y C) = {(xi, yi)}i C, where C {1, 2, , N} is an index set defining context information, neural processes model the conditional predictive distribution of the target values y T = {yi}i T at some target data points XT = {xi}i T based on the context DC, i.e. P(y T |XT , DC). Usually, target set is defined as T = {1, 2, , N}. Only in CNP [9], T {1, 2, , N} and T C = . In this paper, we define T = {1, 2, , N} for all NPs, i.e. conditional predictive distribution is P(y|X, DC) = QN i=1 P(yi|xi, DC). Fundamentally, there are two NP variants: deterministic and probabilistic. Deterministic NP [9], i.e. CNP, models the conditional distribution as P(y|X, DC) = P(y|X, r C), where r C Rd is an aggregated feature vector processed by a function that maps DC into a finite-dimensional vector space in a permutation-invariant way. In probabilistic NPs [10], a latent variable z Rd is introduced to capture model uncertainty and the NPs infer Pθ(z|DC) given context set using the reparameterization trick [16] and models such a conditional distribution as Pθ(yi|xi, DC) = R Pθ(yi|xi, DC, z)Pθ(z|DC)dz and it is trained by maximizing an ELBO: Ez Pθ(z|X,y)[log Pθ(y|X)] KL[Pθ(z|X, y) Pθ(z|DC)]. Meta Training NP Prediction To achieve fast prediction on a new context set at test time, NPs meta-learn a distribution over predictors. To perform meta-learning, we require a meta-dataset (dataset of datasets). We consider an unknown distribution µ on an instance space X Y, and a set of independent sample D = {(xi, yi)}N i=1 drawn from µ: (xi, yi) µ and D µN. Suppose meta-dataset contains M datasets D1:M = {Dm}M m=1 with Dm = {DC m, DT m}, we assume that all M datasets drawn from a common environment τ, which is a probability measure on the set of probability measures on X Y. The draw of µ τ indicates the encounter of a specific learning task µ in the environment τ. For simplicity, we assume that each dataset has the same sample size N. Following the previous work related to multi-task learning [24] and meta learning [5], The environment τ induces a measure µN,τ on (X Y)N such that µN,τ(A) = Eµ τ[µN(A)], A (X Y)N. Thus a dataset Dm is independently sampled from a task µ encountered in τ, which is denoted as Dm µN,τ for m [M]. Suppose there exists a meta parameter θ indicating the shared knowledge among different tasks. In this case, a meta learning algorithm Ameta for NPs takes meta-datasets D1:M as input, and then outputs a meta parameter θ = Ameta(D1:M) Pθ|D1:M . When given a new test dataset D, we can evaluate the quality of the meta parameter θ by the following true risk: Rτ(θ) = ED µN,τ EU Pθ|D1:M [Rµ(θ)] (1) where Rµ(θ) = E(xi,yi) µ log Pθ(yi|xi, DC). Usually, τ and µ are unknown, we can only estimate the meta parameter θ from the observed data D1:M. In this case, the empirical risk w.r.t θ is: RD1:M (θ) = 1/M XM m=1 Eθ Pθ|DC m RDm(θ) (2) where RDm(θ) = (1/N) PN i=1 log Pθ(yi|xi, DC). NPs have various strengths: 1) Efficiency: meta-learning allows NPs to incorporate information from a new context set and make predictions with a single forward pass. The complexity is linear or quadratic in the context size instead of cubic as with Gaussian process regression; 2) Flexibility: NPs can define a conditional distribution of an arbitrary number of target points, conditioning an arbitrary number of observations; 3) Permutation invariance: the encoders of NPs use set property [32] to make the target prediction permutation invariant. Thanks to these properties, NPs are widely-used in lots of tasks, e.g., Bayesian optimization [8], recommendation [20, 21], physics engines controlling [27] etc. While there are many NP variants to improve the performance of NPs [9, 10, 14, 13, 15, 28], those do not take model s stability into consider account yet, which is the key to the robustness of the model. 4 Problem Formulation Stability of NP A stable learning algorithm has the property that replacing one element in the training set does not result in a significant change to the algorithm s output [2]. Therefore, if we take the training error as a random variable, the training error of a stable learning algorithm should have a small variance. This implies that stable algorithms have the property that the training errors are close to the testing error [2]. Based on the defined risks, the algorithmic stability of approximate {yi}i T in NPs is defined as follows. Definition 4.1. (Algorithmic Stability of Neural Processes) For any measure µN,τ on (X Y)N such that µN,τ(A) = Eµ τ[µN(A)], A (X Y)N, sample M datasets D1:M from µN,τ randomly. For a given ϵ > 0, we say that RD1:M (θ) is δ-stable if the following holds: P (|Rτ(θ) RD1:M (θ)| ϵ) 1 δ. (3) The above stability for NPs has the property that the generalization error is bounded, which indicates that minimizing the training error will have a high probability of minimizing the testing error. This new stability notion makes it possible to measure the generalization performance between different NP approximations. For instance, for any two meta-datasets D1 1:M and D2 1:M from µN,τ, train NPs on D1 1:M and D2 1:M are δ1-stable and δ2-stable, respectively. Then RD1 1:M (θ) is more stable than RD2 1:M (θ) if δ1 < δ2. This implies that RD1 1:M (θ) is close to Rτ(θ) with higher probability than RD2 1:M (θ), i.e. minimizing RD1 1:M (θ) will lead to solutions that are of high probabilities with better generalization performance than minimizing RD2 1:M (θ). Based on the above analysis, we can see that the reliability of data points is crucial to the success of NPs and frail NPs are susceptible to noise. Generalization Generalization 1 10 20 50 80 100 The number of replaced datasets Figure 1: Stability vs. generalization error with different numbers of replaced noisy datasets. Stability vs. Generalization Error The sparsity of the data, incomplete and noisy introduces challenges to the algorithm stability. NP models are biased to the quality of context data and target data, so small changes in the training data (noisy) may significantly change the models. In this case, unstable solutions will introduce high training error variance, and minimizing the training error may not guarantee consistent error reduction on the testing dataset, i.e., low generalization performance. In other words, the algorithm stability has a direct impact on generalization performance, and an unstable NP solution has low generalization performance. We take NPs with 1D regression task as an example [9] to investigate the relationship between generalization performance and stability of NPs. The total number of training and testing datasets is 200 and 100. We trained the NPs model with curves generated from the Gaussian process with RBF kernels by replacing the normal data dataset with a noisy dataset, i.e. the number of replaced datasets is turned in {1, 10, 20, 50, 80, 100}. We quantify stability changes of NPs with the generalization error when the number of replaced datasets increases from 1 to 100. We compute the difference between training error and test error to measure generalization error. We define the difference between test error and training error as Rτ(θ) RD1:M (θ), and compute P(|Rτ(θ) RD1:M (θ)| ϵ) with 100 different runs to measure stability. Here we choose ϵ in Definition 4.1 as 0.0015 to cover all error differences when the number of replaced datasets is 1. As shown in Figure 1, the generalization error increases when the number of replaced points increases since the testing error becomes lower. On the contrary, the stability of NP decreases with the number of replaced points increases. This indicates that stability decreases with generalization error increases. This study demonstrates that existing NPs suffer from lower generalization performance due to low algorithmic stability. Therefore, it is important to develop a stable solution for NPs that offers good generalization performance. In this section, inspired by the previous work [19], we present a stable solution for NPs with stability and high generalization. Algorithmic stability provides an intuitive way to measure the changes in the outputs of a learning algorithm when the input is changed. Various ways have been introduced to measure algorithmic stability. Following the definition of uniform stability [2], given a stable NP, the approximation results remain stable if the change of the datasets. For instance, we can remove a subset of easily predictable data points from D1:M to obtain D 1:M. It is desirable that the solution of minimizing both D1:M and D 1:M together will be more stable than the solution of minimizing D1:M only. The following Theorem formally proves the statement. Theorem 5.1. Let D1:M (M 2) be a sampled meta-dataset of measure µN,τ. Let Ds D1:M be a subset of the meta-dataset, which satisfy that (xi, yi) Ds, log Pθ(yi|xi, DC) RD1:M (θ). Let D 1:M = D1:M Ds, then for any ϵ > 0 and 1 > w0 > 0, 1 > w1 > 0 (w0 + w1 = 1), w0RD1:M (θ) + w1RD 1:M (θ) and RD1:M (θ) are δ1-stable and δ2-stable, respectively, then δ1 δ2. Proof. Let s assume that Rτ(θ) RD1:M (θ) [ a1, a1] and Rτ(θ) (w0RD1:M (θ) + w1RD 1:M (θ)) [ a2, a2] are two random variables with zero mean, where a1 = sup{Rτ(θ) RD1:M (θ)} and a2 = sup{Rτ(θ) (w0RD1:M (θ) + w1RD 1:M (θ))}. Based on Markov s inequality1, for any t > 0, we have P(Rτ(θ) RD1:M (θ) ϵ) E h et(Rτ (θ) RD1:M (θ))i Based on Hoeffding s lemma2, we have E[et(Rτ (θ) RD1:M (θ))] e 1 2 t2a2 1, i.e. P(Rτ(θ) RD1:M (θ) ϵ) e 1 2 t2a2 1 etϵ . Similarly, we have P(Rτ(θ) RD1:M (θ) ϵ) e 1 2 t2a2 1 etϵ . Combining those two inequalities, we have P(|Rτ(θ) RD1:M (θ)| ϵ) 2e 1 2 t2a2 1 etϵ , i.e. P(|Rτ(θ) RD1:M (θ)| ϵ) 1 2e 1 2 t2a2 1 Similarly, we have P |Rτ(θ) (w0RD1:M (θ) + w1RD 1:M (θ))| ϵ 1 2e 1 2 t2a2 2 In this case, the relationship between a1 and a2 is a2 = sup n Rτ(θ) RD1:M (θ) + w1 RD1:M (θ) RD 1:M (θ) o = sup {Rτ(θ) RD1:M (θ)} + w1 sup n RD1:M (θ) RD 1:M (θ) o = a1 + λ1sup n RD1:M (θ) RD 1:M (θ) o (7) Since (xi, yi) Ds, log Pθ(yi|xi, DC s ) RD1:M (θ), we have (1/N) PN i=1 log Pθ(yi|xi, DC s ) RD1:M (θ). Then, since D1:M = Ds D 1:M. This means that sup{RD1:M (θ) RD 1:M (θ)} 0. Thus, we have a2 a1. This turns out that 2e 1 2 t2a2 2 etϵ 2e 1 2 t2a2 1 etϵ , i.e. δ1 δ2. Theorem 5.1 indicates that, if we remove a subset that is easier to predict than average from D1:M to form D 1:M, then w0RD1:M (θ) + w1RD 1:M (θ) has higher probability of being close to Rτ(θ) than RD1:M (θ). Therefore, minimizing w0RD1:M (θ) + w1RD 1:M (θ) will lead to solutions that have better generalization performance than minimizing RD1:M (θ). However, Theorem 5.1 only proves that it is beneficial to remove an easily predictable dataset from D1:M to obtain D 1:M, but does not show how many datasets we should remove from D1:M. Actually, removing more datasets that satisfy log Pθ(yi|xi, DC) RD1:M (θ) can obtain better D 1:M, as shown in following Theorem 5.2. Theorem 5.2. Let D1:M (M 2) be a sampled meta-dataset of measure µN,τ. Let Ds1 and Ds2 be two subsets of D1:M, which satisfy Ds2 Ds1 D1:M. Ds2 and Ds1 satisfy that (xi, yi) Ds1, log Pθ(yi|xi, DC) RD1:M (θ). Let D1 1:M = D1:M Ds1 and D2 1:M = D1:M Ds2, then for any ϵ > 0 and 1 > w0 > 0, 1 > w1 > 0 (w0 + w1 = 1), w0RD1:M (θ) + w1RD1 1:M (θ) and w0RD1:M (θ) + w1RD2 1:M (θ) are δ1-stable and δ2-stable, respectively, then δ1 δ2. Proof. Let s assume that Rτ(θ) (w0RD1:M (θ)+w1RD1 1:M (θ)) [ a1, a1] and Rτ(θ) (w0RD1:M (θ)+ w1RD2 1:M (θ)) [ a2, a2] are two random variables with 0 mean, where a1 = sup{Rτ(θ) (w0RD1:M (θ) + w1RD1 1:M (θ))} and a2 = sup{Rτ(θ) (w0RD1:M (θ) + w1RD2 1:M (θ))}. Then, based on Markov s inequality and Hoeffding s lemma, we have P Rτ(θ) w0RD1:M (θ) + w1RD1 1:M (θ) ϵ 1 2e 1 2 t2a2 1 P Rτ(θ) w0RD1:M (θ) + w1RD2 1:M (θ) ϵ 1 2e 1 2 t2a2 2 Since (xi, yi) Ds1, log Pθ(yi|xi, DC) RD1:M (θ) and Ds2 Ds1 D1:M, we have RD1 1:M (θ) RD2 1:M (θ). Thus, we have sup{RD1 1:M (θ) RD2 1:M (θ)} 0 since a1 = a2 + w1sup{RD1 1:M (θ) RD2 1:M (θ)} and a1 a2. Then, we can conclude that δ1 δ2. 1(Extended version of Markov s Inequality) Let x be a real-valued non-negative random variable and ϕ( ) be a nondecreasing nonnegative function with ϕ(a) > 0. Then, for any ϵ > 0, P(x ϵ) E[ϕ(x)] ϕ(ϵ) . 2(Hoeffding s Lemma) Let x be a real-valued random variable with zero mean and p(x [a, b]) = 1. Then, for any z R, E[ezx] exp 1 8z2(b a)2 . Theorem 5.2 indicates that removing more data points that are easy to predict will obtain more stable NPs. Therefore, it is desirable to choose D 1:M (i.e. D1:M Ds) as the whole set which is harder to predict than average, i.e. the whole set satisfying (xi, yi) D 1:M, log Pθ(yi|xi, DC) RD1:M (θ). Without loss of generality, we can extend Theorem 5.2 by considering several harder predicted sets to obtain a more stable solution by minimizing them all together. In this case, we need to prove that the stability of a model with K subsets is better than the model with K 1 subsets, as shown in Theorem 5.3. Theorem 5.3. Let D1:M (M 2) be a sampled meta-dataset of measure µN,τ. Let Ds1, Ds2, , Ds K D1:M be K subsets and satisfy (xi, yi) Dsk(k [K]), log Pθ(yi|xi, DC) RD1:M (θ). Let Dk 1:M = D1:M Dsk for all k [K], then for any ϵ > 0 and 1 > wk > 0 for all k [K], (w0 + w1 + + w K = 1), w0RD1:M (θ) + PK k=1 wk RDk 1:M (θ) and (w0 + w K)RD1:M (θ) + PK 1 k=1 wk RDk 1:M (θ) are δ1-stable and δ2-stable, respectively, then δ1 δ2. Proof. For simplicity, we denote w0RD1:M (θ) + PK k=1 wk RDk 1:M (θ) as R1 and (w0 + w K)RD1:M (θ) + PK 1 k=1 wk RDk 1:M (θ) as R2. Let s assume that Rτ(θ) R1 [ a1, a1] and Rτ(θ) R2 [ a2, a2] are two random variables with 0 mean, where a1 = sup{Rτ(θ) R1} and a2 = sup{Rτ(θ) R2}. Then, based on Markov s inequality and Hoeffding s lemma, we have P(|Rτ(θ) R1| ϵ) 1 2e 1 2 t2a2 1 etϵ , (|Rτ(θ) R2| ϵ) 1 2e 1 2 t2a2 2 Similar to the proof of Theorem 5.3, we have RDK 1:M (θ) RD1:M (θ), which indicates that sup{RD1:M (θ) RDK 1:M (θ)} 0. Since a2 = a1 + w Ksup{RD1:M (θ) RDK 1:M (θ)}, we know that a2 a1. Thus we can conclude that 2e 1 2 t2a22 etϵ 2e 1 2 t2a12 etϵ , i.e. δ1 δ2. Based on Theorem 5.3, we know that optimization on D1:M and more than one hard predictable subsets of D1:M can achieve more stable prediction. However, how to select data points that are difficult to predict from D1:M is still a challenging problem, especially, since we need to select K subsets. 5.1 The Proposed Solution According to the analysis of stability in NPs, we propose a stable solution for NPs to achieve model stability with the aid of hard predictable subsets selection. Specifically, we introduce a solution to obtain those hard predictable subsets based on only one set of easily predicted data points which can be broken into four steps: (1) Selecting a existing NP model (e.g., CNP [9], NP[10]) and training it with meta-dataset D1:M; (2) Selecting an easily predicted subset Ds D1:M, which satisfies that (xi, yi) Ds, log Pθ(yi|xi, DC) RD1:M (θ); (3) Dividing Ds into K non-overlapping subsets Ds1, Ds2, , Ds K and satisfying K k=1Dsk = Ds; (4) Defining K subsets that are difficult to predict, i.e. Dk 1:M = D1:M Dsk for all k [K]; Thus, a new extension of NPs is given as L = arg min θ w0RD1:M (θ) + XK k=1 wk RDk 1:M (θ), (10) where w0, w1, , w K indicate the contributions of each component and satisfy PK k=0 wk = 1. The whole learning algorithm is given in Algorithm 1. From steps 1 to 12, we obtain K different hard predictable subsets, and the complexity of lines 1 to 12 is O(MN). The complexity of line 11 is related to the applied NP models (such as CNP, NP, ANP, etc). Thus, the computational complexity of NPs with our stable solution is similar to the original NPs. As shown in Algorithm 1, we need to pre-train the base model to select samples. In fact, the pre-trained model can not only be used for sample selection but its model parameters can be used as the initialization of the stable version model. At this time, the training of the stable version can converge faster. Stability Guarantee Here we give a theoretical guarantee of the proposed stable solution. Theorem 5.4. Let D1:M (M 2) be a sampled meta-dataset of measure µN,τ. Let Ds D1:M which satisfies that (xi, yi) Ds, log Pθ(yi|xi, DC) RD1:M (θ). By dividing Ds into K subsets Algorithm 1 Learning algorithm for stable NPs Input: Meta-dataset D1:M, β > 0.5 is a predefined probability for data selection. Ds = . 1: Train a NP model with parameter θ based on RD1:M (θ); 2: for (xi, yi) D1:M do 3: randomly generate split parameter ρi [0, 1]; 4: if (( log Pθ(yi|xi, DC) RD1:M (θ)) & ρi β) 5: or (( log Pθ(yi|xi, DC) > RD1:M (θ)) & ρi 1 β) then 6: Ds Ds (xi, yi); 7: end if 8: end for 9: divide Ds into Ds1, Ds2, , Ds K with K k=1Dsk = Ds; 10: for k = 1, 2, , K do 11: Dk 1:M = D1:M Dsk; 12: end for 13: update parameters by optimizing θ = arg minθ w0RD1:M (θ) + PK k=1 wk RDk 1:M (θ); Output: The learned optimal parameters θ . Ds1, Ds2, , Ds K which satisfy that K k=1Dsk = D1:M. Let D0 1:M = D1:M Ds and Dk 1:M = D1:M Dsk for all k [K], then for any ϵ > 0 and 1 > wk > 0 for all k [K], (w0 + w1 + + w K = 1), w0RD1:M (θ) + PK k=1 wk RDk 1:M (θ) and w0RD1:M (θ) + (1 w0)RD0 1:M (θ) are δ1-stable and δ2-stable, respectively, then δ1 δ2. Proof. For simplicity, we denote w0RD1:M (θ) + PK k=1 wk RDk 1:M (θ) as R1 and w0RD1:M (θ) + (1 w0)RD0 1:M (θ) as R2. Let s assume that Rτ(θ) R1 [ a1, a1] and Rτ(θ) R2 [ a2, a2] are two random variables with 0 mean, where a1 = sup{Rτ(θ) R1} and a2 = sup{Rτ(θ) R2}. Then, based on Markov s inequality and Hoeffding s lemma, we have P(|Rτ(θ) R1| ϵ) 1 2e 1 2 t2a2 1 etϵ , P(|Rτ(θ) R2| ϵ) 1 2e 1 2 t2a2 2 k [K], Dsk Ds and (xi, yi) Ds, log Pθ(yi|xi, DC) RD1:M (θ), we have RDk 1:M (θ) RD0 1:M (θ). By combining the above inequalities over all k [K], we have XK k=1 wk RDk 1:M (θ) XK k=1 wk RD0 1:M (θ) = (1 w0)RD0 1:M (M). (12) Thus, sup{Rτ(θ) R1} sup{R1:M(θ) R2}, i.e. a1 a2. Thus we have δ1 δ2. According to the above theorem, we can achieve model stability by selecting only one easily predicted subset. 6 Experiments We started with learning predictive functions on synthetic datasets, and then high-dimensional tasks, e.g., system identification on physics engines, image completion, and Bayesian optimization, were performed to evaluate the properties of the NP-related models. 6.1 1D Regression To verify the proposed stable solution, we combined the stable solution with different baseline NP classes (CNP [9], NP [10], ANP [14],Conv CNP [11], Conv NP [6], and their bootstrapping versions [18]) and compared them on 1D regression task. Among them, BCNP, BNP, BANP, BConv CNP, and NConv NP are recently proposed stable strategies for NPs with Bootstrap. Specifically, the stochastic process (SP) initializing with a 0 mean Gaussian Process (GP) y(0) GP(0, k( , )) indexed in the interval x [ 2.0, 2.0] were used to generate data, where the radial basis function kernel and Matern Kernel were adopted for model-data mismatch scenario. More detailed information can be obtained in the Appendix. We investigated the model performance in terms of different noise settings. We introduced Gaussian noise N(0, 1) and added noise to different proportions of the data, such as {0%, 5%, 10%, 15%}. Table 1 lists the average log-likelihoods comparison in terms of different noise proportions. The best result is marked in bold. First, we can see that if we adopt the robust solution in baselines, the model achieves the best results on all the datasets, showing the effectiveness of the Table 1: Average Log-likelihoods over all context and target points on realizations from Synthetic Stochastic Process on the different percent of added noise. Here we set the context size to 20. (Mean Std). Note that adding S before the original model name is a model with our stable solution. Kernel Method Original Noise(+5%) Noise(+10%) Noise(+15%) context target context target context target context target CNP 0.8724 0.008 0.4334 0.007 0.8522 0.005 0.4001 0.010 0.8014 0.006 0.3552 0.004 0.7152 0.006 0.2853 0.005 BCNP 0.9042 0.009 0.4589 0.006 0.8774 0.006 0.4278 0.008 0.8316 0.006 0.3767 0.005 0.7487 0.007 0.3017 0.006 SCNP 0.9255 0.008 0.4733 0.004 0.8935 0.005 0.4478 0.006 0.8517 0.004 0.3986 0.005 0.7621 0.005 0.3279 0.006 NP 0.8215 0.004 0.3853 0.005 0.8011 0.004 0.3511 0.006 0.7611 0.005 0.3042 0.008 0.6722 0.005 0.2435 0.007 BNP 0.8722 0.004 0.4211 0.004 0.8321 0.003 0.3876 0.004 0.7922 0.004 0.3389 0.005 0.7189 0.004 0.2776 0.007 SNP 0.8955 0.003 0.4356 0.004 0.8567 0.003 0.4046 0.005 0.8165 0.004 0.3568 0.006 0.7356 0.005 0.2955 0.006 ANP 1.2563 0.002 0.5763 0.004 1.2245 0.007 0.5347 0.006 0.1742 0.005 0.4871 0.007 0.9821 0.005 0.4151 0.004 BANP 1.2722 0.004 0.5887 0.006 1.2411 0.005 0.5471 0.005 0.1886 0.006 0.4917 0.006 1.0642 0.006 0.4327 0.005 SANP 1.2831 0.000 0.5994 0.004 1.2564 0.004 0.5578 0.004 1.2052 0.004 0.5025 0.006 1.1243 0.005 0.4356 0.006 Conv CNP 1.2631 0.002 0.6421 0.002 1.2333 0.005 0.5415 0.005 0.1827 0.004 0.4936 0.005 1.0241 0.006 0.4262 0.005 BConv CNP 1.2761 0.004 0.6531 0.005 1.2476 0.004 0.5533 0.004 0.1931 0.007 0.4986 0.005 1.0716 0.006 0.4396 0.005 SConv CNP 1.3991 0.001 0.6793 0.004 1.2651 0.003 0.5623 0.005 1.2126 0.004 0.5096 0.005 1.1331 0.006 0.4461 0.005 Conv NP 1.2874 0.003 0.6503 0.004 1.2371 0.006 0.5451 0.006 0.1865 0.004 0.4965 0.006 0.9915 0.005 0.4335 0.006 BConv NP 1.2922 0.004 0.6627 0.006 1.2505 0.006 0.5583 0.004 0.1971 0.005 0.5025 0.007 1.0731 0.006 0.4436 0.004 SConv NP 1.4036 0.002 0.6831 0.003 1.2671 0.005 0.5675 0.004 1.2188 0.003 0.5157 0.005 1.1389 0.004 0.4505 0.005 CNP 0.8531 0.005 0.2431 0.010 0.8231 0.005 0.2144 0.010 0.7761 0.008 0.1784 0.007 0.7052 0.005 0.1452 0.006 BCNP 0.8778 0.005 0.2762 0.009 0.8487 0.006 0.2477 0.009 0.8015 0.007 0.2051 0.007 0.7378 0.005 0.1766 0.006 SCNP 0.8963 0.003 0.2953 0.006 0.8689 0.005 0.2658 0.007 0.8268 0.006 0.2258 0.006 0.7567 0.004 0.1936 0.005 NP 0.7643 0.015 0.2041 0.015 0.7342 0.002 0.1725 0.008 0.6892 0.004 0.1542 0.006 0.6235 0.008 0.1342 0.007 BNP 0.8156 0.005 0.2689 0.007 0.7789 0.004 0.2215 0.005 0.7421 0.005 0.2117 0.007 0.6715 0.006 0.1828 0.006 SNP 0.8368 0.006 0.2844 0.005 0.8036 0.003 0.2483 0.003 0.7635 0.004 0.2325 0.006 0.6973 0.006 0.2016 0.005 ANP 1.2421 0.002 0.6366 0.004 1.2115 0.001 0.6001 0.008 1.1784 0.004 0.1622 0.006 1.1252 0.007 0.5274 0.008 BANP 1.3456 0.003 0.6514 0.005 1.3125 0.005 0.6115 0.002 1.2672 0.004 0.1711 0.005 1.2236 0.006 0.5306 0.006 SANP 1.3721 0.002 0.6653 0.004 1.3461 0.003 0.6256 0.004 1.3011 0.003 0.1782 0.004 1.2457 0.005 0.5356 0.002 Conv CNP 1.2515 0.003 0.6418 0.004 1.2226 0.006 0.6085 0.005 1.1832 0.005 0.1871 0.007 1.1326 0.005 0.5351 0.004 BConv CNP 1.3527 0.005 0.6616 0.006 1.3252 0.005 0.6235 0.007 1.2767 0.006 0.1952 0.005 1.1315 0.006 0.5417 0.006 SConv CNP 1.3852 0.003 0.6731 0.004 1.3364 0.004 0.6335 0.005 1.2831 0.003 0.2037 0.005 1.1521 0.005 0.5557 0.004 Conv NP 1.2746 0.002 0.6557 0.005 1.2345 0.007 0.6015 0.005 1.1865 0.006 0.1943 0.005 1.1358 0.005 0.5397 0.003 BConv NP 1.3356 0.004 0.6787 0.006 1.3305 0.005 0.6383 0.006 1.2851 0.005 0.2015 0.005 1.1415 0.006 0.5338 0.003 SConv NP 1.3878 0.002 0.6836 0.004 1.3435 0.005 0.6417 0.004 1.2866 0.004 0.2025 0.005 1.1521 0.005 0.5363 0.006 stable solution. Besides, performances on all methods become less accurate in more complicated settings, while our solution has fewer effects. One interesting observation is that the improvements against the base model on CNP are less significant than NP and ANP. The possible reason is that CNP only predicts points out of the context set. Table 2: Experiments on 1D regression data with Periodic kernel. Periodic Original Noise(+15%) context target context target ANP 0.5730 0.006 -4.2345 0.005 0.3521 0.005 -5.3211 0.007 Conv CNP 0.5983 0.006 -4.0215 0.005 0.3658 0.005 -4.5233 0.008 Conv NP 0.6125 0.005 -3.8952 0.006 0.3756 0.006 -4.3413 0.008 BANP 0.6253 0.003 -3.5413 0.005 0.3651 0.006 -4.2511 0.015 BConv CNP 0.6342 0.005 -3.4142 0.004 0.3712 0.004 -4.1750 0.008 BConv NP 0.6355 0.004 -3.3627 0.005 0.3768 0.008 -4.0116 0.007 SANP 0.6315 0.002 -3.3317 0.004 0.3748 0.004 -4.0515 0.003 SConv CNP 0.6433 0.002 -3.1515 0.004 0.3866 0.005 -3.9851 0.004 SConv NP 0.6551 0.000 -3.1062 0.004 0.3981 0.002 -3.8895 0.004 To investigate the model s ability to address model-data mismatch scenarios, we conducted experiments on 1D regression tasks with Periodic kernel. Following the setting of BANP and similar noise settings of our previous kernels, we list the results on both original data and noise(+15) data in Table 2. In this model-data mismatch data, stable versions still significantly outperform their corresponding original versions. 6.2 Image Completion Image completion can be regarded as a 2D function regression task and be interpreted as being generated from a stochastic process (since there are dependencies between pixel values). Following the setting in previous work [14], we trained the NPs on EMNIST [4] and 32 32 CELEBA [23] using the standard train/test split with up to 200 context/target points at training. Detailed experiment settings are given in the Appendix. We evaluated average log-likelihoods over all points on realizations from image completion. Table 3 lists the comparisons between NPs with and without our stable solution in terms of original setting and noise setting, and the performance demonstrates the superiority of our stable solution. 6.3 System Identification on Physics Engines The second synthetic experiment focuses on evaluating model dynamics on a classical simulator, Cart-Pole systems, which is detailed in [7, 27]. The Cart-Pole swing-up task is a standard benchmark for nonlinear control due to the non-linearity in the dynamics, and the requirement for nonlinear Table 3: Average Log-likelihoods over all context and target points on EMNIST and CELEBA. Dataset Method Original Noise(+5%) Noise(+10%) Noise(+15%) context target context target context target context target CNP 0.9522 0.023 0.7515 0.0015 0.8977 0.0016 0.6336 0.017 0.8242 0.0018 0.5784 0.009 0.6566 0.0017 0.5341 0.016 BCNP 0.9678 0.010 0.8058 0.008 0.9015 0.008 0.6711 0.009 0.8415 0.007 0.6089 0.009 0.6788 0.006 0.5715 0.006 SCNP 0.9716 0.008 0.8343 0.006 0.9251 0.008 0.6971 0.007 0.8674 0.006 0.6343 0.007 0.6986 0.005 0.5877 0.005 NP 0.9678 0.004 0.7756 0.005 0.9011 0.009 0.6941 0.006 0.8544 0.009 0.6455 0.007 0.7034 0.009 0.5865 0.006 BNP 0.9757 0.005 0.8358 0.005 0.8116 0.007 0.7625 0.006 0.8759 0.007 0.6773 0.007 0.7451 0.005 0.6237 0.005 SNP 0.9847 0.005 0.8562 0.006 0.8368 0.006 0.7844 0.005 0.8984 0.005 0.6984 0.005 0.7653 0.005 0.6456 0.004 ANP 1.1125 0.002 1.0321 0.004 0.9815 0.002 0.6366 0.006 0.9021 0.004 0.7053 0.008 0.8454 0.002 0.7034 0.005 BANP 1.1355 0.003 1.0615 0.005 1.0236 0.002 0.6549 0.005 0.9155 0.004 0.7521 0.006 0.8612 0.003 0.7515 0.005 SANP 1.1531 0.000 1.0877 0.004 1.0421 0.002 0.6776 0.005 0.9321 0.002 0.7843 0.006 0.8732 0.003 0.7657 0.005 Conv CNP 1.1363 0.002 1.0461 0.004 1.0252 0.006 0.6448 0.005 0.9116 0.005 0.7115 0.006 0.8621 0.005 0.7246 0.004 BConv CNP 1.1425 0.004 1.0787 0.006 1.0311 0.005 0.6626 0.005 0.9252 0.006 0.7617 0.006 0.8717 0.006 0.7627 0.005 SConv CNP 1.1631 0.003 1.0894 0.004 1.0563 0.004 0.6778 0.004 0.9356 0.004 0.7885 0.005 0.8813 0.005 0.7692 0.006 Conv NP 1.1415 0.002 1.0563 0.004 1.0286 0.006 0.6536 0.006 0.9168 0.005 0.7171 0.007 0.8675 0.005 0.7368 0.005 BConv NP 1.1526 0.004 1.0837 0.005 1.0415 0.005 0.6684 0.005 0.9285 0.006 0.7762 0.006 0.8742 0.006 0.7727 0.005 SConv NP 1.1753 0.003 1.0934 0.004 1.0641 0.004 0.6837 0.005 0.9402 0.005 0.7925 0.006 0.8923 0.005 0.7853 0.005 CNP 1.0323 0.016 0.7845 0.013 1.0177 0.016 0.7438 0.017 0.8956 0.009 0.7344 0.011 0.7677 0.012 0.6096 0.009 BCNP 1.0452 0.009 0.8015 0.008 1.0275 0.009 0.7726 0.008 0.9351 0.006 0.8376 0.009 0.8015 0.010 0.6816 0.008 SCNP 1.0525 0.008 0.8243 0.006 1.0348 0.008 0.7868 0.006 0.9562 0.004 0.8545 0.008 0.8344 0.006 0.7045 0.005 NP 1.1333 0.004 0.8766 0.005 1.1043 0.015 0.8355 0.015 1.0034 0.008 0.8456 0.006 0.8935 0.009 0.6893 0.006 BNP 1.1732 0.005 0.8901 0.006 1.1378 0.007 0.8678 0.006 1.0411 0.008 0.8711 0.005 0.9256 0.008 0.7671 0.006 SNP 1.1952 0.005 0.9062 0.006 1.1565 0.005 0.8846 0.005 1.0542 0.006 0.8956 0.004 0.9425 0.006 0.7985 0.005 ANP 1.1633 0.002 1.0163 0.004 1.1377 0.004 0.9866 0.006 1.0418 0.004 0.8845 0.006 0.9363 0.004 0.7346 0.008 BANP 1.1751 0.002 1.0389 0.005 1.1488 0.004 1.0155 0.005 1.0602 0.005 0.9255 0.006 0.9489 0.004 0.8415 0.007 SANP 1.1854 0.000 1.0594 0.004 1.1685 0.002 1.0353 0.004 1.0772 0.002 0.9455 0.004 0.9655 0.003 0.8673 0.005 Conv CNP 1.1697 0.004 1.0366 0.004 1.1445 0.006 0.9947 0.006 1.0542 0.005 0.8971 0.007 0.9521 0.005 0.7563 0.006 BConv CNP 1.1822 0.004 1.0467 0.004 1.1511 0.005 1.0271 0.005 1.0686 0.006 0.9317 0.005 0.9542 0.005 0.8527 0.005 SConv CNP 1.1889 0.003 1.0615 0.005 1.1753 0.004 1.0478 0.004 1.0752 0.004 0.9485 0.005 0.9693 0.007 0.8714 0.006 Conv NP 1.1767 0.004 1.0451 0.004 1.1485 0.005 1.0156 0.006 1.0594 0.005 0.9066 0.007 0.9573 0.004 0.7779 0.005 BConv NP 1.1822 0.004 1.0555 0.004 1.1573 0.004 1.0351 0.005 1.0762 0.005 0.9461 0.003 0.9661 0.006 0.8636 0.003 SConv NP 1.1867 0.003 1.0668 0.004 1.1846 0.004 1.0487 0.005 1.0863 0.004 0.9511 0.005 0.9743 0.006 0.8836 0.003 Table 4: Bayesian optimization experiments on data generated by different GP kernels. Method ANP BANP SANP Conv CNP BConv CNP SConv CNP Conv NP BConv NP SConv NP RBF 0.1245 0.003 0.1341 0.003 0.1142 0.002 0.1215 0.002 0.1168 0.003 0.1037 0.002 0.1197 0.002 0.1156 0.003 0.1015 0.003 Matern 0.1518 0.003 0.1316 0.004 0.1201 0.002 0.1489 0.003 0.1301 0.002 0.1216 0.002 0.1446 0.002 0.1242 0.003 0.1204 0.004 Periodic 0.1892 0.002 0.1788 0.005 0.1672 0.001 0.1652 0.002 0.1526 0.004 0.1487 0.003 0.1611 0.002 0.1498 0.004 0.1446 0.002 controllers to successfully swing up and balance the pendulum. More detailed experimental settings are given in the Appendix. (b) MAE Figure 2: The predictive Log Likelihood (LL) and Mean Average Error (MAE) on Cart-Pole state transition testing dataset. For each configuration of the simulator including training and testing environments, we sampled 400 trajectories of the horizon as 10 steps using a random controller. During the testing process, 100 state transition pairs were randomly selected for each configuration of the environment, working as the maximum context points to identify the configuration of dynamics. Figure 2 shows the predictive Log-Likelihoods (LL) and Mean Average Error (MAE) on Cart-Pole State Transition Testing Dataset. We can see that our stable NPs achieve better likelihood and lower prediction error than the original ones. The variances on the stable one are consistently smaller than all original baselines. 6.4 Bayesian Optimization Following the setting in BANP [18], we conducted the Bayesian optimization experiment. Taking GP data with RBF, Matern, and Periodic prior functions as examples, we gave the results of ANP, Conv CNP, Conv NP, Boostraping versions, and our stable versions. To maintain consistent comparison, we standardized the initializations and normalized the results. We reported the best simple regret, which represents the difference between the current best observation and the global optimum. As shown in Table 4, we can see that our stable solutions consistently achieve lower regret than other NPs. Table 5: Predator-prey model results. Method ANP BANP SANP Conv CNP BConv CNP SConv CNP Conv NP BConv NP SConv NP Simulated-context 2.5801 0.003 2.5912 0.002 2.6127 0.004 2.5912 0.002 2.6068 0.003 2.6164 0.002 2.5925 0.002 2.6125 0.005 2.6231 0.003 Simulated-target 1.8265 0.003 1.8635 0.004 1.8844 0.002 1.8352 0.004 1.8524 0.003 1.9516 0.002 1.9228 0.005 1.9442 0.003 1.9662 0.004 Real-context 1.7234 0.002 1.8496 0.005 1.8412 0.001 1.7956 0.002 1.8026 0.004 1.8744 0.003 1.8342 0.002 1.8476 0.005 1.8796 0.002 Real-target -7.8042 0.002 -5.4836 0.004 -5.2527 0.001 -5.3414 0.005 -5.1526 0.004 -5.3513 0.004 -5.3155 0.002 -5.2615 0.004 -5.2145 0.003 6.5 Predator-prey Models Following [18] and [11], we consider the Lotka Volterra model [30], which is used to describe the evolution of predator prey populations. We first trained the models using simulated data generated from a Lotka-Volterra model and tested them on real-world data (Hudson s Bay hare-lynx data), which is quite different from the simulated data and can be considered as a mismatch scenario. Table 5 lists the results on both simulated and real data. Similar to the previous observation, our stable version still outperforms the original version. Among stable versions, SConv NP achieves the best performance. 6.6 Ablation Study The key parameter in our stable solution is the number of hard predictable subsets K. Taking SANP as an example, we investigated the average log-likelihood in terms of different K on the 1D regression task, as shown in Figure 3. We can see that SANP performs better as K increases, reaches the best value at around K = 4, and then becomes stable in performance as K grows larger. As proved in Theorem 5.3, optimization on D1:M and more than one hard predictable subset of D1:M can achieve more stable prediction. We also conducted different experiments to explore the impact of different wk. Taking the 1D regression task with RBF-GP data as an example, we set K = 3 and different wk for experiments, as shown in Table 6. It can be seen from the table below that when different weights are set, the model using a stable strategy is better than the original model, and when the weight is set to be equal, its performance is optimal. In addition, when the value of wk(k 1) is significantly different from w0, such as (0.625, 0.125, 0.125, 0.125) and (0.0625, 0.3125, 0.3125, 0.3125), its performance is more significantly reduced compared to (0.25, 0.25, 0.25, 0.25), but it still has a significant improvement compared to the original non-stable model. 1.24 1.25 1.26 1.27 1.28 1.29 0 1 2 4 8 10 K (a) RBF/context 0.56 0.57 0.58 0.59 0 1 2 4 8 10 K (b) RBF/Target 0 1 2 4 8 10 K (c) Matern/context 0.62 0.63 0.64 0.65 0.66 0.67 0 1 2 4 8 10 K (d) Matern/target Figure 3: Log-likelihood (SANP) comparisons with different K on 1D regression task. Table 6: Log-likelihood comparisons with different wk on 1D regression task. Weights (0.25,0.25,0.25,0.25) (0.4,0.2.0.2,0.2) (0.625,0.125,0.125,0.125) (0.1,0.3,0.3,0.3) (0.0625,0.3125,0.3125,0.3125) SCNP (0.9255, 0.4125) (0.9127, 0.4019) (0.9035, 0.3998) (0.9149, 0.4086) (0.8927, 0.3991) SNP (0.8955, 0.3925) (0.8737, 0.3817) (0.8716, 0.3809) (0.8831, 0.3859) (0.8657, 0.3775) SANP (1.2831, 0.5215) (1.2776, 0.5187) (1.2738, 0.5196) (1.2791, 0.5203) (1.2712, 0.5193) SConv CNP (1.3991, 0.5996) (1.3916, 0.5933) (1.3841, 0.5915) (1.3934, 0.5946) (1.3854, 0.5919) SConv NP (1.4036, 0.6015) (1.3991, 0.6004) (1.3931, 0.5992) (1.4012, 0.6015) (1.3957, 0.5995) 7 Conclusion and Future Work In this paper, we provided theoretical guidelines for deriving stable solutions for NPs, which can obtain good generalization performance. Experiments demonstrated the proposed stable solution can help NPs to achieve more accurate and stable predictions. Although the theoretical analysis we give is based on regression models, it is still open to question whether this conclusion is appropriate for classification models. Therefore, we are interested in extending our theory, expecting it to apply to more different types of tasks. Acknowledgement This work was partly supported by the National Natural Science Foundation of China under Grant 62176020; the National Key Research and Development Program (2020AAA0106800); the Joint Foundation of the Ministry of Education (8091B042235); the Beijing Natural Science Foundation under Grant L211016; the Fundamental Research Funds for the Central Universities (2019JBZ110); and Chinese Academy of Sciences (OEIP-O-202004). [1] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. ar Xiv preprint ar Xiv:1607.06450, 2016. [2] Olivier Bousquet and André Elisseeff. Algorithmic stability and generalization performance. Advances in Neural Information Processing Systems, pages 196 202, 2001. [3] Olivier Bousquet and André Elisseeff. Stability and generalization. The Journal of Machine Learning Research, 2:499 526, 2002. [4] Gregory Cohen, Saeed Afshar, Jonathan Tapson, and Andre Van Schaik. Emnist: Extending mnist to handwritten letters. In 2017 International Joint Conference on Neural Networks (IJCNN), pages 2921 2926. IEEE, 2017. [5] Alec Farid and Anirudha Majumdar. Generalization bounds for meta-learning via pac-bayes and uniform stability. Advances in Neural Information Processing Systems, 34:2173 2186, 2021. [6] Andrew Foong, Wessel Bruinsma, Jonathan Gordon, Yann Dubois, James Requeima, and Richard Turner. Meta-learning stationary stochastic process prediction with convolutional neural processes. Advances in Neural Information Processing Systems, 33:8284 8295, 2020. [7] Yarin Gal, Rowan Mc Allister, and Carl Edward Rasmussen. Improving pilco with bayesian neural network dynamics models. In Data-Efficient Machine Learning workshop, ICML, volume 4, page 25, 2016. [8] Alexandre Galashov, Jonathan Schwarz, Hyunjik Kim, Marta Garnelo, David Saxton, Pushmeet Kohli, SM Eslami, and Yee Whye Teh. Meta-learning surrogate models for sequential decision making. ar Xiv preprint ar Xiv:1903.11907, 2019. [9] Marta Garnelo, Dan Rosenbaum, Christopher Maddison, Tiago Ramalho, David Saxton, Murray Shanahan, Yee Whye Teh, Danilo Rezende, and SM Ali Eslami. Conditional neural processes. In International Conference on Machine Learning, pages 1704 1713. PMLR, 2018. [10] Marta Garnelo, Jonathan Schwarz, Dan Rosenbaum, Fabio Viola, Danilo J Rezende, SM Eslami, and Yee Whye Teh. Neural processes. ar Xiv preprint ar Xiv:1807.01622, 2018. [11] Jonathan Gordon, Wessel P Bruinsma, Andrew YK Foong, James Requeima, Yann Dubois, and Richard E Turner. Convolutional conditional neural processes. ar Xiv preprint ar Xiv:1910.13556, 2019. [12] Moritz Hardt, Ben Recht, and Yoram Singer. Train faster, generalize better: Stability of stochastic gradient descent. In International Conference on Machine Learning, pages 1225 1234. PMLR, 2016. [13] Makoto Kawano, Wataru Kumagai, Akiyoshi Sannai, Yusuke Iwasawa, and Yutaka Matsuo. Group equivariant conditional neural processes. ar Xiv preprint ar Xiv:2102.08759, 2021. [14] Hyunjik Kim, Andriy Mnih, Jonathan Schwarz, Marta Garnelo, Ali Eslami, Dan Rosenbaum, Oriol Vinyals, and Yee Whye Teh. Attentive neural processes. ar Xiv preprint ar Xiv:1901.05761, 2019. [15] Mingyu Kim, Kyeongryeol Go, and Se-Young Yun. Neural processes with stochastic attention: Paying more attention to the context dataset. In International Conference on Learning Representations, 2022. [16] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. ar Xiv preprint ar Xiv:1312.6114, 2013. [17] Volodymyr Kuleshov, Nathan Fenner, and Stefano Ermon. Accurate uncertainties for deep learning using calibrated regression. In International Conference on Machine Learning, pages 2796 2804. PMLR, 2018. [18] Juho Lee, Yoonho Lee, Jungtaek Kim, Eunho Yang, Sung Ju Hwang, and Yee Whye Teh. Bootstrapping neural processes. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 6606 6615. Curran Associates, Inc., 2020. [19] Dongsheng Li, Chao Chen, Qin Lv, Junchi Yan, Li Shang, and Stephen Chu. Low-rank matrix approximation with stability. In International Conference on Machine Learning, pages 295 303. PMLR, 2016. [20] Xixun Lin, Jia Wu, Chuan Zhou, Shirui Pan, Yanan Cao, and Bin Wang. Task-adaptive neural process for user cold-start recommendation. In Proceedings of the Web Conference, pages 1306 1316, 2021. [21] Huafeng Liu, Liping Jing, Dahai Yu, Mingjie Zhou, and Michael Ng. Learning intrinsic and extrinsic intentions for cold-start recommendation with neural stochastic processes. In Proceedings of the 30th ACM International Conference on Multimedia, pages 491 500, 2022. [22] Tongliang Liu, Dacheng Tao, Mingli Song, and Stephen J Maybank. Algorithm-dependent generalization bounds for multi-task learning. IEEE transactions on pattern analysis and machine intelligence, 39(2):227 241, 2016. [23] Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In Proceedings of International Conference on Computer Vision (ICCV), December 2015. [24] Andreas Maurer, Massimiliano Pontil, and Bernardino Romera-Paredes. The benefit of multitask representation learning. Journal of Machine Learning Research, 17(81):1 32, 2016. [25] Carl Edward Rasmussen. Gaussian processes in machine learning. In Summer school on machine learning, pages 63 71. Springer, 2003. [26] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in neural information processing systems, pages 5998 6008, 2017. [27] Qi Wang and Herke Van Hoof. Doubly stochastic variational inference for neural processes with hierarchical latent variables. In International Conference on Machine Learning, pages 10018 10028. PMLR, 2020. [28] Qi Wang and Herke van Hoof. Learning expressive meta-representations with mixture of expert neural processes. In Advances in neural information processing systems, 2022. [29] Andre Wibisono, Lorenzo Rosasco, and Tomaso Poggio. Sufficient conditions for uniform stability of regularization algorithms. 2009. [30] Darren J Wilkinson. Stochastic modelling for systems biology. CRC press, 2018. [31] Manzil Zaheer, Satwik Kottur, Siamak Ravanbakhsh, Barnabas Poczos, Ruslan Salakhutdinov, and Alexander Smola. Deep sets. ar Xiv preprint ar Xiv:1703.06114, 2017. [32] Manzil Zaheer, Satwik Kottur, Siamak Ravanbakhsh, Barnabas Poczos, Russ R Salakhutdinov, and Alexander J Smola. Deep sets. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017. A Inductive biases Here, we revisit some properties, which would help us understand NPs. First, we give a concept of Permutation Invariant Function which is the basic property of stochastic process, e.g., NPs. Definition A.1. (Permutation Invariant Function) A function f( ) : N i RD Rd mapping a set of data points {xi}N i=1 is Permutation Invariant Function if x = [x1, x2, , x N] f = [f1(xπ(1:N)), f2(xπ(1:N)), , fd(xπ(1:N))], (13) where xi RD and the function output is a d dimensional vector. Operation π : [1, 2, , N] [π1, π2, , πN] is a permutation set over the order of elements in the set. Definition A.2. (Permutation Equivariant Function) A function f( ) : N i RD RN mapping a set of data points {xi}N i=1 is Permutation Invariant Function if Xπ = [xπ1, xπ2, , xπN ] fπ = π f(x1:N), (14) where the function output contains N elements keeping the order of inputs. Permutation Equivariant Function keeps the order of elements in the output consistent with that in the input under any permutation operation π. Permutation invariant functions are candidate functions for learning embeddings of a set or other order uncorrelated data structure {xi}N i=1, and the invariant property is easy to be verified. Here, we give a mean operation structure over the output F(Xπ(1:N)) = i=1 φ1(xi), 1 i=1 φ2(xi), , 1 B Model Architecture We show the architectural details of the CNP, NP, and ANP models used for the 1D and 2D function regression experiments. The neural process aims to learn a stochastic process (random function) mapping target features xi to prediction yi given the context set DC as training data (a realization from the stochastic process), i.e., learning log P y T |XT , DC = log P y|X, DC = i=1 P yi|xi, DC . (16) Conditional neural process (CNP) [9] describes P yi|xi, DC with a deterministic neural network taking DC to output the parameters of P yi|xi, DC . CNP consists of an encoder fenc( ), an aggregator fagg( ) and a decoder fdec( ); the encoder summarizes DC and xi into latent representations [r1, , r|C|] R|C| d via permutation-invariant neural network [31], where d is the number of latent dimensions, and aggregator summarizes the encoded context features to a single representation r C, and decoder takes as input the aggregated representations r C and xi and output the single output-specific mean µi and variance σ2 i for the corresponding value of yi. ri = fenc (xi, yi) , i C φ = fagg(r C) (µi, σi) = fdec(φ, xi), p yi|xi, DC = N yi; µi, σ2 i i T where fenc( ) and fdec( ) are feed-forward neural networks. The decoder output µi and variance σ2 i are predicted mean and variance. We use Gaussian distribution N(yi; µi, σ2 i ) as predictive distribution. CNP is trained to maximize the expected likelihood EP (T )[P yi|xi, DC ]. Neural process [10] further models functional uncertainty using a global latent variable. Unlike CNP, which maps a context into a deterministic representation eri, NP encoders a context into a Gaussian latent variable z, giving additional stochasticity in function construction. Following [14], we consider an NP with both a deterministic path and latent path, where the deterministic path models the overall skeleton of the function eri, and the latent path models the functional uncertainty: ri = f (1) enc (xi, yi) , i C φ = fagg(r) (µz, σz) = f (2) enc DC , q(z|DC) = N(z; µz, σ2 z) (µi, σi) = fdec(φ, z, xi), p yi|xi, z, DC = N yi; µi, σ2 i i T with f (1) enc( ) and f (2) enc( ) having the same structure as fenc( ) in Eq.(17). In this scenario, the conditional distribution is lower bounded as: log P y|X, DC i=1 Eq(z|DC) log P yi|xi, z, DC P(z|DC) q(z|X, y) We further approximate q(z|DC) P(z|DC) and train the model by maximizing this expected lower bound over tasks. Furthermore, ANP introduces attention mechanisms into NP to resolve the issue of under-fitting. The architectural details of the CNP, NP, and ANP are the same as in [14]. Here we give the detailed architectures of the encoder and decoder of NPs. B.1 Encoder without attention Encoder focuses on learning embeddings for each data point in the context set, and the basic component is multi-layer perceptron, which is defined by MLP(l, din, dh, dout) = LINEAR(dh, dout) (RELU LINEAR(dh, dh) ) | {z } (l 2) LINEAR(dh, din) (20) where l is the number of layers, din, dh and dout are dimensinalities of inputs, hidden unites and outputs. Here RELU( ) is adapted as activation function. The encoder in Vanilla CNP uses a deterministic encoder which focuses on learning embeddings for each data point in context set. ri = MLP(le1, dx + dy, dh, dh)([xi, yi]), i C ri, φ = MLP(le2, dh, dh)(r C) (21) where dx and dy are the dimensionalities of xi and yi. To follow the encoder structure in NP, we introduce another encoder aligned with original deterministic encoder to permit the same number parameters, i.e., r(1) i = MLP(le1, dx + dy, dh, dh)([xi, yi]) i C r(1) i , φ1 = MLP(le2, dh, dh)(r(1) C ) r(2) i = MLP(le1, dx + dy, dh, dh)([xi, yi]) i C r(2) i , φ2 = MLP(le2, dh, dh)(r(2) C ) φ = [φ1, φ2] The encoder in NP contains a deterministic path and a latent path, i.e., r(1) i = MLP(lde1, dx + dy, dh, dh)([xi, yi]) i C r(1) i , φ = MLP(lde2, dh, dh)(r(1) C ) r(2) i = MLP(lla1, dx + dy, dh, dh)([xi, yi]) i C r(2) i , [µz, σ z] = MLP(lla2, dh, dh)(r(2) C ) σz = 0.1 + 0.9 SIGMOID(σ z), z N(µz, diag(σ2 z)). In this case, the encoder outputs deterministic representation φ and latent representation z. B.2 Encoder with attention The attention mechanism is widely used in NPs, Specifically, multi-head attention [26] is adapted, which is defined by Q = {LINEAR(dq, dout)(q)}q Q, {Q i}nhead i=1 = SPLIT(Q , nhead), K = {LINEAR(dk, dout)(k)}k K, {K i}nhead i=1 = SPLIT(K , nhead), V = {LINEAR(dv, dout)(v)}v V, {V i}nhead i=1 = SPLIT(V , nhead), Hi = SOFTMAX Q i(K i) / p dout V i, H = CONCAT ({Hi}nhead i=1 ) H = LAYERNORM(Q + H) MHA(dout)(Q, K, V) = LAYERNORM(H + RELU(LINEAR(dout, dout))) (24) where dq, dv, dk are the dimensionalities of query Q, key K, and value V, respectively. nhead is the number of head. Here Layer normalization [1] LAYERNORM( ) is adapted. It is easy to derive self-attention by setting Q = K = V, i.e., SA(dout))(X) = MHA(dout)(X, X, X) (25) For CNP, the encoder with attention still contains two deterministic paths, fqk = MLP(lqk, dx, dh, dh) Q = fqk(xi), i T K = {fqk(xi)}, i C V = SA(dh).({MLP(lv, dx + dy, dh, dh)([xi, yi])}i C) φ1 = MHA(dh)(Q, K, V) H = SA(dh) ({RELU MLP(le1, dx + dy, dh, dh)([xi, yi])}i C) φ2 = MLP(le, dh, dh) φ = [φ1, φ2] Similarly, encoder with attention in NP contains a deterministic path and a latent path, i.e., fqk = MLP(lqk, dx, dh, dh) Q = fqk(xi), i T K = {fqk(xi)}, i C V = SA(dh).({MLP(lv, dx + dy, dh, dh)([xi, yi])}i C) φ = MHA(dh)(Q, K, V) and H = SA(dh) ({RELU MLP(le1, dx + dy, dh, dh)([xi, yi])}i C) [µz, σ z] = MLP(lla, dh, dh) σz = 0.1 + 0.9 SIGMOID(σ z), z N(µz, diag(σ2 z)). B.3 Decoder The decoder focuses on predicting output for target points based on the encoder s outputs φ. For target point {xi}i T , the decoder of CNP is defined by [µi, σ i] = MLP(ddec, 2dh + dx, dh, 2dy)[φ, xi], i T σi = 0.1 + 0.9 SOFTPLUS(σ i) yi N(µi, σi) (29) Decoder of NP is defined by [µi, σ i] = MLP(ddec, dh + dz + dx, dh, 2dy)[φ, xi, z], i T σi = 0.1 + 0.9 SOFTPLUS(σ i) yi N(µi, σi) (30) C Implementation Details and Experiments For CNP [9], BCNP [18], and our SCNP, we apply the encoder with attention described in Eq (26) and decoder described in Eq (29). For NP [10], ANP [14], BNP [18], BANP [18] and our SNP and SANP models, we apply encoder with attention described in Eq (27) and (28), and decoder described in Eq (30). C.1 1D Regression For synthetic 1D regression experiments, the neural architectures for CNP, NP, ANP, BCNP, BNP, BANP, and our SCNP/SNP/SANP refer to Appendix B. The number of hidden units is dh = 128 and latent representation dz = 128. The number of layers are le = lde = lla = lqk = lv = 2. We generate datasets for synthetic 1D regression. Specifically, the stochastic process (SP) initializes with a 0 mean Gaussian Process (GP) y(0) GP(0, k( , )) indexed in the interval x [ 2.0, 2.0], where the radial basis function kernel k(x, x ) = σ2 exp( x x 2/2l2) with s U(0.1, 1.0) and σ U(0.1, 0.6). Furthermore, GP with Matern Kernel is adopted for model-data mismatch scenario, which is defined as k(x, x ) = σ2(1 + 5d/l + 5d2/(3l2)) exp( 5d/l) and d = x x with s U(0.1, 1.0) and σ U(0.1, 0.6). For a fair comparison, we set the same data generation, training, and testing for all models. We trained all models for 100, 000 steps with each step computing updates with a batch containing 100 tasks. We used the Adam optimizer with an initial learning rate 5 10 4 and decayed the learning rate using Cosine annealing scheme for baselines. For SCNP/SNP/SANP, we set K = 3. The size of the context C was drawn as |C| U(3, 200). Testings were done for 3, 000 batches with each batch containing 16 tasks (48, 000 tasks in total). We investigate the model stability from the size of the context set and the percent of added noise. First, we conduct experiments on different size of context set, i.e., |C| {20, 50, 100, 200}. Table 7 shows the Average Log-likelihoods performance comparison between different methods in terms of different context size. We can see that the performance becomes better with the increasing of |C| and NPs with stable solution still achieve better performance. Second, we investigate the model performance in terms of different noise setting. Here we introduce Gaussian noise N(0, 1) and add noise to different proportions of the data, such as {0%, 5%, 10%, 15%}. Table 8 lists the Average Log-likelihoods performance comparison in terms of different noise proportions. Table 7: Average Log-likelihoods over all context and target points on realizations from Synthetic Stochastic Process on different size of context set. (Mean Std). Kernel Method 20 50 100 200 context target context target context target context target CNP 0.8724 0.008 0.4334 0.007 0.9533 0.005 0.4854 0.010 1.1224 0.005 0.5322 0.007 0.1563 0.004 0.5783 0.004 BCNP 0.9015 0.009 0.4579 0.007 0.9787 0.007 0.5215 0.008 1.1687 0.007 0.5716 0.009 0.1985 0.005 0.6086 0.005 SCNP 0.9255 0.008 0.4733 0.004 0.9944 0.005 0.5433 0.006 1.1833 0.006 0.5918 0.005 0.2111 0.004 0.6333 0.004 NP 0.8215 0.004 0.3853 0.005 0.9124 0.006 0.4234 0.003 1.0855 0.003 0.4767 0.005 1.1225 0.002 0.5233 0.004 BNP 0.8714 0.004 0.4122 0.004 0.9712 0.004 0.4718 0.004 1.1426 0.005 0.5269 0.006 1.1716 0.004 0.5698 0.004 SNP 0.8955 0.003 0.4356 0.004 0.9866 0.004 0.4934 0.005 1.1637 0.004 0.5434 0.004 1.1958 0.003 0.5933 0.003 ANP 1.2563 0.002 0.5763 0.004 1.3233 0.002 0.6322 0.003 1.4633 0.003 0.6866 0.006 1.4982 0.006 0.7322 0.004 BANP 1.2715 0.003 0.5878 0.005 1.3325 0.004 0.6465 0.004 1.4778 0.003 0.6915 0.004 1.5115 0.005 0.7436 0.005 SANP 1.2831 0.000 0.5994 0.004 1.3452 0.002 0.6577 0.003 1.4898 0.001 0.7043 0.002 1.5285 0.002 0.7534 0.005 CNP 0.8531 0.005 0.2431 0.010 0.9123 0.005 0.2984 0.003 1.0522 0.004 0.3542 0.002 1.0984 0.008 0.4022 0.005 BCNP 0.8765 0.006 0.2788 0.009 0.9411 0.006 0.3266 0.005 1.0752 0.004 0.3762 0.004 1.1245 0.007 0.4326 0.006 SCNP 0.8963 0.003 0.2953 0.006 0.9555 0.004 0.3467 0.003 1.0967 0.003 0.3967 0.003 1.1467 0.008 0.4556 0.005 NP 0.7643 0.015 0.2041 0.015 0.8221 0.002 0.2547 0.003 0.9322 0.003 0.3155 0.004 1.0452 0.005 0.3563 0.005 BNP 0.8052 0.008 0.2651 0.007 0.8678 0.003 0.3163 0.004 0.9672 0.003 0.3656 0.005 1.1052 0.005 0.4015 0.005 SNP 0.8368 0.006 0.2844 0.005 0.8956 0.002 0.3326 0.004 0.9959 0.003 0.3849 0.004 1.1215 0.003 0.4313 0.004 ANP 1.2421 0.002 0.6366 0.004 1.3022 0.001 0.6881 0.004 1.4211 0.004 0.7331 0.003 1.4631 0.002 0.7753 0.008 BANP 1.3452 0.007 0.6513 0.004 1.4056 0.003 0.7015 0.004 1.4505 0.004 0.7531 0.005 1.4986 0.004 0.7996 0.006 SANP 1.3721 0.002 0.6653 0.004 1.4322 0.003 0.7126 0.004 1.4633 0.004 0.7644 0.003 1.5153 0.005 0.8125 0.004 Table 8: Average Log-likelihoods over all context and target points on realizations from Synthetic Stochastic Process on different percent of added noise. Here we set the context size to 20. (Mean Std). Note that adding S before the original model name is a model with our stable solution. Kernel Method Original Noise(+5%) Noise(+10%) Noise(+15%) context target context target context target context target CNP 0.8724 0.008 0.4334 0.007 0.8522 0.005 0.4001 0.010 0.8014 0.006 0.3552 0.004 0.7152 0.006 0.2853 0.005 BCNP 0.9042 0.009 0.4589 0.006 0.8774 0.006 0.4278 0.008 0.8316 0.006 0.3767 0.005 0.7487 0.007 0.3017 0.006 SCNP 0.9255 0.008 0.4733 0.004 0.8935 0.005 0.4478 0.006 0.8517 0.004 0.3986 0.005 0.7621 0.005 0.3279 0.006 NP 0.8215 0.004 0.3853 0.005 0.8011 0.004 0.3511 0.006 0.7611 0.005 0.3042 0.008 0.6722 0.005 0.2435 0.007 BNP 8722 0.004 0.4211 0.004 0.8321 0.003 0.3876 0.004 0.7922 0.004 0.3389 0.005 0.7189 0.004 0.2776 0.007 SNP 0.8955 0.003 0.4356 0.004 0.8567 0.003 0.4046 0.005 0.8165 0.004 0.3568 0.006 0.7356 0.005 0.2955 0.006 ANP 1.2563 0.002 0.5763 0.004 1.2245 0.007 0.5347 0.006 0.1742 0.005 0.4871 0.007 0.9821 0.005 0.4151 0.004 BANP 1.2722 0.004 0.5887 0.006 1.2411 0.005 0.5471 0.005 0.1886 0.006 0.4917 0.006 1.0642 0.006 0.4327 0.005 SANP 1.2831 0.000 0.5994 0.004 1.2564 0.004 0.5578 0.004 1.2052 0.004 0.5025 0.006 1.1243 0.005 0.4356 0.006 CNP 0.8531 0.005 0.2431 0.010 0.8231 0.005 0.2144 0.010 0.7761 0.008 0.1784 0.007 0.7052 0.005 0.1452 0.006 BCNP 0.8778 0.005 0.2762 0.009 0.8487 0.006 0.2477 0.009 0.8015 0.007 0.2051 0.007 0.7378 0.005 0.1766 0.006 SCNP 0.8963 0.003 0.2953 0.006 0.8689 0.005 0.2658 0.007 0.8268 0.006 0.2258 0.006 0.7567 0.004 0.1936 0.005 NP 0.7643 0.015 0.2041 0.015 0.7342 0.002 0.1725 0.008 0.6892 0.004 0.1542 0.006 0.6235 0.008 0.1342 0.007 BNP 0.8156 0.005 0.2689 0.007 0.7789 0.004 0.2215 0.005 0.7421 0.005 0.2117 0.007 0.6715 0.006 0.1828 0.006 SNP 0.8368 0.006 0.2844 0.005 0.8036 0.003 0.2483 0.003 0.7635 0.004 0.2325 0.006 0.6973 0.006 0.2016 0.005 ANP 1.2421 0.002 0.6366 0.004 1.2115 0.001 0.6001 0.008 1.1784 0.004 0.5622 0.006 1.1252 0.007 0.5274 0.008 BANP 1.3456 0.003 0.6514 0.005 1.3125 0.005 0.6115 0.002 1.2672 0.004 0.5711 0.005 1.2236 0.006 0.5306 0.006 SANP 1.3721 0.002 0.6653 0.004 1.3461 0.003 0.6256 0.004 1.3011 0.003 0.5782 0.004 1.2457 0.005 0.5356 0.002 C.2 System Identification on Physics Engines The second synthetic experiment focuses on evaluating model dynamics on a classical simulator, Cart-Pole systems, which is detailed in [7, 27]. The Cart-Pole swing-up task is a standard benchmark for nonlinear control due to the non-linearity in the dynamics, and the requirement for nonlinear controllers to successfully swing up and balance the pendulum. A pendulum of length l is attached to a cart by a frictionless pivot. The system begins with the cart at position xc = 0 and the pendulum hanging down: θ. The goal is to accelerate the cart by applying horizontal force ut at each time-step t to invert and then stabilize the pendulum s endpoint at the goal. There are some parameters that need to be known, such as cart mass mc, pendulum mass mp, acceleration of gravity g = 9.82m/s2, time horizon T, time discretization t and ground friction coefficient fc. In this case, the Cart-Pole swing-up task aims to forecast the transited state [xc, θ, x c, θ ] in time step t + 1 based on the input as a state action pair [xc, θ, x c, θ , a] in time step t. For system identification task on physics engines, the neural architectures for CNP, NP, ANP, BANP and our RNP refer to Appendix B. The number of hidden unites is dh = 32 and latent representation dz = 32. The number of layers are le = lde = lla = lqk = lv = 2. To generate a variety of trajectories under a random policy for this experiment, the mass mc and the ground friction coefficient fc are varied in the discrete choices mc {0.3, 0.4, 0.5, 0.6, 0.7} and fc {0.06, 0.08, 0.1, 0.12}. Each pair of [mc, fc] values specifies a dynamics environment, and we formulate all pairs of mc {0.3, 0.5, 0.7} and fc {0.08, 0.12} as training environments with the rest 16 pairs of configurations as the testing environments. For each configuration of the simulator including training and testing environments, we sample 400 trajectories of horizon as 10 steps using a random controller, and more details refer to Supplementary material. During the testing process, 𝑥! Figure 4: Cart-Pole Dynamical Systems.The cart and the pole are with masses mc and mp, and the length of the pole is l. And the configuration of the simulator is up to parameters of the cart-pole mass and the ground friction coefficient here with other hyper-parameters fixed in this experiment. Table 9: Average Log-likelihoods over all context and target points on EMNIST and CELEBA. Dataset Method Original Noise(+5%) Noise(+10%) Noise(+15%) context target context target context target context target CNP 0.9522 0.023 0.7515 0.0015 0.8977 0.0016 0.6336 0.017 0.8242 0.0018 0.5784 0.009 0.6566 0.0017 0.5341 0.016 BCNP 0.9678 0.010 0.8058 0.008 0.9015 0.008 0.6711 0.009 0.8415 0.007 0.6089 0.009 0.6788 0.006 0.5715 0.006 SCNP 0.9716 0.008 0.8343 0.006 0.9251 0.008 0.6971 0.007 0.8674 0.006 0.6343 0.007 0.6986 0.005 0.5877 0.005 NP 0.9678 0.004 0.7756 0.005 0.9011 0.009 0.6941 0.006 0.8544 0.009 0.6455 0.007 0.7034 0.009 0.5865 0.006 BNP 0.9757 0.005 0.8358 0.005 0.8116 0.007 0.7625 0.006 0.8759 0.007 0.6773 0.007 0.7451 0.005 0.6237 0.005 SNP 0.9847 0.005 0.8562 0.006 0.8368 0.006 0.7844 0.005 0.8984 0.005 0.6984 0.005 0.7653 0.005 0.6456 0.004 ANP 1.1125 0.002 1.0321 0.004 0.9815 0.002 0.6366 0.006 0.9021 0.004 0.7053 0.008 0.8454 0.002 0.7034 0.005 BANP 1.1355 0.003 1.0615 0.005 1.0236 0.002 0.6549 0.005 0.9155 0.004 0.7521 0.006 0.8612 0.003 0.7515 0.005 SANP 1.1531 0.000 1.0877 0.004 1.0421 0.002 0.6776 0.005 0.9321 0.002 0.7843 0.006 0.8732 0.003 0.7657 0.005 CNP 1.0323 0.016 0.7845 0.013 1.0177 0.016 0.7438 0.017 0.8956 0.009 0.7344 0.011 0.7677 0.012 0.6096 0.009 BCNP 1.0452 0.009 0.8015 0.008 1.0275 0.009 0.7726 0.008 0.9351 0.006 0.8376 0.009 0.8015 0.010 0.6816 0.008 SCNP 1.0525 0.008 0.8243 0.006 1.0348 0.008 0.7868 0.006 0.9562 0.004 0.8545 0.008 0.8344 0.006 0.7045 0.005 NP 1.1333 0.004 0.8766 0.005 1.1043 0.015 0.8355 0.015 1.0034 0.008 0.8456 0.006 0.8935 0.009 0.6893 0.006 BNP 1.1732 0.005 0.8901 0.006 1.1378 0.007 0.8678 0.006 1.0411 0.008 0.8711 0.005 0.9256 0.008 0.7671 0.006 SNP 1.1952 0.005 0.9062 0.006 1.1565 0.005 0.8846 0.005 1.0542 0.006 0.8956 0.004 0.9425 0.006 0.7985 0.005 ANP 1.2633 0.002 1.0163 0.004 1.1377 0.004 0.9866 0.006 1.0418 0.004 0.8845 0.006 0.9363 0.004 0.7346 0.008 BANP 1.2751 0.002 1.0389 0.005 1.1488 0.004 1.0155 0.005 1.0602 0.005 0.9255 0.006 0.9489 0.004 0.8415 0.007 SANP 1.2854 0.000 1.0594 0.004 1.1685 0.002 1.0353 0.004 1.0772 0.002 0.9455 0.004 0.9655 0.003 0.8673 0.005 100 state transition pairs are randomly selected for each configuration of the environment, working as the maximum context points to identify the configuration of dynamics. C.3 Image Completion Analogous to the 1D experiments, we take random pixels of a given image at training as targets, and select a subset of this as contexts, again choosing the number of contexts and targets randomly (n U[3, 200], m n + U[0, 200 n]). The xi are rescaled to [ 1, 1] and the yi are rescaled to [ 0.5, 0.5]. We use a batch size of 16 for both EMNIST and Celeb A, i.e. use 16 randomly selected images for each batch. For image completion experiments on EMNIST and Celeb A dataset, the neural architectures for CNP, NP, ANP, BCNP, BNP, BANP, and our SCNP, SNP, and SANP refer to Appendix B. The number of hidden unites is dh = 128 and latent representation dz = 128. The number of layers are le = lde = 4, lla = lqk = lv = 5. hhead = 8 C.4 Uncertainty Measuring Methods for reasoning under uncertainty are a key building block of accurate and reliable machine learning systems. We further analyze the learned models using the framework introduced in [17] to quantify uncertainty by investigate the calibration error and sharpness of the models. By assuming predictive distribution Pθ(yi|xi, DC) as Gaussian distribution N(yi; µi, σ2 i ), we can get the probabilistic forecast Fi(xi). More formally, m confidence levels 0 p1 < p2 < pm 1 are choose. For each threshold pj, we compute the empirical frequency ˆpj = |yi|Fi(yi) pj, i T | Table 10: Calibration error and sharpness of the models for 1D regression experiments (Mean Std). Method RBF Matern CAL SHA CAL SHA CNP 0.0724 0.008 0.0887 0.007 0.0514 0.005 0.0831 0.010 BCNP 0.1015 0.007 0.1151 0.005 0.0724 0.005 0.1152 0.009 SCNP 0.1225 0.005 0.1357 0.006 0.0425 0.005 0.1325 0.008 NP 0.0615 0.004 0.0715 0.005 0.0343 0.015 0.0717 0.015 BNP 0.0871 0.005 0.1052 0.006 0.0325 0.009 0.0715 0.008 SNP 0.1155 0.004 0.1257 0.005 0.0347 0.008 0.0718 0.007 ANP 0.1532 0.002 0.0616 0.004 0.0921 0.004 0.0871 0.006 BANP 0.2353 0.002 0.0689 0.004 0.0752 0.005 0.0741 0.006 SANP 0.2633 0.000 0.0741 0.004 0.0415 0.002 0.0667 0.004 Table 11: Calibration error and sharpness of the models for system identification experiments (Mean Std). Method CAL SHA CNP 0.0872 0.008 0.0415 0.007 BCNP 0.1051 0.005 0.0765 0.006 SCNP 0.1124 0.005 0.0833 0.005 NP 0.0821 0.003 0.0581 0.005 BNP 0.0952 0.003 0.0616 0.005 SNP 0.1001 0.001 0.0668 0.005 ANP 0.0863 0.001 0.0673 0.004 BANP 0.1235 0.001 0.0815 0.005 SANP 0.1431 0.000 0.1094 0.004 In this case, the calibration error is defined as a numerical score describing the quality of forecast calibration: CAL(F1, y1, , F|T |, y|T |) = j=1 wj (pj ˆpj)2 (32) here wj is weight and we set wj = 1. Sharpness is measured by using the variance var(Fi) = σ2 i of the random variable whose CDF is Fi. Low-variance predictions are tightly centered around one value. A sharpness score can be defined by SHA(F1, , F|T |) = 1 |T | i T σ2 i (33) We evaluated the CE and sharpness of CNP, NP, ANP, BCNP, NBP, BANP, and corresponding stable versions SCNP, SNP, and SANP trained in the experiments. Table 10, 11, 12 list the calibration error and sharpness score on 1D regression, system identification, and image completion tasks. In several settings, models with our stable solution can achieve better calibration and sharpness but work worse in some settings, such as the calibration error being worse than NP and ANP in terms of 1D regression tasks with Matern. The possible reason is that our method tends to produce conservative credible intervals, so become under-confident or less over-confident in some settings. Table 12: Calibration error and sharpness of the models for image completion experiments (Mean Std). Method EMNIST CELEBA CAL SHA CAL SHA CNP 0.0182 0.002 0.0574 0.002 0.0253 0.002 0.0743 0.001 BCNP 0.0415 0.003 0.0716 0.002 0.0412 0.002 0.0981 0.002 SCNP 0.0543 0.001 0.0846 0.002 0.0457 0.002 0.1136 0.002 NP 0.0163 0.001 0.0671 0.002 0.0261 0.001 0.0711 0.001 BNP 0.0352 0.002 0.0815 0.002 0.0463 0.001 0.0981 0.002 SNP 0.0446 0.002 0.0918 0.000 0.0532 0.001 0.1046 0.001 ANP 0.0156 0.002 0.0656 0.002 0.0261 0.004 0.0815 0.006 BANP 0.0412 0.002 0.0871 0.002 0.0513 0.003 0.1125 0.003 SANP 0.0558 0.000 0.0954 0.001 0.0632 0.002 0.1265 0.004 1.24 1.25 1.26 1.27 1.28 1.29 0 1 2 4 8 10 K (a) RBF/context 0.56 0.57 0.58 0.59 0 1 2 4 8 10 K (b) RBF/Target 0 1 2 4 8 10 K (c) Matern/context 0.62 0.63 0.64 0.65 0.66 0.67 0 1 2 4 8 10 K (d) Matern/target Figure 5: The Log-likelihood comparisons with different K on 1D regression task. 0 1 2 4 8 10 0 1 2 4 8 10 Figure 6: The LL and MAE comparisons with different K on system identification task. 0 1 2 3 4 5 6 7 8 9 10 K (a) EMNIST 0 1 2 3 4 5 6 7 8 9 10 K (b) CELEBA Figure 7: The LL comparisons with different K on image completion task. C.5 Ablation Study The key parameter in our stable solution is the number of hard predictive subsets K. Taking SANP as an example, we investigated the average log-likelihood in terms of different K on 1D regression task, as shown in Figure 5. We can see that SANP performs better as K increases, reaches the best value at around K = 4, and then becomes stable in performance as K grows larger. As proved in Theorem 5.4 in the main manuscript, optimization on DC and more than one hard predicted subsets of DC can achieve more stable prediction. Similarly, we also conducted experiments to investigate the effect of K on System identification and image completion task and similar observations can be seen in Figure 6 and 7.