1. Introduction
Pre-trained language models based on Transformer architecture have consistently achieved exceptional performance across a wide range of natural language processing (NLP) tasks in recent years. Examples of these models include GPT [1], BERT [2], ET-BERT [3], BFCN [4], and PAL-BERT [5]. Despite excellent performance, these models are limited in their applicability to edge devices due to their large parameter count and high computational demands [6,7]. Consequently, reducing computational costs and model size while maintaining original performance has become a vital area of research [8,9].
Pre-trained language models perform well in many tasks, but in specific subjects or tasks further training and fine-tuning are usually required to improve performance, often consuming significant computing resources and time. To address this, researchers have developed various model compression and acceleration techniques for efficient deployment on edge devices, including pruning [10], quantization [11], knowledge distillation (KD) [12], and low-rank decomposition [13]. Among these, KD significantly reduces computing time and resource consumption by transferring the knowledge of a complex teacher model to a simplified student model. Although the student model may perform slightly worse than the teacher model, it can still maintain sufficient accuracy in resource-limited environments [14]. The key to efficient deployment is balancing computing requirements and model performance [15,16].
Task-specific knowledge distillation has shown significant advantages in improving model performance [17,18]. Raphael successfully compressed the BERT model by distilling task-specific knowledge into a lightweight BiLSTM model [19]. BERT-PKD guides the training of student models by extracting knowledge from the intermediate layers of BERT [20]. In addition, the researchers enabled the student model to more comprehensively capture the abstract knowledge of the teacher model by matching the internal representation of BERT [21]. DynaBERT achieves dynamic compression of the BERT model by adaptively adjusting the model width and depth [22]. BERT of Theseus reduces model complexity by gradually replacing BERT layers [23]. TinyBERT employs a two-stage distillation method to significantly reduce model size and accelerate inference [24]. However, these methods have certain limitations. Previous methods primarily focused on learning the features of the teacher model, while overlooking the interrelationships between these features, which are crucial in many complex natural language processing tasks [25]. In addition, when calculating the hidden states loss, the hidden states dimension of the student model must match that of the teacher model, otherwise additional calculations are required to align the dimensions of the two, which limits the flexibility of the student model structure.
The main research question of this paper is how to enable the student model to more effectively capture and learn the characteristics of the teacher model during the knowledge distillation process, while fundamentally solving the hidden states dimension mismatch between the teacher and student models. The goal is to significantly reduce the model parameters while maintaining high performance. We propose the Autocorrelation Matrix Knowledge Distillation (AMKD) approach to address the limitations of traditional KD methods.Unlike existing techniques such as BERT-PKD, DistilBERT, or TinyBERT, AMKD excels at capturing complex feature interactions while directly resolving hidden states dimension mismatches without additional projection layers.
AMKD not only allows the student model to learn the features of the teacher model but also captures the complex relationships among these features. The relationships between features play a critical role in NLP tasks. By minimizing the differences between the autocorrelation matrices of the student and teacher models, the student model learns not only individual feature behaviors but also the relationships between features in the teacher model. This approach effectively captures high-order feature interactions, enabling the student model to better understand and learn the complex features of the teacher model. More importantly, AMKD effectively addresses the issue of hidden states dimension mismatch between student and teacher models without requiring additional projection layers for dimension alignment. AMKD retains the essential information from the teacher model, minimizing information loss during training. Compared to traditional methods, AMKD significantly enhances the performance of the student model, demonstrating superior flexibility and robustness.
We evaluated AMKD on multiple NLP tasks and demonstrated that it significantly enhances the performance of small BERT models in specific tasks. AMKD addresses the limitations of traditional knowledge distillation in learning complex feature relationships and overcoming hidden states dimension mismatches. The resulting student models exhibit greater flexibility and robustness, offering new approaches for model compression and acceleration in resource-constrained environments.
2. Preliminaries
This section begins with an introduction to the Transformer architecture and the core components of this framework [26], followed by a discussion of KD techniques [14]. Our proposed AMKD method is developed based on these foundational concepts.
2.1. Transformer Architecture
The Transformer architecture, introduced by Vaswani et al. in 2017 [26], is a widely adopted deep learning model, particularly in natural language processing and machine translation. Unlike traditional models such as RNN [27] and CNN [28], the Transformer captures long-range dependencies through self-attention. The model consists of an encoder, which converts input sequences into hidden representations, and a decoder, which generates outputs from these representations. The key components include multi-head attention, feed-forward networks [29], residual connections [30], and layer normalization [31].
Multi-Head Attention Mechanism. The multi-head attention mechanism is a fundamental component of the Transformer architecture, enabling the model to attend to multiple segments of the input sequence concurrently. The computation of attention involves several steps, starting with the calculation of scaled dot-product attention. Given a query matrix , a key matrix , and a value matrix , the attention scores are derived using the following equation:
(1)
where denotes the dot product between the query and key matrices, reflecting their similarity, and represents the dimensionality of the key, which is used to scale the dot product to prevent gradient instability. The softmax function is subsequently applied to transform these similarities into weights, which are then multiplied by the value matrix to yield the weighted output.In the multi-head attention mechanism, the query, key, and value matrices , , and are each projected into h different subspaces through separate linear projections. These subspaces allow the model to perform attention computations in parallel across multiple different representation spaces. Specifically, for each head i, the input matrices are transformed into the subspaces using different projection matrices , , and , and the attention computation is performed as follows:
(2)
where , , and are the linear transformation matrices that project the queries, keys, and values into different subspaces. The outputs of all the heads are concatenated and then processed through a linear transformation matrix to produce the final output:(3)
where is the linear transformation matrix that maps the concatenated multi-head outputs back to the original dimensionality. This process allows the multi-head attention mechanism to capture different aspects of the input sequence from multiple subspaces, enhancing the model’s representational capability.Feed-Forward Networks. In each encoder and decoder layer of the Transformer, there is also an independent feed-forward network applied to every position in the sequence. This network consists of two linear transformations with a ReLU activation function in between:
(4)
Residual Connections and Layer Normalization. Following the feed-forward network, the output from each sub-layer undergoes processing through residual connections and layer normalization, as defined below:
(5)
2.2. Knowledge Distillation
KD is a technique that extracts knowledge from a complex deep model (teacher model) and transfers it to a smaller model (student model) [14]. By imitating the output distribution of the teacher model, the student model can reduce the complexity of the model and the demand for computing resources while maintaining high performance. This method is often used to pursue a balance between performance and efficiency when deploying models. The goal of KD is to make the output of the student model as close as possible to the output of the teacher model by optimizing a loss function. This loss function can be expressed as
(6)
where is the loss function that measures the difference between the student model and the teacher model, and represent the output results of the student model and the teacher model, respectively, z is the input data, and is the training data set. By minimizing this loss function, the student model can effectively learn the knowledge of the teacher model.3. Method
This section presents a novel distillation approach, AMKD. AMKD effectively improves the ability of the student model to capture and understand semantic relationships from the teacher model, resolving the dimensional mismatch issues encountered during the distillation process.
3.1. Overview of AMKD
The core idea of AMKD is to enable the small student model to learn from the knowledge of the large teacher model , thereby improving the performance of the student model on specific tasks under the guidance of the teacher model. Both the student and teacher models are composed of an embedding layer, several Transformer layers, and a prediction layer. Each Transformer layer consists of an attention layer and a hidden layer.
AMKD consists of four distillation components: prediction layer distillation, hidden states distillation, attention matrix distillation, and embedding layer distillation. Assuming that the teacher model contains M Transformer layers and the student model contains N Transformer layers, Figure 1 visually presents the overall distillation framework. and represent the hidden states of the i-th layer of the student model and the j-th layer of the teacher model, respectively, and D are the hidden states dimensions of the two, l represents the input sequence length, and is the matrix transpose operation. h represents the number of attention heads, and and represent the attention matrices of the student model and the teacher model, respectively.
In contrast to previous approaches, we leverage the autocorrelation matrix of hidden states to capture complex feature relationships, rather than focusing solely on the individual features of the teacher model. The student model learns not only the behavior of each feature from the teacher model but also the interactions between features. Furthermore, AMKD effectively resolves the issue of hidden states dimensional mismatches between the teacher and student models, allowing for greater flexibility in the configuration of the student model while reducing computational costs. Next, we will provide a detailed description of the four distillation strategies in AMKD.
3.2. Prediction Layer Distillation
To enable the student model to mimic the output from the prediction layer of the teacher model, we employ the KD technique introduced by Hinton et al. [14]. In particular, we compute the Kullback–Leibler (KL) divergence between the logits of the student and teacher models to more closely align the outputs of the student model with those of the teacher. The calculation is presented as follows:
(7)
where and represent the logits output by the teacher and student models, respectively. Logits refer to the raw scores produced by a neural network in a classification task, representing un-normalized predictions before the application of the softmax function. The parameter t serves as a temperature value to smooth the probability distributions. The index i is used to denote the different categories in the classification task. Additionally, we consider the hard label loss for the student model, , which is calculated as the cross-entropy between the true labels and the predicted probabilities.The total loss of the student model consists of two main components: the hard label loss and the soft label loss. The total loss is defined as follows:
(8)
where is a weighting factor, chosen from the range , that balances the hard label loss and the soft label loss.3.3. Attention Matrix Distillation
Attention matrix distillation aims to transfer the attention weight information from the Transformer structure in the teacher model to the student model. These weights capture rich linguistic information, which is crucial for natural language understanding [32]. By applying attention matrix distillation, the student model can better inherit the language comprehension capabilities of the teacher model. The loss function for attention matrix distillation is defined as follows:
(9)
where N represents the quantity of Transformer layers in the student model, h indicates the quantity of attention heads, represents the attention matrix for the k-th head in the i-th layer of the student model, l signifies the length of the input, and is the mean squared error loss. The i-th layer of the student model learns from the corresponding j-th layer of the teacher model.3.4. Hidden States Distillation
In the Transformer model, the hidden states are an important part of the intermediate representation of each layer, capturing the semantic information and features of the input sequence. In traditional knowledge distillation methods, the student model usually needs to use a projection layer to solve the problem of mismatching the hidden states dimension with the teacher model. The hidden states matrix of the student model and the hidden states matrix of the teacher model cannot be directly compared due to their different dimensions. Here, l represents the length of the input sequence, D and represent the hidden states dimensions of the teacher model and the student model, respectively, and is usually smaller than D. To solve this problem, traditional methods introduce a projection layer to project the hidden states of the student model to a dimension that matches the teacher model:
(10)
However, this method has two major limitations: first, the projection operation adds additional computational complexity, especially in large-scale model scenarios, where the computational overhead increases significantly; second, the dimension conversion process may lead to information loss, meaning that the student model is unable to fully capture the characteristic performance of the teacher model.
To overcome these problems, AMKD takes a different approach: by calculating the autocorrelation matrix of the hidden states, explicit dimensional projection is avoided. Specifically, AMKD calculates the autocorrelation matrix of the hidden states matrix of each layer, converting it from the shape to a unified shape; that is,
(11)
where represents the autocorrelation matrix. By calculating the autocorrelation matrix, AMKD effectively solves the dimension mismatch problem. No matter how different the hidden states dimensions of the student model and the teacher model are, after autocorrelation matrix conversion, they can be compared on a unified dimension.The autocorrelation matrix can not only capture the performance of a single feature, but also capture the high-order relationship between features. The formula is as follows:
(12)
where represents the inner product between the u-th row and the v-th row of the hidden states matrix . As shown in Figure 2, the left figure shows the autocorrelation matrix of the teacher model, and the right figure shows the autocorrelation matrix of the student model. By comparing these two matrices, we can see how the student model learns the complex feature relationships in the teacher model through AMKD. The color depth represents the correlation between features, and the darker the color, the stronger the correlation.In AMKD, we use mean squared error (MSE) to measure the difference between the autocorrelation matrices of the student model and the teacher model, and the hidden states distillation loss is defined as
(13)
By minimizing the MSE of the autocorrelation matrix, AMKD ensures that the student model can not only learn the individual features of the teacher model, but also capture the complex relationships between the features. This significantly improves the performance of the student model in complex tasks.
3.5. Embedding Layer Distillation
We also performed embedding layer distillation to enable the student model to learn the semantic information from the embedding layer of the teacher model. This process is similar to computing the hidden states loss: by calculating the inner product of the embedding layer outputs from both the student model and the teacher model, the embedding layer distillation loss can be derived. The formula is as follows:
(14)
where and represent the embedding layer outputs of the student model and the teacher model, respectively, and represents the matrix transpose operation. This approach enables the student model to effectively learn the key features of the embedding layer in the teacher model while also capturing the complex relationships between embedding layer features.3.6. Overall Loss Function
We combined prediction layer distillation, attention matrix distillation, hidden states distillation, and embedding layer distillation. The total distillation loss of AMKD can be expressed as follows:
(15)
where is a weighting factor, chosen from the range , that balances the hard label loss and the soft label loss.3.7. Data Augmentation
To improve the generalization and robustness of the student model, we adopted a simple yet effective data augmentation technique. The specific steps are detailed in Algorithm 1. By randomly shuffling the positions of tokens in the input sequence, we increased data diversity and reduced the dependence of the model on specific sequences, thereby enhancing the adaptability and stability of the student model. This method is easy to implement and has a low computational overhead.
| Algorithm 1 Short Disorder Data Augmentation |
|
3.8. Skip-Layer Distillation
In the skip-layer distillation method, the student model acquires knowledge by selectively learning different depth levels of the teacher model. Assuming that the number of layers of the teacher model is M and the number of layers of the student model is N, we first calculate the average skip-layer interval by dividing M by N. When M and N are not divisible, the skip-layer interval may be a non-integer. To address this, we adopt a rounding-down strategy, taking the integer part of the result as the skip-layer interval. For the remaining layers, additional learning is performed based on their importance to ensure that key information is not ignored. This skip-layer selection strategy can more comprehensively acquire knowledge at different levels of the teacher model, thereby effectively improving the overall performance of the student model.
4. Experiments
In this section, we evaluated the proposed AMKD method across various NLP tasks.
4.1. Datasets
We evaluated the proposed method using the GLUE dataset. The GLUE benchmark includes a variety of natural language understanding tasks, designed to comprehensively evaluate how well a model can handle different linguistic phenomena [33]. These tasks include Multi-Genre Natural Language Inference (MNLI), which involves determining the relationship between a premise and a hypothesis [34]; Quora Question Pairs (QQP), which evaluates the semantic equivalence between question pairs [35]; Question Answering Natural Language Inference (QNLI), which determines whether a given context contains the answer to a question; Stanford Sentiment Treebank (SST-2), which involves the sentiment classification of movie review sentences [36]; Microsoft Research Paraphrase Corpus (MRPC), which assesses whether sentence pairs are paraphrases [37]; and Recognizing Textual Entailment (RTE), which involves a determination of whether one sentence entails another [38].
For the machine reading comprehension task, we used the Stanford Question Answering Dataset (SQuAD v1.1) for evaluation [39]. This dataset, created by Rajpurkar et al. in 2016, contains 100,000 question–answer pairs collected through crowdsourcing, with the task of finding the text snippet within a Wikipedia passage that answers the question. SQuAD v2.0 builds on this by introducing cases where no clear answer is available, making the task more realistic and requiring the model to determine whether an answer exists and appropriately handle cases with no answer [40].
4.2. Training Details and Baselines
In our experiment, we used a model fine-tuned for specific tasks as the teacher model [2]. The model has 12 layers (), a hidden size of 768 (), a feed-forward size of 3072 (), and 12 attention heads (), totaling 109 million parameters. The student model was a pre-trained language model, , with a smaller parameter size. The model has 4 layers (), a hidden size of 312 (), a feed-forward size of 1200 (), and 12 attention heads (), totaling 14.5 million parameters.
The distillation temperature t was chosen from the set {1, 4, 8}, while the value of was chosen from {0.2, 0.5, 0.8, 1}. The learning rate was chosen from , and the batch size was chosen from {16, 32}. For the fine-tuning task, the number of epochs was set to 5 without data augmentation, whereas the distillation task was conducted over 25 epochs with data augmentation. The sequence length for task-specific distillation was consistently set to 128. During the skip learning process, we set , meaning that learned from at every third layer. The best results were selected from each experiment.
We compared our method, -AMKD, with several baseline methods, including BERT-KD [14], which represents the traditional prediction layer distillation method, BERT-PKD, and DistilBERT. We fine-tuned the published pre-trained models -PKD [20] and [19] on each specific task using the recommended hyperparameters. Table 1 provides a structural comparison of the models used in our experiments. BERT
4.3. Experimental Results on GLUE
The experimental results are presented in Table 3. BERTBASE (Teacher) represents our implementation of the BERT teacher model [2]. BERTTINY-FT shows the results of directly fine-tuning the pre-trained BERTTINY model on each task. BERTTINY-KD represents the results of applying the prediction distillation method to the BERTTINY model [14]. BERTTINY-AMKD is our proposed method. The results of TinyBERT4 are obtained by directly calculating the MSE of the hidden states. Other baseline data are obtained by directly fine-tuning the publicly available pre-trained models [19,20]. Accuracy is used as the evaluation metric for all tasks except for the MRPC task, which uses F1 score. The number of training samples for each dataset is provided below the dataset name.
The results from the four-layer student models suggest that a substantial reduction in model size results in a significant performance gap between BERTTINY-FT (or BERTSMALL) and BERTBASE.
BERTTINY-AMKD outperforms BERTTINY-FT in all GLUE tasks, with an average improvement of 4.7%, indicating that the proposed AMKD method can significantly enhance the performance of small models in various downstream tasks. In addition, BERTTINY-AMKD achieved an average score of 83.6% in the GLUE task, which is significantly higher than the 79.5% of the traditional knowledge distillation method BERTTINY-KD.
BERTTINY-AMKD achieves 96.3% of the performance of the BERTBASE teacher model while using only 13.3% of the parameters. This demonstrates its ability to maintain high accuracy while significantly reducing computational resources. BERTTINY-AMKD outperforms the four-layer KD baseline models BERT4-PKD and DistilBERT4 by 2.6% and 3.9%, respectively, while utilizing only 28% of their parameters.
Compared to TinyBERT4, BERTTINY-AMKD has obvious advantages in performance and flexibility. TinyBERT4 uses MSE loss directly on hidden states, while AMKD not only learns individual features but also captures high-order dependencies through the autocorrelation matrix. This makes BERTTINY-AMKD outperform TinyBERT4 on average in all GLUE tasks, an improvement of 1.7%. Figure 3 shows the average performance of different models on the GLUE task. It can be seen intuitively that -AMKD has significantly improved performance compared to other models.
In the MNLI task, BERTTINY-AMKD achieved accuracy rates of 82.1% on MNLI-m and 81.7% on MNLI-mm, significantly higher than the 76.5% and 76.1% achieved by BERTTINY-KD. Furthermore, it outperformed the 79.9% and 79.3% achieved by BERT4-PKD, as well as the 78.9% and 78.0% achieved by DistilBERT4. On the CoLA (Corpus of Linguistic Acceptability) dataset, despite the significant performance gap between all four layer models and the teacher model, BERTTINY-AMKD still achieved notable improvements. To further analyze the performance of different methods on GLUE tasks, we give a detailed comparison of several methods in Figure 4. AMKD outperforms other models in every task, with a performance almost close to that of the teacher model.
BERT4-PKD and DistilBERT4 initialize the student model from specific layers of BERTBASE during training, which requires the student model to match the hidden states dimensions of the teacher model. In contrast, our proposed AMKD method allows for greater flexibility in the configuration of the student model. AMKD not only allows the student model to learn the features of the teacher model, but also captures the complex relationships among these features.
4.4. Experimental Results on SQuAD
We further validated the effectiveness of AMKD in question answering (QA) tasks using the SQuAD v1.1 and SQuAD v2.0 datasets. Unlike the GLUE tasks, QA tasks require more nuanced knowledge to identify the correct answer, which makes the learning process more complex. We did not use data augmentation in the experiments. The experimental results are shown in Table 4.
BERTTINY-AMKD outperformed the two four-layer baseline models, BERT4-PKD and DistilBERT4, on both the SQuAD v1.1 and SQuAD v2.0 datasets. These results once again demonstrate the effectiveness of the AMKD method in capturing and learning teacher model features. Even in challenging tasks like question answering, the performance of the student model is significantly improved.
4.5. Performance of AMKD-Last vs. AMKD-Skip
We compared two techniques: BERTTINY(AMKD-Last) and BERTTINY(AMKD-Skip). AMKD-Last means that the student model only learns from the last few layers of the teacher model, while AMKD-Skip lets the student model extract knowledge from the teacher model every k layers. Table 5 summarizes the experimental results of these two AMKD methods.
Although both methods outperform the KD baseline, BERTTINY(AMKD-Skip) shows slightly better performance than BERTTINY(AMKD-Last). This performance advantage is likely due to the ability of AMKD-Skip to extract information every k layers, capturing a broader range of semantic representations from lower to higher layers. In contrast, AMKD-Last focuses only on the final layers, resulting in less comprehensive semantic information.
4.6. Ablation Studies
This section explores the impact of different parts of the Transformer architecture on the effectiveness of distillation through ablation experiments.
The ablation experiment results in Table 6 demonstrate that removing different distillation objectives significantly affects the learning performance of AMKD. The removal of attention distillation (w/o Atten) has the most substantial impact on overall model performance, particularly on the MNLI task. Similarly, removing hidden layer distillation (w/o Hidden) also results in a significant performance reduction. In contrast, removing embedding layer distillation (w/o Embed) causes a smaller decrease in performance.
5. Conclusions
BERTTINY-AMKD significantly outperforms traditional distillation methods and other four-layer baseline models. In the GLUE task, BERTTINY-AMKD achieves an average score of 83.6%, which is 4.1% higher than BERTTINY-KD, 2.6% higher than BERT4-PKD, and 3.9% higher than DistilBERT4. In the SQuAD benchmark test, the EM and F1 scores of BERTTINY-AMKD in SQuAD v1.1 increased to 72.1% and 81.8%, respectively, also surpassing other four-layer baseline models. Ablation experiments show that attention matrix distillation and hidden states distillation are crucial to performance improvement, while the layer-skipping distillation strategy can more comprehensively obtain different levels of knowledge from the teacher model. These experimental results demonstrate that the AMKD method can significantly improve the accuracy of small models in natural language processing tasks.
By introducing the autocorrelation matrix, AMKD not only effectively captures the complex relationships between the features of the teacher model, but also skillfully handles the dimensional differences between the student model and the teacher model, reducing information loss. This method is suitable not only for classification tasks but also for a range of applications, including reading comprehension and regression tasks. AMKD significantly improves the performance of small BERT models in NLP tasks, providing an efficient and flexible solution for deploying pre-trained language models in resource-constrained environments. The AMKD method we proposed is mainly used for the distillation of specific tasks. If conditions permit, distillation can be performed in the pre-training stage in the future to generate a general small BERT model suitable for a wider range of tasks. In addition, it is difficult to adapt the fixed temperature parameter t in the distillation process to different samples. In the future, it will be possible to consider dynamically adjusting the temperature to better adapt to the needs of different samples.
Conceptualization, K.Z.; methodology, J.L.; software, J.L.; validation, J.L. and B.W.; formal analysis, J.L.; investigation, H.M.; resources, K.Z.; data curation, B.W.; writing—original draft preparation, J.L.; writing—review and editing, K.Z.; visualization, B.W.; supervision, K.Z.; project administration, K.Z.; funding acquisition, K.Z. All authors have read and agreed to the published version of the manuscript.
Not applicable.
Not applicable.
The original contributions presented in the study are included in the article; further inquiries can be directed to the corresponding author.
The authors declare no conflicts of interest.
The following abbreviations are used in this manuscript:
| AMKD | Autocorrelation Matrix Knowledge Distillation |
| NLP | Natural Language Processing |
| KD | Knowledge Distillation |
Footnotes
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.
Figure 1. An overview of the AMKD method, showing how to distill the knowledge of an M-layer Transformer teacher model [Forumla omitted. See PDF.] to an N-layer Transformer student model [Forumla omitted. See PDF.].
Figure 2. Comparison of the autocorrelation matrix between the teacher model and the student model.
Figure 3. Comparison of average performance of AMKD and other methods on GLUE tasks.
Figure 4. Radar chart comparison of performance of AMKD and other methods on GLUE tasks.
Comparison of architecture and parameter size between the BERT teacher model and student models.
| Model | #Layer | Hidden Size | Feed-Forward | Speedup | #Params | Relative |
|---|---|---|---|---|---|---|
| BERTBASE | 12 | 768 | 3072 | 1.0× | 109M | 100% |
| BERTTINY | 4 | 312 | 1200 | 9.4× | 14.5M | 13.3% |
| BERTSMALL | 4 | 512 | 2048 | 5.7× | 29.2M | 26.8% |
| BERT4-PKD | 4 | 768 | 3072 | 3.0× | 52.2M | 47.9% |
| DistilBERT4 | 4 | 768 | 3072 | 3.0× | 52.2M | 47.9% |
Comparison of distillation components in different approaches.
| Model | Teacher Model | Prediction Layer | Embedding Layer | Hidden States | Attention Matrix |
|---|---|---|---|---|---|
| BERT-KD | BERTBASE | ✓ | |||
| DistilBERT | BERTBASE | ✓ | ✓ | ||
| BERT-PKD | BERTBASE | ✓ | ✓ | ||
| BERT-AMKD | BERTBASE | ✓ | ✓ | ✓ | ✓ |
Comprehensive evaluation results of various models on the GLUE benchmark. The best result for each task is shown in bold.
| Model | MNLI-m | MNLI-mm | QQP | QNLI | SST-2 | MRPC | RTE | Avg |
|---|---|---|---|---|---|---|---|---|
| (393K) | (393K) | (364K) | (105K) | (67K) | (3.7K) | (2.5K) | ||
| BERTBASE (Google) | 84.6 | 83.4 | 89.2 | 90.5 | 93.5 | 88.9 | 66.4 | 85.2 |
| BERTBASE (Teacher) | 84.3 | 83.2 | 91.3 | 91.6 | 93.7 | 89.8 | 73.6 | 86.8 |
| BERTTINY-FT | 75.4 | 74.9 | 83.5 | 84.8 | 87.6 | 83.2 | 62.6 | 78.9 |
| BERTSMALL | 77.6 | 77.0 | 87.0 | 86.4 | 89.7 | 83.4 | 61.8 | 80.4 |
| BERT4-PKD | 79.9 | 79.3 | 88.3 | 85.1 | 89.4 | 82.6 | 62.3 | 81.0 |
| DistilBERT4 | 78.9 | 78.0 | 87.6 | 85.2 | 91.4 | 82.4 | 54.1 | 79.7 |
| TinyBERT4 | 80.5 | 81.0 | 88.5 | 85.7 | 90.7 | 83.3 | 63.5 | 81.9 |
| BERTTINY-KD | 76.5 | 76.1 | 84.6 | 85.1 | 88.2 | 83.5 | 62.8 | 79.5 |
| BERTTINY-AMKD | 82.1 | 81.7 | 89.4 | 87.9 | 92.1 | 86.9 | 65.4 | 83.6 |
Comprehensive evaluation results of baseline models and BERTTINY-AMKD on SQuAD. Evaluation metrics: EM (Exact Match) and F1 (F1 Score).
| Model | SQuAD v1.1 | SQuAD v2.0 | ||
|---|---|---|---|---|
| EM | F1 | EM | F1 | |
| BERTBASE (Teacher) | 79.8 | 87.9 | 73.2 | 76.4 |
| BERT4-PKD | 68.7 | 78.1 | 59.6 | 63.9 |
| DistilBERT4 | 70.2 | 79.8 | 59.4 | 63.2 |
| BERTTINY-AMKD | 72.1 | 81.8 | 65.2 | 68.4 |
Performance comparison between AMKD-Last and AMKD-Skip on the GLUE benchmark.
| Model | MNLI-m | MNLI-mm | QQP | QNLI | SST-2 | MRPC | RTE | Avg |
|---|---|---|---|---|---|---|---|---|
| BERTTINY (AMKD-Last) | 81.1 | 80.5 | 88.2 | 86.3 | 91.8 | 87.0 | 65.0 | 82.8 |
| BERTTINY (AMKD-Skip) | 82.1 | 81.7 | 89.4 | 87.9 | 92.1 | 86.9 | 65.4 | 83.6 |
Ablation studies on distillation components in AMKD learning.
| Model | MNLI-m | MNLI-mm | SST-2 | MRPC | Avg |
|---|---|---|---|---|---|
| BERTTINY-AMKD | 82.1 | 81.7 | 92.1 | 86.9 | 85.7 |
| w/o Embed | 81.4 | 81.3 | 91.2 | 86.2 | 85.0 |
| w/o Atten | 79.3 | 78.7 | 89.5 | 83.6 | 82.8 |
| w/o Hidden | 79.5 | 78.9 | 91.0 | 84.2 | 83.4 |
References
1. Achiam, O.J.; Adler, S.; Agarwal, S.; Ahmad, L.; Akkaya, I.; Aleman, F.L.; Almeida, D.; Altenschmidt, J.; Altman, S.; Anadkat, S. et al. GPT-4 Technical Report; OpenAI: San Francisco, CA, USA, 2023.
2. Devlin, J.; Chang, M.W.; Lee, K.; Toutanova, K. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv; 2018; arXiv: 1810.04805
3. Lin, X.; Xiong, G.; Gou, G.; Li, Z.; Shi, J.; Yu, J. ET-BERT: A Contextualized Datagram Representation with Pre-training Transformers for Encrypted Traffic Classification. Proceedings of the ACM Web Conference 2022; Lyon, France, 25–29 April 2022.
4. Shi, Z.; Luktarhan, N.; Song, Y.; Tian, G. BFCN: A Novel Classification Method of Encrypted Traffic Based on BERT and CNN. Electronics; 2023; 12, 516. [DOI: https://dx.doi.org/10.3390/electronics12030516]
5. Zheng, W.; Lu, S.; Cai, Z.; Wang, R.; Wang, L.; Yin, L. PAL-BERT: An Improved Question Answering Model. Comput. Model. Eng. Sci.; 2023; 139, pp. 2729-2745. [DOI: https://dx.doi.org/10.32604/cmes.2023.046692]
6. Wu, T.; Hou, C.; Zhao, Z.; Lao, S.; Li, J.; Wong, N.; Yang, Y. Weight-Inherited Distillation for Task-Agnostic BERT Compression. arXiv; 2023; arXiv: 2305.09098
7. Piao, T.; Cho, I.; Kang, U. SensiMix: Sensitivity-Aware 8-bit index & 1-bit value mixed precision quantization for BERT compression. PLoS ONE; 2022; 17, e0265621.
8. Liu, Y.; Lin, Z.; Yuan, F. ROSITA: Refined BERT cOmpreSsion with InTegrAted techniques. arXiv; 2021; arXiv: 2103.11367[DOI: https://dx.doi.org/10.1609/aaai.v35i10.17056]
9. Lin, Y.J.; Chen, K.Y.; Kao, H.Y. LAD: Layer-Wise Adaptive Distillation for BERT Model Compression. Sensors; 2023; 23, 1483. [DOI: https://dx.doi.org/10.3390/s23031483]
10. Hoefler, T.; Alistarh, D.; Ben-Nun, T.; Dryden, N.; Peste, A. Sparsity in Deep Learning: Pruning and growth for efficient inference and training in neural networks. arXiv; 2021; arXiv: 2102.00554
11. Zhang, J.; Zhou, Y.; Saab, R. Post-training Quantization for Neural Networks with Provable Guarantees. SIAM J. Math. Data Sci.; 2022; 5, pp. 373-399. [DOI: https://dx.doi.org/10.1137/22M1511709]
12. Muksimova, S.; Umirzakova, S.; Mardieva, S.; Cho, Y.I. Enhancing Medical Image Denoising with Innovative Teacher–Student Model-Based Approaches for Precision Diagnostics. Sensors; 2023; 23, 9502. [DOI: https://dx.doi.org/10.3390/s23239502]
13. Kaushal, A.; Vaidhya, T.; Rish, I. LORD: Low Rank Decomposition Of Monolingual Code LLMs For One-Shot Compression. arXiv; 2023; arXiv: 2309.14021
14. Hinton, G.; Vinyals, O.; Dean, J. Distilling the knowledge in a neural network. arXiv; 2015; arXiv: 1503.02531
15. Qi, P.; Zhou, X.; Ding, Y.; Zhang, Z.; Zheng, S.; Li, Z. FedBKD: Heterogenous Federated Learning via Bidirectional Knowledge Distillation for Modulation Classification in IoT-Edge System. IEEE J. Sel. Top. Signal Process.; 2023; 17, pp. 189-204. [DOI: https://dx.doi.org/10.1109/JSTSP.2022.3224597]
16. Jiao, X.; Yin, Y.; Shang, L.; Jiang, X.; Chen, X.; Li, L.; Wang, F.; Liu, Q. LightMBERT: A Simple Yet Effective Method for Multilingual BERT Distillation. arXiv; 2021; arXiv: 2103.06418
17. Jiang, M.; Lin, J.; Wang, Z.J. ShuffleCount: Task-Specific Knowledge Distillation for Crowd Counting. Proceedings of the 2021 IEEE International Conference on Image Processing (ICIP); Anchorage, AK, USA, 19–22 September 2021; pp. 999-1003.
18. Wu, Y.; Chanda, S.; Hosseinzadeh, M.; Liu, Z.; Wang, Y. Few-Shot Learning of Compact Models via Task-Specific Meta Distillation. Proceedings of the 2023 IEEE/CVF Winter Conference on Applications of Computer Vision (WACV); Waikoloa, HI, USA, 3–7 January 2022; pp. 6254-6263.
19. Tang, R.; Lu, Y.; Liu, L.; Mou, L.; Vechtomova, O.; Lin, J. Distilling task-specific knowledge from bert into simple neural networks. arXiv; 2019; arXiv: 1903.12136
20. Sun, S.; Cheng, Y.; Gan, Z.; Liu, J. Patient knowledge distillation for bert model compression. arXiv; 2019; arXiv: 1908.09355
21. Aguilar, G.; Ling, Y.; Zhang, Y.; Yao, B.; Fan, X.; Guo, E. Knowledge Distillation from Internal Representations. arXiv; 2019; arXiv: 1910.03723[DOI: https://dx.doi.org/10.1609/aaai.v34i05.6229]
22. Hou, L.; Huang, Z.; Shang, L.; Jiang, X.; Chen, X.; Liu, Q. Dynabert: Dynamic bert with adaptive width and depth. Adv. Neural Inf. Process. Syst.; 2020; 33, pp. 9782-9793.
23. Xu, C.; Zhou, W.; Ge, T.; Wei, F.; Zhou, M. Bert-of-theseus: Compressing bert by progressive module replacing. arXiv; 2020; arXiv: 2002.02925
24. Jiao, X.; Yin, Y.; Shang, L.; Jiang, X.; Chen, X.; Li, L.; Wang, F.; Liu, Q. Tinybert: Distilling bert for natural language understanding. arXiv; 2019; arXiv: 1909.10351
25. Sanh, V.; Debut, L.; Chaumond, J.; Wolf, T. DistilBERT, a distilled version of BERT: Smaller, faster, cheaper and lighter. arXiv; 2019; arXiv: 1910.01108
26. Vaswani, A.; Shazeer, N.; Parmar, N.; Uszkoreit, J.; Jones, L.; Gomez, A.N.; Kaiser, Ł.; Polosukhin, I. Attention is all you need. In Proceedings of the Advances in Neural Information Processing Systems; Long Beach, CA, USA, 4–9 December 2017; Volume 30.
27. Luo, Y.; Yu, J. Music Source Separation With Band-Split RNN. IEEE/ACM Trans. Audio Speech Lang. Process.; 2022; 31, pp. 1893-1901. [DOI: https://dx.doi.org/10.1109/TASLP.2023.3271145]
28. Alzubaidi, L.; Zhang, J.; Humaidi, A.J.; Al-dujaili, A.; Duan, Y.; Al-Shamma, O.; Santamaría, J.I.; Fadhel, M.A.; Al-Amidie, M.; Farhan, L. Review of deep learning: Concepts, CNN architectures, challenges, applications, future directions. J. Big Data; 2021; 8, 53. [DOI: https://dx.doi.org/10.1186/s40537-021-00444-8]
29. Sonkar, S.; Baraniuk, R. Investigating the Role of Feed-Forward Networks in Transformers Using Parallel Attention and Feed-Forward Net Design. arXiv; 2023; arXiv: 2305.13297
30. Biçici, E.; Kanburoglu, A.B.; Türksoy, R.T. Residual Connections Improve Prediction Performance. Proceedings of the 2023 4th International Informatics and Software Engineering Conference (IISEC); Ankara, Turkiye, 21–22 December 2023; pp. 1-5.
31. Cui, Y.; Xu, Y.; Peng, R.; Wu, D. Layer Normalization for TSK Fuzzy System Optimization in Regression Problems. IEEE Trans. Fuzzy Syst.; 2023; 31, pp. 254-264. [DOI: https://dx.doi.org/10.1109/TFUZZ.2022.3185464]
32. Sáenz, C.A.C.; Becker, K. Understanding stance classification of BERT models: An attention-based framework. Knowl. Inf. Syst.; 2023; 66, pp. 419-451. [DOI: https://dx.doi.org/10.1007/s10115-023-01962-y]
33. Wang, A.; Singh, A.; Michael, J.; Hill, F.; Levy, O.; Bowman, S.R. GLUE: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding. arXiv; 2018; arXiv: 1804.07461
34. Williams, A.; Nangia, N.; Bowman, S.R. A broad-coverage challenge corpus for sentence understanding through inference. arXiv; 2017; arXiv: 1704.05426
35. Chen, Z.; Zhang, H.; Zhang, X.; Zhao, L. Quora Question Pairs. 2018. Online Resource. Available online: https://api.semanticscholar.org/CorpusID:233225749 (accessed on 8 October 2024).
36. Socher, R.; Perelygin, A.; Wu, J.; Chuang, J.; Manning, C.D.; Ng, A.Y.; Potts, C. Recursive deep models for semantic compositionality over a sentiment treebank. Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing; Seattle, WA, USA, 18–21 October 2013; pp. 1631-1642.
37. Dolan, B.; Brockett, C. Automatically constructing a corpus of sentential paraphrases. Proceedings of the Third international workshop on paraphrasing (IWP2005); Jeju Island, Republic of Korea, 14 October 2005.
38. Bentivogli, L.; Clark, P.; Dagan, I.; Giampiccolo, D. The Fifth PASCAL Recognizing Textual Entailment Challenge. TAC; 2009; 7, 1.
39. Rajpurkar, P.; Zhang, J.; Lopyrev, K.; Liang, P. Squad: 100,000+ questions for machine comprehension of text. arXiv; 2016; arXiv: 1606.05250
40. Rajpurkar, P.; Jia, R.; Liang, P. Know what you don’t know: Unanswerable questions for SQuAD. arXiv; 2018; arXiv: 1806.03822
41. Turc, I.; Chang, M.W.; Lee, K.; Toutanova, K. Well-read students learn better: On the importance of pre-training compact models. arXiv; 2019; arXiv: 1908.08962
You have requested "on-the-fly" machine translation of selected content from our databases. This functionality is provided solely for your convenience and is in no way intended to replace human translation. Show full disclaimer
Neither ProQuest nor its licensors make any representations or warranties with respect to the translations. The translations are automatically generated "AS IS" and "AS AVAILABLE" and are not retained in our systems. PROQUEST AND ITS LICENSORS SPECIFICALLY DISCLAIM ANY AND ALL EXPRESS OR IMPLIED WARRANTIES, INCLUDING WITHOUT LIMITATION, ANY WARRANTIES FOR AVAILABILITY, ACCURACY, TIMELINESS, COMPLETENESS, NON-INFRINGMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. Your use of the translations is subject to all use restrictions contained in your Electronic Products License Agreement and by using the translation functionality you agree to forgo any and all claims against ProQuest or its licensors for your use of the translation functionality and any output derived there from. Hide full disclaimer
© 2024 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https://creativecommons.org/licenses/by/4.0/). Notwithstanding the ProQuest Terms and Conditions, you may use this content in accordance with the terms of the License.
Abstract
Pre-trained language models perform well in various natural language processing tasks. However, their large number of parameters poses significant challenges for edge devices with limited resources, greatly limiting their application in practical deployment. This paper introduces a simple and efficient method called Autocorrelation Matrix Knowledge Distillation (AMKD), aimed at improving the performance of smaller BERT models for specific tasks and making them more applicable in practical deployment scenarios. The AMKD method effectively captures the relationships between features using the autocorrelation matrix, enabling the student model to learn not only the performance of individual features from the teacher model but also the correlations among these features. Additionally, it addresses the issue of dimensional mismatch between the hidden states of the student and teacher models. Even in cases where the dimensions are smaller, AMKD retains the essential features from the teacher model, thereby minimizing information loss. Experimental results demonstrate that BERTTINY-AMKD outperforms traditional distillation methods and baseline models, achieving an average score of 83.6% on GLUE tasks. This represents a 4.1% improvement over BERTTINY-KD and exceeds the performance of BERT4-PKD and DistilBERT4 by 2.6% and 3.9%, respectively. Moreover, despite having only 13.3% of the parameters of BERTBASE, the BERTTINY-AMKD model retains over 96.3% of the performance of the teacher model, BERTBASE.
You have requested "on-the-fly" machine translation of selected content from our databases. This functionality is provided solely for your convenience and is in no way intended to replace human translation. Show full disclaimer
Neither ProQuest nor its licensors make any representations or warranties with respect to the translations. The translations are automatically generated "AS IS" and "AS AVAILABLE" and are not retained in our systems. PROQUEST AND ITS LICENSORS SPECIFICALLY DISCLAIM ANY AND ALL EXPRESS OR IMPLIED WARRANTIES, INCLUDING WITHOUT LIMITATION, ANY WARRANTIES FOR AVAILABILITY, ACCURACY, TIMELINESS, COMPLETENESS, NON-INFRINGMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. Your use of the translations is subject to all use restrictions contained in your Electronic Products License Agreement and by using the translation functionality you agree to forgo any and all claims against ProQuest or its licensors for your use of the translation functionality and any output derived there from. Hide full disclaimer





