Discrete Neural Algorithmic Reasoning
摘要
评审与讨论
This paper addresses the challenges of generalization and interpretability in neural algorithmic reasoning (NAR) by introducing a novel architecture that guides the model to follow algorithmic execution as a series of finite, predefined states. The proposed architecture has three main components: feature discretization, hard attention, and separation of discrete and continuous data flows. The empirical results show that the model achieves perfect scores in both single-task and multi-task experiments across various algorithms. Additionally, the architecture enhances interpretability, allowing validation of the correct execution of desired algorithms.
优点
- The writing is generally clear, though some areas could be further elaborated.
- The paper is well-motivated, as both OOD generalization and interpretability in NAR are both important questions. The proposed architecture of separating the discrete and continuous data flows are novel and effective.
- The perfect scores across algorithms are impressive, especially given the model’s capacity for size generalization on graphs 100 times larger, outperforming strong baselines.
- The proposed architecture consists of three design components, each of which is validated through ablation studies to demonstrate its effectiveness.
缺点
-
While the perfect scores achieved in the experiments are impressive, the paper could be strengthened by testing on a wider range of algorithms. Although both parallel (e.g., BFS) and sequential (e.g., Prim) algorithms are covered, most algorithms studied are graph-based, where previous NAR methods have already proven effective. The CLRS-30 dataset includes a broader variety of algorithms (e.g., sorting and search), where many NAR methods can struggle. Although SALSA-CLRS was chosen for its thorough OOD evaluation, testing the model on the CLRS-30 dataset with larger graph sizes would add valuable insights due to the dataset’s extensive algorithm coverage.
-
Another limitation is the lack of application to real-world datasets, as the authors note in the Future Work section. A significant advantage of NAR methods is their ability to operate on high-dimensional data by utilizing a pretrained model that mimic algorithms. This presents an additional OOD challenge with potential distribution shifts in real-world data. Evaluating the proposed architecture on real-world datasets would demonstrate its practical value; even a single real-world experiment, as seen in related works (e.g., Numeroso et al., 2023), could significantly strengthen the method's implications.
-
Although interpretability is a valuable strength of the proposed method, it is unclear how this model’s interpretability, achieved through analyzing state transitions and attention blocks, differs from other NAR approaches, which also allow interpretation of intermediate executions (e.g., using decoder outputs to indicate a node’s current predecessor in a sorting algorithm).
问题
-
For the multi-task experiments, do the results in Table 2 correspond to these settings? It is unclear which algorithms were trained simultaneously and how this compares to the baselines’ configurations.
-
Without hints, DNAR struggles to generalize. Could you clarify the reverse-engineering issue described? Was any intermediate approach, such as noisy teacher forcing or teacher forcing decay, tested to examine the impact of hints on DNAR?
-
How were GIN and PGN used as baselines? Were they employed to treat this as an end-to-end node-level prediction task?
We thank the reviewer for their review and constructive comments!
the paper could be strengthened by testing on a wider range of algorithms
We agree that a more general architecture is of great interest for future work. In its current form, the proposed model is not capable of executing some algorithms from the CLRS-30 benchmark, however, simple modifications can enhance its expressivity, e.g., by supporting more complex aggregations with top-k attention (with fixed k). Also, as we mention in the paper, the model can be extended with additional modifications, such as edge-based reasoning. For example, edge-based top-2 attention (where each edge chooses the top 2 adjacent edges to receive the message from) can implement the sort of triplet reasoning.
In our work, we aim to describe the key principles that help to build perfectly generalizable models and do not focus on a general architecture capable of executing a wide range of algorithms.
Another example of simple modification which extends the list of algorithms the model can be tested on, is supporting multiple different scalars in ScalarUpdater, which can allow the model to execute several sorting algorithms. While this does not fully address the mentioned limitation, please let us know if you consider such an experiment important and we will add this during the discussion period.
It is unclear how this model’s interpretability … differs from other NAR approaches
The main difference is that with discrete states we “fully observe” the inner computations, while simply looking at the intermediate hint predictions might be misleading.
For example (this particular example is motivated by our analysis of the no-hint model, please note the general response), let us consider the model for the BFS task, which internally uses K discrete node states as distances from the starting node and able to correctly predict all intermediate hints when evaluated on graphs with diameter less than K. If we try to interpret such model by intermediate updates of hints on graphs with small diameter we will simply observe the perfect mimicking of the ground truth algorithm without understanding how exactly the model works: does it uses the notion of visited nodes directly or implements the “distance concept” (or any more complex logic).
In other words, if there is some information flow beyond the predicted hints (which is true for all covered baselines) inside the model, then simple probing of hints might not give a full picture about the underlying computations inside the model.
For the multi-task experiments, do the results in Table 2 correspond to these settings?
For the multitask experiments, we train all 6 algorithms from Table 1 simultaneously. The motivation for the multitask experiment is to check if the proposed processor is capable of multitask learning without losing the generalization capabilities. We did not test the baselines in the multitask experiment. We will add more details regarding the multitask experiment to the text and a separate column to Table 1 to highlight the perfect scores in the multitask setup too.
Without hints, DNAR struggles to generalize. Could you clarify the reverse-engineering issue described? Was any intermediate approach, such as noisy teacher forcing or teacher forcing decay, tested to examine the impact of hints on DNAR?
Please see the second part of the general response for the reverse-engineering details.
We indeed tested some additional approaches, but all of them are more about different hint usage rather than about finding a good intermediate approach: e.g. we checked if it is possible to learn the proposed models without teacher forcing at all and got perfect scores for all problems besides DFS. We did not test teacher forcing decay for the DFS problem, instead, we achieved perfect scores for DFS without teacher forcing by weighting the hint losses along the execution trajectory (by annealing the hint losses along the processor steps) to enforce the sequential learning of the problem.
An interesting intermediate approach that leaves the freedom for the model to learn something different is to supervise the model only on a part of the hints, but we did not test that.
How were GIN and PGN used as baselines?
We reused the implementation from the SALSA-CLRS, we refer to the corresponding paper and public source code for the details. In short, they add a GRU cell after the message-passing step of GIN (GINe is used for the problems with the edge weights), the baselines trained with hint supervision without teacher forcing. They only decode the hint predictions to calculate the loss. The hints of type pointer are predicted from the pairs of node features and node masks are predicted from the node features. Please let us know if some particular details remain unclear.
We hope that our response addresses your concerns and will be happy to answer any additional questions during the discussion period if needed.
Thank you for the detailed response. All my questions are well answered. I chose to maintain my current score to accept the paper. Though, while understanding the difficulties present, I still believe the paper can be further strengthened by a wider range of algorithms or real-world experiments.
This paper studies neural algorithmic reasoning and proposes a discrete transformer-based processor architecture. The authors show that the proposed model empirically achieves perfect size generalization on several algorithms picked from the CLRS benchmark.
优点
- The problem of learning neural algorithmic reasoning is important. I like the overall idea of learning discrete neural reasoners since the previous efforts of learning neural networks with continuous states and or continuous operators failed miserably.
- The architecture design has some interesting components. For example, the separation of the data and the computational flow that manipulates the data is interesting. In particular, the scalars (continuous input) only affect the computation of attention weights and do not affect the node or edge states. Although designs with the same spirit have appeared before, e.g., in Neural Execution Engines, this part has its own merits.
- The paper is easy to follow and well-written, except that a few technical details are sparse.
缺点
-
I am concerned about the significance of the contributions. If the claims of this paper are all correct, then we just obtain a recipe for learning a neural reasoner to perfectly mimic a known algorithm. However, to make the learning successful, we need to first run the algorithm to collect the full trace of the execution, i.e., the sequence of intermediate states generated by the algorithm (the so-called hints). In other words, we just perfectly fit the “correct algorithm” using a neural network.
Moreover, I do not see a theoretical guarantee of when this perfect fitting would happen. For example, would it happen on a specific class of algorithms or any algorithm? See my next comment on the claim about “the guarantee”.
This is unsatisfying since the goal of neural algorithmic reasoning is to learn correct algorithms from data without knowing the algorithms. From this perspective, the really interesting and valuable part is the exploration under the no-hints setting. However, in Section 6, the authors did not provide any experimental results and just stated that their model never achieved perfect validation scores. In short, the authors should provide a thorough empirical study of their proposed model under the no-hints setting and compare it with other approaches on benchmarks like CLRS.
- In Section 5, the claim “we can guarantee that for any graph size, the model will mirror the desired algorithm, which is correct for any test size” is quite strong. In my opinion, such a strong claim needs rigorous theoretical proof. However, the authors only provide some vague arguments to support this claim.
In particular, can you elaborate on the following two questions?
- How do you confirm that the attention block indeed operates as “select_best” selector? I neither see any empirical investigation nor theoretical proof on this point.
- Even if the attention block operates as “select_best” selector, why does this condition lead to the above claim? You should at least demonstrate your detailed logic using an example algorithm like DFS.
- I think the size of the discretized states would matter a lot to the expressivity of the proposed neural reasoner, i.e., what algorithms can be represented by the designed class of neural networks. However, I did not see the experimental study and discussion on its effect.
- An interesting experiment is to check if there is some sort of phase transition in terms of problem size. Specifically, I would imagine the phenomenon of perfect fitting to the correct algorithm would disappear if we decreased the problem size in training. Then, what is the minimum problem size to ensure it fits perfectly for a particular algorithm?
- The details of how to train task-dependent encoders and decoders are not provided, and how they affect the processor's training is not discussed at all. I would imagine the quality of the encoded embedding is quite important to the success of learning the processor, even with teacher forcing.
- In section 3.3, I get the high-level idea that the scalars (continuous input) only affect the computation of attention weights and do not affect the node or edge states. However, the description of the idea in the 2nd paragraph of Section 3.3 is not so clear that I could not figure out how exactly the computation is designed. It would be great to either write the equations and or illustrate the computational graph.
问题
Please see my comments in the weakness part.
Thank you for the detailed review! We addressed the questions and concerns below.
If the claims of this paper are all correct, then we just obtain a recipe for learning a neural reasoner to perfectly mimic a known algorithm. … This is unsatisfying since the goal of neural algorithmic reasoning is to learn correct algorithms from data without knowing the algorithms
First, let us note that the goal of NAR is to learn the network to execute algorithms in the latent space and is not restricted to the setup where the ground truth algorithm is unknown. For example, see the blueprint of NAR (Veličković & Blundell, 2021): “An algorithmic reasoner is trained to imitate A, optimising it to be close to ground-truth abstract outputs, A(x). P is a processor network operating in a high-dimensional latent space, which, if trained correctly, will be able to imitate the individual steps of A”.
Also, we note that the ground truth hints (and different forms of supervision on them during training) are the key component of the CLRS-30 benchmark.
Moreover, for the current state of the field, learning with hints is an important and unsolved problem. E.g., the large body of research (Section 2.1 of the paper), including the state-of-the-art approaches (Bevilacqua et al., 2023; Bohde et al., 2024) heavily rely on different forms of carefully designed step-by-step hints and are not applicable for no-hint learning without additional modifications.
However, we fully agree that learning without hints is an important and challenging problem for further developments of neural algorithmic reasoning.
How do you confirm that the attention block indeed operates as “select_best” selector? I neither see any empirical investigation nor theoretical proof on this point.
In short, we confirm this empirically by directly computing attention weights. As we described, the unnormalized attention weight between nodes U and V depends only on the states of U and V and the indicator if V has the smallest scalar value among the neighbors of U. Thus, the “select_best” property follows from the fact AttentionWeight(U_state, V_state, is_best=True) > AttentionWeight(U_state, V_state, is_best=False) for all possible combinations of (U_state, V_state).
Even if the attention block operates as “select_best” selector, why does this condition lead to the above claim? You should at least demonstrate your detailed logic using an example algorithm like DFS.
Please, see part 1 of our general response where we describe the detailed logic. We will also extend the description in the paper. Please let us know if some details remain unclear.
I think the size of the discretized states would matter a lot to the expressivity of the proposed neural reasoner,
We agree that the model with a greater states count can express strictly more functions. However, for models trained with hints, the states count is directly specified with the hints structure. There is the possibility to design alternative hints to utilize more states, but this requires describing the new hints structure. For no-hint models, states count is a simple hyperparameter. However, as we preliminarily investigated in Section 6, this count should be as small as possible. For example, one can think about extreme cases where the states count is close to the number of training data points; the resulting model can demonstrate perfect training loss but zero generalization capabilities.
An interesting experiment is to check if there is some sort of phase transition in terms of problem size. Specifically, I would imagine the phenomenon of perfect fitting to the correct algorithm would disappear if we decreased the problem size in training
Intuitively, from the perspective of state updates, such phase transition should occur if the training size/distribution is not enough to cover all possible state combinations/subtasks.
For example, for the BFS problem, it is enough to use graphs with only 2 nodes to observe that a not_visited node becomes visited or not depending on the received message. However, the subtask of selecting the parent from the multiple visited neighbors requires at least 4 nodes (where the minimum sufficient example is the complete bipartite graph K(2, 2)).
We conducted additional experiments to empirically find if the described transition occurs and what is the smallest training size for perfect fitting of each covered algorithm. We train the DNAR models for each problem on ER(n, 0.5) graphs for different n keeping the remaining hyperparameters the same as in the main experiments, and test the resulting models on the graphs with 160 nodes (same as in Table 1 of the paper).
Node level scores on graphs with 160 nodes for different training sizes:
| 3 | 4 | 5 | |
|---|---|---|---|
| BFS | 41 | 100 | 100 |
| DFS | 38 | 100 | 100 |
| Dijkstra | 13 | 26 | 100 |
| MST | 11 | 14 | 100 |
| MIS | 79 | 100 | 100 |
| Eccentricity | 45 | 100 | 100 |
Note that the empirical bound is around 4-5 nodes.
The details of how to train task-dependent encoders and decoders are not provided, and how they affect the processor's training is not discussed at all. I would imagine the quality of the encoded embedding is quite important to the success of learning the processor, even with teacher forcing
In our work, encoders and decoders are simple linear layers trained simultaneously with the processor (which is similar to prior work mentioned in Section 2.1).
In section 3.3, I get the high-level idea that the scalars (continuous input) only affect the computation of attention weights and do not affect the node or edge states. However, the description of the idea in the 2nd paragraph of Section 3.3 is not so clear that I could not figure out how exactly the computation is designed. It would be great to either write the equations and or illustrate the computational graph.
We allow the scalars to affect only the Key vector in the attention mechanism:
For each node , we group the neighbors of the node by their state, and for each neighbor decide if the edge has the smallest scalar value among all neighbors of with the same state as .
Then, we consider the obtained 0-1 indicator for each edge in the attention block as a part of the edge state when we compute the Key vector for the edge.
You can find the computation of the described 0-1 indicators (variable best_in_group) in lines 30-38 here: https://anonymous.4open.science/r/F4CA/processors.py In short, the mentioned code doubles the amount of used states by splitting each state depending on the discussed indicator and the Key vector of each edge is obtained from these states.
We will add an illustration to the text.
We hope that our response addresses your concerns and we are open to further discussions.
I thank the authors for the detailed response. Some of my concerns and confusion have been resolved, e.g., the details of architecture and experiments. However, my main concerns remain:
- The strong claim we can guarantee that for any graph size, the model will mirror the desired algorithm, which is correct for any test size is not well grounded. In the response, the authors said this is mainly confirmed empirically, which is rather weak. To make this part more convincing, a serious statistical analysis is needed. For example, how frequently would the successful learning of the individual steps like the select_best selector happen?
- I still think presenting the results of your method under the without-hints setting is necessary. Given that your method achieves almost perfect performance under the with-hints setting, people would expect it to perform better under the without-hints setting. If not, then it suggests the performance gain under the with-hints setting is not brought by the improved inductive bias of the model (like your architecture design) but rather by the hints.
Therefore, I am still on the borderline.
Thank you for your response!
The strong claim we can guarantee that for any graph size, the model will mirror the desired algorithm, which is correct for any test size is not well grounded
Let us emphasize the main claim we are trying to make: the proposed architecture allows us to fully interpret the trained model independently from the graph sizes/distributions and, if learned correctly, the model can perfectly imitate algorithms for any test data. This property is non-trivial and summarizes our contribution (and importantly, it should not be statistically grounded). To the best of our knowledge, none of the models in the field possesses this property. We will make our claim more clear to highlight that it is related to the architecture itself. We also note that we cannot guarantee that the training will converge to the correct model for any initialization/training data distributions (the simplest example is the “phase transition” experiment from the previous reply). On the other hand, while we consider the mentioned property of the proposed architecture important, it would be less of practical interest if such “correct learning” would occur only for specific data distributions. In our work, we only experiment with the standard data distributions from the benchmark and several seeds, but evaluating the robustness of the training process for different distributions and hyperparameters is indeed important. We will conduct an additional experiment to highlight optimization challenges that might occur during training.
I still think presenting the results of your method under the without-hints setting is necessary. Given that your method achieves almost perfect performance under the with-hints setting, people would expect it to perform better under the without-hints setting
We consider poor no-hint performance as one of the main limitations of the proposed method. To avoid the mentioned misexpectation we added a note to the paper that simple removal of hints breaks the learning and the resulting models are only slightly improved over the untrained ones.
If not, then it suggests the performance gain under the with-hints setting is not brought by the improved inductive bias of the model (like your architecture design) but rather by the hints.
In general, we consider the proposed model in its current form as a potential answer on where perfect generalization might come from. In this sense, the performance gain is from the architecture design, as we compare with hints-based models. On the other hand, the proposed model can be considered as a particular way to use hints in a more precise and accurate way.
While achieving similar performance with no-hint models would be of greater interest, end-to-end learning of discrete models, especially with long rollouts, is a great challenge. Importantly, there are several ways to relax some discrete constraints to balance generalization with the applicability of continuous optimization. Also, there are several potential strategies to improve training without hints, e.g. (this particular example is from our discussion with Reviewer P5ow) we guess that some sort of iterative self-distillation might be useful in this setting for BFS (and similar problems):
- Train a no-hint model with good validation performance (similar to our experiments): such a model can capture the desired (but not the simplest) dynamics of the problem;
- Train a new model with less states count supervised only on the intermediate updates of the outputs of the model from the first step (e.g., predicted pointers after each step).
However, for sequential problems, such as DFS, obtaining a good model for stage 1 can be difficult. One possible way to overcome this limitation is to design a curriculum learning setup.
But, as we mentioned in the paper, we leave a deeper investigation of learning interpretable and generalizable neural reasoners without hints for future work.
Thanks again for your involvement in the discussion; we will be happy to answer any additional questions if needed.
The authors make some modifications to Transformer training to make it more effective for algorithmic reasoning tasks. In particular, they constrain it to learn hard attention, and separate discrete and continuous "flows" to prevent information loss. Overall these modicfications result in a model than the GIN and PGN baselines they consider.
优点
The approach seems to work much better than the baselines considered, and achieves perfect performance on the two datasets studied in this article.
缺点
I found this paper to be difficult to read, but I'm not an expert in the area. It would be useful to see each of the design decisions in this paper ablated, with their corresponding effect on the datasets covered in this research.
问题
-
I'm very much not an expert in this area, so I did a bit of a literature search on related papers in the field. I found one such paper [1], used a text-based version of the CLRS-3o, called TextCLRS-text. Does it make sense to use that dataset here? Or to compare to the proposed TransNAR architecture?
-
As mentioned earlier, it would be useful to ablate the different algorithmic decisions made in this paper. In its current form, I can't ascertain which of the architectural decisions is responsible for which gains on these datasets.
Thank you for your review of our paper! We address your questions and concerns below.
I found one such paper [1], used a text-based version of the CLRS-3o, called TextCLRS-text. Does it make sense to use that dataset here? Or to compare to the proposed TransNAR architecture?
There is some connection between our work and the work by Bounsi et al. (2024). In short, TransNAR discusses a method to enhance reasoning capabilities of language models with the task-specific GNN-based NAR model, where any NAR model can be used as an “internal tool”. Thus, replacing the baselines GNN with the proposed DNAR model improves the quality of the tool that the language model can use and the overall performance will be limited only by the correctness of the “tool usage”, and not by the inaccuracies of the tool itself. However, there are some difficulties in measuring the direct effect of including the proposed DNAR model in the TransNAR pipeline, as the source code for TransNAR is not yet publicly available.
It would be useful to ablate the different algorithmic decisions made in this paper. In its current form, I can't ascertain which of the architectural decisions is responsible for which gains on these datasets.
Please note the ablation experiments in Appendix A. Recall that we have three key components of our contribution: feature discretization, hard attention, and separating discrete and continuous data flows.
In short, we demonstrate that removing each of these components yields the model without provable guarantees of perfect generalization. However, these components differ in terms of the impact on the performance. E.g., using regular attention instead of hard attention yields perfect test scores for given datasets, but it is possible to construct adversarial examples with large neighborhood sizes where performance drops. On the other hand, removing discretization from the scalar updater significantly affects the performance even on the small test graphs.
If you have any other questions or concerns, we are happy to discuss them further.
Dear reviewer, please let us know if you find the mentioned ablation experiments useful and if you have any other concerns. We also note that we added an experiment to investigate if the proposed ScalarUpdate module can be extended to support more complex continuous manipulations with scalars (please see Appendix C of the paper). In short, we find our results positive and consider the proposed discretization as a promising way to overcome the difficulties of value generalization in an interpretable way.
We hope that the revised version of our work with additional details from the general responses is now more clear to readers.
We will be happy to answer any additional questions if needed.
Thanks for the conversation and for adding additional experiments. Overall, I agree with reviewer 865w that, even with the supplementary experiments, the results presented in this paper are below the standards set by ICLR. Proper ablations and analysis and ablations could be done to better understand the perfect performance reported. Further, I’m concerned with reviewer 2sp9’s point, around the claim that the model will mirror the desired algorithm for any graph size. I also agree that the no-hint setting deserves more attention, as right now it’s not clear where the source of the gains are. I’m electing to keep my score at 5.
The paper proposes a use of a finite set of predefined discrete states and manipulation of continuous inputs to improve limitations of current paradigms in neural algorithmic reasoning such as redundant dependencies in algorithmic data. The method achieves perfect performance on chosen tasks in SALSA-CLRS and CLRS-30 benchmarks and introduces a starting perspective on use of discrete states to pave the way for future methods in discrete neural algorithmic reasoning and their interpretability.
优点
- Impressive performance achieved across selected tasks in SALSA-CLRS and CLRS-30 benchmarks
- Interpretability and simplicity of the proposed model which utilizes set of discrete states to capture continuous inputs using edge priorities
- Further perspectives on paving way for discrete neural algorithmic reasoners
缺点
- There is limited explanation of how the discrete states are computed, the paper extensively discusses related work in sections 1 and 2, perhaps it would be better if this space was used for providing further detail on proposed architecture
- Work is limited in providing mathematical / theoretical definitions or explanation of discrete state space and manipulation with continuous inputs which is the main novelty
问题
- Results show perfect performance on SALSA-CLRS and CLRS-30 benchmarks. Could you elaborate on reported results and 0 variance between your seeds given that this is not the case in majority of comparable baselines or ablation experiments? Could you please report variance on results in table 1?
- Could you please explain the differences in computed discrete states used for SALSA-CLRS and CLRS-30 tasks as well as how the model performs on out-of-distribution data?
- Could you provide a pseudocode behind Discretize_nodes and Discretize_edges functions shown on page 4?
- Work mentions modifying hints from the benchmark, could you clarify if this was included for tested baselines as well as how the hints are modified?
- Could you further elaborate on the addition of a virtual node in your model (section 4.3)?
- Could you provide results of the hyperparameter search mentioned in section 6?
- Could you provide further detail on reverse-engineering (perhaps a toy example) mentioned in section 6?
Could you provide a pseudocode behind Discretize_nodes and Discretize_edges functions shown on page 4?
Let be the -th node features (). First, we apply a linear layer to get the unnormalized logits of representing each of possible states: . Then, we select the state from these logits (this step is different depending on training/training_with_teacher_forcing/inference) and then we use the (learnable) embedding of the corresponding state.
During training with hints, we use ground truth states (we optimize the logits with CE loss and use teacher forcing) and for no-hint learning we use Gumbel softmax with annealing temperature. At the inference we always use argmax.
def discretize(node_features):
states_logits = LinearProjection(node_features) # shapes are: (num_nodes, hidden_size) -> (num_nodes, num_states)
if training_without_teacher_forcing:
states_one_hot = gumbel_softmax(states_logits)
if training_with_teacher_forcing:
states_one_hot = ground_truth_states
if not training:
states_one_hot = argmax(states_logits, dim=1)
node_features = states_embeddings(states_one_hot)
return node_features
Work mentions modifying hints from the benchmark, could you clarify if this was included for tested baselines as well as how the hints are modified?
We only modify hints for the DFS problem, keeping the original execution flow. In particular, the original hints for the DFS problem (Veličković P. et al., 2022) use several additional concepts such as the global time counter and node’s discovery/finish times, which are not directly used in the DFS problem itself, but needed for other DFS-based problems from the CLRS-30 benchmarks (e.g., Bridges and SCC).
As graph-level features are not needed directly for any of the covered problems, we keep the model without the graph-level features and remove graph-level hints for the DFS problem.
These changes were not included for the tested baselines.
Could you further elaborate on the addition of a virtual node in your model (section 4.3)?
The virtual node is needed to select the unique node from the whole graph (for example, to select the current node at each step of the MST algorithm). In prior works, this type of computation is done by using graph-level features (which are updated depending on all node features and used in all node features updates) or using complete graphs for message passing. We use a simpler alternative that does not maintain the graph-level features: the virtual node is not updated from the node features and is only capable of selecting important nodes via the attention mechanism. In other words, the “unique” node becomes the one from which the virtual node received the message.
Could you provide results of the hyperparameter search mentioned in section 6? Could you provide further detail on reverse-engineering (perhaps a toy example) mentioned in section 6?
Please see the second part of the general response, where we give a more detailed description of our analysis. Please let us know if some details remain unclear.
We hope that our response addresses your concerns and will be happy to answer any additional questions during the discussion period if needed.
Thank you very much for your detailed response to my questions as well as general response in the rebuttal. It is useful to see deeper analysis behind conducted experiments, particularly behind perfect performance on used benchmarks as well as clarification behind algorithms used to compute discrete set of states. Despite useful provided commentary, I still believe the work is below requirements for the acceptance threshold, and would therefore maintain my current score. I would recommend further complementing work with more detailed theoretical portion, particularly in the area of training convergence as it is mentioned that this is challenging to guarantee. I would also recommend searching for potential counter-examples (perhaps beyond used benchmarks) which could help test limits of the proposed direction. I hope that this feedback is helpful to the authors for the continuation of their work on this problem.
We thank the reviewer for a thoughtful review! Let us address the raised concerns and questions.
Could you elaborate on reported results and 0 variance between your seeds given that this is not the case in majority of comparable baselines or ablation experiments?
The proposed model indeed converges to the perfect mimicking of the ground truth algorithms and demonstrates perfect scores across different seeds. The main reason why this is not the case for the baselines and ablation experiments is discretization: due to the design of the proposed model it needs only to learn a finite amount of discrete state transitions and (as we discuss in the paper and in the first part of the general response) the same state transitions are required to execute the algorithm on any test data. On the other hand, the baselines and the models from the ablation experiments do not force the node and edge features to be from the fixed sets and usually encounter OOD values when given the larger test graphs. Performance on such data can vary across neural networks initializations and other randomness from the training process.
Note that while we report the perfect scores for the covered tasks, we cannot theoretically guarantee that the training will converge to the correct model for any initialization/training data distribution.
Could you please report variance on results in table 1?
Standard deviation of the node-level scores for the baselines across 5 different seeds (test size is 160 nodes) (we will add the full version of this table to the paper):
| GIN | PGN | |
|---|---|---|
| BFS | 8.9 | 0.1 |
| DFS | 3.1 | 2.5 |
| Dijkstra | 8.3 | 7.0 |
| MST | 6.5 | 4.5 |
| MIS | 3.3 | 0.3 |
| Eccentricity | 18.9 | 0.1 |
Could you please explain the differences in computed discrete states used for SALSA-CLRS and CLRS-30 tasks as well as how the model performs on out-of-distribution data?
Discrete states are the same between the benchmarks as SALSA-CLRS uses the CLRS-30 hints for the problems in CLRS-30 (BFS, DFS, Dijkstra, MST). The main difference is the training and test data: as described in Sections 4.1 and 4.2, the test data for SALSA-CLRS contains more sparse graphs.
As described in Section 5, the performance of the proposed model does not depend on the size and distribution of the test data (please see the first part of the general response for more details).
On the other hand, the performance of the baselines may vary depending on particular test data distributions, and achieving a good generalization for every target distribution is considered challenging in prior work (Georgiev et al., 2023).
Thank you for your response! Let us reply to the raised concerns.
I would recommend further complementing work with more detailed theoretical portion, particularly in the area of training convergence as it is mentioned that this is challenging to guarantee
Let us highlight the main part of our contribution: we consider the proposed model in its current form as a potential answer on where perfect generalization might come from. The proposed architecture allows us to fully interpret the trained model independently from the graph sizes/distributions and, if learned correctly, the model can perfectly imitate algorithms for any test data. To the best of our knowledge, none of the models in the field possesses this property (neither theoretical guarantees of training convergence).
I would also recommend searching for potential counter-examples (perhaps beyond used benchmarks) which could help test limits of the proposed direction
As we mentioned in the paper, one important limitation of our work is the reduced expressiveness of the proposed architecture. For example, the hard-attention mechanism cannot compute averaging over neighbors in a single message-passing step. In this sense, finding potential counter-examples is trivial. On the other hand, we also mentioned that there are several ways to improve expressiveness without losing the generalization and, importantly, to improve expressiveness by reducing the generalization. Thus, we believe that the discrete building blocks of the proposed model might arise in different forms as building blocks in a more general architecture.
For example, our ablation experiments (Appendix A) demonstrate that several components (e.g. hard attention) can be removed to relax strong generalization guarantees while yielding good performance. Also, we conduct an additional experiment to investigate if the proposed ScalarUpdate module can be extended to support more complex continuous manipulations with scalars (please see Appendix C of the paper). While learning to find correct manipulations from the discrete set might be challenging for particular data distributions or updates sets (e.g., such decomposition might not be unique), we find our results positive and consider the proposed discretization as a promising way to overcome the difficulties of value generalization in an interpretable way.
We hope that our response addresses your concerns.
This paper introduces a novel approach to neural algorithmic reasoning by enforcing neural networks to operate with discrete states and separating discrete and continuous data flows. The authors propose a model that integrates hard attention mechanisms, feature discretization, and a separation between discrete computations and continuous inputs (scalars). This design aims to align neural network computations closely with classical algorithms and thus improves out-of-distribution generalization and interpretability. The method is evaluated on several algorithmic tasks from the SALSA-CLRS benchmark, including BFS, DFS, Prim's algorithm, Dijkstra's algorithm, Maximum Independent Set (MIS), and Eccentricity calculations. The proposed Discrete Neural Algorithmic Reasoner (DNAR) achieves perfect test scores on these tasks, even on graphs significantly larger than those seen during training. The authors also discuss the limitations of their approach and potential directions for future work.
优点
- Originality: The paper presents a novel method that enforces discrete state transitions in neural networks, which is a significant departure from traditional continuous representations. By integrating hard attention and separating discrete and continuous data flows, the authors address key challenges in neural algorithmic reasoning, particularly out-of-distribution generalization, and interpretability.
- Quality: The experimental results are strong, with DNAR achieving perfect test scores across multiple algorithmic tasks and graph sizes. The comparison with baseline models and state-of-the-art methods demonstrates the effectiveness of the proposed approach.
- Clarity: The paper is generally well-written and structured. The authors explain their methodology, including architectural choices and training procedures. The inclusion of diagrams and tables aids in understanding the proposed model and its performance.
- Significance: The work contributes to the field by showing that neural networks can be designed to mimic classical algorithms with perfect generalization and interpretability. This has implications for developing reliable and trustworthy AI systems that can be formally verified.
缺点
- Expressiveness Limitations: The enforced constraints, such as hard attention and discrete state transitions, limit the model's expressiveness. For instance, in a single message-passing step, the model cannot compute certain aggregate functions, like averaging over neighbors. This restricts the method's applicability to algorithms that fit within these constraints.
- Scope of Evaluation: The experimental evaluation focuses on specific algorithmic tasks where the proposed method aligns well. It remains unclear how the model would perform on more complex algorithms that require different computational primitives or continuous manipulations beyond simple increments.
- Training Without Hints: While the method achieves excellent results with hint supervision, training without hints is identified as challenging. This limitation reduces the method's applicability in scenarios where intermediate algorithmic steps (hints) are unavailable.
- Sensitivity to Hyperparameters: The paper mentions that certain hyperparameters, like the number of discrete states, significantly impact performance, especially when training without hints. However, there is limited discussion on how sensitive the model is to these hyperparameters and the implications for generalization.
- Presentation Details: While the paper is generally clear, some sections could benefit from additional explanations. For example, the scalar updater's mechanism could be elaborated further to better understand if someone is not in the NAR field.
问题
- How does the proposed method handle algorithms that require more complex continuous manipulations or aggregate functions beyond simple increments and selections? Can the model be extended to support such algorithms without losing the benefits of discretization and interpretability?
- Are there potential strategies to improve training without hint supervision? How does the model perform on tasks without hints when compared to continuous models?
- The constraints improve generalization but reduce expressiveness. Is there a way to balance this trade-off by selectively relaxing some constraints?
- Can the proposed separation between discrete and continuous data flows be applied to other neural network architectures beyond attention-based models? What challenges might arise in such adaptations?
The constraints improve generalization but reduce expressiveness. Is there a way to balance this trade-off by selectively relaxing some constraints?
While there are several ways to improve expressiveness without losing the generalization at all, we also can improve expressiveness by reducing the generalization:
1.Removing hard attention: as we demonstrate in our ablation experiments, using regular attention instead of hard attention yields perfect test scores for the BFS problem, but it is possible to construct adversarial examples with large neighborhood sizes where performance drops. While for more complex attention patterns (besides strictly attending to the single node) the OOD performance might be less robust, the expressivity gain is significant.
2. Removing feature discretization, but updating scalars with discrete operations: as shown by prior work and our ablation experiments, learning precise continuous manipulations is non-trivial and small inaccuracies in such manipulations can significantly affect the overall performance of the NARs. Thus, we can use the proposed separation between the discrete and continuous data flows and do not discretize node/edge features at all (and use discretization only in the ScalarUpdater). However, for non-attention-based models one needs to come up with how scalars will affect the discrete flow.
Can the proposed separation between discrete and continuous data flows be applied to other neural network architectures beyond attention-based models? What challenges might arise in such adaptations?
The general idea is to update scalars with predefined operations and let them affect the other flow only in a discrete manner. The simplest implementation of the latter is to update node/edge features depending on simple comparison/selections of scalars, as we did in the paper. When not aiming for the generalization guarantees, we can simply use such discrete updates of the features in the discrete flow and use any GNN as a processor. The node/edge features in the resulting model will still be discrete (e.g., could express complex substructures of the graph).
We guess that the main challenge arises when we try to answer the question of where perfect generalization might come from and how different GNN models perform on their limits of generalization. E.g., for attention-based models, this leads to the notion of attention weights annealing. We guess that such challenges are architecture-specific. For example, we guess that the potential analysis of the MPNN with MAX aggregation (and predefined fixed sets for node/edge features) is somewhat similar to the one of the hard-attention model (as it is invariant to the exact number of neighbors of each state, only presence of each state is important), but analyzing the possibility of perfect generalization with another aggregation functions might be more challenging.
We thank the reviewer for thoughtful questions and we are happy to discuss further.
I am happy with the answers and would like to have one more experiment for this: "Please let us know if you are interested in a particular type of scalar manipulations in a particular algorithm, we are ready to conduct an additional experiment during the discussion period."
Authors can choose the most challenging one.
We thank the reviewer for the constructive review and positive feedback! We address the questions below.
How does the proposed method handle algorithms that require more complex continuous manipulations or aggregate functions beyond simple increments and selections? Can the model be extended to support such algorithms without losing the benefits of discretization and interpretability?
First, we note that simple manipulations with scalars cover a significant part of the classical algorithms. In our work, we use the minimum set of the required functions, but it can be directly extended by other functions such as pow/sin/cos/exp. However, this complicates the optimization problem of selecting the correct operations/operands from the operations results (e.g., such decomposition might not be unique). Importantly, as ScalarUpdater can be viewed as a separate module, we can separately check if it is possible to train ScalarUpdater with any given set of predefined manipulations for any problem only with MSE on the results. Please let us know if you are interested in a particular type of scalar manipulations in a particular algorithm, we are ready to conduct an additional experiment during the discussion period.
The proposed selection with hard attention indeed significantly reduces expressivity. The simplest way to support more complex aggregations in a single processor step is to use top-k attention (with a fixed k). Also, as we mention in the paper, the model can be extended with additional architectural modifications, such as edge-based reasoning. For example, edge-based attention with top-2 attention (where each edge chooses the top 2 adjacent edges to receive the message from) can implement the sort of triplet reasoning.
Are there potential strategies to improve training without hint supervision? How does the model perform on tasks without hints when compared to continuous models?
We perform a hyperparameter search only for the BFS problem, achieving up to 79 node-level accuracy when tested on the graphs with 64 nodes, please see the second part of the general response for more details. Without hyperparameter search, discrete models are only slightly improved over the untrained models.
As we demonstrate in our no-hint experiment, the proposed discrete model is potentially capable of capturing the desired dynamic of the algorithm (e.g., using the concept of distance for the BFS task) when its architecture is well aligned with the underlying problem. However, the direct regularization of the simplicity of the resulting model (such as reducing the states count) performs poorly as it significantly changes the optimization problem. We guess that some sort of iterative self-distillation might be useful in this setting for BFS (and similar problems):
- Train a no-hint model with good validation performance (similar to our experiments): such a model can capture the desired (but not the simplest) dynamics of the problem;
- Train a new model with less states count supervised only on the intermediate updates of the outputs of the model from the first step (e.g., predicted pointers after each step).
However, for sequential problems, such as DFS, obtaining a good model for stage 1 can be difficult. One possible way to overcome this limitation is to design a curriculum learning setup.
Another intermediate approach (similar to step 2) is to supervise the model only on a part of the hints (e.g., pointers) of the ground truth algorithm.
Thank you for being involved in the discussion!
We conduct an additional experiment to investigate if the proposed ScalarUpdate module can be extended to support more complex continuous manipulations with scalars (please see Appendix C of the paper). While learning to find correct manipulations from the discrete set might be challenging for particular data distributions or updates sets, we find our results positive and consider the proposed discretization as a promising way to overcome the difficulties of value generalization in an interpretable way.
We are open to further discussion!
Dear Authors,
I am pleased with your revisions and the clarity of your experimental results. The paper is now excellent, and I have updated my score to Accept (8).
Details on no-hint discrete reasoners
In the paper, we briefly discussed the challenges of no-hint learning of the proposed DNAR model (Section 6). Following the questions of the reviewers, in this part of the general response we extend this discussion and we will update the paper accordingly.
Recall that for no-hint experiments we focus on the BFS problem. We perform hyperparameter search over the training sizes (using ER(n, 0.5) graphs for ), discrete node states count (from 2 to 6 states), softmax temperature annealing schedules ([3, 0.01], [3, 0.1], [3, 1]). For each hyperparameter choice, we train 5 models with different seeds. We validate the resulting models on the graphs of size 16. The best resulting model is obtained with the training size 5 and 4 node states.
Node/graph level scores of the best_no_hint_model for different graph sizes:
| 5 | 16 | 64 | |
|---|---|---|---|
| best_no_hint_model | 97/86 | 94/34 | 79/0 |
Then we tried to analyze the mistakes of the resulting models and to understand how they utilize the given states. First, we look at the node states after the last step of the processor and note that the states correspond to the distances from the starting node. More formally, we note that the model with 4 states uses the first state for the starting node, the second state for its neighbors, the third state for nodes at distance 2 from the starting node, and the last state for all other nodes and such states-based classification of distance has accuracy > 98% when tested on 1000 random graphs with 16 nodes.
Then, we note that for the nodes that are from the first 4 distance layers from the starting node, the pointers are predicted with 100% accuracy and these pointers are computed layer-by-layer as in the ground truth algorithm. The mistakes of the model are on the distance >= 4 from the starting node (we did not reverse-engineer the specific logic of computations on larger distances).
We’ve added several illustrations of the node states and the dynamics of the pointer prediction updates. (Note that for the simplicity of illustrations, we use the model with 3 discrete states).
We believe that with more extensive hyperparameter search one can obtain a no-hint BFS model with 5 states that will have perfect scores on graphs with small diameters, however, such a model will not generalize to larger diameters. On the other hand, hint-based models utilize only 2 node states in a slightly different manner and achieve perfect generalization for any test data, thus we highlight the need to achieve perfect validation performance with models that use as few states as possible, as we discussed in the paper.
We would like to thank all the reviewers for their valuable feedback. In this general response, we first provide a more detailed description of how we prove the correctness of the learned algorithms on any test data and then extend our discussion on no-hint discrete reasoners.
Proving correctness
This part of the general response extends our discussion in Section 5 of the paper.
As an example, consider the BFS algorithm. First, recall the pseudocode of the algorithm:
Starting_node <- visited
All other nodes <- not visited
Self-loops <- pointers
All other edges <- not a pointers
For step in range(T):
For each node U in a graph:
If U is visited on previous steps: continue
If U has a neighbor P that visited on previous steps:
U becomes visited on this step
U select the smallest-indexed such neighbor P as parent:
Edge (U, P) becomes a pointer
Self loop (U, U) becomes not a pointer
Return a BFS tree described by pointers
Now let us describe how we can verify that the trained DNAR model will perfectly imitate this algorithm for any test data.
First, we note that for each node , the node state on the step (denoted by ) is the function of and , where is the node that sends the message to on step : StateUpdate(, message_from_V_t) = .
How does the node select a node that will send a message to it? For any node connected to , the node computes attention scores depending on discrete states of of each node and a discrete indicator if each node has the smallest (or largest) scalar among all neighbors of with the same discrete state as . Then, the node selects the node with the largest attention score.
In our (slightly simplified) case, the attention scores only depend on the tuples (, , indicator_if_u_has_the_smallest_index) and there are only 8 such tuples. We can directly compute these attention scores and verify the required invariants: e.g., AttentionScore(not_visited, visited, smallest) > AttentionScore(not_visited, *any other combinations*), which would imply that the not_visited node will receive the message from the smallest-indexed visited neighbor if such exists independently of the graph size and distribution. If there is no such neighbor, the node will receive the message from another not_visited node (or from itself).
After verifying the correctness of the message flows, we need to ensure if the state updates are computed correctly: e.g., StateUpdate(not_visited, message_from_visited) == visited, StateUpdate(visited, *any*) == visited, StateUpdate(not_visited, message_from_not_visited) == not_visited, etc.
The main idea is that due to the finite states count and discrete manipulations with scalars, there are only finite amounts of such checks that can cover all possible state transitions and all of them should be evaluated only once.
We will add the example above to the revised paper.
We are happy to provide further clarifications if needed.
We thank all reviewers for carefully reading our work and providing constructive feedback. We deeply appreciate the time and effort spent providing detailed reviews.
We are currently working on updating the paper with clarifications and additional experiments, described in our replies. Please let us know if you have any additional questions or comments, this will help us to properly address the raised concerns and update the paper according to your feedback.
We look forward to sharing the updated manuscript in the coming days.
We thank the reviewers for taking part in the discussion!
We have uploaded a new revision of the paper:
- We conducted an additional experiment to investigate if the proposed ScalarUpdate module can be extended to support more complex continuous manipulations with scalars (Appendix C);
- We extended our discussion in Section 5 (interpretability and testing) with the particular example from the general response;
- We extended our discussion in Section 6 (no-hint discrete reasoners) with the additional details and illustrations from the general response;
- We added several minor improvements and clarifications according to the reviewer’s suggestions.
Will be happy to answer any additional questions during the discussion period if needed
This paper presents an exciting approach to neural algorithmic reasoning which directly relies on constraining the model's conditioning to a finite, discrete set of states.
I believe the idea is important and definitely has promise. However, as it stands, even though the Authors' rebuttal resulted in some scores being raised, it appears the majority of the reviewers are leaning on the rejection side than the acceptance side.
From my read of the reviews and the paper, it appears that the Authors could definitely benefit from either (a) evaluating their method in more diverse experimental settings (beyond the specific algorithmic tasks considered here), or (b) having a more solid theoretical grounding of why the method is effective.
Upholding this opinion, I have decided to recommend this paper for rejection. I reiterate that the work is very interesting, but accepting it in its current form (without a more thorough investigation beyond the tightly controlled experimental environment studied so far) may do the work and its reach a disservice. The Reviewers did not oppose my decision.
审稿人讨论附加意见
No additional comments beyond the main meta-review.
Reject