PaperHub
7.5
/10
Spotlight4 位审稿人
最低6最高8标准差0.9
8
8
8
6
4.3
置信度
正确性3.5
贡献度3.0
表达3.5
ICLR 2025

Learning Transformer-based World Models with Contrastive Predictive Coding

OpenReviewPDF
提交: 2024-09-25更新: 2025-02-16
TL;DR

We introduce TWISTER, a Transformer model-based reinforcement learning algorithm using action-conditioned Contrastive Predictive Coding to learn high-level feature representations and improve the agent performance.

摘要

关键词
model-based reinforcement learningtransformer networkcontrastive predictive coding

评审与讨论

审稿意见
8

In this paper, the authors investigated the Model-Based Reinforcement Learning (MBRL) agent with contrastive predictive coding. There have been many tries to apply Transformer as the world model backbone for MBRL. However, the authors argued the impacts from the tries are limited due to the limitations of the predictive loss (one-step next observation prediction), and proposed a new method equipped with the contrastive predictive coding loss. Through this, they reported the state-of-the-art performance on Atari100k benchmark among the MBRL agents without look-ahead search.

优点

  • The breakthrough performance on Atari100k among the MBRL agents without look-ahead search. They showed clearly better performance than previous works, while the works have shown comparable performances for each other.
  • Good motivation and validation of that. Their hypothesis, the predictive loss for next frame is not good enough to learn meaningful information, is reasonable and they validated it well through their proposed agent and experiments.
  • Well written paper. We can enjoy to read this paper, because of well summarized related works, well discussed their motivation and diverse empirical ablation studies.
  • The analyses for the experimental results. They properly discussed their experimental results, for instance, why their model shows better performance on Atari100k (because of capturing small but important objects which can be missed only through reconstruction loss). Ablations of their architectures are studied substantially.
  • The detailed hyper-parameters are shared in the appendix.- We didn’t find a critical weakness from this paper.

缺点

  • We didn’t find a critical weakness from this paper.

问题

  • In Figure 2, you showed well why one-step predictive loss can be insufficient to learn the world model, but what means the shaded lines in the graph?
  • Did you test the agent performance on Atari100k benchmark when doesn’t apply data augmentation? Only giving temporally near samples as the positive sample (without argumentation) is better I think (if the performance from that is comparable with the version equipped by the data argumentation).
  • There is another study to apply contrastive loss to MBRL agent [1], considering it as a reference can be helpful to make this paper better.
  • For AC-CPC Predictor, how did you input the actions? Did you concatenate them?
  • Can you test the performance when not using reconstruction loss for the world model training? It can be more native version of MBRL with contrastive loss.
  • In lines 426-432, your analyses why AC-CPC objective is helpful and what the previous works missed, make sense. As following your analyses, it makes sense that the reconstruction-based policy learning methods (IRIS and \delta-IRIS) show better performance than TWM or STORM in the Mean Human Normalized Score measurement. Then, why they (IRIS and \delta-IRIS) showed worse performance in Median Human Normalized Score measurement? Interestingly, your method outperforms for the two measurements. Can you explain which factors make this difference?
评论

We sincerely thank you for your positive and encouraging feedback on our submission. We are pleased to hear that you found our paper engaging and enjoyable to read.

In Figure 2, you showed well why one-step predictive loss can be insufficient to learn the world model, but what means the shaded lines in the graph?

The shaded lines represent cosine similarities obtained for individual games. The main curve is obtained by averaging similarities over all 26 games. We find that the game with the fastest declining state similarity is Up N Down, where the game state is highly dynamic. On the other hand, the game with the slowest declining state similarity is Amidar, where distant observations are highly correlated to current observations.

Did you test the agent performance on the Atari100k benchmark when doesn’t apply data augmentation? Only giving temporally near samples as the positive sample (without augmentation)

The paper experimented with data augmentation with the default setting of K=10 future contrastive steps. We demonstrates the utility of data augmentation for complexifying the AC-CPC objective. The use of random crops requires the world model to identify several key elements in the observations in order to accurately predict positives samples. We suppose that lowering K and removing data augmentation could significantly simplify the objective and degrade performance.

There is another study to apply contrastive loss to MBRL agent [1], considering it as a reference can be helpful to make this paper better.

