PaperHub
4.6
/10
Rejected5 位审稿人
最低3最高6标准差1.4
6
3
3
6
5
3.4
置信度
正确性2.2
贡献度1.8
表达2.2
ICLR 2025

Identifying Sub-networks in Neural Networks via Functionally Similar Representations

OpenReviewPDF
提交: 2024-09-28更新: 2025-02-05
TL;DR

We propose a novel approach to identify functionally similar subnetworks within neural networks.

摘要

关键词
mechanistic interpretabilitysubnetworks

评审与讨论

审稿意见
6

This paper proposes a new metric based on the Gromov-Wasserstein distance to measure the similarity between layers in neural networks, which can be used for mechanistic interpretability and understanding of the inner workings of neural networks. Experiments are conducted on algebraic and language tasks and a number of interesting patterns and observations are reported.

优点

  • A new metric for measuring the functional similarity of layers which has a number of desirable properties.
  • The metric behaves experimentally as expected which may aid in improving our understanding of the inner workings of neural networks.
  • Experiments on algebraic and language tasks under a number of settings, e.g., fine-tuning, sparsity, etc., reveal some sensible patterns.

缺点

  • The metric relies on qualitative assessment of the numerical results, which may be complicated for large neural networks, or may be subjective or ambiguous.
  • Missing comparison with state-of-the-art metrics, e.g., Yu et al., 2024, Klabunde et al., 2023, Lohit & Jones, 2022 (listed in Sec 2)
  • The mathematical notation is inconsistent and in some minor instances incorrect, e.g., "Let f0 be a function that produces a set of output Y0 given a set of input X, where Y0 ∈ Rn×dy", i.e., so is Y0 a set or a matrix?
  • Overall, the paper was hard to follow, and I had to read it multiple times to appreciate the details.

问题

  • Can we use this metric in a quantitive manner?
  • Can observations reported in Sec 5.2 and 5.3 be shown experimentally? For example, the paper states, "The first major differences occur at block 9 and then the last three layers (10, 11, 12) seem to form a distinct block. This seems to indicate that most of the function/task fitting occurs at these later layers. This may be explained by the fact that gradient values are potentially larger at these later layers so they change faster from the pretrained models." This can be shown experimentally.
  • The algebraic task is simple which is suitable for a controlled study, can a computer vision task be included to compare with the observations reported for the NLP task?
评论

We thank the reviewer for their valuable feedback. In response, we have clarified notations and improved the presentation in the revision.

Metric in a quantitative manner: The proposed GW distance can absolutely be used in a quantitative manner. However, it would generally require some kind of ground truth for comparison, which is challenging to obtain in the context of sub-network discovery. Our synthetic experiment of modular sum is designed to provide ground truth subnetworks that compute different functions, but this does not guarantee the absence of further subnetworks within them (in fact, modular sum experiments have been shown to contain further sub-functions, as noted by Nanda et al., 2023).

Nevertheless, we design downstream tasks of model fine-tuning (Appendix G) and pruning/compression (Appendix J) to evaluate how well the discovered sub-networks would perform in prediction. Shown in Appendix J, we take the original pre-trained BERT and only use the first n={12,8,4,2,1,0}n = \{12, 8, 4, 2, 1, 0 \} transformer blocks while discarding the rest. Here n=12n = 12 corresponds to using all the transformer blocks, resulting the same BERT model. n=0n=0, on the other hand, means that we only use a (linear) classifier layer after the embedding layer to predict the class label. The results are shown in Table 7 in Appendix J. As a reminder, GW distance suggest the last 4 blocks in YELP (see Figure 6) and the last 2 blocks in SST (see Figure 17) are mostly different, which is marked by star ()(*) in the table. It shows that by using a limited number of layers, we can achieve similar performance to the full 12 block model, with 0.01%0.01\% and 0.54%0.54\% differences in YELP and SST, respectively. In contrast, using one fewer transformer block can risk much larger performance drop, with 0.10%0.10\% in YELP and 8.60%8.60\% differences in SST2 (approximately a 10-fold worse reduction). These accuracy results further justify the quality of the subnetworks identified.

Missing comparison with state-of-the-art metrics: In the original draft, we tested cosine and Wasserstein distances, used by some of the papers suggested. In the revision, we further added 6 additional baselines to provide a more comprehensive comparison. The results demonstrate that our method consistently identifies subnetworks with greater clarity and reliability compared to these other approaches.

Observations reported in Sec 5.2 and 5.3: We added an experiment, as discussed above, to show that a small subset of layers can fit the function well. Regarding the exact positions and gradient values, we checked the gradients of all layers and found that there is no significant differences among them, except for the gradients in the linear classification layer (i.e., the very last layer), where the gradient values were more than 10 times higher. As a result, we have removed the statement in the revision.

Vision task: As requested, We test our approach on the lightweight ResNet9 model [5], for ease of visualization and comparison. We compare a randomly initialized ResNet9 and a trained model that achieves 91.63% accuracy. We show the distance measures computed using all methods in Figure 18, with GW distance give 4 subnetworks. To further examine how the sub-network structures align with learned representation visually, we visualize the computed distances along with the learned representations of a "ship" image across all layers in Figure 9. The top row shows the representations of a ship at each layer. To observe the gradual changes over layers, we visualize the distance between each layer and its previous layer, using various methods that can handle different dimensions between compared spaces.

