Official code for "Long-Horizon Model-Based Offline Reinforcement Learning Without Conservatism", NeurIPS 2025 Workshop on Aligning Reinforcement Learning Experimentalists and Theorists.
Authors: Tianwei Ni, Esther Derman, Vineet Jain, Vincent Taboga, Siamak Ravanbakhsh, Pierre-Luc Bacon.
Most offline RL methods rely on conservatism, either by penalizing out-of-dataset actions or by restricting planning horizons. While effective for stability, we show that conservatism fundamentally hurts generalization.
Using a bandit example, we show that mild uncertainty penalties prevent adaptation and lead to suboptimal decisions:
In contrast, a non-conservative Bayesian agent (penalty = 0) enables test-time generalization by adaptation:
- generalizes to better but unseen arm, or
- re-commits to better and seen arm after exploratory interaction.
This corresponds to solving the "epistemic POMDP" (Ghosh et al., 2022): construct a POMDP from the offline dataset and train a history-dependent agent on it. Conceptually simple and Bayes-optimal.
When scaling Bayesian agents to MDPs, removing conservatism exposes value overestimation. We find that long-horizon model rollouts are essential to counteract this effect.
Long horizons introduce compounding model errors. Two simple design choices make them practical:
- Layer normalization in the world model
- Uncertainty-based adaptive rollout truncation
Together, they substantially reduce error accumulation:
Combining these insights with other design choices like small context learning rates yields NEUBAY, a practical algorithm grounded in the neutral Bayesian principle.
On D4RL and NeoRL benchmarks, NEUBAY performs competitively with state-of-the-art conservative baselines, while avoiding pessimism entirely. NEUBAY plans over several hundred steps, challenging the dominant short-horizon practice.
We test our codebase on L40S and A100 GPUs. Our code is written in JAX & Equinox. To install dependencies:
conda create -n neubay python=3.10
conda activate neubay
conda env update -f requirements.ymlThese are mainly standard D4RL dependencies.
To support reproducibility and simplify agent tuning, we provide pretrained world-model ensembles in the Google Drive folder:
After downloading, place the folder under offline_world/ckpt. For each dataset, we uploaded six checkpoints for each random seed used in our experiments.
configs: configurations for each benchmark and datasetexperience: tape-based replay buffer (agent_buffer.py), planner (collector.py), evaluation on the true MDP (evaluator.py), offline data storage for world modeling and history sampling (world_buffer.py)memory: linear recurrent unit adapted from Memoroidsneorl: adapted from NeoRL codebase, included directly to avoid installationoffline_world: world model training and interface for agent interactiononline_rl: online RL algorithms (e.g., REDQ) for training with a mixture of offline and synthetic data
We run 3 seeds in parallel for acceleration and we recommend setting
export XLA_PYTHON_CLIENT_PREALLOCATE=falseto prevent a single process from allocating all GPU memory. All the training logs will be uploaded to wandb.
Offline data for the bandit in Section 3 is provided in offline_world/data. You may generate your own dataset by python offline_world/bandit_data.py.
# train your own reward ensemble; skip this if you use pretrained one
export PYTHONPATH=${PWD}:$PYTHONPATH
python offline_world/bandit_ensemble.py --seed 0
# train agent on the reward ensemble with the same seed
python offline_bandit.py seed=0In our experiments, we sweep over collect.penalty_coef and train.enc_lr to report the best result (see the Appendix).
First download all datasets by running
python get_all_datasets.pyTrain the world model ensemble (skip if using pretrained checkpoints). We cap training with ensemble.total_epochs on some datasets to prevent very long training time.
export PYTHONPATH=${PWD}:$PYTHONPATH
# d4rl locomotion (dataset with the same suffix share the same total_epochs)
python offline_world/cont_ensemble.py --config-path=../configs/d4rl_loco --config-name=base dataset_name=hopper-random-v2 ensemble.total_epochs=1200 seed=0
python offline_world/cont_ensemble.py --config-path=../configs/d4rl_loco --config-name=base dataset_name=halfcheetah-medium-replay-v2 seed=0
python offline_world/cont_ensemble.py --config-path=../configs/d4rl_loco --config-name=base dataset_name=walker2d-medium-v2 ensemble.total_epochs=1200 seed=0
python offline_world/cont_ensemble.py --config-path=../configs/d4rl_loco --config-name=base dataset_name=halfcheetah-medium-expert-v2 ensemble.total_epochs=600 seed=0
# neorl locomotion (all use the same total_epochs)
python offline_world/cont_ensemble.py --config-path=../configs/neorl --config-name=base dataset_name=Hopper-v3-low ensemble.total_epochs=1200 seed=0
# adroit
python offline_world/cont_ensemble.py --config-path=../configs/adroit --config-name=base dataset_name=pen-human-v1 seed=0
python offline_world/cont_ensemble.py --config-path=../configs/adroit --config-name=base dataset_name=pen-cloned-v1 ensemble.total_epochs=2400 seed=0
python offline_world/cont_ensemble.py --config-path=../configs/adroit --config-name=base dataset_name=hammer-cloned-v1 ensemble.total_epochs=1200 seed=0
# antmaze (all use the same total_epochs)
python offline_world/cont_ensemble.py --config-path=../configs/antmaze --config-name=base dataset_name=antmaze-umaze-v2 ensemble.total_epochs=1200 seed=0Then train the recurrent agent on the pretrained ensemble. We sweep over train.critic_enc_lr = train.actor_enc_lr and train.real_weight, and save the best hparams in configs/<domain>/task.
# d4rl locomotion
python offline_cont.py --config-path=configs/d4rl_loco --config-name=base dataset_name=halfcheetah-medium-expert-v2 seed=0
python offline_cont.py --config-path=configs/d4rl_loco --config-name=base task=halfcheetah_medium_expert seed=0 # use the per-task tuned hparams
# neorl locomotion
python offline_cont.py --config-path=configs/neorl --config-name=base dataset_name=Hopper-v3-low seed=0
python offline_cont.py --config-path=configs/neorl --config-name=base task=Hopper_v3_low seed=0
# adroit
python offline_cont.py --config-path=configs/adroit --config-name=base dataset_name=pen-cloned-v1 seed=0
python offline_cont.py --config-path=configs/adroit --config-name=base task=pen_cloned seed=0
# antmaze
python offline_cont.py --config-path=configs/antmaze --config-name=base dataset_name=antmaze-umaze-v2 seed=0
python offline_cont.py --config-path=configs/antmaze --config-name=base task=umaze seed=0For the ablation study on the Markov agent (sweeping over train.real_weight):
# d4rl locomotion
python offline_markov.py --config-path=configs/d4rl_loco --config-name=base_markov dataset_name=halfcheetah-medium-expert-v2 seed=0
# neorl locomotion
python offline_markov.py --config-path=configs/neorl --config-name=base_markov dataset_name=Hopper-v3-low seed=0
# adroit
python offline_markov.py --config-path=configs/adroit --config-name=base_markov dataset_name=pen-cloned-v1 seed=0
# antmaze
python offline_markov.py --config-path=configs/antmaze --config-name=base_markov dataset_name=antmaze-umaze-v2 seed=0For the other studies, based on the best config using task=<dataset>, change (defaults in bold):
- Uncertainty quantile as rollout threshold:
collect.unc_quantilein (0.9, 0.99, 0.999, 1.0) - Uncertainty penalty as conservative term:
collect.penalty_coefin (0.0, 0.04, 0.2, 1.0, 5.0) - Ensemble size used in planning (the total ensemble size is fixed at 128):
collect.ensemble_sizein (5, 20, 100)
We provide both LayerNorm-ed (d4rl_loco, 6 seeds) and non-LayerNormed (d4rl_loco_no_ln, 1 seed) world model checkpoints. To compare them:
# collect rollouts: trained on hc-random and evaluated on hc-medium-replay
python eval_error.py --config-path=configs/d4rl_loco --config-name=base dataset_name=halfcheetah-random-v2 +eval_dataset=medium-replay-v2 collect.unc_quantile=-1.0
python eval_error.py --config-path=configs/d4rl_loco --config-name=base_no_ln dataset_name=halfcheetah-random-v2 +eval_dataset=medium-replay-v2 collect.unc_quantile=-1.0
# plot rollout stats
python plot_ln.py --train_dataset halfcheetah-random-v2 --eval_dataset medium-replay-v2- https://github.com/patrick-kidger/equinox for the JAX deep learning framework
- https://github.com/proroklab/memoroids for the LRU-based RL implementation in Equinox
- https://github.com/yihaosun1124/OfflineRL-Kit for offline model-based RL implementation in PyTorch
- https://github.com/Howuhh/sac-n-jax for SAC implementation in Equinox
- https://github.com/kwanyoungpark/LEQ and https://github.com/HxLyn3/ADMPO for AntMaze-related configurations
- https://github.com/FanmingL/Recurrent-Offpolicy-RL for recurrent RL configurations on the encoder learning rate
Please open an issue for technical problems or send an email to Tianwei (twni2016@gmail.com) for questions about the paper.