Thank you for providing this work as an additional possible reference. Could you share with us the full reference of the paper? The review formatting seems to have removed the reference link.

For AC-CPC Predictor, how did you input the actions? Did you concatenate them?

Yes, we concatenate the sequence of actions along the feature dimension with the model states sts_{t}. We modified section 3.1 in the paper to specify that actions are first concatenated to condition the predictor network.

Can you test the performance when not using reconstruction loss for the world model training?

Our paper demonstrated the impact of AC-CPC on learning Transformer-based world model representations. We are studying if AC-CPC can also learn representations for the encoder network without using reconstruction.

why they (IRIS and Δ\Delta-IRIS) showed worse performance in Median Human Normalized Score measurement?

It is possible to achieve a higher Mean score with a lower Median if the best-performing games achieve high scores. This is why the Median metric is also important to evaluate the general performance of the methods. In the case of IRIS and Δ\Delta-IRIS, both methods archive very good performance on Breakout, which improves the Mean score compared to TWM and STORM. The two metrics are important and complementary to effectively evaluate the performance of methods.

评论

Thank you for your rebuttal in the limited time. Our concerns are well addressed in their rebuttal, I will keep my score.

Sorry for missing the reference, that is

Deng, Fei, Ingook Jang, and Sungjin Ahn. "Dreamerpro: Reconstruction-free model-based reinforcement learning with prototypical representations." International conference on machine learning. PMLR, 2022.

审稿意见
8

This paper introduces a novel world model for deep reinforcement learning that integrates DreamerV3 with transformer-based architectures and contrastive predictive coding. Instead of predicting only the next state, this world model is designed to predict multiple time steps into the future. The method demonstrates state-of-the-art performance on the Atari 100k benchmark.

优点

  • [S1] The paper presents a novel combination of established approaches, specifically transformer-based world models and CPC, which leads to strong performance results on the Atari 100k benchmark.
  • [S2] Ablation studies are conducted thoroughly, providing insight into the components of the model. (Though see Weakness [W2] for potential areas of improvement.)
  • [S3] The related work section on world models is thorough, with Table 1 providing a clear and valuable summary.

缺点

  • [W1] The explanation of AC-CPC is somewhat vague, which is problematic given the importance of CPC in the proposed method. Some specific concerns include:
    • (i) The authors do not clarify the necessity of the representation network, which is only briefly mentioned in Eq. (1).
    • (ii) Notation introduced in Eq. (1), specifically etke_t^k and e^tk\hat{e}_t^k, is not referenced further in the text.
    • (iii) Visualizations of the representation networks and AC-CPC predictors in Figure 3(a) could enhance clarity.
  • [W2] The ablation studies section could benefit from improved consistency. Specifically, the order of the studies differs across Table 3, the main text, and Figure 6, which may lead to confusion. I recommend aligning the order across these references. Additionally, inconsistent naming conventions add to the lack of clarity. Suggested improvements:
    • In Figure 6(a), "Number of CPC steps" could be updated to "Number of contrastive steps" to match the main text.
    • The paragraph titled "Future Actions Conditioning" could be revised to "Action Conditioning," aligning with Figure 6(d).
  • [W3] Many figures suffer from small text, making them challenging to read without zooming. Figure 6, in particular, has unnecessary whitespace around plots, which could be adjusted to improve readability.
  • [W4] Including evaluations on other benchmarks, such as the DeepMind Control Suite, would strengthen the soundness of the findings. However, I acknowledge that this may be difficult to implement.

问题

  • [Q1] Do the authors have any insight into why the performance drops when predicting 15 steps into the future?
评论

Thank you for reviewing our paper and providing many valuable suggestions. Please find below our response to the concerns and questions that you raised in the review.

[W1] (i) The authors do not clarify the necessity of the representation network, which is only briefly mentioned in Eq. (1).

We updated section 3.1 (lines 313-319) to further detail the utility of the representation network for computing similarities. Contrary to the original CPC paper, which experiments with continuous feature states, we use discrete latent states for the world model. The encoder network outputs discrete one hot encoded representations ztz_{t} that are sampled from a categorical distribution with 32 groups, each with 32 classes. This requires learning a representation network to project the discrete encoded states ztz_{t} to contrastive features etke_{t}^{k} and compute similarities.