As shown in Figure 9, RSM, RSA, MSID, and CKA show indicate significant changes across many layers, without providing clear evidence of sub-network structures. AGW highlights the changes in the final few layers only. In comparison, GW distance demonstrates the most consistency with the image representations visually. Specifically, the 3rd convolution layer (Layer ID 2.ReLU) introduces the first notable differences, where the ship's shape becomes less distinct, signaling the learning of mid-level features. The shapes become increasingly blurred in the 5th convolution layers (Layer ID 4.Conv2d ) and by Layer 4.ReLU the ship's shape is nearly absent. The final convolutional layer (Layer ID 7.Conv2d) shows significant changes from its preceding layer (Layer ID 6.ReLU), marking the point where class-specific information is consolidated. These results suggest that GW distance aligns most effectively with the learned image representations, providing strong evidence that it can reveal meaningful subnetwork structures in vision models.

[5] Park et al. "Trak: Attributing model behavior at scale." ICML 2023.

审稿意见
3

This paper proposes a novel perspective on model interpretation by utilizing the Gromov-Wasserstein (GW) distance to compute similarities between model layers and identify sub-networks. For algebraic and real-world NLP tasks, the proposed methods effectively identify sub-components.

优点

  1. This paper proposes a novel perspective for model interpretation that minimizes both human and computational expenses.
  2. The proposed methodology can effectively identify subnetworks.

缺点

  1. The motivation for identifying subnetworks is unclear. Once identified, I'm curious about how they can enhance our understanding of networks in real-world applications.
  2. The key idea to identify subnetworks is the use of Gromov-Wasserstein (GW) distance, which has been proposed before and might diminish the novelty of the work.
  3. There are some typos, such as a missing comma on line 146.

问题

  1. The motivation for why it is necessary to identify subnetworks is unclear. Could the authors provide more explanation about the relationship between model interpretation and subnetwork?
  2. Additionally, the field of model pruning involves identifying and retaining stronger subnetworks. What are the differences and connections between this approach and model pruning?
  3. Could the author provide more explanation about Figure 3, particularly the relationship between the three identified groups and the model architecture?

I am happy to engage further during the discussion phase.

评论

We thank the review for the valuable comments, and have fixed the typo per suggestion.

Could the authors provide more explanation about the relationship between model interpretation and subnetwork?: As discussed in the introduction, we believe that discovering sub-functions in the form of distinct subnetworks can provide valuable insights into how many functions a neural network utilizes to achieve specific outcomes, such as making predictions. This represents an important first step toward the automated identification of these functions. To address the broader question of "What kind of functions does the neural network learn?", it is both important and beneficial first to explore "How many distinct (sub)-functions does the neural network learn?". We believe understanding neural networks through the identification of subnetworks is essential due to the complexity and opacity of modern deep learning models. Neural networks, especially those with many layers and parameters, often exhibit behaviors that are difficult to interpret holistically. Identifying subnetworks allows us to break down the model into smaller, more interpretable units, providing insights into how individual components contribute to the model’s overall performance.

The key idea to identify subnetworks is the use of Gromov-Wasserstein (GW) distance, which has been proposed before and might diminish the novelty of the work. Our work proposes to use GW for understanding similarity among neural network representations. Although we do not technical deviate from existing GW measure, we justify its usage by invariance properties and empirical evaluations, demonstrating its promising application in detecting substructures in neural networks.

Additionally, the field of model pruning involves identifying and retaining stronger subnetworks. What are the differences and connections between this approach and model pruning?: In our approach, the primary focus is on subnetwork discovery, which aims to identify and ultimately understand the distinct components or sub-functions within a neural network that contribute to its overall task, such as prediction. This process provides insight into the network’s internal structure and the roles of different components, answering questions like 'How many different sub-functions does the neural network learn?". In contrast, model pruning focuses on reducing the size of the model by removing less important or redundant parts, such as neurons or layers, to improve computational efficiency while retaining performance. The goal of pruning is not necessarily to understand the model's internal structures, but to streamline it without significantly impacting its functionality. While these two approaches have different objectives, they do share connections. Insights from subnetwork discovery can inform pruning decisions by identifying which parts of the model are essential for maintaining performance (as we have tested in Appendix L). Additionally, pruning can simplify the model, making it easier to identify and interpret the important subnetworks.

Could the author provide more explanation about Figure 3, particularly the relationship between the three identified groups and the model architecture?: The 3 identified groups correspond to different components within transformer blocks. Group i (layers roughly from 22 to 44, and 52 and 74) and group iii (the initial and final layers) mainly involve MLP and/or residual computations. Figure 3 shows these groups share strong similarities among themselves, indicating that these layers computes relatively simple transformations. In contrast, Group ii (layer roughly from 12 to 20, 43 to 51, and around 74 to 82) consists of attention-computations, suggesting more complex transformations compared to MLP layers. We provide a more explanation in Figure 5 and the accompanying paragraph to the left of the figure.

评论

Thank you for the rebuttal. However, my concerns remain unaddressed. I remain unconvinced by the motivation for the work. Furthermore, the proposed method, based on the Gromov-Wasserstein distance, has been explored in prior research, which raises concerns about the novelty of the contribution. Consequently, I will maintain my original score.

