PaperHub
5.0
/10
Poster4 位审稿人
最低2最高7标准差1.9
5
7
6
2
3.8
置信度
正确性2.5
贡献度2.5
表达2.5
NeurIPS 2024

Retrieval & Fine-Tuning for In-Context Tabular Models

OpenReviewPDF
提交: 2024-05-15更新: 2024-11-06
TL;DR

We use retrieval with fine-tuning to improve in-context learning on tabular data and show large improvement with better scaling

摘要

关键词
in-context learningtabular dataretrievalfoundation modelstransformers

评审与讨论

审稿意见
5

Tabular data is an important yet understudied modality in machine learning. Following recent success of Tab transformer and TabPFN, people have made significant progress on tabular tasks. This paper is an extension of TabPFN, a previous sota tabular learning model on small scale data, and the authors combine with retrieval and finetuning. The motivation is intuitive and the idea is natural.

In the retrieval part, the authors reuse the simple kNN trick and discuss in details the performance in TabPFN.

In the finetuning part, the authors construct the training data by shared contexts and finetune the TabPFN model.

The authors experiment on large scale datasets and proposes the state performance model LoCalPFN.

优点

  • Tabular learning is an important topic in real life applications
  • The methodology of retrieval and finetuning is not dependent on ICL tabular model.
  • The performance gain by retriveal and finetuning is clear wrt TabPFN
  • Scaling experiment is provided which is important to large scale analysis

缺点

  • There are many limitations on the datasets requirements, e.g. number of features, number of classes, tasks
  • No other deep learning or transformer based models as a baseline
  • Sec 2.4 is not quite clear in finetuning steps, though the general ideas are appreciated.

问题

  1. L32, memory scales quadratically in the size of the dataset. Do the authors mean exactly the calculation of attention matrix? If so, it seems this memory limit s general and what the authors do is to leverage the limited context length better?
  2. In Table 1, is the gain from loCalPFN w.r.t TabPFN-kNN coming from only finetuning per dataset?
  3. Just make sure I understand correctly, do the authors fine tune on each dataset of 95 benchmarks?
  4. How are tree methods evaluated? Any hyperparameter optimization?
  5. What is the training time and inference time on each dataset and how this time scales with dataset size?
  6. What is the selection criteria of 95 datasets? How are they chosen, preprocessed, etc.

局限性

NA

作者回复

There are many limitations on the datasets requirements, e.g. number of features, number of classes, tasks

We agree. There are ways to go beyond them while still using the same architecture though. For instance one can perform feature selection very efficiently as the forward pass of TabPFN is very fast. Furthermore, all multiclass problems can always be reduced to binary classification problems. Finally even regression can be cast as classification as we demonstrated in the general message But we share your feelings and are currently investigating training a better base model which does not suffer from these limitations. This is however a different technical contribution on top on which the present method can be applied as well.

No other deep learning or transformer based models as a baseline

We address this point in the paragraph “Deep learning model comparisons” in Section 4.2 which points to Table 5 in appendix. Essentially deep learning models are usually slower to train and we could only obtain results on a subset of the datasets. We still report their performance and ours on those. We believe the contrast is very clear.

Sec 2.4 is not quite clear in finetuning steps, though the general ideas are appreciated.

This section is indeed the most technical part of the paper and we’d be happy to answer any specific question and update the main text. In the meantime to explain differently the main point: Let’s say we want to use a context of size N_ctx=1000 we wish to classify N_qy = 500. With vanilla TabPFN, we simply construct a sequence of size 1500 and tell the model which points are ctx and query through the attention mask. However when using a local context, now each of the 500 queries has its own 1000 context. We now have to fit (1000+1)500(1000+1) * 500 points in the GPU! For doing inference this is fine, but when having to do backprop, and that over many steps, it just becomes too slow. What we do instead is find a way to select points which are all “local” to each other so that we can use a context within that neighborhood and share it for all points, even though it might not be the exact neighbors of each query. We find that this works very well and efficiently in practice.

L32, memory scales quadratically in the size of the dataset. Do the authors mean exactly the calculation of attention matrix? If so, it seems this memory limit s general and what the authors do is to leverage the limited context length better?

Yes, this is correct for the TabPFN-kNN method. Note that this poses some challenges on how to implement this efficiently but your understanding is correct.

In Table 1, is the gain from loCalPFN w.r.t TabPFN-kNN coming from only finetuning per dataset?

Yes this is correct.

Just make sure I understand correctly, do the authors fine tune on each dataset of 95 benchmarks?

