# a_tractable_inference_perspective_of_offline_rl__b0a080fa.pdf A Tractable Inference Perspective of Offline RL Xuejie Liu1,3 , Anji Liu2 , Guy Van den Broeck2, Yitao Liang1 1Institute for Artificial Intelligence, Peking University 2Computer Science Department, University of California, Los Angeles 3School of Intelligence Science and Technology, Peking University xjliu@stu.pku.edu.cn, liuanji@cs.ucla.edu guyvdb@cs.ucla.edu, yitaol@pku.edu.cn A popular paradigm for offline Reinforcement Learning (RL) tasks is to first fit the offline trajectories to a sequence model, and then prompt the model for actions that lead to high expected return. In addition to obtaining accurate sequence models, this paper highlights that tractability, the ability to exactly and efficiently answer various probabilistic queries, plays an important role in offline RL. Specifically, due to the fundamental stochasticity from the offline data-collection policies and the environment dynamics, highly non-trivial conditional/constrained generation is required to elicit rewarding actions. While it is still possible to approximate such queries, we observe that such crude estimates undermine the benefits brought by expressive sequence models. To overcome this problem, this paper proposes Trifle (Tractable Inference for Offline RL), which leverages modern tractable generative models to bridge the gap between good sequence models and high expected returns at evaluation time. Empirically, Trifle achieves 7 state-of-the-art scores and the highest average scores in 9 Gym-Mu Jo Co benchmarks against strong baselines. Further, Trifle significantly outperforms prior approaches in stochastic environments and safe RL tasks with minimum algorithmic modifications. 3 1 Introduction Recent advancements in deep generative models have opened up the possibility of solving offline Reinforcement Learning (RL) [27] tasks with sequence modeling techniques (termed Rv S approaches). Specifically, we first fit a sequence model to the trajectories provided in an offline dataset. During evaluation, the model is tasked to sample actions with high expected returns given the current state. Leveraging modern deep generative models such as GPTs [5] and diffusion models [18], Rv S algorithms have significantly boosted the performance on various RL problems [1, 6]. Despite its appealing simplicity, it is still unclear whether expressive modeling alone guarantees good performance of Rv S algorithms, and if so, on what types of environments. This paper discovers that many common failures of Rv S algorithms are not caused by modeling problems. Instead, while useful information is encoded in the model during training, the model is unable to elicit such knowledge during evaluation. Specifically, this issue is reflected in two aspects: (i) inability to accurately estimate the expected return of a state and a corresponding action sequence to be executed given near-perfect learned transition dynamics and reward functions; (ii) even when accurate return estimates exist in the offline dataset and are learned by the model, it could still fail to sample rewarding actions during evaluation.4 At the heart of such inferior evaluation-time performance is the fact that highly Equal contribution Corresponding author 3Our code is available at https://github.com/liebenxj/Trifle.git 4Both observations are supported by empirical evidence as illustrated in Section 3. 38th Conference on Neural Information Processing Systems (Neur IPS 2024). non-trivial conditional generation is required to stimulate high-return actions [32, 3]. Therefore, other than expressiveness, the ability to efficiently and exactly answer various queries (e.g., computing the expected returns), termed tractability, plays an equally important role in Rv S approaches. Having observed that the lack of tractability is an essential cause of the underperformance of Rv S algorithms, this paper studies whether we can gain practical benefits from using Tractable Probabilistic Models (TPMs) [35, 7, 23], which by design support exact and efficient computation of certain queries? We answer the question in its affirmative by showing that we can leverage a class of TPMs that support computing arbitrary marginal probabilities to significantly mitigate the inference-time suboptimality of Rv S approaches. The proposed algorithm Trifle (Tractable Inference for Offline RL) has three main contributions: Emphasizing the important role of tractable models in offline RL. This is the first paper that demonstrates the possibility of using TPMs on complex offline RL tasks. The superior empirical performance of Trifle suggests that expressive modeling is not the only aspect that determines the performance of Rv S algorithms, and motivates the development of better inference-aware Rv S approaches. Competitive empirical performance. Compared against strong offline RL baselines (including Rv S, imitation learning, and offline temporal-difference algorithms), Trifle achieves the state-of-the-art result on 7 out of 9 Gym-Mu Jo Co benchmarks [14] and has the best average score. Generalizability to stochastic environments and safe-RL tasks. Trifle can be extended to tackle stochastic environments as well as safe RL tasks with minimum algorithmic modifications. Specifically, we evaluate Trifle in 2 stochastic Open AI-Gym [4] environments and action-space-constrained Mu Jo Co environments, and demonstrate its superior performance against all baselines. 2 Preliminaries Offline Reinforcement Learning. In Reinforcement Learning (RL), an agent interacts with an environment that is defined by a Markov Decision Process (MDP) S, A, R, P, d0 to maximize its cumulative reward. Specifically, the S is the state space, A is the action space, R : S A R is the reward function, P : S A S is the transition dynamics, and d0 is the initial state distribution. Our goal is to learn a policy π(a|s) that maximizes the expected return E[PT t=0 γtrt], where γ (0, 1] is a discount factor and T is the maximum number of steps. Offline RL [27] aims to solve RL problems where we cannot freely interact with the environment. Instead, we receive a dataset of trajectories collected using unknown policies. An effective learning paradigm for offline RL is to treat it as a sequence modeling problem (termed RL via Sequence Modeling or Rv S methods) [20, 6, 13]. Specifically, we first learn a sequence model on the dataset, and then sample actions conditioned on past states and high future returns. Since the models typically do not encode the entire trajectory, an estimated value or return-to-go (RTG) (i.e., the Monte Carlo estimate of the sum of future rewards) is also included for every state-action pair, allowing the model to estimate the return at any time step. Figure 1: An example PC over boolean variables X1, . . . , X4. Every node s probability given input x1x2 x3x4 is labeled in blue. p(x1x2 x3x4) = 0.22. Tractable Probabilistic Models. Tractable Probabilistic Models (TPMs) are generative models that are designed to efficiently and exactly answer a wide range of probabilistic queries [35, 7, 37]. One example class of TPMs is Hidden Markov Models (HMMs) [36], which support linear time (w.r.t. model size and input size) computation of marginal probabilities and more. Probabilistic Circuits (PCs) [7] are a general class of TPMs. As shown in Figure 1, PCs consist of input nodes that represent simple distributions (e.g., Gaussian, Categorical) over one or more variables as well as sum and product nodes that take other nodes as input and gradually form more complex distributions. Specifically, product nodes model factorized distributions over their inputs, and sum nodes build weighted mixtures (mixture weights are labeled on the corresponding edges in Fig. 1) over their input distributions. Please refer to Appx. B for a more detailed introduction to PCs. 0 100 200 300 400 Average Estimated Returns Normalized Actual Returns halfcheetah-MR halfcheetah-M walker2d-MR walker2d-M hopper-MR hopper-M 0.5 0.6 0.7 0.8 0.9 Inference-time Optimality Score Normalized Actual Returns halfcheetah-MR halfcheetah-M halfcheetah-ME walker2d-MR walker2d-M walker2d-ME walker2d-MR walker2d-M walker2d-ME DT TT Trifle(ours) walker2d-M hopper-M halfcheetah-M DT TT Trifle(ours) Inference-time Optimality Score Figure 2: Rv S approaches suffer from inference-time suboptimality. Left: There is a strong positive correlation between the average estimated returns by Trajectory Transformers (TT) and the actual returns in 6 Gym-Mu Jo Co environments (MR, M, and ME denote medium-replay, medium, and medium-expert, respectively), which suggests that the sequence model can distinguish rewarding actions from the others. Middle: Despite being able to recognize high-return actions, both TT and DT [6] fail to consistently sample such action, leading to bad inference-time optimality; Trifle consistently improves the inference-time optimality score. Right: We substantiate the relationship between low inference-time optimality scores and unfavorable environmental outcomes by showing a strong positive correlation between them. Recent advancements have extensively pushed forward the expressiveness of modern PCs [30, 31, 9], leading to competitive likelihoods on natural image and text datasets compared to even strong Variational Autoencoder [43] and Diffusion model [22] baselines. This paper leverages such advances and explores the benefits brought by PCs in offline RL tasks. 3 Tractability Matters in Offline RL Practical Rv S approaches operate in two main phases training and evaluation. In the training phase, a sequence model is adopted to learn a joint distribution over trajectories of length T: {(st, at, rt, RTGt)}T t=0.5 During evaluation, at every time step t, the model is tasked to discover an action sequence at:T := {aτ}T τ=t (or just at) that has high expected return as well as high probability in the prior policy p(at:T |st), which prevents it from generating out-of-distribution actions: p(at:T |st, E[Vt] v) := 1 Z p(at:T |st) if EVt p( |st,at)[Vt] v, 0 otherwise, (1) where Z is a normalizing constant, Vt is an estimate of the value at time step t, and v is a pre-defined scalar chosen to encourage high-return policies. Depending on the problem, Vt could be the labeled RTG from the dataset (e.g., RTGt) or the sum of future rewards capped with a value estimate (e.g., PT 1 τ=t rτ + RTGT ) [13, 20]. The above definition naturally reveals two key challenges in Rv S approaches: (i) training-time optimality (i.e., expressivity ): how well can we fit the offline trajectories, and (ii) inference-time optimality: whether actions can be unbiasedly and efficiently sampled from Equation (1). While extensive breakthroughs have been achieved to improve the training-time optimality [1, 6, 20], it remains unclear whether the non-trivial constrained generation task of Equation (1) hinders inferencetime optimality. In the following, we present two general scenarios where existing Rv S approaches underperform as a result of suboptimal inference-time performance. We attribute such failures to the fact that these models are limited to answering certain query classes (e.g., autoregressive models can only compute next token probabilities), and explore the potential of tractable probabilistic models for offline RL tasks in the following sections. Scenario #1 We first consider the case where the labeled RTG belongs to a (near-)optimal policy. In this case, Equation (1) can be simplified to p(at|st, E[Vt] v) (choose Vt := RTGt) since onestep optimality implies multi-step optimality. In practice, although the RTGs are suboptimal, the predicted values often match well with the actual returns achieved by the agent. Take Trajectory 5To minimize computation cost, we only model truncated trajectories of length K (K < T) in practice. Transformer (TT) [20] as an example, Figure 2 (left) demonstrates a strong positive correlation between its predicted returns (x-axis) and the actual cumulative rewards (y-axis) on six Mu Jo Co [42] benchmarks, suggesting that the model has learned the goodness of most actions. In such cases, the performance of Rv S algorithms depends mainly on their inference-time optimality, i.e., whether they can efficiently sample actions with high predicted returns. Specifically, let at be the action taken by a Rv S algorithm at state st, and Rt := E[RTGt] is the corresponding estimated expected value. We define a proxy of inference-time optimality as the quantile value of Rt in the estimated state-conditioned value distribution p(Vt|st).6 The higher the quantile value, the more frequent the Rv S algorithm samples actions with high estimated returns. We evaluate the inference-time optimality of Decision Transformers (DT) [6] and Trajectory Transformers (TT) [20], two widely used Rv S algorithms, on various environments and offline datasets from the Gym-Mu Jo Co benchmark suite [14]. As shown in Figure 2 (middle), the inference-time optimality is averaged (only) around 0.7 (the maximum possible value is 1.0) for most settings. And these runs with low inference-time optimality scores receive low environment returns (Fig. 2 (right)). Scenario #2 Achieving inference-time optimality becomes even harder when the labeled RTGs are suboptimal (e.g., they come from a random policy). In this case, even estimating the expected future return of an action sequence becomes highly intractable, especially when the transition dynamics of the environment are stochastic. Specifically, to evaluate a state-action pair (st, at), since RTGt is uninformative, we need to resort to the multi-step estimate V m t := Pt 1 τ=t rτ + RTGt (t > t), where the actions at:t are jointly chosen to maximize the expected return. Take autoregressive models as an example. Since the variables are arranged following the sequential order . . . , st, at, rt, RTGt, st+1, . . . , we need to explicitly sample st+1:t before proceed to compute the rewards and the RTG in V m t . In stochastic environments, estimating E[V m t ] could suffer from high variance as the stochasticity from the intermediate states accumulates over time. As we shall illustrate in Section 6.2, compared to environments with near-deterministic transition dynamics, estimating the expected returns in stochastic environments using intractable sequence models is hard, and Trifle can significantly mitigate this problem with its ability to marginalize out intermediate states and compute E[V m t ] efficiently and exactly. 4 Exploiting Tractable Models The previous section demonstrates that apart from modeling, inference-time suboptimality is another key factor that causes the underperformance of Rv S approaches. Given such observations, a natural follow-up question is whether/how more tractable models can improve the evaluation-time performance in offline RL tasks? While there are different types of tractabilities (i.e., the ability to compute different types of queries), this paper focuses on studying the additional benefit of exactly computing arbitrary marginal/condition probabilities. This strikes a proper balance between learning and inference as we can train such a tractable yet expressive model thanks to recent developments in the TPM community [9, 30]. Note that in addition to proposing a competitive Rv S algorithm, we aim to highlight the necessity and benefit of using more tractable models for offline RL tasks, and encourage future developments on both inference-aware Rv S methods and better TPMs. As a direct response to the two failing scenarios identified in Section 3, we first demonstrate how tractability could help even when the labeled RTGs are (near-)optimal (Sec. 4.1). We then move on to the case where we need to use multi-step return estimates to account for biases in the labeled RTGs (Sec. 4.2). 4.1 From the Single-Step Case... Consider the case where the RTGs are optimal. Recall from Section 3 that our goal is to sample actions from p(at|st, E[Vt] v) (where Vt := RTGt). Prior works use two typical ways to approximately sample from this distribution. The first approach directly trains a model to generate return-conditioned actions: p(at|st, RTGt) [6]. However, since the RTG given a state-action pair is stochastic,7 sampling 6Due to the large action space, it is impractical to compute p(Vt|st) := P at p(Vt|st, at) p(at|st). Instead, in the following illustrative experiments, we train an additional GPT model p(Vt|st) using the offline dataset. 7This is true unless (i) the policy that generates the offline dataset is deterministic, (ii) the transition dynamics is deterministic, and (iii) the reward function is deterministic. from this RTG-conditioned policy could result in actions with a small probability of getting a high return, but with a low expected return [32, 3]. An alternative approach leverages the ability of sequence models to accurately estimate the expected return (i.e., E[RTGt]) of state-action pairs [20]. Specifically, we first sample from a prior distribution p(at|st), and then reject actions with low expected returns. Such rejection sampling-based methods typically work well when the action space is small (in which we can enumerate all actions) or the dataset contains many high-rewarding trajectories (in which the rejection rate is low). However, the action could be multi-dimensional and the dataset typically contains many more low-return trajectories in practice, rendering the inference-time optimality score low (cf. Fig. 2). Having examined the pros and cons of existing approaches, we are left with the question of whether a tractable model can improve sampled actions (in this single-step case). We answer it with a mixture of positive and negative results: while computing p(at|st, E[Vt] v) is NP-hard even when p(at, Vt|st) follows a simple Naive Bayes distribution, we can design an approximation algorithm that samples high-return actions with high probability in practice. We start with the negative result. Theorem 1. Let at := {ai t}k i=1 be a set of k boolean variables and Vt be a categorical variables with two categories 0 and 1. For some st, assume the joint distribution over at and Vt conditioned on st follows a Naive Bayes distribution: p(at, Vt|st) := p(Vt|st) Qk i=1 p(ai t|Vt, st), where ai t denotes the ith variable of at. Computing any marginal over the random variables is tractable yet conditioning on the expectation p(at|st, E[Vt] v) is NP-hard. The proof is given in Appx. A. While it seems hard to directly draw samples from p(at|st, E[Vt] v), we propose to improve the aforementioned rejection sampling-based method by adding a correction term to the original proposal distribution p(at|st) to reduce the rejection rate. Specifically, the prior is often represented by an autoregressive model such as GPT: p GPT(at|st) := Qk i=1 p GPT(ai t|st, a