评论

Thank you for your time to review our rebuttal and for your thoughtful consideration.

Motivation: Regarding the motivation, we have revised the introduction and related work sections to provide a clearer context for our study. We would like to further emphasize that our work addresses significant manual effort typically required to understand the underlying functions, which remains a challenge and previous methods have not adequately addressed. For example, Nanda et al, 2023 and Zhong et al, 2024 both rely on manual inspections to reverse engineer these functions, which are limited in smaller networks (e.g., one-layer transformer) and simpler problems (e.g., modular sum between two numbers). These limitations underscore the relevance of our work in automating this effort for larger networks and real datasets.

Novelty:Furthermore, while it is true that the Gromov-Wasserstein distance has been explored in prior research, our work makes a novel contribution by demonstrating its applicability for automated functional discovery. This not only advances its practical usage in the field but also uncovers new and interesting insights for real datasets.

We hope these clarifications address your concerns. We would greatly appreciate any additional suggestions or guidance to further improve the clarity and impact of our work.

审稿意见
3

This paper proposes a novel method for identifying sub-components within neural networks, addressing challenges arising from varying distributions and dimensionalities in intermediate representations.

优点

  • The proposed method discerns similar and dissimilar layers within the network, revealing potential sub-components, and addresses challenges from varying distributions and dimensionalities in intermediate representations.
  • The presentation of the experimental results is precise and well-organized.

缺点

  • The experiments are limited due to insufficient validation across diverse tasks, datasets, and learning methods. This includes gaps in testing with image (classification) datasets, self-supervised learning methods, etc.
  • The contribution of the proposed method is unclear. Can the identified sub-network effectively serve as a proxy, replacing the entire original network to accelerate inference for specific downstream tasks?
  • The motivation for identifying subnetworks requires clarification. Specifically, the phrase "understanding the network by investigating the existence of distinct subnetworks" needs a more explicit explanation. It is unclear why identifying these subnetworks contributes to a deeper comprehension of the neural network as a whole.
  • The proposed technical solution appears straightforward. To address challenges associated with varying distributions and dimensionalities in intermediate representations, the primary approach involves identifying subnetworks using the Gromov-Wasserstein (GW) distance. However, since GW distance has been previously proposed, it could potentially reduce the novelty of this work [1].

[1] Demetci, Pinar, et al. "Revisiting invariances and introducing priors in Gromov-Wasserstein distances." arXiv preprint arXiv:2307.10093 (2023).

问题

  • Could you elaborate on the necessity and benefits of understanding neural networks through the identification of sub-networks?
  • Could you detail the technical innovations presented in this paper that extend beyond the Gromov-Wasserstein method?
  • Could you present additional experimental results across diverse tasks, datasets, and learning methodologies?
  • Could you discuss a real-world application of sub-network identification, such as improving the inference speed of neural networks?
评论

We thank the reviewer for their constructive feedback.

Could you elaborate on the necessity and benefits of understanding neural networks through the identification of sub-networks? Understanding neural networks through the identification of subnetworks is essential due to the complexity and opacity of modern deep learning models. Neural networks, especially those with many layers and parameters, often exhibit behaviors that are difficult to interpret holistically. Identifying subnetworks allows us to break down the model into smaller, more interpretable units, providing insights into how individual components contribute to the model’s overall performance. It has also many benefits. First, it enhances interpretability by enabling us to understand how specific subnetworks process information. Second, it aids in model optimization by identifying which components are critical for performance and which can be removed or simplified. Third, understanding subnetworks helps in debugging and improving the model. By identifying and analyzing subnetworks, we can simplify models, enhance their efficiency, and improve their transparency, leading to more reliable and interpretable neural networks.

Could you detail the technical innovations presented in this paper that extend beyond the Gromov-Wasserstein method? While we do not introduce technical modifications to the existing GW measure, we adapt its application to analyze the similarity among neural network representations and facilitate subnetwork discovery. We justify this approach through its invariance properties and empirical evaluations. Notably, our experiments demonstrate that the GW distance effectively identifies distinct substructures and aligns visually with the learned representations, as shown in Section 5.4.

Could you present additional experimental results across diverse tasks, datasets, and learning methodologies? Per suggestion, we have included experiment results for a convolutional neural network (ResNet 9) on image classification dataset CIFAR-10 in Section 5.4. Our method has now been tested in algebraic, NLP, and computer vision datasets with different neural network architectures. Additionally, we have incorporated the suggested baseline [6] along with several others. Overall, GW demonstrated the clearest subnetwork structures and the strongest visual consistency on image datasets. We believe this further supports the validity of our approach for subnetwork identification.

[6] Demetci, Pinar, et al. "Revisiting invariance and introducing priors in Gromov-Wasserstein distances." arXiv preprint arXiv:2307.10093 (2023).

Could you discuss a real-world application of sub-network identification, such as improving the inference speed of neural networks? While our paper focuses on a different objective -- finding the number of components within a network, rather than identifying the most critical components for a specific task -- with downstream applications aimed at enabling a more automated process for mechanistic interpretability by facilitating the study of each component independently, the reviewer likely suggests an interesting direction on using a subnetwork to speed up inference. We have tested only fine-tuning specific layers in a BERT while freezing the other layers, in Appendix G, which can speed up training. In our revision, we have added an additional experiment involving network pruning, where only certain layers are retained for training. As shown in Appendix L, we identified a subset of transformer blocks (4 for YELP and 2 for SST2, compared to all 12) that can achieve comparable performance. This approach not only aids in understanding but can also accelerates both training and testing.

