This section briefly introduces the proposed Joint Learning Attention Network (JLAN), as shown in Fig. 3.
JLAN is made up of three parts. The first part is to capture the semantic information of the dataset using a residual neural network and bidirectional long short-term memory (Bi-LSTM) network. The second part extracts appropriate information from the label attention and self-attention mechanism, called joint learning. The third part introduces a denoising mechanism to reduce the noise in the training samples and help the model converge faster. Finally, medical code prediction results have been significantly improved.
Specifically, we use the self-attention mechanism for clinical texts to identify the code-related components from each document. At the same time, we introduce the label attention mechanism to make ICD codes attend to clinical document representation. We design the joint learning strategy to output the comprehensive document representation to adapt the two parts.
In addition, we consider the noise problem of clinical diagnosis and capture the noise through an auxiliary noise model over the classifier model. We first assign a probability score to each training sample. Then, we use this score to guide the learning of the noise model selectively. Our function constrains the noise sample within the noise model and drives the classifier to learn from the clean training samples.
Problem definition
Let \({T = \{({x}_{i}, {y}_{i})\}}_{i=1}^{N}\) denote the clinical texts, which contain N documents with related medical codes \({{Y}_{i}= \{yi\in \{\mathrm{0,1}\}}^{C}\}\). Where \(C\) is the number of all labels. Every word can be encoded to a low-dimension space and represented as a \(n\)-dimension vector via the word2vector technique [20]. Let \({x}_{i} = \{{w}_{1},\dots \dots {w}_{n} \}\) denote the \(i\_th\) clinical record, \({w}_{n}\) is the \(n\_th\) word vector in the clinical record.
For the ICD coding task, each code contains text information. Therefore, the code can be represented as an embedding vector. The set of codes can be encoded by a trainable matrix \(M\). Our model trains the classifier to assign the most relevant codes to the newly arriving record by learning the input document and their associated codes.
Input representation
Word embedding has been widely used in neural networks to capture the basic semantic information of words effectively. Generally, clinical notes are written by medical professionals. Thus, we use a distributed representation to obtain a word vector closer to the meaning of the target word.
Our model uses a word list \(c = \{{c}_{1}, {c}_{2}, \dots \dots , {c}_{n} \}\) as input, n denotes the length of the sequence. Let E means the word embedding matrix, which is pretrained via word2vec [20] from the dataset. Hence, the input can be replaced by a matrix \(E = \{{e}_{1},{e}_{2},\dots \dots {e}_{n}\}\), \({e}_{n}\) is the word vector.
Residual convolutional network
To solve the degradation problem of the deep neural network, we introduce the residual neural network into the model. Specifically, the residual neural network can make models converge faster and help us adopt a deeper design for the feedforward neural network. We input the word embedding matrix into the residual block [21]. Thus, the residual block can be formalized as:
$${Y}_{i}=F\left({E}_{i},\left\{{W}_{i}\right\}\right)+h({E}_{i})$$
(1)
$${E}_{i+1}=ReLU\left({Y}_{i}\right)$$
(2)
where \(E,Y\) indicates the input and output of this layer, the \(F({E}_{i},\{{W}_{i}\})\) indicates the residual mappings. A residual block consists of two parts. The first part goes through the convolution network and activation function, and the second part uses shortcut connections to add the input of this layer to the output of the first part. Finally, the added result is fed to the output layer through the activation function to complete the processing of residual blocks.
Bidirectional LSTM layer
To capture each word's forward and backward contextual information in each clinical text, we adopt the Bi-LSTM model [22] to learn the word embedding of each clinical record. In addition, Bi-LSTM can keep long dependent information and overcome gradient vanishing problems. Therefore, it is fit to capture the long-term dependency feature. At time step \(d\), the hidden state can be updated with the help of input and the \(\left(d-1\right)\_th\) step output, we compute the vectors as:
$${\overrightarrow {{h_{d} }} = LSTM\left( {\overrightarrow {{h_{d - 1} }} ,w_{d} } \right)}$$
(3)
$${\overleftarrow {{h_{d} }} = LSTM\left( {\overleftarrow {{h_{d - 1} }} ,w_{d} } \right)}$$
(4)
$${h_{d} = \overrightarrow {{h_{d} }} \oplus \overleftarrow {{h_{d} }} }$$
(5)
The dimensionality of the hidden state is set to k, resulting in the size of Bi-LSTM vectors \({h}_{d}\) at 2k. Therefore, the whole document can be represented as a matrix \(H=[{h}_{1},{h}_{2},\dots ,{h}_{n}]\in {R}^{2k\times n}\).
Dual attention network
The difficulty of the long-tail problem is that most labels have rare instances. Therefore, classifying labels in a limited number of instances has become an urgent problem to be solved. The attention mechanism can give more weight to a small part of crucial information when processing extensive data. This mechanism is naturally suitable for dealing with long-tail problems. Moreover, the number of cases between different diseases varies greatly. Therefore, how to comprehensively characterize data is a challenging task. To this end, we have designed a dual attention mechanism, which can effectively link different feature information and adaptively integrate disease-related text information.
In this subsection, we introduce a dual attention network for medical code and document representation learning. This network composes of the label attention mechanism and the self-attention mechanism. We introduce these two parts in detail in the following two sub-sections.
The dual attention network aims to identify the components related to the medical code in each clinical text. Intuitively, it can simultaneously take the clinical text and medical codes into account and expand the receptive field of the model. Therefore, this strategy is suitable for clinical code classification.
For example, regarding the original text, “This is an 81-year-old woman with a history of emphysema, her primary care doctor thought she had shortness of breath for three days and thought it was a COPD attack.” It is divided into two categories: Emphysema and COPD. The content of "emphysema" is more related to the patient's medical history than directly related to symptoms, and “COPD” (chronic obstructive pulmonary disease) should be related to the patient's symptoms. Next, we introduce the two components of the dual attention network.
Self-attention mechanism
As mentioned above, a multi-label clinical text can be marked by more than one medical code, and each clinical document should have the most relevant context to its corresponding medical code. In other words, each record may contain multiple components, which contribute differently to each medical code.
To capture the different components of each text, we adopt a self-attention mechanism [23], which has been successfully used in various text mining tasks [24]. The clinical text attention score (\({T}^{S }\in {R}^{l\times n}\)) can be calculated by.
$${T}^{S }=softmax\left({W}_{1}\mathit{tan}h\left({W}_{2}H\right)\right)$$
(6)
where \({W}_{1}\in {R}^{d\times 2k}\) and \({W}_{2}\in {R}^{l\times d}\) are the self-attention parameters that need training. The d is a hyperparameter that we can set. Each row \({T}_{j}^{s}\) (an n-dim row vector where n is the total number of words) represents the contribution of clinical records to the \({j}_{th}\) label. We can get the linear combination of contexts. Finally, the clinical text representation of the medical code \({M}^{(S)}\in {R}^{l\times 2k}\) is calculated as follows.
$${{M}_{j}^{s}=T}_{j}^{s}{H}^{T}$$
(7)
Label attention mechanism
The self-attention mechanism can be regarded as the attention based on the clinical text because it focuses on the document content.
As we all know, medical codes have specific semantics in ICD coding. To utilize the semantic information of the codes, we preprocess the codes' descriptions and represent them as a trainable matrix \({C\in R}^{l\times k}\) in the same k-dim space with the documents.
Once we have the word embedding from Bi-LSTM and the code embedding in \(C\), we can determine the semantic relationship between each pair of words and codes. We calculate the dot product between \({h}_{d}\) and \({C}_{j}\) as follows.
$${B}^{\left(l\right)}=CH$$
(8)
where \({B}^{\left(l\right)}\in\) \({R}^{l\times n}\) indicates the forward and backward sides relation between words and codes. Like the previous self-attention mechanism, the medical code representation can be constructed by linearly combining the context words of the code, as shown below.
$${M}^{\left(l\right)}={B}^{\left(l\right)}{H}^{T}$$
(9)
Finally, the document can be re-represented along with the code by \({M}^{\left(l\right)}\in {R}^{l\times 2k}\).
Joint learning mechanism
Using these two pieces of information has become a vital issue when we get the label attention matrix L and the self-attention matrix S. In this section, a joint learning strategy is proposed to extract critical information from the attention matrix.
Joint learning can integrate multiple sub-models into one model. Specifically, after the label attention and self-attention matrix are determined, joint learning can train the attention modules and the rest of the model together by introducing hyperparameters. In this way, we build specific document representations for both high-frequency and low-frequency labels.
The label attention matrix focuses on the semantic connection between medical code and clinical text. In contrast, the self-attention matrix focuses on the content of clinical medical records. We introduce the joint learning mechanism to fully use these two parts, as shown in Fig. 4, which can extract appropriate information from these two parts.
Specifically, we multiply the self-attention matrix and the label attention matrix with \({W}_{3} \mathrm{and} {W}_{4}\), and feed the results to the sigmoid activation function. After that, we get two weight vectors \(\alpha\) and \(\beta\) to represent the importance of different attention matrices. These two weight vectors can be obtained by inputting the fully connected layer on S and L.
$${Sigmoid \left( x \right) = \frac{1}{{1 + e^{ - x} }}}$$
(10)
$${\alpha = Sigmoid\left( { S W_{3} } \right), S \in R^{l \times k} }$$
(11)
$${\beta = Sigmoid\left( { L W_{4} } \right), L \in R^{l \times k} }$$
(12)
\({W}_{3},{W}_{4}\in {R}^{k}\) are the parameters to be trained. \({\alpha }_{i}\) and \({\beta }_{i}\) represent the importance of different attention matrices to construct the final attention matrix representation for the \(i\_th\) label text. Therefore, we apply the following constraints to the two weight vectors.
$${0 < \alpha_{{i{ }}} + \beta_{i} \le 1}$$
(13)
After that, we multiply the weight vector with the label attention and self-attention matrix. Finally, we splice the label attention matrix and the self-attention matrix after the above processing along the \(i\_th\) label to obtain the attention matrix.
Denoising mechanism
In this part, we consider the noise problem in medical code allocation. Specifically, ICD code assignment is usually a manual process that takes a long time per patient. Due to inexperienced coders, differences between coders, and incorrect grouping codes, it is also prone to errors. In addition, clinical diagnosis and treatment records are often long texts prone to misspelling or typos, leading to wrong code predictions and affect model performance[25].
Since noise negatively influences the classification results, we consider introducing the denoising mechanism and designing an auxiliary noise model on the classifier. Our target is to identify and prune the noisy samples to improve the quality of classifier training [26].
We leverage the finding that learning on clean labels is more accessible than noise labels [27]. Furthermore, we combine the binary cross entropy loss function [28] and design it as a truncation loss function. Specifically, truncation loss discards large loss samples with dynamic thresholds in each iteration. Our training goal is to minimize the loss between the prediction \(\tilde{y }\) and the target y:
$$\begin{array}{c}{T}_{loss}\left(y,\tilde{y }\right)=\left\{\begin{array}{c}0, {BC}_{loss}(y,\tilde{y })>\varepsilon \cup (\tilde{y }=1)\\ {BC}_{loss}, Otherwise,\end{array}\right.\end{array}$$
(14)
where \(\varepsilon\) denotes the pre-defined threshold and \({BC}_{loss}\) represents the binary cross entropy loss.
The truncation loss removes the noise samples whose binary cross entropy loss is larger than \(\varepsilon\). Although this truncation loss is easy to explain and implement, the fixed threshold may not suit the entire training process. Because the noisy feedback typically has large loss values during the early epochs[29], the training loss value decrease as the training iterations increase. To adapt to the overall trend of training loss, we can replace the fixed threshold with a dynamic threshold function \({D}_{T}\), which changes the threshold during the training process.
$${D_{T} = \min \left( {\gamma T,D_{max} } \right),}$$
(15)
where \(D_{max}\) is the upper bound, and \(\gamma\) is a parameter to adjust the speed to achieve the maximum drop rate.
Thus, the training strategy constrains the noise and drives the classifier to learn from the clean training samples. This method can use the dynamic threshold function to truncate the loss value of the high-loss interaction to zero and discard the high-loss noise influence.
Output layer
In this part, we feed the denoised information V into the classifier. Once we have a comprehensive representation of clinical texts and medical codes, we can build a multi-label text classifier through a multilayer perceptron with two fully connected layers. Then we use the sum-pooling operation to obtain the score \(\widehat{y}\) for the ICD codes. Mathematically speaking, the predicted probability \(\tilde{y }\) of each code can be estimated in the following way:
$$\widehat{y}=pooling\left(V\right), { \widehat{y}}_{i}=\sum_{j=1}^{n}{V}_{ij}, V\in {R}^{n\times k}$$
(16)
$$\tilde{y }=sigmoid\left(\widehat{y}\right)$$
(17)
Finally, the sigmoid function is used to convert the score vector into a probability vector.