[W1] (ii) Notation introduced in Eq. (1), specifically etke_{t}^{k} and e^tk\hat{e}_{t}^{k}, is not referenced further in the text.

[W1] (iii) Visualizations of the representation networks and AC-CPC predictors in Figure 3(a) could enhance clarity.

We also updated the section 3.1 to include notations introduced in Equation 1. We originally found it adequate to introduce the contrastive features as notations in Equation 1 along with other variables. The features were also referenced among other predictions on line 246. We agree that including AC-CPC predictions and targets in Figure 3 (a) improves clarity by referencing the notations. We updated Figure 3 (a) to include a more detailed visualization of the AC-CPC objective including predicted contrastive features etke_{t}^{k} introduced in equation 1. We hope this improves clarity.

[W2] The ablation studies section could benefit from improved consistency

Thank you for highlighting this. We updated the paper accordingly to ensure a consistent naming convention and ordering of the ablations. We increased the text size in figures to improve readability, especially Figures 7 and 8. We also increased the size of Figure 6 to limit unnecessary whitespace around plots.

W4: Including evaluations on other benchmarks, such as the DeepMind Control Suite, would strengthen the soundness of the findings. However, I acknowledge that this may be difficult to implement.

As stated in our global response, our paper reports results on the Atari 100k benchmark, which is commonly acknowledged as sufficiently diverse to verify the effectiveness of algorithms. This is in line with the common practice, the state-of-the-art methods we compare with are reporting on Atari 100k (SimPLe, TWM, IRIS and STORM). The suggested other benchmarks are considered and we are working on getting such results. However, due to limited resources, we cannot guarantee to provide new results.

Do the authors have any insight into why the performance drops when predicting 15 steps into the future?

We find that increasing the number of AC-CPC steps has a positive impact on most games. However, a degradation of the results is noticed when predicting 15 steps into the future. One possible reason for these findings is that some of the games include uncontrollable randomness that makes distant target states very hard to predict compared to other negative samples. The world model will learn to identify the set of possible futures and attribute high similarity to probable future samples. However, if uncontrollable game randomness is such that a large amount of negative samples is part of the set of probable futures, the world model will not be able to accurately predict the correct future sample without additional conditions. This is why future works could explore the use of “learned latent actions” to condition world models on environment changes that the agent does not control.

评论

Thank you for addressing all of my concerns in your revision. I appreciate the effort you have put into clarifying the points I raised and improving the presentation of the paper. I understand that extending the experiments to other benchmarks is difficult due to computational constraints, and I acknowledge the standard practice of evaluating on the Atari 100k benchmark.

After reviewing the updated version, I am satisfied with the revisions, and I have updated my scores accordingly.

审稿意见
8

The paper explores the use of contrastive predictive coding (modified objective function) for learning efficient world models.

The paper is very well written.

Motivation: The work argues that the task of predicting the next frame may not require a complex model (as the successive video frames are somewhat similar or change in distribution is very sparse). The table 1 is very helpful on distinguishing how different methods effect each other.

Proposed Method: The work builds upon transformer based world models and CPC loss to learn high level representations for improving the sample efficiency of the learning agent.

Ablations: The work conducts thorough ablations to understand the role of different components (number of steps in CPC loss, role of data augmentation, role of world model architecture i.e. RNNs v/s Transformer) and the effect of action conditioning.

优点

  • The paper is very well written.
  • The paper does a good job of ablating various different components used in proposed method.
  • The paper does a good job of citing relevant literature.

缺点

  • There's no weakness as such. The paper validate all the claims made in the introduction via careful experimentation and proper ablations.

问题

  • "As shown in Figure 2, the cosine similarity between adjacent latent states of the world model is very high, making it relatively straightforward for the world model to predict the next state compared to more distant states."

In the introduction paper talks about cosine similarity for the proposed method, it will also be interesting to plot the cosine similarity for all the ablations the paper propose (number of contrastive steps, future action conditioning, effect of data augmentation) (as well as the dreamer baseline).

  • It will be helpful if authors can include another benchmark to make sure the proposed method shows gains on other environments too (like Crafter as mentioned in section 2.2 for related work).