审稿意见
6

The paper proposes to compute functional similarity between layers of neural networks using Gromov-Wasserstein distance in order to identify subcomponents in the architecture of trained models. In particular, the goal is compare different neural networks layers in order to identify subcomponents of the networks, which may be specific to solve particular tasks. Experiments are conducted on synthetic modular arithmetic tasks and on BERT text models fine-tuned on sentiment analysis tasks.

优点

  • The direction proposed by the paper is interesting and relevant: extracting subcomponents of neural networks based on similarity measures is a core direction for fields such as representation alignment, interpretability and model selection potentially leading to new strategies for transfer learning and neural network pruning and merging, among others.

  • The paper writing is clear and the proposed method is simple.

缺点

While the direction of the paper is interesting the realization seems not to be as broad as it stated in the introductory section and title. In particular:

  • As stated in the paper one core objective is "to detect how many distinct (complex) functions exist in a learned network, and which layers correspond to each such function". While the question is very relevant, I don't see direct evidence of the fact that blocks of layers with identify with a given complex function, except for the synthetic modular experiment. In particular on real data there is no analysis on correlation on accuracy on downstream tasks and GW similarity of layers, which to the bet of my understanding seems necessary in order to assess if a subcomponent learned a specialized function of the input. In order to understand this one way could be just to train leaner readout on the downstream from representations of the two blocks separately and see how the performance changes and correlates with the GW distance. Related to this:

    • In Figure 8 shows just the mean GW distances across all layers which doesn't allow to validate any task specialization of the block of layers with similar GW distance
  • Experiments and method section are focused on transformer networks, not general neural network architectures as discussed in th intro/title. However measuring GW distances in representation space should be agnostic of the architecture. Why the focus just on this architecture?

  • The paper focuses on the synthetic task of modular arithmetic and the one of sentiment analysis on text. In order to make the statements more general more tasks, datasets and /or modalities should be taken into account. For example, it might be interested to see in Vision models based on ResNet or convolutional architectures, early layers subcomponents would be able to solve more task which depends on local statistics of the input, while layers closer to the inputs are more expect to solve task which rely on global information, such as image classification.

  • The synthetic task on modular arithmetic is trained from scratch but it doesn't seem to take into account possible effects given by different random initializations. See questions section for more details on this.

  • The following papers should be included for discussion and/or comparison:

    • Tsitsulin, Anton, et al. "The shape of data: Intrinsic distance for data distributions." ICLR 2020

    • Kornblith, Simon, et al. "Similarity of neural network representations revisited." International conference on machine learning. PMLR, 2019.

In the former, it is proposed a similarity measures based on heat kernels of graph on latent representations which gives a lower bound on the GW distances between embeddings. In the latter is proposed the CKA measure to compute similarity across data representations. This shares most of the favorable properties highlighted for the GW distance and is computationally efficient. Also it shows how comparing different layers of networks reflects the functional structure of each component (e.g. see the Figure 4 in the paper).

More broadly there are many contributions from the field of model merging which might be useful to take into account and discuss, for example:

  • Singh, Sidak Pal, and Martin Jaggi. "Model fusion via optimal transport." Advances in Neural Information Processing Systems 33 (2020)

  • Stoica, George, et al. "Zipit! merging models from different tasks without training." ICLR

  • More details on the method to compute GW distance could be provided in the paper in order to make it self-contained. Specifically, why the choice of the Fused Gromov Wasserstein method in particular?

问题

  • Modular arithmetic experiment: could the authors justify why the choice of the GW distance represent a good choice for extracting compared to measures as CKA [b], canonical correlations analysis and other kernel distance measures mentioned in (Klabunde,2023)

  • Networks initialized with different feed converge to different representations (see e.g. [a,b,c]), although they might perform similarly. In the modular arithmetic experiment how is this taken into account? It would be interesting to measure the GW distance between networks trained on the same data and task but with different initializations. Alternatively the experiments could be repeated averaging results for multiple seeds.

    • [a] Moschella, L. "Relative representations enable zero-shot latent space communication (2023)." ICLR 2023

    • [b] Li, Yixuan, et al. "Convergent learning: Do different neural networks learn the same representations?." PMLR

    • [c] Kornblith, Simon, et al. "Similarity of neural network representations revisited." International conference on machine learning. PMLR, 2019.

  • What is the exact difference between the implemented version of Wasserstein and GW distance in the paper?

评论

We thank the reviewer for their constructive feedback and insightful suggestions.

In particular on real data there is no analysis on correlation on accuracy on downstream tasks and GW similarity of layers ... In order to understand this one way could be just to train leaner readout on the downstream from representations of the two blocks separately and see how the performance changes and correlates with the GW distance. Even when termediate layers learn different functions, their representations may not be effective at prediction. Per suggestion, we have trained a logistic regression and a 2-layer MLP probe on every transformer block to predict the final label on YELP datasets. The accuracies of logistic regression readout across transformer blocks are all around 50%, until the very last block, while the MLP readouts gradually increase from 50% to 59%, until the very last block. Hence leaner readouts may not predict well due to its lack of predictive power. To further check whether layers can identify with a given complex function, we use a vision dataset CIFAR-10 to visually inspect the subnetworks. Results are discussed below.

