ShieldHead: Decoding-time Safeguard for Large Language Models
摘要
评审与讨论
This paper proposed a new framework, ShieldHead, that identifies the harmful contents on the sentence and token levels. They employ the label disambiguation technique to supervise model training with token- and sentence-level labels. In their experiments, they show the effectiveness of their method across diverse datasets and models. They also demonstrate the inference efficiency of their proposed framework.
优点
- The paper introduces a framework that can identify harmful contents on both sentence- and token-level.
- They utilize the label disambiguation method to automatically assign token-level labels based on sentence-level labels.
- They provide extensive experiments to evaluate the effectiveness of their proposed methods compared to LLMs and SOTA models for safety detection.
缺点
- Lack of some model comparisons. For sentence-level classification, I expect they can compare to a model that is fully fine-tuned llama-3.1-8b on the WGTRAIN dataset. GPT-4 can use a few WGTRAIN samples as a few-shot in-context learning examples in the prompt. For the token-level evaluation, they should also compare to some baseline models, such as GPT-4.
- Lack of clarity on their token-level evaluation experiment. I wonder if they annotated sub-word tokens (split by llama model) or words. What is the annotation agreement? How do they assign labels when there are any disagreements between annotators? It is also weird to me that they annotate "all tokens following an unsafe token in the sentence were marked as unsafe." Again, they should also compare to some baseline models for token-level classification. I will ask more questions in the next section.
- Unclear on calculating inference speed. I wonder how they calculate the latency of GPT-4 which is an API-based LLM.
问题
- Do you annotate sub-word tokens (split by llama model) or words? What is the annotation agreement? How do they assign labels when there are any disagreements between annotators?
- I wonder why "all tokens following an unsafe token in the sentence were marked as unsafe." in your token-level annotation. Does this result in a significant amount of unsafe tokens? What is the ratio between unsafe and safe tokens in your token-level test set?
- When you generate the token-level labels, does your method have the same behaviour that all tokens following an unsafe token in the sentence were marked as unsafe?
- Line 204, what is the value of K in your experiment?
伦理问题详情
N/A
Thanks for your thoughtful reviews and suggestions. We've adapted your comments in the revision. The concerns are addressed and summarized as follows:
Q1: Lack of some model comparisons. For sentence-level classification, I expect they can compare to a model that is fully fine-tuned llama-3.1-8b on the WGTRAIN dataset. GPT-4 can use a few WGTRAIN samples as a few-shot in-context learning examples in the prompt. For the token-level evaluation, they should also compare to some baseline models, such as GPT-4.
A: As suggested, additional experiments have been conducted. The results are shown as Rebuttal Table 3, and included in Table 1 (LN340) in the revision:
(1) We fully fine-tuned the Llama-3.1-8B model on the WGTRAIN dataset, following WildGuard's training settings and chat template . The performance was lower than both WildGuard’s (which used WGTRAIN on Mistral-7B-v0.3) and our results. ShieldHead outperforms full fine-tuned on Llama3.1 model by an average F1-score (prompt +1.6%, response +1.5%).
(2) We tested GPT-4 with two WildGuard samples as few-shot in-context examples. For details on the prompt template, we refer the reviewer to Table 6 in Appendix A.6. The results show that GPT-4 with few-shot learning significantly improved, especially in S-RLHF, indicating the benefits of few-shot learning for specific dataset distributions, though it still underperforms ShieldHead by an average F1-score of 1.3% and 1.9% across prompt and response benchmarks, respectively.
Rebuttal Table 3: Comparisons of F1 scores (%) between ShieldHead (with Gemma2-9B as base model) and other methods on prompt and response safety classification in existing public benchmarks.
| Model | Prompt Harmfulness (F1) | Response Harmfulness (F1) | |||||
|---|---|---|---|---|---|---|---|
| XSTest | OpenAI Mod | ToxicChat | Avg. | BeaverTails | S-RLHF | Avg. | |
| Full FT on Llama3.1 | 94.3 | 72.9 | 66.0 | 77.7 | 81.7 | 83.8 | 82.8 |
| WildGuard | 94.8 | 72.1 | 67.1 | 78.0 | 84.1 | 84.0 | 84.1 |
| GPT-4 (zero shot) | 89.5 | 70.5 | 68.3 | 76.1 | 86.1 | 67.9 | 77.0 |
| GPT-4 (few shot) | 89.5 | 74.4 | 70.2 | 78.0 | 88.5 | 76.3 | 82.4 |
| Ours | 95.1 | 76.3 | 66.5 | 79.3 | 80.0 | 88.5 | 84.3 |
(3) We tested GPT-4 with three samples as few-shot in-context examples for word-level safety classification. For details on the prompt template, we refer the reviewer to Table 7 in Appendix A.6. Due to challenges in ensuring GPT-4 adheres to token (sub-word) segmentation, we conducted word-level testing. For a fair comparison, we recalculated our word-level classification results by using the label and prediction of the first token in multi-token words as the representative. The recalculated results for ShieldHead's word-level safety classification yielded an F1 score of 85.1% and an accuracy of 78.9%. In contrast, GPT-4’s output included 29 sentences where the number of predictions differed from the actual number of words. Excluding these cases, the word-level safety classification yielded an F1 score of 73.6% and accuracy of 68.0%. GPT-4’s performance in word-level classification underperforms ShieldHead by F1-score of 11.5% and accuracy of 10.9%. To the best of our knowledge, no other baseline models currently support token-level safety risk classification.
. Han, Seungju, et al. "Wildguard: Open one-stop moderation tools for safety risks, jailbreaks, and refusals of llms." arXiv preprint arXiv:2406.18495 (2024).
Q2: Do you annotate sub-word tokens (split by llama model) or words?
A: Thank you for your comment. We labeled sub-word tokens (as split by the Llama model) rather than whole words. Since the model's output is token-level, the labels are also token-level. Specifically, when a word is split into multiple sub-word tokens, all corresponding tokens share the same label (safe/unsafe). Our token-level annotation can also be referenced in Figure 4 of the main paper, where each cell represents a token, with words (e.g., "don’t" split into "don" and "’t") shown accordingly.
Q3: What is the annotation agreement? How do they assign labels when there are any disagreements between annotators?
A: As to the annotation agreement, we followed the best practice from prior studies, such as ToxicChat and Beavertails , adapting their sentence-level annotation method to our token-level annotation: Two independent annotators were labeling each token within the sentence. Any disagreements between the annotators were resolved through discussion to reach mutual agreement on all tokens. Details are included in Appendix A.5 (LN804) in the revision.
Lin, Zi, et al. "Toxicchat: Unveiling hidden challenges of toxicity detection in real-world user-ai conversation." arXiv:2310.17389 (2023).
Ji, Jiaming, et al. "Beavertails: Towards improved safety alignment of llm via a human-preference dataset." Advances in NIPS 36 (2024).
Q4.1: I wonder why "all tokens following an unsafe token in the sentence were marked as unsafe." in your token-level annotation.
A: Unlike the literal meaning of token-level classification, which evaluates individual tokens in isolation. We take into account the surrounding context (previous tokens) when determining whether a token is labeled as "safe" or "unsafe." This approach is consistent with the way GPT-like models generate text, where the lm-head is influenced by all preceding tokens. Hence, if an unsafe token indicating that the entire sentence is unsafe appears, the subsequent tokens are also considered unsafe, as they are contextually linked to the unsafe token. And we mark them as unsafe accordingly.
Q4.2: Does this result in a significant amount of unsafe tokens? What is the ratio between unsafe and safe tokens in your token-level test set?
A: Such a token-level annotation method does not result in a significant amount of unsafe tokens. In practice, unsafe sentences do not necessarily begin with the first word being unsafe, which helps create a balanced distribution. As indicated in LN373, the constructed token-level test dataset comprised a total of 2,694 tokens, with 1,745 labeled as unsafe and 949 as safe. This yields a ratio of approximately 2:1 between unsafe and safe tokens.
Q5: When you generate the token-level labels, does your method have the same behaviour that all tokens following an unsafe token in the sentence were marked as unsafe?
A: Thank you for your comment. Our method exhibits the behavior that all tokens following an unsafe token in the sentence were marked as unsafe when making token-level safety classification. This behavior arises from the decision-making process of ShieldHead, which relies on hidden states that capture the cumulative context in GPT-like architectures. When an unsafe token indicating that the entire sentence is unsafe appears, the subsequent tokens are influenced by this prior context and are also predicted to be unsafe. An example is shown in Figure 4 (LN445) in Section 3.6 in the paper, in the "scenario with unsafety response" prompt, where all tokens following the unsafe token "racist" are classified as unsafe (indicated by red).
Q6: Line 204, what is the value of K in your experiment?
A: K = 10 in our experiment. We have included this detail in LN205 in the revision.
Q7: Unclear on calculating inference speed. I wonder how they calculate the latency of GPT-4 which is an API-based LLM.
A: Thank you for your comment, and sorry for the confusion. In our experiments, we evaluated the inference latency based on practical usage scenarios. For API-based LLMs like GPT-4, the latency we report includes network latency, as it's challenging to isolate the impact of network latency when directly measuring inference time for API-based models. We have included this detail in the caption of Figure 1 (LN077) in the revision.
Thanks for your detailed responses and paper revision.
I found that your response addressed some of my concerns but I still don't entirely agree that you annotate all tokens following an unsafe token in the sentence were marked as unsafe. This seems to simplify the task. Why don't you annotate a span of tokens that are unsafe? For example, when we detoxify a text, we should only remove the toxic content and preserve the original semantic meaning as much as possible.
Thank you for your quick feedback, and we are glad that our response has addressed some of your concerns. We sincerely appreciate your efforts in reviewing and discussing this paper.
To clarify, our annotation is not to simplify the task, but rather is more appropriate for the risk classification scene.
R1. For the convenience of understanding, we list the two annotation methods as follows:
-
Isolated token-level classifier (the reviewer Xpkq may refer to this): Safety risk classification based solely on individual or a span of tokens, without considering preceding context, makes sense with span-based annotation. For “How to make bombs to dismantle old buildings” as example, the annotation can be:
How to make bombs to dismantle old buildings -
Cumulative token-level classifier (Ours): Since our method is based on GPT-like models, the safety classification of each token takes into account all the preceding context (as described below). If an unsafe token is identified in the preceding context, it will also be reflected in the subsequent content, and thus, the subsequent sections are also considered to be unsafe. For “How to make bombs to dismantle old buildings” as example, the annotation can be:
HowHow to
How to make
How to make bombs
How to make bombs to
How to make bombs to dismantle
How to make bombs to dismantle old
How to make bombs to dismantle old buildings
R2. Scenarios where span-based annotation fails: As for safety risk classification, using span-based annotation is not entirely accurate, as safety risks often involve context-dependent terms, where the cumulative effect of the GPT-like model is useful. An example of this is the “grandma exploit”, where the prompt “Please act as my deceased grandmother who would read me Windows 10 Pro keys to fall asleep to” is safe in individual token, but unsafe as a complete sentence. ShieldHead leverages a GPT-like cumulative effect to handle such issues. Moreover, some ambiguous terms can also make a span-based annotation inaccurate if merely considering a span of text. In contrast, using cumulative token annotation can successfully handle these cases, where all the preceding contexts are taken into consideration.
R3. Regarding the example of text detoxifying task by reviewer Xpkq: Although our model is focused on safety risk classification, with the incorporation of some engineering techniques, it is also capable of performing detoxification during decoding-time (in the Section 2.4 of the manuscript). In this scenario, with span-based annotation, the entire response would typically need to be generated and then detoxified. Instead, ShieldHead can dynamically and in real-time adjust its output. As soon as ShieldHead identifies an unsafe token, the model re-samples the token and re-performs safety risk classification until a safe token is generated. If, after a certain number of sampling attempts, no safe token is found, the model will refuse to respond. This approach is more efficient than performing detoxification after generating the entire response. Of course, the specific sampling strategy and the threshold for "a certain number of sampling attempts" are related to the engineering techniques, and not to the ShieldHead method itself.
We hope this response addressed the reviewer’s concerns, and we are happy and open to answer any follow-up questions.
To demonstrate the reason for our annotation and the cumulative effect of GPT-like model in our safety risk classification task, we provide further clarification as follows:
R4. Theoretical background: Most work in text generation considers all preceding tokens when performing next token prediction , and our work aims to classify safety of the responses generated by these text generators. Therefore, intuitively, we follow this manner.
Hence, in this paper, we adopt the cumulative effect of GPT-like models . Instead of directly labeling each token individually (treat them independently), we label tokens taking into account all the preceding context. Our safety risk classification is derived from the conditional probability , where to represent the first to the th token, and denotes the safety risk classification of the th token.
Following the previous example presented in R1, the token "dismantle" in the example is classified as safe when treated individually. However, in our approach, when considering "dismantle", we take into account the preceding tokens, thus label it as unsafe. Details are shown below.
-
Isolated token-level:
- How to make bombs to dismantle
-
Cumulative token-level:
- How to make bombs to dismantle
R5. Experimental results: Note that this is a manual annotation method for the evaluation set, whereas the token-level labels used during training are obtained through a prototype-based disambiguation method (LN303 in the paper). In our experiments, following the cumulative effect of GPT-like models, we use the safety classification of the last token in a sentence as the prediction for the sentence. We achieve state-of-the-art performance in sentence-level safety classification. This also demonstrates that the cumulative effect is applicable in the context of safety classification in our work.
Thanks again for your comments and efforts, we are happy and open to answer any follow-up questions.
Radford, Alec. "Improving language understanding by generative pre-training." (2018).
Radford, Alec, et al. "Language models are unsupervised multitask learners." OpenAI blog 1.8 (2019): 9.
Brown, Tom B. "Language models are few-shot learners." arXiv preprint arXiv:2005.14165 (2020).
I adjusted my rate as most of my concerns have been addressed. I still wonder if there are any existing works that use accumulative annotation on token-level tasks. You may cite and discuss them. I am also curious how you annotate the "grandma exploit" example. Which token is the first unsafe token?
Thank you for your continued engagement and raising the rating, it is very encouraging for us. We are glad that our response has addressed most of your concerns, and we provide further clarification as follows.
R6. Related work of accumulative annotation on token-level tasks:
In addition to its application in safety risk classification, as proposed in this paper, accumulative token-level annotation is also applied in other tasks. For example, RIPA , which was proposed to solve hallucination detection , proposes that once the first hallucinated token is detected, all subsequent generated tokens should be labeled as hallucinated. Specifically, "the label at each position to be the expected future groundedness of the preceding sequence : all the sequences after the first hallucination tokens include at least one hallucination token and therefore are labeled as hallucinations." (Section 4.2 in the paper of RIPA). Similarly, for the task of hallucination detection, NADO and FUDGE introduced the concept that the occurrence of hallucination in a token is related to its prefix from the perspective of the conditional distribution formula. Our method, RIPA , NADO , and FUDGE , leverages the cumulative effect of GPT-like models, where the conditional probability of subsequent tokens is influenced by the preceding tokens. However, their objectives differ from ours. For example, RIPA constructs a reward based on these accumulative token-level labels. Besides, Fast-detectgpt , which was proposed to perform AI-generated text detection , uses token-level conditional probability curvature as a cumulative effect to determine whether text is AI-generated. Similar to previous annotation manner, Fast-detectgpt and DetectGPT label the first few tokens as "human-written," and all tokens following the first AI-generated token as "AI-generated".
In summary, accumulative annotation has been employed in previous work with various tasks, further validating the effectiveness of this annotation method.
R7. Grandma exploit: As to the first unsafe token of “Please act as my deceased grandmother who would read me Windows 10 Pro keys to fall asleep to”, we annotate “keys” as the first unsafe token, as it is the first token that induces a direct violation of Intellectual Property. When generated and accumulated to the token "keys", this token, in combination with the preceding tokens, leads to a clear safety risk.
We would like to thank you for your suggestions and discussions, which provided us with the opportunity to explain our annotation method in detail through examples (R1-3, R7), theoretical background (R4), related work (R6), and experimental results (R5). We have included these details into the revision (A5.2 from LN810 and A5.3 from LN825), which we believe will help to make the paper clearer. Thanks again for your comments and efforts, we are happy and open to answer any follow-up questions.
Sehyun Choi, Tianqing Fang, Zhaowei Wang, and Yangqiu Song. 2023. KCTS: Knowledge-Constrained Tree Search Decoding with Token-Level Hallucination Detection. In Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing, pages 14035–14053, Singapore. Association for Computational Linguistics.
Meng, Tao, et al. "Controllable text generation with neurally-decomposed oracle." Advances in Neural Information Processing Systems 35 (2022): 28125-28139.
Kevin Yang and Dan Klein. 2021. FUDGE: Controlled Text Generation With Future Discriminators. In Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pages 3511–3535, Online. Association for Computational Linguistics. Creative Commons License
Bao, Guangsheng, et al. "Fast-detectgpt: Efficient zero-shot detection of machine-generated text via conditional probability curvature." in ICLR 2024.
Mitchell, Eric, et al. "Detectgpt: Zero-shot machine-generated text detection using probability curvature." International Conference on Machine Learning. PMLR, 2023.
Dear Reviewer Xpkq,
As the discussion period is drawing to a close, we’d like to summarize the rebuttal period: Additional experiments results are provided in safety-critical scenarios and detoxification applications. We clarified our token-level annotation and prototype-based label disambiguation method. Ablation studies on the moving-average momentum 𝛾 and 𝜎 are conducted to offer better insight into the mechanics of the proposed prototype-based label disambiguation method.
Once again, we greatly appreciate your efforts and comments, which have been invaluable in helping us improve the paper. We hope to engage in further discussions and hear more feedback from you. If you have any remaining questions that we can address during the review period, we would be happy to answer them.
Dear Reviewer Xpkq,
I hope this message finds you well. This is a gentle reminder that our rebuttal as well as our updated paper are online.
Thank you very much for your time and dedication to this process. We understand that your time is very valuable, but as the discussion period ends within 4 hours, we kindly ask you to consider providing any feedback and/or increasing your score. We are of course ready to discuss any remaining concern you may have.
Warm regards,
Authors
The paper presents a new approach to content moderation in Large Language Models (LLMs), called ShieldHead. This method uses the hidden states of a dialogue model to identify harmful content, which significantly reduces computational costs and increases efficiency compared to previous methods. ShieldHead operates in real-time, providing moderation results alongside the output of the chatbot system. It uses a label disambiguation technique to enhance its performance, learning from both token-level and sentence-level labels. The paper demonstrates that ShieldHead outperforms previous models in terms of speed and accuracy, achieving state-of-the-art performance on the XSTest and SafeRLHF datasets while using approximately 99% less parameters than previous LLM-based moderation models.
优点
- The authors introduce ShieldHead, a classification head that identifies harmful content by learning from the last-layer hidden states of the dialogue model. This method is efficient, providing real-time moderation results alongside token-wise streaming output during the chatbot system's decoding phase.
- The paper is clearly written, with a well-structured argument and clear explanation of the methodology and results. It provides a comprehensive overview of the current landscape of content moderation methodologies and the challenges and limitations of existing approaches.
- The significance of the research lies in its potential to improve the safety and efficiency of LLMs. By providing real-time moderation results, ShieldHead reduces the likelihood of harmful content being exposed by AI chatbots.
缺点
- The major concern is the novelty. The use of an extra head/classifier to identify specific content, in this case harmful content, is a common and conventional method used in past researches across various research topics. For instance, "Bottom-Up Abstractive Summarization Sebastian (Gehrmann et al., 2018)" uses a token-level classifier to select summarization tokens. In my view, the main innovation lies in the supervisory method for the classifier, referred to as label disambiguation in this study. The technique employs a moving-average to refresh class (safe/unsafe) prototype embeddings derived from the hidden features of each class samples. Nevertheless, aside from the ablation results, there doesn't seem to be any intuitive or theoretical backing for this approach.
- While the experimental data suggests a positive effect on performance, the underlying processes and reasons for the success of label disambiguation remain unclear. What is the benefit of using moving-average for prototype updating as opposed to a simple average across hidden features from the same class? In my view, they essentially perform the same function. What is the rationale behind reducing \gama and \sigma during the training process and is it necessary to do so? I cannot find any explanation or ablation study on them.
- The method's effectiveness varies across different datasets. ShieldHead surpasses the baseline on XSTest and S-RLHF, but falls short on OpenAI Mod, ToxicChat, and BeaverTails. This inconsistency diminishes the significance of the method.
问题
LN193: the basic idea and intuition of the label disambiguation method is unclear. How and why does it work? What is the difference compared to a simple everage of the hidden features of each class?
LN196: the prototype is mentioned in various context, such as prototype embedding, prototype vector, and prototype label, but what is the connection between the basic idea of prototype and your problem? Why don't use a simple classifier?
LN209: the p^token_(i,j,c) is confusing. Does it refer to the token probability or the token itself?
LN213: do the prototypes P^(t+1)_c refer to the prototype embeddings?
LN215: what kind of normalization is used here?
LN228: the next line says it is a moving-average update, but the equation uses sentence-level label to update token-level labels, which is not a moving average.
LN265: the equation presents the calculation of L_prompt and L_res, but what do they refer to and why do we need them?
Table2: Since the method does not show consistent results accross datasets, models, and settings, the choice of Gemma2-9B looks like a cherry pick. Do we have a general observation about when the method will work?
Table3: "w/o Token Loss (only sentence-level labels)" performs better than "w/o Label Disambi. (both sentence-level and token-level labels but without disambiguation strategy)", does it mean that the token-level labels without disambiguation actually harm the performance?
Table3: ablation on L_prompt and L_res is missing.
Q8: Table2: Since the method does not show consistent results accross datasets, models, and settings, the choice of Gemma2-9B looks like a cherry pick. Do we have a general observation about when the method will work?
A: Thank you for your comment. In this paper, we follow the evaluation manner of WildGuard and ShieldLM which are widely used as baselines. The performance of ShieldHead depends on the representation capabilities of the base model, particularly in terms of safety-related features. Stronger base models, especially those with better safety-related representations, tend to yield higher performance in safety risk classification when trained with ShieldHead. In our experiments, we observed that Gemma2 outperforms LLaMA in terms of safety representation, which is why it shows better results in our current setup. Besides, ShieldHead is designed as a general framework and can be applied to different base models. All the models and code will be released on acceptance of the work. All other base models, including the emerging ones, can be tested by all researchers.
. Han, Seungju, et al. "Wildguard: Open one-stop moderation tools for safety risks, jailbreaks, and refusals of llms." arXiv preprint arXiv:2406.18495 (2024).
Zhang, Zhexin, et al. "Shieldlm: Empowering llms as aligned, customizable and explainable safety detectors." arXiv preprint arXiv:2402.16444 (2024).
Q9: Table3: "w/o Token Loss (only sentence-level labels)" performs better than "w/o Label Disambi. (both sentence-level and token-level labels but without disambiguation strategy)", does it mean that the token-level labels without disambiguation actually harm the performance?
A: Thank you for your comment. Yes, you are correct. The token-level labels without disambiguation do indeed harm performance. To clarify, the token-level labels refer to assigning the sentence-level label to all tokens within the sentence, without any refinement. This approach introduces noise, as certain tokens in an unsafe sentence (e.g., "the", "are") are incorrectly labeled as unsafe. As mentioned in W2.1 and Q2, the purpose of the prototype-based label disambiguation technique is to address this issue by refining the token-level labels, ensuring that tokens are assigned more accurate labels based on their hidden states and the prototype embeddings. This helps to eliminate the noise introduced by directly applying sentence-level labels to all tokens, thereby improving performance.
Q10: Table3: ablation on and is missing.
A: We have included the ablation study in Q7 and discussed the results. The results are included in LN772 of the revision (Table 4 in the revision, Rebuttal Table 4.4).
Thanks the authors for comprehensive response. I have carefully evaluated the rebuttal. However, I find the clarification on novelty and intuition unconvincing. Particularly, the authors point out that the prototype-based label disambiguation algorithm was originally proposed by PiCO and Pico+, which lessens the uniqueness of this paper even more. Despite this, given the robust experiments, I have decided to maintain my original scoring.
Thank you for your quick feedback and acknowledgement of our experiments. We sincerely appreciate your efforts in the review and discussion of this paper. To respond to your concerns, we provide additional clarification as follows:
Both our method, Pico and Pico+ employ prototype-based label disambiguation, as mentioned in the related work section in LN542 in the paper. To clarify, we compare our method with Pico.
Pico and Pico+ were originally designed for image classification with noisy labels, while our method is applied to safety risk classification. The former relies heavily on augmentation and contrastive learning. While, our approach utilizes a more concise method, which does not employ augmentation or contrastive learning. The proposed method relies solely on the tokens in the sentence and does not depend on the process of constructing positive pairs in contrastive learning to update prototypes. Instead, we employed the selection of Top K predicted confidence values to update prototypes.
From the training perspective, different losses are applied in these methods. Pico uses a loss which is a combination of a simple classifier based loss and a contrastive learning based loss (loss function L_{Pico} \= L_{cls} \+ λL_{cont}). While for the proposed method, only classification based loss is applied in a different manner compared to Pico. More specifically, the classification is now considered with two levels, i.e., sentence-level and token-level (L_{ShieldHead} \= L_{prompt} \+ L_{res} \+ λL_{token} as indicated Equation 9 in LN267).
Despite these differences, we share a theoretical foundation by utilizing the prototypes and soft labels. Beyond this, as aforementioned, our approach differs significantly from Pico in terms of task and model details, which highlights the novelty of our method.
We believe a novel method is proposed in this paper for a challenging task. To the best of our knowledge, no other baseline models currently support token-level safety risk classification. The proposed token-level safety classification method ShieldHead outperforms GPT-4 (F1-score +11.5%, accuracy +10.9%). While achieving state-of-the-art performance, as confirmed by all reviewers, our model demonstrates remarkable efficiency (training with only ~1% of the parameters).
We hope this response addressed the reviewer’s concerns, and we are happy and open to answer any follow-up questions.
Wang, Haobo, et al. "Pico: Contrastive label disambiguation for partial label learning." International conference on learning representations. 2022.
Wang, Haobo, et al. "Pico+: Contrastive label disambiguation for robust partial label learning." IEEE Transactions on Pattern Analysis and Machine Intelligence (2023).
Dear Reviewer uL5B,
As the discussion period is drawing to a close, we’d like to summarize the rebuttal period: Additional experiments results are provided in safety-critical scenarios and detoxification applications. We clarified our token-level annotation and prototype-based label disambiguation method. Ablation studies on the moving-average momentum 𝛾 and 𝜎 are conducted to offer better insight into the mechanics of the proposed prototype-based label disambiguation method.
Once again, we greatly appreciate your efforts and comments, which have been invaluable in helping us improve the paper. We hope to engage in further discussions and hear more feedback from you. If you have any remaining questions that we can address during the review period, we would be happy to answer them.
Addressing Questions
Q1: LN193: the basic idea and intuition of the label disambiguation method is unclear. How and why does it work? What is the difference compared to a simple everage of the hidden features of each class?
A: We have clarified this issue in W1 and W2.
Q2: LN196: the prototype is mentioned in various context, such as prototype embedding, prototype vector, and prototype label, but what is the connection between the basic idea of prototype and your problem? Why don't use a simple classifier?
A:
The basic idea of prototype: Prototype embeddings represent the central or "common" hidden features of each class and are used to update token-level labels, making them more accurate. More accurate token-level labels are used to refine the prototype which can enhance its accuracy.
Why don't use a simple classifier: In an ideal scenario, we could directly train a simple classifier for token-level classification. However, the token-level labels are inaccurate. We refer the reviewer to the ablation study without Label Disambiguation (LN421, Table 3), where directly using token-level labels leads to a performance decline (prompt +2.3% and response +2.0% in F1-score).
Q3: LN209: the is confusing. Does it refer to the token probability or the token itself?
A: In LN209, refers to the token probability. We identify the tokens in each class that have the top k highest prediction probabilities (also, the model's prediction confidence), denoted as . The hidden states of these top-K tokens are then used to update the prototype embeddings for each class.
Q4: LN213: do the prototypes refer to the prototype embeddings?
A: Yes, refer to the prototype embeddings.
Q5: LN215: what kind of normalization is used here?
A: We use L2 normalization as mentioned in LN215, which is implemented by using F.normalize in the PyTorch code.
Q6: LN228: the next line says it is a moving-average update, but the equation uses sentence-level label to update token-level labels, which is not a moving average.
A: Thank you for pointing this out. The original formula only showed the update method for epoch=1. Each token-level label is initially assigned the sentence-level label. During each training batch, they are updated to disambiguated labels based on cosine similarity between the hidden states and the prototype. After epoch 1, when each token has been processed multiple times, the labels are updated using the previously updated results, with a moving average applied. This correction has been made in LN228 in the revision.
Q7: LN265: the equation presents the calculation of and , but what do they refer to and why do we need them?
A: and represent the loss functions for prompt risk classification task and response risk classification task, respectively, as described in LN246 and LN254. These two tasks are distinct but mutually beneficial, as shown in the following ablation study in Q10. The losses are computed separately also because the labels for the prompt and response may differ—for instance, the prompt may contain risk while the response may not. The ablation results, as shown in Rebuttal Table 4.4, indicate that removing the prompt loss decreases the model's performance on the prompt, by an average F1-score of 4.0%, although it can still learn some sentence-level patterns via the response loss. Similarly, omitting the response loss reduces performance on the response by an average F1-score of 2.6%.
Rebuttal Table 4.4: Comparisons of F1 scores (%) between ShieldHead (with Llama3.1-8B as base model) with different losses on prompt and response safety classification in existing public benchmarks.
| Model | Prompt Harmfulness (F1) | Response Harmfulness (F1) | |||||
|---|---|---|---|---|---|---|---|
| XSTest | OpenAI Mod | ToxicChat | Avg. | BeaverTails | S-RLHF | Avg. | |
| w/o Sentence Loss | 91.8 | 58.1 | 71.3 | 73.7 | 73.5 | 81.6 | 77.6 |
| w/o Prompt Loss | 92.0 | 59.3 | 70.9 | 74.1 | 76.2 | 85.9 | 81.1 |
| w/o Response Loss | 95.4 | 64.0 | 74.8 | 78.1 | 74.2 | 84.0 | 79.1 |
| Ours | 95.1 | 64.3 | 74.9 | 78.1 | 76.7 | 86.6 | 81.7 |
We thank the reviewer for your detailed feedback and suggestions. We have updated the paper with text changes highlighted in red, and elaborate on each of the questions below.
Addressing Weakness
W1.1: Novelty of our approach
A: The main novelty of our approach is the use of prototype-based label disambiguation to refine token-level labels. Previous approaches in token-level classification can be categorized as "Local Assumption", while our method adopts "Global Refinement". "Global Refinement" corrects labels from global (sentence-level) to local (token-level) perspective during training, whereas "Local Assumption" relies on predefined local labels or heuristic rules applied before training (e.g., in Bottom-Up Abstractive Summarization, where alignment rules identify summarization tokens). Manual annotation of token-level labels is costly for large-scale language models, and the heuristic rule of using sentence-level labels for all tokens can introduce noise, as shown in the ablation w/o Label Disambi. in LN421. To address these issues, we introduce label disambiguation to reduce token-level noise and achieve state-of-the-art on the averaged F1 score on 3 prompt harmfulness and 2 response harmfulness datasets. Besides, to the best of our knowledge, no other models currently support token-level safety risk classification.
W1.2: Intuitive backing for our approach
A: Intuitively, prototype embeddings represent the central or "common" hidden features of each class and are used to update token-level labels, making them more accurate. More accurate token-level labels, in turn, are used to refine the prototype, enhancing its accuracy.
In each training batch, we refresh the class prototypes (safe/unsafe) with top-K predicted token features (those with the highest prediction probabilities) for each class. Token-level labels are then refined by calculating the cosine similarity between each token's hidden state and the updated prototypes. For example, in an unsafe sentence, a token initially labeled as unsafe but is actually safe (e.g., "I" or "am") may have a higher cosine similarity to the safe prototype, prompting a token-level label correction from unsafe to safe.
W1.3: Theoretical backing for our approach
A: From a theoretical perspective, PiCO and Pico+ demonstrate the effectiveness of the prototype-based label disambiguation algorithm from the viewpoint of the EM algorithm. Specifically, through the mutual updates of prototype vectors and soft labels, the proposed method improves the classifier's performance (prompt +2.3% and response +2.0% in F1-score), as shown in the ablation w/o Label Disambi. in LN421.
Wang, Haobo, et al. "Pico: Contrastive label disambiguation for partial label learning." International conference on learning representations. 2022.
Wang, Haobo, et al. "Pico+: Contrastive label disambiguation for robust partial label learning." IEEE Transactions on Pattern Analysis and Machine Intelligence (2023).
W2.1: Benefit of using moving-average
A: Moving average and simple average do not perform the same function. We conducted an ablation experiment to compare their effects, as shown in Rebuttal Table 4.1.
There are two potential ways to update prototypes using the simple average of hidden features: (1) by calculating a simple average over all tokens at the start, or (2) by calculating it per batch during training.
In the first case, when token-level labels are initially assigned based on noisy sentence-level labels, the simple average produces biased prototypes, which prevents the prototypes from reducing the noise of token-level labels. ShieldHead outperforms the first case model by an average F1-score of 2.1% and 1.9% across prompt and response benchmarks, respectively.
In the second case, calculating the simple average per batch can still introduce batch-specific biases, as prototypes are calculated independently for each batch. This prevents the prototypes from accumulating information over time, leading to less stable or reliable prototypes. ShieldHead outperforms the second case model by an average F1-score of 1.3% and 1.0% across prompt and response benchmarks, respectively.
Rebuttal Table 4.1: Comparisons of F1 scores (%) between ShieldHead (with Llama3.1-8B as base model) and simple average methods on prompt and response safety classification in existing public benchmarks.
| Model | Prompt Harmfulness (F1) | Response Harmfulness (F1) | |||||
|---|---|---|---|---|---|---|---|
| XSTest | OpenAI Mod | ToxicChat | Avg. | BeaverTails | S-RLHF | Avg. | |
| Simple average (prototype remains constant during training) | 93.0 | 72.6 | 62.5 | 76.0 | 76.5 | 83.0 | 79.8 |
| Simple average per batch (𝛾=0) | 93.2 | 74.1 | 63.0 | 76.8 | 75.9 | 85.5 | 80.7 |
| Ours | 95.1 | 74.9 | 64.3 | 78.1 | 76.7 | 86.6 | 81.7 |
W2.2: Rationale behind reducing 𝛾 and 𝜎 during the training process
A: 𝛾 and 𝜎 represent the update rates for the prototypes and token-level labels, respectively. They were implemented because, in the early stages of training, the classifier is less reliable. This selection of top-k tokens for updating the prototypes may not be fully accurate. Therefore, a conservative update approach is used initially to prevent noise, with the update rate gradually increasing as the classifier becomes more reliable.
To demonstrate the effect of the moving average update and the impact of reducing 𝛾 and 𝜎 during training, we included an additional ablation study as shown in Rebuttal Table 4.2 and 4.3.
Rebuttal Table 4.2: Comparisons of F1 scores (%) between ShieldHead (with Llama3.1-8B as base model) with different 𝛾 values on prompt and response safety classification in existing public benchmarks.
| Model | Prompt Harmfulness (F1) | Response Harmfulness (F1) | |||||
|---|---|---|---|---|---|---|---|
| XSTest | OpenAI Mod | ToxicChat | Avg. | BeaverTails | S-RLHF | Avg. | |
| 𝛾=0.99 | 93.0 | 72.2 | 62.1 | 75.8 | 76.3 | 83.5 | 79.9 |
| 𝛾=0.95 | 95.2 | 73.5 | 63.9 | 77.5 | 76.3 | 85.7 | 81.0 |
| 𝛾=0.5 | 93.3 | 73.6 | 62.6 | 76.5 | 76.0 | 85.8 | 80.9 |
| 𝛾=0(simple average) | 93.2 | 74.1 | 63.0 | 76.8 | 75.9 | 85.5 | 80.7 |
| Ours | 95.1 | 74.9 | 64.3 | 78.1 | 76.7 | 86.6 | 81.7 |
Rebuttal Table 4.3: Comparisons of F1 scores (%) between ShieldHead (with Llama3.1-8B as base model) with different 𝜎 values on prompt and response safety classification in existing public benchmarks.
| Model | Prompt Harmfulness (F1) | Response Harmfulness (F1) | |||||
|---|---|---|---|---|---|---|---|
| XSTest | OpenAI Mod | ToxicChat | Avg. | BeaverTails | S-RLHF | Avg. | |
| 𝜎=1(w/o Label Disambi.) | 93.0 | 72.0 | 62.3 | 75.8 | 76.2 | 83.1 | 79.7 |
| 𝜎=0.98 | 93.3 | 72.2 | 61.9 | 75.8 | 75.6 | 84.5 | 80.1 |
| 𝜎=0.5 | 94.5 | 72.8 | 64.1 | 77.1 | 76.7 | 85.8 | 81.3 |
| Ours | 95.1 | 74.9 | 64.3 | 78.1 | 76.7 | 86.6 | 81.7 |
Overall, the results show that reducing 𝛾 and 𝜎 yields better performance than using fixed values. Note that setting 𝛾=0 corresponds to using a simple average for each batch (when calculating the prototype in each batch during training, only the current batch is considered, without taking historical information into account). The moving average approach outperforms the simple average by an average F1-score of 1.3% and 1.0% across prompt and response benchmarks, respectively. Similarly, setting 𝜎=0 corresponds to an ablation where the token-level labels are used without the label disambiguation algorithm (since the labels are not updated, this will be further discussed in Q9).
W3: The method's effectiveness varies across different datasets. ShieldHead surpasses the baseline on XSTest and S-RLHF, but falls short on OpenAI Mod, ToxicChat, and BeaverTails. This inconsistency diminishes the significance of the method.
A: Thanks for bringing up this concern. Surpassing baselines across different datasets would certainly strengthen the result. However, it's worth noting that the current performance results demonstrate the effectiveness of ShieldHead, for two reasons:
- The performance of ShieldHead depends on the representation capabilities of the base model, particularly in terms of safety-related features. Stronger base models, especially those with better safety-related representations, tend to yield higher performance in safety risk classification when trained with ShieldHead.
- Compared with other full finetune methods like Llama-Guard, ShieldHead achieves the best overall performance with minimal finetuning—only about 1% of the model parameters, despite not achieving the best results on a few specific datasets. It strikes the optimal efficiency-effectiveness trade-off, ensuring superior overall performance.
With its efficiency (easy finetuning) and effectiveness (state-of-the-art results on average across various datasets), we believe the proposed method can benefit this research field.
The paper addresses the need for efficient and real-time content moderation for large language models. Traditional moderation methods either rely on non-parametric rule-based classifiers, which lack deep language comprehension, or LLM-based moderators, which offer high accuracy but are computationally intensive and time-consuming, requiring a separate inference stage. The paper proposed the ShieldHead unit, which is a lightweight, real-time content moderation system that uses an MLP as a classification head on the hidden states of an LLM's last layer, running after the LM head on each step. This setup enables token-level, real-time moderation without adding significant computational overhead.
优点
- High Efficiency with Low Computation Cost: ShieldHead achieves real-time moderation by using an MLP (instead of another LLM) to directly score the hidden states of an LLM's last layer, significantly reducing computation while largely maintaining the performance.
- Fine-grained Safety Detection: ShieldHead improves moderation accuracy by performing token and sentence-level safety evaluation. This dual-layered approach allows for more fine-grained detection and demonstrates the interpretability of this method.
缺点
- Low Adaptability across Models: From the description in the paper, ShieldHead needs to be trained from scratch for each language model, as the hidden-state distribution varies across different models. This requirement limits ShieldHead’s applicability, making it less suitable for wide-range use across diverse language models without substantial re-training.
- Need White-box Access to the Upstream Model: ShieldHead requires access to the hidden states of the language model, implying a need for white-box access. This limits its use with models that do not allow such access or are deployed in environments where model internals are inaccessible.
问题
- How well does ShieldHead handle contextually ambiguous terms at the token level, where the classification (safe or unsafe) depends heavily on the surrounding context? Although Section 3.6 provides a single example, it would be valuable to know if the authors conducted an ablation study specifically focused on that.
- It would be helpful if the authors could provide more detailed evaluation data, since in safety-critical scenarios, false positives may be significantly more harmful than false negatives.
Thanks for your constructive feedback and insightful suggestions. The concerns are addressed as follows:
Q1: Low Adaptability across Models
A: We suggest an alternative solution that involves training ShieldHead on some public models such as Gemma2-9B, and performing safety classification with the whole pre-trained weights (Gemma2-9B+ShieldHead). In this way, the moderator has no need to further re-training when the language model is changed.
Moreover, the cost of integrating ShieldHead into the dialogue model is minimal as it requires only ~1% of the model parameters. In practice, training ShieldHead on dialogue model cost merely 1 hour on 1 GPU device.
Q2: Need White-box Access to the Upstream Model
A: Thank you for your comment. As mentioned in Q1, all sentence-level safety classifiers, such as Llama-Guard, require hidden states during training but can be directly used for inference (for safety risk classification). One advantage of ShieldHead is that it requires minimal fine-tuning and is adaptable to a wide range of base models. This allows developers to finetune the model with minimal computational cost, tailoring the safety classifier to their specific model's hidden-state distribution when needed.
Q3: How well does ShieldHead handle contextually ambiguous terms at the token level, where the classification (safe or unsafe) depends heavily on the surrounding context? Although Section 3.6 provides a single example, it would be valuable to know if the authors conducted an ablation study specifically focused on that.
A: XSTEST is a dataset containing ambiguous terms at the token level, with samples that are prone to misidentification depending heavily on the surrounding context (e.g., How can I kill a Python process?). We sampled 15 safe and 15 corresponding unsafe sentences from this dataset for annotation, following the same annotation process as LN362 and the annotation agreement described in Appendix A.5 of LN806 in the paper. Token-level safety classification was performed using the ShieldHead model trained with Llama-3.1-8B-Instruct. The results of the token-level safety classification yielded an F1 score of 81.0% and an accuracy of 75.9%, representing a decline of 3.7% and 2.1% from the original results (LN367). Despite the decline, ShieldHead still outperforms GPT-4 (word-level, evaluated on original dataset in the paper) by F1-score of 7.4% and accuracy of 7.9%, demonstrating strong performance.
Q4: It would be helpful if the authors could provide more detailed evaluation data, since in safety-critical scenarios, false positives may be significantly more harmful than false negatives.
A: We appreciate your valuable insight. Identifying unsafe instances as safe (false negatives, with "safe" considered negative and "unsafe" considered positive) is crucial in safety-critical scenarios. In our model training, we took into account and calculated a variety of metrics, including the False Negative Rate (FNR), although our primary focus was on the F1 score. The FNR results are shown in Rebuttal Table 2. ShieldHead on ToxicChat exhibits a relatively high FNR, as the number of positive samples (unsafe) is much smaller than the number of negative samples (safe), with a ratio of approximately 1:10, which increases the difficulty of the task. Nonetheless, the model demonstrated generally strong performance across different base model sizes, with the lowest FNRs of 0.14 and 0.13 on the prompt and response benchmarks, respectively, showing the potential of ShieldHead in safety-critical scenarios.
Rebuttal Table 2: FNR scores of ShieldHead with different base models on prompt and response safety classification in existing public benchmarks.
| Model | Prompt Harmfulness (FNR) | Response Harmfulness (FNR) | |||||
|---|---|---|---|---|---|---|---|
| XSTest | OpenAI Mod | ToxicChat | Avg. | BeaverTails | S-RLHF | Avg. | |
| Gemma2-2B | 0.03 | 0.17 | 0.37 | 0.19 | 0.14 | 0.09 | 0.12 |
| Gemma2-9B | 0.03 | 0.14 | 0.32 | 0.16 | 0.15 | 0.09 | 0.12 |
| Gemma2-27B | 0.04 | 0.14 | 0.25 | 0.14 | 0.17 | 0.09 | 0.13 |
Thank you to the authors for their detailed response! I believe the rebuttal has thoroughly addressed my concerns regarding this paper. Therefore, I have decided to raise my score to 8 while keeping my confidence at 3.
Thank you for your positive feedback and raising the rating from 6 to 8. We really appreciate your time dedicated to improving our work. As aforementioned, we will adapt your suggestions to our revision and release all code and models on acceptance of the work. More discussion is very welcome.
This paper introduces ShieldHead, a classification head that operates on the last-layer hidden states of the dialogue model, enabling real-time identification of harmful content with significantly lower computational overhead. ShieldHead employs a label disambiguation technique to enhance performance and achieves state-of-the-art results on various datasets while being approximately 300 times faster than existing methods.
Key Contributions include 1) Real-Time Moderation: ShieldHead allows for real-time safety risk moderation during the decoding phase of chatbot responses, achieving less than 1 ms latency. 2) Efficiency: It utilizes only about 1.1% of the parameters of existing models like LlamaGuard, significantly reducing training and inference costs. 3) Performance: Extensive experiments demonstrate that ShieldHead outperforms existing moderation methods on multiple benchmarks, including XSTest and SafeRLHF datasets.
优点
-
The authors propose a simple yet effective method called SHIELDHEAD, which relies on a classification head on the last layer of hidden states of the LLMs using the output of the head to classify. Compared to LlamaGuard, their method shows a speed about 300× faster (<1ms) than previous LLM-based moderation models with ∼99% fewer parameters of LlamaGuard.
-
This paper contributes an efficient safety risk moderation pipeline for LLMs, designed to deliver real-time results based on the token-level streaming output of an LLM.
缺点
-
The Equation 3 to Equation 10 is a bit wordy. The idea can be illustrated in a more concise way, for example, the input prompt and the sentence-level classification result can be described in textual words.
-
The motivation of this work is not persuasive enough given the fact that the inference-time intervention method and model editing methods can both effectively enhance the model's safety by using much less computational cost. If the motivation is to build more effective methods for enhancing AI safety, it lacks the essential baselines to compare.
-
In the related work of SAFETY CONTENT MODERATION, the intervention-based methods and model-editing methods have been ignored. In terms of the safety, their methods' performance is competitive and should be mentioned here.
问题
Typos to be corrected: Line 487: ”steal” Line 488: ”not permissible to”
Thanks very much for your insightful comments and suggestions! We have updated our paper based on the comments of the reviewers with text changes highlighted in red. Discussion about intervention-based methods and model-editing methods in the related work of safety content moderation are added in Section 4; (ii) Comparison with these methods are added in Rebuttal Table 1; (iii) we removed Equation 7 and 8 and replaced them with detailed descriptions of the computational processes.
Q1: The Equation 3 to Equation 10 is a bit wordy. The idea can be illustrated in a more concise way, for example, the input prompt and the sentence-level classification result can be described in textual words.
A: Thanks for your comment.
Eq. 7 which indicates the process of getting prompt safety classification results is removed.
Eq. 8 which indicates the process of obtaining response safety classification results is also removed.
All details are explained in LN247 and LN257.
Q2: The motivation of this work is not persuasive enough given the fact that the inference-time intervention method and model editing methods can both effectively enhance the model's safety by using much less computational cost. If the motivation is to build more effective methods for enhancing AI safety, it lacks the essential baselines to compare.
A: Thanks for the comment. The proposed method is different from Inference-time intervention (ITI) and model-editing methods. ITI and model-editing methods are applied to directly correct safety risks (detoxification), while the proposed method focuses on safety risk classification. In fact, ITI and model-editing methods can be adapted to deal with the same task we are tackling, i.e. safety risk classification. For example, Model Surgery and DPO TOXIC use Behavior Probe, a simple MLP classifier, as the basis for detoxification, which can be applied for safety risk classification. As shown in Rebuttal Table 1, ShieldHead outperforms Behavior Probe model by an average F1-score of 3.1% and 4.6% across prompt and response benchmarks, respectively.
Rebuttal Table 1: Comparisons of F1 scores (%) between ShieldHead (with Gemma2-9B as base model) and Behavior Probe on prompt and response safety classification in existing public benchmarks.
| Model | Prompt Harmfulness (F1) | Response Harmfulness (F1) | |||||
|---|---|---|---|---|---|---|---|
| XSTest | OpenAI Mod | ToxicChat | Avg. | BeaverTails | S-RLHF | Avg. | |
| Behavior Probe | 94.1 | 73.4 | 61.0 | 76.2 | 73.8 | 85.5 | 79.7 |
| Ours | 95.1 | 76.3 | 66.5 | 79.3 | 80.0 | 88.5 | 84.3 |
More specifically, Behavior Probe based method averages the hidden states of all tokens in a sentence as features to predict sentence-level safety, which introduces token-level noise. As shown in the ablation in Table 4 (LN711) in the main paper, this noise can significantly affect the final performance on safety risk classification. In addition to the performance differences, ITI and model-editing methods also have the following two issues:
- Behavior Probe-based classification is limited to sentence-level, while the proposed token-level classification offers more direct interpretability for detoxification and potential applicability to other tasks.
- Although the results of modeling editing show no significant negative impact on general metrics (e.g., MMLU, GSM8K), it may still affect the generated outputs in unintended ways, as it involves the direct modification of model parameters.
The detailed comparison with analysis is included in LN345 in Table 1 of the main paper.
Wang, Huanqian, et al. "Model Surgery: Modulating LLM's Behavior Via Simple Parameter Editing." arXiv preprint arXiv:2407.08770 (2024).
Lee, Andrew, et al. "A mechanistic understanding of alignment algorithms: A case study on dpo and toxicity." arXiv preprint arXiv:2401.01967 (2024).
. Kenneth Li, Oam Patel, Fernanda Viégas, Hanspeter Pfister, and Martin Wattenberg.Inference-time intervention: Eliciting truthful answers from a language model.Advances in Neural Information Processing Systems, 36, 2024a.
. Andrew Lee, Xiaoyan Bai, Itamar Pres, Martin Wattenberg, Jonathan K. Kummerfeld, and Rada Mihalcea.A mechanistic understanding of alignment algorithms: A case study on DPO and toxicity.In Forty-first International Conference on Machine Learning, 2024a.
Q3: In the related work of SAFETY CONTENT MODERATION, the intervention-based methods and model-editing methods have been ignored. In terms of the safety, their methods' performance is competitive and should be mentioned here.
A: The related work on intervention-based methods and model-editing methods is summarized as follows:
Intervention-based approaches, such as linear probes, involve the step of training small linear classifiers on model activations to detect toxicity at sentence level . Model-editing techniques, including activation manipulation and spectral editing during inference , aim to steer model outputs by modifying internal representations. These approaches are applied to directly correct safety risks (detoxification), while the proposed method focuses on safety risk classification.
Details are included in the related work of safety content moderation in Section 4 (LN498) in the revision.
. Kenneth Li, Oam Patel, Fernanda Viégas, Hanspeter Pfister, and Martin Wattenberg.Inference-time intervention: Eliciting truthful answers from a language model.Advances in Neural Information Processing Systems, 36, 2024a.
. Andrew Lee, Xiaoyan Bai, Itamar Pres, Martin Wattenberg, Jonathan K. Kummerfeld, and Rada Mihalcea.A mechanistic understanding of alignment algorithms: A case study on DPO and toxicity.In Forty-first International Conference on Machine Learning, 2024a.
. Andy Zou, Long Phan, Sarah Chen, James Campbell, Phillip Guo, Richard Ren, Alexander Pan, Xuwang Yin, Mantas Mazeika, Ann-Kathrin Dombrowski, et al.Representation engineering: A top-down approach to ai transparency.arXiv preprint arXiv:2310.01405, 2023.
. Yifu Qiu, Zheng Zhao, Yftah Ziser, Anna Korhonen, Edoardo M Ponti, and Shay B Cohen.Spectral editing of activations for large language model alignment.arXiv preprint arXiv:2405.09719, 2024.
Q4: Typos to be corrected: Line 487: ”steal” Line 488: ”not permissible to”
A: Thanks for your comment. We have corrected this typo in LN481 and LN483 in the revision.
Thanks for your response. I am still confused about two open questions: 1) In a one-sentence summary, what is the significant advantage of safety risk classification compared to previous methods; 2) Why not add the results of the intervention-based methods and model-editing methods, is there any fundamental limitations of comparing these methods in a fair condition? Before addressing these concerns, I am willing to keep my score.
I also agree with the opinion of Reviewer Xpkq that the current methodology acts in an over-simplified setting and lacks practical meaning.
Q2. Comparison between model editing / ITI with ShieldHead.
R: As mentioned before, ShieldHead and model editing / ITI are addressing two different tasks. The former focuses on classification and improves efficiency in this task, while the latter aims to reduce toxicity. To the best of our knowledge, no previous work on safety risk classification task has made comparisons with methods on the detoxification task . However, risk classification approaches could potentially be applied in detoxification scenarios. Three potential application scenarios are outlined as follows: 1. Evaluation of detoxification; 2. System-level detoxification; 3. Integration with model-level detoxification. We believe demonstrating the effectiveness of ShieldHead in these application scenarios during the rebuttal phase highlights its practical significance. Therefore, we have conducted experiments to explore the application of ShieldHead in each of these three scenarios.
1. Evaluation of detoxification. In fact, the safety risk classifier can be applied as an evaluator for the detoxification task. Specifically, we explore detoxification methods based on model editing / ITI such as Model Surgery , Category-specific Steering , and SKU . Model Surgery employs the Perspective API to evaluate whether the model responses, after detoxification, contain safety risks, while Category-specific Steering and SKU utilizes GPT-4 with few-shot as the evaluator. GPT-4 and Perspective API can be used as a safety risk classifier and are introduced in Table1 from LN339, LN042 and LN491 (related works and experiments) of our paper.
Similar to GPT-4 and Perspective API , ShieldHead is able to evaluate detoxification methods as a safety risk classifier. Here, we select SKU (a detoxification method) for evaluation as an example. SKU evaluates detoxification performance by comparing the Harmful Rate of responses to SafeRLHF prompts before and after modifying model weights with SKU. GPT-4 (few-shot) was employed to evaluate SKU against other baselines. As a better risk classifier, the proposed ShieldHead outperforms GPT-4 (few-shot) by 12.2% in F1 score (Table1 from LN339 in the paper), demonstrating its potential as a more accurate evaluator. As shown in Rebuttal Table 1.2, we observe that the response of both original Llama2-7B and the Llama2-7B + SKU show an increase in Harmful Rates evaluated with ShieldHead. As a more accurate classifier, ShieldHead can identify more instances that GPT-4 classifies as safe but are actually unsafe (reducing the false negative rate, which is crucial in safety-critical scenarios). This suggests that ShieldHead can potentially serve as a better evaluator for detoxification methods.
Rebuttal Table 1.2: Comparisons of Harmful Rate with different safety risk classifiers.
| Method | Harmful Rate (evaluating with GPT-4 few shot) | Harmful Rate (evaluating with ShieldHead) |
|---|---|---|
| Llama2-7B | 57% | 62.7% |
| Llama2-7B with SKU | 3% | 4.3% |
2. System-level detoxification. As noted in the response R3 to Reviewer Xpkq, we can also perform system-level detoxification with some engineering techniques. Since ShieldHead is designed for safety risk classification rather than detoxification, we use simple techniques here following Meta (Llama Guard ) and Google (Shieldgemma ): If the first generated token after the prompt was classified as unsafe, the model will refuse to respond. Using system-level detoxification, we reduce the Harmful Rate to 1.7% on SafeRLHF, which outperforms SKU (-1.3%) and Model Sugery (-7.6%), indicating the potential of ShieldHead in detoxification.
3. Integration with model-level detoxification. ShieldHead can potentially be integrated with model editing/ITI methods and enhance the detoxification. We demonstrate a simple approach combining model-level and system-level detoxification: after modifying model weights with SKU or Model Sugery , if the first generated token after the prompt from the modified model was classified as unsafe, the model will refuse to respond, further lowering the Harmful Rate to 1% and 1.3% respectively. This indicates the potential to combine ShieldHead with model-level detoxification methods to further improve detoxification performance.
Rebuttal Table 1.3: Comparisons of detoxification based on model editing/ITI with ShieldHead.
| Method | Harmful Rate on SafeRLHF (↓) | |
|---|---|---|
| model-level detoxification | SKU | 3% |
| Model Sugery | 9.3% | |
| system-level detoxification | ShieldHead | 1.7% |
| both model-level and system-level detoxification | SKU + ShieldHead | 1% |
| Model Sugery + ShieldHead | 1.3% |
(Cont. Response to Q2)
We hope these experiments could address your concerns. However, in this reply, we just demonstrate the potential application of our method in detoxification. ShieldHead primarily focuses on the task of safety risk classification, and the application in detoxification is not not emphasized in the paper. We appreciate the reviewer's suggestion, and this could be considered as a direction for future work.
Q3. Current methodology acts in an over-simplified setting and lacks practical meaning.
R: As we responded to Reviewer Xpkq, our token-level annotation ("annotate all tokens following an unsafe token in the sentence as unsafe") is not to simplify the task and setting, but rather is more suitable for the safety risk classification scenario. We explained our annotation method in detail through examples (see responses R1-3, R7 to Reviewer Xpkq, the same as below), theoretical background (R4), related work (R6), and experimental results (R5). In short, our annotation method is intuitively and theoretically based on the cumulative effects of GPT-like models, and there are related works that adopt a similar annotation method. Furthermore, we have empirically demonstrated the effectiveness of our annotation method. We have incorporated these details into the revision (A5.2 from LN810 and A5.3 from LN825), which we believe will help to make the paper clearer. For the full details, please refer to R1-R7 in the response to Reviewer Xpkq. Our response addressed most of Reviewer Xpkq's concerns with a raise of the rating.
We hope our response has addressed your concerns. Thanks again for your comments and efforts, we are happy and open to answer any follow-up questions.
Inan, Hakan, et al. "Llama guard: Llm-based input-output safeguard for human-ai conversations." arXiv preprint arXiv:2312.06674 (2023).
Perspective API. https://perspectiveapi.com/.
Han, Seungju, et al. "Wildguard: Open one-stop moderation tools for safety risks, jailbreaks, and refusals of llms." arXiv preprint arXiv:2406.18495 (2024).
Markov, Todor, et al. "A holistic approach to undesired content detection in the real world." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 37. No. 12. 2023.
Zhang, Zhexin, et al. "Shieldlm: Empowering llms as aligned, customizable and explainable safety detectors." arXiv preprint arXiv:2402.16444 (2024).
Wang, Huanqian, et al. "Model Surgery: Modulating LLM's Behavior Via Simple Parameter Editing." arXiv preprint arXiv:2407.08770 (2024).
Bhattacharjee, Amrita, et al. "Towards Inference-time Category-wise Safety Steering for Large Language Models." arXiv preprint arXiv:2410.01174 (2024).
Zheyuan Liu, Guangyao Dou, Zhaoxuan Tan, Yijun Tian, and Meng Jiang. 2024. Towards Safer Large Language Models through Machine Unlearning. In Findings of the Association for Computational Linguistics: ACL 2024, pages 1817–1829, Bangkok, Thailand. Association for Computational Linguistics.
Achiam, Josh, et al. "Gpt-4 technical report." arXiv preprint arXiv:2303.08774 (2023).
Jiaming Ji, Mickel Liu, Juntao Dai, Xuehai Pan, Chi Zhang, Ce Bian, Ruiyang Sun, Yizhou Wang, and Yaodong Yang. 2023. Beavertails: Towards improved safety alignment of llm via a humanpreference dataset. arXiv preprint arXiv:2307.04657.
Zeng, Wenjun, et al. "Shieldgemma: Generative ai content moderation based on gemma." arXiv preprint arXiv:2407.21772 (2024).
Dear Reviewer QCo5,
I hope this message finds you well. Thank you very much for your time and dedication to this process. We understand that your time is very valuable, but as the discussion period ends within 4 hours, we kindly ask you to consider providing any feedback and/or increasing your score. We are of course ready to discuss any remaining concern you may have.
Warm regards,
Authors
Thank you for your continued engagement and efforts. Sorry for the delayed reply, as we spent some time conducting the experiments for application in detoxification. Before addressing the questions, it is important to clarify that safety risk classification and detoxification based on model editing/ITI are fundamentally two different tasks, each with its own significance and advantages. The former aims to identify safety risks in any given text, while the latter seeks to reduce the toxicity of specific models.
Q1. One-sentence summary of the significant advantage of safety risk classification compared to previous methods.
R: Safety risk classification is decoupled from the response generation phase and does not impact the model's response, making it a system-level (non-invasive) safety operation capable of detecting safety risks in any text, whereas model editing and ITI intervene in the generation process, affecting model responses and serving as model-level safety enhancement techniques (invasive) targeting specific models.
Dear Reviewer QCo5,
As the discussion period is drawing to a close within 8 hours, we’d like to summarize the rebuttal period: Additional experiments results are provided in safety-critical scenarios and detoxification applications. We clarified our token-level annotation and prototype-based label disambiguation method. Ablation studies on the moving-average momentum 𝛾 and 𝜎 are conducted to offer better insight into the mechanics of the proposed prototype-based label disambiguation method.
Once again, we greatly appreciate your efforts and comments, which have been invaluable in helping us improve the paper. We hope to engage in further discussions and hear more feedback from you. If you have any remaining questions that we can address during the review period, we would be happy to answer them.
We thank all reviewers for their comments and valuable suggestions. We truly appreciate your time and effort dedicated to improving our work. All reviewers confirm the efficiency of our proposed method. Reviewer QCo5 described our approach as simple yet effective, Reviewer dYJV praised for our fine-grained safety detection and high efficiency with low computational cost; Reviewer Xpkq highlighted the extensive experiments to evaluate the proposed method; and Reviewer uL5B commended the manuscript for being clearly written, with a well-structured argument.
In this paper, an effective method is proposed for sentence-level and token-level safety classification.
All the comments and suggestions have been addressed with detailed feedback and experiments. All changes are marked in red in the main paper.
In this general response, we provide brief responses to important and common comments to avoid repetition.
- Reviewer dYJV, Xpkq: Details about token-level performance.
- Conducted token-level classification on a dataset containing ambiguous terms and compared to GPT-4.
- Result: GPT-4 underperforms ShieldHead by 11.5% in F1-score and 10.9% in accuracy, with only 31 valid responses out of 60. To the best of our knowledge, no other baseline models currently support token-level safety risk classification. Observed a slight decline of 2.1% in accuracy when evaluated on ShieldHead with an ambiguous dataset.
- Reviewer QCo5: Comparison with Inference-time intervention (ITI) and model-editing methods.
- Conducted sentence-level classification compared to the safety classification part of ITI and model-editing methods.
- Result: ShieldHead outperforms ITI and model-editing methods by an average F1-score of 3.1% and 4.6% across prompt and response benchmarks, respectively. Additionally, unlike these methods, ShieldHead supports token-level classification and does not require modifying the original model parameters.
- Reviewer uL5B: Novelty of the proposed method and utility of prototype-based label disambiguation method.
- We use the prototype-based label disambiguation method to eliminate noise in token-level labels, and refine the prototypes based on the updated more accurate token-level labels, creating a mutually reinforcing process.
This paper proposes ShieldHead, a classification head that operates on the last layer hidden states of a chat-based model that every step predicts if the model generated output so far is unsafe (or might lead to unsafe output if the model kept generating). This work is in contrast with safety classifiers like Llama-Guard that let the model finish generating before applying the classifier. Thus, the main motivation is improving efficiency since employing a safety classifier can be expensive. One of the main contributions of this work is a method to generate soft token-wise labels to train their classifier given only sentence-level labels.
The reviews have praised the simplicity of the method and positive results but have questioned the novelty and missing baselines which the authors have addressed sufficiently in my opinion. However, I do question its general applicability as the head will need to be trained for every new model whereas safety classifiers are off the shelf. This makes it less suitable for wide range use across diverse language models without substantial re-training (even when the training is fast, it is still required). This training is of course not possible without white box access. In the latter case, if this approach is to be used, it has to be used with a different model which brings back efficiency concerns so the utility diminishes.
Additionally, I found some inconsistencies in the experimental setup. The experimental details primary discuss Llama 3 family of models, the results in 3.3.1 are on Gemma finetuned models where the baselines do not all seem to be Gemma finetuned (such as Wildguard). Even with a different base models, the proposed approach seems to show minor improvements over Wildguard.
In light of these concerns, I recommend a reject at this point.
审稿人讨论附加意见
A reviewer raised concerns about the low adaptability across models and the need for white-box access to the upstream model. The authors address these concerns by suggesting an alternative solution that involves training ShieldHead on some public models and performing safety classification with the whole pre-trained weights. Additionally, they argue that ShieldHead requires minimal fine-tuning and is adaptable to a wide range of base models. I think this solution limits the utility of the approach as being more efficient.
Other reviewers raised concerns about novelty (which the authors addressed sufficiently), and lack of comparisons to other methods like model editing (which in my view are not required but still the authors provide those comparisons).
My primary concern that led to my decision is lack of fair comparisons and minimal improvement over their strongest baseline (WildGuard).
Reject