Yes. (see L43, L132)

How are tree methods evaluated? Any hyperparameter optimization?

We use the results from TabZilla. All baselines used HPO indeed. Please refer to Appendix A.2.1 for more details.

What is the training time and inference time on each dataset and how this time scales with dataset size?

We provided some runtime results in the attached pdf. The main takeaway is that retrieval is not the bottleneck here, it is the cost of doing the finetuning.

What is the selection criteria of 95 datasets? How are they chosen, preprocessed, etc.

We tried to address this question in the first paragraph of the experiment section “The 95 datasets are filtered from TabZilla to meet TabPFN’s architectural requirements by ensuring that each dataset has at most 100 features, at most 10 classes, does not contain NaN values, and has at least one instance per class for each split”. These are essentially conditions so that the datasets can be used without modification by the base model, TabPFN. The list of datasets is available in Appendix A.1

评论

Thank the authors for the feedback. I am still at the borderline for its improvement wrt TabPFN. Happy to hear what other reviewers think

评论

Thank you for engaging with us!

We have tried to address the point of the limitations of TabPFN by explaining how to expand to more classes and provided a regression experiment. Do you have a specific limitation in mind you would like us to address?

Furthermore we would like to point out that our contribution is the retrieval and fine-tuning procedure (and their efficient implementation such as the shared context approximation or the joint retrieval+fine-tuning), our paper is not concerned with modifications to the base model (TabPFN) which is definitely important, but orthogonal, future work.

We believe we addressed your concerns, which notably included the deep learning and their runtime experiment, clarification on the fact that HPO is indeed used for our baselines (30 rounds), and answering specific clarifications alongside with the concerns about TabPFN's limitations which are addressed in the general message and discussed above. We hope that our additional experiments and explanations will be reflected in your evaluation of our work.

审稿意见
7

The authors extend the recently introduced TabPFN to larger and more complex datasets by fetching a relevant context for each test point using a KNN algorithm. The author evaluate two methods, TabPFN-knn, which consists of using the original TabPFN on the fetched context for each test point, and LocalPFN, which adds a finetuning step to adapt TabPFN to these new kind of local context. For this finetuning, an approximation is devised to use close points together in both the context and the query, enabling a shared context which is faster, but with results close to the result of a KNN for each point.

The paper shows extensive evaluations of these two methods, demonstrating state of the art results against the original TabPFN, Gradient Boosting Trees like CatBoost, and (not pretrained) neural networks on a previously introduced benchmark.

Furthermore, the authors provide several ablations, for instance showing the importance of jointly adding the knn and finetuning steps.

优点

Important and novel contribution, which allow to use TabPFN is many more settings.

The evaluation is well done. The authors benchmark their method on a previously introduced dataset, and the selection they use is principled, showing that no cherry-picking of dataset was used. The baselines are strong (modern GBDT like catboost, modern NN like ft transformer, and different tabpfn variants).

The ablations are interesting and numerous. For instance, I was wondering about TabPFN-3k-32ens-int and was happy to see it. The importance of doing the finetuning jointly with the knn is interesting. The approximation used are also ablated, like in table 12.

The paper is very well introduced and written, and make for a very pleasant read.

缺点

Some details are missing, for instance:

  • how are runtimes computed for figure 11? Which hardware? Which hyperparameters?
  • I'm not sure whether LocalPFN undergoes some HPO or not (I think not). If it's not the case, I think it should be made clearer, as it makes the comparison to tuned GBDT more impressive. If it does, the hyperparameter space should be provided.

