Parallelizing Model-based Reinforcement Learning Over the Sequence Length
This paper introduces the PaMoRL framework, which parallelizes MBRL over the sequence length, improving training speed and sample efficiency.
摘要
评审与讨论
This paper introduces a new model-based reinforcement learning method, called PaMoRL, which parallelize the world model training (PWM) and the eligibility trace estimation (PETE). Specifically, the parallelization is achieved by leveraging efficient parallel scan operations. On the commonly used Atari 100K and DM Control benchmarks, PaMoRL achieves competitive performance while enjoying significant improvements in training efficiency.
优点
- The writing is generally clear and easy to follow.
- Accelerating eligibility trace computation with parallel scan is a novel design, which can also benefits model-free RL methods.
- The experiments are comprehensive and solid.
缺点
- The core idea of this paper is about parallel scans. However, similar techniques have been used in Mamba [37] to accelerate computation. I noticed the authors cited Mamba but they didn't give much discussions on the connections.
- What's the benefit of the proposed modified linear attention compared to SSM with scan? Appendix B mentioned linear attention is more expressive. Is there any quantitative evidence to show the benefit?
- Could the authors report the results using metrics recommended in [i]? This will give a more reliable comparison between different methods.
- A recent work [ii] also studies how to make world model training fast. How does PaMoRL compare to it?
- How does the GPU overhead brought by scanners scale with different configurations (say, sequence length or other parameters)? Would the overhead become unacceptable when we change the setting or it is always less than say, X%, of the total GPU memory used?
[i] https://github.com/google-research/rliable
[ii] Cohen, Lior, et al. "Improving Token-Based World Models with Parallel Observation Prediction." arXiv preprint arXiv:2402.05643 (2024).
Minor issues:
- Labels and ticks in Figure 1 can be made larger to improve readability (same for Figure 3 (b,c) and the result figures in Appendix).
问题
Please see weaknesses above.
局限性
Yes, the limitations have been discussed in Section 5.
Thank you very much for reviewing our paper and for your many detailed comments. The following are responses to the weaknesses and questions you listed:
W1: More discussion of the connection to Mamba[1].
R1: Both our PaMoRL method and Mamba use parallel scanning algorithms for acceleration, and both have data-dependent decay rates (i.e., the selective mechanism described in Mamba). However, Mamba uses a special IO-aware parallel scanning algorithm for efficient training, which focuses on reducing the number of reads and writes between SRAM and HBM in the GPU through kernel fusion, i.e., carefully dividing the parameters that need to be saved and those that need to be recomputed, and the method is suitable for targeting the hardware characteristics to improve the training efficiency in the presence of a defined model architecture The method is suitable for improving the training efficiency for hardware characteristics when the model architecture is determined. Our PaMoRL method uses a generalized parallel scanning algorithm that does not require any manual partitioning of parameters and is therefore applicable to a wide range of model architectures. We thank you for your suggestions and will add more discussion of the connection to the Mamba in the "Background" section of the revised version of our paper.
W2: What are the benefits of improved linear attention compared to SSM with scanning? Is there any quantitative evidence that this linear attention is more expressive?
R2: Compared to SSM with scanning, improved linear attention has a data-dependent decay rate (i.e. gating mechanism), which allows the world model to filter out irrelevant information and remember relevant information indefinitely. We added SSM training curves to the ablation experiments in Figure 3 in the PDF attached to "Common Response", and the results show that Improved Linear Attention outperforms SSM on multiple tasks. Notably, DreamerV3 [2], which also has a gating mechanism, outperforms SSM in most environments, again illustrating the importance of data-dependent decay rates.
W3: Can the authors report results using the metrics recommended in [3]?
R3: We appreciate your suggestion, and in Figure 2 in the attached PDF from "Common Response" we use the metrics recommended in [3] to report the results.
W4: Comparison between a recent work [4] and PaMoRL?
R4: Recent work [4] focuses on how to accelerate the Token-based MBRL algorithm, where the world model is accelerated using the RetNet [5] architecture for accelerated training and also accelerates the process of predicting the next observation for the world model using block parallelization. In contrast, PaMoRL does not need to predict multiple fine-grained tokens, its world model architecture increases the data-dependent decay rate compared to RetNet, and it can use a recurrent pattern with minimal computational overhead when predicting the next observation. We have added comparisons to [4] in both Figure 1 and Figure 2 in the PDF attached to "Common Response", and our experiments conclude that our PaMoRL method matches and exceeds both the mean, median, and interquartile mean (IQM) human-normalized scores and the optimality gap both match and exceed the method proposed in [4], and also have higher training speed since PaMoRL does not need to predict multiple tokens.
W5: How does the GPU overhead from parallel scanning scale with different configurations? Does the overhead become unacceptable when changing settings?
R5: We have added experiments on GPU memory usage vs. wall clock time as sequence length varies in Figure 4 (Right) in the PDF attached to "Common Response" (based on an RTX 3090 GPU), and the results show that parallel scanning incurs additional GPU memory usage as sequence length increases, however, at a maximum sequence length of 1024 compared to a minimum of 16, there is only less than 14% (~3.36GB) additional GPU memory overhead.
W6: The labels and markers in Figure 1 could be made larger to improve readability (as are the result plots in Figure 3 (b, c) and the Appendix).
R6: We appreciate your suggestions. We will increase the font size of Figures 1 and Figures 3 in the revised version of the paper. You can find the corrected versions of Figures 1 and Figure 3 in the PDF attached to "Common Response".
[1] Albert Gu, et al. "Mamba: Linear-Time Sequence Modeling with Selective State Spaces." arXiv preprint arXiv:2312.00752 (2023).
[2] Danijar Hafner, et al. "Mastering diverse domains through world models." arXiv preprint arXiv:2301.04104 (2023).
[3] https://github.com/google-research/rliable
[4] Cohen, Lior, et al. "Improving Token-Based World Models with Parallel Observation Prediction." ICML (2024).
[5] Yutao Sun, et al. "Retentive Network: A Successor to Transformer for Large Language Models." arXiv preprint arXiv:2307.08621 (2023).
I apologize for the delay in my response. I want to thank the authors for their detailed reply and for addressing my questions. Based on the additional information provided, I've changed my score to lean acceptance. I appreciate the contribution to accelerating MBRL and achieving high performance. However, I remain somewhat unconvinced about the algorithmic novelty. While the PETE component is new, the combination of SSM, parallel scan, and gating is also utilized in Mamba, as confirmed by the authors in their reply.
We appreciate your consideration of our rebuttal. We would like to state further that although components in our PWM and Mamba seem to share the same spirit in the design of the SSM paradigm, parallel scanning and gating modules, this is only a prerequisite for both our PWM and Mamba to be effective in their domain respectively. We sincerely believe that "the devil is in the details", meaning that the model's specific implementation determines its sample efficiency or hardware efficiency.
Specifically, Mamba, due to its need to maintain its self-consistency with previous work in the SSM family, must cater to the paradigm of classical state space models representing continuous differential equations and must be parameterized and discretized using special tricks to achieve the gating mechanism implicitly. We recognize this limitation of Mamba and use a more "simple yet effective" gating mechanism to step out of the SSM paradigm.
Furthermore, although both our PWM and Mamba use parallel scanning algorithms for acceleration, parallel scanning is one of the fundamental computational paradigms, along with FFT and convolution. We thus believe that the algorithmic novelty lies in the customization of a specific domain or task. From this point of view, our parallel scanning algorithm is significantly different from Mamba. To satisfy the need for flexibility in model architecture design in the MBRL method, the parallel scanning method we use is inspired by HPC hardware design and is more concerned with generality, and thus is compatible with arbitrary architectures (as long as they satisfy the parallelization conditions) compared to the model architecture-specific parallel scanning method used by Mamba.
Please let us know if you have further considerations or questions, as we are eager to improve the quality of our paper.
Thank you for the explanations. Please include the above discussions in the updated version. I don't have further questions.
Thank you for the reminder. The above discussion will be integrated into the revised version of our paper.
The paper proposes a novel framework to parallelize model-based RL, including two improvements parallelizing the world model and parallelizing eligibility traces. They demonstrate the dramatic speed-up of training speed without sacrificing inference efficiency. The proposed method achieves state-of-the-art score performance and reduces the runtime of two important components significantly.
优点
- A significant speed-up of model generation runtime and eligibility trace estimation runtime ;
- The first paper to point out that the computational process of eligibility traces can be parallelized over the sequence length, which would be super helpful to the RL community to improve the efficiency of algorithms;
- Novel modifications to the RSSM module in the Dreamer algorithm to eliminate non-linear dependencies, each residual block of the sequence model consists of a modified linear attention module and a GLU module;
- A clear illustration of the proposed parallel algorithm with PyTorch (pseudo-)code;
- Significant reduction in runtimes of sequence model and eligibility trace estimation, Sequence model computation achieves 7.2× and 16.6× speedups compared to sequential rollout using the Kogge-stone and Odd-even scan algorithms, respectively, with a sequence length of 64.
缺点
This is an excellent paper and I didn't find any significant flaws.
问题
NA
局限性
NA
We greatly appreciate your high evaluation of our paper and your detailed feedback on its strengths. We are very pleased to see that you have recognized the innovation and contribution of our proposed PaMoRL method, and we will further improve and enhance our paper. Please feel free to let us know if you have any further suggestions. We look forward to your further feedback.
This paper proposes a new framework named Parallelized Model-based Reinforcement Learning (PaMoRL) to improve the training speed of MBRL methods. PaMoRL employs a parallel scan technique to parallelize world model learning and eligibility trace estimation. With experiments on the Atari and DMC domain, the paper claims that PaMoRL can improve the sample efficiency of MBRL methods while significantly reducing training time.
优点
- The paper tackles a significant problem in MBRL.
- The method is evaluated over two benchmarks, including a sufficient number of tasks.
缺点
- The paper lacks contribution. Parallelized World Model and Parallelizable Eligibility Trace Estimation merely employ old methods in this MBRL scenario, which are more like programming tricks than novel techniques.
- Analysis and theoretical derivation of the parallel scan methods are missing.
- Experiment results are reported only over three seeds, which can hardly be considered confident.
问题
How do you extrapolate the training FPS of methods other than DreamerV3 and PaMoRL from other GPUs? Could you provide some additional details?
局限性
Fine
Thank you very much for reviewing our paper and for your many detailed comments. The following are responses to the weaknesses and issues you listed:
W1: Parallelized World Model and Parallelizable Eligibility Trace Estimation use only old methods and are more like programming tricks than novel techniques.
R1: Our PaMoRL method draws inspiration from recent well-known work in the field of LLM (e.g., Mamba [1], RWKV [2], and Linear Transformer [3]) and follows the three conditions in the "Parallel Scan" paragraph in our paper (L95-L97) in the "Background" section of the paper to design the novel Parallelized World Model architecture. For Parallelizable Eligibility Trace Estimation, we also find that this widely used return estimation method in the RL domain naturally satisfies the parallelizability condition, further accelerating MBRL training in the policy learning phase. From the perspective of the MBRL domain, we have good reasons to believe that our PaMoRL method uses novel techniques rather than just programming tricks.
Moreover, as you point out, we have empirically demonstrated through sufficient experiments that our PaMoRL method significantly speeds up training while exceeding the baseline in most tasks, achieving the best of both worlds in terms of hardware efficiency and data efficiency, an achievement that we believe represents a significant contribution to the MBRL field.
W2: Lack of analysis and theoretical derivation of the parallel scanning method.
R2: The analysis and theoretical derivation of the two parallel scanning methods (Kaggle-Stone and Odd-Even) used in our paper can be found in Chapter 1.4.1 of [4] and Chapter 39.2.1 of [5], respectively. We thank you for your suggestion and will add more discussion and analysis of the theoretical derivations of these two parallel scanning methods in the background section of the revised version of the paper.
W3: The experimental results are reported for only three seeds, which can hardly be considered confident.
R3: We appreciate your suggestions for the experiment, and for this reason, we aligned the experimental setup with the previous well-known work [6-12] by adding two additional random seeds and took the suggestion from Reviewer earm to compare different methods using the metrics recommended in [13] to report results that give reliable comparisons.
Q1: Additional details of extrapolating training FPS from other GPUs for methods other than DreamerV3 and PaMoRL?
A1: Among the other methods in Figure 1, IRIS [7] and REM[9], TWM [10] were evaluated on A100 GPUs, while SimPLe [11], STORM [12], and other model-free RL methods were evaluated on P100 GPUs. The extrapolation methods we employ are consistent with the setup used in DreamerV3 [6], where it is assumed that the P100 is twice as fast as the P100 and the A100 is twice as fast.
[1] Albert Gu, et al. "Mamba: Linear-Time Sequence Modeling with Selective State Spaces." arXiv preprint arXiv:2312.00752 (2023).
[2] Bo Peng, et al. "RWKV: Reinventing RNNs for the Transformer Era." ACL (2023).
[3] Zhen Qin, et al. "The Devil in Linear Transformer." EMNLP (2022).
[4] Blelloch, et al. "Prefix sums and their applications." School of Computer Science, Carnegie Mellon University Pittsburgh (1990).
[5] Harris, et al. "Parallel prefix sum (scan) with CUDA." GPU gems (2007).
[6] Danijar Hafner, et al. "Mastering diverse domains through world models." arXiv preprint arXiv:2301.04104 (2023).
[7] Vincent Micheli, et al. "Transformers are sample efficient world models." ICLR (2023).
[8] Max Schwarzer, et al. "Data-Efficient Reinforcement Learning with Self-Predictive Representations." ICLR (2021).
[9] Cohen, Lior, et al. "Improving Token-Based World Models with Parallel Observation Prediction." ICML (2024).
[10] Jan Robine, et al. "Transformer-based World Models Are Happy With 100k Interactions." ICLR (2023).
[11] Lukasz Kaiser, et al. "Model-Based Reinforcement Learning for Atari." ICLR (2020).
[12] Weipu Zhang, et al. "STORM: Efficient Stochastic Transformer based World Models for Reinforcement Learning." NeurIPS (2023).
I appreciate the authors' great effort in addressing my questions. Despite room for improvement in the writing, the empirical results demonstrating efficiency proves the work's contribution to the community. I have raised my score to 5, and I hope the authors can elaborate on the analysis and theoretical derivation of parallel scan methods in a future revision.
We appreciate your consideration of our rebuttals and suggestions on the content and writing of the paper. We will do our best to improve the writing of our paper and incorporate your suggestions for detailing the analysis and theoretical derivation of the parallel scanning method in the revised version. We believe this will provide deeper insights for readers interested in the theory of parallel scanning.
Please let us know if you have any further comments or suggestions, as we are eager to improve the quality of our paper.
Model-based RL algorithms are popular due to their strong data efficiency compared to model-free alternatives. However, most MBRL algorithms learn a recurrent world model that scales linearly in time complexity wrt the input sequence length. This paper proposes several changes to the DreamerV3 algorithm that allows for parallelization of the world model training across the sequence (temporal) dimension, in particular use of the Odd-Even parallel scan algorithm + architectural changes that permit use of parallel scans. Experiments are conducted on Atari 100k + DMControl (proprio + pixels), and results indicate that the proposed method both achieves better data efficiency and hardware efficiency compared to prior work.
Post-rebuttal update: I believe that the authors have addressed my main concerns. I have increased my score from 4 -> 5 and soundness from 2 -> 3 to reflect that.
优点
- The problem studied in this paper is both interesting, timely, and highly relevant to the NeurIPS community. This paper extends (at least in spirit) a series of recent papers from LLM literature (e.g. RWKV, Mamba) to MBRL algorithms that similarly can benefit from parallelization of training while maintaining low memory footprint, since many MBRL methods leverage RNNs. It is refreshing to see this kind of work in the MBRL space.
- The proposed architectural and algorithmic changes are fairly intuitive and seemingly simple to implement, while there is a big potential upside in terms of training parallelization. While the changes are simple, they are not trivial to come up with. I absolutely consider this a key strength of this paper.
- Experimental results are promising. It appears that the method is effective in terms of parallelizing model training while maintaining a low memory footprint (mostly due to the choice of parallel scan algorithm). Data efficiency / task performance is comparable or better than previous work on MBRL without planning.
缺点
While I find the technical contributions of this paper quite intriguing, there's a few things that I find somewhat problematic:
- The key selling point of the proposed method is its parallelism. However, there's pretty limited evidence that it actually achieves this goal. Sure, runtime (ms) is generally quite a bit lower, but I would have expected to see a comparison in training wall-time vs. sequence length used to train the model. Intuitively, the computational gains would increase with the sequence length vs. vanilla DreamerV3, but I did not see any experiments of this sort.
- Along the same lines, I find it a bit odd that most experiments emphasize better data efficiency when that is not really a central part of the paper. The authors even state themselves that "our main goal is to maximize the hardware efficiency of existing MBRL methods, rather than pursuing extreme sample efficiency." (L188), so why do the experimental results primarily measure data (sample) efficiency rather than hardware efficiency? It seems like the proposed changes improve data efficiency so I understand why the authors would highlight that, but it seems mostly orthogonal to the actual problem that the authors set out to address.
- Limited ablations. The authors make multiple changes to the algorithm, but only two (token mixing, RMSNorm) are ablated. I'm left with a fairly poor intuition for how important the various algorithmic changes really are for both data efficiency and hardware efficiency, and it is not even clear to me how the ablations compare to e.g. vanilla DreamerV3 here (Figure 3).
问题
I would like the authors to address my comments above. I have a couple of additional follow-up questions:
- I am not entirely sure how to read Figure 1. There appears to be 3 metrics shown in each subfigure, but only two axes. For example, the right-most figure shows GPU memory utilization, mean score, and training FPS. Can the authors please clarify how to read this figure? Additionally, the figure text is very small and difficult to read; I suggest increasing the font size.
- It appears that the two ablations (token mixing, RMSNorm) are conducted on different sets of tasks. What is the reasoning behind that? This makes the results seem somewhat cherry-picked to me.
- Based on the description of the use of batch normalization (L257) + the qualitative results in the appendices, it seems like BN is quite useful. Why do the authors refrain from running any quantitative ablations on the use of BN? Does DreamerV3 benefit equally from BN, given that this is largely orthogonal to the actual contributions of this work?
局限性
I believe that the authors address limitations adequately in the last paragraph of the paper.
Thank you very much for reviewing our paper and for your many detailed comments. The following are responses to the weaknesses and questions you listed:
W1: Limited evidence of the key selling point of parallelism in our method.
R1: Thank you very much for your suggestions on the experiments, we have already added the wall-clock times and GPU memory usage (based on an RTX 3090 GPU) for our PaMoRL, SSM (with the data-dependent decay rate removed), and DreamerV3[1] trained at different sequence length in Figure 4 (Right) in the PDF attached to "Common Response". The experimental results demonstrate that the computational gain of our PaMoRL method in terms of wall-clock time increases with sequence length, while at the same time, there is less than 14% GPU memory overhead (~3.36GB) at the maximum sequence length of 1024 compared to the minimum sequence length of 16.
W2: Why do the experimental results mainly measure data efficiency rather than hardware efficiency?
R2: We believe that the key to determining whether an MBRL method can maximize hardware efficiency should be to determine whether it has operators that do not allow parallelization over the sequence lengths. For example, if the world model architecture and eligibility trace estimation in a particular MBRL method built upon DreamerV3 satisfies the three conditions mentioned in the "Parallel Scan" paragraph in the "Background" section of our paper (L95-L97), then in principle, this MBRL algorithm is capable of maximizing the hardware efficiency through the parallel scan algorithm. Furthermore, since the observation and action spaces of the individual tasks in each benchmark do not differ much in dimension and our method does not require an individual hyperparameter for each task, which means that our PaMoRL method is computationally highly consistent across tasks, we do not think that setting up too many experiments to demonstrate the hardware efficiency of our method is necessary.
On the other hand, the MBRL method has the key selling point of its inherent high data efficiency, and most of the previous works [1-3] that concluded that their methods are highly data-efficient stem from the superior performance in various tasks that measure data efficiency. Therefore we set up various experiments measuring data efficiency to empirically demonstrate that our PaMoRL method can also achieve MBRL-level data efficiency.
W3 & Q2: Limited ablation of the importance of various modules on data efficiency and hardware efficiency. Reasons for existing ablation experiments on different task sets?
R3: According to experimental results from previous work [4, 5], both Token Mixing and RMSNorm modules play a key role in their model training. Therefore, we choose "Alien", "Boxing" and "MsPacman", which are tasks focusing on sequence prediction, to demonstrate the benefits of Token Mixing in sequence prediction. We also choose "Amidar", "UpNDown" and "Qbert", which are scattered tasks, to measure the improvement of RMSNorm on training stability.
In addition, we supplemented the training curves of the above 6 tasks by removing the PWM of Token Mixing and RMSNorm modules and added SSM (PWM without data-dependent decay rate) and vanilla DreamerV3. You can check the results in Figure 3 in the attached PDF of "Common Response". The experimental results show that Token Mixing, RMSNorm, and data-dependent decay rate are all beneficial to data efficiency. The results of the experiments on hardware efficiency can be found in “R1”.
Q1: Clarification in Figure 1.
A1: We apologize for the incorrect title of Figure 1 due to a plotting error. We have corrected Figure 1 in the PDF attached to "Common Response" and increased the font size.
Q2: Why are existing ablation experiments being performed on different task sets?
A2: Please find the response to this question in "R3".
Q3: Quantitative ablation on the role of batch normalization techniques?
A3: Thank you very much for your suggestion, we have added quantitative ablations to Figure 4 (Left) in the PDF attached to "Common Response". The results show that PWM benefits from the batch normalization trick but not DreamerV3. We believe that this is because PWM's decoder has only stochastic states as inputs, which makes it difficult for training samples to distinguish from each other in the early stages of training, resulting in the notorious "posterior collapse" [6]. The DreamerV3 decoder mitigates this problem by having additional deterministic states as conditional inputs.
[1] Danijar Hafner, et al. "Mastering diverse domains through world models." arXiv preprint arXiv:2301.04104 (2023).
[2] Vincent Micheli, et al. "Transformers are sample efficient world models." ICLR (2023).
[3] Max Schwarzer, et al. "Data-Efficient Reinforcement Learning with Self-Predictive Representations." ICLR (2021).
[4] Bo Peng, et al. "RWKV: Reinventing RNNs for the Transformer Era." ACL (2023).
[5] Zhen Qin, et al. "The Devil in Linear Transformer." EMNLP (2022).
[6] Samuel R. Bowman, et al. "Generating Sentences from a Continuous Space." CONLL (2016).
Thank you for responding to my questions in detail. I believe that my main concerns have been addressed in the rebuttal. I still hold the opinion that it is important to include experiments on hardware efficiency in a paper about hardware efficiency. I hope that the authors will take my comments (as well as those of fellow reviewers) into account when revising their paper. I have increased my score from 4 -> 5 and soundness from 2 -> 3 contingent on the authors incorporating all feedback into the camera-ready version.
We appreciate your consideration of our rebuttal. We will seriously consider your (and other reviewers') suggestions on hardware efficiency experiments to provide more insights for readers.
Please also let us know if you have further comments or suggestions, as we are eager to improve the quality of our paper.
Common Response
We thank all reviewers for their valuable feedback, reviewers (MVJ1, wbEe, earm) for recognizing the efficiency and novelty of PaMoRL, and reviewers (MVJ1, 5drW, earm) for the promising results and comprehensiveness of the paper's experiments. We summarize the main updates and some frequently asked questions below:
Why do the experimental results primarily measure data efficiency rather than hardware efficiency?
We believe that the core of determining whether an MBRL method can maximize hardware efficiency should be whether it contains operators that do not allow parallelization over the sequence length. The operators that make up the world model and eligibility trace estimation of our method fully satisfy the three conditions mentioned in the "Parallel Scan" paragraph in the "Background" section of our paper (L95-L97), and thus our method can maximize hardware efficiency through the parallel scan algorithm. In addition, since the observation and action spaces of individual tasks in each benchmark do not differ much in dimension and our method does not need to adjust different hyperparameters for different tasks in the same benchmark, this means that our PaMoRL method is computationally highly consistent across tasks, and therefore we do not think that setting up too many experiments to prove the hardware efficiency of our method is necessary.
On the other hand, the MBRL method has the key selling point of its inherent high data efficiency and most of the previous works [1-3] that concluded that their methods are highly data-efficient stem from the superior performance in multiple tasks that measure data efficiency. We, therefore, consider it necessary to set up various experiments measuring data efficiency to empirically demonstrate that our PaMoRL method can achieve MBRL-level data efficiency.
Additions to the ablation studies.
We accepted the suggestions of Reviewer earm and Reviewer 5drW to add two additional random seeds, to achieve alignment with experimental setups in previous well-known work [1-3], to compare different methods using the metrics recommended in [4] to report results, and to add comparisons with recent work [5]. Our PaMoRL matches or exceeds the baseline on mean, median, interquartile mean (IQM) human-normalized scores and optimality gap. We also accepted Reviewer MVJ1's and Reviewer earm's suggestions to increase the wall-clock time and GPU memory usage (based on an RTX 3090 GPU) for our PWM, SSM (equivalent to PWM without the data-dependent decay rate), and DreamerV3[1] trained under different sequence length settings. The experimental results demonstrate that the computational gain of our PaMoRL method on wall-clock time increases with sequence length, while at the same time, there is less than 14% GPU memory overhead (~3.36 GB) at the maximum sequence length of 1024 compared to the minimum sequence length of 16, which represents the superior hardware efficiency of our PaMoRL method.
In addition, we also selected "Alien", "Boxing" and "MsPacman", which are tasks focusing on sequence prediction, and "Amidar", "UpNDown", and "Qbert", which focus on observation dispersion, for RMSNorm, Token Mixing, and Gating modules, and compare them with DreamerV3. The experimental results show that these modules all play a key role in improving data efficiency.
Clarification of Figure 1.
We apologize for the incorrect title of Figure 1 due to a plotting error. Figure 1 in our attached PDF puts a corrected figure with an increased font size. We have accepted Reviewer 5drW's suggestion to detail the GPUs used for each baseline training, and have elaborated on the details of the method for extrapolating the training speeds of the different GPU models to the NVIDIA V100 GPUs, which is consistent with the setups used in DreamerV3, where it is assumed that the P100 is twice as fast as the P100, and the A100 is twice as fast as it is.
[1] Danijar Hafner, et al. "Mastering diverse domains through world models." arXiv preprint arXiv:2301.04104 (2023).
[2] Vincent Micheli, et al. "Transformers are sample efficient world models." ICLR (2023).
[3] Max Schwarzer, et al. "Data-Efficient Reinforcement Learning with Self-Predictive Representations." ICLR (2021).
[4] https://github.com/google-research/rliable
[5] Cohen, Lior, et al. "Improving Token-Based World Models with Parallel Observation Prediction." ICML (2024).
This paper initially did not convince all Reviewers, who expressed concerns about its novelty and empirical evaluation. The discussion period has been useful to clarify these concerns and resulted to some Reviewers' increasing their scores, resulting to unanimously positive reviews. Remarkably, the speed-up provided by the presented method w.r.t. other baselines has been praised by the Reviewers, making this a significant contribution.
I encourage the Authors to incorporate Reviewers' feedback in the final version.