# controllable_text_generation_with_neurallydecomposed_oracle__6042bd98.pdf Controllable Text Generation with Neurally-Decomposed Oracle Tao Meng University of California, Los Angeles tmeng@cs.ucla.edu Sidi Lu University of California, Los Angeles sidilu@cs.ucla.edu Nanyun Peng University of California, Los Angeles violetpeng@cs.ucla.edu Kai-Wei Chang University of California, Los Angeles kwchang@cs.ucla.edu We propose a general and efficient framework to control auto-regressive generation models with Neur Ally-Decomposed Oracle (NADO). Given a pre-trained base language model and a sequence-level boolean oracle function, we propose to decompose the oracle function into token-level guidance to steer the base model in text generation. Specifically, the token-level guidance is approximated by a neural model trained with examples sampled from the base model, demanding no additional auxiliary labeled data. Based on posterior regularization, we present the closed-form optimal solution to incorporate the token-level guidance into the base model for controllable generation. We further provide a theoretical analysis of how the approximation quality of NADO affects the controllable generation results. Experiments conducted on two tasks: (1) text generation with lexical constraints and (2) machine translation with formality control demonstrate that our framework efficiently guides the base model towards the given control factors while maintaining high generation quality. 1 Introduction Auto-regressive language models have been widely used for text generation. With the recent development of large-scale pre-trained language models (Radford et al., 2019; Brown et al., 2020; Raffel et al., 2020; Lewis et al., 2020), they have achieved state-of-the-art performances in applications such as machine translation (Bahdanau et al., 2015; Luong et al., 2015), image captioning (Anderson et al., 2018; You et al., 2016) and open domain text generation (Zhang and Lapata, 2014; Yao et al., 2019; Vinyals and Le, 2015; Shang et al., 2015; Lu et al., 2018). However, many applications such as open-domain creative generation (Yao et al., 2019; Goldfarb-Tarrant et al., 2020; Tian and Peng, 2022; Han et al., 2022; Chen et al., 2022; Spangher et al., 2022) require to control model output with specific sequence-level attributes. The attributes can be specified by a set of rules2 or by an abstract concept (e.g., the generated text follows a particular writing style). How to control auto-regressive language models to satisfy these attributes is an open challenge. In this paper, we propose a general and flexible framework for controllable text generation. Given a base pre-trained language model and a sequence-level oracle function indicating whether an attribute is satisfied, our goal is to guide the text generation to satisfy certain attributes using the oracle. To equal contribution 2For example, lexical constraints require certain words to appear in the generated text (Hokamp and Liu, 2017; Lin et al., 2020) 36th Conference on Neural Information Processing Systems (Neur IPS 2022). A girl with hair sits at a table around. A woman is sitting and looking down. Base Model NADO (a) Take lexically constrained generation as an example, where the oracle checks whether all keywords in the input x are incorporated in generated text y. With proper training using samples from the base model p (dashed arrow) labeled by the oracle, we decompose the oracle into token-level guidance and parameterize it by an auxiliary model Rθ (NADO). We use Rθ to provide guidance when generating text with the base model (see details in Fig. 1(b)). Base Auto-regressive (b) Illustration of the controlled generation process. Both the base model and the auxiliary model (NADO) take input x and the generated sequence (prefix) y 1 such that given input x, y 0 to balance these losses. The final training loss is L(x, y, RC θ ) = LCE(x, y, RC θ ) + λLreg(x, y, RC θ ). (8) 3.5 Sampling In Sec. 3.4 we describe that we train NADO by sampled data from base model p. One advantage is that we are able to leverage different sampling strategies to better adapt to different application scenarios. It is also possible to leverage reinforcement learning to train RC θ , and we discuss our connection to reinforcement learning in the appendix. In this section, we introduce two sampling strategies and their corresponding properties. Sampling with Temperature Control. In some task, the output sequences are not diverse much, in other words, the token distribution in each step is very peaky. Since our NADO is trained on the sampled examples, we expect those examples to cover as much tokens combination as possible to avoid overfitting. Therefore, we add temperature factor T to smooth the distribution (Ackley et al., 1985). Specifically, we sample y from distribution p(y|x) 1 T , and add coefficient p(y|x)1 1 T when computing the cross-entropy loss. Formally, the expected loss is Ey p(y|x) 1 T h p(y|x)1 1 T LCE(x, y, RC θ ) i = X y Y p(y|x)LCE(x, y, RC θ ), which is same as the original expected loss in Eq. (7). Importance Sampling. In practice, the training process of NADO can be extraordinarily difficult when samples generated by the base model p hardly satisfy C. i.e. Ey p(y|x)[p(C|x, y)] 0. Hence, we introduce the importance sampling (Hammersley and Morton, 1954) to tackle this issue. Basically, we leverage existing partially trained ˆRθ to form distribution ˆq. Although ˆRθ is not well-trained, it is still able to provide positive guidance to produce samples satisfying C. Note that ˆq does not have to be updated in each training epoch. With coefficient p(y|x) ˆq(y|x), the expected loss is same as the original expected loss: ˆq(y|x)LCE(x, y, RC θ ) = X y Y p(y|x)LCE(x, y, RC θ ). 4 Experiments We conduct experiments on two tasks: lexically constrained generation (LCG) and machine translation (MT) with formality change. For the former, we use GPT-2 (Radford et al., 2019) as the base model and for the latter, we use a sequence-to-sequence model, Marian MT (Junczys-Dowmunt et al., 2018). We demonstrate our framework is generally effective in both scenarios. The boolean oracle is a rule-based function checking whether all lexical constraints are satisfied in LCG task, while in MT it is a classifier trained on an external dataset identifying the formality of the text. We put all details about hyper-parameter settings in the appendix. 4.1 Text Generation with Lexical Constraints We evaluate our model on two general classes of LCG problems: Unsupervised LCG: annotation for lexical constraints are not available during training, but are expected to be in their exact order and lexical form during inference. Supervised LCG: annotation for lexical constraints are available, yet the words may appear in a different lexical form (e.g., look can appear in the past tense looked ) or a different order in the generated text. In both cases, we define oracle C as a boolean function indicating whether the generated sequence satisfies all of the lexical constraints. We do not naturally have negative samples (i.e. the sequences that do not satisfy all constraints) to train the auxiliary model in both settings, thus, it is non-trivial to compare against methods requiring both positive and negative labeled data for training the auxiliary model like FUDGE and Ge Di. Data Setup For unsupervised LCG, we follow the settings in POINTER (Zhang et al., 2020) and conduct our experiments on Yelp! Review and News dataset. Each of the unsupervised LCG dataset contains a great number of un-annotated, raw sequences for training (160K for Yelp! Review and 268,586 for News). During inference, the model is expected to generate text lexically constrained in the exact order and form by a specific number of keywords (7 for Yelp! Review and 4 for News). For supervised LCG, we evaluate the proposed method on Common Gen (Lin et al., 2020). Common Gen is a supervised LCG task that aims to examine the commonsense of neural text generation models. For training, it contains 32,651 unique key concepts (i.e. the constraints) with 67,389 completed sequences in total. It also contains a validation set with 993 concepts and 4018 reference sequences. For a more robust evaluation, the dataset maintains an open leaderboard that benchmarks different approaches on a withheld test set. We follow most of the data configurations specified in the original paper that first introduced the datasets. General Model Setup We investigate the effectiveness of different factors in our framework by enumerating different combinations of them. We implement two types of base model: (Seq2seq base model) A sequence-to-sequence model p(y|x) that takes into account the lexical constraints as condition sequence input; (DA base model) A language model that is only domain-adapted to p(y) but unconditioned on anything. This is a challenging setting, since we impose the lexical constraints only with NADO. This setting is to better verify the effectiveness and efficiency of the proposed method and control irrelevant factors. Under both p(y|x) and p(y) settings, we fine-tune the base model from the pre-trained GPT2-Large. During training, NADO is trained as a Seq2seq-like model5, which takes in the keys (for unsupervised LCGs, they are generated by randomly sampling a specific number of natural words in the original sentence) and generates the token-level guidance RC θ (x, y i). For each pseudo key, we sample 32 target text with top-p (p = 0.8) random sampling from base model p. We conduct experiments to test different training setups for NADO: (NADO training) The proposed training process described in Sec. 3.4. (Warmup) We warm up NADO by maximizing the likelihood of positive samples, but only backpropagating the gradient to the parameters of Rθ. The warm-up RC θ is used for importance sampling described in Sec. 3.5. With DA base models, however, the warmup process is always incorporated for practical success of training (see the results for DA pretrained w/o warmup). We also consider the setting with warmup only, which can be treated as a stronger baseline to verify that the major improvement of our framework is not coming from the extended capacity in NADO. Results and Analysis We compare the performance under different setups of our model to previous state-of-the-art methods on LCG tasks, including insertion-based models (Levenshtein Transformer (Gu et al., 2019) with Lexical Constraints (Susanto et al., 2020), Ins Net (Lu et al., 2022a), etc.) and decoding-based algorithms. We also compare the results with a simple baseline which address the problems with a standard Seq2seq pipeline. The results are as shown in Table 1. NADO consistently improves the BLEU score and coverage in different setups. Furthermore, under the best setting of each task (see bolded items in the table), NADO performs significantly better than most baselines in generation quality and can achieve very good lexical constraints coverage rate. 5In this experiment, the input x is only describing the lexical constraint C. However, our framework also supports general inputs in other Seq2seq tasks with constraints. For example, machine translation with lexical constraints where the constraint C is different from the input x. 0 2 4 6 8 10 w/ Regularization w/o Regularization (a) BLEU-4 comparison of NADO training with/without using Eq. 6 0 2 4 6 8 10 w/ Regularization w/o Regularization (b) Coverage comparison of NADO training with/without using Eq. (6). Figure 2: Comparative study of the effectiveness of regularization in NADO training. Compared to Ins Net, it is much easier for an autoregressive model with NADO to handle flexible reordering/transformation of lexical constraints. This is reflected in the performance comparison of Ins Net and NADO on Common Gen dataset. Under most settings, a Seq2seq base model makes it easier for the framework to perform well, as it guarantees a reasonable level of lexical constraint coverage in even the initial state of the model. Using a DA pretrained base model is a even challenging setup since the lexical constraints are only imposed with NADO. Therefore, the base model distribution is much distinct from the one filtered by the oracle, which is shown by poor performances on both metrics. However, with warmup and NADO under importance sampling, we show that it is still possible to obtain a powerful model with the proposed method. To further study the correlation between the base model quality and the improvement of NADO, we conduct experiments on GPT-2 base model. The GPT-2 base model has lower scores with and without NADO compared with GPT-2 large, while the coverage improvements are similar. It shows NADO is capable to push the base model distribution towards the oracle if the base model has decent quality. We also do human evaluation on base model (GPT-2 Large fine-tune) and the best NADO system, together with the gold reference for comparison. The results are shown in Tab. 2. The evaluation metrics are detailed described in the Appendix. Some qualitative are shown in Tab. 3. To study the importance of the regularization term, we conduct an ablative study under the optimal setting on the Common Gen dataset (Seq2seq base model with NADO only). The results are shown in Figure 2. While the success of achieving lexical control does not degenerate when NADO w/o regularization overfits, adding regularization can significantly improve the robustness of NADO generation quality when training NADO for more epochs. 4.2 Machine Translation with Formality Change Datasets and Setup We follow the experimental setting in FUDGE (Yang and Klein, 2021) to formalize the results of machine translation. Given an informal source sentence, our goal is to translate it into formal sentence written in the target language. We conduct our experiments on Fisher and CALLHOME Spanish-English Speech Translation Corpus (Post et al., 2013), where both of the Spanish source and English reference are informal and casual. Instead of evaluating the translation on original references, we use the formal and fluent rewritten version of references (Salesky et al., 2019) to evaluate the translation quality by BLEU scores. In the training process, the formal version reference is unseen to the models. We also evaluate the formality scores by a discriminator trained on GYAFC formality dataset (Rao and Tetreault, 2018) as what FUDGE paper does. In this experiment, pre-trained Marian MT model (Junczys-Dowmunt et al., 2018) is used as the base model. In FUDGE, the authors train an auxiliary model also on GYAFC modeling token-level guidance P(formal|y