# castle_regularization_via_auxiliary_causal_graph_discovery__2ee68479.pdf CASTLE: Regularization via Auxiliary Causal Graph Discovery Trent Kyono University of California, Los Angeles tmkyono@ucla.edu Yao Zhang University of Cambridge yz555@cam.ac.uk Mihaela van der Schaar University of Cambridge University of California, Los Angeles The Alan Turing Institute mv472@cam.ac.uk Regularization improves generalization of supervised models to out-of-sample data. Prior works have shown that prediction in the causal direction (effect from cause) results in lower testing error than the anti-causal direction. However, existing regularization methods are agnostic of causality. We introduce Causal Structure Learning (CASTLE) regularization and propose to regularize a neural network by jointly learning the causal relationships between variables. CASTLE learns the causal directed acyclical graph (DAG) as an adjacency matrix embedded in the neural network s input layers, thereby facilitating the discovery of optimal predictors. Furthermore, CASTLE efficiently reconstructs only the features in the causal DAG that have a causal neighbor, whereas reconstruction-based regularizers suboptimally reconstruct all input features. We provide a theoretical generalization bound for our approach and conduct experiments on a plethora of synthetic and real publicly available datasets demonstrating that CASTLE consistently leads to better out-of-sample predictions as compared to other popular benchmark regularizers. 1 Introduction A primary concern of machine learning, and deep learning in particular, is generalization performance on out-of-sample data. Over-parameterized deep networks efficiently learn complex models and are, therefore, susceptible to overfit to training data. Common regularization techniques to mitigate overfitting include data augmentation [1, 2], dropout [3, 4, 5], adversarial training [6], label smoothing [7], and layer-wise strategies [8, 9, 10] to name a few. However, these methods are agnostic of the causal relationships between variables limiting their potential to identify optimal predictors based on graphical topology, such as the causal parents of the target variable. An alternative approach to regularization leverages supervised reconstruction, which has been proven theoretically and demonstrated empirically to improve generalization performance by obligating hidden bottleneck layers to reconstruct input features [11, 12]. However, supervised auto-encoders suboptimally reconstruct all features, including those without causal neighbors, i.e., adjacent cause or effect nodes. Naively reconstructing these variables does not improve regularization and representation learning for the predictive model. In some cases, it may be harmful to generalization performance, e.g., reconstructing a random noise variable. Equal contribution 34th Conference on Neural Information Processing Systems (Neur IPS 2020), Vancouver, Canada. Although causality has been a topic of research for decades, only recently has cause and effect relationships been incorporated into machine learning methodologies and research. Recently, researchers at the confluence of machine learning and causal modeling have advanced causal discovery [13, 14], causal inference [15, 16], model explainability [17], domain adaptation [18, 19, 20] and transfer learning [21] among countless others. The existing synergy between these two disciplines has been recognized for some time [22], and recent work suggests that causality can improve and complement machine learning regularization [23, 24, 25]. Furthermore, many recent causal works have demonstrated and acknowledged the optimality of predicting in the causal direction, i.e., predicting effect from cause, which results in less test error than predicting in the anti-causal direction [21, 26, 27, 28]. Contributions. In this work, we introduce a novel regularization method called CASTLE (CAusal STructure LEarning) regularization. CASTLE regularization uses causal graph discovery as an auxiliary task when training a supervised model to improve the generalization performance of the primary prediction task. Specifically, CASTLE learns the causal directed acyclical graph (DAG) under continuous optimization as an adjacency matrix embedded in a feed-forward neural network s input layers. By jointly learning the causal graph, CASTLE can surpass the benefits provided by feature selection regularizers by identifying optimal predictors, such as the target variable s causal parents. Additionally, CASTLE further improves upon auto-encoder-based regularization [12] by reconstructing only the input features that have neighbors (adjacent nodes) in the causal graph. Regularization of a predictive model to satisfy the causal relationships among feature and target variables effectively guide the model towards the direction of better out-of-sample generalization guarantees. We provide a theoretical generalization bound for CASTLE and demonstrate improved performance against a variety of benchmark methods on a plethora of real and synthetic datasets. 2 Related Works Table 1: Comparison of related works. METHOD FEAT. SEL. STRUCT. LEARNING CAUSAL PRED. TARGET SEL. CAPACITY-BASED SAE CASTLE We compare to the related work in the simplest supervised learning setting where we desire learning a function from some features X to a target variable Y given some data of the variables X and Y to improve out-of-sample generalization within the same distribution. This is a significant departure from the branches of machine learning algorithms, such as in semi-supervised learning and domain adaptation, where the regularizer is constructed with information other than variables X and Y . Regularization controls model complexity and mitigates overfitting. ℓ1 [29] and ℓ2 [30] regularization are commonly used regularization approaches where the former is used when a sparse model is preferred. For deep neural networks, dropout regularization [3, 4, 5] has been shown to be superior in practice to ℓp regularization techniques. Other capacity-based regularization techniques commonly used in practice include early stopping [31], parameter sharing [31], gradient clipping [32], batch normalization [33], data augmentation [2], weight noise [34], and Mix Up [35] to name a few. Normbased regularizers with sparsity, e.g. Lasso [29], are used to guide feature selection for supervised models. The work of [12] on supervised auto-encoders (SAE) theoretically and empirically shows that adding a reconstruction loss of the input features functions as a regularizer for predictive models. However, this method does not select which features to reconstruct and therefore suffers performance degradation when tasked to reconstruct features that are noise or unrelated to the target variables. Two existing works [25, 23] attempt to draw the connection between causality and regularization. Based on an analogy between overfitting and confounding in linear models, [25] proposed a method to determine the regularization hyperparameter in linear Ridge or Lasso regression models by estimating the strength of confounding. [23] use causality detectors [36, 27] to weight a sparsity regularizer, e.g. ℓ1, for performing non-linear causality analysis and generating multivariate causal hypotheses. Neither of the works has the same objective as us improving the generalization performance of supervised learning models, nor do they overlap methodologically by using causal DAG discovery. Causal discovery is an NP-hard problem that requires a brute-force search through a non-convex combinatorial search space, limiting the existing algorithms to reaching global optima for only small problems. Recent approaches have successfully accelerated these methods by using a novel acyclicity constraint and formulating the causal discovery problem as a continuous optimization over real matrices (avoiding combinatorial search) in the linear [37] and nonlinear [38, 39] cases. CASTLE incorporates these recent causal discovery approaches of [37, 38] to improve regularization for prediction problems in general. As shown in Table 1, CASTLE regularization provides two additional benefits: causal prediction and target selection. First, CASTLE identifies causal predictors (e.g., causal parents if they exist) rather than correlated features. Furthermore, CASTLE improves upon reconstruction regularization by only reconstructing features that have neighbors in the underlying DAG. We refer to this advantage as target selection . Collectively these benefits contribute to the improved generalization of CASTLE. Next we introduce our notation (Section 3.1) and provide more details of these benefits (Section 3.2). 3 Methodology In this section, we provide a problem formulation with causal preliminaries for CASTLE. Then we provide a motivational discussion, regularizer methodology, and generalization theory for CASTLE. 3.1 Problem Formulation In the standard supervised learning setting, we denote the input feature variables and target variable, by X = [X1, ..., Xd] X and Y Y, respectively, where X Rd is a d-dimensional feature space and Y R is a one-dimensional target space. Let PX,Y denote the joint distribution of the features and target. Let [N] denote the set {1, ..., N}. We observe a dataset, D = (Xi, Yi), i [N] , consisting of N i.i.d. samples drawn from PX,Y . The goal of a supervised learning algorithm A is to find a predictive model, f Y : X Y, in a hypothesis space H that can explain the association between the features and the target variable. In the learning algorithm A, the predictive model ˆf Y is trained on a finite number of samples in D, to predict well on the out-of-sample data generated from the same distribution PX,Y . However, overfitting, a mismatch between training and testing performance of ˆf Y , can occur if the hypothesis space H is too complex and the training data fails to represent the underlying distribution PX,Y . This motivates the usage of regularization to reduce the hypothesis space s complexity H so that the learning algorithm A will only find the desired function to explain the data. Assumptions of the underlying distribution dictate regularization choice. For example, if we believe only a subset of features is associated with the label Y , then ℓ1 regularization [29] can be beneficial in creating sparsity for feature selection. CASTLE regularization is based on the assumption that a causal DAG exists among the input features and target variable. In the causal framework of [40], a causal structure of a set of variables X is a DAG in which each vertex v V corresponds to a distinct element in X, and each edge e E represents direct functional relationships between two neighboring variables. Formally, we assume the variables in our dataset satisfy a nonparametric structural equation model (NPSEM) as defined in Definition 1. The word nonparametric means we do not make any assumption on the underlying functions fi in the NPSEM. In this work, we characterize optimal learning by a predictive model as discovering the function Y = f Y (Pa(Y ), u Y ) in NPSEM [40]. Definition 1. (NPSEMs) Given a DAG G = (V = [d + 1], E), the random variables X = [Y, X] satisfy a NPSEM if Xi = fi(Pa(Xi), ui), i [d + 1], where Pa(i) is the parents (direct causes) of Xi in G and u[d+1] are some random noise variables. 3.2 Why CASTLE regularization matters We now present a graphical example to explain the two benefits of CASTLE mentioned in Section 2, causal prediction and target selection. Consider Figure 1 where we are given nine feature variables X1, ..., X9 and a target variable Y . Causal Prediction. The target variable Y is generated by a function f Y (Pa(Y ), u Y ) from Definition 1 where the parents of Y are Pa(Y ) = {X2, X3}. In CASTLE regularization, we train a predictive model ˆf Y jointly with learning the DAG among X and Y . The features that the model uses to predict Y are the causal parents of Y in the learned DAG. Such a model is sample efficient in uncovering the true function f Y (Pa(Y ), u Y ) and generalizes well on the out-of-sample data. Our theoretical analysis in Section 3.4 validates this advantage when there exists a DAG structure among the variables X and Y . However, there may exist other variables that predict Y more accurately than the causal parents Pa(Y ). For example, if the function from Y to X8 is a one-to-one linear mapping, we can predict Y trivially from the feature X8. In our objective function introduced later, the prediction loss of Y will be weighted higher than the causal regularizer. Among the predictive models with a similar prediction loss of Y , our objective function still prefers to use the model, which minimizes the causal regularizer and uses the causal parents. However, it would favor the easier predictor if one exists and gives a much lower prediction loss of Y . In this case, the learned DAG may differ from the true DAG, but we reiterate that we are focused on the problem of generalization rather than causal discovery. Figure 1: Example DAG. Target Selection. Consider the variables X5, X6 and X7 which share parents (X2 and X3) with Y in Figure 1. The functions X5 = f5(X2, u5), X6 = f6(X3, u6), and X7 = f7(X3, u7) may have some learnable similarity (e.g. basis functions and representations) with Y = f Y (X2, X3, u Y ), that we can exploit by training a shared predictive model of Y with the auxiliary task of predicting X5, X6 and X7. From the causal graph topology, CASTLE discovers the optimal features that should act as the auxiliary task for learning f Y . CASTLE learns the related functions jointly in a shared model, which is proven to improve the generalization performance of predicting Y by learning shared basis functions and representations [41]. 3.3 CASTLE regularization Let X = Y X denote the data space, P(X,Y ) = P X the data distribution, and F the Frobenius norm. We define random variables X = [ X1, X2, ..., Xd+1] := [Y, X1, ..., Xd] X. Let X = X1, ..., Xd denote the N d input data matrix, Y the N-dimensional label vector, X = [Y, X] the N (d + 1) matrix that contains data of all the variables in the DAG. To facilitate exposition, we first introduce CASTLE in the linear setting. Here, the parameters are a (d + 1) (d + 1) adjacency matrix W with zero in the diagonal. The objective function is given as ˆ W min W 1 N Y XW:,1 2 + λRDAG( X, W) (1) where W:,1 is the first column of W. We define the DAG regularization loss RDAG( X, W) as RDAG( X, W) = LW + RW + βVW. (2) where LW = 1 N X XW 2 F , RW = Tr e W W d 1 2, VW is the ℓ1 norm of W, is the Hadamard product, and e M is the matrix exponential of M. The DAG loss RDAG( X, W) is introduced in [37] for learning linear DAG by continuous optimization. Here we use it as the regularizer for our linear regression model Y = XW:,1 + ϵ. From Theorem 1 in [37], we know the graph given by W is a DAG if and only if RW = 0. The prediction ˆY = XW:,1 is the projection of Y onto the parents of Y in the learned DAG. This increases the stability of linear regression when issues pertaining to collinearity or multicollinearity among the input features appear. Continuous optimization for learning nonparametric causal DAGs has been proposed in the prior work by [38]. In a similar manner, we also adapt CASTLE to nonlinear cases. Suppose the predictive model for Y and the function generating each feature Xk in the causal DAG are parameterized by an M-layer feed-forward neural network fΘ : X X with Re LU activations and layer size h. Figure 2 shows the network architecture of fΘ. This joint network can be instantiated as a d + 1 sub-network fk with shared hidden layers, where fk is responsible for reconstructing the feature Xk. We let Wk 1 denote the h (d + 1) weight matrix in the input layer of fk, k [d + 1]. We set the k-th column of Wk 1 to zero such that fk does not utilize Xk in its prediction of Xk. We let Wm, m = 2, .., M 1 denote the weight matrices in the network s shared hidden layers, and WM = [W1 M, ..., Wd+1 M ] denotes the h (d+1) weight matrix in the output layer. Explicitly, we define the sub-network fk as fk( X) = φ φ φ XWk 1 W2 WM 1 Wk M, (3) Shared Layers For k=1,..., d: Recons. Loss Shared Layers Figure 2: Schematic of CASTLE regularization. Our goal is to have the following tasks: (1) a prediction of a target variable Y , and (2) the discovered causal DAG for input features X and Y . where φ( ) is the Re LU activation function. The function fΘ is given as fΘ( X) = [f1( X), ..., fd+1( X)]. Let fΘ( X) denote the prediction for the N samples matrix X where [fΘ( X)]i,k = fk( Xi), i [N] and k [d + 1]. All network parameters are collected into sets as Θ1 = {Wk 1}d+1 k=1, Θ = Θ1 {Wm}M k=2 (4) The training objective function of fΘ is Θ min Θ 1 N Y [fΘ( X)]:,1 2 + λRDAG X, fΘ . (5) The DAG loss RDAG X, fΘ is given as RDAG X, fΘ = LN(fΘ) + RΘ1 + βVΘ1. (6) Because the k-th column of the input weight matrix Wk 1 is set to zero, LN(fΘ) = 1 N X fΘ( X) 2 F differs from the standard reconstruction loss in auto-encoders (e.g. SAE) by only allowing the model to reconstruct each feature and target variable from the others. In contrast, auto-encoders reconstruct each feature using all the features including itself. VΘ1 is the ℓ1 norm of the weight matrices Wk 1 in Θ1, and the term RΘ1 is given as, RΘ1 = (Tr e M M d 1)2, (7) where M is a (d + 1) (d + 1) matrix such that [M]k,j is the ℓ2-norm of the k-th row of the matrix Wj 1. When the acyclicity loss RΘ1 is minimized, the sub-networks f1, . . . fd+1 forms a DAG among the variables; RΘ1 obligates the sub-networks to reconstruct only the input features that have neighbors (adjacent nodes) in the learned DAG. We note that converting the nonlinear version of CASTLE into a linear form can be accomplished by removing all the hidden layers and output layers and setting the dimension h of the input weight matrices to be 1 in (3), i.e., fk( X) = XWk 1 and fΘ( X) = [ XW1 1, ..., XWd+1 1 ] = XW, which is the linear model in (1-2). Managing computational complexity. If the number of features is large, it is computationally expensive to train all the sub-networks simultaneously. We can mitigate this by sub-sampling. At each iteration of gradient descent, we randomly sample a subset of features to reconstruct and only minimize the prediction loss and reconstruction loss on these sub-sampled features. Note that we do not have a hidden confounders issue here, since Y and the sub-sampled features are predicted by all the features except itself. The sparsity DAG constraint on the weight matrices is unchanged at each iteration. In this case, we keep the training complexity per iteration at a manageable level approximately around the computational time and space complexity of training a few networks jointly. We include experiments on CASTLE scalability with respect to input feature size in Appendix C. 3.4 Generalization bound for CASTLE regularization In this section, we analyze theoretically why CASTLE regularization can improve the generalization performance by introducing a generalization bound for our model in Figure 2. Our bound is based on the PAC-Bayesian learning theory in [42, 43, 44]. Here, we re-interpret the DAG regularizer as a special prior or assumption on the input weight matrices of our model and use existing PAC-Bayes theory to prove the generalization of our algorithm. Traditionally, PAC-Bayes bounds are only applied to randomized models, such as Bayesian or Gibbs classifiers. Here, our bound is applied to our deterministic model by using the recent derandomization formalism from [45, 46]. We acknowledge and note that developing tighter and non-vacuous generalization bounds for deep neural networks is still a challenging and evolving topic in learning theory. The bounds are often stated with many constants from different steps of the proof. For reader convenience, we provide the simplified version of our bound in Theorem 1. The proof, details (e.g., the constants), and discussions about the assumptions are provided in Appendix A. We begin with a few assumptions before stating our bound. Assumption 1. For any sample X = (Y, X) P X, X has bounded ℓ2 norm s.t. X 2 B, for some B > 0. Assumption 2. The loss function L(fΘ) = fΘ( X) X 2 is sub-Gaussian under the distribution P X with a variance factor s2 s.t. t > 0, EP X h exp t L(fΘ) LP (fΘ) i exp( t2s2 Theorem 1. Let fΘ : X X be a M-layer Re LU feed-forward network with layer size h, and each of its weight matrices has the spectral norm bounded by κ. Then, under Assumptions 1 and 2, for any δ, γ > 0, with probability 1 δ over a training set of N i.i.d samples, for any Θ in (4), we have: LP (fΘ) 4LN(fΘ) + 1 N h RΘ1 + C1(VΘ1 + VΘ2) + log 8 δ i + C3 (8) where LP (fΘ) is the expected reconstruction loss of X under P X, LN(fΘ), VΘ1 and RΘ1 are defined in (6-7), VΘ2 is the ℓ2 norm of the network weights in the output and shared hidden layers, and C1 and C2 are some constants depending on γ, d, h, B, s and M. The statistical properties of the reconstruction loss in learning linear DAGs, e.g. LW = 1 N X W X 2 F , have been well studied in the literature: the loss minimizer provably recovers a true DAG with high probability on finite-samples, and hence is consistent for both Gaussian SEM [47] and non-Gaussian SEM [48, 49]. Note also that the regularizer RW or RΘ1 are not a part of the results in [47, 48, 49]. However, the works of [37, 38] empirically show that using RW or RΘ1 on top of the reconstruction loss leads to more efficient and more accurate DAG learning than existing approaches. Our theoretical result on the reconstruction loss explains the benefit of RW or RΘ1 for the generalization performance of predicting Y . This provides theoretical support for our CASTLE regularizer in supervised learning. However, the objectives of DAG discovery, e.g., identifying the Markov Blanket of Y , is beyond the scope of our analysis. The bound in (8) justifies RΘ1 in general, including linear or nonlinear cases, if the underlying distribution P X is factorized according to some causal DAG. We note that the expected loss LP (fΘ) is upper bounded by the empirical loss LN(fΘ), VΘ1, VΘ1 and RΘ1 which measures how close (via acyclicity constraint) the model is to a DAG. From (8) it is obvious that not minimizing RΘ1 is an acceptable strategy asymptotically or in the large samples limit (large N) because RΘ1/N becomes negligible. This aligns with the consistency theory in [47, 48, 49] for linear models. However for small N, a preferred strategy is to train a model fΘ by minimizing LN(fΘ) and RΘ1 jointly. This would be trivial because the samples are generated under the DAG structure in P X. Minimizing RΘ1 can decrease the upper bound of LP (fΘ) in (8), improve the generalization performance of fΘ, as well as facilitate the convergence of fΘ to the true model. If P X does not correspond to any causal DAG, such as image data, then there will be a tradeoff between minimizing RΘ1 and LN(fΘ). In this case, RΘ1 becomes harder to minimize, and generalization may not benefit from adding CASTLE. However, this is a rare case since causal structure exists in most datasets inherently. Our experiments demonstrate that CASTLE regularization outperforms popular regularizers on a variety of datasets in the next section. 4 Experiments In this section, we empirically evaluate CASTLE as a regularization method for improving generalization performance. We present our benchmark methods and training architecture, followed by our synthetic and publicly available data results. 0 50 100 150 DAG vertex cardinality |G| Average rank 50 100 150 200 = Dataset size / |G| Average rank 0 20 40 60 80 100 Num. noise variables Average rank Baseline L1 L2 DO(0.2) DO(0.5) SAE BN IN MU CASTLE Figure 3: Experiments on synthetic data. The y-axis is the average rank ( standard deviation) of each regularizer on the test set over each synthetic DAG. We show the average rank as we increase the number of features or vertex cardinality |G| (left), increase the dataset size normalized by the vertex cardinality |G| (center), and as we increase the number of noise (neighborless) variables (right). Benchmarks. We benchmark CASTLE against common regularizers that include: early stopping (Baseline) [31], L1 [29], L2 [30], dropout [3] with drop rate of 20% and 50% denoted as DO(0.2) and DO(0.5) respectively, SAE [12], batch normalization (BN) [33], data augmentation or input noise (IN) [2], and Mix Up (MU) [35], in no particular order. For each regularizer with tunable hyperparameters we performed a standard grid search. For the weight decay regularizers L1 and L2 we searched for λℓp {0.1, 0.01, 0.001}, and for input noise we use a Gaussian noise with mean of 0 and standard deviation σ {0.1, 0.01, 0.01}. L1 and L2 were applied at every dense layer. BN and DO were applied after every dense layer and active only during training. Because each regularization method converges at different rates, we use early stopping on a validation set to terminate each benchmark training, which we refer to as our Baseline. Network architecture and training. We implemented CASTLE in Tensorflow2. Our proposed architecture is comprised of d + 1 sub-networks with shared hidden layers, as shown in Figure 2. In the linear case, VW is the ℓ1 norm of W. In the nonlinear case, VΘ1 is the ℓ1 norm of the input weight matrices Wk 1, k [d + 1]. To make a clear comparison with L2 regularization, we exclude the capacity term VΘ2 from CASTLE, although it is a part of our generalization bound in (8). Since we predict the target variable as our primary task, we benchmark CASTLE against this common network architecture. Specifically, we use a network with two hidden layers of d + 1 neurons with Re LU activation. Each benchmark method is initialized and seeded identically with the same random weights. For dataset preprocessing, all continuous variables are standardized with a mean of 0 and a variance of 1. Each model is trained using the Adam optimizer with a learning rate of 0.001 for up to a maximum of 200 epochs. An early stopping regime halts training with a patience of 30 epochs. 4.1 Regularization on Synthetic Data Table 2: Experiments on nonlinear synthetic data of size n generated according to Fig. 1 in terms of MSE ( standard deviation) Regularizer n = 500 n = 1000 n = 5000 Baseline 0.83 0.03 0.80 0.04 0.73 0.02 L1 0.81 0.05 0.79 0.03 0.71 0.02 L2 0.81 0.05 0.77 0.02 0.71 0.01 DO(0.2) 0.80 0.04 0.79 0.01 0.70 0.02 DO(0.5) 0.79 0.02 0.78 0.04 0.70 0.02 SAE 0.79 0.03 0.77 0.04 0.69 0.02 BN 0.81 0.04 0.79 0.03 0.72 0.02 IN 0.82 0.05 0.78 0.04 0.71 0.02 MU 0.79 0.05 0.78 0.04 0.72 0.08 CASTLE 0.77 0.02 0.75 0.04 0.68 0.02 Synthetic data generation. Given a DAG G, we generate functional relationships between each variable and its respective parent(s) with additive Gaussian noise applied to each variable with a mean of 0 and variance of 1. In the linear case, each variable is equal to the sum of its parents plus noise. For the nonlinear case, each variable is equal to the sum of the sigmoid of its parents plus noise. We provide further details on our synthetic DGP and pseudocode in Appendix B. Consider Table 2, using our nonlinear DGP we generated 1000 test samples according to the DAG in Figure 1. We then used 10-fold cross-validation to train and validate each benchmark on varying training sets of size n. Each model was evaluated on the test set from weights saved at the lowest validation error. Table 2 shows that CASTLE improves over all experimental benchmarks. We present similar results for our linear experiments in Appendix B. 2Code is provided at https://bitbucket.org/mvdschaar/mlforhealthlabpub. Table 3: Comparison of benchmark regularizers on regression and classification in terms of test MSE and AUROC ( standard deviation), respectively, for experiments on real datasets using 10-fold cross-validation. Bold denotes best performance. For conciseness we show only a subset of the benchmarks. The full version of this table is in Appendix C along with results on additional datasets. Dataset Baseline L1 Dropout 0.2 SAE Batch Norm Input Noise Mix Up CASTLE Regression (MSE) BH 0.141 0.023 0.137 0.025 0.168 0.032 0.148 0.027 0.139 0.021 0.137 0.018 0.194 0.064 0.123 0.016 WQ 0.747 0.038 0.747 0.043 0.738 0.029 0.727 0.030 0.723 0.039 0.771 0.036 0.712 0.018 0.708 0.030 FB 0.758 1.017 0.663 0.796 0.429 0.449 0.372 0.168 0.705 0.396 0.609 0.511 0.385 0.208 0.246 0.153 BC 0.359 0.061 0.342 0.037 0.334 0.030 0.322 0.021 0.325 0.024 0.319 0.022 0.322 0.030 0.318 0.036 SP 0.416 0.108 0.421 0.181 0.285 0.042 0.228 0.022 0.318 0.062 0.389 0.095 0.267 0.072 0.200 0.020 CM 0.536 0.103 0.574 0.125 0.327 0.025 0.387 0.034 0.470 0.047 0.495 0.081 0.376 0.030 0.326 0.031 Classification (AUROC) CC 0.764 0.009 0.766 0.007 0.776 0.009 0.774 0.012 0.773 0.009 0.772 0.012 0.778 0.009 0.787 0.007 PD 0.799 0.008 0.793 0.013 0.797 0.010 0.796 0.010 0.773 0.024 0.796 0.013 0.802 0.016 0.817 0.004 BC 0.721 0.018 0.726 0.011 0.718 0.024 0.605 0.068 0.727 0.012 0.722 0.026 0.700 0.055 0.731 0.010 LV 0.559 0.061 0.594 0.020 0.579 0.053 0.542 0.095 0.583 0.026 0.597 0.041 0.553 0.092 0.595 0.032 SH 0.915 0.015 0.921 0.006 0.922 0.017 0.701 0.205 0.913 0.013 0.922 0.005 0.921 0.005 0.929 0.007 RP 0.782 0.071 0.801 0.013 0.743 0.052 0.774 0.103 0.802 0.018 0.796 0.009 0.730 0.043 0.814 0.014 Dissecting CASTLE. In the synthetic environment, we know the causal relationships with certainty. We analyze three aspects of CASTLE regularization using synthetic data. Because we are comparing across randomly simulated DAGs with differing functional relationships, the magnitude of regression testing error will vary between runs. We examine the model performance in terms of each model s average rank over each fold to normalize this. If we have r regularizers, the best and worst possible rank is one and r, respectively (i.e., the higher the rank the better). We used 10-fold cross-validation to terminate model training and tested each model on a held-out test set of 1000 samples. First, we examine the impact of increasing the feature size or DAG vertex cardinality |G|. We do this by randomly generating a DAG of size |G| {10, 50, 100, 150} with 50|G| training samples. We repeat this ten times for each DAG cardinality. On the left-hand side of Fig. 3, CASTLE has the highest rank of all benchmarks and does not degrade with increasing |G|. Second, we analyze the impact of increasing dataset size. We randomly generate DAGs of size |G| {10, 50, 100, 150}, which we use to create datasets of α|G| samples, where α {20, 50, 100, 150, 200}. We repeat this ten times for each dataset size. In the middle plot of Figure 3, we see that CASTLE has superior performance for all dataset sizes, and as expected, all benchmark methods (except for SAE) start to converge about the average rank at large data sizes (α = 200). Third, we analyze our method s sensitivity to noise variables, i.e., variables disconnected to the target variable in G. We randomly generate DAGs of size |G| = 50 to create datasets with 50|G| samples. We randomly add v {20i}5 i=0 noise variables normally distributed with 0 mean and unit variance. We repeat this process for ten different DAG instantiations. The results on the right-hand side of Figure 3 show that our method is not sensitive to the existence of disconnected noise variables, whereas SAE performance degrades with the increase of uncorrelated input features. This highlights the benefit of target selection based on the DAG topology. In Appendix C, we provide an analysis of adjacency matrix weights that are learned under various random DAG configurations, e.g., target with parents, orphaned target, etc. There, we highlight CASTLE in comparison to SAE for target selection by showing that the adjacency matrix weights for noise variables are near zero. We also provide a sensitivity analysis on the parameter λ from (5) and results for additional experiments demonstrating that CASTLE does not reconstruct noisy (neighborless) variables in the underlying causal DAG. 4.2 Regularization on Real Data We perform regression and classification experiments on a spectrum of publicly available datasets from [50] including Boston Housing (BH), Wine Quality (WQ), Facebook Metrics (FB), Bioconcentration (BC), Student Performance (SP), Community (CM), Contraception Choice (CC), Pima Diabetes (PD), Las Vegas Ratings (LV), Statlog Heart (SH), and Retinopathy (RP). For each dataset, we randomly reserve 20% of the samples for a testing set. We perform 10-fold cross-validation on the remaining 80%. As the results show in Table 3, CASTLE provides improved regularization across all datasets for both regression and classification tasks. Additionally, CASTLE consistently ranks as the top regularizer (graphically shown in Appendix C.3), with no definitive benchmark method coming in as a consensus runner-up. This emphasizes the stability of CASTLE as a reliable regularizer. In Appendix C, we provide additional experiments on several other datasets, an ablation study highlighting our sources of gain, and real-world dataset statistics. 5 Conclusion We have introduced CASTLE regularization, a novel regularization method that jointly learns the causal graph to improve generalization performance in comparison to existing capacity-based and reconstruction-based regularization methods. We used existing PAC-Bayes theory to provide a theoretical generalization bound for CASTLE. We have shown experimentally that CASTLE is insensitive to increasing feature dimensionality, dataset size, and uncorrelated noise variables. Furthermore, we have shown that CASTLE regularization improves performance on a plethora of real datasets and, in the worst case, never degrades performance. We hope that CASTLE will play a role as a general-purpose regularizer that can be leveraged by the entire machine learning community. Broader Impact One of the big challenges of machine learning, and deep learning in particular, is generalization to out-of-sample data. Regularization is necessary and used to prevent overfitting thereby promoting generalization. In this work, we have presented a novel regularization method inspired by causality. Since the applicability of our approach spans all problems where causal relationships exist between variables, there are countless beneficiaries of our research. Apart from the general machine learning community, the beneficiaries of our research include practitioners in the social sciences (sociology, psychology, etc.), natural sciences (physics, biology, etc.), and healthcare among countless others. These fields have already been exploiting causality for some time and serve as a natural launch-pad for deploying and leveraging CASTLE. With that said, our method does not immediately apply to certain architectures, such as CNNs, where causal relationships are ambiguous or perhaps non-existent. Acknowledgments This work was supported by Glaxo Smith Kline (GSK), the US Office of Naval Research (ONR), and the National Science Foundation (NSF) 1722516. We thank all reviewers for their invaluable comments and suggestions. [1] Larry S. Yaeger, Richard F. Lyon, and Brandyn J. Webb. Effective training of a neural network character classifier for word recognition. In M. C. Mozer, M. I. Jordan, and T. Petsche, editors, Advances in Neural Information Processing Systems 9, pages 807 816. MIT Press, 1997. [2] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E. Hinton. Imagenet classification with deep convolutional neural networks. In Proceedings of the 25th International Conference on Neural Information Processing Systems - Volume 1, NIPS 12, page 1097 1105, Red Hook, NY, USA, 2012. Curran Associates Inc. [3] Geoffrey E. Hinton, Nitish Srivastava, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Improving neural networks by preventing co-adaptation of feature detectors. Ar Xiv, abs/1207.0580, 2012. [4] Stefan Wager, Sida Wang, and Percy Liang. Dropout training as adaptive regularization. Advances in Neural Information Processing Systems, 07 2013. [5] Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: A simple way to prevent neural networks from overfitting. Journal of Machine Learning Research, 15(56):1929 1958, 2014. [6] Sebastian Lunz, Ozan Öktem, and Carola-Bibiane Schönlieb. Adversarial regularizers in inverse problems. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems 31, pages 8507 8516. Curran Associates, Inc., 2018. [7] Bin-Bin Gao, Chao Xing, Chen-Wei Xie, Jianxin Wu, and Xin Geng. Deep label distribution learning with label ambiguity. IEEE Transactions on Image Processing, 26:2825 2838, 04 2017. [8] Yoshua Bengio, Pascal Lamblin, Dan Popovici, and Hugo Larochelle. Greedy layer-wise training of deep networks. In Proceedings of the 19th International Conference on Neural Information Processing Systems, NIPS 06, page 153 160, Cambridge, MA, USA, 2006. MIT Press. [9] Marc Aurelio Ranzato and Martin Szummer. Semi-supervised learning of compact document representations with deep networks. In Proceedings of the 25th International Conference on Machine Learning, pages 792 799, 01 2008. [10] Tianyu He, Xu Tan, Yingce Xia, Di He, Tao Qin, Zhibo Chen, and Tie-Yan Liu. Layer-wise coordination between encoder and decoder for neural machine translation. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems 31, pages 7944 7954. Curran Associates, Inc., 2018. [11] Pascal Vincent, Hugo Larochelle, Yoshua Bengio, and Pierre-Antoine Manzagol. Extracting and composing robust features with denoising autoencoders. In Proceedings of the 25th international conference on Machine learning, pages 1096 1103. ACM, 2008. [12] Lei Le, Andrew Patterson, and Martha White. Supervised autoencoders: Improving generalization performance with unsupervised regularizers. In Advances in Neural Information Processing Systems, pages 107 117, 2018. [13] Shengyu Zhu and Zhitang Chen. Causal discovery with reinforcement learning. Co RR, abs/1906.04477, 2019. [14] Ruichu Cai, Feng Xie, Clark Glymour, Zhifeng Hao, and Kun Zhang. Triad constraints for learning causal structure of latent variables. In Advances in Neural Information Processing Systems 32, pages 12883 12892. Curran Associates, Inc., 2019. [15] Uri Shalit, Fredrik D Johansson, and David Sontag. Estimating individual treatment effect: generalization bounds and algorithms. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 3076 3085. JMLR. org, 2017. [16] Ahmed Alaa and Mihaela van der Schaar. Validating causal inference models via influence functions. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, pages 191 201, Long Beach, California, USA, 09 15 Jun 2019. PMLR. [17] Patrick Schwab and Walter Karlen. Cxplain: Causal explanations for model interpretation under uncertainty. In Advances in Neural Information Processing Systems 32, pages 10220 10230. Curran Associates, Inc., 2019. [18] Kun Zhang, Bernhard Schölkopf, Krikamol Muandet, and Zhikun Wang. Domain adaptation under target and conditional shift. In Sanjoy Dasgupta and David Mc Allester, editors, Proceedings of the 30th International Conference on Machine Learning (ICML), volume 28 of Proceedings of Machine Learning Research, pages 819 827, 2013. [19] Jonas Peters, Peter Bühlmann, and Nicolai Meinshausen. Causal inference by using invariant prediction: identification and confidence intervals. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 78(5):947 1012, 2016. [20] Sara Magliacane et al. Domain adaptation by using causal inference to predict invariant conditional distributions. In S. Bengio et al., editors, Advances in Neural Information Processing Systems 31, pages 10846 10856. Curran Associates, Inc., 2018. [21] Mateo Rojas-Carulla, Bernhard Schölkopf, Richard Turner, and Jonas Peters. Invariant models for causal transfer learning. Journal of Machine Learning Research, 19(36):1 34, 2018. [22] Bernhard Schoelkopf, Dominik Janzing, Jonas Peters, Eleni Sgouritsa, Kun Zhang, and Joris Mooij. On causal and anticausal learning. Proceedings of the 29th International Conference on Machine Learning, ICML 2012, 2, 06 2012. [23] Mohammad Taha Bahadori, Krzysztof Chalupka, Edward Choi, Robert Chen, Walter F. Stewart, and Jimeng Sun. Causal regularization. Co RR, abs/1702.02604, 2017. [24] Dominik Rothenhausler, Nicolai Meinshausen, Peter Buhlmann, and Jonas Peters. Anchor regression: heterogeneous data meet causality. Co RR, abs/1801.06229, 2018. [25] Dominik Janzing. Causal regularization. In Advances in Neural Information Processing Systems 32, pages 12704 12714. Curran Associates, Inc., 2019. [26] Bernhard Schölkopf, Dominik Janzing, Jonas Peters, Eleni Sgouritsa, Kun Zhang, and Joris Mooij. On causal and anticausal learning. In Proceedings of the 29th International Coference on International Conference on Machine Learning, ICML 12, page 459 466, Madison, WI, USA, 2012. Omnipress. [27] David Lopez-Paz, Robert Nishihara, Soumith Chintala, Bernhard Scholkopf, and Leon Bottou. Discovering causal signals in images. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), July 2017. [28] Dominik Janzing and Bernhard Schölkopf. Semi-supervised interpolation in an anticausal learning scenario. J. Mach. Learn. Res., 16(1):1923 1948, January 2015. [29] Robert Tibshirani. Regression shrinkage and selection via the lasso. Journal of the Royal Statistical Society: Series B (Methodological), 58(1):267 288, 1996. [30] Arthur Hoerl and Robert Kennard. Ridge regression: Biased estimation for nonorthogonal problems. Technometrics, 12:55 67, 04 2012. [31] Ian Goodfellow, Yoshua Bengio, and Aaron Courville. Deep Learning. MIT Press, 2016. [32] Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio. On the difficulty of training recurrent neural networks. In ICML, 2012. [33] Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In Proceedings of the 32nd International Conference on International Conference on Machine Learning - Volume 37, ICML 15, page 448 456. JMLR.org, 2015. [34] Hyeonwoo Noh, Tackgeun You, Jonghwan Mun, and Bohyung Han. Regularizing deep neural networks by noise: Its interpretation and optimization. In Proceedings of the 31st International Conference on Neural Information Processing Systems, NIPS 17, page 5115 5124, Red Hook, NY, USA, 2017. Curran Associates Inc. [35] Hongyi Zhang, Moustapha Cisse, Yann Dauphin, and David Lopez-Paz. mixup: Beyond empirical risk minimization. Proceedings of the 6th International Conference on Learning Representations (ICLR), 2018. [36] Krzysztof Chalupka, Frederick Eberhardt, and Pietro Perona. Estimating causal direction and confounding of two discrete variables. ar Xiv preprint ar Xiv:1611.01504, 2016. [37] Xun Zheng, Bryon Aragam, Pradeep Ravikumar, and Eric P. Xing. DAGs with NO TEARS: Continuous Optimization for Structure Learning. In Advances in Neural Information Processing Systems, 2018. [38] Xun Zheng, Chen Dan, Bryon Aragam, Pradeep Ravikumar, and Eric P Xing. Learning sparse nonparametric dags. ar Xiv preprint ar Xiv:1909.13189, 2019. [39] Sébastien Lachapelle, Philippe Brouillard, Tristan Deleu, and Simon Lacoste-Julien. Gradientbased neural DAG learning. In Proceedings of the 8th International Conference on Learning Representations (ICLR), 2020. [40] J. Pearl. Causality. Causality: Models, Reasoning, and Inference. Cambridge Univ. Press, 2009. [41] Andreas Maurer, Massimiliano Pontil, and Bernardino Romera-Paredes. The benefit of multitask representation learning. The Journal of Machine Learning Research, 17(1):2853 2884, 2016. [42] John Langford and John Shawe-Taylor. Pac-bayes & margins. In Advances in neural information processing systems, pages 439 446, 2003. [43] John Shawe-Taylor and Robert C Williamson. A pac analysis of a bayesian estimator. In Proceedings of the tenth annual conference on Computational learning theory, pages 2 9, 1997. [44] David A. Mc Allester. Some pac-bayesian theorems. In Machine Learning, pages 230 234. ACM Press, 1998. [45] Behnam Neyshabur, Srinadh Bhojanapalli, and Nathan Srebro. A pac-bayesian approach to spectrally-normalized margin bounds for neural networks. ar Xiv preprint ar Xiv:1707.09564, 2017. [46] Vaishnavh Nagarajan and J Zico Kolter. Deterministic pac-bayesian generalization bounds for deep networks via generalizing noise-resilience. ar Xiv preprint ar Xiv:1905.13344, 2019. [47] Po-Ling Loh and Peter Bühlmann. High-dimensional learning of linear causal networks via inverse covariance estimation. The Journal of Machine Learning Research, 15(1):3065 3105, 2014. [48] Bryon Aragam, Arash A Amini, and Qing Zhou. Learning directed acyclic graphs with penalized neighbourhood regression. ar Xiv preprint ar Xiv:1511.08963, 2015. [49] Sara Van de Geer, Peter Bühlmann, et al. l0-penalized maximum likelihood for sparse directed acyclic graphs. The Annals of Statistics, 41(2):536 567, 2013. [50] Dheeru Dua and Casey Graff. UCI machine learning repository, 2020. [51] Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. ar Xiv preprint ar Xiv:1609.04836, 2016. [52] Pascal Germain, Francis Bach, Alexandre Lacoste, and Simon Lacoste-Julien. Pac-bayesian theory meets bayesian inference. In Advances in Neural Information Processing Systems, pages 1884 1892, 2016. [53] Joel A Tropp. User-friendly tail bounds for sums of random matrices. Foundations of computational mathematics, 12(4):389 434, 2012.