GW distance vs other baselines: Per suggestion, we have added results of MSID and CKA, along with several others, with publicly available implementation and default parameters in experiments. In comparison to CKA and other kernel distance measures, GW shares some of the invariance properties (including orthogonal transformation and isotropic scaling) but additionally considers the geometric property of the metric space. Moreover, compared to MSID, we find that directly optimizing the GW distance yields better results in identifying subnetwork structures.

ResNet and Vision Experiments: We focus on attention-based architectures because it is the state-of-the-art architecture for NLP tasks and are widely used in existing works for mechanistic interpretability studies (e.g., Nanda et al, 2023). As the reviewer suggested, our methods can indeed generalize to other models. We have tested our method on ResNet-9 trained on the CIFAR 10 dataset. The results, including pairwise distances and visualizations, are provided in Section 5.4 and Appendix N. As shown in Figure 9, GW distance demonstrates the most consistency with the image representations visually. Specifically, the 3rd convolution layer (Layer ID 2.ReLU) introduces the first notable differences in GW distance, where the ship's shape becomes less distinct, signaling the learning of mid-level features. The shapes become increasingly blurred in the 5th convolution layers (Layer ID 4.Conv2d ) and by Layer 4.ReLU the ship's shape is nearly absent. The final convolutional layer (Layer ID 7.Conv2d) shows significant changes from its preceding layer (Layer ID 6.ReLU), marking the point where class-specific information is consolidated. These results confirm that that early layers sub-components focus on local statistics of the input such as shapes, while later layers learn higher levels of features and rely on global information. These results also suggest that GW distance aligns most effectively with the learned image representations, providing strong evidence that it can find meaningful subnetwork structures in vision models.

Model merging related works: Thank you for the suggestions. [1] uses GW distance as a regularization to fuse models. We have cited the works in the related work section.

[1] Singh, Sidak Pal, and Martin Jaggi. "Model fusion via optimal transport." Advances in Neural Information Processing Systems 33 (2020)

[2] Stoica, George, et al. "Zipit! merging models from different tasks without training." ICLR

Implementation details: "Why the choice of the Fused Gromov Wasserstein method in particular?" We do not use Fused GW method. "What is the exact difference between the implemented version of Wasserstein and GW distance in the paper?" We have added a discussion in Appendix H on the exact implementation details of W and GW distances. Specifically, we solve W distance based on [3] and GW distance based on conditional gradient [4].

[3] Bonneel et al. Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG), 2011

[4] Titouan et al. “Optimal Transport for structured data with application on graphs”. ICML, 2019.

Random Seeds and Similarity Results: To study the impact of initialization seeds on the learned representations, we train the same BERT model on YELP datasets with different seeds, with identical hyperparameters for a total of 27,000 iterations. As shown in Figure 19 in Appendix M, while the learned representations vary across seeds, but the general block structures remain consistent when analyzed using GW distances.

评论

I thank the authors for their answers, and additional experiments performed during the rebuttal period. In particular thanks for adding the baselines. I will increase my score to 6.

Although I think the paper is in a better state now, I still have some concerns: (i) outside of the synthetic modular arithmetic task, I'm concerned by the lack of any validation that the subspaces of the network correspond to specialized tasks (ii) the CKA baseline seems to perform similarly qualitatively to GW (e.g. in Figure 15): I'm unsure on what are the actual benefit of using GW over CKA, given that the inputs are already in correspondence.

评论

We thank the reviewer for your time to review our rebuttal and for your further consideration.

Validation: Thank you for highlighting this important concern. We understand the significance of validating that the subspaces of the network correspond to specialized tasks, as it directly relates to the robustness and interpretability of our approach. We acknowledge that direct validation is necessary, which is why we designed the synthetic experiments. Real datasets generally lack such ground truth labels on subspaces, so it is difficult and impractical to compare methods quantitatively. In our revision, besides the modular sum experiment, we provide indirect validation through a set of downstream tasks of model fine-tuning (Appendix G) and pruning/compression (Appendix J) to evaluate how well the discovered sub-networks would perform in prediction. Specifically, shown in Appendix J, we take the original pre-trained BERT and only use the first n={12,8,4,2,1,0}n = \{12, 8, 4, 2, 1, 0 \} transformer blocks while discarding the rest. Here n=12n = 12 corresponds to using all the transformer blocks, resulting the same BERT model. n=0n=0, on the other hand, means that we only use a (linear) classifier layer after the embedding layer to predict the class label. The results are shown in Table 7 in Appendix J. As a reminder, GW distance suggest the last 4 blocks in YELP (see Figure 6) and the last 2 blocks in SST (see Figure 17) are mostly different, which is marked by star ()(*) in the table. It shows that by using a limited number of layers, we can achieve similar performance to the full 12 block model, with 0.01%0.01\% and 0.54%0.54\% differences in YELP and SST, respectively. In contrast, using one fewer transformer block can risk much larger performance drop, with 0.10%0.10\% in YELP and 8.60%8.60\% differences in SST2 (approximately a 10-fold worse reduction). These accuracy results further justify the quality of the subnetworks identified. We appreciate your feedback and welcome any additional suggestions or methodologies you believe would further strengthen this aspect of our work.

