Flexible task abstractions emerge in linear networks with fast and bounded units
We train neural networks in changing environments and show that task abstractions emerge in parameters trained with fast learning rate and heavily regularized. The task abstractions can then support cognitive flexibility.
摘要
评审与讨论
The paper investigates fast task switching/adaptation in a gated linear network. Specifically, it shows that neuron-like properties including regularization, fast learning, and non-negativity lead the model to demonstrate fast task adaptation and generalize compositionally. They provide a detailed analysis of the learning dynamics to show the components responsible for fast adaptation and the conditions in which a flexible scheme arises instead of a lazy forgetful one. They also replicate similar patterns in a more complicated task and model setting, and replicate human learning characteristics in prior studies.
优点
This work provides extensive model behavior/weight analyses and a thorough mathematical analysis of the learning dynamics to understand the driving forces of the observed fast adaptation.
The authors go beyond simple controlled tasks and a linear model to investigate the same phenomenon in a nonlinear model using MNIST.
This work draws a nice parallel with human cognition that is backed both theoretically and empirically.
The paper is dense but well-written.
缺点
No major weakness that I identified. One comment is that because the paper is very dense, some details are omitted (e.g. in figures) and need to be inferred.
问题
How many training instances were in the spiky loss part, in both the synthetic task and the MNIST task? E.g. does the model achieve single-shot adaptation in the task block?
Would the gates adapt to a new task replacing an old task quickly?
Must gating-like properties be at the 2nd layer in the deep monolithic network? What happens when you insert regularization and faster learning in layer 1 units? In general, for larger-scale applications, is it better to have gating-like mechanisms in later parts of the model?
What happens when non-gating weights are also equipped with some amount of neuron-like properties? Are there neuroscientific reasons to maintain separate groups of task weights and gate weights?
局限性
Noted in the paper.
Thank you for your thoughtful comments. We are glad to hear that you find the theoretical analysis of the underlying mechanism insightful. Moreover, we are encouraged that you value our aim to test our framework against behavioral experiments as well as to demonstrate a basic usefulness beyond synthetic tasks.
Q1: How many training instances were in the spiky loss part, in both the synthetic task and the MNIST task? E.g. Does the model achieve single-shot adaptation in the task block?
The reviewer brings up the important question about how quickly is this model able to adapt to previous tasks and new tasks. As we run our simulations in continuous time (gradient flow), each gradient is calculated over a large number of samples (200 for both cases). We made this choice to permit theoretical analysis where a sample average is required. Figure 1F shows the number of time steps needed until adaptation for the synthetic plot, and we supply a similar version for fashionMNIST in Fig. R2.
To investigate the sample efficiency of the model on the synthetic task, we now provide a simulation in Fig. R3 with coarsely discretized time where only a single sample is used in every update. While this makes the simulation more noisy in a single timestep, its qualitative features persist: In late blocks, the loss drops significantly on the first sample, and reaches its minimum after a few samples (few-shot adaptation). More generally, the persistence of the fast adaptation in the equivalent model in Fig. 3 (where noisy input is absent) suggests fast adaptation is indeed facilitated by the gates.
We will include both figures in the appendix of the final revision.
Would the gates adapt to a new task replacing an old task quickly?
If there are some shared components between the old task(s) and new task(s), the gates would adapt to appropriately modulate the parts of the network that correspond to the new task. For example, we illustrate this in our generalization experiments where we switch the network to a new set of tasks (A+B, B+C, A+C) composed of different combinations of the tasks previously encountered (A, B, C). In these cases, the network adapts the gates to turn on the parts of the network corresponding to the new tasks, and turns off the rest of the network, allowing it to generalize its learned knowledge rapidly.
Must gating-like properties be at the 2nd layer in the deep monolithic network? What happens when you insert regularization and faster learning in layer 1 units? In general, for larger-scale applications, is it better to have gating-like mechanisms in later parts of the model?
Because the deep monolithic model is linear, there is no meaningful difference between gating occurring in the first versus the second layer. The main distinction is that rather than gating the output of computations already performed in the network, inducing gating-like behavior in the first layer would instead gate off certain computations in advance (by zeroing their input) and only perform computations on specific paths within the network. If regularization and a faster learning rate are instead applied to layer 1 units, we would expect to see similar gating-like behavior emerge in the first layer rather than the second. Indeed, the question of whether there is an advantage to gating earlier or later in a network is an interesting open question which may only arise once additional depth and nonlinearity are added to a model.
What happens when non-gating weights are also equipped with some amount of neuron-like properties?
Thank you for addressing the minimal conditions needed for gating. We can distinguish two settings:
First, what happens if we equip the weights with a fast timescale? We find that with very fast weight learning rates, specialization eventually disappears: The weight representations will be fast enough to quickly adapt after a block change, not requiring any gating (Fig. 5). This however is an unnatural assumption, as synaptic plasticity is slow.
Second, we can ask whether an architecture with two layers of weights, but without any scalar gates will show gating-like behavior. By regularizing and increasing the learning rate of one layer, we indeed observe the fully-connected weight layer to exhibit gating-like behavior, while the other layer learns to specialize to the corresponding tasks, emulating the same behavior we see in our gated model. (Fig. 6).
Are there neuroscientific reasons to maintain separate groups of task weights and gate weights?
The reviewer brings up an important question as to how the model structure maps to brain structures. There are multiple gating circuits in the brain. One is the prefrontal cortex (PFC) gating activity in earlier cortices. We think of the weights as the synaptic strengths in motor and sensory cortices as they learn a task, while the gating to switch between distinct computations is projected top-down from PFC. In addition, cognitive flexibility and switching task representations in PFC occurs through rapid changes in neural activity rather than synaptic updates. As such, we see them implemented in different regions, and also through different substrates (i.e. synaptic plasticity vs. neural activity changes).
One comment is that because the paper is very dense, some details are omitted (e.g. in figures) and need to be inferred.
We thank the reviewer for this valuable comment. We reviewed the manuscript and found some important details omitted especially regarding the details of the experiments run in each figure. The final version of the manuscript now addresses this.
Thank you for providing thorough answers to these questions, they all make sense. Great work!
Thank you for your support and encouraging words. Having us clarify the speed with which adaptation between tasks happen was quite valuable in particular. Thank you for your thoughtful comments and questions.
This paper looks at the problem of having an agent learn a series of tasks sequentially by receiving supervised data for each task. Training a NN on this has the issue of catastrophic forgetting, and learn best with shuffled data. Their model attempts to be more similar to humans, who do best with task data presented nonshuffled. The idea is to have subnetworks which can learn each task, and then learn a high level router to the subnetworks. The router uses a gating variable trained with regularization to encourage only one subnetwork to be active at a time. As a result, when trained on the nonshuffled data, the model can pick up on the different tasks and assign them to the subnetworks. The authors show some experiments where the model does just that, with some analysis of the dynamics. They also provide some theoretical analysis of what is causing this to happen, which is possible because it is a pretty simple setup.
优点
- The paper models an important observed difference between humans and neural networks.
- The model is remarkably simple and easy to understand. It is tempting to label the results as "predictable" or not surprising, but simple models are good.
- The authors provide an impressive barrage of experiments analyzing their model. in particular, I am impressed by the experiments in section 5, where they show how specialization emerges as different hyperparameter choices reach their optimal values.
缺点
- it seems like parts of the architecture, such as the number of subnetworks and the size of the subnetworks are tuned to match those given in the system.
- a lot of the paper structure seems different from what is typically done for neurips. for example, there is no related work discussed in the main text, but it is in the appendix. many of the experiment details are also in the appendix, which is weird because without the details, it's somewhat hard to understand the results presented in the main text.
- as a somewhat theoretical and simplified model, it's unclear how much the approach relates to more "real world" tasks. the one task shown is MNIST, which is still a pretty artificial example because the second task is also MNIST, but with a weird permutation. This leads to my main question and reservation, which is that the model might only work if the input tasks are orthogonal? if so, that's a huge limitation, and dramatically changes my view of the paper. My current score is assuming that this limitation is true, but if it were not true, I would give a higher score (maybe an 8).
问题
Questions:
- what is P set to for the experiments? Suggestions:
- it seems like you need to explain the experiments more in the main text. the results are just shown without fully explaining the experiment setup. Are there two tasks, that are cycled repeatedly?
- why not use some other image recognition task besides permuted MNIST? is it because the tasks have to be orthogonal?
- Line 116 I would explain the regularization constraints are in words besides just presenting the equation. you should also link to where it is explained further in the appendix.
局限性
It seems like the model only works if the input tasks have orthogogonal solutions, is that true? if so, that's a huge limitation.
Thank you for your review and thoughtful comments. We are delighted that you appreciate the simplicity of our model and the thorough experiments we provide.
The idea is to have subnetworks which can learn each task, and then learn a high level router to the subnetworks.
We thank the reviewer for the accurate summary regarding the routing function of the gating variables in our setup. In contrast to other methods that learn gating of task representations, we update weights and gates simultaneously through gradient descent, reflecting real-world conditions in which learning needs to be segmented and structured at the same time. Thus, they are fundamentally treated on equal footing. Critically, our suggested regularization (alleviating overspecification by favoring c to be of order 1) does not strictly encourage one task to be active at a time, since mixed solutions are given the same penalty.
We are excited about this model because it solves several computational problems with one mechanism: It is able to discover boundaries between tasks, retrieve task representations for previously learned tasks, and move representations for new tasks away from previous ones.
it seems like parts of the architecture, such as the number of subnetworks and the size of the subnetworks are tuned to match those given in the system.
So far, we have indeed only considered the case where the number of subnetworks P equals the number of tasks M (P=M=2) for simplicity. We fully agree that analysing P>M (i.e., multiple paths) improves the understanding of the model. To this end, we conduct a new experiment in the attached figure R4. The system still learns to solve both tasks and adapts flexibly to context switches. In this setting however, students only partially specialize, since they are now underconstrained. This disappears when regularization of student weight magnitude is imposed, which can physiologically be interpreted as a “representational cost”: Over time, we would expect a biological system to prune the additional components.
We note that, under certain conditions, the flexible regime can also emerge in the fully-connected architecture which is inherently free of these assumptions (Fig. 6).
On this basis, we adjust the presentation of our base model to the more general case P>M.
The paper structure seems different from what is typically done for neurips
Many of the experiment details are also in the appendix, which is weird because without the details, it's somewhat hard to understand the results presented in the main text.
Thank you for pointing this out! We had deferred our related work section to the appendix for space constraints, but see how it is valuable in the main text, so that we will incorporate a shorter version there. We will move necessary experimental details back to the main text as well.
is it because the tasks have to be orthogonal?
This leads to my main question and reservation, which is that the model might only work if the input tasks are orthogonal?
We thank the reviewer for this important comment. We ran several simulations here to identify how orthogonality might impact the model. Prior related work (Lee et al., 2024) provided an extensive investigation of orthogonality of representations in linear networks. For theoretical simplicity, we considered only orthogonal tasks. We ran three new simulations in the PDF attached to this rebuttal.
First, we plot how parameterizing the overlap between teachers affects student specialization (Fig. R1A). Ideally, we expect the model to maintain a shared representation according to the similarity between the tasks while separating the non-similar parts. Indeed, we see such a graded specialization in our model proportional to the overlap between the tasks. .
Second, we train our model on a set of three compositional tasks, with pairwise overlap (e.g., A+B, B+C, A+C), making them non-orthogonal. Our model learns to specialize to the underlying shared components and appropriately gate them to solve each of the three non-orthogonal tasks (Figure R1B).
Third, we move slightly from MNIST to fashionMNIST to operate on natural images, and use two permutations with varying degrees of orthogonality. We use permutations because we are strictly interested in the setting of cognitive flexibility where the same input can be mapped to different responses. The new permutations are based on delineating (orthogonal) upper-vs.-lower body or (correlated) warm-to-cold weather. We find that the system works well in both cases, with specialization occurring only marginally later in the non-orthogonal setting (Figure R2).
In summary, the studied phenomenon is robust to correlations between tasks.
My current score is assuming that this limitation is true, but if it were not true, I would give a higher score (maybe an 8).
We appreciate the reviewer stating their concern and clearly stating how it influenced their decision. We hope that our responses above and the additional simulations addressed the concern about orthogonality fully.
I would explain the regularization constraints are in words besides just presenting the equation.
We discuss two kinds of regularization:
- Non-negativity of gates. It is motivated by the fact that gating behavior can only amplify or attenuate a variable, but not change its sign.
- Alleviating overspecification. Due to the linear nature of our architecture, the output scale can be changed both by amplifying weights or gates. However, gates should not need to account for learning the scale of the task. The regularization encourages gates to be of order 1. This has the effect whereby upscaling one gate above equilibrium will drive other gates to be downscaled.
We will add a short explanation to mathematical expressions in-place and emphasize the more extensive treatment in Appendix B.
Thank you for answering my questions and concerns. I am impressed by the experiments on nonorthogonality, and will raise my score to an 8.
Thank you for the thoughtful comments. The non-orthogonal tasks experiment was quite informative and will be a valuable addition to the final version. Thank you for the updated score.
This paper presents a method aimed at learning generalizable task abstractions by constraining neuron dynamics in artificial neural networks. It takes inspiration from biological neurons by constraining artificial neurons to non-negativity, forcing a faster timescale, and regularizing.
The paper goes through a problem setup, then investigates various aspects of the approach. First, it argues that task specialization emerges through joint gradient descent, showing generalization results on compositional tasks. Next, it gives extensive theoretical analysis of the mechanisms of this task adaptation. It then briefly shows evidence for emergence of specialization in the model and visualizes weights, then finally discusses applications.
优点
Originality
- Applying these constraints to neurons does not seem to be particularly novel. However, I have not seen this specific combination or analysis from the perspective of task adaptation.
Quality
- Experiments are extensive. Each claim in the paper is backed up by lots of relevant evidence.
- Motivation is interesting at a high level
- Investigating these general, simple changes as drivers of flexible task abstractions is creative and exciting. The authors have clearly thought about various aspects of this.
缺点
Quality
- Though the motivation is interesting, it is confusing because the mechanisms appear to be inspired by catastrophic forgetting but then are applied to the related but distinct problem of generalizable task adaptations
- The central claim/benefit of the paper, that this architectural adaptation might make neural networks more brain-like, is barely supported; the authors even admit to this. However, that seems to be one of the main value adds of the paper
Clarity
- Paper setup is very confusing. It's okay to have a nontraditional structure, but it ends up being a prose-heavy laundry list of various properties of the method and results.
Significance
- A proven (at least, reliable) model of some biological brain phenomenon is valuable. However, this paper has only shown vague similarity, not much comparative predictive power of this approach vs an unconstrained approach on neural data
- A machine learning result can also have impact, but the paper doesn't clearly argue for that. It does it show that the representations can generalize across two training settings in MNIST, but this isn't enough to make a real ML argument.
问题
None
局限性
Not really
Thank you for your helpful comments. We appreciate that you found the conducted experiments relevant to our claims.
Investigating these general, simple changes as drivers of flexible task abstractions is creative and exciting.
We are glad that this contrast came across well: Despite the simplicity of the proposed mechanism, the effect on inducing flexible task abstractions is remarkable.
Paper setup is very confusing. It's okay to have a nontraditional structure, but it ends up being a prose-heavy laundry list of various properties of the method and results.
Thank you for this valuable feedback on the organization of our manuscript. We agree that it failed to guide the reader through.
On a high level, our paper has this structure:
Hypothesis: Are neuronal gates updated through gradient descent sufficient to separate tasks into dedicated representations? First, we answer through simulations whether and how implicit context switches are learned (Fig. 2). Second, we analytically address why task representations arise (Fig. 4). Third, we apply the model to cognitive science (Fig. 7). The remaining figures are not essential to this narrative, but function as controls for the necessary conditions for flexible gating.
To address this point, we now strictly keep to this structure:
1. We reorganized the order of the mathematical analysis and grouped subsections therein. We now first precisely define flexible adaptation, and then discuss the main driver for its emergence.
2. We moved the Related works section back to the main text to better state our contribution.
3. We moved figure 3 showing how the learned abstractions generalize to the supplementary, as it does not directly support the main narrative. We clearly motivate the remaining figures as controls.
Contribution
The reviewer raises concerns about the contribution of our work.
We aim to contribute to theoretical cognitive science by developing a minimal neural network model of cognitive flexibility. We build on a line of recent theoretical NeurIPS papers studying the learning dynamics in linear networks and relating them to cognition, starting with Saxe et al. (2013, Exact solutions to the nonlinear dynamics of learning). Later studies examined how gating alleviates interference, but their gating was static and handed to the network (Saxe et al., 2022, Dynamics of abstraction in gated networks). We generalize this line of work with an analytically tractable model of how appropriate gating emerges dynamically.
Our model contributes to cognitive science:
1. We offer, to our knowledge, the first neural network model that benefits from data distribution shifts and struggles in the shuffled data regime, similar to humans.
2. We provide a direct comparison to humans where their task switching behavior accelerates as they practice the tasks involved. This came from both our simulations and our theoretical analysis, and we provide a mechanistic explanation for the phenomenon.
3. Not only does the model infer tasks and retrieve the suitable task abstraction, but similar to humans, further learning or credit assignment is gated by the inferred task. I.e. only the parameters for the inferred task are changed.
We believe that such a model can make additional behavioral predictions and provide a neural basis for explaining additional experimental findings, such as Heald et al’s influential study on sensorimotor repertoires (2021, Nature).
A proven (at least, reliable) model of some biological brain phenomenon is valuable. However, this paper has only shown vague similarity, not much comparative predictive power of this approach vs an unconstrained approach on neural data
We aim to present a minimal hypothesis of how task abstractions in (biological) neural networks might emerge during training.
The reviewer highlights the importance of connecting directly to neural data. Unfortunately, almost no full recordings from animals during the training phase on multiple tasks exists. Recordings are mostly from well-trained animals switching between the tasks (discussed by Bouchacourt et al (2022), Fast rule switching and slow rule updating). Neuroscientists point to a technical difficulty where animals take weeks to months to learn two opposing tasks while recording electrodes slip and drift within days.
Still, technical improvements especially with calcium field imaging might make recordings during learning possible soon: In addition to behavioral signatures, our model predicts gradual differentiation of both task-specific activity vectors in neural space modulated by context that arise over the course of learning. We expect this simple and interpretable signatures to be robustly observed across tasks and species, an advantage over purely Bayesian models and more expressive machine learning architectures.
A machine learning result can also have impact, but the paper doesn't clearly argue for that.
The reviewer identified that the simulated model might be extensible to an ML method for larger non-linear datasets. We view constructing an ML system that incorporates these principles to improve machine learning results as a clearly separate project, which we are considering taking on in the future after having laid the theoretical groundwork with this paper.
The mechanisms appear to be inspired by catastrophic forgetting but then are applied to the related but distinct problem of generalizable task adaptations.
The reviewer highlights that results on the generalizability of our learned task abstractions belong to a different problem and need comparison to different controls such as few-shot learning or meta-learning models. We moved these results to the appendix and clarified they only serve to show that the abstractions are functional and can be recomposed through gradient descent.
We will clarify our contributions in the introduction section in the final revision.
Thanks for all the effort! Updated thoughts:
- Presentation is much better now, thank you for incorporating the comments.
- ML contribution: the added experiments help as well. It's still not a big ML result of course, but it's enough to seem promising. Your discussion with reviewer pBhB about orthogonality was illuminating.
- Cog sci contribution: fair argument re: neural data not being available for this setting. That said, lack of testbeds to demonstrate claims doesn't make it okay to have insufficient evidence. The similarity is still high-level and only somewaht demonstrated. I wouldn't call the hypothesis "minimal" so much as "weak" (not weak like bad, weak like nascent/weakly correlated). I do buy that given the state of the fields, this is a more thoughtful and tight experiment than I'd originally felt.
Raising my score to a 6.
We thank the reviewers for their thoughtful comments that helped us improve the manuscript. We were glad that the reviewers found “the model remarkably simple and easy to understand” (reviewers 9Ed7 and pBhB), and studying the emergence flexible task abstractions “creative and exciting” (9Ed7, SwaN). The reviewers also appreciated the significance of the work as modeling “an important observed difference between humans and neural networks” (pBhB, SwaN) with a “motivation that is interesting at a high-level” (reviewer 9Ed7).
While the reviewers found the experiments to be “extensive” and the claims well backed up (reviewers 9Ed7 and pBhB), they also identified improvements that can be made to the conceptual structure of the paper for clarity. Additionally, reviewers pBhB and SwAN pointed out that the manuscript was dense and the experiments had some important details relegated to the supplementary.
We discuss these insightful comments in the responses to individual reviewers. To address, we are taking the following actions to the final revision:
- We restructured the manuscript for clarity: We rearranged the order of analytical results and grouped the points needed to state the analytical mechanism behind the model. We moved the Related works section to the main text to clearly state our contribution. We moved generalization results that did not support the main narrative to the appendix.
- We ran simulations to test how the model might handle shared task information, as our manuscript had only orthogonal tasks chosen to permit theoretical analysis. We see that tasks are separated to the extent needed, but not more (see Figure R1 in the attached PDF).
- We now better clarify our aim to contribute to theoretical cognitive science by developing a minimal neural network model of cognitive flexibility. In particular, the work addresses a gap in a line of previous work at NeurIPS which aimed at interpretable neural models, positioned between Bayesian approaches and more expressive machine learning architectures.
- Finally, we attach the necessary simulations which allow us to address the reviewers’ questions.
Again, we thank the reviewers for their time and insightful comments. We hope our specific responses below address concerns and questions raised.
The reviewers all emphasized that the experiments are extensive and detailed, solidly backing up the stated claims. The simple proposed changes (faster time scales, non-negativity and regularization) on the gating layer are clearly demonstrated to induce task selectivity, enabling specialization of the subnetworks. During the discussion phase, authors answered insightful questions posed in the reviews in detail and provided new results for experiments on datasets with compositional tasks and a more complex image dataset, which provided additional insights and further evidence for the claims in the paper. The paper is a clear accept!