评论

We sincerely thank you for your positive and encouraging feedback on our submission.

it will also be interesting to plot the cosine similarity for all the ablations the paper proposes (number of contrastive steps, future action conditioning, effect of data augmentation) (as well as the dreamer baseline).

We computed the average cosine similarities for the different ablations, including the dreamer baseline. You can find the plot of cosine similarities for ablations in the rebuttal supplementary materials. We find that different ablations do not have a significant impact on cosine similarities. This is as intended because the cosine similarities are computed from the encoder network outputs ztz_{t}, which are trained to be compressed encoded features for image observations.

It will be helpful if authors can include another benchmark to make sure the proposed method shows gains on other environments too (like Crafter as mentioned in section 2.2 for related work).

As mentioned in our global response, we are working on the application of our method to Crafter and DMC.

评论

I've read the rebuttal, and keep the same score.

审稿意见
6

The paper proposes to combine Transformer-based world models with contrastive learning, and outperforms DreamerV3 and recent Transformer-based MBRL agents on the Atari 100k benchmark.

优点

  • The paper provides a detailed overview of DreamerV3 and recent Transformer-based MBRL agents. Table 1 provides a good summary of implementation differences.
  • There is adequate ablation study to justify the design of action-conditioned CPC loss and the CPC steps used. The finding that random crop and resize work better than random shift is quite intriguing.
  • Sufficient implementation details are provided.
  • The paper achieves strong results on the Atari 100k benchmark.

缺点

  • The paper did not discuss similar works that combine contrastive learning with Dreamer, notably DreamingV2 and DreamerPro. While targeting different tasks, DreamingV2 has a quite similar contrastive loss design.
  • From Table 2, it is quite intriguing that the proposed method TWISTER achieves significantly higher score on Gopher than previous methods. Further from Table 10, without the contrastive loss, the score on Gopher drops much more significantly than on other games. I am a bit concerned that the overall improvement brought by contrastive learning might be dominated by this single game.
  • Based on the first two points, the paper could be strengthened by showing the general applicability of the method. For example, the paper can investigate more tasks (e.g., Crafter, DMC), or more backbones (e.g., S4/S5, see S4WM and R2I)

问题

Some clarifying questions on implementation details

  • Did you use EMA encoder or any stop gradient operation when obtaining future latents zt+kz'_{t+k}?
  • Is the augmentation consistent across time (e.g., cropping the same region across time rather than choosing a random cropping region at each time step)?
评论

Did you use EMA encoder or any stop gradient operation when obtaining future latents ?

Similarly to the original CPC paper and similar approaches such as vq-wav2vec, we did not use an EMA encoder or the stop gradient operator. We let AC-CPC loss gradients back-propagate to encoder parameters like the other objectives optimized by the world model.

Is the augmentation consistent across time (e.g., cropping the same region across time rather than choosing a random cropping region at each time step)?

The random crop and resize augmentation is applied independently for each image in the input batch. This prevents the model from easily identifying images from the same sequence, which could simplify the AC-CPC objective.

评论

Thank you for addressing all my concerns. I have increased my score.

评论

Thank you for your constructive feedback on our paper. Please find below our response to the concerns and questions that you raised in the review.

The paper did not discuss similar works that combine contrastive learning with Dreamer, notably DreamingV2 and DreamerPro. While targeting different tasks, DreamingV2 has a quite similar contrastive loss design.

Our paper discusses three main domains of related works in Section 2: We start by giving a general overview of model-based reinforcement learning methods. We then discuss the proposed Transformer-based approaches to world model learning. Finally, we introduce the CPC method and our motivations for applying the CPC objective to model-based Reinforcement learning.

We are familiar with the DreamerPro and Dreaming algorithms, which are related to contrastive-based world model learning. The two methods proposed reconstruction-free variants of the Dreamer algorithm to improve robustness for continuous control tasks. DreamerPro proposed to learn representations with SwAV while Dreaming proposed a multi-step contrastive objective. Knowing that DreamerPro was applied to Atari games in the DreamerPro paper appendix, we found it interesting to evaluate the method on the Atari 100k benchmark using the official code provided by the authors. Using 5 different training seeds, we obtained normalized Mean and Median scores of 79% and 28%, which is lower than DreamerV3. We provide a modified Table with DreamerPro results in the rebuttal supplementary materials as a reference.