GW vs CKA: Thank you for pointing out the apparent qualitative similarity between the GW and CKA results in Figure 15. We appreciate your careful examination of the results and the opportunity to clarify the benefits of our approach. As briefly discussed in Appendix I, while CKA performs similarly qualitatively to GW in Figure 15, it also shows greater variability within block structures. Moreover, on CIFAR-10 dataset per Figure 20 (which may present a more challenging task due to its lower accuracy), CKA struggles to provide clear subnetwork structures, often showing different block structures at each layer. These results demonstrate that geometric-based Gromov-Wasserstein distance capture more variances and better highlight the complex transformations of data.

审稿意见
5

The paper studies the functional similarity of sub-modules by applying Gromov-Wasserstein distance (GW) to their outputs. The core idea is that sub-modules (e.g., attention head/layer outputs at different depths along a transformer) can be seen as individual functions over their input, and these functions can be studied through the similarity of their outputs. However, existing methods in the literature involve either indirectly comparing the information content with respect to another function's output (a downstream task applied independently on each sub-module output) or directly with some (dis)similarity measure that assumes their pre-existing compatibility/comparability between representations. The approach proposed in this paper falls in the second category of methods, and by using GW models, a more generic measure (i.e., that takes into account more invariants) than the considered baselines. The approach is tested on a controlled environment (synthetic dataset and controlled training dynamic of an ad-hoc model) and on an NLP task (sentiment analysis) on a pre-trained model (BERT).

优点

  • The research question is extremely fitting for the venue, as it directly studies neural representations and is not one of the many deep learning topics that only indirectly involve them.
  • The approach is well-founded and investigated even for training dynamics, with an insightful analysis in Figure 8.
  • The high-level structure of the paper is robust: overall idea -> synthetic investigation -> real-world case -> in-depth analysis for the real-world one. I particularly appreciated the experiment structuring with their modularity in "setup" and "results".

缺点

  • I had a hard time understanding the scope and goals of the work: the paper is claimed to be aiming at mechanistic interpretability, but the methods, to my understanding, do not provide any human-understandable finding. Studying whether different layers provide similar or dissimilar representations is definitely interesting (and in fact, many existing works already do this), but I find the relation to mechanistic interpretability to be far-fetched.

  • The novelty of the approach is also not entirely clear — in particular, in what aspect is the approach novel? Prior work has already investigated the (dis)similarity of different layers, both intra- and inter-networks [1, 2], and has also studied the block structure hinted at in this paper [2]. The similarity measure, by itself, reminds me of a topology-based measure in [3]. Some discussion on the novelty of the measure would help to understand. If it is novel, then how does it compare to widely established measures such as CKA [4] or CCA-based measures? The former is cited in the work, but I don’t see it in any experiments. CCA and CKA take into account (in different ways) the “sample-to-sample” relationships, so they are more suited as baseline methods to see the advantages of GW. In general, [1, 2] are too relevant to not be discussed in the work.

  • Clarity of the method. While it is clear from the text why GW could be a strong measure to apply in this context, the method part is a bit confusing. From my understanding, GW in this paper is just “evaluated” on the representation pairs since the correspondence (π\pi) is known (i.e., the representation for sample X at layer i corresponds to the representation of sample X at layer j). Is it correct? If this is the case, I would suggest specifying that π\pi is not optimized in this setting because equation 2 suggests the opposite. If this is not the case, it’s unclear to me why we would want to optimize a correspondence that is already known at the risk of matching different samples across layers. Can the authors kindly expand on this?

  • I found the writing to be a bit confusing, often mixing low-level details with high-level takeaways. I suggest moving any detail not immediately necessary to understand the experiments to an appendix to allow the reader to focus on the core message.

  • Experiment inconsistencies. There are a few mismatches between the tables/figures content and the text. For example:

    • Table 1 introduces "All representations considered in experiments". However, all the other experiments have only one of them considered. Table 2 is the only exception, where apparently, all of them were used to find the most similar ones across all possible representations. I would suggest to introduce all the variations in this experiment (the one related to Table 2) since they are applied only there, otherwise it comes up as confusing, especially since it's the first table in the paper.
    • Figure 2 (related to the experiment for section 4.1), contains the “Pretrained” (a) and “Fine-Tuned” (b) sub-figures, but I didn’t find in the text any description of what they are. The associated paragraph “Distance Distributions” just refers to the generic Figure 2 without describing its content or commenting on the different results according to the model.
    • The paragraph “Neighborhood Change” refers to a T-SNE plot showing the results that are commented on, but the figure is in the appendix (Figure 9e, Appendix D). Since the message of this paragraph is to study neighborhood consistency, I would use a quantitative metric for measuring neighborhood overlap (even a Jaccard similarity between the top-k neighbors, I think, could suffice), adding it to the main text rather than (only) a qualitative plot in the appendix.
    • Table 2 for experiment 5.1 shows results that I think are not well-commented in the text. What's missing is the message from them: what's the meaning of the top similar layers being the ones of those specific kinds and not others?
    • In the synthetic data experiment (section 5.1), 3 models are used: Model 0, Model E, and Model L. According to the Training procedure/setup paragraphs, Model E and L have three layers each. However, in the Results paragraph of the same section and in the related figures (3 and 4), they appear to have more than 90 layers each.

    [1] Raghu, Maithra, et al. "Do vision transformers see like convolutional neural networks?." Advances in neural information processing systems 34 (2021): 12116-12128.

    [2] Nguyen, Thao, Maithra Raghu, and Simon Kornblith. "On the origins of the block structure phenomenon in neural network representations." TMLR, 2022.

    [3] Klabunde, Max, et al. "Similarity of neural network models: A survey of functional and representational measures." arXiv preprint arXiv:2305.06329 (2023).

    [4] Kornblith, Simon, et al. "Similarity of neural network representations revisited." International conference on machine learning. PMLR, 2019.

