Predicting masked tokens in stochastic locations improves masked image modeling
Given an incomplete picture of a dog, can precisely determine the location of its tail? Masked image models like I-JEPA and MAE do not deal with location uncertainty. We propose a way to model the uncertainty via stochastic positional embeddings.
摘要
评审与讨论
This paper proposes stochastic positional embeddings (StoP) to improve masked image modeling (MIM), which incorporates location uncertainty by conditioning the model on stochastic masked token positions drawn from Gaussian distribution. Experimental results demonstrate that using StoP reduces overfitting to location features and guides the model toward learning features that are more robust to location uncertainty, which also leads to better performance on a variety of downstream tasks.
优点
- The idea of stochastic positional embedding proposed here is novel to me
- Experiments are sufficient to support the proposed method, showing that the proposed method can achieve significant improvements on various downstream tasks
缺点
Several parts of the proposed method are not properly introduced and may cause some confusions, details can be found in Questions part
问题
- I am a bit confused on step 11 in Algorithm 1. As in Figure 2, the context and masked representations are computed by adding their tokens and positional embeddings together. Then for step 11, I suppose should refer to the positional embedding, and should refer to context token? Why do we need an additional linear transformation on ? Some explanations may be needed for this part.
- Based on the above concern, I am also confused by later explanations in section 3.2 and 4.3, The authors seem to let (resp. ) as context (resp. masked) tokens, and (resp. ) corresponds to the bias for context (resp. masked) tokens. However, I suppose should simply be used to compute stochastic positional embedding as in (2), and is computed from encoder to encode context information. How can they have the same role?
- Moreover, with the above correspondence, we should have (resp. ) as context (resp.) tokens, then the positional embedding is simply (resp. ), and where is the stochasticity? I suppose there might be some misunderstanding.
- I would also like to see more discussions on the connection between StoP and vanilla MIM. I suppose we can replace step 10 with , and step 11 with to reduce to vanilla MIM, is it correct? Such discussions may make it easier to understand the proposed method.
- While the authors have mentioned the necessity of regularization on A, the regularization with context token is a bit confusing. I note that the authors have conducted additional experiments in section 4.3 that uses L1 regularization on A. Nevertheless, L1 regularization should aim to obtain a sparse matrix A, which seems to contradict with the original aim to avoid zero A. The authors may consider using some other regularization (and also remove A in computing context tokens) and see how such modification works compared to Algorithm 1.
Minor: the authors may also need to pay more attention on notations and typos. An example is on the top of page 5 “Context Encoding”, “Where” is wrongly capitalized (in fact the capitalization is used very arbitrarily and may require a careful proof-reading). Also, the notation through this paper is not consistent, especially for representations and . Some revisions may be needed as well.
Thank you for the thoughtful consideration of the paper and constructive feedback. We’ve incorporated your feedback and uploaded a new revision of the paper.
Q: “As in Figure 2, the context and masked representations are computed by adding their tokens and positional embeddings together”
Thank you for pointing this out, there is a mistake in Figure 2 and we apologize for the confusion. To clarify, the masked and context tokens are computed as follows (as in Algorithm 11):
11:
12:
13:
We uploaded a new paper revision and fixed Figure 2.
Q: “For step 11, I suppose should refer to the positional embedding, and should refer to context token?
You are right.
Q: Why do we need an additional linear transformation on ? Some explanations may be needed for this part.
The mapping projects from the output dimension of the encoder to the input dimension of the predictor (this is standard for other approaches like MAE and I-JEPA).
Q: The authors seem to let (resp. ) as context (resp. masked) tokens, and (resp. ) corresponds to the bias for context (resp. masked) tokens. However, I suppose should simply be used to compute stochastic positional embedding as in (2), and is computed from encoder to encode context information. How can they have the same role?”
We assume that you ask why both the context and noise linearly projected by matrix . In the original MIM formulation there is no stochastic positions and the masked tokens and context tokens are computed as follows:
12:
13:
When applying StoP with the reparameterization trick (Eq.3, revised manuscript), the noise is now linearly projected with a matrix (defined in Eq. 1) and summed with the positional embeddings. Therefore, both the sampled noise and the context tokens are now linearly projected:
11:
12:
13:
Note that here the context tokens and noise use different projections and . However, we find that the weights of A are quickly scaled down during training, setting which overcomes the noise during training, resorting to the original MIM introduced before, and the same empirical downstream accuracy. We discussed this in Section 3.1 and 3.2 (see “Avoiding a degenerate determinism solution” and “Masked tokens in stochastic locations”.)
To avoid this, we use the same matrix to also project the context tokens (line 13), instead of using a different projection matrix :
11:
12:
13:
The motivation for using to project both the context features and noise can be understood by considering two extreme cases. When , there is complete certainty about the positional embeddings but all context is lost (). On the other hand, when is large the context information is preserved, but due to the large magnitude of the noise is amplified and camouflages the positional embedding features of the masked tokens: . This dual role of matrix A forces the model to balance between location certainty and the influence of context features in predictions. It optimizes the trade-off for each feature, balancing their presence in predictions against the need for precise spatial locations.
We discussed this in Section 3.2 in the original submission but following the comments we revised the manuscript to make this more clear.
Q: I would also like to see more discussions on the connection between StoP and vanilla MIM.
The above answer should clarify this comment as well. We follow your advice and clarify this in Algorithm 1 caption and highlight the differences in the Algorithm (see revised manuscript).
Q: While the authors have mentioned the necessity of regularization on A, the regularization with context token is a bit confusing. I note that the authors have conducted additional experiments in section 4.3 that uses L1 regularization on A. Nevertheless, L1 regularization should aim to obtain a sparse matrix A, which seems to contradict with the original aim to avoid zero A. The authors may consider using some other regularization (and also remove A in computing context tokens) and see how such modification works compared to Algorithm 1.
To clarify, in the regularization experiments we followed a similar setting to what you suggested (no stochasticity). We used the basic MIM recipe plus regualrization over the projection matrix that projects the context tokens from the encoder output dimension to the predictor input dimension:
Since there is no stochasticity, we do not need to worry about avoiding A going to zero to scale the noise down (). There is still the tradeoff between the MIM reconstruction loss and regularization loss (but this is always the case with regularization).
Dear reviewer, towards the end of the discussion phase, we trust that our response has successfully addressed your inquiries. We look forward to receiving your feedback regarding whether our reply sufficiently resolves any concerns you may have, or if further clarification is needed.
I suppose some of my previous concerns are successfully resolved, mainly on the connection between StoP and vanilla MIM. Nevertheless, I still have some confusions regarding some details of StoP:
- The use of using the same matrix for both projection and covariance still sounds strange to me. I understand that currently StoP effectively prevents as it will lead to (no context). Nevertheless, I suppose there might be some other implementations, a straight-forward idea is to use an additional matrix and computes . In such case, also leads to (no context). I wonder if the authors can provide some discussions on that.
- I am now a bit confused on the experiments on regularization. I suppose you are trying to prove that the improvements of StoP do not solely come from regularizing (which is used as the projection matrix for context token)? However, I am not sure if the matrix in StoP is really regularized towards a sparse matrix. Given that you observed that the norm of decreases with increasing , I suppose you should try regularization (which regularizes the norm of matrix ) instead of , and see if that can lead to much improvement.
Dear reviewer, thank you for the reply and we are happy that some of your previous concerns are resolved.
Q: I suppose there might be some other implementations, a straight-forward idea is to use an additional matrix B
Thank you for this suggestion. The idea to use the matrix B would cancel the noise and lead to a deterministic solution (i.e., removing our novel noise component), and thus this is undesirable. For example: If , then the noise is scaled down via and the positional embedding is unaffected. B can then be set to be , and this will preserve the context tokens information.
There might be other ways to regularize A that can be explored, for example, by incorporating additional (multiple) loss terms that ensure A has a large enough norm, and that it is full rank. However, our solution is simpler as it doesn't require additional losses and hyperparam tuning.
Q: I am now a bit confused on the experiments on regularization. I suppose you are trying to prove that the improvements of StoP do not solely come from regularizing ? (which is used as the projection matrix for context token)? However, I am not sure if the matrix in StoP is really regularized towards a sparse matrix. Given that you observed that the norm of decreases with increasing , I suppose you should try ℓ2 regularization (which regularizes the norm of matrix ) instead of ℓ1, and see if that can lead to much improvement.
Indeed, we wanted to show that the improvements of StoP are not just due to reducing the norm of A. Clearly, there are several notions of norm, and these can be explored. We focused on because this is a standard approach to regularizing the rank of for the diagonal case. Furthermore, it is well known that optimization with SGD implicitly regularizes norm (e.g., see https://arxiv.org/abs/1906.05890), so we wanted to test a norm that is not implicitly regularized.
We note that our regularization experiments also resulted in low (the higher the regularization loss coefficient , the lower the norm, see table below).
| loss Coeff | norm |
|---|---|
| 1.0 | 0.00002 |
| 0.1 | 0.00007 |
| 0.01 | 0.00010 |
| 0.001 | 0.00020 |
The authors may consider using some other regularization (and also remove A in computing context tokens)
I suppose you should try ℓ2 regularization (which regularizes the norm of matrix ) instead of ℓ1, and see if that can lead to much improvement.
Dear reviewer, we follow up on your suggestion and include additional experiments applying regularization on .
Specifically, we trained ViT-B/16 baseline models using deterministic sine-cosine positional embeddings for 150 epochs while adding regularization loss weighted by {}. We then applied the ImageNet linear probing protocol, then report the results below.
These results indicate that StoP cannot be merely replaced by regularization over . Please let us know if there are any other concerns, and we are open to hear more feedback or provide further clarification if needed.
| Model | Top-1 Acc |
|---|---|
| Baseline, | 61.7 |
| Baseline, | 62.7 |
| Baseline, | 61.9 |
| Baseline, | 59.8 |
| StoP | 64.8 (+2.1) |
Dear reviewer, we would greatly appreciate it if you could review our new response. We believe that we have effectively addressed all of your previous concerns. We actively stand by for the last few hours of the discussion phase.
The paper proposes modeling a distribution over positional embeddings instead of learning/using deterministic ones which is compatible with any Masked Image Modeling (MIM) framework.
优点
Authors propose smart modeling design choice to avoid collapsing model to just learn deterministic embeddings. Experimental evaluation shows consistent improvements compared to deterministic MIM (i.e. I-JEPA) for models of different sizes. Also, ablation study is great, authors ablate and deeply study different aspects of the model.
缺点
Honestly, I don't see any obvious weaknesses of the work.
问题
To strengthen the evaluation, it would be nice to see linear probes/finetuning results on the larger set of downstream datasets. Also, it could be nice to have a model pretrained on a larger dataset rather than Imagenet-1000 as it could lead to stronger model and will enable better transfer to downstream problems which is important to have such representations for the community.
Thank you for the thoughtful consideration of the paper and very positive feedback.
Q: To strengthen the evaluation, it would be nice to see linear probes/finetuning results on the larger set of downstream datasets.
Thank you for the comment. We’ve evaluated StoP on 5 different datasets (ImageNet, iNat, Places, DAVIS 2017, CLEVR). Following your comment, we will run additional evaluations on CUB-200, Flowers-102 and IN-100 and will include it in the final manuscript.
Q: it could be nice to have a model pretrained on a larger dataset rather than Imagenet-1000 as it could lead to stronger model and will enable better transfer to downstream problems which is important to have such representations for the community.
Thank you for the suggestion. We think that running large scale experiments (e.g, on LAION 5B) with StoP is exciting. Since this might require non trivial engineering efforts and amounts of resources, we leave this for future work.
I would like to thank the authors for the clarifications and will maintain my initial assessment of the paper.
Thank you very much; we truly value your support in accepting the paper.
The paper proposes the Stochastic Positionalem beddings (StoP) to MIM in order to perturb the location information of images as a way of regularization. This avoids overfitting the model. The paper motivates and derives the empirical training loss of such perturbation that allows for end to end training by borrowing the well known reparametrization trick. Empirical evidence shows that the proposed method improves the existing SOTA method by evident margin.
优点
The paper has several strengths including:
S1. It introduces Stochastic Positional Embeddings (StoP) for the purpose of adding perturbations to the location information of images within the MIM framework, thus serving as a means of regularization. This measure intuitively can prevent the model from overfitting.
S2. By employing a reparametrization trick, the paper trivially both justifies and develops the empirical training loss associated with this form of perturbation, enabling end-to-end training.
S3. Empirical results highlight that this proposed technique significantly enhances the state-of-the-art method, demonstrating a noticeable improvement.
缺点
However, there are also several concerning points that needs to be addressed:
W1: It is unclear to me why it is necessary to learn optimal via additional parameterization. What is the benefits of introducing additional degree of freedom here to learn Sigma? What if we fix Sigma without learning? Isn't it a simpler way to avoid degeneracy of matrix A? Please explain the motivation.
W2: I understand that adding stochastic perturbation to position of the images makes sense in regularizing the model. However, why the same spectral decomposition is applied to features s_x (by multiplying with A)? This step also lacks motivation and seems to be heuristic, please clarify on this point,
W3: What exactly architecture did the paper use to parameterize the matrix ? An architecture flow illustration will help better illustrate this mechanism. Currently, I am not sure how the back-propagation of flows back to the network (figure 1 does not have this part ) and how it affects the SSL learning with a positive gain.
W4: I am not sure of the significance of proposition 1. I do not see why using this optimal predictor can help achieve better generalization ability of the SSL pretraining on downstream tasks.
问题
Please see above for the in total 4 questions to be addressed.
伦理问题详情
None.
Thank you for the thoughtful consideration of the paper and constructive feedback. We’ve incorporated your feedback and uploaded a new revision of the paper.
W1: It is unclear to me why it is necessary to learn optimal Σ via additional parameterization. What is the benefits of introducing additional degree of freedom here to learn Sigma? What if we fix Sigma without learning? Isn't it a simpler way to avoid degeneracy of matrix A? Please explain the motivation.
We compare fixed to learned in Figure 3. Like you mentioned, using a fixed indeed prevents degeneracy of . However, learned works better empirically ( compared to , see Figure 3). Implementation wise, both approaches are very simple.
The motivation to use a learned is to avoid having to perform an extensive grid search to find the optimal values. It is easier to let the model find the values itself.
W2: I understand that adding stochastic perturbation to position of the images makes sense in regularizing the model. However, why the same spectral decomposition is applied to features (by multiplying with )? This step also lacks motivation and seems to be heuristic, please clarify on this point.
Without posing any constraint over , we find that the weights of are quickly scaled down during training, setting to overcome the noise:
11:
12:
Therefore, we resort to the basic MIM without stochasticity. To avoid this, we use to project both the noise tokens and the context tokens :
11:
12:
13:
The motivation for using to project both the context features and noise can be understood by considering two extreme cases, when , there is complete certainty about the positional embeddings of the masked tokens but all context is lost (), thus making the MIM prediction task impossible. On the other hand, when is large the context information is preserved, but due to the large magnitude of the noise is amplified and camouflages the positional embedding features of the masked tokens: , which makes the prediction task hard as well. This dual role of matrix forces the model to balance between location certainty and the influence of context features in predictions. It optimizes the trade-off for each feature, balancing their presence in predictions against the need for precise spatial locations.
We discuss this in Section 3.1 (see “Avoiding a degenerate determinism solution”) and in Section 3.2 in the original submission but following the comment we revised the manuscript to make this more clear.
W3: What exactly architecture did the paper use to parameterize the matrix Σ? An architecture flow illustration will help better illustrate this mechanism. (figure 1 does not have this part ) Currently, I am not sure how the back-propagation of Σ flows back to the network and how it affects the SSL learning with a positive gain.
We defined (See revised manuscript Eq. 2) where is a scalar hyperparameter and is a learned matrix. However, instead of sampling from Eq.1 (where we cannot backprop through ), we use the reparametrization trick to sample noise , and multiplying by to get the stochastic positional embeddings: (see revised manuscript Eq 3). This is differentiable w.r.t because the sampling distribution does not depend on . Note that .
We followed your suggestion and revised the architecture figure (Figure 2) to include the reparameterization trick to make it more clear (see new paper revision).
W4: I am not sure of the significance of proposition 1. I do not see why using this optimal predictor can help achieve better generalization ability of the SSL pretraining on downstream tasks.
The main goal of Proposition 1 is to provide insight to what is learned with StoP in a simple setting (one input and one output). In this case, we show that the optimal predictor explicitly models location uncertainty by performing spatial smoothing. We do not claim this property leads to better generalization (although we do see empirical downstream gains).
To summarize, we think Proposition 1 provides a nice further analysis, but we are open to moving this into the appendix.
Dear reviewer, towards the end of the discussion phase, we trust that our response has successfully addressed your inquiries. We look forward to receiving your feedback regarding whether our reply sufficiently resolves any concerns you may have, or if further clarification is needed.
Dear reviewer, thank you for the reply and we are happy that some of your previous concerns are resolved.
However, in terms of reusing the matrix to , I am still not convinced (W2). The current version of doing this projection lacks clear motivation and thus leaving it hard to judge the correctiveness.
The reason for applying matrix to is to prevent the stochastic positional embeddings from collapsing into deterministic positional embeddings. Let's begin by describing why this collapse phenomenon happens, and subsequently, we will outline how the use of with provides an effective solution to address it.
Stochastic positional embeddings collapse to deterministic
By using the reparametrization trick, we generate stochastic positions as follows: (Eq 3).
It's important to note that regulates the noise level, and this noise disrupts the positional embeddings of the masked tokens. Therefore, better MIM predictions may be achieved without the presence of noise. Consequently, during training there is a risk of collapse into deterministic positional embeddings by setting .
Our experimental results confirm this. Without introducing a mechanism to prevent collapse, the empirical results resemble those obtained using deterministic Sine-Cosine features (see, for example, Table 6, under "Sine Cosine").
Preventing Collapse
Hence, in order to effectively capture location uncertainty through stochastic positional embeddings, it is crucial to prevent the occurrence of this collapse. While there might exist other ideas to address this issue, we employ a simple yet effective approach. This approach stands out as it doesn't necessitate additional losses, hyperparameters, or even learned weights.
The idea is to use the matrix to project both and . It's worth noting that in MIM models, there is a linear projection of from the encoder's dimension to the predictor's dimension and we replace it with . For a detailed view of the differences between StoP and MIM, please refer to the revised paper, Algorithm 1.
How does reusing to both and prevents collapse while promoting the modeling of location uncertainty?
Using to project serves as a preventive measure against setting , as doing so would eliminate crucial context information, making the MIM prediction task impossible. Or as pointed out by reviewer gTwQ: "StoP effectively prevents as it will lead to ".
However, the model has to learn a matrix A that doesn't excessively amplify as this would result in amplifying the noise as well. Excessive amplification of the noise would camouflage the positional embeddings of the masked tokens, making their location very uncertain.
Summary
To summarize, without introducing a mechanism to prevent collapse, the positional embeddings become deterministic. We proposed to mitigate that by applying the same matrix both to the noise and to . To learn a good , the model has to trade off the importance of input context and the certainty in the masked tokens location.
Please let us know if there are any other concerns, and we are open to hear more feedback and provide further clarification if needed. We discuss this topic at length in the recent revision (Section 3.2: “Avoiding a degenerate deterministic solution”).
Dear reviewer, we would greatly appreciate it if you could review our new response. We believe that we have effectively addressed all of your previous concerns. We actively stand by for the last few hours of the discussion phase.
Thanks for the response! After reading the rebuttal, I think some of my concerns are addressed (empirical evidence showing the benefits of using learned Sigma vs the fixed Sigma). However, in terms of reusing the matrix A to , I am still not convinced (W2). The current version of doing this projection lacks clear motivation and thus leaving it hard to judge the correctiveness. In this regard, I am afraid I agree with reviewer gTwQ, and I look forward to a better justification of the formulation.
We thank the reviewers for their insightful and positive comments. The reviewers mentioned that stochastic positional embeddings (StoP) is “novel” (gTwQ) and an “intuitive idea to prevent the model from overfitting” (Hczn). Furthermore, reviewer 92Mr mentioned that the authors’ idea to prevent degenerate solution of the covariance matrix is a “smart modelling design choice”. Lastly, all reviewers are satisfied with the experimental study. Specifically, they mentioned StoP “significantly enhances the state-of-the-art method, demonstrating a noticeable improvement” (Hczn), “show consistent improvements” (92Mr), and that the experiments are “sufficient to support the proposed method” (gTwQ).
We addressed all the reviewers’ comments and incorporated their feedback to the new paper revision (new text highlighted in red).
The reviewers maintain concerns regarding the approach, despite the author rebuttal. The paper needs to be improved in its clarity of presentation before acceptance. In particular, the use of the same matrix A to project noise tokens and context tokens was not well justified in the submission. The authors provided further intuition in their responses, but neither of the two reviewers who raised the same concern were satisfied by the further justification in the author responses. For this reason, the authors are urged to solidify their justification of this choice in their method and potentially also explore alternative choices as they outline in one of their responses.
为何不给更高分
Two reviewers raised major concerns about the methodology, despite author responses.
为何不给更低分
N/A
Reject