However, since the focus of these works was on learning representations without reconstruction loss for continuous control tasks and using the standard recurrent-based Dreamer world model, we did not find it necessary to include them in the related works section. Although it employs a contrastive-based objective, the Dreaming method differs from CPC in key aspects. Some of the key differences are that the prediction layer is applied recurrently into the future with an overshooting distance K set to 3. Contrary to the Dreaming objective, AC-CPC directly makes the prediction by providing the concatenated sequence of future actions as inputs to the predictor network. This allows TWISTER to make predictions over longer distances without encountering possible vanishing/exploding gradients during back-propagation through time.

without the contrastive loss, the score on Gopher drops much more significantly than on other games. I am a bit concerned that the overall improvement brought by contrastive learning might be dominated by this single game

We find that AC-CPC has a positive impact on Gopher, helping the agent to accurately model the gopher’s position and tunnels for planning. As stated in section 4.2, we also find that AC-CPC improves the performance of games with small moving objects such as Breakout, Pong and Asterix. This is also the case for Up N Down, Kangaroo, and Road Runner, which significantly improve Mean and Median scores.

The Median metric analyzes the performance of average-performing games in each experiment. This metric is interesting to evaluate the general performance of algorithms by ignoring outlier tasks like Gopher. For instance, Delta-IRIS faces a similar situation with Breakout. We find that TWISTER archives a higher Median score compared to previous Transformer-based approaches, which indicates its positive impact not only on Gopher but also on average-performing games like Breakout, Pong Up N Down, Kangaroo or Road Runner.

Another known metric for comparing the performance of average-performing games is IQM. IQM shows the aggregated performance on the middle 50% of combined runs, ignoring outlier games. We attached an updated ablation figure for the number of contrastive steps in the supplementary materials showing the impact of AC-CPC on IQM metric scores. We can see that AC-CPC representation learning has a significant impact on IQM scores, validating its effectiveness on diverse games.

the paper could be strengthened by showing the general applicability of the method. For example, the paper can investigate more tasks (e.g., Crafter, DMC), or more backbones (e.g., S4/S5, see S4WM and R2I)

As stated in our global rebuttal response, our paper reports results on the Atari 100k benchmark, which is commonly acknowledged as sufficiently diverse to verify the effectiveness of algorithms. This is in line with the common practice, the state-of-the-art methods we compare with are reporting on Atari 100k (SimPLe, TWM, IRIS and STORM). The suggested other benchmarks and backbones (S4) are considered and we are working on getting such results. However, due to limited resources, we cannot guarantee to provide new results.

评论

We thank the Reviewers for their insightful comments and valuable feedback. We appreciate the recognition from Reviewers 1PDj**1PDj**, 2k8g**2k8g** and TrKK**TrKK** of “good” and “excellent” paper presentation. We also would like to thank Reviewer z6Gc**z6Gc** for the suggestions of improvements regarding the paper presentation. We thank Reviewers 2k8g**2k8g** and TrKK**TrKK** for their positive and encouraging feedback on our paper: “The paper validates all the claims made in the introduction via careful experimentation and proper ablations”, “We can enjoy to read this paper, because of well summarized related works, well discussed their motivation and diverse empirical ablation studies”. We also would like to thank Reviewer 1PDj**1PDj** for the constructive comments on our paper. We address the main feedback from reviewers below and summarize the changes made to the submission.

  1. Following the suggestions of Reviewer z6Gc**z6Gc**, we updated Figure 3 (a) to include a more detailed visualization of the AC-CPC objective including target representations etke_{t}^{k} and predicted contrastive features e^tk\hat{e}_{t}^{k} introduced in Equation 1. We also updated Section 3.1 (lines 313-319) to detail the utility of the representation network for computing similarities. We arranged the order of ablation figures and results in Table 3 to coincide with the order of subsections. The text size of figures has also been increased to improve readability.

  2. We provide additional resources to illustrate our responses in the supplementary materials. Following the comments made by Reviewer 2k8g**2k8g**, we plot cosine similarities of image observations for different ablations. We also provide a modified result Table including the DreamerPro method for additional comparison. Finally, following the concerns highlighted by Reviewer 1PDj**1PDj** about the impact of outlier tasks, we attached an updated ablation figure for the number of contrastive steps showing IQM metric scores.

  3. One common feedback pointed out by Reviewers 1PDj**1PDj**, 2k8g**2k8g** and z6Gc**z6Gc** was that the findings of the paper would be strengthened by evaluating on other benchmarks such as the DeepMind Control Suite (DMC) and Crafter. Our paper reports results on the Atari 100k benchmark, which is commonly acknowledged as sufficiently diverse to verify the effectiveness of algorithms. This is in line with the common practice, the state-of-the-art methods we compare with are reporting on Atari 100k (SimPLe, TWM, IRIS and STORM). The suggested other benchmarks are considered and we are working on getting such results. However, due to limited resources, we cannot guarantee to provide new results.