问题

  • Unconnected Related works. I think that the related work section could benefit from some restructuring. For example, it is not clear to me why the "Algorithm discovery in algebraic problems" and "Real applications" (e.g., here "vision and language components" are mentioned, why?) paragraphs are connected to this work. On the contrary, I found the last one (Similarity Measure between Neural Network Layers) to be clear and well-connected to the rest of the paper. I think a couple of sentences for each paragraph explaining the connection with the rest of the work could help.
  • Section 5.2, Setup paragraph: when the three sparse models are introduced, I would suggest explaining why they are used/useful in this experiment. Even a single sentence like "The sparse models are used to force their respective models to condense more the information in the few weights they have remaining, so we can analyze any link between this constraint and their structural similarity".
  • Section 5.1, Training procedure paragraph, line 356: "In all these models, we are able to achieve 100% prediction accuracy". Is it on a validation set hinting at perfect modeling of the target function or on a training one useful to understand if the models could overfit or not?
  • There's a typo in line 146/147 (intro to Section 3). A full stop is missing after "as a similarity search problem".
  • Experiment 5.2, Results paragraph: on line 439, it is written: "from layer 8 and only fine-tuning the later layers". As a minor comment, I would suggest changing this sentence to “freezing the model up to layer 8 and only fine-tuning the layers after that” to make the sentence clearer.
  • The figures are dense and have really small ticks/labels. For example, Figure 6 and Figure 2 require a lot of zooming in. My suggestion to improve the readability would be to start by sharing the x and y axes labels and avoid adding details in the titles (i.e., the KL divergence, especially because it appears only there and is not commented on in the text).
  • Section 4, start of the “Unknown Targets” paragraph (line 230/231): I think there are missing words, the opening sentence is not grammatically correct. A rewrite could be “When the search target is unknown, functionally similar parts cannot be identified by comparison to a predefined set of target functions.”
评论

We thank the reviewer for many detailed comments and great suggestions.

Studying whether different layers provide similar or dissimilar representations is definitely interesting (and in fact, many existing works already do this), but I find the relation to mechanistic interpretability to be far-fetched. As we discussed in the introduction, we believe that sub-function discovery in the form of different subnetworks could provide us insights into the number and type of sub-functions that exist in a network for modeling the prediction function. This serves a good first step towards identifying these functions automatically. We believe understanding neural networks through the identification of subnetworks is essential due to the complexity and opacity of modern deep learning models. Neural networks, especially those with many layers and parameters, often exhibit behaviors that are difficult to interpret holistically. Identifying subnetworks allows us to decompose the model into smaller, more interpretable units, providing insights into how individual components contribute to the model’s overall performance.

Clarity of the method: "it’s unclear to me why we would want to optimize a correspondence..." Our goal is to uncover layer similarities, with Gromov-Wasserstein (GW) serving as a metric for comparing unaligned manifolds—effectively measuring the distance between two sets of samples. If the identity map (for π\pi) yields the minimal distance, GW would recover such an identity mapping. If we assume that the semantics of the examples do not change across layers, setting π\pi to identity would work. However, in neural networks, it is plausible that the semantics of the examples could change, as evidenced by the significant geometric changes that often observed in local neighborhoods (visualized in the tSNE projection in Figure 11). These shifts arise because different network layers may focus on distinct aspects of sentences or images, suggesting that using simple identity mappings would result higher distance at every layer regardless of other transformations. To account for these semantic shifts, it is desirable to optimize optimizing the correspondence, which enables us to uncover meaningful structure and relationships that would otherwise be obscured if we use the fixed identity map. Empirically, baseline RSA is similar to GW distance with the self transportation plan, based on inter-sample distances. On YELP dataset, RSA is more sensitive to layer changes and give less definitive boundaries between subnetworks structures (Figure 16). On CIFAR-10 dataset, RSA also tends to show higher discrepancies between visually similar layers and does not provide clear sub-network structures (Figure 9).

Baselines and Experiments: Per suggestion, we conducted experiments on 6 additional baselines including the suggested CCA and CKA, on modular, NLP, and computer vision datasets. The results show our methods can give clear division of subnetworks, and visually align well with learned representations. Moreover, we also give a quantitative measure on discovered network. We also added a discussion in Related works on [1, 2, 3, 4] accordingly.

[1] Raghu, Maithra, et al. "Do vision transformers see like convolutional neural networks?." Advances in neural information processing systems 34 (2021): 12116-12128.

