# controllable_invariance_through_adversarial_feature_learning__6845d6f2.pdf Controllable Invariance through Adversarial Feature Learning Qizhe Xie, Zihang Dai, Yulun Du, Eduard Hovy, Graham Neubig Language Technologies Institute Carnegie Mellon University {qizhex, dzihang, yulund, hovy, gneubig}@cs.cmu.edu Learning meaningful representations that maintain the content necessary for a particular task while filtering away detrimental variations is a problem of great interest in machine learning. In this paper, we tackle the problem of learning representations invariant to a specific factor or trait of data. The representation learning process is formulated as an adversarial minimax game. We analyze the optimal equilibrium of such a game and find that it amounts to maximizing the uncertainty of inferring the detrimental factor given the representation while maximizing the certainty of making task-specific predictions. On three benchmark tasks, namely fair and bias-free classification, language-independent generation, and lighting-independent image classification, we show that the proposed framework induces an invariant representation, and leads to better generalization evidenced by the improved performance. 1 Introduction How to produce a data representation that maintains meaningful variations of data while eliminating noisy signals is a consistent theme of machine learning research. In the last few years, the dominant paradigm for finding such a representation has shifted from manual feature engineering based on specific domain knowledge to representation learning that is fully data-driven, and often powered by deep neural networks [Bengio et al., 2013]. Being universal function approximators [Gybenko, 1989], deep neural networks can easily uncover the complicated variations in data [Zhang et al., 2017], leading to powerful representations. However, how to systematically incorporate a desired invariance into the learned representation in a controllable way remains an open problem. A possible avenue towards the solution is to devise a dedicated neural architecture that by construction has the desired invariance property. As a typical example, the parameter sharing scheme and pooling mechanism in modern deep convolutional neural networks (CNN) [Le Cun et al., 1998] take advantage of the spatial structure of image processing problems, allowing them to induce more generic feature representations than fully connected networks. Since the invariance we care about can vary greatly across tasks, this approach requires us to design a new architecture each time a new invariance desideratum shows up, which is time-consuming and inflexible. When our belief of invariance is specific to some attribute of the input data, an alternative approach is to build a probabilistic model with a random variable corresponding to the attribute, and explicitly reason about the invariance. For instance, the variational fair auto-encoder (VFAE) [Louizos et al., 2016] employs the maximum mean discrepancy (MMD) to eliminate the negative influence of specific nuisance variables , such as removing the lighting conditions of images to predict the person s identity. Similarly, under the setting of domain adaptation, standard binary adversarial cost [Ganin and Lempitsky, 2015, Ganin et al., 2016] and central moment discrepancy (CMD) [Zellinger et al., 2017] have been utilized to learn features that are domain invariant. However, all these invariance 31st Conference on Neural Information Processing Systems (NIPS 2017), Long Beach, CA, USA. inducing criteria suffer from a similar drawback, which is they are defined to measure the divergence between a pair of distributions. Consequently, they can only express the invariance belief w.r.t. a pair of values of the random variable at a time. When the attribute is a multinomial variable that takes more than two values, combinatorial number of pairs (specifically, O(n2)) have to be added to express the belief that the representation should be invariant to the attribute. The problem is even more dramatic when the attribute represents a structure that has exponentially many possible values (e.g. the parse tree of a sentence) or when the attribute is simply a continuous variable. Motivated by the aforementioned drawbacks and difficulties, in this work, we consider the problem of learning a feature representation with the desired invariance. We aim at creating a unified framework that is (1) generic enough such that it can be easily plugged into different models, and (2) more flexible to express an invariance belief in quantities beyond discrete variables with limited value choices. Specifically, inspired by the recent advancement of adversarial learning [Goodfellow et al., 2014], we formulate the representation learning as a minimax game among three players: an encoder which maps the observed data deterministically into a feature space, a discriminator which looks at the representation and tries to identify a specific type of variation we hope to eliminate from the feature, and a predictor which makes use of the invariant representation to make predictions as in typical discriminative models. We provide theoretical analysis of the equilibrium condition of the minimax game, and give an intuitive interpretation. On three benchmark tasks from different domains, we show that the proposed approach not only improves upon vanilla discriminative approaches that do not encourage invariance, but also outperforms existing approaches that enforce invariant features. 2 Adversarial Invariant Feature Learning In this section, we formulate our problem and then present the proposed framework of learning invariant features. (a) y and s are marginally independent (b) y and s are not marginally independent Figure 1: Dependencies between x, s, y, where x is the observation and y is the target to be predicted. s is the attribute to which the prediction should be invariant. Given observation/input x, we are interested in the task of predicting the target y based on the value of x using a discriminative approach. In addition, we have access to some intrinsic attribute s of x as well as a prior belief that the prediction result should be invariant to s. There are two possible dependency scenarios of x, s and y here: (1) s and y can be marginally independent. For example, in image classifications, lighting conditions s and identities of persons y are independent. The data generation process is s p(s), y p(y), x p(x | s, y). (2) In some cases, s and y are not marginally independent. For example, in fairness classifications, s are the sensitive factors such as age and gender. y can be the saving, credit and health condition of a person. s and y are related due to the inherent bias within the data. Using a latent variable z to model the dependency between s and y, the data generation process is z p(z), s p(s | z), y p(y | z), x p(x | s, y). We show the corresponding dependency graphs in Figure 1. Unlike vanilla discriminative models that outputs the conditional distribution p(y | x), we model p(y | x, s) to make predictions invariant to s. Our intuition is that, due to the explaining away effect, y and s are not independent when conditioned on x although they can be marginally independent. Consequently, p(y | x, s) is a more accurate estimation of y than p(y | x). Intuitively, this can inform and guide the model to remove information about undesired variations. For example, if we want to learn a representation of image x that is invariant to the lighting condition s, the model can learn to brighten the input if it knows the original picture is dark, and vice versa. Also, in multi-lingual machine translation, a word with the same surface form may have different meanings in different languages. For instance, gift means present in English but means poison in German. Hence knowing the language of a source sentence helps inferring the meaning of the sentence and conducting translation. As the input x can have highly complicated structure, we employ a dedicated model or algorithm to extract an expressive representation h from x. Thus, when we extract the representation h from x, we want the representation h to preserve variations that are necessary to predict y while eliminating information of s. To achieve the aforementioned goal, we employ a deterministic encoder E to obtain the representation by encoding x and s into h, namely, h = E(x, s). It should be noted here that we are using s as an additional input. Given the obtained representation h, the target y is predicted by a predictor M, which effectively models the distribution q M(y | h). By construction, instead of modeling p(y | x) directly, the discriminative model we formulate captures the conditional distribution p(y | x, s) with additional information coming from s. Surely, feeding s into the encoder by no means guarantees the induced feature h will be invariant to s. Thus, in order to enforce the desired invariance and eliminate variations of factor s from h, we set up an adversarial game by introducing a discriminator D which inspects the representation h and ensure that it is invariant to s. Concretely, the discriminator D is trained to predict s based on the encoded representation h, which effectively maximizes the likelihood q D(s | h). Simultaneously, the encoder fights to minimize the same likelihood of inferring the correct s by the discriminator. Intuitively, the discriminator and the encoder form an adversarial game where the discriminator tries to detect an attribute of the data while the encoder learns to conceal it. Note that under our framework, in theory, s can be any type of data as long as it represents an attribute of x. For example, s can be a real value scalar/vector, which may take many possible values, or a complex sub-structure such as the parse tree of a natural language sentence. But in this paper, we focus mainly on instances where s is a discrete label with multiple choices. We plan to extend our framework to deal with continuous s and structured s in the future. Formally, E, M and D jointly play the following minimax game: min E,M max D J(E, M, D) J(E, M, D) = E x,s,y p(x,s,y) [γ log q D(s | h = E(x, s)) log q M(y | h = E(x, s))] (1) where γ is a hyper-parameter to adjust the strength of the invariant constraint, and p(x, s, y) is the true underlying distribution that the empirical observations are drawn from. Note that the problem of domain adaption can be seen as a special case of our problem, where s is a Bernoulli variable representing the domain and the model only has access to the target y when s = source domain during training. 3 Theoretical Analysis In this section, we theoretically analyze, given enough capacity and training time, whether such a minimax game will converge to an equilibrium where variations of y are preserved and variations of s are removed. The theoretical analysis is done in a non-parametric limit, i.e., we assume a model with infinite capacity. In addition, we discuss the equilibriums of the minimax game when s is independent/dependent to y. Since both the discriminator and the predictor only use h which is transformed deterministically from x and s, we can substitute x with h and define a joint distribution p(h, s, y) of h, s and y as follows p(h, s, y) = Z x p(x, s, h, y)dx = Z x p(x, s, y)p E(h | x, s)dx = Z x p(x, s, y)δ(E(x, s) = h)dx Here, we have used the fact that the encoder is a deterministic transformation and thus the distribution p E(h | x, s) is merely a delta function denoted by δ( ). Intuitively, h absorbs the randomness in x and has an implicit distribution of its own. Also, note that the joint distribution p(h, s, y) depends on the transformation defined by the encoder. Thus, we can equivalently rewrite objective (1) as J(E, M, D) = E h,s,y p(h,s,y) [γ log q D(s | h) log q M(y | h)] (2) To analyze the equilibrium condition of the new objective (2), we first deduce the optimal discriminator D and the optimal predictor M for a given encoder E and then prove the global optimality of the minimax game. Claim 1. Given a fixed encoder E, the optimal discriminator outputs q D(s | h) = p(s | h) and the optimal predictor corresponds to q M(y | h) = p(y | h). Proof. The proof uses the fact that the objective is functionally convex w.r.t. each distribution, and by taking the variations we can obtain the stationary point for q D and q M as a function of q. The detailed proof is included in the supplementary material A. Note that the optimal q D(s | h) and q M(y | h) given in Claim 1 are both functions of the encoder E. Thus, by plugging q D and q M into the original minimax objective (2), it can be simplified as a minimization problem only w.r.t. the encoder E with the following form: min E J(E) = min E E h,s,y q(h,s,y) [γ log q(s | h) log q(y | h)] = min E γH( q(s | h)) + H( q(y | h)) (3) where H( q(s | h)) is the conditional entropy of the distribution q(s | h). Equilibrium Analysis As we can see, the objective (3) consists of two conditional entropies with different signs. Optimizing the first term amounts to maximizing the uncertainty of inferring s based on h, which is essentially filtering out any information of s from the representation. On the contrary, optimizing the second term leads to increasing the certainty of predicting y based on h. Implicitly, the objective defines the equilibrium of the minimax game. Win-win equilibrium: Firstly, for cases where the attribute s is entirely irrelevant to the prediction task (corresponding to the dependency graph shown in Figure 1a), the two terms can reach the optimum at the same time, leading to a win-win equilibrium. For example, with the lighting condition of an image removed, we can still/better classify the identity of the people in that image. With enough model capacity, the optimal equilibrium solution would be the same regardless of the value of γ. Competing equilibrium: However, there are cases where these two optimization objectives are competing. For example, in fair classifications, sensitive factors such as gender and age may help the overall prediction accuracies due to inherent biases within the data. In other words, knowing s may help in predicting y since s and y are not marginally independent (corresponding to the dependency graph shown in Figure 1b). Learning a fair/invariant representation is harmful to predictions. In this case, the optimality of these two entropies cannot be achieved simultaneously, and γ defines the relative strengths of the two objectives in the final equilibrium. 4 Parametric Instantiation of the Proposed Framework To show the general applicability of our framework, we experiment on three different tasks including sentence generation, image classification and fair classifications. Due to the different natures of data of x and y, here we present the specific model instantiations we use. Sentence Generation We use multi-lingual machine translation as the testbed for sentence generation. Concretely, we have translation pairs between several source languages and a target language. x is the source sentence to be translated and s is a scalar denoting which source language x belongs to. y is the translated sentence for the target language. Recall that s is used as an input of E to obtain a language-invariant representation. To make full use of s, we employ separate encoders Encs for sentences in each language s. In other words, h = E(s, x) = Encs(x) where each Encs is a different encoder. The representation of a sentence is captured by the hidden states of an LSTM encoder [Hochreiter and Schmidhuber, 1997] at each time step. We employ a single LSTM predictor for different encoders. As often used in language generation, the probability q M output by the predictor is parametrized by an autoregressive process, i.e., q M(y1:T | h) = t=1 q M(yt|y