SMART: Towards Pre-trained Missing-Aware Model for Patient Health Status Prediction
摘要
评审与讨论
This paper presents SMART, a novel model designed to tackle the challenges of missing and irregular data in electronic health records (EHRs). Utilizing a two-stage training strategy, SMART first pre-trains to handle missing data in the latent space and then fine-tunes for specific clinical tasks. The model's innovative masked attention recurrent transformer (MART) block captures temporal and variable interactions, significantly improving prediction accuracy across various clinical tasks. Evaluated on three EHR datasets, SMART outperformed existing models, demonstrating robust performance and versatility. Despite its strengths, including handling missing data and achieving superior prediction accuracy, the model's complexity and limited dataset variety highlight areas for further exploration and improvement.
优点
One of the notable strengths of this paper is its originality in addressing the pervasive issue of missing and irregular data in electronic health records (EHRs). By introducing the SMART model with its innovative masked attention recurrent transformer (MART) block, the authors present a novel approach that captures temporal and variable interactions more effectively than traditional methods. The paper excels in explaining the complex mechanisms of the SMART model and its components, making it accessible even to those not deeply familiar with the intricacies of machine learning models.
缺点
(1) the lack of comparative analysis with state-of-the-art models beyond the specific baseline models mentioned. While SMART outperformed these baselines, a broader comparison with the latest advancements in EHR prediction models would provide a clearer benchmark of its superiority. This omission could leave readers questioning how SMART fares against the most cutting-edge approaches in the field. (2) the paper could benefit from a more in-depth discussion on the computational efficiency of the SMART model, particularly regarding inference time and resource requirements. Given the increasing emphasis on real-time decision support in clinical settings, understanding the model's computational demands is crucial for practical implementation.
问题
I'd like for the authors to respond to below questions if they can: - (1) Could you elaborate on why you chose to reconstruct latent representations during pre-training rather than imputing missing values in the input space? (2) Which component of SMART do you believe contributes the most to its superior performance, and why? (3) Given the quadratic complexity of temporal attention in SMART, how scalable is the model when applied to EHR datasets with varying lengths of patient records? Have you explored strategies to mitigate computational costs without compromising performance? (4) The ablation study provides insights into the importance of different components within SMART. What specific findings surprised you the most during these experiments, and how did they influence the design or interpretation of the final model? (5) Could you discuss any challenges or limitations encountered during the implementation of SMART in real-world clinical settings?
局限性
There is a notable focus on performance across various datasets, but the generalizability of the model to different healthcare systems or diverse patient populations remains unclear. It would be beneficial for the authors to discuss potential ethical concerns or unintended consequences their approach might introduce, such as biases in predictions or challenges in interpretability that could affect clinical decision-making.
Thank you for recognizing the strengths of SMART, especially its effectiveness and novel designs. We address your concerns and answer your questions below.
W1: Lack of comparative analysis with state-of-the-art models beyond the specific baseline models mentioned.
Thank you for your question about the latest baseline methods. We would like to clarify that we have used the latest methods, including SAFARI and RainDrop in 2022, Warpformer and Primenet in 2023, and PPN in 2024. If we missed any latest methods, please remind us and we will include them in our comparison.
W2: The paper could benefit from a more in-depth discussion on the computational efficiency of the SMART model, particularly regarding inference time and resource requirements.
Thank you for your concern about the computational efficiency. We would like to clarify that we have compared the computational efficiency of the models in 4.3, and SMART has achieved the best balance among all methods in the comparison of training time performance. As for inference time, SMART can complete the inference of patients in a very short time (within 1 second), which can meet the needs of clinical scenarios.
Q1: Could you elaborate on why you chose to reconstruct latent representations during pre-training?
Thank you for your question about why we reconstruct in the input space. As we explained in lines 59-60 in the introduction, imputation in the input space may cause the model to get stuck in unnecessary details instead of learning more high-level semantic features that are beneficial to downstream tasks. More specifically, the imputation models pursue more accurate interpolation at each missing sample point, while this work pursues performance on clinical tasks, that is, completing tasks based on the overall information of the sequence. The optimization directions of these two goals are different. Therefore, we did not impute missingness in the input space during pre-training. For more explanation, please refer to our global response.
Q2: Which component do you believe contributes the most to its superior performance, and why?
Thank you for your curiosity about the effectiveness of SMART. We conducted ablation experiments on each design of the model in Section 4.2.2, and the results show that the information from the missing mask is the most important. Without it, the temporal and spatial attention we proposed cannot work properly, and the model cannot perceive the missing data. In addition, as described in lines 287-310, the ablation experiment verifies that each innovative design (including pre-training, temporal attention, variable attention, and CLS Vector) is very important and can bring performance improvement.
Q3: Given the quadratic complexity of temporal attention in SMART, how scalable is the model when applied to EHR datasets with varying lengths of patient records? Have you explored strategies to mitigate computational costs without compromising performance?
Thank you for your question about scalability. The length of patients in our dataset is variable, and SMART is scalable for them. However, due to the lack of corresponding datasets, we did not conduct experiments on longer series. The focus of our work is to combine missing awareness to improve performance on clinical predictions, so we have not explored strategies to reduce computational costs without affecting performance. There are some existing methods, such as Linformer [1], RWKV [2], etc., which reduce the computational complexity of the attention mechanism. They can be utilized to accelerate this method, which can be a future work.
[1] Linformer: Self-attention with linear complexity. arXiv 2020.
[2] RWKV: Reinventing RNNs for the Transformer Era. EMNLP 2023.
Q4: What specific findings surprised you the most during ablation studies, and how did they influence the design or interpretation of the final model?
Thank you for your question about the ablation experiment. In the ablation experiment, we found that compared with the method of imputing in the input space, the improvement of SMART is obvious, so we are very glad to see that the proposed method is verified in the experiment. In addition, the effectiveness of the CLS vector also surprised us since it brought a very large improvement. The introduction of the CLS vector also provides direction and insights for the future development of models for time series.
Q5: Could you discuss any challenges or limitations during the implementation in real-world clinical settings?
Thank you for your interest in the application of SMART in real clinical scenarios. We have mentioned the challenges of SMART in real-world application scenarios in Appendix B Broader Impact, including the possibility of making unfair predictions for patients and potential ethical issues. In particular, we would like to emphasize that the model is only a tool to assist physicians in making decisions. The model should be used together with physicians to make the best decision for patients.
Limitations: The generalizability of the model to different healthcare systems or diverse patient populations remains unclear. It would be beneficial for the authors to discuss potential ethical concerns or unintended consequences their approach might introduce, such as biases in predictions or challenges in interpretability that could affect clinical decision-making.
Thank you for your concern about ethical issues. We have mentioned the possible ethical issues of SMART in Broader Impact. We acknowledge that SMART may make unfair predictions for patients, leading to potential ethical issues. However, the model cannot make any decisions on behalf of physicians. The model is only a tool to help physicians understand the patient's condition. How to improve the fairness of the model is one of the possible research directions in the future.
We hope these explanations adequately address your concerns.
Dear Reviewer Hfjo, I am a NeurIPS 2024 Area Chair of the paper that you reviewed.
This is a reminder that authors left rebuttals for your review. We need your follow up responses on that. Please leave comment for any un-answered questions you had, or how you think about the author's rebuttal.
The author-reviewer discussion is closed on Aug 13 11:59pm AoE.
Best regards, AC
I appreciate the authors for taking the time to provide clarification. In general, I feel that most of my concerns have been addressed in the rebuttal, and my review of the paper is now complete.
We are pleased to hear that your questions have been satisfactorily answered! Your questions are insightful and valuable. We will ensure to incorporate these additional specifics in our final revision, aiming to enhance the clarity of the presentation.
The paper presents SMART, a self-supervised representation learning approach that tries to tackle the problem of missingness in EHR data. It proposes a novel self-supervised pre-training approach which is able to reconstruct missing data representations in the input space and makes use of both temporal and variable attention mechanisms to achieve that. The pre-trained encoder can be further fine-tuned with a label-specific decoder for different downstream classification tasks. Through multiple datasets, comparisons with baselines, and comprehensive ablation studies, the authors show the effectiveness of their method at generalization and robustness to missing data.
优点
- Missingness is an important issue in medical domain, especially in EHR data, so having a ML model that is missing-aware and can still create meaningful representations is impactful
- Even though the two-stage training process may not be new, the creation of the MART blocks and the pre-training paradigm seems novel
- The proposed method is able to comprehensively beat the previous baselines across all datasets that were tested, both in performance and in training times
- Multiple ablations showcase the effectiveness of different components of the model architecture
- The paper is well-written and code is provided
缺点
- During the pre-training stage, since it is based on reconstruction, access to the full dataset with all observations is required. The method would not work if the training data had missing values as well.
- During fine-tuning, only the label decoder is updated during the first few epochs of training but it is unclear as to why this is needed. It is mentioned that the pre-trained parameters are reserved, but a more quantitative explanation or an ablation would help.
问题
- Are there any evaluations done just for the pre-training task? How good is the model at reconstruction and imputing missing values?
- Were there any constraints placed on the imputed values during training? It might happen that the model imputes missing variables with unrealistic values.
- Can this model be extended to incorporate multiple modalities like images or text (clinical notes), perhaps by learning separate missing-aware encoders for each of them? Something like this is done in [1].
- How does this model compare against a fully-supervised method where all of the data is available?
[1] Wu, Zhenbang, et al. "Multimodal patient representation learning with missing modalities and labels." The Twelfth International Conference on Learning Representations. 2024.
局限性
N/A
Thank you for recognizing the strengths of SMART, especially its novelty and effectiveness. We address your concerns and answer your questions below.
W1: During the pre-training stage, since it is based on reconstruction, access to the full dataset with all observations is required. The method would not work if the training data had missing values as well.
Thank you for your concern about data integrity. It appears there may be a misunderstanding regarding the task at hand. We have mentioned in lines 122-128 that the EHR data we use contains a lot of missingness. The goal of this work is to endow the model with the ability to perceive missing data. In the pre-training stage, we sample missingness based on probability on the basis of existing missing data as our learning target, rather than on fully observed data. The missing-aware method we proposed effectively enhances the model's ability to learn missing data and improves the predictive performance in clinical tasks.
W2: During fine-tuning, only the label decoder is updated during the first few epochs of training but it is unclear as to why this is needed. It is mentioned that the pre-trained parameters are reserved, but a more quantitative explanation or an ablation would help.
Thank you for your interest in the fine-tuning experiment settings. We hope that the model can retain the pre-trained parameters in the first several epochs instead of updating all parameters, so that the initial optimization goal of the model is to improve the classifier's adaptation to the embedding. Assigning different learning rates or schedulers to various parameters is a common strategy in the fine-tuning process. To further resolve your doubts, we provide the experimental results of not freezing parameters in fine-tuning as follows.
| Model | Cardiology AUPRC(%) | F1(%) | Sepsis AUPRC(%) | F1(%) | In-hospital Mortality AUPRC(%) | F1(%) |
|---|---|---|---|---|---|---|
| w/o freeze | 51.462.38 | 46.893.19 | 79.813.15 | 74.293.11 | 50.811.47 | 42.852.63 |
| SMART | 53.842.24 | 47.532.33 | 81.670.84 | 75.372.62 | 53.300.12 | 44.232.03 |
The ablation results show that if the parameters are not frozen, the model performance will degrade, which may be caused by the large initial learning rate, resulting in the pre-trained parameters not being fully retained.
Q1: Are there any evaluations done just for the pre-training task? How good is the model at reconstruction and imputing missing values?
Thank you for your interest in pre-training evaluation. Our model is not specifically designed to solve the imputation problem, but to improve performance on clinical prediction tasks. Thus, it can only complete the reconstruction in the latent space.
For better understanding, we give a detailed explanation and we will add the explanation in our future submission. The observed value in the input space can be viewed as a composed signal. The learned encoder can be viewed as a signal filter that decomposes the observed composed signal over a learned "dictionary". The "dictionary" is the affine transformation(s) shared by all the time series that transforms the input-composed signal to learned decomposed embeddings. The learned embeddings can be regarded as the decomposition of the "dictionary". Our reconstruction in the latent space is essentially a reconstruction of the decomposed filtered signals of different entries. On the one hand, this reduces the noise in the original input-composed signal. On the other hand, this pursues the consensus of different time series that converge to the underlying expectation. Thus, reconstruction in the latent space achieves better performance than reconstruction in the input space. When it comes to evaluation, although the loss can be calculated, this loss may be meaningless and cannot be compared with existing imputation methods.
Q2: Were there any constraints placed on the imputed values during training? It might happen that the model imputes missing variables with unrealistic values.
Thank you for your concern about the imputed values. Because our learning rate is relatively small and the pre-training process is data-driven, there will be no imputation outside the data distribution. Thus, no explicit constraints are required. Although the model may provide incorrect imputed representations, they still reflect some objective reality learned by the model. On the other hand, since our model reconstructs in the missing in the latent space, we cannot really confirm whether it will provide unrealistic imputed values.
Q3: Can this model be extended to incorporate multiple modalities like images or text (clinical notes)?
Thank you for your curiosity about the potential of the model to handle multimodal data. The work you provided is a very good reference for handling multimodal data, and we will cite it in future versions. The method we proposed has the potential to be extended to more modalities. It can support other modalities by masking part of the image or text and restoring them in pre-training, so that the model can better learn high-order representations and complete prediction tasks. However, since multimodal data is more complex, multi-task learning brought by pre-training may be one of the challenges, which can be used as a future exploration direction.
Q4: How does this model compare against a fully-supervised method where all of the data is available?
Thank you for your curiosity about the results of full data training. The fine-tuning phase of our method and all baseline methods are fully supervised and trained on the full data, although they contain a large amount of missing data. The results show that SMART is significantly ahead of these methods, verifying the effectiveness of our work on the full data.
We hope these explanations adequately address your concerns.
I thank the authors for providing responses to the questions and for the additional experiments. The rebuttal has clarified most of my concerns; a few points to note:
- Regarding the W1 clarification, I think the Figure 1 caption should be updated which could be misunderstood. Currently, it states that 'We randomly mask EHR data...' but it should say something along the lines of 'The EHR data already contains missing values and we randomly mask existing observations...'.
- In the introduction, it is stated in lines 41 and 42 that "By masking certain observations to serve as targets for imputation, a portion of the data is withheld from the model and cannot be fully utilized for predicting patient". However, if I am understanding correctly, the pre-training stage does this as well, as stated on line 195: "generate a mask to remove the existing observations partially" which will have the same weaknesses. If this is true, the introduction should be reworded.
- Algorithm 1 should be updated to clearly define the inputs and the outputs.
- Please update the manuscript to specify why it is difficult to evaluate the reconstruction (since it is happening in the latent space).
Overall, I am satisfied with the rebuttal and will raise my score accordingly.
We are delighted that our responses have addressed most of your concerns, and we greatly appreciate your thoughtful suggestions. We will incorporate these additional revisions to further enhance the clarity and quality of our manuscript. Below, we address each of your remaining points.
The Figure 1 caption should be updated.
Thank you very much for your suggestion! We will take steps to describe our problem and data more clearly, including revising the caption to more accurately describe our data and methodology.
The should be reworded since the pre-training stage masks certain observations to serve as targets for imputation as well.
Thank you for your concern. You are correct in noting that the pre-training stage involves masking observations for imputation. However, we would like to clarify that during the fine-tuning stage, SMART leverages the complete dataset for clinical task training, thereby mitigating the limitations associated with training on incomplete data. This two-stage approach, also used by Primenet [1] (as referenced in lines 110-112), helps to address the concerns related to training on incomplete data.
[1] Primenet: Pre-training for irregular multivariate time series. AAAI 2023.
Algorithm 1 should be updated to clearly define the inputs and the outputs.
Thank you for this suggestion. We will revise Algorithm 1 to clearly define the inputs and outputs, thereby improving the clarity of the problem description.
Please update the manuscript to specify why it is difficult to evaluate the reconstruction (since it is happening in the latent space).
We appreciate your suggestion and will update the manuscript to include an explanation. Evaluating the reconstruction in latent space is challenging due to the model-dependent nature of the reconstruction and the absence of ground truth, making direct evaluation difficult.
Thank you once again for your insightful feedback and for your willingness to raise the score. We are committed to making these revisions to ensure our work is presented as clearly as possible.
The authors propose a novel approach to handling missing data in an attention-based module in a method that is geared for predicting downstream health-related outcomes given multivariate time series patient data in EHR settings. Specifically, the proposed module, termed as the MART block, biases the attention across the temporal dimension using a heuristic which favors time steps with observations. In quantitative benchmarks involving six different disease-like outcomes, the method obtains the highest accuracy among several other methods that are also designed to learn patient representations given time series data in an EHR setting.
优点
The paper is clearly written. The functionality of the proposed module was described in a straightforward way. I believe that the MIMC pipeline adhered to is quite standard and if the quantitative results are reliable then the amount of improvement attained by the approach deserves recognition.
缺点
I felt there was not enough motivation in the introduction in terms of why handling imputation in the latent space is expected to be a better approach than prior approaches, other than that it is possibly a novel approach. Is the intuition supposed to be similar to how latent diffusion works well? Some background or citations would further in convince the reader.
Does the work have any considerations for MAR, MCAR, MNAR, and structured missingness cases (all of which are present in EHR data); are MART blocks supposed to work well for all such types of missing data patterns? Due to some of these questions, I was not convinced upon a first read that biasing pairwise missing value statuses in a monotonically increasing scale (Eq 1) should intuitively be beneficial; although the quantitative results speak for themselves.
Were there any results on measuring how well the imputation methods impute missing observations in the actual time series data? eg. This could be done through a simulation. It seemed that all benchmarks were based on downstream tasks. I assume the proposed method also does well in this case, or there was an assessment that it does not matter in the perspective of improving downstream performance. It would be nice to understand if the contribution of the method is mainly in improving prediction of downstream tasks, or if it also imputes the data well.
The quantitative result obtained by SMART was encouraging, but I did not feel that the benchmark fully encompassed the breadth of methods which are available. Firstly, it seemed lacking in terms of basic baselines using logistic regression and gradient boosting methods to provide a good intuition on how easy or difficult the downstream tasks are; I believe these are important for studies that employ their own benchmarking pipeline on MIMIC.
I also personally wished to see more baselines which are commonly explored in the imputation literature (or at least mention them at all). To list some, there is softImpute, MissForest, MICE, and a Joint low-rank model from Sportisse et al. (2020). Among more recent works, there is HI-VAE, GP-VAE, and notMIWAE. It is true that many of these works do not handle time series data or may not have been designed with EHR in mind, but there are straightforward ways to process the data such that they can be applied.
Furthermore, there are a few recent works that also use an attention model and claimed sota at the time of their release in terms of time series imputation. Some words on how they are related or why or why not they are a good fit for these tasks would add to the comprehensiveness of the work.
[1] Zhang et al 2023 “Improving Medical Predictions by Irregular Multimodal Electronic Health Records Modeling“ [2] SAITS from Du et al 2023 “SAITS: Self-Attention-based Imputation for Time Series” [3] GRIN from Cini et al 2022 “Filling the G_AP_S: Multivariate Time Series Imputation by Graph Neural Networks” [4] Marisca et al 2022 “Learning to Reconstruct Missing Data from Spatiotemporal Graphs with Sparse Observations”
Since the representation learning part of the approach is emphasized in a few parts of the work, I was surprised not to see any visualizations of the embeddings learned (per patient or per variable), and how if any it differs significantly from those obtained from prior works.
A minor point, but for first time readers it would be nice to have citations in lines 257 to 259 for SAFARI, PPN, and GRASP despite they might have been mentioned & cited elsewhere.
问题
Is the CLS token involved in the pre-training stage at all similarly BERT? If so, what does it predict in the pre-training stage?
I just want to confirm that it would be true for a reader to interpret that even when the model is ablated in any which way (no mask, no temporal or variable attention, no cls), it would still be a top 1 or 2 ranking method in the benchmarks? Was there any ablation which degraded the model beyond this ballpark?
局限性
The work discusses some limitations in the final section. I did not assess that this work could have a potentially negative societal impact.
W1: Why handling imputation in the latent space is expected to be a better approach than prior approaches?
Thank you for your insightful questions about why our approach works. We have shown in lines 59-60 that reconstruction in the latent space can help it better learn higher-order data patterns instead of focusing on trivial details. For more explanation, please refer to our global response.
W2: Are MART blocks supposed to work well for all types of missing patterns? Why biasing in a monotonically increasing scale in Eq 1 is beneficial?
Missingness in EHR is usually considered to be MNAR [5]. Whether to check an indicator is determined by physicians, and the reason is not reflected in the data. Whether a particular indicator is assessed is informative, as the absence could suggest that the condition with it is not severe. The performance gains with mask also verify that MART is well suited for MNAR data. However, since it is impossible to model the cause of missingness, we randomly sample missing data in pre-training.
For the question in Eq 1, we intended to endow the perception of the degree of missingness (both, only one, or none). It did not matter whether it was manually set or learnable since the weight matrices in the attention will adapt to them. We also conducted a simple experiment by replacing 1 and 2 in Eq 1 with two learnable parameters (w/ learnable bias), as shown in the rebuttal PDF.
The experimental results show that the results of the manual value are better, so we finally used the manual setting of the monotonically increasing values.
[5] A Bayesian latent class approach for EHR-based phenotyping. Statistics in Medicine 2019.
W3: Were there any results on measuring how the imputation methods impute missingness?
Since the goal of pre-training is to improve performance on clinical tasks, SMART is not designed to provide an imputed series. For further exploration, we added a decoder at the same level as the embedding decoder in the pre-training, which is used to impute in the input space, and evaluated the performance after fine-tuning (w/ both imputation), as shown in the rebuttal PDF.
We found that with this decoder, the performance declined and was even worse than w/o Pre-training, which indicates that there may be a trade-off between imputation in the latent and input space, and the pursuit of accurate imputation may lead to poor performance in downstream tasks. (We did not add special designs on this ablation, so the conclusion may be inaccurate.)
W4: It seemed lacking in terms of basic baselines.
We have shown the results of the basic methods GRU and Transformer in Appendix A.6, and they are worse than most of the baselines, showing the difficulty of tasks. Since LR and tree models (XGBoost, etc.) are not recursive, they do not apply to our datasets composed of variable-length series.
W5: I wished to see more baselines commonly explored in the imputation literature.
Thank you for providing a lot of literature. We are delighted to read them and will discuss them to enrich our paper. However, most methods, including [3,4], are not designed for completing clinical tasks. There are also some imputation models combined with tasks, such as [1,2] you mentioned, but [1] does not essentially use any supervision related to imputation, similar to Warpformer[6]. [2] does not apply to time series with varying lengths, so it cannot be evaluated on our dataset. In addition, although it is not designed to complete downstream tasks, their authors compare the results of these tasks by using GRU to model the imputed series. Nevertheless, training [2] for imputation will remove some of the sampled data as imputation supervision, which potentially reduces the performance on downstream tasks, as we mentioned in lines 107-109.
[6] Warpformer: A multi-scale modeling approach for irregular clinical time series. KDD 2023.
W6: Were there any visualizations of the embeddings learned?
Thanks for your reminder. We are sorry that we overlooked it. We have uploaded the embedding visualization comparison by t-SNE of the patients in the test set from the Cardiology dataset in the rebuttal PDF. We can find that the embedding learned by SMART is more discriminative, which qualitatively verifies its effectiveness.
W7: It would be nice to have more citations of baselines.
Thanks for your suggestion. We will add them in future versions, which will improve the quality and readability.
Q1: Is the CLS vector in the pre-training similar to BERT? What does it predict?
The CLS vector is somewhat similar to [CLS] in BERT, but they are not the same. The mention of BERT is just for ease of understanding. In the Next-Sentence-Prediction task of BERT, [CLS] is used for prediction to explicitly encourage it to express the information of the entire sequence, while there is no such supervision in the pre-training of SMART. In the pre-training, the encoding of the CLS vector position is not calculated in the loss (because its is False), because our original intention is to calculate the hidden representation of the removed sample points, and including it in the loss may bring slight performance changes.
But similar to [CLS], the CLS vector does not represent any time step but learns the overall representation, which is a bridge between the pre-training and fine-tuning, as mentioned in line 308. It provides a location to store the overall information (because its mask is always True), and using it as a query in Variable Attention encourages this.
Q2: Would SMART be a top-ranking method without novel designs?
Thank you for your question about the performance. However, this idea may be incorrect. Without all the unique designs, the model will degenerate into a Transformer. As shown in Appendix A.6, it is significantly weaker than most baselines, and the gains brought by pre-training obviously cannot make up for such a large gap.
Thank you for the thorough response, overall I believe most of my concerns have been addressed.
The responses to W1,W3, W5, W7, Q1, Q2 seem reasonable.
W2: We also conducted a simple experiment by replacing 1 and 2 in Eq 1 with two learnable parameters (w/ learnable bias), as shown in the rebuttal PDF.
The addition of this experiment is appreciated. I think there is enough evidence to support that the original idea is reasonable and should be shared with others who might be interested in using the attention mechanism in similar settings with missing data. For future works it would be interesting to investigate the quite large variance with the learnable bias (+-3, so on the high end it can surpass the fixed bias?) and similarly check if the bias learned from scratch aligns at all with the intuition of the proposed bias strategy in this case. Also would it be advantageous at all to have a pairwise bias that is separate depending on which side of the attention you are processing (ie. upper vs lower triangle)? Just a thought, I believe the work so far is sufficient.
Small note that maybe the F1 score of 75.50±2.75 should be bolded instead of 75.37±2.62 if the intent was to bold the highest obtained score if the table will make it into the final paper.
W4: It seemed lacking in terms of basic baselines.
Even though some of the methods are not recursive temporally, the prediction of the downstream task is not time specific (it is an outcome prediction) so I personally find the lack of basic processing -> baseline method for outcome prediction not satisfactory. For instance, a feature vector with counts of observed conditions & prescriptions (or means and standard deviations for continuous measures) could be generated summarizing the entire time series of each patient (eg [1]). I would not further ask for this analysis though as no other reviewer has raised a similar concern and it seems very few prior works in the area do this.
For the reasons above I would be willing to raise the score.
[1] EHRSHOT: An EHR Benchmark for Few-Shot Evaluation of Foundation Models
We are pleased to hear that your questions have been satisfactorily answered! Your suggestions including adding more literature on imputation models are indeed valuable. We will ensure to incorporate these additional specifics in our final revision, aiming to enhance the clarity of the presentation.
For more concerns about W2 and W4, we answer below.
W2: We also conducted a simple experiment by replacing 1 and 2 in Eq 1 with two learnable parameters (w/ learnable bias), as shown in the rebuttal PDF.
We're glad the additional experiment with learnable parameters provided sufficient evidence. We agree that exploring the variance observed with the learnable bias and introducing a separate pairwise bias depending on the attention side (upper vs. lower triangle) are promising avenues for future research. The attention side approach, in particular, may help address low-rank issues in vanilla attention.
Thank you for pointing out the F1 score in the table. We will ensure that the correct value, 75.50±2.75, is bolded in the final version to reflect the highest obtained score.
W4: It seemed lacking in terms of basic baselines.
Thank you for your feedback. We recognize the importance of including basic baselines and have conducted two ablations to illustrate the performance of Logistic Regression (LR) and XGBoost:
- Classify the last observation directly: This approach may introduce significant missingness.
- Classify the last observation with a front-fill strategy: Here, missing values are imputed using the latest available values from past observations.
| Model | Cardiology AUPRC(%) | F1(%) | Sepsis AUPRC(%) | F1(%) | In-hospital Mortality AUPRC(%) | F1(%) |
|---|---|---|---|---|---|---|
| LR(Direct) | 30.670.33 | 14.130.88 | 14.050.67 | 0.470.33 | 30.290.85 | 10.852.25 |
| XGBoost(Direct) | 32.292.48 | 21.691.70 | 21.261.90 | 13.780.14 | 30.761.36 | 21.802.53 |
| LR(Front-fill) | 47.313.47 | 35.980.32 | 18.921.84 | 3.790.41 | 46.452.51 | 32.091.85 |
| XGBoost(Front-fill) | 45.413.31 | 35.380.36 | 27.310.39 | 17.390.33 | 47.551.76 | 38.002.06 |
| SMART | 53.842.24 | 47.532.33 | 81.670.84 | 75.372.62 | 53.300.12 | 44.232.03 |
We were surprised to find that the front-fill strategy outperformed some of the baseline methods on the Cardiology dataset. This may be due to the specific characteristics of the Cardiology dataset, which could be more amenable to basic baseline approaches. Please note that the front-fill strategy was not applied in the experiments in our manuscript.
We appreciate your willingness to raise the score and thank you again for your thoughtful suggestions.
Thank you for providing the additional baselines; I believe they provide a great deal of intuition on the difficulty of the tasks, especially for those not familiar with the domain or the methods. The continued work by the authors is appreciated.
The paper presents a strategy to account for missing data in EHR called SMART. This is broken down in 2 stages: pretraining and fine tuning. Pretraining learns a hidden state representation which is done by randomly making the input and predicting the label, while fine tuning uses this hidden state representation and is tuned for specific down stream tasks. The authors demonstrate the efficacy on multiple datasets and on high impact areas: cardiology, sepsis and in-hospital mortality. The methodology is also quite light weight as demonstrated in Fig3 and the authors also capture some of the limitations of the modified attention mechanism.
优点
Enough experiments to convince of initial efficacy New attention mechanism is quite light weight Latent representations are being learnt effectively to handle missing data
缺点
Can show more areas of impact by looking at conditions where the missing data can cause more problems to look at the limit of SMART Quadratic attention mechanism is also a problem (although approximations have shown to reduce that cost) Can the authors show: projecting the hidden state representation back to the data level what the model is learning by using the masks? To present us an idea of what the model deems important to learn when the data is missing
问题
Can the authors show: projecting the hidden state representation back to the data level what the model is learning by using the masks? To present us an idea of what the model deems important to learn when the data is missing
局限性
Yes they have addressed it.
Thank you for recognizing the strengths of SMART, especially its effectiveness and novel designs. We address your concerns and answer your questions below.
W1: Can you discuss the limitations of SMART by showing more areas of impact by looking at conditions where missing data could cause more problems?
Thank you for your question about the limitations of SMART. We have discussed the limitations in clinical scenarios in the Conclusions and Limitations section and the Broader Impact in the Appendix. Missing vital signs or lab results can lead to incorrect prediction or delayed diagnoses. SMART can be aware of the missingness in the time series and enhance its prediction performance, avoiding potential delayed diagnoses due to lack of observations. However, we recognize the necessity for more precise methods in future research.
For non-clinical fields, such as finance, electricity, meteorology, etc., where missingness are common in time series, SMART has the potential to play a role in these fields. However, for ECG/EEG or other high-frequency data, the quadratic computational complexity may limit the usability of SMART.
W2: The quadratic attention mechanism is also a problem (although approximations have been shown to reduce that cost).
Thank you for your concern about the computational complexity. Our contribution lies in proposing a missing-aware EHR model to better accomplish clinical prediction tasks. Although the quadratic computational complexity may limit the use of SMART on very long time series, it is applicable to EHR data, given the limited length of time series. In addition, there are some methods compatible with SMART, such as model compression or linear attention [1,2], which can reduce overhead and improve scalability, but this is beyond the scope of this paper and can be further explored in future work.
[1] Linformer: Self-attention with linear complexity. arXiv 2020.
[2] RWKV: Reinventing RNNs for the Transformer Era. EMNLP 2023.
W3&Q1: Can you project the hidden state representation back to the data level and show what the model is learning by using the masks?
Thank you for your question about how the model imputes the missingness. We focus on reconstructing missing data in the latent space, so the imputed data cannot be mapped back to the input space. In particular, because our goal is to improve the accuracy of clinical tasks when designing the method, we encourage the model to learn as much information as possible about the entire sequence in the pre-training stage and learn task-related embeddings in fine-tuning, rather than simply imputing.
For better understanding, we give a detailed explanation and we will add the explanation in our future submission. The observed value in the input space can be viewed as a composed signal. The learned encoder can be viewed as a signal filter that decomposes the observed composed signal over a learned "dictionary". The "dictionary" is the affine transformation(s) shared by all the time series that transforms the input-composed signal to learned decomposed embeddings. The learned embeddings can be regarded as the decomposition of the "dictionary". Our reconstruction in the latent space is essentially a reconstruction of the decomposed filtered signals of different entries. On the one hand, this reduces the noise in the original input-composed signal. On the other hand, this pursues the consensus of different time series that converge to the underlying expectation. Thus, reconstruction in the latent space achieves better performance than reconstruction in the input space.
We hope these explanations adequately address your concerns.
Thanks for the response. Some of the responses point to the efficacy of the method and I’d like to keep the score the same.
We are pleased to hear that your questions have been satisfactorily answered! Your questions are insightful and valuable. We will ensure to incorporate these additional specifics in our final revision, aiming to enhance the clarity of the presentation.
Dear Reviewer, I am a NeurIPS 2024 Area Chair of the paper that you reviewed.
This is a reminder that authors left rebuttals for your review. We need your follow up responses on that. Please leave comment for any un-answered questions you had, or how you think about the author's rebuttal.
The author-reviewer discussion is closed on Aug 13 11:59pm AoE.
Best regards, AC
We thank all reviewers for their high-quality comments and for recognizing the strengths of SMART. We have addressed all the concerns and answered questions in the rebuttal.
Here, for the commonly asked question of why reconstruction in latent space is more effective than imputation in input space, we give a detailed explanation and we will add the explanation in our future submission. The observed value in the input space can be viewed as a composed signal. The learned encoder can be viewed as a signal filter that decomposes the observed composed signal over a learned "dictionary". The "dictionary" is the affine transformation(s) shared by all the time series that transforms the input-composed signal to learned decomposed embeddings. The learned embeddings can be regarded as the decomposition of the "dictionary". Our reconstruction in the latent space is essentially a reconstruction of the decomposed filtered signals of different entries. On the one hand, this reduces the noise in the original input-composed signal. On the other hand, this pursues the consensus of different time series that converge to the underlying expectation. Thus, reconstruction in the latent space achieves better performance than reconstruction in the input space.
In the rebuttal PDF, we uploaded more ablation experiments and the embedding visualization comparison by t-SNE of SMART and baseline methods on the Cardiology dataset mentioned by Reviewer MRHA. By observation, we can find that the embedding learned by SMART is more discriminative, which qualitatively verifies the effectiveness of our method.
The paper presents SMART, a novel approach for handling missing data in electronic health records (EHR) using an attention-based module. The method aims to improve predictions for downstream health-related outcomes given multivariate time series patient data.
The reviewers generally agreed on the paper's strengths, noting its clear presentation, novel approach to handling missing data in the latent space, and impressive performance on multiple datasets and high-impact clinical areas. The lightweight nature of the modified attention mechanism was also appreciated.
However, some concerns were raised. Reviewer MRHA questioned the motivation for handling imputation in the latent space and requested more background on why this approach is expected to be superior. The authors addressed this in their rebuttal, explaining that reconstruction in the latent space helps learn higher-order data patterns and reduces noise in the original input signal.
Another concern was about the method's effectiveness for different types of missing data patterns (MAR, MCAR, MNAR). The authors clarified that missingness in EHR is usually considered MNAR and provided evidence that SMART is well-suited for this type of data.
Reviewer MRHA also requested more comprehensive baselines, including basic methods like logistic regression and gradient boosting, as well as more recent imputation techniques. The authors explained that some basic methods were included in the appendix and that certain suggested baselines were not directly applicable to their variable-length time series data.
Both reviewers raised questions about the imputation quality in the input space. The authors conducted additional experiments showing that pursuing accurate imputation in the input space may lead to poor performance in downstream tasks, supporting their focus on latent space reconstruction.
Reviewer PYgK suggested exploring more areas of impact and discussing limitations in scenarios where missing data could cause more problems. The authors acknowledged this and provided some discussion on potential applications and limitations in non-clinical fields.
After considering the reviews, rebuttal, and subsequent discussions, I believe the paper presents a valuable contribution to the field of EHR data analysis and clinical prediction tasks. The authors have adequately addressed most of the reviewers' concerns and demonstrated the effectiveness of their approach.
As an Area Chair, my recommendation is to accept this paper for publication at NeurIPS. The novel approach, strong empirical results, and potential impact in clinical applications make it a worthy addition to the conference. However, I encourage the authors to incorporate the additional explanations and discussions from their rebuttal into the final version of the paper to enhance its clarity and comprehensiveness.