Hierarchically branched diffusion models leverage dataset structure for class-conditional generation
Class-conditional diffusion models can be trained to share the diffusion timeline and model parameters according to similarity between classes
摘要
评审与讨论
The paper proposes a novel framework for class-conditional generation using diffusion models, which are models that can generate realistic objects by reversing a noisy diffusion process. The framework called hierarchically branched diffusion models, leverages the hierarchical relationship between different classes in the dataset to learn the diffusion process in a branched manner. The framework has several advantages over existing methods, such as being easily extendable to new classes, enabling analogy-based conditional generation (i.e. transmutation), and offering interpretability into the generation process. The paper evaluates the framework on several benchmark and large real-world scientific datasets, spanning different data modalities (images, tabular data, and graphs).
优点
- The paper addresses a novel and important problem of class-conditional generation using diffusion models, which can capture the rich information and structure of the data classes.
- The authors introduce a novel and flexible framework that can exploit the inherent hierarchy between distinct data classes by using branch points and can handle different types of diffusion models and paradigms.
- Extensive experiments and analysis are conducted to demonstrate the advantages of the proposed framework in continual learning, transmutation, and interpretability.
缺点
-
The paper is well-written, but some claims could be improved for clarification. For example, in Section 4 Page 6, it is unclear how the model performs with versus without fine-tuning the upstream branches that also diffuse over the newly added class in the continual learning setting without certain empirical results as support. Another example is that I found the observation in Section 5 Page 6 that letters with a larger feature value tended to transmute to letters with a larger feature value is hard to interpret from Figure 3b with only scatterplots of some feature values given.
-
The efficiency of branched models over standard linear models when sampling multiple classes is clear from Table S9 when the number of classes is relatively small. Could the authors provide some insights when the number of classes grows very large with potential comparison regarding complexity analysis?
-
In Figure 4, the interpretation at immediate branch points between two classes is shown and aligned with intuition. I am curious about how the visualization is for the branch point from more upstream such as the branch point between class 0 and the immediate branch point between class 4 and 9.
问题
-
Why the FID between true and generated cells are the same for the branched and linear model in Figure S3 c)?
-
Could the authors give some insights on why only one baseline of the label-guided (linear) models (Ho et al., 2021) is used for comparison?
Thank you for taking the time to review and provide valuable feedback! We have answered all the questions and comments in the following.
Fine-tuning upstream branches (W1)
In our experiments of continual learning in branched diffusion models, we found that fine-tuning upstream branches did not help (or hurt) the generative performance of downstream tasks. This is a very expected result, as our branch points are explicitly defined to be the point in diffusion time after which classes diffuse nearly identically. Hence, fine-tuning these upstream branches is not expected to significantly change the model’s performance. We will better clarify this point in the revised version.
Transmutation of letters (W1)
In Figure 3b, we showed the scatterplots for just two examples of features, but the histograms below (still in panel 3b) show that we found positive (often strongly positive) correlations for all features in the dataset. Of course, if analogous features were not being retained in the process of transmutation, we would expect these histograms to be tightly centered around 0.
Efficiency of branched models with more classes (W2)
The speedup of branched-diffusion models is only indirectly influenced by the number of classes, and instead is directly determined by the hierarchy and how much diffusion time can be shared. The amount of shared diffusion time, in turn, is a function of how similar different classes are (and the diffusion process). For datasets with a lot of very similar classes, we are able to share much more diffusion time, and therefore will enjoy more speedup. However, datasets with very distinct classes will have very late branch points, leading to lower efficiency. It is also worth mentioning that datasets with only very distinct classes are not particularly well-suited to the method of branched diffusion which we propose in this work. Instead, the main contribution of our work is intended for scientific applications, whose datasets are typically characterized by a large amount of internal structure (e.g. cell types, chemical classes, structural folds, etc.).
Interpretation of branch points between less-related classes (W3)
Branch points between less-related classes can certainly still be interpreted (and still yield valid interpolations/intermediates). However, between very unrelated classes, the interpretations are naturally less meaningful, as only very high-level (and relatively uninformative) features will be shared between these classes.
To demonstrate this, we updated our manuscript to show the hybrid resulting from interpreting the branch point between more dissimilar classes (Supplementary Figure S9). For example, on our tabular letters dataset, we quantitatively show that the branch-point hybrids between dissimilar classes still exhibit feature values which are interpolated between the two original classes, and often act as intermediates which exhibit properties of both. We also visually show the branch-point hybrids between dissimilar digits of our MNIST dataset. Our resulting hybrids show the common features between dissimilar digits, which are fairly high-level, including: 1) the digit is in the center of the image; and 2) there are areas which are relatively empty (coinciding with the “holes” of how many people draw digits like 6s and 9s). This is a direct consequence of the lower similarity between two unrelated data classes rather than a limitation of branched diffusion models.
FID of RNA-seq dataset (Q1)
For the RNA-seq dataset, below we reproduce the FID values for each class. Due to the high variation in FID values between the classes (likely because some cell types are inherently harder to generate than others), differences in these values were not obviously rendered in Supplementary Figure S3.
| Branched | Label-guided | |
|---|---|---|
| CD16+ NK | 377.254 | 376.300 |
| Cl. Mono. | 163.470 | 166.414 |
| Late Eryth. | 150.605 | 153.134 |
| Macroph. | 151.118 | 156.442 |
| Megakar. | 64.968 | 65.506 |
| Mem. B | 1604.477 | 1610.171 |
| NK | 84.447 | 85.517 |
| Plasmabl. | 317.073 | 317.584 |
| Tem/Eff. H. T | 174.465 | 175.692 |
Choice of baseline (Q2)
Our proposed method of branched diffusion is a fundamentally novel way to perform class-conditional generation using diffusion. We thus sought to compare the benefits of branched diffusion to the current state-of-the-art methods of class-conditional diffusion. In the literature, there are effectively two ways to perform class-conditional diffusion: classifier-guided diffusion and classifier-free diffusion (the latter of which we called label-guided diffusion). Classifier-guided diffusion requires training external classifiers on all noise levels, and is not commonly used today for this reason. Label-guided diffusion, on the other hand, is universally relied upon today for conditional diffusion, and state-of-the-art results in diffusion arise practically entirely from this method. Hence, we opted to perform comparisons relative to this technique (the only state-of-the-art method used today).
Many thanks to your response and the additional experiments. I have also read the reviews and responses with other reviewers. My concerns are mostly addressed and I have raised my score accordingly.
The paper presents a new way to do class-conditional generation in diffusion generative models by modelling the formation of class structures as a hierarchical ‘branching’-type process: The generation starts from a noisy image that contains no class information, and as generation progresses, the partially formed image starts narrowing down to a smaller set of classes. Whenever some classes are ruled, out this is denoted as a ‘branch point’. In the method, these points are empirically estimated from the noising forward process and the original training data, and during training and generation, a separate conditioning signal is given to each of the branches. Each class has a unique combination of branches. It is shown that this formulation of class-conditional generation helps in avoiding catastrophic forgetting in some scenarios, lends itself to novel image-translation-type conditional generation for MNIST, gene data and molecules. Additionally, the paper shows a method to visualize and interpret the branching points using averages of the noisy data points. There can also be a benefit in generation efficiency if sampling multiple classes by combining the generation of multiple classes at once.
优点
- The paper presents a creative use of the diffusion forward process itself for controlling the generation in a way that is mostly unexplored at the moment.
- The observation that training new classes on only subsets of the diffusion forward process avoids catastrophic forgetting, is particularly striking and seems like a genuinely new effect. Possibly this paper could be a first step towards utilizing this property in more realistic continual learning scenarios.
- The paper also goes on to come up with different creative use-cases of the explicit branching structure, such as ‘transmutation’ where data points are transferred from one class to another using the branching structure, and in general finds multiple potentially relevant scientific data sets to experiment on.
- Perhaps the ideas here could inspire more research towards creating more structured diffusion generative models in the future.
缺点
- Some of the experiments are not, at the moment, particularly convincing of the usefulness of the effects that they are showcasing. For the analogy-based generation with the RNA-seq data set, some marker genes were indeed changed, but do we have any other way to evalute the success of these generations? Could we formalize a clear objective on what does the conditional generation aim to do in the first place here? A similar issue exists for the molecule data set: Indeed regenerating does allow to generate cycled molecules from acyclic ones, but it is not clear what if any properties of the original molecule are retained this way. Just looking at the results, it seems possible that the generated molecules are a random mixture of the desired property and some atoms and bonds from the original molecules.
- Continuing on the analogy-based generation, I feel that it would be appropriate to do an ablation where we use a regular diffusion model, noise out the data partially, and regenerate with the changed label. Would this work equally well, or differently somehow?
- The points about interpretability are also interesting, but it remains a bit unclear what could be the use of these average branching points.
问题
- What does the high correlation of expression between genes before and after transmutation mean here? That those particular genes often did not change? Is this the property that we want?
- Since adding uncorrelated noise probably does not result in clean hierarchical class structures in all cases (maybe in more complex image data sets, as pointed out in the paper), do you think it would be possible to induce such structure, e.g., by diffusing in some specifically designed latent spaces or otherwise designing the diffusion process itself to encourage it?
- I wonder if the example where catastrophic forgetting is avoided in MNIST is possible to extend to multiple steps, to a slightly more realistic continual learning scenario?
Overall, I think the idea is interesting and the paper presents new qualitative effects that emerge from the new formulation, but it is not there yet for publication. In particular, a more thorough experimental validation for continual learning and analogy-based generation would be in place, so that the reader would have clear takeaways. For the analogy-based generation, some kind of formalization of what are we targeting with the conditional generation, would also help with showcasing the potential significance.
Thank you for taking the time to review and provide valuable feedback! We have answered all the questions and comments in the following.
Formalizing and quantifying the goals of transmutation (W1, Q1)
In transmutation (i.e. analogy-based conditional generation), we start with an object of class , and transmute it to another object of class (usually ). The goal of transmutation is two-fold: 1) efficacy: should have features which define it to be in class , and should have features which define it to be in class ; and 2) analogy: features which do not distinguish/define an object’s class to be in or should be retained. This makes analogy-based conditional generation substantially different from standard conditional generation, as we are interested in sampling from instead of . In addition to several qualitative results, we systematically quantified both objectives in our experiments, demonstrating both efficacy and analogy in multiple real-world datasets.
For the RNA-seq example, we showed efficacy by quantifying the marker genes before and after transmutation. The class of a cell (i.e. cell type) is a complex concept, and marker genes are by far the most widely accepted method for determining cell type. As such, we showed key marker genes which demonstrate that a cell was successfully converted from one type to another. To demonstrate analogy, we quantified the correlation of non-marker genes, namely genes responsible for COVID-related inflammation (which do not define the cell type). These inflammation genes were explicitly identified by Lee et. al. (2020) (the publication which is the source of the RNA-seq data) as being key upregulated genes in COVID-infected cells (regardless of cell type). The correlation here is quantified over a random sample of cells (each point in the correlation is a single cell). A high correlation indicates that if the starting cell had a high expression of the gene, then the transmuted cell also had high expression of that gene (and vice versa). For these inflammation genes, this is precisely the desired property, as it indicates that the model is transmuting COVID-infected cells of one type into COVID-infected cells of another type, and healthy cells from one type into healthy cells from another type (i.e. the non-cell-type-defining features are retained).
For the molecular example, we quantified efficacy simply by counting the number of molecules which satisfied the goal structural property. We quantified analogy by computing the similarity of functional groups (which, in turn, endow the chemical properties of the molecule) before and after transmutation. Overall, in this experiment, the objective of transmutation is to generate molecules that satisfy the new class (e.g. “cyclic” or “halogenated”), while also resembling the initial molecules. Therefore, our results showing that the generated molecules satisfy the desired target property and are some “mixture” of similar chemical motifs from the original molecule, are certainly an indication that transmutation is working properly. For any starting molecule, it is desirable to see a distribution of similar atoms and functional groups in the generated molecules (i.e. analogy), but still being valid molecules which satisfy the target property (i.e. efficacy).
Attempting transmutation on a linear diffusion model (W2)
It is certainly possible to perform what was suggested here in a linear (i.e. regular) diffusion model. However, there are no explicitly defined branch points in a linear model, so it is not clear to what point forward diffusion should be performed before reverse diffusing to a different class (i.e. what “turn-back point” to use). As such, we would be forced to select a post hoc turn-back point which the model was not trained with. This leads to several pitfalls:
- Choosing a turn-back point which is too early will hurt efficacy, as the model will not have enough diffusion time to generate an object of the target class.
- Choosing a turn-back point which is too late will hurt analogy, as the model will lose many features which are non-class-defining.
- A turn-back point which retains both efficacy and analogy for one individual example from the dataset may not be sufficient for another example, because a linear diffusion model is not trained with any concept of branch points, and so there is no imposition of this structure between classes as a whole.
One of our core contributions is precisely the definition of formal branch points, which are the optimal time points to perform transmutation. Branched diffusion models are explicitly trained with these branch points, and so the model learns to generate shared characteristics in upstream branches, and class-specific characteristics in class-specific branches, for all examples.
Following the suggestion, to better show the advantages of branched diffusion models for transmutation compared to standard (linear) diffusion models, we have added a new ablation study (Supplementary Figure S8). Our results show that in the absence of explicitly defined branch points, attempting to perform transmutation in a linear model is significantly more difficult. Imposing a post hoc turn-back point that is too early causes the model to fail at efficacy: the target class is not generated at all. Imposing a turn-back point which is too late causes the model to suffer in analogy: non-class-specific features are no longer correlated before and after transmutation. Additionally, choosing certain turn-back points can cause some objects to be transmuted efficaciously, and others to fail to generate the target class entirely. These issues all severely complicate the attempt to perform transmutation in a linear model. Furthermore, even if we had prior knowledge of what a branch point should be, the transmuted results from the linear model at the “correct” branch point still are lower quality than in the branched model, likely because of crosstalk between the classes, which is not controlled for at all in the training of the linear model.
Thus, branched diffusion models are much more well-structured for tasks like transmutation (and of course, branched models are also uniquely more easily extendable in a continual-learning setting, their branch points are interpretable, and generation of multiple classes is significantly more efficient).
Benefits of interpreting branch points (W3)
In our experiments, we have shown how hybrid intermediates naturally capture shared characteristics between classes (Figure 4). In this paper, we focused on validating this hypothesis, as it demonstrates how the proposed model is learning to hierarchically generate objects. We will extend our discussion to better highlight the potential benefits of branch points on interpretability, which include:
- Discovery. Branch points can help identify shared higher-order features across two (or more) classes. While class similarities could also be discovered through other techniques, using branch points enables a model-based definition of these shared features, as they depend not only on the underlying data distribution but also on the trained model, including the defined diffusion process and the entire dataset on which the model was trained.
- Distribution-based interpretation of shared features. Branch-point intermediates are distributions of objects, where each individual feature has its own distribution of values given any two classes. This allows us to sample from the intermediate distributions, thereby not only allowing us to glean insight into shared features and concepts, but also assess their significance/confidence.
- Improved model understanding. Branch points may be used to diagnose potential issues in the generative model, providing a quantitative/qualitative way of checking generation intermediates, in addition to generated samples. On a related note, during the development of this method, we partially relied on interpreting branch points to verify the integrity of our discovered hierarchies.
We plan to further investigate applications of directly interpreting branch-point intermediates in future work.
Inducing hierarchies by defining alternative diffusion processes (Q2)
We found that even with a simple Gaussian diffusion process, we were able to recover meaningful hierarchies in our datasets. For example, although there is often some variation in the branching structures which can arise, our models remain robust to this variation, and these branching structures generally make intuitive sense (Supplementary Figure S5).
Defining alternative diffusion processes or leveraging uniquely structured latent spaces are certainly interesting directions which we hope to explore in future work. Depending on the scientific goal in mind, a unique diffusion process may be defined which targets a particular outcome for tasks like transmutation. For example, in the generation of molecules, a discrete diffusion process which operates on junction-tree representations may further encourage whole substructures and motifs to be retained in the transmutation process. On the other hand, a more fine-grained diffusion process which readily breaks/makes bonds might be better for a goal like scaffold hopping. Leveraging a specifically designed latent space (e.g. one which preserves short distances between molecules which have similar solubility) may result in a branched model which is better tailored to solubility-related tasks, as more diffusion time can be shared.
Multi-step continual learning (Q3)
The continual-learning experiments we showed can certainly be extended to multiple steps. The biggest benefit of branched diffusion models for continual learning is that when we introduce a new data class, the existing model remains completely frozen (i.e. the model’s generation of previous classes remains absolutely identical). As such, we have no doubts that since we achieved good generative results after introducing one class, our findings easily extend to the case where we introduce additional classes.
To further show how our method performs in larger continual-learning experiments, we present novel results performing multi-step continual-learning experiments where we extended our MNIST models to three never-before-seen classes one by one: 7s, 1s, then 2s (a mixture of classes which are similar to or distinct from existing classes in the model) (Supplementary Figure S7).
As expected, our results here show the exact same trend as before: branched models are easily extended with an efficient fine-tuning step, where the performance of previous classes is untouched upon the introduction of a new data class. In label-guided (linear) models, fine-tuning on new data causes catastrophic forgetting. It is also worth noting that due to the reliance of label-guided models on class embeddings, typical neural-network architectures are unable to accommodate never-before-seen labels unless they are explicitly initialized to allow for extra class labels; in our experiments on continual learning, we needed to separately train a label-guided model on fewer classes, but with sufficient extra capacity (”embedding slots”) to accommodate future classes. In contrast, branched models have no such limitation, highlighting an additional advantage of branched models in continual learning.
I thank the authors for the new experiments and explanations! I understand some questions better now, but some remain open. Here's some comments / questions that came to mind while reading the response:
Transmutation
- How do we define which features distinguish/define an objects class to be c_0 or c_1? I still think that a mathematical definition would be better here, so that we can clearly evaluate the sensibility of the definition and the quality of the results.
- E.g., q(c_1, x_0 \in c_0): What does this formula mean? By q(c_1), do the authors mean the distribution of the label or the distribution of x_0 given the label?
- Figure S8 does indeed look promising! But this is still just one rather small example, and a more thorough evaluation against this standard diffusion baseline (with numerical metrics) would make the case more convincing that training with the branching procedure does make the model behave clearly differently from a standard diffusion model.
Interpretability
- I think I understand slightly better now the interpretability point, and the model understanding part in particular seems useful for working with models like this.
Continual learning
- How can you make sure that the models generation of previous classes remains identical? Do you train a new neural network? Otherwise this claim seems impossible to me, since finetuning means that you do retrain the weights, but maybe I did not understand properly.
- I don't understand how do branch models not not have the limitation of adding some new weights for new classes? Is it the case the case that no direct input is given to the model here?
- For the fine-tuning experiments on continual learning, do you have the details on how long were the models fine-tuned, etc.? This seems like important information here.
FID values An important question I forgot to ask: How is the FID calculated for the cell data? Does FID make sense in this context, given that it uses the InceptionV3 network meant for image data? As a side note, even for MNIST, FID doesn't seem like the best metric, since the InceptionV3 network is trained on ImageNet, which differs quite greatly from the characteristics of MNIST (and as such the inner representations may not accurately reflect the structure of MNIST digits very well). But I'm not demanding any new experiments on the MNIST part.
Overall, my view of the paper overall has not changed as of now. The new results and explanations do give me more confidence that the method is interesting to the community, but the paper in the current stage seems preliminary still, as any of the effects, such as continual learning, have not been explored very thoroughly. In my view, especially the continual learning part seems interesting, due to the clear improvement over the standard diffusion model, and expanding experiments on that to other wider types of data would make for a really interesting paper. I am also open to hearing out further responses from the authors.
Thank you for the feedback and for clarifying remaining concerns!
Firstly, we would like to note that we do conduct fairly extensive experiments over many different data types and datasets, spanning from images to graphs to tabular data. Two of our datasets were also large, real-world scientific datasets, and altogether we showed global, quantitative results for each of the benefits brought by our proposed method of branched diffusion (continual learning, transmutation, interpretability, and efficiency). As such, we believe that these experiments (in addition to the additional experiments we have shown throughout the rebuttal period) are sufficient to demonstrate the advantages and disadvantages of our method relative to the current state of the art.
Transmutation
Unfortunately, typical probability notation is not very well-suited to capture the process of transmutation. To be more clear, we defined above to be the conditional distribution of objects for class , but also conditioned on an object . If we were to draw an object from , we would like both efficacy and analogy to hold (i.e. belongs to class and features that do not define class identity are shared between and ).
Although we agree that it would be nice to have more formalization on the process of transmutation, it is unfortunately not possible to mathematically formalize both goals of efficacy and analogy in a straightforward way. In particular, it remains generally infeasible to clearly distinguish and define the class-defining features which should be modified in efficacious transmutation, and the instance-specific features which should be kept analogous. Although class-defining features and instance-specific features certainly exist, they are oftentimes latent, high-level, and entangled with each other. If the class-defining and instance-specific features were disentangled and clearly defined, then a general mathematical formalization would be more accessible, but unfortunately, the realization of these features are highly dependent on the specific application, including the dataset, the definition of classes, and the meaning of the underlying data. Unless we have a more simplistic dataset or we have adequate domain knowledge to perform this disentanglement and latent discovery, it remains infeasible to formally define transmutation at a feature level.
For this reason, we showed globally quantitative results on transmutation for many datasets of different data types (Figure 3). For example, we showed transmutation on a tabular dataset, where we could examine feature analogy by inspecting individual features. We also showed globally quantitative results demonstrating both efficacy and analogy on two real-world large scientific datasets: RNA-seq and drug-like molecules. For RNA-seq, to show efficacy and analogy we relied on knowledge of the domain: namely the presence of marker genes which define cell type and analogous COVID-related inflammation genes which were first discovered in Lee et. al. (2020). For the molecules, our quantification of efficacy and analogy relied on high-order structural properties and knowledge of functional groups.
Overall, however, it remains relatively infeasible to clearly mathematically formalize the low- and high-order features which we expect to be retained for analogy but modified for efficacy, particularly when we are attempting to remain general to most scientific applications. Importantly, note that transmutation can be thought of as a generalization of neural style transfer (NST) (i.e. generating images where the subject of the image is given the style of another image). Note that in the seminal works which define NST [1, 2], there is also no mathematical formalization of the goals of NST, due to the inability to define ahead of time the features which we expect to be modified or retained.
We are also happy to hear that our novel ablation study shown in Supplementary Figure S8 was helpful. Importantly, we would like to clarify that transmutation is a novel technique that we presented, which was not a part of traditional linear diffusion models. Indeed, our main contribution is the discovery of branch points and applying them to a diffusion model, so that we can perform transmutation in the first place. As our results show, attempting to force a conventional linear model to perform transmutation is unnatural and difficult, and transmutation as a technique in linear diffusion models also did not exist prior to our paper. As such, beyond this ablation study showing why branch points are important for transmutation, we do not find it particularly meaningful to perform such an extensive evaluation (as we have done in our other experiments) against such an unnatural and difficult technique which understandably did not exist before.
Continual learning
In a branched diffusion model, we split the diffusion process of different classes into distinct branches, where reverse diffusion along each branch is predicted by a different output task of a multi-task neural network (Figure 1c). The branching is designed such that every class of the dataset has a subset of branches which collectively form the full diffusion timeline. Branches at late diffusion times tend to be shared between classes, as objects diffuse very similarly (regardless of class) at these late times. Each class, however, is guaranteed to have its own terminal branch (i.e. branch starting from ), which is solely responsible for generating that class.
When we introduce a new class to a branched diffusion model, we add a branch point to the existing hierarchy; the placement of this branch point depends on how similar the new class is to existing classes. This creates a new terminal branch for specifically that new class, and also splits an existing internal branch in two (Figure 2a). In terms of the neural network, we simply create a new network of the same architecture but with two more output tasks. We copy over the weights from the old model to the new model, matching up branches appropriately (i.e. we copy over the weights from the corresponding branches in the existing hierarchy; for the new terminal branch, we can initialize the weights to be the same as the terminal branch for a similar class). When we fine-tune this branched model, we freeze the weights in all the shared layers, and all the output tasks other than the one new terminal branch. We then fine-tune the model (updating only the weights in the new terminal branch), using data of the new class only. This guarantees that all other branches will predict reverse diffusion identically as before, and so all pre-existing classes are guaranteed to be predicted identically (as we only touched the new terminal branch, which does not affect the generation of any other classes).
Contrast this with a linear model, which does not cleanly separate the diffusion of different classes in any manner, and so new classes cannot be introduced efficiently without catastrophic forgetting.
It is also worth noting that in a branched model, there is no need to fine-tune upstream branches of the newly added class (i.e. upon the addition of a new class, we only needed to fine-tune the newly added terminal branch). This is an intentional and direct consequence of the definition of branch points, as the diffusion of objects upstream of branch points is nearly identical for different classes, and so anything upstream of a branch point would not need to be modified.
We showed extensive quantitative results for continual learning, comparing branched and linear models, in Figure 2 and Supplementary Figure S7.
For these continual-learning experiments, whenever we added a new class, we fine-tuned for 10 epochs (branched or label-guided models). This included our MINST experiments and our RNA-seq experiments. We did note that the loss had converged each time after 10 epochs. In our experiments, we also fine-tuned label-guided models with the entire dataset (instead of just the new data class) to show that the label-guided model could recover generative performance when fine-tuned on the entire dataset (which is of course quite inefficient). When fine-tuning label-guided models on the whole dataset, we trained with 30 epochs (note the dataset size is much larger here in each epoch compared to fine-tuning on a class-limited dataset). We will certainly add this to our supplementary methods, and we thank the reviewer for catching this missing detail.
FID
For our RNA-seq dataset, we computed FID by applying the FID metric directly to the feature values, as FID is a well-defined metric on general vector distributions. We absolutely agree that every metric of generative performance (including FID) has its advantages and disadvantages. In our analyses, we believe that FID is sufficient to capture the global trends in generative performance that we wish to understand. In addition to the analysis of FID, we also verified the generative performance of our branched diffusion models using other methods which we directly adapted from previous works (Supplementary Figures S1—S2).
Importantly, however, we emphasize that the main contributions of our work lie beyond generative performance, and we simply wish to ensure that branched diffusion models are not significantly worse in performance compared to their linear counterparts. We believe that our analyses in this regard are sufficient to demonstrate this claim.
References
[1] Gatys, L., Ecker, A., Bethge, M. A Neural Algorithm of Artistic Style (2015)
[2] Johnson, J., Alahi, A., Li, F. Perceptual Losses for Real-Time Style Transfer and Super-Resolution (2016)
Thank you for the further elaborations! Couple of last questions (I hope these do not come too late) that came to mind:
"Contrast this with a linear model, which does not cleanly separate the diffusion of different classes in any manner, and so new classes cannot be introduced efficiently without catastrophic forgetting."
Did you try also only fine-tuning the new output heads in a linear diffusion model? I understand that the difference is that one would have to train it for all the diffusion steps instead of just for the last branch, but this would be a good baseline to have. It sounds that the ability to finetune a small amount of parameters for only a part of the diffusion process is the key to making the continual learning scenario to work here. I think that this is quite important information (and could be useful for characterising when could we except this method to be useful).
"we computed FID by applying the FID metric directly to the feature values"
So does this mean that you calculate the 2-Wasserstein distance for the two sets of distances (with a Gaussian assumption)? I am mainly confused how does the calculation work in practice here, since FID itself is defined in a very specific way for image data and involves extracting features from a convolutional network, and I can't see how this could be applied to general vectors.
Thank you for clarifying the last few questions!
Fine-tuning a linear model
Linear (i.e. traditional, label-guided) diffusion models are generally single-task models which are not trained with multiple output heads (Ho et. al. 2021). Instead, class identity is specified to the model as a label embedding which is fed in as an input. It is certainly true that one of the major reasons why branched diffusion models are much easier to extend to new data is the ability to fine-tune a small number of parameters for only a small part of diffusion time, while guaranteeing that no previously learned tasks/classes are affected. Indeed, one of the contributions of branched diffusion models—in addition to recognizing that there are natural branch points along the diffusion process—is to turn the neural network into a multi-task network which optimally separates diffusion time into distinct branches and tasks to accomplish this separation of parameters.
In contrast, training a traditional diffusion model to be multi-task (with one output task per data class) is generally not a method which is used today. A lot of the diffusion process between classes (particularly at later times) would be inefficiently learned by multiple tasks, and so a single-task model with label embeddings is the ubiquitously relied-upon approach for class-conditional diffusion today. A branched model solves this problem of inefficient learning by leveraging a hierarchy of class similarities to define branches of shared diffusion. That said, we agree that this method of multi-task diffusion in an otherwise linear model would be far more amenable to continual learning than the ubiquitous single-task linear diffusion models, although it would still be less suitable than branched diffusion models due to the need to train on the entire diffusion timeline for each class (as you suggested).
Additionally, training a multi-task linear diffusion model would not allow for the many other benefits of branched diffusion models, such as transmutation, interpretability of branch points, or efficient generation.
FID on feature values
That's right! 2-Wasserstein distance is perhaps a clearer term to use here for similarity comparisons on general vectors. We will update the language in our manuscript accordingly.
This paper proposes hierarchically branched diffusion models for class-conditional generation. In a hierarchically branched diffusion model, branched points between all classes is generated based on the similarity between each class pair. The proposed model can be easily extended to continual learning scenarios. The model facilitates analogy-based conditional generation and provides a interpretability into the class-conditional generation process.
优点
-
This paper is well-written and easy to understand.
-
The inclusion of well-crafted visualizations greatly enhances the comprehension of key concepts.
-
The proposed method offers meaningful advantages.
缺点
I did not find notable weaknesses of this paper.
问题
A concern arises regarding the scalability of the proposed method as the number of classes increases. The experiments conducted appear to be limited to datasets with a small number of classes. It would be beneficial if the authors could present results for datasets with a larger number of classes.
伦理问题详情
N/A
Thank you for taking the time to review and provide valuable feedback!
To answer your question about datasets with more classes, we note that one of the datasets we presented had 26 classes, which we consider fairly typical in magnitude (or even larger in magnitude) compared to most scientific applications.
For example, training a model with cell types for classes would require only 10 - 15 classes for most tissue systems. For drug-discovery applications where the classes are target receptors, most families of receptors only have 5 - 10 members under investigation at any given point.
I appreciate the authors for answering my questions. However, the large number of classes I mentioned wasn't around 28 but a greater number (e.g., 100). I would like to maintain my positive score.
The paper proposes a method for class-conditional (label guided) sampling from a diffusion model by introducing branching. Analysis on several datasets suggests that the approach can be competitive (or perhaps even superior) in terms of generated sample quality. The proposed approach has several advantages compared to “classifier-free” guidance. It can readily incorporate new classes and can be used for transmutation (transferring a specific instance from one class to another). The method is considerably more efficient if the aim is multi-class sampling.
优点
S1. The paper proposes a highly novel and elegant approach to class-conditional (label guided) sampling from a diffusion model.
S2. Experiments indicate that the proposed method offers competitive (or better) performance to the state-of-the-art “classifier-free” guidance approach in terms of Frechet inception distance.
S3. The paper details how the presented method can efficiently incorporate new classes without retraining the entire model and illustrates how the method can be employed for transmutation. Multi-class sampling is considerably more efficient compared to the state-of-the-art approach.
S4. The paper is well-written and presents the proposed method clearly.
缺点
W1. Some of the experiments are not particularly compelling and serve more as examples of the potential of the technique rather than providing convincing evidence in support of the claims of the paper. In the image domain, analysis is limited to MNIST and the letter-recognition dataset; this leaves open the question as to how well the approach scales to other (less-structured) types of images and more challenging image classes, where the class hierarchy may not be so clear.
W2. The branching point definition involves a threshold. There does not appear to be a concrete recipe for the selection of this threshold. Based on the paper it seems to be left to the practitioner to determine when branching points are “too close” to 0 or T. There is no investigation of how the selection of this threshold impacts performance.
W3. Some of the claims in the paper are not adequately supported by experimental evidence. The claims should be moderated or experimental results provided to support the more general conclusions.
The paper proposes a novel, intriguing and elegant approach. The major weakness of the paper is that most of the claims in the paper are supported by relatively limited experimentation. For example, the paper claims that the method achieves similar or better generative performance as the state-of-the-art, but does not clearly preface this with the clarification that the outperformance is observed only for two simple character-based image datasets and the similar performance is only established for one other dataset. The paper would be considerably more convincing if there were experiments on more challenging image datasets.
问题
Q1. “In general, the branched diffusion models achieved similar or better generative performance compared to the current state-of-the-art label-guided strategy. In many cases, the branched models outperformed the label-guided models, likely due to the multi-tasking architecture which can help limit inappropriate crosstalk between distinct classes.” – these sentences seem to be strong claims when the experiments are conducted on three datasets (two of which are similar). There seems to be no discernible outperformance for the single-cell RNA-seq data. The outperformance is really only for two character-based image datasets, so “In many cases” seems to be a stretch. Considerably more extensive experiments on a variety of datasets are required to support the general claim made in the paper. Alternatively it could be restricted to “For experiments performed on two character-based datasets and an RNA-seq dataset, ….” Can the authors clarify whether they consider the current experiments to be sufficient to demonstrate similar or better generative performance?
Q2. There are concerns that the Frechet inception distance can provide an incomplete or even misleading picture of generative quality (e.g., “The Role of Imagenet Classes in Frechet Inception Distance”, ICLR 2023; “Assessing Generative Models via Precision and Recall”, NeurIPS 2018). Do the authors consider that there would be value in employing other approaches for investigating sample quality?
Q3. Why is the new class experiment limited to training on 3 classes? Is this to make the task easier? The new class experiment for the single-cell RNA-seq dataset seems to be similarly limited (just starting with two classes and adding a third). Is there a reason that the more obvious experiment of removing just one class and adding it back is avoided? What happens if the “1” class is already included (i.e. something that is much closer to the introduced task)?
Q4. “Of course, images and image-like data are the only modalities that suffer from this issue.” – why are images and image-like data the only modalities? Is there a “not” missing? Otherwise this seems to be an odd claim. The class-defining subject of a sequence could be at multiple parts of the image. The class of a graph can be defined by two subgraphs that are far from one another.
Q5. “Additionally, this limitation on images may be avoided by diffusing in latent space.” – is there evidence for this claim?
Thank you for taking the time to review and provide valuable feedback! We have answered all the questions and comments in the following.
Comparison of generative performance and use of FID (W3, Q1, Q2)
In our manuscript, we quantified the generative performance of branched vs linear (i.e. traditional) diffusion models over three datasets. For MNIST, we saw better performance in the branched model for 10/10 classes. For letters, our branched model outperformed in 24/26 classes. For RNA-seq, our branched model outperformed in 8/9 classes. Given these results, we consider it reasonable to claim “in general, the branched diffusion models achieved similar or better generative performance compared to the current state-of-the-art label-guided strategy”, and that “in many cases, the branched models outperformed the label-guided models” (quoted from Section 3.1). Note that our claims here are certainly within the context of our experiments, and we are not attempting to claim that branched diffusion models always offer improved generative performance over their linear counterparts in all situations. For improved clarity, we will modify these sentences as follows: “In our experiments, the branched diffusion models generally achieved similar or better generative performance…”
Specifically for the RNA-seq dataset, below we reproduce the FID values for each class. Due to the high variation in FID values between the classes (likely because some cell types are inherently harder to generate than others), differences in these values were not obviously rendered in Supplementary Figure S3.
| Branched | Label-guided | |
|---|---|---|
| CD16+ NK | 377.254 | 376.300 |
| Cl. Mono. | 163.470 | 166.414 |
| Late Eryth. | 150.605 | 153.134 |
| Macroph. | 151.118 | 156.442 |
| Megakar. | 64.968 | 65.506 |
| Mem. B | 1604.477 | 1610.171 |
| NK | 84.447 | 85.517 |
| Plasmabl. | 317.073 | 317.584 |
| Tem/Eff. H. T | 174.465 | 175.692 |
Finally, we agree that every metric for performance (including FID) has its advantages and disadvantages. In our analyses, we believe that FID is sufficient to capture the global trends in generative performance that we wish to understand. In addition to the analysis of FID, we also verified the generative performance of our branched diffusion models using other methods (Supplementary Figures S1—S2).
Importantly, however, we emphasize that the main contributions of our work lie beyond generative performance, and we simply wish to ensure that branched diffusion models are not significantly worse in performance compared to their linear counterparts. We believe that our analyses in this regard are sufficient to demonstrate this claim.
Application to scientific domains and datasets outside of images (W1)
The main contribution of our work is intended for scientific applications, whose datasets are typically characterized by a large amount of internal structure (e.g. cell types, chemical classes, structural folds, etc.). As such, although we demonstrated our method’s soundness on well-known image benchmarks such as MNIST, we also focused much of our work on real-world large scientific datasets such as single-cell RNA-seq and drug-like molecules, where we showed the promise of branched diffusion to perform real scientific discovery by recovering known biology.
Notably, our proposed method is very much dependent on datasets with internal structure. Datasets which have do not have inherent structure (or only very weak structure) between classes are not suitable for branched diffusion, whose core hierarchical backbone reflects the intrinsic similarities between classes in the dataset. We will modify the language in our manuscript to better clarify this intended use case of our method.
Defining thresholds and robustness to thresholds (W2)
We found that in general, our selection of was not too difficult, as there is a range of possible values which lead to branch points that are fairly close to the time range of . Additionally, we also found that variation in the branch points (resulting from randomness in the branch-point discovery procedure) still yielded similar results (Supplementary Figure S5).
To further demonstrate this, we present novel results on the robustness of our method to different values of in the updated version of our manuscript (Supplementary Figure S6). On the MNIST dataset, we swept through possible values of ranging from to , uniformly distributed in logarithmic space (note that the we used in our main results was ). We computed branch points for each value of , and trained a branched diffusion model on each. The two largest values of yielded class-specific branches of length 0, and therefore were discarded from this analysis. Supplementary Figure S6 shows that the branch points identified from our various values of are very similar to the distribution of branch points identified in Supplementary Figure S5 (variation due to randomness in sampling and in diffusion). This suggests that our branch-point discovery algorithm is rather robust to various values of , provided obviously inappropriate values (such as those which yield length-0 terminal branches) are not considered. Furthermore, models trained with these branch points still yielded generative performance which was largely robust to .
Additional continual-learning experiments with larger models (Q3)
The core benefit of branched diffusion models for continual learning is that the entire existing model can be frozen, and only a single new branch is added for a new class. Thus, in the exploration of this continual-learning benefit, we expect no difference in difficulty if we were to start with a larger existing model or a smaller one. Specifically for the continual-learning experiments, we opted to use smaller models simply for efficiency in training/loading. Regarding class extension when a highly similar class is already present (e.g. the model has already been trained to generate 1s and we are now introducing 7s for the first time), we would expect branched models to perform even better (or at least, certainly not worse) in these experiments, while we expect label-guided models to suffer from the same issues.
Additionally, we opted to show the result of adding a new data class which has never before been seen by the model (instead of removing an existing class and adding it back), as we believe this better simulates how these scientific models may be deployed in practice (i.e. new data coming in from experiments on new cell types or systems).
To further show how our method performs in larger continual-learning experiments, we present novel results performing multi-step continual-learning where we extended our MNIST models to three never-before-seen classes one by one: 7s, 1s, then 2s (a mixture of classes which are similar to or distinct from existing classes in the model) (Supplementary Figure S7).
As expected, our new results highlight the same trend as before: branched models are easily extended with an efficient fine-tuning step, where the performance of previous classes is untouched upon the introduction of a new data class. In label-guided (linear) models, fine-tuning on new data causes catastrophic forgetting. It is also worth noting that due to the reliance of label-guided models on class embeddings, typical neural-network architectures are unable to accommodate never-before-seen labels unless they are explicitly initialized to allow for extra class labels; in our experiments on continual learning, we needed to separately train a label-guided model on fewer classes, but with sufficient extra capacity (”embedding slots”) to accommodate future classes. In contrast, branched models have no such limitation, highlighting an additional advantage of branched models in continual learning.
Centering problem for images (Q4)
Images (and image-like data) are unique in the sense that there is a very well-defined alignment of pixels between different images and the class of an image is largely translationally invariant with respect to the subject. Importantly, typical neural-network architectures which learn to generate images (including the U-Net architectures used for image diffusion models) rely heavily on this property. That is, for the network to generalize to generating images where the subject is in the top left of the image, it needs to see images in the training set where the subject is in the top left (this is true for traditional linear diffusion models, as well).
Because branched diffusion models explicitly model the subdistribution of data distinctly for each class (whereas the separation of class-specific distributions is more implicit in a label-guided model), we conjecture (i.e. it is possible) that with limited training data, branched models may have a harder time generalizing to generate different classes. We also conjecture that this issue may be bypassed by training in latent space, or with different image-generating architectures which more elegantly encode translational equivariance.
For other data types such as graphs, although it is true that subgraphs which define a class may be in distinct areas of a graph, graphs do not have a well-defined alignment with each other which is relied upon by the neural network. Instead, typical graph-generating neural networks effectively treat two graphs identically even if they are rotated or rearranged in space, as long as the connectivity remains the same. Thus, the issue of “centering” does not exist for such data types (including graphs, tabular data, etc.).
We plan to extend the discussion on this point in the revised version of the manuscript.
Diffusion in latent space for images (Q5)
To clarify, our intended meaning for this sentence was that we expect this issue for images to be avoided by diffusing in latent space (i.e. “may” here being an expression of possibility rather than assured potential), although we leave exploration of this for future work, particularly because our contribution is meant to be for general scientific data (including RNA-seq, molecular graphs, etc.) instead of being focused on images. We will clarify and reword this to be clear in the revised revision.
That said, our RNA-seq diffusion models were trained in latent space, so we already have some evidence that supports the use of latent-space representations for branched-diffusion-model training.
Thank you for the detailed response to my comments. The majority of my questions and identified weaknesses have been addressed by the response. I am currently reflecting on the revised paper, the reviews, and responses, and will potentially raise my recommendation.
W2: Thank you for providing the additional results demonstrating the robustness of the approach for selecting the threshold. The results support the argument that the choice of threshold doesn’t overly matter, provided it does not lead to terminal branches of length 0. The initial concern I expressed in my review was that “simply choosing a value where the branch points are not all too close to t = 0 or t = T” does not make clear how “too close” is defined. For reproducibility, it would be helpful to provide a concrete process for the selection of \epsilon, or at least describe how it was done to produce the results in the paper.
Q3: I am somewhat surprised that the decision to use smaller models to explore continual learning was based on considerations about the time required for training. The image datasets analyzed in the paper are already small. If there is a need to reduce them even further to enable training, then it suggests that the training burden for the approach is very high. The multi-step experiments produce some impressive results and are a good addition to the paper. Was there another reason not to provide results for a larger initial model? The claim “we expect no difference in difficulty if we were to start with a larger existing model or a smaller one” doesn’t seem to be supported by evidence, and it seems like a simple enough experiment. Does training a 10-class model take a very long amount of time?
We sincerely appreciate the time taken to take our new results and response into consideration and reflection! Here, we will address these remaining questions.
W2
Our selection of was done by first examining the similarity between pairs of classes throughout the diffusion timeline from to . Given our rough goal of having branch points not too close to or , we considered the similarity between pairs of classes at . We then simply selected our value of to be the average similarity over pairs of identical classes (i.e. the average similarity between random partitions of data from the same class), minus the average similarity over pairs of different classes, rounded to one significant digit for simplicity. This difference measures the expected gap in similarity between identical versus distinct classes at the midpoint of forward diffusion. We found that this procedure was sufficient to ensure that the branch points were not all too close to or .
Of course, our results in Supplementary Figure S6 also show that the branch points (and downstream performance) are fairly robust to the selection of anyways. We agree that this procedural detail could be very helpful in justifying our selection of (even if performance is not particularly sensitive to this hyperparameter), and we will update our manuscript accordingly.
Q3
The decision to start with a model of fewer classes was motivated by two reasons: 1) sheer convenience of faster training time and data loading (particularly during development and testing); and 2) we needed to hold out never-before-seen classes to introduce for our continual-learning experiments (i.e. if we had trained a model on all available classes, then we would not have new classes to add for our experiments).
Our expectation that there is no increase in difficulty when starting with a larger model versus a smaller model is firmly rooted in the following observations: 1) when we add a new data class to a branched diffusion model, the weights in the existing architecture are frozen so that catastrophic forgetting is impossible; and 2) learning to generate a new class effectively reduces to fine-tuning a standard diffusion model, where the early layers of the neural network are pre-trained, and the late diffusion times are transferred from similar data. Note that this process of freezing early layers and only training on early diffusion times (for a single class) is also what allows continual learning in a branched model to be much faster than in a linear model.
In order to support our expectation that there is no increase in difficulty for larger models, we note that our experiments in Supplementary Figure S7 showed that the continual-learning benefit of branched diffusion models remained equally effective and performative as more classes were introduced and the size of the model increased. Furthermore, we also showed continual-learning results on branched diffusion models trained on our much larger dataset of single-cell RNA-seq (Figure 2).
Finally, to directly answer the final question, training a 10-class model does not take a prohibitively large amount of time, and this was never a limitation for us. For example, to generate Supplementary Figures S5—S6, we trained many branched diffusion models (on the full 10-class dataset) for 150 epochs, which was the same as the linear model.
Thank you for the additional clarification. Since all of my concerns have been addressed, and I think the paper makes a highly novel and promising contribution, I have raised my overall score to 8.
To the reviewers, we very much appreciate your time in providing valuable feedback.
In response to the initial reviews, we have conducted additional experiments and updated our manuscript with novel results. We have also responded in detail to all of the raised questions.
As the discussion period closes in two days on 22 November, we would like to confirm if our responses have sufficiently clarified and addressed any remaining concerns. We are happy to provide any additional clarification and discussion.
Thank you!
This paper proposes an extension of diffusion models for conditional generation using hierarchically branched diffusion. It builds on the idea that classes that are similar should at the earlier stages of the generative process not be distinguishable.
In general the referees like the paper, but also has some hesitation around some of the empirical findings relative to the motivation of the methodology.
The paper has style quite different style than most ICLR-type paper with not a single equation in main text. Appendix A contains the most of the methodology and digging into it, it becomes clear that the method is more ad hoc than it appears at face value. Therefore, the paper cannot recommended be accepted in its current form. What is important is to include equations detailing the modified log likelihood lower bound objective and an ablation comparison for the log likelihood against the same model not using the branching process. This will make the paper substantially stronger and help the work get a bigger impact.
为何不给更高分
The formulation is too informal for ICLR and the benchmarking is not complete.
为何不给更低分
None.
Reject