[2] Nguyen, Thao, Maithra Raghu, and Simon Kornblith. "On the origins of the block structure phenomenon in neural network representations." TMLR, 2022.

[3] Klabunde, Max, et al. "Similarity of neural network models: A survey of functional and representational measures." arXiv preprint arXiv:2305.06329 (2023).

[4] Kornblith, Simon, et al. "Similarity of neural network representations revisited." International conference on machine learning. PMLR, 2019.

Clarifications and editing: Per suggestion, w moved details on representation candidates in the original Table 1 to the Appendix A. Note that these representations in Table 1 is also used in Figure 3, 4, and others modular sum experiments. As discussed in Appendix B, the search space is of size 93 because all intermediate layers within transformer blocks are also compared. We have distinguished transformer blocks and intermediate layers in the revision to avoid confusion. Due to space limit, we removed the pretrained figure in Figure 2 and clarified its content further in Appendix F. We have added Jaccard similarity measures on neighborhood overlap in the main text and more details in Table 5 of Appendix F. Moreover, we have clarified KL divergence in figures in Appendix F. On the modular sum experiment, as discussed in Appendix B, we use a separate validation datasets with an 80-20 train-validation split.

评论

I thank the authors for their rebuttal. I am glad to see the added baselines, but given that I am still overall unconvinced about the motivation of the work and the link to mechanistic interpretability, I am only raising my score to 5.

评论

Thank you for your thoughtful comment and further evaluation of our work.

Motivation: We appreciate your concern regarding the motivation of the work and its connection to mechanistic interpretability. We would like to further clarify that the primary motivation of our work is to automate the process of understanding complex underlying functions. This challenge is central to the growing field of mechanistic interpretability, as it directly addresses the challenge of revealing the underlying processes and reasoning within models. Our work builds on existing literature in mechanistic interpretability. For example, Nanda et al, 2023 and Zhong et al, 2024 both rely on manual inspections to reverse engineer these functions, which are limited in smaller networks (e.g., one-layer transformer) and simpler problems (e.g., modular sum between two numbers). These limitations highlight the relevance of our work toward automating this effort for larger networks and real datasets.

Link to mechanistic interpretability: Mechanistic interpretability aims to uncover how models function internally and why they make certain decisions. Our work contributes to this goal by providing a framework for identifying (specialized) substructures within models. For instance, in CIFAR-10, our method enables us to visually inspect the relationships between different subnetworks and image representation from the model’s internal computation. This aligns with the broader goals of mechanistic interpretability by providing insights into the "how" behind model behavior.

We hope this explanation helps to clarify our motivation and its connection to mechanistic interpretability. We would be happy to expand on these points further if you have additional suggestions or if there are specific areas you believe we should elaborate on.

评论

We thank the reviewers for their detailed and constructive feedback. Below we address major questions and clarify some confusions regarding our method. In response to the suggestions, we have updated the manual script to reflect these changes, including improvements in presentation and clarity, the addition of 6 new baseline methods, and 1 new vision dataset alongside many suggested experiments. New results took some time, but we hope these updates address the reviewers' concerns. We look forward to further discussions with the reviewers.

Baselines We included 6 additional baselines in the experiment, including RSM [1], RSA [1,6], CCA[5], CKA [2], MSID [3], and AGW [4], with implementation details provided in Appendix H. These baselines are used to compare with GW distances across various experiments, including cases where the dimensions of the spaces are identical (a total of 10 methods compared) and where they differ (6 methods compared). Overall, GW demonstrated the clearest subnetwork structures and the strongest visual consistency on image datasets. These results further validate the effectiveness of our approach for subnetwork identification.

[1] Klabunde, Max, et al. "Similarity of neural network models: A survey of functional and representational measures." arXiv preprint arXiv:2305.06329 (2023).

[2] Kornblith, Simon, et al. "Similarity of neural network representations revisited." International conference on machine learning. PMLR, 2019.

[3] Tsitsulin, Anton, et al. "The shape of data: Intrinsic distance for data distributions." ICLR 2020.

[4] Demetci, Pinar, et al. "Revisiting invariance and introducing priors in Gromov-Wasserstein distances." arXiv preprint arXiv:2307.10093 (2023).

[5] Morcos et al. "Insights on representational similarity in neural networks with canonical correlation." NeurIPS 2018.

[6] Kriegeskortee et al. Representational similarity analysis-connecting the branches of systems neuroscience. Frontiers in Systems Neuroscience 2 (2008).

AC 元评审

This paper introduces a method to automate mechanistic interpretability by identifying subnetworks that are functionally distinct/similar. This uses the GW distance, which allows matching representations across different intermediate representations. They perform experiments on algebraic, language and vision tasks.

Ultimately, no reviewer championed the paper. The main limitations called by the reviewers are:

Improvement on motivation and how it is reflected in the experiments (in particular validation on real data sets. This was raised by all reviewers, with the two most critical ones being dissatisfied with the authors answer (and the two more positive ones also calling this as an unaddressed liitation)

审稿人讨论附加意见

Reviewers requested additional baselines, that were run (leading to raising the score to 5 by 4pzj, and to 6 by bfy4). Overall, the positioning of the work did not convince the reviewers and nobody championed the paper for acceptance.

最终决定

Reject