GAMformer: In-Context Learning for Generalized Additive Models
GAMformer uses in-context learning to estimate Generalized Additive Models in a single forward pass, performing comparably to traditional iterative approaches like EBMs and spline based GAMs while maintaining interpretability.
摘要
评审与讨论
The authors use prior fitted networks to train a large transform-type architecture that can then learn the shape functions of additive models in a single forward pass. The resulting model is competitive to other approaches with similar interpretability and capacity.
优点
- Originality: The idea of learning shape functions in-context is novel and interesting.
- Numerical Experiments The comparisons cover a variety of models from different classes, thus giving a bigger picture of GAMformer's capabilities.
缺点
Major
Contributions
- [C1] The claim that "experimental results demonstrate GAMformer's capacity to match the accuracy of leading GAMs" might be accurate regarding performance, but the interpretability has not been sufficiently scrutinized. See comments on Experiments below.
- [C2] In light of GAMs requiring no tuning (at least in the
mgcvpackage), the claim "... to form shape functions ... eliminating the need for ... iterative learning and hyperparameter tuning" does not seem particularly significant. - [C3] The contribution claiming the model was applied to the MIMIC-II dataset lacks significance. This dataset has been analyzed previously. The current study does not add any new insights. The dataset itself is also not particularly challenging, yet the modeling approach seems to have missed a key property of the dataset (see E3 below).
Technical soundness/correctness
- [T1] The introduction to GAMs is missing a distributional assumption (a GAM consists of both structural and distributional assumptions; see Wood, 2017).
- [T2] The simulated functions are not GAMs but deterministic functions. As correctly noted by the authors, a GAM is defined by a link function, yet they use a simple indicator function without induced noise or distributional assumptions for the simulation, which does not correspond to the data-generating process of a GAM. The code uses
sigmoid, but there is no distribution involved (same for the regression task). - [T3] "Allocating bins based on the quantiles of the feature in the training dataset" → This approach is likely inferior to equidistant binning, as quantile-based binning alters the data distribution of the feature (see Li and Wood, 2017).
- [T4] The comparisons with
mgcv::gamappear incomplete (see below).
Significance
-
[S1] The computational costs of:
- fitting a GAM are (see Wood, 2020), where is the number of features and the number of basis functions (in
mgcv, often set to 10). For the data used by the authors, this would amount to 50-800 parameters; - predicting with a fitted GAM is .
In contrast, ICL requires millions of parameters, if I understand Sec. 3.2 correctly, and even with fast inference, it is slow compared to GAMs, where typically and hence the quadratic scaling of in the transformer is still the bottleneck. Moreover, the authors report that the model required 25 days on a high-performance GPU, whereas all the analyzed datasets could be fit within seconds using GAMs. GAMs can also be applied to datasets of size using
mgcv::bam(see Wood et al., 2017). - fitting a GAM are (see Wood, 2020), where is the number of features and the number of basis functions (in
-
[S2] The method does not seem to outperform other models in prediction accuracy and appears to be inferior to TabPFN. TabPFN itself could also be analyzed using SHAP after computing the predictions, raising the question if a specific architecture is even necessary.
-
[S3] I could not identify any other significant insights of theoretical nature or similar that could be derived from a GAMformer. In particular, I would assume that the SCMs in TabPFN likely already cover GAM-type models (related to S2). This raises the question of what additional benefit is gained by making them explicit, as done here.
Experiments
- [E1] The experiments do not show the shape functions of the GAM method, which would be particularly useful for illustrative examples.
- [E2] Simulations should be designed to correspond to an actual GAM to see whether the GAMformer can actually recover those (see T2).
- [E2/T4] The
mgcv::gamshould not be inferior to logistic regression if used correctly (see Figure 6). - [E3] Isn't there censoring in the MIMIC datasets? A time-to-event model might be more appropriate in this case then.
- [E4] In the Appendix experiments, the authors switch to
pyGAM, which is known to be inferior tomgcv::gam, and do not report the latter's performance.
Reproducibility
- [R1] The code does not provide competitor models.
Writing
- [W1] There is some redundancy between Sections 1 and 2, which disrupts the flow of reading.
- [W2] The notation is somewhat confusing, as is an index in the bins and represents the th feature in .
Minor / Technical soundness
- [M1] The functions are typically referred to as smooth terms or smooth functions in the GAM literature, not shape functions (a term seemingly invented by the NAM community). They are also not partial dependence plots (as these are plots, not functions; in GAM literature, they are referred to as partial effects).
- [M2] The function typically does not map to but to a subspace (, or more specifically, e.g., (0,1) for the logistic function).
- [M3] What is in equation (2)?
- [M4] "Spline-based GAMs use the backfitting algorithm" Backfitting was proposed by Hastie and Tibshirani. More recent approaches, like those from Wood, use PIRLS or alternatives like INLA (see Wood, 2017).
- [M5] The citation in footnote 3 (and for the
mgcvpackage in general) seems incorrect. - [M6] No "shape functions" for pairwise smooth interactions are shown.
References
- Wood, 2017: https://www.taylorfrancis.com/books/mono/10.1201/9781315370279/generalized-additive-models-simon-wood
- Li and Wood, 2017: https://link.springer.com/article/10.1007/s11222-019-09864-2
- Wood et al., 2017: https://www.tandfonline.com/doi/full/10.1080/01621459.2016.1195744
- Wood, 2020: https://www.maths.ed.ac.uk/~swood34/test-gam.pdf
Suggestions for Improving the Paper
Here are some suggestions for improving the paper:
-
Writing/Novelty/Significance: 1) Clearly articulate how GAMformer advances beyond current GAM implementations and/or what additional insights they provide for PFNs or in-context learning. If you consider not changing your listed contributions, provide a more rigorous theoretical comparison with GAMs (computational complexity, etc.). I would, however, suggest changing your argumentation and thinking about what other aspects a PFN-type model can provide that a GAM cannot.
-
Technical Correctness and Clarity: Revise the GAM background to more formally introduce GAMs, and merge the redundant parts of Sections 1 and 2.
-
Numerical Experiments: Consider 1) simulating datasets that follow a GAM data-generating process, 2) comparing against
mgcv::gamin performance, 3) showing shape functions also for GAMs. -
Application: If the authors' approach allows the inclusion of censoring, consider modifying your application. Alternatively, consider a different and more challenging dataset.
-
Reproducibility: Include the code for competitor models.
问题
- Q1: I would be very happy if the authors could address the weaknesses I have mentioned above
- Q2: Are there any insights of GAMformers that I might have missed?
- Q3: Have the authors thought about analyzing the smoothness of GAMformers and whether this could provide more interpretable functions compared to the jagged functions of NAMs / EBMs?
- Q4: Have the authors thought about extending the class of GAMs? I would assume that this model could also learn a combination of GAMs, trees, NODEs, etc., and still remain interpretable.
Thank you for your detailed comments. From [S1], [S2] and the suggested improvements, it seems that there is some misunderstanding of our method. Our method is not slower to perform inference than traditional GAMs, and produces GAMs with a comparable number of parameters. It also provides an extremely high degree of introspection, while PFN type models are completely opaque and extremely slow to predict.
Addressing your individual points:
-
[C1] The parametric form of the model makes it intrinsically interpretable. Whether the shape functions allow for meaningful interpretation is rather subjective and hard to quantify. We show a wide variety of shape functions produced by GAMformer in the paper for subjective analysis, so it would be great if you could clarify point [E1]. The shape functions produced by GAMformer are similar to and share many of the properties of shape functions produced by EBMs and NAMs.
-
[C2] We compare against mgcv, and given the result in Figure 6, it is clear that mgcv without tuning is not a competitive method and is statistically significantly worse than Logistic Regression
-
[C3] See T3. We follow a standard evaluation procedure to show that our methodology provides comparable insights to existing approaches.
-
[T1] We use the definition of Additive models of Hastie, Tibshirani and Friedman, who we consider authoritative in the area.
-
[T2] The reviewer is correct that the data generating process for the meta learning data used to train the transformer model does not correspond to the data-generating process of GAMs. The goal of the synthetic meta learning data is to produce a very large number of realistic tabular learning problems without restricting them to be GAMs. The restriction to GAMs occurs in the architectural constraint applied to the neural net architectures that the transformer can learn from the data. We believe this is similar in spirit to most experimental evaluations of GAMs where real-world data is not restricted to having been generated by a GAM, but a GAM model is then learned from this data.
-
[T3] Quantile binning is employed by EBMs and leading gradient boosted regression tree implementations such as XGBoost, LightGBM and CatBoost, which are current gold-standard algorithms in machine learning. The binning obviously changes the distribution of the features, but given the immense practical success of this method, it is unclear why this should be considered a disadvantage. In practice, all of these methods (including ours) choose a binning that is fine enough not to cause significant loss of resolution but coarse enough to make learning computationally more efficient. Tree-based algorithms such as XGBoost, LightGBM, CatBoost and EBMs are not sensitive to the choice of interval or quantile-based binning as long as the binning resolution is sufficient.
-
[S1] There seems to be a confusion about training, meta-training and prediction phases. State-of-the-art GAM methods usually employ histogram-based lookup tables for prediction, which is the method that GAMformer uses. The inference speed of GAMformer is not slower than that of other GAMs, and this conclusion likely means a misunderstanding of the method. Similarly, the amount of features in the predicted GAM models is bins * classes * features for multiclass classification and bins * features for binary classification, which is similar to the range of parameters in a fitted spline based model of 50-800 as you suggest (though the minimum in our implementation would be 64 for a single features and two classes). The millions of parameters are not fit to any individual datasets, but learned during meta-learning, which corresponds to the development of the algorithm in classical ML. Indeed all the datasets studied in the table can be fit within seconds using GAMformer. The 25 days of GPU training should be compared to the time it took to implement mgcv, not the inference time.
-
[S2] As mentioned above, it’s not immediately clear how to analyze TabPFN with SHAP. This is likely computationally infeasible as TabPFN does not produce a model; also, it would only result in a post-hoc explanation, not a glass-box model.
-
[S3] This statement is somewhat confusing. The SCMs are the data generating process used for training TabPFN, and we indeed assume they cover GAM-type models as we re-use this prior. However, TabPFN does not recover the SCM, and so the predictions made by TabPFN are completely opaque. The benefit of making the additive model explicit is that the user has access to the model structure, which is not possible with TabPFN.
-
[E1] Could you please clarify? Figure 2, Figure 4, Figure 7, Figure 10, Figure 11 and Figure 12 compare the results of GAMformer with a baseline SOTA GAM model, EBMs.
-
[E2 /T4] This seems to contradict [C2], and it’s unclear on what basis this claim is made. Either mgcv::gam requires tuning to perform reasonably, or it does not. We did not perform tuning and it was outperformed to a statistically significant degree by logistic regression. If it requires tuning (“using it correctly” as you say) to outperform logistic regression, then [C2] does not seem valid.
-
[E3] The MIMIC datasets are censored with respect to some outcomes. This is common for many medical datasets and outcomes. Although some models trained on the MIMIC datasets are, as the reviewer suggests, time-to-event models, most are not and training “standard” classification models on the MIMIC data as we do is very common.
-
[Q2] Yes, please see above.
-
[Q3] This would indeed be interesting, however, we are not familiar with established ways to measure function smoothness for GAM models.
-
[Q4] This is somewhat confusing; it’s unclear to us how a more complex combination of models could still be interpretable. While one of the interesting aspects of this work is to show that it is possible to learn specific parametric forms with a meta-learning approach, and this approach could translate to other model families, it’s unclear how these models could still be interpretable. We specifically restricted the model class to GAMs for this very purpose.
Dear Authors,
Thank you for your detailed comments. Some answers helped in resolving my unclarities. Below is my response for the rest:
[C3] See T3. We follow a standard evaluation procedure to show that our methodology provides comparable insights to existing approaches.
I am not sure how a standard application with a commonly used dataset can be a paper contribution. I agree that it does a good job as a sanity check, showing that your method works similarly to other methods. But this belongs to your second contribution point ("experimental results demonstrate ...").
We compare against mgcv, and given the result in Figure 6, it is clear that mgcv without tuning is not a competitive method and is statistically significantly worse than Logistic Regression
If a logistic regression is better than mgcv, I believe something went wrong as I tried to elaborate in my review. To get to the bottom of this: What do you mean with mgcv without tuning? It could also help if you provide the code.
Furthermore, I still don't agree with the sentence "... to form shape functions ... eliminating the need for ... iterative learning and hyperparameter tuning". mgcv implements various approaches that do not require separate tuning, e.g., AIC-based selection of the smoothing parameter. See also E2/T4 response below.
The reviewer is correct that the data generating process for the meta learning data used to train the transformer model
I think there might have been a misunderstanding. I was talking about your the illustrative example data generation, not the pre-training. Can you comment on this?
[S1] There seems to be a confusion about training, meta-training and prediction phases.
I was trying to say that I do not see the benefit in your method when accounting for its complexity.
- The inference is as expensive as for the GAM, correct?
- The training is much more expensive, correct? (on a factual level, not philosophically)
- Will it be significantly different from a fitted GAM in prediction performance? I don't think so, since GAMs have certain optimality properties.
- Is there any way I can trust your method's inference? It doesn't come with confidence intervals that have theoretical guarantees as for GAMs.
- Furthermore, once you have done the basis transformation, a GAM is a linear model. You don't need to train a transformer to learn a linear model, this can be done in a single
nn.Linearlayer (to be fair, you could benefit from meta-learning, but there are theoretical limits to this in linear models as well).
As mentioned above, it’s not immediately clear how to analyze TabPFN with SHAP. This is likely computationally infeasible as TabPFN does not produce a model;
Maybe I am not too familiar with SHAP, but wouldn't it be possible to use any other XAI method that works on the level of predictions such as partial prediction plots?
also, it would only result in a post-hoc explanation, not a glass-box model.
But a million-parameter model is not a glass-box either, is it? Just because an LLM does have a linear head that we can interpret doesn't mean we understand its reasoning or uncertainty.
[S3] This statement is somewhat confusing.
To rephrase: If I had access to a post-hoc interpretability method like the one mentioned above, would I find that TabPFN is also learning a GAM if the data-generating process of the data on which we want to predict is also a GAM?
[E1] Could you please clarify?
I see shape functions for GAMformer and EBM. But I was talking about the shape functions of a simple GAM model (in particular mgcv, because some strange things are happening in pyGAM and statsmodels).
[E2 /T4] This seems to contradict [C2], and it’s unclear on what basis this claim is made. Either mgcv::gam requires tuning to perform reasonably, or it does not. We did not perform tuning and it was outperformed to a statistically significant degree by logistic regression. If it requires tuning (“using it correctly” as you say) to outperform logistic regression, then [C2] does not seem valid.
It does not contradict. I think we are just talking about different things when using the word "tuning".
mgcvuses a criterion to define the optimal smoothing parameter (such as the AIC). If you want to call this "tuning", so be it. But this happens out of the box, no additional data is required. You do not need to set any parameters.- If you "tuned"
mgcv(meaning you fiddled around with the parameters), then you likely either did something wrong or more advanced, but unnecessary in most practical cases).
[Q4] This is somewhat confusing
Using a fANOVA approach for example as e.g. done here.
Due to the deadline extension, it would be interesting if the authors could also comment on my other points.
The training is much more expensive, correct? (on a factual level, not philosophically)
The training of a per-dataset model, i.e. the in-context learning is not more expensive. It is, as we noted, not scalable to large datasets given the current attention mechanism, but on smaller datasets the complexity is comparable to additive models if not faster, as it is a single forward pass in the transformer. Learning the transformer is a meta-learning step that, as mentioned before, could be conceptually compared to tuning the specific GAM implementation in mgcv::gam.
Furthermore, once you have done the basis transformation, a GAM is a linear model. You don't need to train a transformer to learn a linear model, this can be done in a single nn.Linear layer (to be fair, you could benefit from meta-learning, but there are theoretical limits to this in linear models as well).
With the representation we choose, directly optimizing a linear model does not lead to an accurate result, as the binned representation does not add any smoothness constraints. The point of learning the transformer for this model is that it learns to regularize based on the data distribution seen during training. That is the fundamental contribution of this work: learning how to infer an additive model that optimally generalizes given a distribution over training datasets in the form of the synthetic prior. It would be interesting to show a direct comparison between a linear model learned on our encoding and the GAMformer model, I'm not sure we will be able to produce this in time for the rebuttal, though. With any other GAM, the regularization is hard-coded in the basis function or the regularization of the basis coefficients, while we learn the optimal regularization for generalization (such as smoothness or dealing with correlated variables) from scratch from the training data.
Maybe I am not too familiar with SHAP, but wouldn't it be possible to use any other XAI method that works on the level of predictions such as partial prediction plots?
Again, as TabPFN does not produce a model, these would be extremely expensive. I highly encourage you to read the TabPFN paper in detail to understand what would be necessary to apply partial prediction plots to TabPFN.
But a million-parameter model is not a glass-box either, is it? Just because an LLM does have a linear head that we can interpret doesn't mean we understand its reasoning or uncertainty.
The models produced by GAMformer are not million parameter models. The function that predicts for a given dataset is a compact GAM model. This is unlike a linear head on an LLM, that uses the latent representation computed with a transformer model. We are producing essentially a linear layer (after binning) that is applied to the original input; that is what makes it interpretable.
But I was talking about the shape functions of a simple GAM model I understand this as including a spline-based GAM model? Unfortunately we likely won't have time to produce these graphs, even with the extended deadline, but it would indeed be an interesting addition. We were unable to obtain competitive results from mgcv:gam, which made the comparison of less interest to us.
mgcv uses a criterion to define the optimal smoothing parameter (such as the AIC). If you want to call this "tuning", so be it. But this happens out of the box, no additional data is required. You do not need to set any parameters.
We did use the default parameters and achieved bad results. We will share the code for you to confirm our results. From these results, we draw the conclusion that the default parameters do not yield good results, and manual intervention is required.
Due to the deadline extension, it would be interesting if the authors could also comment on my other points.
Were there specific other points that you wanted us to address that we didn't address above?
It would be great to make sure you understand the architectural difference between this work and TabPFN and why this model is more interpretable than TabPFN. From the questions above, it seems these structural differences, which are the core contribution of our work, are not clear.
Dear Authors,
Thank you for the quick response. Please find my answers below:
The training of a per-dataset model, i.e. the in-context learning is not more expensive.
I think this is a misunderstanding. I was talking about the (pre-)training time. If you take the number of FLOPs an A100 can process in 25 days, then you could fit around 400 million GAM models. I understand that you learn (very fast) in context. But that was not my point.
Learning the transformer is a meta-learning step that, as mentioned before, could be conceptually compared to tuning the specific GAM implementation in mgcv::gam.
If I understand correctly what you refer to as "tuning": in mgcv this does not take more than a second for and the typical number of features you used in your experiments. Hence my confusion about the quote cited in [C2].
With the representation we choose, directly optimizing a linear model does not lead to an accurate result, as the binned representation does not add any smoothness constraints.
I think you might have misunderstood. What I meant is: If you would train a GAM with basis functions (such as a B-spline basis), then this effectively becomes a (ridge-regularized) linear model. I am not saying your GAMformer does that.
With any other GAM, the regularization is hard-coded in the basis function or the regularization of the basis coefficients
What is the downside of hard-coded regularization? I might need to choose a basis with a suitable null space, but that is all I have to do.
while we learn the optimal regularization for generalization
Is there any theoretical evidence that supports this claim? There is one for GAMs (however, only for the given dataset).
We did use the default parameters and achieved bad results. We will share the code for you to confirm our results. From these results, we draw the conclusion that the default parameters do not yield good results, and manual intervention is required.
I am a bit confused. How can you achieve bad results, and then --- with manual intervention --- perform worse than a logistic regression model?
Were there specific other points that you wanted us to address that we didn't address above?
[E4], [R1], whether the [W]-parts make sense, and all the comments from my response I wrote before (25 Nov 2024, 14:20 ET) that have not been addressed (in particular the "tuning" question).
It would be great to make sure you understand the architectural difference between this work and TabPFN and why this model is more interpretable than TabPFN. From the questions above, it seems these structural differences, which are the core contribution of our work, are not clear.
I think one way forward to avoid confusion could be to improve Figure 1 and its caption.
However, understanding your architecture is not the point that I am struggling with. With my review,
- I am providing you pointers and constructive feedback to improve your manuscript (such as your list of contributions) --- something that so far the authors do not seem to value much
- and asking you
-
- to provide evidence that your in context learning can recover a GAM model,
-
- what benefits there are if inference is as expensive as for GAMs, (pre-)training is much more expensive (on the level of 400 mio GAMs), prediction performance is likely the same, and I don't know whether I can trust your method.
-
First, we would like to thank reviewer DRNQ for going the extra mile to understand our work and make suggestions about how to improve it. This kind of thorough reviewing is not so common these days, and we really appreciate it. Thank you. We also appreciate clearly articulating your concerns.
We will try to provide empirical evaluation for recovering a GAM and details on the evaluation procedure for mgcv after the (US) holiday weekend.
Although we did not discuss how to generate error bars with GAMformer, the method we would use is the same bootstrap method used in EBMs and NAMs. Specifically, we would form multiple bootstrap samples of the data, use GAMformer to generate a GAM for each boostrap sample, and then return the mean and variance (or confidence interval) at each point on the learned shape function. The speed of GAMformer's forward pass makes this even more efficient than it is with methods such as EBMs or NAMs which need to iteratively refit a model for each bootstrap sample. We will add a short paragraph describing this method for generating confidence intervals to the final draft, as well as add confidence intervals to the shape functions in the figures.
As you suggest, GAMs formed by GAMformer have similar computational cost at prediction time as GAMs learned with other algorithms, so no win or loss there. And while the cost of the forward pass in GAMformer that generates a GAM is likely less than the cost of the iterative optimization required by algorithms such as EBMs or NAMs, this is traded off against a very large meta cost to train the GAMformer transformer in the 1st place. This tradeoff, however, is not as bad as it might seem. First, the GAMformer transformer model need only be trained once, and then it can be publicly served or distributed to efficiently generate many GAMs for many users and datasets. And we can already see above how this kind of efficiency could have utility for things like bootstrap analysis and exploratory modeling.
Unfortunately we were not able to perform conclusive results for recovering GAM models from a GAM model with binomial distribution in time for the rebuttal.
Regarding the training using mgcv we used the following code for training and prediction:
formula_str <- paste(target_variable, "~", paste(predictor_variables, collapse = " + "))
formula <- as.formula(formula_str)
if (num_classes == 2) {
# Binary classification: use gam with binomial family
model <- gam(formula, data = train_data, family = binomial)
test_probs <- predict(model, newdata = test_data, type = "response")
positive_class <- levels(data[[target_variable]])[1]
binary_labels <- as.numeric(test_data[[target_variable]] == positive_class)
roc_curve <- roc(binary_labels, test_probs)
auc_values <- c(auc_values, auc(roc_curve))
} else {
model <- multinom(formula, data = train_data)
test_probs <- predict(model, newdata = test_data, type = "probs")
multiclass_roc <- multiclass.roc(test_data[[target_variable]], test_probs)
auc_values <- c(auc_values, as.numeric(auc(multiclass_roc)))
}
The paper addresses the problem of supervised learning for tabular data and proposes a solution based on generalized additive models. A key feature is the use of an attention-based neural network (Transformer) to process the training data and provide a prior over the parameters of the non-linear predictive functions. The learning process involves splitting the training data into a training set and a holdout set. A predictive likelihood over the holdout set is used to learn the prior based on the training set. Experiments on synthetic data and OpenML datasets are conducted to compare the proposed solution with explainable boosted machines, demonstrating its ability to achieve comparable predictive performance
优点
- The paper is clear and well-written Clarity
- The paper addresses an important and relevant problem, specifically how to leverage deep learning to learn a prior for predictive tasks on tabular data. Relevance
- The code is available, and a Jupyter Notebook is provided to demonstrate how the proposed model and explainable boosted machines generate the predictive functions Code Availability. However, no checks have been performed to verify the reproducibility of the experiments.
缺点
- The novelty of the paper is limited and incremental. Novelty. The main ideas have already appeared in two previous works [1,2], and the primary difference seems to be the use of a different classifier/regressor. In other words, instead of considering Bayesian neural networks or structural causal models like in [1,2], the authors focus on generalised additive models. In essence, the work can be seen as an application of existing ideas within the context of generalised additive models.
- There are several vague and overstated claims that are not properly supported. For instance, the abstract mentions that the proposed solution generates highly interpretable predictive functions. Soundness However, this is also true for the competitors, and it is unclear what the real advantage of the proposed solution is over existing generalised additive models and other interpretable models (such as XGBoost). In the experiments (e.g., lines 304-305), it is stated that the proposed solution outperforms explainable boosted machines (EBMs), but these claims seem exaggerated. Firstly, in the low-data regime (32 samples) with a larger number of features (64), the proposed solution clearly underperforms compared to EBMs by 14 points, suggesting a possible blind spot and indicating that sufficient data is required for the proposed solution to perform on par. Secondly, it is unclear whether the differences in the results are statistically significant, as no standard deviation is provided. Similarly, for Figures 2 and 3, it is claimed that the proposed solution clearly learns smoother predictive functions. However, this is subjective and not consistently true (only the 1st and 3rd plots in Figure 2 support the authors' claim).
- The experimental analysis lacks a consistent comparison across datasets and tasks with other interpretable models. Additionally, the analysis focuses on the case where the ground truth classifier lies within the hypothesis space. What about the agnostic case? Quality
- The experiments are conducted on small datasets, reflecting the poor scalability of the approach. While the idea of synthesizing data may be reasonable for small datasets, it may not be tractable or feasible for higher-dimensional data, given the potential for combinatorial explosion. Scalability and feasibility are currently overlooked, which is a significant limitation of the proposed solution. As a result, it is unclear why one should prefer this approach over existing interpretable models that are more scalable. Quality/Significance
References
[1] Müller, Hollmann, Arango, Grabocka, Hutter. “Transformers can do Bayesian Inference”. ICLR 2022
[2] Hollmann, Müller, Eggensperger, Hutter. “TabPFN: A Transformer that Solves Small Tabular Classification Problems in A Second”. ICLR 2023
问题
Please, refer to the main weaknesses.
Thank you for your comments and suggestions.
Novelty
As mentioned in the shared response, we think there is substantial novelty in producing a parametric, interpretable model with efficient inference via in-context learning, which TabPFN does not.
Soundness
Could you please elaborate in how far XGBoost is an interpretable model? Gradient boosting models are generally considered to be black-box models that at most allow post-hoc interpretation. It is not the point of the paper to claim that we are outperforming EBMs across the board, and we are happy to adjust the phrasing. Rather, we want to claim that it is possible to create competitive additive models using in-context learning; this work is meant as a proof-of-concept of this idea, and we do not expect it to immediately replace existing solutions.
Thank you for the clarifications, but my concerns remain unaddressed. Hence, I keep my score.
Could you please clarify in how far your concerns are unaddressed? The main idea of the work of producing parametric interpretable models with efficient inference via in-context learning has not been previously investigated, so we would like to understand your concern about novelty better.
We would also like to understand better your concern about quality. The real world datasets that we evaluate on have no ground-truth classifier, and it's unlikely that the bayes estimator for these datasets are all within the class of generalized additive models as we parametrized it. Therefore nearly all our results address the agnostic case.
The paper proposes an in-context learning approach for learning generalized additive models for tabular data building on prior work (PFN and TabPFN) for tabular classification. The training procedure executes on synthetic data by sampling a random causal graph and generating data from an initial random sample. The data is split into training and test datasets to simulate inference. A transformer model applies attention across the data points and features and handles tabular data of varying sizes. A single forward pass of the transformer estimates the shape functions for the given in-context training data which are then applied to the test example. The shape functions themselves are represented as discrete functions which apply to discretized and binned features. The method is demonstrated experimentally on synthetic and real data including a mortality risk case study where the shape functions are used to interpret model predictions.
优点
The method appears to be a novel approach for learning generalized additive models.
The paper is well-written, ideas and goals are clearly stated, background work is acknowledged and limitations are addressed.
Experiments are done on synthetic and real examples with an extensive public health case study interpreting the learned shape functions and their implications.
The paper discusses the limitations of the model which are 1) lack of accounting of higher-order interactions 2) lack of improvement over datasets larger than seen during training and 3) quadratic complexity of the transformer.
Also propose an extension to model higher-order effects by concatenating data and high-order effects.
缺点
The approach appears to be limited to discrete target values. Shape functions are learned as discretized functions over discretized features which could be limiting.
问题
Do you only consider discrete target variables in the experiments? Given that the features are binned and discretized, could the method be applied for regression with continuous variables?
Thank you for your comments. The results in the paper indeed are limited to the classification setting, however, we are in the process of training a separate model for regression. We are not concerned about the discretization of input variables; we use the same discretization scheme that is commonly used with gradient boosted models and EBMs, both of which provide state-of-the-art results for regression (See Grinsztajn et al “Why do tree-based models still outperform deep learning on tabular data?”). In fact, the excellent performance of discretized gradient boosting algorithms was the motivation for this architecture.
The paper presents GAMformer, a novel model for fitting Generalized Additive Models (GAMs) using in-context learning (ICL) within a transformer-based framework. Unlike traditional GAMs that rely on iterative methods such as splines or gradient boosting, GAMformer uses a single forward pass to estimate shape functions for each feature, eliminating the need for hyperparameter tuning and iterative optimization. This approach is trained exclusively on synthetic data but performs competitively with existing GAMs on real-world tasks. GAMformer’s non-parametric, binned approach to shape function estimation enables high interpretability of feature impacts. Experimental results show that GAMformer matches or surpasses other interpretable machine learning methods on both synthetic and real-world tabular datasets, including clinical applications on the MIMIC dataset for ICU mortality prediction. Additionally, the model’s adaptability to real-world data demonstrates its potential for scalable, interpretable applications without extensive tuning.
优点
GAMformer is a contribution to GAMs, leveraging ICL and transformer models to eliminate iterative optimization, thereby simplifying the modeling process and reducing the computational overhead associated with traditional GAMs.
The model maintains high interpretability—crucial for critical fields like healthcare—while matching the performance of established methods like Explainable Boosting Machines (EBMs).
GAMformer’s training on synthetic data enables it to generalize to real-world data effectively, a challenging task for many models, especially in interpretability-driven applications.
The use of a non-parametric, binned representation for shape functions allows for flexibility, particularly for capturing discontinuities or sudden shifts in feature impacts.
The model was rigorously tested across various benchmark datasets, and a case study on ICU mortality in the MIMIC-II dataset demonstrated its clinical interpretability potential, which is well-aligned with the paper’s goals.
缺点
GAMformer currently only supports main effects and second-order feature interactions, limiting its applicability for datasets where higher-order interactions are significant.
The Transformer architecture in GAMformer scales quadratically with the number of data points, leading to potential performance bottlenecks for very large datasets. Exploring scalable attention mechanisms, as the authors suggest, would strengthen the model’s practical use.
While the clinical case study is insightful, further empirical evaluations across diverse fields (e.g., finance, manufacturing) would provide a clearer picture of GAMformer’s interpretability and performance across different domains.
There is a lack of quantitative results tables comparing the model with recent baselines.
问题
Please see the weaknesses section.
Thank you for your comments. Regarding the limitation to first and second order effects, for glassbox models, limiting to second order and main effect features is a common approach, as higher order interactions are usually not easy to interpret. It has been found that by modeling up to second order interactions, GAMs can achieve state-of-the-art results on a wide variety of datasets (See Chang et. al. “How Interpretable and Trustworthy are GAMs?”). While it would be possible to extend GAMformer to higher order interactions, this is unlikely to be useful in a setting that requires interpretable models, and other interpretable GAM models such as EBMs and NAMs typically restrict models to main effects and pairwise interactions.
What recent baselines do you think are missing from our comparison?
I have read the author's response, but my concerns remain unresolved.
As a result, I will maintain my original score.
I'm sorry to hear you don't think your concerns are addressed. How can we address them?
We want to thank all the reviewers for their insightful comments. We want to address some broader comments before addressing individual reviews.
Novelty
While this work builds on the work of Mueller and Hollmann, it presents several novel ideas that are absent from that work. We are producing a parametric, interpretable model with efficient inference via in-context learning. None of these are true for TabPFN, which does not explicitly represent the prediction function. We are not, as reviewer 7tYk suggests, using a different classifier/regressor. The work in Hollman (TabPFN) produces no model, and learns to perform predictions via in-context learning with a transformer, which means that for each individual prediction, the transformer model has to be invoked, leading to slow predictions compared with more traditional models. This also means that existing post-hoc methods like SHAP, which reviewer DRNQ suggests [S1] as an alternative, do not readily apply to TabPFN, as no model is produced. A brute-force version of SHAP could be applied to TabPFN by fixing the training set and varying prediction points; however, this essentially means running the transformer on all probing points required by SHAP and is therefore extremely computationally intensive compared to our approach. Furthermore, SHAP provides only post-hoc explanations that can be hard to interpret when aggregated to the feature level. In this work we use transformers to produce a GAM model that is interpretable, is very efficient at making predictions, and which can be edited if necessary to make corrections to the model because of bias in the training data. We also innovate beyond the architecture of TabPFN by supporting an arbitrary number of features, and providing a model that equivariant to the ordering of features.
Scalability
Scalability of transformer-based architectures is a widely studied problem; solving it goes beyond the scope of this work. We are in the process of extending our model to more scalable architectures, however, any future development improving attention scalability (which is an incredibly active research field given the applications to LLMs) is likely to improve the scalability of our approach.
The paper introduces GAMformer, a method that uses in-context learning to fit Generalized Additive Models in a single step, rather than traditional iterative methods. The models presented are trained on synthetic data but demonstrates good performance on real-world datasets. The authors claim competitive performance with leading GAM implementations while maintaining interpretability.
The method is limited to first and second-order feature interactions. As well, the neural network based approach loses many of the theoretical guarantees compared to traditional GAMs. Reviewers were concerned by the cost of training the model compared to ad-hoc GAM training. Reviewers also had concerns about the experimental validation and the novelty compared to recent works.
Based on the reviews and discussion, this paper appears to be marginally below the acceptance threshold. While the approach is interesting, I believe this is outweighed by reviewers' concerns so I will recommend rejection.
审稿人讨论附加意见
Reviews were initially slightly negative highlighting novelty concerns, computational complexity, statistical significance of results, comparison with standard baselines, and reproducibility.
The authors responded to reviewers' concerns making good points on one-time training cost (vs every-time training cost for GAMs) and defended their experimental methodology while acknowledging some limitations. Their rebuttal helped clarify some points but left others unresolved. One reviewer was particularly concerns about potential mistakes in baseline experiments. After rebuttal some reviewer concerns remained unaddressed.
Reject