# learning_to_teach_with_dynamic_loss_functions__a1a97ba2.pdf Learning to Teach with Dynamic Loss Functions 1Lijun Wu , 2Fei Tian , 2Yingce Xia, 3Yang Fan , 2Tao Qin, 1Jianhuang Lai, 2Tie-Yan Liu 1Sun Yat-sen University, Guangzhou, China 2Microsoft Research, Beijing, China 3University of Science and Technology of China, Hefei, China 1wulijun3@mail2.sysu.edu.cn, stsljh@mail.sysu.edu.cn 2{fetia, yingce.xia, taoqin, tie-yan.liu}@microsoft.com, 3fyabc@mail.ustc.edu.cn Teaching is critical to human society: it is with teaching that prospective students are educated and human civilization can be inherited and advanced. A good teacher not only provides his/her students with qualified teaching materials (e.g., textbooks), but also sets up appropriate learning objectives (e.g., course projects and exams) considering different situations of a student. When it comes to artificial intelligence, treating machine learning models as students, the loss functions that are optimized act as perfect counterparts of the learning objective set by the teacher. In this work, we explore the possibility of imitating human teaching behaviors by dynamically and automatically outputting appropriate loss functions to train machine learning models. Different from typical learning settings in which the loss function of a machine learning model is predefined and fixed, in our framework, the loss function of a machine learning model (we call it student) is defined by another machine learning model (we call it teacher). The ultimate goal of teacher model is cultivating the student to have better performance measured on development dataset. Towards that end, similar to human teaching, the teacher, a parametric model, dynamically outputs different loss functions that will be used and optimized by its student model at different training stages. We develop an efficient learning method for the teacher model that makes gradient based optimization possible, exempt of the ineffective solutions such as policy optimization. We name our method as learning to teach with dynamic loss functions (L2T-DLF for short). Extensive experiments on real world tasks including image classification and neural machine translation demonstrate that our method significantly improves the quality of various student models. 1 Introduction Teaching, which aims to help students learn new knowledge or skills effectively and efficiently, is important to advance modern human civilization. In human society, the rapid growth of qualified students not only relies on their intrinsic learning capability, but also, even more importantly, relies on the substantial guidance from their teachers. The duties of teachers cover a wide spectrum: defining the scope of learning (e.g., the knowledge and skills that we expect students to demonstrate by the end of a course), choosing appropriate instructional materials (e.g., textbooks), and assessing the progress of students (e.g., through course projects or exams). Effective teaching involves progressively and dynamically refining the teaching strategy based on reflection and feedback from students. Recently, the concept of teaching has been introduced into artificial intelligence (AI), so as to improve the learning process of a machine learning model. Currently, teaching in AI mainly focuses on The work was done when the first and fourth authors were interns at Microsoft Research Asia. The first two authors contribute equally to this work. 32nd Conference on Neural Information Processing Systems (Neur IPS 2018), Montréal, Canada. training data selection. For example, machine teaching [56, 34, 35] aims at identifying the smallest training data that is capable of producing the optimal learner models. The very recent work, learning to teach (L2T for short) [13], demonstrates how to automatically design teacher models for better machine learning process. While conceptually L2T can cover different aspects of teaching in AI, [13] only studies the problem of training data teaching. In this work, inspired from learning to teach, we study loss function teaching in a formal and concrete manner for the first time. The main motivation of our work is a natural observation on the analogy between loss functions in machine learning and exams in educating human students: appropriate exams reflect the progress of students and urge them to make improvements accordingly, while loss values outputted by the loss function evaluate the performance of current machine learning model and set the optimization direction for the model parameters. In our loss function teaching framework, a teacher model plays the role of outputting loss functions for the student model (i.e., the daily machine learning model to solve a task) to minimize. Inspired from human teaching, we design the teacher model according to the following principles. First, similar to the different difficulty levels of exams with respect to the progress of student in human education, the loss function set by the teacher model should be dynamic, i.e., the loss functions should be adaptive to different phases of the training process of the student model. To achieve this, we require our teacher model to take the status of student model into consideration in setting the loss functions, and to dynamically change the loss functions with respect to the growth of the student model. Such process is shown in Fig. 1. Second, the teacher model should be able to make self-improvement, just as a human teacher can accumulate more knowledge and improve his/her teaching skills through more teaching practices. To achieve that, we assume the loss function takes the form of neural network whose coefficients are determined via a parametric teacher model, which is also a neural network. The parameters of the teacher model can be automatically optimized in the teaching process. Through optimization, the teacher keeps improving its teaching model and consequently the quality of loss functions it outputs. We name our method as learning to teach with dynamic loss functions (L2T-DLF). Figure 1: The student model is trained via minimizing the dynamic loss functions taught by the teacher model (yellow curve). The bottom black plane represents the parameter space of student model and the four colored mesh surfaces denote different loss functions outputted via teacher model at different phases of student model training. The eventual goal of the teacher model is that its output can serve as the loss function of the student model to maximize the long-term performance of the student, measured via a taskspecific objective such as 0-1 accuracy in classification and BLEU score in sequence prediction [41], on a stand-alone development dataset. Learning a good teaching model is not trivial, since on the one hand the task-specific objective is usually non-smooth w.r.t. student model outputs, and on the other hand the final evaluation of the student model is incurred on the dev set, disjoint with the training dataset where the teaching process actually happens. We design an efficient gradient based optimization algorithm to optimize teacher models. Specifically, to tackle the first challenge, we smooth the task-specific measure to its expected version where the expectation is taken on the direct output of student model. To address the second challenge, inspired by Reverse-Mode Differentiation (RMD) [6, 7, 38], through reversing the stochastic gradient descent training process of the student model, we obtain derivatives of the parameters of the teacher model via chaining backwards the error signals incurred on the development dataset . We demonstrate the effectiveness of L2T-DLF on various realworld tasks including image classification and neural machine translation with different student models such as multi-layer perception networks, convolutional neural networks and sequence-to-sequence models with attention. The improvements clearly demonstrate the effectiveness of the new loss function learnt by L2T-DLF. 2 Related Work The study of teaching for AI, inspired by human teaching process, has a long history [1, 17]. The most recent efforts of teaching mainly focus on the level of training data selection. For example, the machine teaching [56, 34, 35] literature targets at building the smallest training set to obtain a pre-given optimal student model. A teaching strategy is designed in [18, 19] to iteratively select unlabeled data to label within the context of multi label propagation, in a similar manner with curriculum learning [8, 27]. Furthermore there are research on pedagogical teaching inspired from cognitive science [44, 23, 39] in which a teacher module is responsible for providing informative examples to the learner for the sake of understanding a concept rapidly. The recent work learning to teach (L2T) [13] offers a more comprehensive view of teaching for AI, including training data teaching, loss function teaching and hypothesis space teaching. Furthermore, L2T breaks the strong assumption towards the existence of an optimal off-the-shelf student model adopted by previous machine teaching literature [56, 35]. Our work belongs to the general framework of L2T, with a particular focus on a thorough landscape of loss function teaching, including the detailed problem setup and efficient solution for dynamically setting loss functions for training machine learning models. Our work, and the more general L2T, leverages automatic techniques to bypass human prior knowledge as much as possible, which is in line with the principles of learning to learn and meta learning [43, 50, 2, 57, 37, 29, 10, 14]. What makes our work different with others, from the technical point of view, is that: 1) we leverage gradient based optimization method rather than reinforcement learning [57, 13]; 2) we need to handle the difficulty when the error information cannot be directly back propagated from the loss function, since we aim at discovering the best loss function for the machine learning models. We design an algorithm based on Reverse-Mode Differentiation (RMD) [7, 38, 15] to tackle such a difficulty. Specially designed loss functions play important roles in boosting the performances of real-world tasks, either by approximating the non-smooth task-specific objective such as 0-1 accuracy in classification [40], NDCG in ranking [49], BLEU in machine translation [45, 3] and MAP in object detection [22, 46], or easing the optimization process of the student model such as overcoming the difficulty brought by data imbalance [30, 32] and numerous local optima [20]. L2T-DLF differs from prior works in that: 1) the loss functions are automatically learned, covering a large space and without the demand of heuristic understanding for task specific objective and optimization process; 2) the loss function dynamically evolves during the training process, leading to a more coherent interaction between loss and student model. In this section, we introduce the details of L2T-DLF, including the student model and the teacher model, as well as the training strategy for optimizing the teacher model. 3.1 Student Model For a task of interest, we denote its input space and output space respectively as X and Y. The student model for this task is then denoted as fω : X Y, with ω as its weight parameters. The training of student model fω is an optimization process that discovers a good weight parameter ω within a hypothesis space Ω, by minimizing a loss function l on the training data Dtrain containing M data points Dtrain = {(xi, yi)}M i=1. Specifically ω is obtained via solving minω Ω P (x,y) Dtrain l(fω(x), y). For the convenience of description, we define a new notation L(fω, D) = P (x,y) D l(fω(x), y) where D is a dataset and will simultaneously name L as loss function when the context is clear. The learnt student model fω is then evaluated on a test data set Dtest = {(xi, yi)}N i=1 to obtain a score M(fω , Dtest) = P (x,y) Dtest m(fω (x), y), as its performance. Here the task specific objective m(y1, y2) measures the similarity between two output candidates y1 and y2. The loss function l(ˆy, y), taking the model prediction ˆy = fω(x) and ground-truth y as inputs, acts as the surrogate of m to evaluate the student model fω during its training process, just as the exams in real-world human teaching. We assume l(ˆy, y) is a neural network with some coefficients Φ, denoted as lΦ(ˆy, y). It can be a simple linear model, or a deep neural network (some concrete examples are provided in section 4.1 and section 4.2). With such a loss function lΦ(ˆy, y) (and the induced notation LΦ), the student model gets sequentially updated via minimizing the output value of lΦ by, for example, stochastic gradient descent (SGD): ωt+1 = ωt ηt LΦ(fωt,Dt train) ωt , t = {1, 2, , T}, where Dt train Dtrain, ωt and ηt is respectively the mini-batch training data, student model weight parameter and learning rate at t-th timestep. For ease of statement we simply set ω = ωT . 3.2 Teacher Model A teacher model is responsible for setting the proper loss function l to the student model by outputting appropriate loss function coefficients Φ. To cater for different status of student model training, we ask the teacher model to output different loss functions lt at each training step t. To achieve that, the status of a student model is represented by a state vector st at timestep t, which contains for example the current training/dev accuracy and iteration number. The teacher model, denoted as µ, then takes st as inputs to compute the coefficients of loss function Φt at t-th timestep as Φt = µθ(st), where θ is the parameters of the teacher model. We further provide some examples of µθ in section 4.1 and section 4.2. The actual loss function for student model is then lt = lΦt. The learning process of student model then switches to: ωt+1 = ωt ηt LΦt(fωt, Dt train) ωt = ωt ηt Lµθ(st)(fωt, Dt train) ωt . (1) Such a sequential procedure of obtaining fω (i.e., fωT ) is the learning process of the student model with training data Dtrain and loss function provided via the teacher model µθ, and we use an abstract operator F to denote it: fω = F(Dtrain, µθ). Just as the training and testing setup in typical machine learning scenarios, the teacher model here similarly follows the two phases setup. Specifically, in the training process of teacher model, similar to qualified human teachers are good at improving the quality of exams, the teacher model in L2T-DLF refines the loss function it sets up via optimizing its own θ. The ultimate goal of teacher model is to maximize the performance of induced student model on a stand-alone development dataset Ddev: max θ M(fω , Ddev) = max θ M(F(Dtrain, µθ), Ddev). (2) We introduce the detailed training process (i.e., how to efficiently optimize Eqn. (2)) in section 3.3. In the testing process of the teacher model, θ is fixed and the student model fω gets updated with the guidance of teacher model µθ, as specified in Eqn. (1). 3.3 Training Process of Teacher Model There are two challenges to optimize teacher model: 1) the evaluation measure m is typically nonsmooth and non-differentiable w.r.t. the parameters of student model; 2) the error is incurred on dev set while the teacher model plays effect in training phase. We use continuous relaxation of m to tackle the first challenge. The main idea is to inject randomness into m to form an approximated version m, where the randomness comes from the student model [49]. Thanks to the fact that quite a few student models output probabilistic distributions on Y, the randomness naturally comes from the direct outputs of fω. Specifically, to approximate the performance of fω on a test data sample (x, y), we have m(fω(x), y) = P y Y m(y , y)pω(y |x), where pω(y |x) is the probability of predicting y given x using fω. The gradient of ω is then easy to obtain via m(fω(x),y) y Y m(y , y) pω(y |x) ω . We further introduce a new notation M(fω, Ddev) = P (x,y) Ddev m(fω(x), y) which approximates the objective of the teacher model M(fωT , Ddev). We use Reverse-Mode Differentiation (RMD) [6, 7, 38] to fill in the gap between training data and development data. To better show the RMD process, we can view the sequential process in Eqn. (1) as a special feed-forward process of a deep neural network where each t corresponds to one layer, and RMD corresponds to the backpropagation process looping the SGD process backwards from T to 1. Specifically denote dθ as the gradient of M(fωT , Ddev) w.r.t. the teacher model parameters θ, which has initial value dθ = 0. On the dev dataset Ddev, the gradient of M(fω, Ddev) w.r.t. the parameter of student model ωT is calculated as dωT = M(fωT , Ddev) m(fωT (x), y) Then looping backwards from T and corresponding to Eqn. (1), at each step t = {T 1, , 1} we have dωt = M(fωt, Ddev) ωt = dωt+1 ηt 2Lµθ(st)(fωt, Dt train) ω2 t dωt+1. (4) At the same time, the gradient of M w.r.t. θ is accumulated at this time step as: dθ = dθ ηt 2Lµθ(st)(fωt, Dt train) θ ωt dωt+1. (5) We leave the detailed derivations for Eqn. (4) and (5) to Appendix. Furthermore it is worth-noting that the computing of dωt and dθ involves hessian vector product, which can be effectively computed via 2g x yv = ( g yv)/ x, without explicitly calculating the Hessian matrix. Reverting backwards from t = T to t = 1, we obtain dθ and then θ is updated using any gradient based optimization algorithm such as momentum SGD, forming one step optimization for θ which we call teacher optimization step. By iterating teacher optimization steps we obtain the final teacher model. The details are listed in Algorithm 1. Algorithm 1 Training Teacher Model µθ Input: Continuous relaxation m. Initial value of θ. while Teacher model parameter θ not converged do One teacher optimization step Randomly initialize student model parameter ω0. for each time step t = 0, , T 1 do Teach student model Conduct student model training step via Eqn. (1). end for dθ = 0. Compute dωT via Eqn. (3). for each time step t = T 1, , 0 do Reversely calculating the gradient dθ Update dθ as Eqn. (5). Compute dωt as Eqn. (4). end for Update θ using dθ via gradient based optimization algorithm. end while Output: the final teacher model µθ. 3.4 Discussion Another possible way to conduct teacher model optimization is through deep reinforcement learning. By treating the teacher model as a policy outputting continuous action (i.e., the loss function), one can leverage continuous control algorithm such as DDPG [31] to optimize teacher model. However, reinforcement learning algorithms, including Q-learning based ones such as DDPG are sample inefficient, probably requiring huge amount of sampled trajectories to approximate the reward using a critic network. Considering the training of student model is typically costly, we resort to gradient based optimization algorithms instead. Furthermore, there are similarity between L2T-DLF and actor-critic (AC) method [5, 48] in reinforcement learning (RL), in which a critic (corresponding to the parametric loss function) guides the optimization of an actor (corresponding to the student model). Apart from the difference within application domain (supervised learning versus RL), there are differences between the design principle of L2T-DLF and AC. For AC, by treating student model as actor, the student model output (e.g., fωt(xt)) is essentially the action at timestep t, fed into the critic to output an approximation to the future reward (e.g., dev set accuracy). This is typically difficult since: 1) the student model output (i.e., the action) at a particular step t is weakly related with the final dev performance. Therefore optimizing its action with the guidance from critic network is largely meaningless; 2) the approximation to the future (a) loss function (b) teacher model Figure 2: Left: the bilinear neural network specifying the loss function lΦt(pω, y) = σ( y Φt log pw). Right: the teacher model outputting Φt via attention mechanism:Φt = µθ(st) = Wsoftmax(V st). reward is hard given the performance measure is highly non-smooth. As a comparison, L2T-DLF is more general in that at each timestep: 1) the teacher model considers the overall status of the student model for the sake of optimizing its parameters, rather than the instant action (i.e., the direct output); 2) the teacher model outputs a loss function with the goal of maximizing, but not approximating the future reward. In that sense, L2T-DLF is more appropriate to real world applications. 4 Experiments We conduct comprehensive empirical verifications of the proposed L2T-DLF, in automatically discovering the most appropriate loss functions for student model training. The tasks in our experiments come from two domains: image classification, and neural machine translation. 4.1 Image Classification The evaluation measure m here is the 0-1 accuracy: m(y1, y2) = 1y1=y2 where 1 is the 0-1 indicator function. The student model fω can be a logistic classifier specifying a softmax distribution pω(y|x) = exp (w yx + by)/ P y Y exp (w y x + by ) with ω = {wy , by }y Y. The class label is predicted as ˆy = arg maxy Y pω(y |x) given input data x. Instead of imposing loss on ˆy and ground-truth y, for the sake of efficient optimization l typically takes the direct model output pω and y as inputs. For example, the most widely adopted loss function l is cross-entropy loss l(pω, y) = log pω(y|x), which could be re-written in vector form l(pω, y) = y log pω, where y {0, 1}|Y| is a one-hot representation of the true label y, i.e., yj = 1j=y, j Y, y is the transpose of y and pw R|Y| is the probabilities for each class outputted via fω. Generalizing the cross entropy loss, we set the loss function coefficients Φ as a matrix interacting between log pw and y, which switches loss function at t-th timestep into lΦt(pω, y) = σ( y Φt log pw), Φt R|Y| |Y|, as is shown in Fig. 2(a). σ is the sigmoid function. The teacher model µθ here is then responsible for setting Φt according to the state feature vector of student model st: Φt = µθ(st). One possible form of the teacher model is a neural network with attention mechanism (shown in Fig. 2(b)): Φt = µθ(st) = Wsoftmax(V st), where W R|Y| |Y| N, V RN |st| constitute the teacher model parameter set θ, N = 10 is the number of keys in attention mechanism. The state vector st is a 13 dimensional vector composing of 1) the current iteration number t; 2) current training accuracy of fω; 3) current dev accuracy of fω; 4) current precision of fω for the 10 classes on the dev set, all normalized into [0, 1]. We choose three widely adopted datasets: the MNIST, CIFAR-10 and CIFAR-100 datasets. For the sake of showing the robustness of L2T-DLF, the student models we choose cover a wide range, including multi-layer perceptron (MLP), plain convolutional neural network (CNN) following Le Net architecture [28], and advanced CNN architecture including Res Net [21], Wide-Res Net [55] and Dense Net [24]. For all the student models, we use momentum stochastic gradient descent to perform training. In Appendix we describe the network structures of student models. The different loss functions we compare include: 1) Cross entropy loss Lce(pω(x), y) = log pω(y|x), which is the most widely adopted loss function to train neural network model; Table 1: The recognition results (error rate %) on MNIST dataset. Student Model/ Loss Cross Entropy [11] Smooth [40] Large-Margin Softmax [36] L2T-DLF MLP 1.94 1.89 1.83 1.69 Le Net 0.98 0.94 0.88 0.77 Table 2: The recognition results (error rate %) on CIFAR-10 (C10) and CIFAR-100 (C100) dataset Student Model/ Loss Cross Entropy [11] Smooth [40] Large-Margin Softmax [36] L2T-DLF C10/C100 C10/C100 C10/C100 C10/C100 Res Net-8 12.45/39.79 12.08/39.52 11.34/38.93 10.82/38.27 Res Net-20 8.75/32.33 8.53/32.01 8.02/31.65 7.63/30.97 Res Net-32 7.51/30.38 7.42/30.12 7.01/29.56 6.95/29.25 WRN 3.80/- 3.81/- 3.69/- 3.42/- Dense Net-BC 3.54/- 3.48/- 3.37/- 3.08/- 2) The smooth 0-1 loss proposed in [40]. It optimizes a smooth version of 0-1 accuracy in binary classification. We extend it to handle multi-class case by modifying the loss function as Lsmooth(pω(x), y) = log σ(K(log pω(y|x) maxy =y log pω(y |x))). It is not difficult to observe when K + , Lsmooth exactly matches the 0-1 accuracy. We choose the value of K to be 50 according to the performance on dev set; 3) The large-margin softmax loss in [36] denoted as Llm, which aims to enhance discrimination between different classes via maximizing the margin induced by the angle between x and a target class representation wy. We use the open-sourced code released by the authors in our experiment; 4) The loss function discovered via the teacher in L2T-DLF. The teacher models are optimized with Adam [26] and the detailed setting is in Appendix. The classification results on MNIST, CIFAR-10 and CIFAR-100 are respectively shown in Table 1 and 2. As can be observed, on all the three tasks, the dynamic loss functions outputted via teacher model help to cultivate better student model. For example, the teacher model helps WRN to achieve 3.42% classification error rate on CIFAR-10, which is on par with the result discovered via automatic architecture search (e.g., 3.41% of NASNet [57]). Furthermore, our dynamic loss functions for Dense Net on CIFAR-10 reduces the error rate of Dense Net-BC (k=40) from 3.54% to 3.08%, where the gain is a non-trival margin. 4.1.1 Teacher Optimization In Fig. 3, we provide the dev measure performance along with the teacher model optimization in MNIST experiment, the student model is Le Net. It can be observed that the dev measure is increasing along with the teacher model optimizing, and finally converges to a high score. 4.1.2 Analysis Towards the Loss Functions To better understand the loss functions outputted via teacher model, we visualize the coefficients of some loss functions outputted by teacher model for training Res Net-8 in CIFAR-100 classification task. Specifically, note that the loss function lΦt(pω, y) = σ( y Φt log pw) essentially characterizes the correlations among different classes via the coefficients Φt. Positive Φt(i, j) value means positive correlation between class i and j that their probabilities should be jointly maximized whereas negative value imposes negative correlation and higher discrimination between the two classes i and j. We choose two classes in CIFAR-100: the Otter and Baby as class i and for each of them pick several representative classes as class j. The corresponding Φt(i, j) values are visualized in Fig. 4, with t = 20, 40, 60 denoting the coefficients outputted via teacher model at t-th epoch of student model training. As can be observed, at the initial phase of training student model (t = 20), the teacher model chooses to enhance the correlation between two similar classes, e.g, Otter and Dolphin, Baby and Boy, for the sake of speeding up training. Comparatively, when the student model is powerful enough (t = 60), the teacher model will force it to perform better in discriminating two similar classes, as indicated via the more negative coefficient values Φt(i, j). The variation of Φt(i, j) 0 5 10 15 20 25 30 35 40 Teacher optimization step Dev measure Le Net on MNIST Figure 3: Measure score on the MNIST dev set along the teacher model optimization. The student model is Le Net. (a) Class Otter (b) Class Baby Figure 4: Coefficient matrix Φt outputted via teacher model. The y-axis (20, 40, 60) corresponds to the different epochs of the student model training. Darker color means the coefficients value are more negative while shallower color means more positive. In each figure, the leftmost two columns denote similar classes and the rightmost three columns represent dissimilar classes. values w.r.t. t well demonstrates the teacher model captures the status of student model in outputting correspondingly appropriate loss functions. 4.2 Neural Machine Translation In the task of neural machine translation (NMT), the evaluation measure m(ˆy, y) is typically the BLEU score [41] between the translated sentence ˆy and ground-truth reference y. The student model fω is a neural network performing sequence-to-sequence generation based on models including RNN [47], CNN [16] and self-attention network [51]. The decoding process of fω is typically autoregressive, in that fω factorizes the translation probability as pω(y|x) = Q|y| r=1 pω(yr|x, y