# fairbatch_batch_selection_for_model_fairness__47092975.pdf Published as a conference paper at ICLR 2021 FAIRBATCH: BATCH SELECTION FOR MODEL FAIRNESS Yuji Roh1, Kangwook Lee2, Steven Euijong Whang 1, Changho Suh1 1KAIST, {yuji.roh,swhang,chsuh}@kaist.ac.kr 2University of Wisconsin-Madison, kangwook.lee@wisc.edu Training a fair machine learning model is essential to prevent demographic disparity. Existing techniques for improving model fairness require broad changes in either data preprocessing or model training, rendering themselves difficult-to-adopt for potentially already complex machine learning systems. We address this problem via the lens of bilevel optimization. While keeping the standard training algorithm as an inner optimizer, we incorporate an outer optimizer so as to equip the inner problem with an additional functionality: Adaptively selecting minibatch sizes for the purpose of improving model fairness. Our batch selection algorithm, which we call Fair Batch, implements this optimization and supports prominent fairness measures: equal opportunity, equalized odds, and demographic parity. Fair Batch comes with a significant implementation benefit it does not require any modification to data preprocessing or model training. For instance, a single-line change of Py Torch code for replacing batch selection part of model training suffices to employ Fair Batch. Our experiments conducted both on synthetic and benchmark real data demonstrate that Fair Batch can provide such functionalities while achieving comparable (or even greater) performances against the state of the arts. Furthermore, Fair Batch can readily improve fairness of any pre-trained model simply via fine-tuning. It is also compatible with existing batch selection techniques intended for different purposes, such as faster convergence, thus gracefully achieving multiple purposes. 1 INTRODUCTION Model fairness is becoming essential in a wide variety of machine learning applications. Fairness issues often arise in sensitive applications like healthcare and finance where a trained model must not discriminate among different individuals based on age, gender, or race. While many fairness techniques have recently been proposed, they require a range of changes in either data generation or algorithmic design. There are two popular fairness approaches: (i) pre-processing where training data is debiased (Choi et al., 2020) or re-weighted (Jiang and Nachum, 2020), and (ii) in-processing in which an interested model is retrained via several fairness approaches such as fairness objectives (Zafar et al., 2017a;b), adversarial training (Zhang et al., 2018), or boosting (Iosifidis and Ntoutsi, 2019); see more related works discussed in depth in Sec. 5. However, these approaches may require nontrivial re-configurations in modern machine learning systems, which often consist of many complex components. In an effort to enable easier-to-reconfigure implementation for fair machine learning, we address the problem via the lens of bilevel optimization where one problem is embedded within another. While keeping the standard training algorithm as the inner optimizer, we design an outer optimizer that equips the inner problem with an added functionality of improving fairness through batch selection. Our main contribution is to develop a batch selection algorithm (called Fair Batch) that implements this optimization via adjusting the batch sizes w.r.t. sensitive groups based on the fairness measure of an intermediate model, measured in the current epoch. For example, consider a task of predicting whether individual criminals re-offend in the future subject to satisfying equalized odds (Hardt et al., 2016) where the model accuracies must be the same across sensitive groups. In case the model is less Corresponding author Published as a conference paper at ICLR 2021 0.2 0.4 0.6 0.8 1.0 ED disparity 0.0 0.2 0.4 0.6 0.8 1.0 Initial point Fair Batch result (a) Accuracy difference across sensitive groups in the sense of equalized odds (that we denote as ED disparity ) when running Fair Batch on the Pro Publica COMPAS dataset. fairsampler = Fair Batch(model, criterion, train_data, batch_size, alpha, target_fairness) loader = Data Loader(train_data, sampler = fairsampler) for epoch in range(epochs): for i, data in enumerate(loader): # get the inputs; data is a list of [inputs, labels] inputs, labels = data ... model forward, backward, and optimization ... (b) Py Torch code for model training where the batch selection is replaced with Fair Batch. Figure 1: The black path in the left figure shows how Fair Batch adjusts the batch-size ratios of sensitive groups using two reweighting parameters λ1 and λ2 (hyperparameters employed in our framework to be described in Sec. 2), thus minimizing their ED disparity, i.e., achieving equalized odds. The code in the right figure shows how easily Fair Batch can be incorporated in a Py Torch machine learning pipeline. It requires a single-line change to replace the existing sampler with Fair Batch, marked in blue. accurate for a certain group, Fair Batch increases the batch-size ratio of that group in the next batch see Sec. 3 for our adjusting mechanism described in detail. Fig. 1a shows Fair Batch s behavior when running on the Pro Publica COMPAS dataset (Angwin et al., 2016). For equalized odds, our framework (to be described in Sec. 2) introduces two reweighting parameters (λ1, λ2) for the purpose of adjusting the batch-size ratios of two sensitive groups (in this experiment, men and women). After a few epochs, Fair Batch indeed achieves equalized odds, i.e., the accuracy disparity between sensitive groups conditioned on the true label (denoted as ED disparity ) is minimized. Fair Batch also supports other prominent group fairness measures: equal opportunity (Hardt et al., 2016) and demographic parity (Feldman et al., 2015). A key feature of Fair Batch is in its great usability and simplicity. It only requires a slight modification in the batch selection part of model training as demonstrated in Fig. 1b and does not require any other changes in data preprocessing or model training. Experiments conducted both on synthetic and benchmark real datasets (Pro Publica COMPAS (Angwin et al., 2016), Adult Census (Kohavi, 1996), and UTKFace (Zhang et al., 2017)) show that Fair Batch exhibits greater (at least comparable) performances relative to the state of the arts (both spanning pre-processing (Kamiran and Calders, 2011; Jiang and Nachum, 2020) and in-processing (Zafar et al., 2017a;b; Zhang et al., 2018; Iosifidis and Ntoutsi, 2019) techniques) w.r.t. all aspects in consideration: accuracy, fairness, and runtime. In addition, Fair Batch can improve fairness of any pre-trained model via fine-tuning. For example, Sec. 4.2 shows how Fair Batch reduces the ED disparities of Res Net18 (He et al., 2016) and Goog Le Net (Szegedy et al., 2015) pre-trained models. Finally, Fair Batch can be gracefully merged with other batch selection techniques typically used for faster convergence, thereby improving fairness faster as well. Notation Let w be the parameter of an interested classifier. Let x X be an input feature to the classifier, and let ˆy Y be the predicted class. Note that ˆy is a function of (x, w). We consider group fairness that intends to ensure fairness across distinct sensitive groups (e.g., men versus women). Let z Z be a sensitive attribute (e.g., gender). Consider the 0/1 loss: ℓ(y, ˆy) = 1(y = ˆy), and let m be the total number of train samples. Let Ly,z(w) be the empirical risk aggregated over samples subject to y = y and z = z: Ly,z(w) := 1 my,z P i:yi=y,zi=z ℓ(yi, ˆyi) where my,z := |{i : yi = y, zi = z}|. Similarly, we define Ly, (w) := 1 my, P i:yi=y ℓ(yi, ˆyi) and L ,z(w) := 1 m ,z P i:zi=z ℓ(yi, ˆyi) where my, := |{i : yi = y}| and m ,z := |{i : zi = z}|. The overall empirical risk is written as L(w) = 1 i ℓ(yi, ˆyi). We utilize for gradient and for subdifferential. 2 BILEVEL OPTIMIZATION FOR FAIRNESS In order to systematically design an adaptive batch selection algorithm, we formalize an implicit connection between adaptive batch selection and bilevel optimization. Bilevel optimization consists of an outer optimization problem and an inner optimization problem. The inner optimizer solves an Published as a conference paper at ICLR 2021 Algorithm 1: Bilevel optimization with Minibatch SGD Minibatch sampling distribution Uniform sampling for each epoch do Draw minibatches according to minibatch sampling distribution for each minibatch do w Minibatch SGD(w, each minibatch) Update minibatch sampling distribution inner optimization problem, and the outer optimizer solves an outer optimization problem based on the outcomes of inner optimization. By viewing the standard training algorithm such as stochastic gradient descent (SGD) (Bottou, 2010) as an inner optimizer and viewing the batch selection algorithm as an outer optimizer, the process of training a fair classifier can be seen as a process of solving a bilevel optimization problem. Batch selection + minibatch SGD = bilevel optimization solver Consider a scenario where one is minimizing the overall empirical risk L(w) via minibatch SGD. The minibatch SGD algorithm picks b of the m indices uniformly at random, say j1, j2, . . . , jb, and updates its iterate with 1 b Pb i=1 ℓ(yji, ˆyji), called a batch gradient. Note that a batch gradient is an unbiased estimate of the true gradient L(w). Since the empirical risk minimization (ERM) formulation does not take a fairness criterion into account, its minimizer usually does not satisfy the desired fairness criterion. To address this limitation of ERM, we adjust the way minibatches are drawn so that the desired fairness guarantee is satisfied. For instance, as we described in the introduction, we can draw minibatches with a larger number of train samples from a certain sensitive group so as to achieve a higher accuracy w.r.t. the group. Once the minibatch distribution deviates from the uniform distribution, the batch gradient estimate is not anymore an unbiased gradient estimate of the overall empirical risk. Instead, it is an unbiased estimate of a reweighted empirical risk. In other words, if we draw train example i with probability pi for all i such that P pi = 1, the batch gradient is an unbiased estimate of L (w) = P i piℓ(yi, ˆyi). This observation enables us the following bilevel optimization-based interpretation of how batch selection interacts with inner optimization algorithm. At initialization, minibatch SGD optimizes the (unweighted) empirical risk. Based on the outcome of the inner optimization, the outer optimizer refines p := (p1, p2, . . . , pm), the sampling probability of each train example. The inner optimizer now takes minibatches drawn from a new distribution and reoptimizes the inner objective function. Due to the new minibatch distribution, the inner objective now becomes a reweighted empirical risk w.r.t. p. This procedure is repeated until convergence. See Algorithm 1 for pseudocode. Therefore, a batch selection algorithm together with an inner optimization algorithm can be viewed as a pair of outer optimizer and inner optimizer for the following bilevel optimization problem: min p Cost(wp), wp = arg min w L (w), where Cost( ) captures the goal of the optimization. Two questions arise. First, how can we design the cost function to capture a desired fairness criterion? Second, how can we design an update rule for the outer optimizer? Can we develop an algorithm with a provable guarantee? In the rest of this section, we show how one can design proper cost functions to capture various fairness criteria. In Sec. 3, we will develop an efficient update rule of Fair Batch. Equal opportunity For illustrative purpose, assume for now the binary setting (Y = Z = {0, 1}). A model satisfies equal opportunity (Hardt et al., 2016) if we have equal positive prediction rates conditioned on y = 1, i.e., L1,0(w) = L1,1(w). Since the ERM formulation does not take the fairness criterion into account, these two quantities differ in general. To mitigate this, we adjust the sampling probability between L1,0(w) and L1,1(w). More specifically, we propose the following procedure to draw a sample. First, we randomly pick which subset of data to sample data from. We pick the set y = 1, z = 0 with probability λ, the set y = 1, z = 1 with probability m1, m λ, and the set y = 0 with probability m0, m . We then pick a sample from the chosen set, uniformly at random. Published as a conference paper at ICLR 2021 This leaves us with a single-dimensional outer optimization variable λ, which controls the sampling bias between data with y = 1, z = 0 and data with y = 1, z = 1. Also, we design the cost function as |L1,0(wλ) L1,1(wλ)| to capture the equal opportunity criterion. Thus, we have the following bilevel optimization problem: min λ [0, m1, m ] |L1,0(wλ) L1,1(wλ)|, wλ = arg min w λL1,0(w) + ( m1, m λ)L1,1(w) + m0, Equalized odds Similarly, we can design a bilevel optimization problem to capture equalized odds (Hardt et al., 2016), which desires the prediction to be independent from the sensitive attribute conditional on the true label, i.e., L0,0(w) = L0,1(w) and L1,0(w) = L1,1(w). Again, the empirical risk minimizer does not satisfy these two conditions in general. To mitigate these disparities, we adjust (i) the sampling probability between L0,0(w) and L0,1(w) and (ii) the sampling probability between L1,0(w) and L1,1(w). To achieve this, we use the following procedure to draw a sample. First, we pick the set y = 0, z = 0 with probability λ1, the set y = 0, z = 1 with probability m0, m λ1, the set y = 1, z = 0 with probability λ2, and the set y = 1, z = 1 with probability m1, m λ2. We then pick one data point at random from the chosen set. This leaves us with a two-dimensional outer optimization variable λ = (λ1, λ2). To capture the equalized odds criterion, we design the outer objective function as: max{|L0,0(w) L0,1(w)|, |L1,0(w) L1,1(w)|}. This gives us the following bilevel optimization problem: min λ [0, m0, m ] [0, m1, m ] max{|L0,0(wλ) L0,1(wλ)|, |L1,0(wλ) L1,1(wλ)|}, wλ = arg min w λ1L0,0(w) + ( m0, m λ1)L0,1(w) + λ2L1,0(w) + ( m1, m λ2)L1,1(w). Demographic parity Demographic parity (Feldman et al., 2015) is satisfied if two sensitive groups have equal positive prediction rates. If my,z s are all equal, then L0,0(w) = L1,0(w) and L0,1(w) = L1,1(w) can serve as a sufficient condition for demographic parity; see Sec. A.1 for why and how to handle demographic parity when this condition does not hold. To satisfy this sufficient condition, we now adjust (i) the the sampling probability between L0,0(w) and L1,0(w) and (ii) the the sampling probability between L0,1(w) and L1,1(w). This gives us the following bilevel optimization problem: min λ [0, m ,0 m ] [0, m ,1 m ] max{|L0,0(wλ) L1,0(wλ)|, |L0,1(wλ) L1,1(wλ)|}, wλ = arg min w λ1L0,0(w) + ( m ,0 m λ1)L1,0(w) + λ2L0,1(w) + ( m ,1 m λ2)L1,1(w). Beyond binary labels/sensitive attributes While the previous examples assumed binary-valued labels and sensitive attributes, our framework is applicable to the cases where the alphabet sizes are beyond binary. As an example, consider the equal opportunity criterion when Z = {0, 1, . . . , nz 1}. The condition reads L1,0(w) = L1,1(w) = = L1,nz 1(w). To satisfy this condition, we adjust the sampling probability between L1,j(w) s by introducing nz 2 -dimensional outer optimization variable λ, and design the outer objective function as maxj1,j2 Z |L1,j1(w) L1,j2(w)|. In our implementation, however, we only use (nz 1)-dimensional disparity objectives as an approximation (i.e., maxj1 {0,1,...,nz 2} |L1,j1(w) L1,j1+1(w)|) for better efficiency. Suppose the level of disparity is ϵ when Fair Batch compares all possible combination pairs of sensitive groups. Now suppose we only optimize on the sequential (nz 1) disparity objectives. Then we will fail to ensure that other objectives like |L1,3(w) L1,1(w)| are within ϵ. In the worst case, the objective |L1,nz 1(w) L1,1(w)| may be (nz 1) ϵ, as we only guarantee that each |L1,j1(w) L1,j1+1(w)| ϵ. If ϵ is small enough, the disparity of our approximation becomes reasonable as well. One can also handle other fairness criteria in a similar way. 3 UPDATE RULE OF FAIRBATCH We design efficient update rules of Fair Batch for different numbers of disparities. Let us define d as the dimension of the outer optimization variable λ, which is the same as the total number of disparities. We first analyze the simplest case where d = 1. We show that a simple gradient descent algorithm can provably solve the outer optimization problem. The equal opportunity example in the previous section falls in this category. We then extend the algorithm developed for the one-dimensional case to the multi-dimensional (d > 1) case. Equalized odds and demographic parity fall in this category. Published as a conference paper at ICLR 2021 3.1 UPDATE RULE FOR d = 1 When d = 1, the general form of our bilevel optimization problem can be written as follows: min λ [0,c1] |f1(wλ) g1(wλ)|, wλ = arg min w λf1(w) + (c1 λ)g1(w) + h(w), where c1 > 0 a constant. Let F(λ) = |f1(wλ) g1(wλ)|. The following lemma shows that F(λ) is quasiconvex in λ under some mild conditions, and its signed gradient can be efficiently computed. Lemma 1 (Quasi-convexity of F(λ)). For d = 1, if f1( ), g1( ), and h( ) satisfy 1. h(w) = 0 or 2. if f1( ), g1( ), and h( ) are twice differentiable, λ 2f1(wλ) + (c1 λ) 2g1(wλ) + 2h(wλ) 0 for every λ [0, c1], then F(λ) is quasi-convex, i.e., F(tλ + (1 t)λ ) max F(λ), F(λ ) for all t [0, 1] and λ, λ . Also, if F( ) = 0, then λF(λ) = {v} and sign (v) = sign (g1(wλ) f1(wλ)). Remark 1. The quasiconvexity of F(λ) is valid when at least one of the conditions in Lemma 1 holds. For the second condition, if f1( ), g1( ), and h( ) are convex, this condition will hold unless all the three functions share their stationary points, which is very unlikely. While there is no theoretical guarantee for the non-convex settings, Fair Batch still shows on par or better results than the other fairness approaches in general settings where the functions may not be convex (see Sec. 4). The proof for Lemma 1 can be found in Sec. A.2. Note that quasiconvexity immediately implies a unique minimum (Boyd et al., 2004). Thus, we design the following signed gradient-based optimization algorithm: t {0, 1, . . .} : λ(t+1) = λ(t) α sign(g1(wλ) f1(wλ)). This algorithm increases λ by α if f1(wλ) g1(wλ) and decreases λ by α otherwise. Recall that this is consistent with our intuition: It increases the sampling probability of a disadvantageous group and decreases that of an advantageous group. The following proposition shows that the proposed algorithm converges to the optimal solution. Proposition 1. Let λ = arg minλ F(λ) and t Z0+. Then, |λ(t) λ | max{|λ(0) λ | tα, α}. Remark 2. F(λ) is not necessarily convex even when we assume the inner objective functions f1( ) and g1( ) are convex or even strongly convex. See Sec. A.3 for a counter example. 3.2 UPDATE RULE FOR d 1 We now develop an efficient update algorithm for the following general bilevel optimization: min λ Λ max i=1,...,d |fi(wλ) gi(wλ)|, wλ = arg min w Pd i=1 [λifi(w) + (ci λi)gi(w)] + h(w). Here, Λ = [0, c1] [0, c2] [0, cd], where ci s are some positive constants. Denoting by F(λ) the outer objective function, let us first derive the gradient of it. Under some mild conditions (see Sec. A.4) on fi( ) s, gi( ) s, and h( ): γi := sign (gi (w) fi (w))( fi (w) gi (w)) H 1 λ ( fi(w) gi(w)) λi F(λ), i, where i = arg maxi |fi(w) gi(w)|, and Hλ is positive definite. See Sec. A.4 for the derivation. Since subdifferential is always a convex set, it follows that γ := (γ1, γ2, . . . , γd) λF(λ). Computing the subgradient γ requires us to compute Hλ, which involves the Hessian matrices of the inner objective function. To avoid this expensive computation, we approximate γ (0, 0, . . . , γi , . . . , 0). See Sec. A.5 for the rationale and intuition behind this approximation. Then, similar to the case of d = 1, we have sign(γ) = (0, 0, . . . , sign (gi (wλ) fi (wλ)) , 0, . . . , 0). This gives us the general update rule of Fair Batch (see Sec. A.6 for pseudocode): t {0, 1, . . .} : λ(t+1) i = λ(t) i α sign(gi (wλ) fi (wλ)), λ(t+1) i = λ(t) i , i = i . Published as a conference paper at ICLR 2021 4 EXPERIMENTS We use logistic regression in all experiments except for Sec. 4.2 where we fine-tune Res Net18 (He et al., 2016) and Goog Le Net (Szegedy et al., 2015) in order to demonstrate Fair Batch s ability to improve fairness of pre-trained models. We evaluate all models on separate test sets and repeat all experiments with 10 different random seeds. We use Py Torch, and our experiments are performed on a server with Intel i7-6850 CPUs and NVIDIA TITAN Xp GPUs. See Sec. B.1 for more details. Measuring Fairness Here we first focus on the equal opportunity (EO) and demographic parity (DP) measures in Sec. 4.1 and Sec. 4.3. The equalized odds (ED) measure is used in Sec. 4.2 and Sec. B.2. To quantify EO, ED, and DP, we compute the disparity between sensitive groups: EO disparity = maxz Z | Pr(ˆy = 1|z = z, y = 1) Pr(ˆy = 1|y = 1)|, ED disparity = maxz Z,y Y,ˆy ˆY | Pr(ˆy = ˆy|z = z, y = y) Pr(ˆy = ˆy|y = y)|, and DP disparity = maxz Z | Pr(ˆy = 1|z = z) Pr(ˆy = 1)|. As we discussed in Sec. 3, EO has a single-dimension outer optimization where the number of disparities d = 1 while ED and DP have multi-dimensional outer optimizations where d > 1. Datasets We generate a synthetic dataset of 3,000 examples with two non-sensitive attributes (x1, x2), a binary sensitive attribute z, and a binary label y, using a method similar to the one in (Zafar et al., 2017a). A tuple (x1, x2, y) is randomly generated based on the two Gaussian distributions: (x1, x2)|y = 0 N([ 2; 2], [10, 1; 1, 3]) and (x1, x2)|y = 1 N([2; 2], [5, 1; 1, 5]). For z, we generate biased data using an unfair scenario Pr(z = 1) = Pr((x 1, x 2)|y = 1)/[Pr((x 1, x 2)|y = 0)+ Pr((x 1, x 2)|y = 1)] where (x 1, x 2) = (x1 cos(π/4) x2 sin(π/4), x1 sin(π/4) + x2 cos(π/4)). We use the real benchmark datasets: Pro Publica COMPAS (Angwin et al., 2016) and Adult Census (Kohavi, 1996) datasets with 5,278 and 43,131 examples, respectively. We use the same pre-processing as in IBM s AI Fairness 360 (Bellamy et al., 2019) and use GENDER as the sensitive attribute. We also employ the UTKFace dataset (Zhang et al., 2017) with 23,708 images to demonstrate the fine-tuning ability of Fair Batch in Sec. 4.2. Baselines We employ three types of baselines: (1) non-fair training with logistic regression (LR); (2) fair training via pre-processing; and (3) fair training via in-processing. For pre-processing methods, we first consider a simple approach that we call Cutting, which evens the data sizes of sensitive groups via saturating them to the smallest-group data size. One can think of a similar alternative approach: Boosting all of the smaller-group data sizes to the largest one, but we do not report herein due to similar performances that we found relative to Cutting. The other two are the state of the arts: reweighing (Kamiran and Calders, 2011) (RW) and Label Bias Correction (Jiang and Nachum, 2020) (LBC). RW intends to balance importance levels across sensitive groups via example weighting, but sticks with these weights throughout the entire model training, unlike Fair Batch. LBC iteratively trains an entire model with example weighting towards an unbiased data distribution. For in-processing methods, we compare with the following three: Fairness Constraints (Zafar et al., 2017a;b) (FC), Adversarial Debiasing (Zhang et al., 2018) (AD), and Ada Fair (Iosifidis and Ntoutsi, 2019). FC incorporates a regularization term in an effort to reduce the disparities among sensitive groups. AD is an adversarial learning approach that intends to maximize the independence between the predicted labels and sensitive attributes. In our experiments, a slight modification is made to AD for improving training stability: Not employing one regularization term used for restricting the training direction. Ada Fair is an ensemble technique that equips the prominent Ada Boost (Friedman et al., 2000) with a fairness aspect. Here the examples that lead to unfair and inaccurate performances are considered to be the difficult instances. In our experiments, natural generalization of Ada Fair intended for ED is made to encompass EO and DP; see Sec. B.3 for the generalization. While Ada Fair bears spiritual similarity to Fair Batch in a sense that mistreated examples are weighted progressively, it comes with a significant distinction in update scale. It is basically a boosting technique; hence such updates are done in distinctive predictors through different rounds; see Sec. 5 for details. Fair Batch Settings To set α, we start from a candidate set of values within the range [0.0001, 0.05] and use cross-validation on the training set to choose the value that results in the highest accuracy with low fairness violation. The default batch sizes are: 100 (synthetic); 200 (COMPAS), 1,000 (Adult Census); and 32 (UTKFace). Published as a conference paper at ICLR 2021 Table 1: Performances on the synthetic, COMPAS, and Adult Census test sets w.r.t. equal opportunity (EO). We compare Fair Batch with three types of baselines: (1) non-fair method: LR; (2) fair training via pre-processing: Cutting, RW (Kamiran and Calders, 2011), and LBC (Jiang and Nachum, 2020); (3) fair training via in-processing: FC (Zafar et al., 2017b), AD (Zhang et al., 2018), and Ada Fair (Iosifidis and Ntoutsi, 2019). Experiments are repeated 10 times. Synthetic COMPAS Adult Census Method Acc. EO Disp. Epochs Acc. EO Disp. Epochs Acc. EO Disp. Epochs LR .885 .000 .115 .000 400 .681 .002 .239 .006 300 .845 .001 .054 .005 300 Cutting .858 .001 .028 .002 800 .674 .005 .055 .018 600 .802 .002 .054 .007 600 RW .858 .000 .020 .000 800 .685 .000 .137 .000 300 .835 .001 .134 .006 100 LBC .858 .001 .022 .000 11200 .673 .002 .031 .006 3900 .841 .003 .011 .003 6300 FC .833 .001 .007 .000 700 .656 .006 .059 .028 100 .844 .001 .021 .004 300 AD .837 .010 .026 .007 200 .683 .001 .067 .029 300 .841 .003 .016 .005 400 Ada Fair .868 .000 .043 .001 16000 .664 .004 .018 .004 9600 .844 .001 .038 .004 9000 Fair Batch .855 .000 .012 .001 300 .681 .001 .022 .005 100 .844 .001 .011 .003 400 Table 2: Performances on the synthetic, COMPAS, and Adult Census test sets w.r.t. demographic parity (DP). The other settings are identical to those in Table 1. Synthetic COMPAS Adult Census Method Acc. DP Disp. Epochs Acc. DP Disp. Epochs Acc. DP Disp. Epochs LR .885 .000 .257 .000 400 .681 .002 .192 .006 300 .845 .001 .125 .001 300 Cutting .885 .001 .258 .001 500 .677 .004 .205 .025 400 .846 .001 .123 .002 300 RW .857 .000 .164 .001 400 .685 .000 .103 .000 300 .835 .001 .052 .003 300 LBC .768 .000 .042 .001 16000 .671 .002 .032 .009 7800 .815 .003 .011 .002 12600 FC .785 .013 .058 .010 600 .684 .001 .083 .015 70 .812 .009 .025 .006 100 AD .812 .008 .063 .014 700 .683 .002 .054 .019 550 .815 .008 .018 .004 400 Ada Fair .784 .001 .089 .001 52000 .642 .004 .033 .011 6300 .825 .002 .040 .001 27000 Fair Batch .794 .001 .040 .001 450 .681 .001 .036 .023 300 .823 .001 .010 .005 600 4.1 ACCURACY, FAIRNESS, AND RUNTIME Table 1 compares Fair Batch against the other approaches on the synthetic, COMPAS, and Adult Census test sets w.r.t. accuracy, EO disparity, and complexity (reflected in the number of epochs). In Sec. B.4, we also present the convergence plot of EO disparity as a function of the number of epochs. LR in row 1 is logistic regression without any fairness technique. The pre-processing techniques in rows 2 4 reduce EO disparity yet while sacrificing the accuracy performance. The in-processing techniques in rows 5 7 further reduce EO disparity yet still sacrificing accuracy. Fair Batch, presented in the last row, offers comparable (or even greater) fairness performance while sacrificing less accuracy. We also present accuracy and fairness trade-off curves of Fair Batch in Sec. B.5. One key implementation benefit is reflected in the small numbers of epochs. We also obtain consistent wall clock times, presented in Sec. B.6. As mentioned earlier, Ada Fair is the most similar in spirit to Fair Batch as it adjusts example weights based on the fairness performances of prior models. We demonstrate in Sec. B.7 that Fair Batch and Ada Fair indeed show similar convergence behaviors yet in different scales (rounds for Ada Fair vs. epochs for Fair Batch). One distinctive feature of Fair Batch relative to Ada Fair is the use of a single model training, thus enabling much faster speed (22.5 96x). We also make similar comparisons yet w.r.t. another fairness measure: DP disparity. See Table 2. Recall that minimizing DP disparity involves adjusting two hyperparameters (λ1, λ2), which also means that d = 2. Although Fair Batch s theoretical guarantees hold only when using one hyperparameter (i.e., d = 1), we nonetheless see similar results where Fair Batch is on par or better than the other approaches, while being the most robust in all aspects. Published as a conference paper at ICLR 2021 Table 3: Performances of the pre-trained models fine-tuned with Fair Batch on the UTKFace test set w.r.t. equalized odds (ED) for two fairness scenarios. While Tables 1 and 2 already demonstrate Fair Batch s performance against the state of the arts, the emphasis here is more on Fair Batch s usability where it is easy to adopt and yet improves the fairness of existing models. z: RACE, y: GENDER z: RACE, y: AGE Pre-trained model Method Acc. ED Disp. Epochs Acc. ED Disp. Epochs Res Net18 Original .893 .002 .086 .012 19 .722 .011 .311 .053 10 Cutting .592 .020 .099 .014 18 .466 .018 .139 .021 20 Fair Batch .894 .002 .063 .013 30 .758 .004 .220 .016 10 Goog Le Net Original .888 .003 .105 .016 20 .746 .006 .294 .034 14 Cutting .606 .010 .076 .017 20 .495 .017 .168 .033 9 Fair Batch .891 .002 .061 .006 11 .741 .018 .202 .019 8 4.2 FINE-TUNING PRETRAINED UNFAIR MODELS FOR FAIRNESS While Tables 1 and 2 already demonstrate Fair Batch s performance against the state of the arts, in this section we emphasize the usability of Fair Batch by showing how it can improve fairness of any pretrained unfair model via fine-tuning and only compare it with Cutting, which is also easy to adopt. Table 3 shows how Fair Batch improves fairness of pre-trained models (Res Net18 (He et al., 2016) and Goog Le Net (Szegedy et al., 2015)) on the UTKFace dataset (Zhang et al., 2017). Each image has three types of attributes: GENDER, RACE, and AGE. We use RACE as the sensitive attribute and consider two scenarios where the label attribute is GENDER or AGE. While GENDER is binary, AGE is multi-valued (<21, 21 40, 41 60, and >60), so we extend Fair Batch in a straightforward fashion; see Sec. B.8 for details. Both Cutting and Fair Batch reduce the ED disparities of the original pre-trained models. However, only Fair Batch does so without sacrificing accuracy performance. 4.3 COMPATIBILITY WITH OTHER BATCH SELECTION TECHNIQUES We demonstrate another key aspect of Fair Batch: Compatibility with existing batch selection approaches that use importance sampling for faster convergence in training. The key functionality of the prior batch selection techniques is that examples considered to be important are given higher weights so as to be sampled more frequently. Fair Batch can easily be tuned to accommodate such functionality: determining the batch-ratios of sensitive groups and then sampling using the importance weights per group. We evaluate Fair Batch combined with one prominent technique, loss-based weighting (Loshchilov and Hutter, 2016), on our synthetic dataset using EO and DP. We find that Fair Batch indeed converges more quickly. It uses about 50 fewer epochs with similar fairness performances; see Sec. B.9 for the EO and DP convergence plots. 5 RELATED WORK Model Fairness Various fairness measures have been proposed to reflect legal and social issues (Narayanan, 2018). Among them, we focus on group fairness measures: equal opportunity (Hardt et al., 2016), equalized odds (Hardt et al., 2016), and demographic parity (Feldman et al., 2015). A variety of techniques have been proposed and can be categorized into (1) pre-processing techniques (Kamiran and Calders, 2011; Zemel et al., 2013; Feldman et al., 2015; du Pin Calmon et al., 2017; Choi et al., 2020; Jiang and Nachum, 2020), which debias or reweight data, (2) in-processing techniques (Kamishima et al., 2012; Zafar et al., 2017a;b; Agarwal et al., 2018; Zhang et al., 2018; Cotter et al., 2019; Roh et al., 2020), which tailor the model training for fairness, and (3) postprocessing techniques (Kamiran et al., 2012; Hardt et al., 2016; Pleiss et al., 2017; Chzhen et al., 2019), which perturb only the model output without touching upon the inside. Most of these methods require broad changes in data preprocessing, model training, or model outputs in machine learning systems (Venkatasubramanian, 2019). In contrast, Fair Batch only requires a single-line change in code to replace batch selection while achieving comparable or even greater performances against the state of the arts. Published as a conference paper at ICLR 2021 Among the fairness techniques, Ada Fair (Iosifidis and Ntoutsi, 2019) is the most similar in spirit to Fair Batch. Ada Fair extends the well-known Ada Boost (Friedman et al., 2000) where examples that lead to poor accuracy or fairness are boosted, i.e., given higher weights during the next round of training a new model that is added to the ensemble. In comparison, Fair Batch is based on theoretical foundations of bilevel optimization and effectively performs the reweighting during each epoch (not through rounds), which leads to an order of magnitude improvement in speed as shown in Sec. 4.1. Although not our immediate focus, there are other noteworthy fairness measures: (1) individual fairness (Dwork et al., 2012) where close examples should be treated similarly, (2) causality-based fairness (Kilbertus et al., 2017; Kusner et al., 2017; Zhang and Bareinboim, 2018; Nabi and Shpitser, 2018; Khademi et al., 2019), which aims to overcome the limitations of non-causal approaches by understanding the causal relationship between attributes, and (3) distributionally robust optimization (DRO) (Sinha et al., 2017)-based fairness (Hashimoto et al., 2018), which achieves accuracy parity without the knowledge of sensitive attribute by balancing the risks across all distributions. Extending Fair Batch to support these measures is an interesting future work. Finally, Chouldechova and Roth (2018) describe three causes of unfairness that help clarify Fair Batch s fairness contributions: (1) minimizing average error fits majority populations, (2) bias encoded in data, and (3) the need to explore and gather more data. Fair Batch addresses the cause (1) via balancing the sensitive group ratios within a batch. Fair Batch also addresses (2) in some cases. For example, consider the recidivism prediction problem described in (Chouldechova and Roth, 2018) where minority populations have biased labels. In this case, Fair Batch can be configured to make the recidivism prediction rate for the minority population similar to those of other populations. There may be other types of data bias that Fair Batch is not able to address. Finally, Fair Batch does not directly address (3) where one must gather more data for better fairness. Instead, there is a recent line of work that studies data collection techniques (Tae and Whang, 2021) for fairness. Batch Selection The batch selection literature for SGD focuses on analyzing the effect of batch sizes (Keskar et al., 2017; Masters and Luschi, 2018) and various sampling techniques (Shamir, 2016; G urb uzbalaban et al., 2019). More recently, importance sampling techniques have been proposed for faster convergence (Loshchilov and Hutter, 2016; Alain et al., 2016; Stich et al., 2017; Csiba and Richt arik, 2018; Katharopoulos and Fleuret, 2018; Johnson and Guestrin, 2018). In comparison, Fair Batch takes the novel approach of using batch selection for better fairness and is compatible with other existing techniques. 6 CONCLUSION We addressed model fairness via the lens of bilevel optimization and proposed the Fair Batch batch selection algorithm. The bilevel optimization provides a natural framework where the inner optimizer is SGD, and the outer optimizer performs adaptive batch selection to improve fairness. We presented Fair Batch for implementing this optimization and showed how its underlying theory supports the fairness measures: equal opportunity, equalized odds, and demographic parity. We showed that Fair Batch offers respectful performances that are on par or even better than the state of the arts w.r.t. all aspects in consideration: accuracy, fairness, and runtime. Also, Fair Batch can readily be adopted to machine learning systems with a minimal change of replacing the batch selection with a single-line of code and be gracefully merged with other batch selection techniques used for faster convergence. ACKNOWLEDGEMENTS Yuji Roh and Steven E. Whang were supported by a Google AI Focused Research Award and by the Engineering Research Center Program through the National Research Foundation of Korea (NRF) funded by the Korean Government MSIT (NRF-2018R1A5A1059921). Kangwook Lee was supported by NSF/Intel Partnership on Machine Learning for Wireless Networking Program under Grant No. CNS-2003129. Changho Suh was supported by Institute for Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No. 2019-0-01396, Development of framework for analyzing, detecting, mitigating of bias in AI model and training data). Published as a conference paper at ICLR 2021 Kristy Choi, Aditya Grover, Trisha Singh, Rui Shu, and Stefano Ermon. Fair generative modeling via weak supervision. In ICML, 2020. Heinrich Jiang and Ofir Nachum. Identifying and correcting label bias in machine learning. In AISTATS, pages 702 712, 2020. Muhammad Bilal Zafar, Isabel Valera, Manuel Gomez-Rodriguez, and Krishna P. Gummadi. Fairness constraints: Mechanisms for fair classification. In AISTATS, pages 962 970, 2017a. Muhammad Bilal Zafar, Isabel Valera, Manuel Gomez-Rodriguez, and Krishna P. Gummadi. Fairness beyond disparate treatment & disparate impact: Learning classification without disparate mistreatment. In WWW, pages 1171 1180, 2017b. Brian Hu Zhang, Blake Lemoine, and Margaret Mitchell. Mitigating unwanted biases with adversarial learning. In AIES, pages 335 340, 2018. Vasileios Iosifidis and Eirini Ntoutsi. Adafair: Cumulative fairness adaptive boosting. In CIKM, page 781 790, 2019. Moritz Hardt, Eric Price, and Nati Srebro. Equality of opportunity in supervised learning. In Neur IPS, pages 3315 3323, 2016. Julia Angwin, Jeff Larson, Surya Mattu, and Lauren Kirchner. Machine bias: There s software used across the country to predict future criminals. And its biased against blacks. Pro Publica, 2016. Michael Feldman, Sorelle A. Friedler, John Moeller, Carlos Scheidegger, and Suresh Venkatasubramanian. Certifying and removing disparate impact. In KDD, pages 259 268, 2015. Ron Kohavi. Scaling up the accuracy of naive-bayes classifiers: A decision-tree hybrid. In KDD, pages 202 207, 1996. Zhifei Zhang, Yang Song, and Hairong Qi. Age progression/regression by conditional adversarial autoencoder. In CVPR, pages 4352 4360, 2017. Faisal Kamiran and Toon Calders. Data preprocessing techniques for classification without discrimination. Knowl. Inf. Syst., 33(1):1 33, 2011. Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR, pages 770 778, 2016. Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott E. Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. Going deeper with convolutions. In CVPR, pages 1 9, 2015. L eon Bottou. Large-scale machine learning with stochastic gradient descent. In COMPSTAT, pages 177 186. Springer, 2010. Stephen Boyd, Stephen P Boyd, and Lieven Vandenberghe. Convex optimization. Cambridge university press, 2004. Rachel K. E. Bellamy, Kuntal Dey, Michael Hind, Samuel C. Hoffman, Stephanie Houde, Kalapriya Kannan, Pranay Lohia, Jacquelyn Martino, Sameep Mehta, Aleksandra Mojsilovic, Seema Nagar, Karthikeyan Natesan Ramamurthy, John T. Richards, Diptikalyan Saha, Prasanna Sattigeri, Moninder Singh, Kush R. Varshney, and Yunfeng Zhang. AI fairness 360: An extensible toolkit for detecting and mitigating algorithmic bias. IBM J. Res. Dev., 63(4/5):4:1 4:15, 2019. Jerome Friedman, Trevor Hastie, and Robert Tibshirani. Additive Logistic Regression: a Statistical View of Boosting. The Annals of Statistics, 38(2), 2000. Ilya Loshchilov and Frank Hutter. Online batch selection for faster training of neural networks. In ICLR 2016 Workshop Track, 2016. Published as a conference paper at ICLR 2021 Arvind Narayanan. Translation tutorial: 21 fairness definitions and their politics. In FAcc T, volume 1170, 2018. Richard S. Zemel, Yu Wu, Kevin Swersky, Toniann Pitassi, and Cynthia Dwork. Learning fair representations. In ICML, pages 325 333, 2013. Fl avio du Pin Calmon, Dennis Wei, Bhanukiran Vinzamuri, Karthikeyan Natesan Ramamurthy, and Kush R. Varshney. Optimized pre-processing for discrimination prevention. In Neur IPS, pages 3995 4004, 2017. Toshihiro Kamishima, Shotaro Akaho, Hideki Asoh, and Jun Sakuma. Fairness-aware classifier with prejudice remover regularizer. In ECML PKDD, pages 35 50, 2012. Alekh Agarwal, Alina Beygelzimer, Miroslav Dud ık, John Langford, and Hanna M. Wallach. A reductions approach to fair classification. In ICML, pages 60 69, 2018. Andrew Cotter, Heinrich Jiang, and Karthik Sridharan. Two-player games for efficient non-convex constrained optimization. In ALT, pages 300 332, 2019. Yuji Roh, Kangwook Lee, Steven Euijong Whang, and Changho Suh. FR-Train: A mutual information-based approach to fair and robust training. In ICML, 2020. Faisal Kamiran, Asim Karim, and Xiangliang Zhang. Decision theory for discrimination-aware classification. In ICDM, pages 924 929, 2012. Geoff Pleiss, Manish Raghavan, Felix Wu, Jon M. Kleinberg, and Kilian Q. Weinberger. On fairness and calibration. In Neur IPS, pages 5684 5693, 2017. Evgenii Chzhen, Christophe Denis, Mohamed Hebiri, Luca Oneto, and Massimiliano Pontil. Leveraging labeled and unlabeled data for consistent fair binary classification. In Neur IPS, pages 12760 12770. 2019. Suresh Venkatasubramanian. Algorithmic fairness: Measures, methods and representations. In PODS, page 481, 2019. Cynthia Dwork, Moritz Hardt, Toniann Pitassi, Omer Reingold, and Richard S. Zemel. Fairness through awareness. In ITCS, pages 214 226, 2012. Niki Kilbertus, Mateo Rojas-Carulla, Giambattista Parascandolo, Moritz Hardt, Dominik Janzing, and Bernhard Sch olkopf. Avoiding discrimination through causal reasoning. In Neur IPS, pages 656 666, 2017. Matt J Kusner, Joshua Loftus, Chris Russell, and Ricardo Silva. Counterfactual fairness. In Neur IPS, pages 4066 4076. 2017. Junzhe Zhang and Elias Bareinboim. Fairness in decision-making - the causal explanation formula. In AAAI, 2018. Razieh Nabi and Ilya Shpitser. Fair inference on outcomes. In AAAI, pages 1931 1940, 2018. Aria Khademi, Sanghack Lee, David Foley, and Vasant Honavar. Fairness in algorithmic decision making: An excursion through the lens of causality. In WWW, pages 2907 2914, 2019. Aman Sinha, Hongseok Namkoong, and John C. Duchi. Certifying some distributional robustness with principled adversarial training. In ICLR, 2017. Tatsunori Hashimoto, Megha Srivastava, Hongseok Namkoong, and Percy Liang. Fairness without demographics in repeated loss minimization. In ICML, pages 1929 1938, 2018. Alexandra Chouldechova and Aaron Roth. The frontiers of fairness in machine learning. Co RR, abs/1810.08810, 2018. Ki Hyun Tae and Steven Euijong Whang. Slice tuner: A selective data acquisition framework for accurate and fair machine learning models. In SIGMOD, 2021. Published as a conference paper at ICLR 2021 Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. In ICLR, 2017. Dominic Masters and Carlo Luschi. Revisiting small batch training for deep neural networks. Co RR, abs/1804.07612, 2018. Ohad Shamir. Without-replacement sampling for stochastic gradient methods. In NIPS, pages 46 54. 2016. Mert G urb uzbalaban, Asu Ozdaglar, and PA Parrilo. Why random reshuffling beats stochastic gradient descent. Mathematical Programming, pages 1 36, 2019. Guillaume Alain, Alex Lamb, Chinnadhurai Sankar, Aaron Courville, and Yoshua Bengio. Variance reduction in sgd by distributed importance sampling. ICLR Workshop, 2016. Sebastian U Stich, Anant Raj, and Martin Jaggi. Safe adaptive importance sampling. In NIPS, pages 4381 4391, 2017. Dominik Csiba and Peter Richt arik. Importance sampling for minibatches. The Journal of Machine Learning Research, 19(1):962 982, 2018. Angelos Katharopoulos and Franc ois Fleuret. Not all samples are created equal: Deep learning with importance sampling. In ICML, pages 2530 2539, 2018. Tyler B Johnson and Carlos Guestrin. Training deep models faster with robust, approximate importance sampling. In NIPS, pages 7265 7275, 2018. Published as a conference paper at ICLR 2021 A APPENDIX THEORY AND ALGORITHMS A.1 DEMOGRAPHIC PARITY We continue from Sec. 2 and provide more details on how we can capture demographic parity using our bilevel optimization framework. Proposition 2. If m0,0 = m0,1 = m1,0 = m1,1, then L0,0(w) = L1,0(w) and L0,1(w) = L1,1(w) can serve as a sufficient condition for demographic parity. Proof. Slightly abusing the notation, we denote by Pr( ) the empirical probability. The demographic parity is satisfied when Pr(ˆy = 1|z = 0) = Pr(ˆy = 1|z = 1) holds. Thus, Pr(ˆy = 1, y = 0|z = 0) + Pr(ˆy = 1, y = 1|z = 0) = Pr(ˆy = 1, y = 0|z = 1) + Pr(ˆy = 1, y = 1|z = 1). Since ℓ(|1 y|, ) = 1 ℓ(y, ), we have i:yi=0,zi=0 (1 ℓ(yi, ˆyi)) + 1 m ,0 i:yi=1,zi=0 ℓ(yi, ˆyi) i:yi=0,zi=1 (1 ℓ(yi, ˆyi)) + 1 m ,1 i:yi=1,zi=1 ℓ(yi, ˆyi). By replacing P i:yi=y,zi=z ℓ(yi, ˆyi) = my,z Ly,z(w), m0,0 m ,0 (1 L0,0(w)) + m1,0 m ,0 L1,0(w) = m0,1 m ,1 (1 L0,1(w)) + m1,1 m ,1 L1,1(w). If m0,0 = m0,1 = m1,0 = m1,1, this reduces to L0,0(w) = L1,0(w) and L0,1(w) = L1,1(w), the above condition reduces to L0,0(w) + L1,0(w) = L0,1(w) + L1,1(w). A sufficient condition to the above condition is L0,0(w) = L1,0(w) and L0,1(w) = L1,1(w). In general, the condition of the above proposition does not hold. Observe that another sufficient condition to demographic parity is as follows: m1,0 m ,0 L1,0(w) m1,1 m ,1 L1,1(w) = 0 m0,0 m ,0 L0,0(w) m0,1 m ,1 L0,1(w) = m0,0 Let us define L 1,0(w) = m1,0 m ,0 L1,0(w), L 1,1(w) = m1,1 m ,1 L1,1(w), L 0,0(w) = m0,0 m ,0 L0,0(w), L 0,1(w) = m0,1 m ,1 L0,1(w), and c = m0,0 m ,1 . Also, define |x|c = max{x c, c x}. Then, we have the following bilevel optimization problem: min λ [0,1] [0,1] max{|L 1,0(wλ) L 1,1(wλ)|, |L 0,0(wλ) L 0,1(wλ)|c}, wλ = arg min w λ1L 0,0(w) + (1 λ1)L 1,0(w) + λ2L 0,1(w) + (1 λ2)L 1,1(w). A.2 PROOF FOR LEMMA 1 We continue from Sec. 3.1 and provide a full proof for Lemma 1. Here we recall Lemma 1. Lemma 1 (Quasi-convexity of F(λ)). For d = 1, if f1( ), g1( ), and h( ) satisfy 1. h(w) = 0 or 2. if f1( ), g1( ), and h( ) are twice differentiable, λ 2f1(wλ) + (c1 λ) 2g1(wλ) + 2h(wλ) 0 for every λ [0, c1], then F(λ) is quasi-convex, i.e., F(tλ + (1 t)λ ) max F(λ), F(λ ) for all t [0, 1] and λ, λ . Also, if F( ) = 0, then λF(λ) = {v} and sign (v) = sign (g1(wλ) f1(wλ)). Published as a conference paper at ICLR 2021 Proof. It it known that a continuous function f : R R is quasiconvex if and only if at least one of the following conditions holds: 1) nondecreasing, 2) nonincreasing, and 3) nonincreasing and then nondecreasing (Boyd et al., 2004). We will prove the lemma by showing that the function F(λ) is quasiconvex by showing that it is nonincreasing and then nondecreasing. More precisely, we will show that f1(wλ) g1(wλ) is a nonincreasing function. It is easy to see that this directly implies that |f1(wλ) g1(wλ)| is nonincreasing and then nondecreasing. Case 1 (h(w) = 0) Consider λ1 and λ2 such that λ1 > λ2. If we can show f1(w λ1) f1(w λ2) and g1(w λ1) g1(w λ2), then this implies that f1(wλ) g1(wλ) is a nonincreasing function. Indeed, this is very intuitive: If we increase λ, the inner optimization problems puts a higher weight on f1( ), resulting in a lower value of f1(w ) and a higher value of g1(w ). We formally show this by contradiction. By the definition of wλ, we have the following two conditions: λ1f1(w λ1) + (c1 λ1)g1(w λ1) λ1f1(w) + (c1 λ1)g1(w), w, (1) λ2f1(w λ2) + (c1 λ2)g1(w λ2) λ2f1(w) + (c1 λ2)g1(w), w. (2) If the lemma s statement is false, one of the following three events should occur: 1. f1(w λ1) > f1(w λ2) and g1(w λ1) g1(w λ2): By adding these two inequalities with respective weights λ1 and c1 λ1, we have λ1f1(w λ1)+(c1 λ1)g1(w λ1) > λ1f1(w λ2)+ (c1 λ1)g1(w λ2). This contradicts equation 1. 2. f1(w λ1) f1(w λ2) and g1(w λ1) < g1(w λ2): Similarly, by adding these two inequalities with respective weights λ2 and c1 λ2, we have λ2f1(w λ2) + (c1 λ2)g1(w λ2) > λ2f1(w λ1) + (c1 λ2)g1(w λ1). This contradicts equation 2. 3. f1(w λ1) > f1(w λ2) and g1(w λ1) < g1(w λ2): By adding equation 1 with w = w λ2 and equation 2 with w = w λ1, we have λ1f1(w λ1) + (c1 λ1)g1(w λ1) + λ2f1(w λ2) + (c1 λ2)g1(w λ2) λ1f1(w λ2) + (c1 λ1)g1(w λ2) + λ2f1(w λ1) + (c1 λ2)g1(w λ1). By rearranging terms, we have (λ1 λ2)(f1(w λ1) f1(w λ2)) (λ1 λ2)(g1(w λ1) g1(w λ2)). By dividing both sides by λ1 λ2 > 0, we have f1(w λ1) f1(w λ2) g1(w λ1) g1(w λ2). This contradicts the condition as the left-hand side is positive while the right-hand side is negative. This completes the proof of the first claim by contradiction. The second claim immediately follows the first claim. Since F(λ) = |f1(wλ) g1(wλ)|, we have d F (λ) dλ = sign (f1(wλ) g1(wλ)) d dλ(f1(wλ) g1(wλ)). As shown in the earlier part of this proof, f1(wλ) g1(wλ) is a nonincreasing function, i.e., df1(wλ) g1(wλ) dλ 0. Thus, sign( d F (λ) dλ ) = sign(g1(wλ) f1(wλ)). Case 2 (If f1( ), g1( ), and h( ) are twice differentiable, λ 2f1(wλ) + (c1 λ) 2g1(wλ) + 2h(wλ) 0 for every λ [0, c1]) In this part of the proof, we will denote wλ by w for simplicity. To show that f1(w) g1(w) is a nondecreasing function (in λ), consider the derivative: d dλ(f1(w) g1(w)) = ( f1(w) g1(w)) dw To compute dw dλ , we implicitly differentiate (with respect to λ) the following stationary equation. λ f1(w) + (c1 λ) g1(w) + h(w) = 0 f1(w) + λ 2f1(w) dw dλ g1(w) + (c1 λ) 2g1(w) dw dλ + 2h(w) dw Published as a conference paper at ICLR 2021 By rearranging terms, we have λ 2f1(w) + (c1 λ) 2g1(w) + 2h(w) dw dλ = ( f1(w) g1(w)). By the assumption, λ 2f1(w)+(c1 λ) 2g1(w)+ 2h(w) is positive definite and hence invertible. Thus, dλ = λ 2f1(w) + (c1 λ) 2g1(w) + 2h(w) 1 ( f1(w) g1(w)). d dλ(f1(w) g1(w)) = ( f1(w) g1(w)) λ 2f1(w) + (c1 λ) 2g1(w) + 2h(w) 1 ( f1(w) g1(w)). Note that λ 2f1(w) + (c1 λ) 2g1(w) + 2h(w) 1 is also positive definite. Thus, d dλ(f1(w) g1(w)) is always negative, and hence f1(w) g1(w) is a decreasing function. Now, observe that (f1(w) g1(w)) d dλ(f1(w) g1(w)) λF(λ). Therefore, if F( ) = 0, then λF(λ) = {v} and sign (v) = sign (g1(w) f1(w)). A.3 INNER OBJECTIVE S CONVEXITY DOES NOT IMPLY OUTER OBJECTIVE S CONVEXITY 0.0 0.2 0.4 0.6 0.8 1.0 0.0 Figure 2: F(λ) is not convex, but quasi-convex. We continue from Sec. 3.2 and provide an example where inner objective s convexity does not imply outer objective s convexity. Consider the following strongly convex functions f1( ) and g1( ): f1(w) = ew + e w 5 , g1(w) = (w 1)2 Shown in Fig. 2 is the outer objective function F(λ). One can observe that it is not convex. Note that it is quasiconvex by Lemma 1. A.4 GRADIENT WHEN d 1 We continue from Sec. 3.2 and derive the gradient of the outer objective function. Recall how we formulated the general bilevel optimization problem: min λ Λ max i=1,...,d |fi(wλ) gi(wλ)|, wλ = arg min w Pd i=1 [λifi(w) + (ci λi)gi(w)] + h(w). In this section, we will prove the following: sign (gi (w) fi (w))( fi (w) gi (w)) H 1 λ ( fi(w) gi(w)) λi F(λ), i. Assume that Pd i=1[λi 2fi(wλ) + (ci λi) 2gi(wλ)] + 2h(wλ) 0 for every λ Λ. In this part of the proof, we will denote wλ by w for simplicity. Published as a conference paper at ICLR 2021 To compute dw dλi , we implicitly differentiate (with respect to λi) the following stationary equation. j=1 [λj fj(w) + (cj λj) gj(w)] + h(w) = 0 fi(w) + λi 2fi(w) w λi gi(w) + (ci λi) 2gi(w) w 1 j d, j =i λj 2fj(w) w λi + (cj λj) 2gj(w) w By rearranging terms, we have λj 2fj(w) + (cj λj) 2gj(w) + 2h(w) λi = ( fi(w) gi(w)). By the assumption, Hλ := Pd j=1 λj 2fj(w) + (cj λj) 2gj(w) + 2h(w) is positive definite and hence invertible. Thus, w λi = H 1 λ ( fi(w) gi(w)). Now observe that F(λ) = |fi (wλ) gi (wλ)|. Therefore, sign (fi (w) gi (w)) λi (fi (w) gi (w)) λi F(λ). λi (fi (w) gi (w)) = ( fi (w) gi (w)) H 1 λ ( fi(w) gi(w)), sign (fi (w) gi (w))( fi (w) gi (w)) H 1 λ ( fi(w) gi(w)) λi F(λ). (4) A.5 RATIONALE AND INTUITION BEHIND THE APPROXIMATION We continue from Sec. 3.2 and provide more justifications for the gradient approximation technique. Assume that Pd i=1[λi 2fi(wλ) + (ci λi) 2gi(wλ)] + 2h(wλ) 0 for every λ Λ. Then, the gradient can be fully characterized as in equation 4. The rationale behind the approximation γ (0, 0, . . . , γi , . . . , 0) is that |γi | will be maximized at i if f1(w) g1(w) f2(w) g2(w) fd(w) gd(w) . This is because ( fi (w) gi (w)) H 1 λ ( fi(w) gi(w)) is an inner product between H 1/2 λ ( fi (w) gi (w)) and H 1/2 λ ( fi(w) gi(w)), and they are always perfectly aligned when i = i . This approximation is also intuitive. Recall that changing λi affects the weights associated with fi (w) and gi (w) in the inner optimization problem. Thus, changes in λi will directly affect F(λ) = |fi (w) gi (w)|. On the other hand, changing λi for i = i does not affect the weights associated with fi (w) and gi (w) but only affects the weights of other terms, so it will only indirectly and weakly affect F(λ). A.6 FAIRBATCH ALGORITHMS IN PSEUDOCODE We continue from Sec. 3.2 and present the Fair Batch algorithms in pseudocode. Algorithms 2, 3, and 4 show how λ is adjusted for equal opportunity, equalized odds, and demographic parity, respectively. From the intermediate model at each epoch (or after a certain iterations), we first obtain f(w) and g(w), which correspond to the losses conditioned on each class. Then, one can update the current value of λ by comparing f(w) and g(w). Published as a conference paper at ICLR 2021 Algorithm 2: Adaptive adjustment of λ w.r.t. equal opportunity. Input: Intermediate model, criterion, train data (xtrain, ztrain, ytrain), previous lambda λ(t 1), and Fair Batch s learning rate α output = model (xtrain) loss = criterion (output, ytrain) λ(t 1) + α, if mean(loss[(y = 1, z = 0)]) > mean(loss[(y = 1, z = 1)]) λ(t 1) α, if mean(loss[(y = 1, z = 0)]) < mean(loss[(y = 1, z = 1)]) λ(t 1), otherwise Output :Next lambda λ(t) Algorithm 3: Adaptive adjustment of λ w.r.t. equalized odds. Input: Intermediate model, criterion, train data (xtrain, ztrain, ytrain), previous lambda λ(t 1), and Fair Batch s learning rate α output = model (xtrain) loss = criterion (output, ytrain) dy=0 = mean(loss[(y = 0, z = 0)]) mean(loss[(y = 0, z = 1)]) dy=1 = mean(loss[(y = 1, z = 0)]) mean(loss[(y = 1, z = 1)]) if |dy=0| > |dy=1| then λ(t 1) 1 + α, if dy=0 > 0 λ(t 1) 1 α, if dy=0 < 0 λ(t 1) 1 , otherwise λ(t 1) 2 + α, if dy=1 > 0 λ(t 1) 2 α, if dy=1 < 0 λ(t 1) 2 , otherwise Output :Next lambda λ(t) B APPENDIX EXPERIMENTS B.1 OTHER EXPERIMENTAL SETTINGS We continue from Sec. 4 and provide more details on experimental settings. We use the Adam optimizer for all trainings. We perform cross-validation on the training sets to find the best hyperparameters for each algorithm. We evaluate models on separate test sets, and the ratios of the train versus test data for the synthetic and real datasets are 2:1 and 4:1, respectively. B.2 EQUALIZED ODDS RESULTS We continue from Sec. 4.1 and show Table 4, which compares the performances of all the fair training techniques on the synthetic, COMPAS, and Adult Census test sets w.r.t. equalized odds. The key observations are the same as in Table 1 where overall Fair Batch has the most robust performance against the state of the arts w.r.t. accuracy, fairness, and runtime. B.3 EXTENSION OF ADAFAIR We continue from Sec. 4 and provide more details on how we extend Ada Fair, which already supports ED, to also support EO and DP. The extension to EO is straightforward as EO is a relaxed version of ED where only the y = 1 class is considered when measuring disparity. Hence, we only reweight examples in the y = 1 class as well. The extension to DP is done by giving more weights on the Published as a conference paper at ICLR 2021 Algorithm 4: Adaptive adjustment of λ w.r.t. demographic parity. Input: Intermediate model, criterion, train data (xtrain, ztrain, ytrain), previous lambda λ(t 1), and Fair Batch s learning rate α output = model (xtrain) loss = criterion (output, 1) dy=0 = sum(loss[(y = 0, z = 0)])/len(z = 0) sum(loss[(y = 0, z = 1)])/len(z = 1) dy=1 = sum(loss[(y = 1, z = 0)])/len(z = 0) sum(loss[(y = 1, z = 1)])/len(z = 1) if |dy=0| > |dy=1| then λ(t 1) 1 α, if dy=0 > 0 λ(t 1) 1 + α, if dy=0 < 0 λ(t 1) 1 , otherwise λ(t 1) 2 + α, if dy=1 > 0 λ(t 1) 2 α, if dy=1 < 0 λ(t 1) 2 , otherwise Output :Next lambda λ(t) Table 4: Performances on the synthetic, COMPAS, and Adult Census test sets w.r.t. equalized odds (ED). The other settings are identical to Table 1. Synthetic COMPAS Adult Census Method Acc. ED Disp. Epochs Acc. ED Disp. Epochs Acc. ED Disp. Epochs LR .885 .000 .115 .000 400 .681 .002 .239 .006 300 .845 .001 .056 .003 300 Cutting .859 .001 .036 .004 650 .665 .005 .066 .018 400 .802 .001 .062 .005 600 RW .856 .000 .038 .002 350 .685 .000 .137 .000 300 .835 .001 .134 .006 100 LBC .858 .001 .026 .000 8800 .673 .002 .063 .005 9000 .840 .002 .027 .004 3300 FC .865 .000 .030 .001 800 .677 .004 .101 .024 50 .841 .001 .038 .003 300 AD .857 .000 .030 .001 1200 .683 .000 .082 .027 450 .843 .002 .033 .002 500 Ada Fair .868 .001 .029 .002 22400 .675 .000 .066 .002 9600 .843 .001 .038 .002 7800 Fair Batch .856 .001 .038 .002 400 .682 .001 .052 .014 100 .843 .001 .036 .002 500 positive examples of a certain sensitive group z = z that suffers from a lower positive prediction rate than other groups. B.4 FAIRNESS CURVES We continue from Sec. 4.1 and show in Figures 3 and 4 the EO and DP disparity curves against the number of epochs for each fairness technique on the synthetic dataset. We also directly compare the curves of all fairness techniques in one graph as shown in Figure 5. Since LBC and Ada Fair require more than 10x many epochs than other methods, we only show their first 1000 epochs. As a result, Fair Batch is one of the fastest methods to converge to low EO or DP disparities. B.5 TRADE-OFF CURVES OF FAIRBATCH We continue from Sec. 4.1 and show in Fig. 6 the accuracy-fairness trade-off curves of Fair Batch for EO and DP on the synthetic dataset. Fair Batch can be tuned by making it less sensitive to disparity. In Algorithms 2 and 4, notice that the λ parameters are updated if there is any disparity among sensitive groups. We modify this logic where the λ parameters are only updated if the disparity is Published as a conference paper at ICLR 2021 0 100 200 300 400 Epochs 0 200 400 Epochs (b) Cutting 0 200 400 600 800 Epochs 0 2500 5000 7500 10000 Epochs 0 200 400 600 Epochs 0 50 100 150 200 Epochs 0 5000 10000 15000 Epochs (g) Ada Fair 0 100 200 300 Epochs (h) Fair Batch Figure 3: EO disparity curves of algorithms on the synthetic dataset. 0 100 200 300 400 Epochs 0 200 400 Epochs (b) Cutting 0 100 200 300 400 Epochs 0 5000 10000 15000 Epochs 0 200 400 600 Epochs 0 200 400 600 Epochs 0 20000 40000 Epochs (g) Ada Fair 0 100 200 300 400 Epochs (h) Fair Batch Figure 4: DP disparity curves of algorithms on the synthetic dataset. above some threshold T. The trade-off curves in Fig. 6 are thus generated by adjusting T. For both EO and DP, we observe that there is a clear trade-off between accuracy and disparity. B.6 WALL CLOCK TIMES We continue from Sec. 4.1 and show in Table 5 the wall clock times (in seconds) of the experiments in Table 1 where we compare Fair Batch against all the fairness techniques on the synthetic, COMPAS, and Adult Census datasets. As a result, each runtime is proportional to the number of epochs shown in Table 1. When comparing the runtimes of individual batches, Fair Batch s batch takes 1.5x longer to run than LR s batch. Table 5: Wall clock times (in seconds) of the experiments in Table 1 using the same settings. Dataset LR Cutting RW LBC FC AD Ada Fair Fair Batch Synthetic 5.71 5.67 17.24 208.47 16.05 3.97 294.31 5.25 COMPAS 6.07 3.34 7.48 94.10 2.76 6.93 215.39 3.00 Adult Census 22.96 7.70 10.02 558.31 28.71 31.76 791.58 46.79 Published as a conference paper at ICLR 2021 0 200 400 600 800 1000 Epochs LR Cutting Reweighing LBC FC AD Ada Fair Fair Batch (a) EO disparity curve of Fair Batch and the baselines. 0 200 400 600 800 1000 Epochs LR Cutting Reweighing LBC FC AD Ada Fair Fair Batch (b) DP disparity curve of Fair Batch and the baselines. Figure 5: Epochs-fairness disparity curves of all algorithms together. 0.855 0.860 0.865 0.870 0.875 0.880 Accuracy (a) Accuracy-EO disparity trade-off curve. 0.80 0.82 0.84 0.86 0.88 Accuracy (b) Accuracy-DP disparity trade-off curve. Figure 6: Accuracy-fairness disparity trade-off curves of Fair Batch on the synthetic dataset. B.7 COMPARISON WITH ADAFAIR We continue from Sec. 4.1 and compare the class weights between Fair Batch and the Ada Fair algorithm. For Ada Fair, the class weights are calculated by adding all example weights in each class. Fig. 7 shows the weight changes of each algorithm. Overall, the trends of the weights are similar. Again, the advantage of Fair Batch is that it can run within one model training instead of using multiple model trainings as in Ada Fair. 0 20 40 60 80 100 Number of models Class weights y1, z1 y0, z1 y1, z0 y0, z0 (a) Number of models-class weight curve of Ada Fair. 0 100 200 300 400 500 Epochs Class weights y1, z1 y0, z1 y1, z0 y0, z0 (b) Epochs-class weight curve of Fair Batch. Figure 7: Comparison of the weight changes on Ada Fair and Fair Batch w.r.t. equalized odds on the synthetic dataset. B.8 EXTENSION OF FAIRBATCH TO MULTI CLASSIFICATION We continue from Sec. 4.2 and explain how Fair Batch can be extended to support multi classification by adjusting more λ parameters. For example, for ED, the label attribute has n classes, and each class connects to m λs. We adjust m λs in the class y = i at each epoch, where the class y = i has the highest ED disparity at that epoch. Published as a conference paper at ICLR 2021 B.9 FAIRBATCH WITH IMPORTANCE SAMPLING We continue from Sec. 4.3 and show Fig. 8, which plots the convergence of Fair Batch when merged with loss-based weighting batch selection. As a result, Fair Batch uses about 50 fewer epochs to converge to low disparities compared to not using loss-based weighting. 0 100 200 300 400 500 Epochs Fair Batch Fair Batch with importance sampling (a) EO disparity curve of Fair Batch. 0 200 400 600 800 1000 Epochs Fair Batch Fair Batch with importance sampling (b) DP disparity curve of Fair Batch. Figure 8: Fairness curves of Fair Batch on the synthetic dataset, with/without loss-based weighting (Loshchilov and Hutter, 2016).