评论

We would like to thank all the reviewers for their valuable feedback and responses. We also appreciate the decision of Reviewer 1PDj**1PDj** and Reviewer z6Gc**z6Gc** to increase the rating of our paper: “Thank you for addressing all my concerns”, “Thank you for addressing all of my concerns in your revision”.

We performed experiments on the DeepMind Control Suite to evaluate TWISTER performance under continuous action spaces. We follow DreamerV3 that evaluates on 20 tasks using only high-dimensional image observations as inputs and a budget of 1M environment steps for training. The following Table compares the official DreamerV3 results with TWISTER (3 seeds) using a budget of 1M environment steps. Bold numbers indicate the best performing method for each task.

TaskDreamerV3 (2023)DreamerV3 (2024)TWISTER (ours)
Acrobot Swingup210.0229**229**220.4
Ball In Cup Catch957.1972**972**964.7
Cartpole Balance996.4993997.9**997.9**
Cartpole Balance Sparse1000.0**1000.0**9641000.0**1000.0**
Cartpole Swingup819.1861**861**791.7
Cartpole Swingup Sparse792.9**792.9**759674.3
Cheetah Run728.7836**836**704.3
Finger Spin818.5589976.9**976.9**
Finger Turn Easy787.7878927.4**927.4**
Finger Turn Hard810.8904940.3**940.3**
Hopper Hop369.6**369.6**227312.7
Hopper Stand900.6903935.6**935.6**
Pendulum Swingup806.3744840.3**840.3**
Quadruped Run352.3617698.9**698.9**
Quadruped Walk352.6811929.1**929.1**
Reacher Easy898.9951967.7**967.7**
Reacher Hard499.2862**862**626.2
Walker Run757.8**757.8**684708.9
Walker Stand976.7**976.7**976976.6
Walker Walk955.8961**961**951.8
Mean739.6786807.3**807.3**
Median808.5861928.3**928.3**

We find that TWISTER achieves state-of-the-art performance with stable learning across all tasks. The application of Transformer-based world models to continuous control tasks has so far been very limited. We hope our work will inspire researchers to further study the potential benefits of Transformer-based world models and representation learning for continuous control tasks.

AC 元评审

The authors presents a significant advancement in model-based reinforcement learning by effectively combining Transformer-based world models with contrastive predictive coding, achieving state-of-the-art human-normalized performance on Atari 100k. The authors thoroughly addressed reviewer concerns by improving figure clarity, providing detailed technical explanations, and demonstrating the method's effectiveness through new DMC results. The comprehensive ablation studies, clear presentation, and strong empirical results across multiple benchmarks, combined with thorough responses during rebuttal, make this a valuable contribution worthy of acceptance. I recommend accepting this paper for ICLR 2025.

审稿人讨论附加意见

During the rebuttal period, the reviewers raised concerns about additional benchmark evaluations, unclear AC-CPC explanations, figure clarity, and similarities to existing approaches. The authors addressed these by providing new DMC results across 20 tasks, detailed technical clarifications, improved visualizations, and explanations distinguishing their method from similar work. All reviewers were satisfied with the responses. Given the strong initial results, thorough technical clarifications, and successful DMC validation, the paper warrants acceptance.

最终决定

Accept (Spotlight)