I think TabPFN-1k-32ens-int (maybe less than 32 ens if it's slow?) would be an interesting baselines as TabPFN is optimized for smaller datasets than 3K.

I think more aggregation metrics should be provided (for instance mean rank, mean normalized score (z-score or other rescaled scores)..).

figure 4 lacks details: which datasets are used? Are the different datasets subsampled to reach smaller sample sizes? I also don't understand why absolute mean AUC is decreasing when the dataset size is increased after a certain point in Figure 7. Is it because the list of datasets is changing?

The code is not currently available.

问题

For inference, you have to use a different context for each points to predict, right? How long / memory intensive is it? What is the runtime repartition between finetuning and inference?

Do you think you can finetune with a smaller k than what you use for inference?

We use the experimental results from TabZilla [35] when they are available.

I assume they are always available expect for TabPFN variants? I also think that the fact that your using the hyperparameter spaces from [35] should be said in the main text (if that's indeed the case).

Which distance are you using in faiss? I'm wondering if using the euclidean distance would improve the performance of the ordinal encoded features for the knn (compared to using one-hot-encoding).

局限性

The authors have adequately addressed the limitation of their work.

作者回复

Thank you very much for your thorough review and helping us improve the paper.

How are runtimes computed for figure 11? Which hardware? Which hyperparameters?

The time reported is the average time to perform training+inference for a single run (averaged over parameters/datasets/seeds). As such it doesn’t indeed account for the HPO for baseline methods which used one. As reported in Appendix A.2.2 the hardware is as follows: All experiments for our proposed methods can be run on a machine with a single NVIDIA RTX 6000 GPU Ada Generation, 995Gi RAM, and AMD Ryzen Threadripper PRO 5995WX 64-Cores CPU.

I'm not sure whether LocalPFN undergoes some HPO or not (I think not). If it's not the case, I think it should be made clearer, as it makes the comparison to tuned GBDT more impressive. If it does, the hyperparameter space should be provided.

We conducted some minor experiments with HPO and saw no difference except for learning rate which we tuned by hand. Thus, we kept the default parameters we had initially (from TabPFN repo). We believe this, on top of other choices not having much effect, such as using embedding vs raw space or euclidean distance vs inner product, shows the approach is quite robust. Indeed we will make this clear in the main text, thanks!

I think TabPFN-1k-32ens-int (maybe less than 32 ens if it's slow?) would be an interesting baselines as TabPFN is optimized for smaller datasets than 3K.

We do not have the exact numbers but generally TabPFN generalizes well to larger context and on average using larger context helps. The performance of 1k-32ens-int was above 1k-32ens but below 3k-32ens-int.

I think more aggregation metrics should be provided (for instance mean rank, mean normalized score (z-score or other rescaled scores)..).

We provide the mean rank, normalized AUC, z-scores for the algorithms of Table 1 in Table 3 of the attached pdf. Please let us know if there are some metrics you’d particularly wish to see and we’ll try to include them during the discussion phase. We made the choice to avoid normalized measures as the main metrics in our paper as we believe it can hinder the reproducibility of the results as the scores are now dependent on an exact set of algorithms. However we agree that they are also interesting and will include them in the paper.

figure 4 lacks details: which datasets are used? Are the different datasets subsampled to reach smaller sample sizes? I also don't understand why absolute mean AUC is decreasing when the dataset size is increased after a certain point in Figure 7. Is it because the list of datasets is changing?

Good question, we will clarify it in the main text. Your second guess is correct. In this figure we have not subsampled or tampered with the datasets in any way, but only binned them in the appropriate size range. This also means that the mean AUC of each bin is not directly comparable to another bin as the datasets within one bin are totally disjoint from the one in another bin. It may be that datasets in the range 1000-3000 for instance are on average (in Tabzilla) harder than those in the range 3000-10000. This is the main reason why we use a standard algorithm as a baseline so that some phenomena (such as TabPFN’s performance decreasing) with dataset size become visible.

The code is not currently available.

Following NeurIPS policy we are reaching out to the AC in order to be allowed to share an anonymous repo link with you. Note that we intended to release our code later (as we would like to see our method being used, we are particularly hopeful TabPFN-kNN could be used as a replacement for TabPFN), and it is currently not optimally refactored/documented.

For inference, you have to use a different context for each points to predict, right? How long / memory intensive is it? What is the runtime repartition between finetuning and inference?

TabPFN-kNN is essentially LocalPFN without finetuning. By looking at Figure 11 it can be seen that LocalPFN takes an order of magnitude longer (or more) compared to TabPFN-KNN. This is due to LocalPFN having to do finetune. Thus finetuning takes an order of magnitude more time compared to inference.

Do you think you can finetune with a smaller k than what you use for inference?

We suspect this approach would work (refer to previous point about robustness). However, it would cause a disparity between finetune and inference and might have negative effects. In terms of smaller k helping with finetune time, most of time is spent for model parameter update not retrieval.

I assume they are always available expect for TabPFN variants? I also think that the fact that your using the hyperparameter spaces from [35] should be said in the main text (if that's indeed the case).

We thank the reviewer for the observation, we will update the paper to reflect these more clearly. Most of the values were indeed present but we needed to re-run TabPFN (missing many datasets) and perform HPO for catboost on one dataset it was missing as well.

Which distance are you using in faiss? I'm wondering if using the euclidean distance would improve the performance of the ordinal encoded features for the knn (compared to using one-hot-encoding).

We also had intuition on which encoding/distance would work best. We tried multiple approaches and they made little difference thus we ended up using the simplest variant. Please refer to the general message for more details.

评论

Thank you for the detailed answer, and for the new tables in the pdf!

The thing which is still not clear for me is the inference time. I think it would be useful to report both the training and inference time separately. In many settings people care about having very low inference time, and this is already a limitation of TabPFN (very fast training but quite slow inference time). Something like inference time / 1000 samples would be interesting to report in the paper. I'm still wondering if for TabPFN-KNN you need one forward pass per inference point (which should be much slower than TabPFN). In Figure 11 TabPFN-KNN seems really fast, so I'm wondering if I'm missing something or if you have some tricks.

评论

Indeed, we had only reported the full training+eval loop runtime for the algorithm but not the inference time specifically before. Let's get into more details.

To classify NqyN_\text{qy} examples given NctxN_\text{ctx} points, TabPFN would contruct a tensor of size (L=Nctx+Nqy,B=1,dL=N_\text{ctx}+N_\text{qy}, B=1, d) with appropriate masking. As we do not share the context across queries, with TabPFN kNN we need to have a specific context for each point and as such our input tensor is size (L=Nctx+1,B=Nqy,dL=N_\text{ctx}+1, B = N_\text{qy}, d) where the +1+1 is the query point to classify in each batch dimension.

For instance, for 512 query points and 1000 context size/neighbors, the inference speed of TabPFN would be about 0.01s. A naive TabPFN-kNN is about 1s. However using bf16 precision we can lower it to about 0.4s (0.008s for TabPFN). If we used 128 queries, the number would be 0.008s (TabPFN) vs 0.1s (TabPFN-kNN). So as we observe, we are slower than TabPFN.

In the 71 datasets in the attached pdf on which the deep learning baselines were able to run according to the Tabzilla design, the largest of the 95 datasets are typically excluded (this is because Tabzilla allocate a time limit for each algorithm and dataset, as such many --slow-- DL baselines do not run on the largest datasets in the given time budget). This is why TabPFN-kNN is really fast. On fig 12, including all datasets the average speed is about 10x slower but still fast considering there is no training. The fact that TabPFN-kNN is slower (and takes a lot more memory, especially when having to perform backprop as we do when performing fine-tuning end-to-end with retrieval) is why we introduced the approximate NN context sharing for fine-tuning. All in all, TabPFN-kNN is not the algorithm with the fastest inference speed, TabPFN is much faster and so is XGBoost etc...

However we believe this algorithm still has very important strengths: it scales extremely well with the dataset size (compared to TabPFN whose performance degrades strongly), so in the case of production ML models which have access to a very large dataset but only have to classify new queries at a lower rate, it is a very well suited algorithm. Furthermore, indexing on large datasets is usually much faster than retraining any ML algorithm, as such it can adapt very fast to changes in data distribution (such as covid for instance) without needing expensive retraining.

Note that the evaluation runtime was never the bottleneck in our research, so this is not what we focused on, but there are other simple ways one could improve upon it. For example, we can use the clustering capabilities of faiss to split the training data into a fixed number of contexts. This way we can ensure many queries will share the same context and harness the speed of TabPFN. While it may degrade the accuracy of the NN search, we think it can be an interesting trade-off if inference speed is an issue.

Thank you for the interesting comment, we will include numbers or a figure specifically concerning the inference speed for clarity.

Have you received the anonymized code link from the AC?

评论

Thank you for the very clear answer!

Have you received the anonymized code link from the AC?

No

评论

Here we share the anonymous link for our code: https://anonymous.4open.science/r/retrieve_ft-E2B0/README.md

Note that we removed some files (including the notebooks) to make sure anonymity is preserved. Thus the code for the analysis figures is not available in this repo (but we included that dataframe with all results).

The code should be able to run, if you have a GPU with limited memory (<48G) we recommend lowering the batch size or number of neighbors (--context_length 500 for instance). If you are familiar with the TabPFN code base you will notice that a large portion of the code has been rewritten, notably to be compatible with the pytorch transformers class.

You will probably be interested in the methods/pfknn.py (TabPFN-kNN) and methods/ftknn.py (LoCalPFN) files. There are currently a lot of options available in the code (different ways of computing neighbors, different embeddings etc..) which can make the code harder to read.

Let us know if you have any questions (and please do not share this code)!

审稿意见
6

The paper introduces Locally-Calibrated PFN (LoCalPFN), an advanced model for tabular data that enhances the transformer-based TabPFN by incorporating retrieval and fine-tuning techniques. By using k-Nearest Neighbours (kNN) to select a local context for each data point and fine-tuning the model on this retrieved set, LoCalPFN adapts more effectively to larger and more complex datasets. Extensive evaluations on 95 datasets from TabZilla demonstrate that LoCalPFN outperforms both neural and tree-based methods, setting a new state-of-the-art in tabular data classification. The key contributions include addressing TabPFN's scaling issues, proposing an improved context usage with retrieval and fine-tuning, and showcasing superior performance through extensive experimentation and ablation studies.

优点

  • The paper introduces Locally-Calibrated PFN (LoCalPFN), which enhances transformer-based in-context learning for tabular data by combining retrieval and fine-tuning techniques.
  • The research is robust, with extensive evaluations on 95 datasets from TabZilla. The authors provide comprehensive experimentation and analysis, demonstrating the effectiveness of LoCalPFN compared to strong baselines, including neural and tree-based models.
  • The paper is well-organized and clearly written, making the complex concepts and methods accessible.

缺点

  • Although LoCalPFN shows improved performance, the fine-tuning process increases computational complexity and runtime, especially with large datasets.
  • The paper's reliance on TabPFN as the base model restricts the generalizability of the proposed method. While LoCalPFN demonstrates significant improvements, it remains unclear whether these benefits would transfer to other in-context learning models for tabular data.
  • The current implementation of LoCalPFN is constrained by TabPFN’s limitations on the number of features and classes, as well as its incompatibility with regression tasks.

问题

  • Have you tested the retrieval and fine-tuning techniques used in LoCalPFN on other in-context learning models?
  • Given the constraints on the number of features and classes due to TabPFN, how do you plan to extend LoCalPFN to handle datasets with more features and classes?

局限性

  • The paper focuses solely on classification tasks and does not address regression tasks. This omission leaves a gap in evaluating the full potential of LoCalPFN.
作者回复

We thank the reviewer for their feedback.

“Although LoCalPFN shows improved performance, the fine-tuning process increases computational complexity and runtime, especially with large datasets.”

Yes this is true, however note that the complexity added by the finetuning process is not specific to LoCalPFN but a cost we also must pay when fine-tuning with TabPFN (vanilla). Note that despite the additional runtime, our method has similar runtime to many other deep learning approaches while having very high performance.

“The paper's reliance on TabPFN as the base model restricts the generalizability of the proposed method. While LoCalPFN demonstrates significant improvements, it remains unclear whether these benefits would transfer to other in-context learning models for tabular data.”

Please refer to the general message. The headline is that we expect the techniques and concepts from this paper to generalize to future tabular foundation models.

“The current implementation of LoCalPFN is constrained by TabPFN’s limitations on the number of features and classes, as well as its incompatibility with regression tasks.”

This is correct. Our contribution is a post-training one so we don’t retrain a base model. However we agree with the reviewer concerning the limitations of TabPFN and are currently working towards removing those. Please refer to the general message for more discussion regarding the possibility of other models.

“Have you tested the retrieval and fine-tuning techniques used in LoCalPFN on other in-context learning models?”

We are not aware of tabular in-context learners except for TabPFN. Are you referring to general in-context learners (i.e. LLMs)? There are multiple issues arising:

  1. Since each row cannot be considered as a token but as a (potentially very long) sequence of tokens of varying length, comparing rows for retrieval is not straightforward.
  2. How to embed rows is actually not obvious in the first place; since LLM perform best on text, most successful methods ([https://arxiv.org/abs/2206.06565 ]) rely on describing each row as a sentence.
  3. Because each row would have many tokens, our context size would be much more limited: 1-2 orders of magnitude below what TabPFN can do [https://arxiv.org/pdf/2406.12031 ].
  4. Despite all of this, LLMs for tabular data prediction still lag behind traditional methods and are harder to evaluate as they have memorized many well-known tabular datasets [https://arxiv.org/abs/2403.06644 ].

“Given the constraints on the number of features and classes due to TabPFN, how do you plan to extend LoCalPFN to handle datasets with more features and classes?”

Even though the current model is limited (100 features, 10 classes, no regression), it is still possible to extend it: for instance we can easily perform feature selection/averaging as the forward pass is fast. Any multiclass problem problem can be turned into multiple binary classification problems. And as we showed in the main message we can even perform regression as classification. However we also wish to go beyond those limitations and are actively towards designing a new architecture and training a new model that does not suffer from those.

Regression

Please refer to the general message

审稿意见
2

The paper proposes LoCalPFN, a new method that improves the scaling of transformer-based in-context learning for tabular data. It uses retrieval and fine-tuning to adapt the transformer to local subsets of the data, and demonstrates state-of-the-art performance on a variety of datasets. The paper makes the following contributions:

  • Provides a comprehensive analysis of TabPFN and identifies key limitations in TabPFN's ability to scale with dataset size and complexity.
  • Proposes LoCalPFN, a novel approach that combines retrieval and fine-tuning techniques. LoCalPFN leverages these methods to enable more effective utilization of context, thereby enhancing the scalability of in-context learning for tabular data.
  • Demonstrates the superiority of LoCalPFN through extensive evaluation and ablation studies. These evaluations reveal that LoCalPFN outperforms baselines on a wide range of datasets

优点

  • Novel Combination of Retrieval and Fine-tuning: The paper introduces an new approach by combining retrieval and fine-tuning for tabular data.

  • Extensive Evaluation: The paper presents a thorough evaluation of the proposed method (LoCalPFN) across a wide range of datasets, comparing it with numerous baselines and conducting ablation studies.

缺点

The paper contains several inaccurate statements, For example, in abstract, "Recent advancements using transformer-based in-context learning have shown promise on smaller and less complex datasets, but have struggled to scale to larger and more complex ones." is inaccurate -- depending on the dataset and task definitions. There are many such examples in the paper, which make it hard to understand the scope of the paper. It raises concerns about the overall reliability of the paper's findings.

The paper fails to clearly define the complexity of tabular datasets, making it difficult to assess the relevance and significance of the proposed method (LoCalPFN) in addressing the challenges of complex datasets.

The paper does not address the important practical considerations of the model's cost and latency at inference time, especially retrieval is used. This omission is a significant limitation, as it hinders the understanding of the model's real-world applicability. As acknowledged in the paper, LoCalPFN has a slower runtime compared to tree-based methods, which are a popular choice for tabular data. While the paper mentions a faster variant (TabPFN-kNN), its performance is not as strong as LoCalPFN."

The proposed LoCalPFN method is heavily reliant on TabPFN as the base model. This raises concerns about the generalizability of the approach to other base models and task types (e.g., regression).

问题

  1. why the dataset size would affect the model quality? is it related to the task complexity and diversity of the samples?
  2. can you given an example of the query x_qy for tabular data? and the distance metrics being used in kNN?

局限性

The paper does not address the important practical considerations of the model's cost and latency at inference time, especially retrieval is used. This omission is a significant limitation, as it hinders the understanding of the model's real-world applicability. As acknowledged in the paper, LoCalPFN has a slower runtime compared to tree-based methods, which are a popular choice for tabular data. While the paper mentions a faster variant (TabPFN-kNN), its performance is not as strong as LoCalPFN."

The proposed LoCalPFN method is heavily reliant on TabPFN as the base model. This raises concerns about the generalizability of the approach to other base models and task types (e.g., regression).

作者回复

We appreciate the time you took to review our paper. We hope to clarify some points of confusion below in our response.

"Recent advancements [...] have struggled to scale to larger and more complex ones." is inaccurate -- depending on the dataset and task definitions.

Can you please clarify this point? This sentence is meant to be understood in the context of tabular data, as the previous one begins with “Tabular data”. In any case, we will update it to “have shown promise on smaller and less complex tabular datasets“ to remove ambiguity.

However, if this sentence is being challenged in the context of tabular data, several works have indeed shown TabPFN to be a competitive model [e.g., https://arxiv.org/abs/2305.02997, https://openreview.net/forum?id=XctSyEsBzx] however it doesn’t translate as well to large datasets because of the limited context size [https://arxiv.org/abs/2402.06971 Fig 3, or Fig 2 of our paper].

“There are many such examples in the paper, which make it hard to understand the scope of the paper. It raises concerns about the overall reliability of the paper's findings.”

We would be happy to have such examples pointed out so that we can better clarify the paper and convince you that our results are strong and reliable.

“The paper fails to clearly define the complexity of tabular datasets, making it difficult to assess the relevance and significance of the proposed method (LoCalPFN) in addressing the challenges of complex datasets.”

We would like to address the comment with several points.

Complexity of a dataset is one of the reasons we invoked when explaining why retrieval can help, even when the dataset size is small. Note that the main motivation for our paper is still to be able to use such in-context models on large datasets. Neither our method nor our results (Table 1 and 6-7) are affected by how we measure complexity.

There isn’t one way to measure data complexity and it is still an open question. Many of the usual measures are based on a notion of “compressibility” dating back to Kolmogorov complexity. Many practical methods are based on a measure of discrepancy between the data distribution (usually p(x)p(x) or p(yx)p(y|x)) and a model of known capacity. For instance, let’s consider a linear model and a more complex non-linear model (e.g., polynomial). If both the linear and non-linear model provide a similar approximation to p(yx)p(y|x) (i.e., low KL(p(yx)q(yx))\text{KL}(p(y|x)\mid\mid q(y|x)) or equivalently cross-entropy), then we can argue that the classification problem is not complex as a simple linear model is enough to approximate it. On the other hand, if the discrepancy between the two models is large, then a non-linear model is necessary to explain the data and as such it is more complex. This is very much in line with our intuition and how the term is used in machine learning, where some tasks are considered “hard” (such as mathematical reasoning) as small models don’t perform well but more advanced ones (e.g., adding more parameters or some search capability) perform better. Thus we argue that measuring the discrepancy between the loss/performance of different algorithms is both intuitive and widely used. Indeed we see examples in the tabular domain [https://arxiv.org/abs/2207.08815, 3.1 “not easy” paragraph] where the gap is measured, or [ https://arxiv.org/pdf/2305.02997 ] where datasets are considered “hard” if a simple baseline does not achieve top performance. We are therefore very much in-line with the standards of the field, including the tabular sub-domain.

“The paper does not address the important practical considerations of the model's cost and latency at inference time, especially retrieval is used. This omission is a significant limitation, as it hinders the understanding of the model's real-world applicability.”

We believe you are referring to our mention of the runtime (L328) and Figure 11 in the Appendix. You can see there that the non-finetuned method (TabPFN-kNN) is faster than all tree based methods but does not significantly outperform XGBoost or CatBoost. As a reminder, TabPFN-kNN represents the raw cost of inference for our technique when only retrieval – not fine-tuning – is performed, which should hopefully assuage concerns about the runtime of the retrieval component. Meanwhile, LoCalPFN is slower (by a factor 10-25x on average) compared to a single XGBoost run, but achieves significantly better performance. Depending on the exact budget and the hardware available to a specific person, TabPFN-kNN, XGBoost, CatBoost, or LoCalPFN might be the best choice. Note that, as GPUs are improving and inference is made more efficient by the day (e.g., better flash attention, quantization, etc.), LoCalPFN’s inference speed should improve over time. Please see the general message for more details; we have included a runtime comparison against popular deep learning techniques there as well, showing that LoCalPFN leads on performance while remaining competitive on runtime.

The proposed LoCalPFN method is heavily reliant on TabPFN as the base model. This raises concerns about the generalizability of the approach to other base models and task types (e.g., regression).

Please refer to the general response.

Dataset Size and Model Quality: We empirically notice that for the datasets & models we have, the task tends to be a bit harder as dataset size increases.

Query Example: Could you clarify what kind of example you would like to see? For all queries, we first standardize its features using training statistics and then retrieve its neighbours in the training set using L2 distance. We then pass the (context, query) vectors to TabPFN.

Overall: Given the strengths you pointed out (comprehensive analysis and extensive ablations, novel method), and that we have hopefully addressed all your concerns, we ask you to re-consider your rating of 2. This is generally reserved for papers that are technically wrong.

评论

Hi,

We would like to again thank you for taking the time to review our paper, but would also very much appreciate if took the opportunity to engage in discussion if time permits, considering that we have clarified some of the more negative points of your review both in the rebuttal to you and the shared rebuttal to all reviewers (that one has a PDF attached).

Thanks!

作者回复

General message

We would like to thank the reviewers for their assessment on our work. Overall, reviewers have appreciated our evaluations and experiments (mentioning points such as “extensive” [j47T], “robust and comprehensive” [iFbB], and “well done, .. principled, … no-cherry picking, … strong ablations” [jqEN]). Our paper was deemed “well-organized and clearly written” [iFbB] and “well written and pleasant to read” [jqEN, paraphrased] by two reviewers. The contribution was also deemed “novel” by multiple reviewers [j47T, jqEN].

However, there were some important and relevant points raised by multiple reviewers, and so we would like to address those here and not just in the individual responses.

Dependence on TabPFN [iFbB, j47T, ZbyM]

Some reviewers (along with us in the paper) accurately noted that our results currently depend heavily on TabPFN. However we point out that our method is made for in-context architectures that consider each datapoint as a token. As far as we know, TabPFN is currently the only one based on this idea, but we expect many new models of this class to follow considering the success of TabPFN. We can draw an analogy to early techniques leveraging BERT as the base model when it was one of the very few performant language models; in particular, works promoting retrieval developed with a heavy reliance on BERT readily extended to future improvements, and many remain highly relevant today. We thus feel strongly that the ideas and concepts presented here will remain relevant for future tabular foundation models.

Regression [j47T, iFbB]

A fair point that was raised is that we do not perform regression experiments, as TabPFN was not trained for regression. We would like to address this point with two arguments:

Our method, both retrieval and fine-tuning aspects, would apply exactly the same way to a pre-trained regression model that processes inputs in a similar manner. Following the reviews, we also have tried to perform regression with our method by simply binning the regression targets into 10 classes, thus using a regression-as-classification approach that has shown success in other domains (e.g., https://arxiv.org/pdf/2402.13425, https://arxiv.org/abs/2403.03950). However, note that TabPFN was not trained to perform this directly. To improve performance when using local contexts, we can perform local binning and predict 10 local values for each point. On the other hand, using a global context would be akin to having a global binning as well which can restrict the precision of the output. We validate this idea on the well known california-housing dataset (about 20k samples, 8 features) and provide results in Table 2 of the joined pdf.

TabPFN using random samples as context and binning them is not a strong regressor: a kNN regressor can do much better. However, when using a local context and binning, we can predict finer-grained values, and thus our MSE and correlation is closer to XGBRegressor. Note here that we did not even fine-tune the method.

Thus, we have already observed a generalization of the ideas presented in this paper; as discussed above, we expect further generalizations of our ideas in the realm of tabular data to be discovered.

Runtime Details [j47T, jqEN, ZbyM]

While we include runtime results in the Appendix in Figure 11, admittedly this only compares our technique against tree-based approaches. In Table 1 of the joined pdf, we include a more thorough comparison with deep learning baselines on the 71 datasets of Table 5. TabPFN-kNN and LoCalPFN outperform the deep learning baselines by a large amount. Furthermore LoCalPFN has similar runtime to other deep learning baselines while being significantly more performant.

Distances and Details for the Retrieval [j47T, jqEN]

We will update the main text of the paper and appendices to include more details, which we also describe here. We always standardized the features of the training set to build the index. Next, for the retrieval, we experimented with different metrics, namely Euclidean (L2) distance and cosine similarity. The main takeaway from our experiments is that these two choices are quite similar when applied to the standardized features and so we stuck with the Euclidean distance.

Yet we were initially convinced that using embeddings of the datapoints would be far better than just using the standardized features for retrieval. However, we found that techniques such as using the encoder outputs or some keys/values at different layers as embeddings – with either L2 or cosine – did not increase performance. Indeed, using embeddings affected performance negatively in many cases, while additionally forcing us to recompute the index during training as the embeddings of the training set were also changing, adding computational burden.

We also tried learning the distance function (e.g., Mahalanobis distance, weighted L2, etc.). Unfortunately though this is non-differentiable and thus requires less-standard optimization techniques such as zeroth-order optimization. While we were able to discard irrelevant features in toy examples, the noise introduced by this optimization did not lead to an increase in performance on real datasets.

Furthermore, we’d like to emphasize here a difference we have with Retrieval-Augmented Generation techniques (RAGs): while in RAGs only a few documents (1 to 5) are usually selected, here we select about 1000 neighbors. This adds robustness against selecting a few irrelevant retrieved examples. Furthermore, in the worst case scenario, if we retrieve examples based on irrelevant/noisy features, our context becomes random, which is what the TabPFN baseline uses. This fact is one of the reasons why retrieval helps so much: in the best case we have large performance improvements, and in the worst case we have something similar to TabPFN.

最终决定

The paper provides a retrieval and finetuning method that stacks on top of the current best tabular data in-context learner: TabPFN. They show consistent improvements compared to a TabPFN + k-nn baseline. The reviewers agree the experimental design is solid and the paper is thorough. The authors also provided a solid rebuttal to all the comments provided by the reviewers. There was one very negative reviewer who did not respond and gave a low score -- aside from this reviewer the average was acrtually high (5,7,6). For this reasons + having read the paper myself and it seems experimentally solid, I vote to accept this paper.