An overview
A longitudinal sequential medication recommendation task can be defined as follows:
\(\textit{Definition 1: Longitudinal EMR Data.}\) In EMR data, each patient’s records can be represented as a sequence of multivariate observations: \(S^{n}=\left\{ P_1^{(n)},P_2^{(n)},\cdots P_{T^{(n)}}^{(n)} \right\}\) where n represents the n-th patient and \(T^{(n)}\) is the number of visits of the n-th patient. The EMR record of the t-th visit is described as \(P_t^{(n)}=\left\{ d_t^{(n)},m_t^{(n)},s_t^{(n)}\right\}\) where \(d_t^{(n)}\) is a collection of diagnostic codes for ICD-10, \(m_t^{(n)}\) is a collection of drug codes for National Drug Codes (NDC), \(s_t^{(n)}\) is the collection of self-reported symptoms named as SYM, such as “fever”.
\(\textit{Definition 2: Longitudinal Sequential Medication Recommendation Task.}\) Given the n-th patient’s history EMR records \(S_{1:t-1}^{(n)}=\left\{ P_1^{(n)},P_2^{(n)},\cdots P_{t-1}^{(n)} \right\}\), diagnostic codes \(d_t^{(n)}\), drug codes \(m_t^{(n)}\) and symptoms \(s_t^{(n)}\) at the t-th visit, we want to recommend the drugs at the t-th visit by generating multi-label output \(\hat{y_t}\in \left\{ 0,1\right\} ^{ML}\)which ML represents the number of drug codes. That is to say, the output of the medication recommendation is a list of appropriate drugs. And the recommendation problem is transformed to a multi-label classification problem.
This study proposed a MR-KPA model to realize this task based on small-scale data. On the one hand, the proposed model adopts a knowledge-enhanced pre-training. A large number of single-visit EMR data is used as the pre-training data for avoiding segment limited longitudinal EMR data. The classification knowledge of diagnostic and drug codes was encoded as external domain features and then fused into EMR embeddings. On the other hand, this model integrated adversarial training into multi-layer perceptron (MLP) to avoid the over-fitting of model during the fine-tuning process.
The whole framework of MR-KPA is described in Fig. 2. It includes three modules: input representation, pre-training and prediction. The input representation module transforms each EMR record into the diagnosis code embedding, the drug code embedding and the symptom embedding. Based on these three types of embeddings, the pre-training module creates a pre-training visit model by performing two types of pre-training tasks. Finally, the prediction module fine-tunes the pre-training visit model and obtains the predicted drug code based on patient’s multiple-visit records. The details will be described in the following subsections.
Input representation
The input representation module transforms each EMR into a group of multi-dimensional embeddings as the input of the subsequent module. As shown in Fig. 3, multiple-visit records are inputted into this module. Each record includes columns SUBJECT ID, HADM ID, ICD-10, NDC, and SYM, which represent the patient ID, hospital ID, diagnostic code, drug code, and symptom participle respectively. They are transformed into two ontology embeddings and one dictionary embedding.
For the EMR of n-th patient at t-th visit \(P_t^{(n)}=\left\{ d_t^{(n)},m_t^{(n)},s_t^{(n)}\right\}\), its input embedding can be obtained as follows.
\(\textit{Ontology embedding.}\) Ontology embedding is adopted to realize domain knowledge-based external feature fusion. Two types of code ontology embeddings are constructed from ICD-10 ontology \(O_{d}\) and NDC ontology \(O_m\). Because medical codes in raw EMR data are leaf nodes in code ontology trees, code ontology embedding can be obtained by using graph attention network (GAT) [8, 10, 12, 13]. It can encode the classification knowledge in diagnostic and drug code trees as external domain features. For each medical code \(c_*\in d_t^{(n)} \cup m_t^{(n)}\) is the embedding dimension, and then the procedure is performed to obtain its ontology embedding as follows:
$$\begin{aligned} o_{c_*}=g(c_*,pa(c_* ),H_e)=\parallel _{k=1}^{k}\sigma \left( \sum _{j\in N_{c_*}}a_{c_*,j}^k W^k h_j\right) \end{aligned}$$
(1)
where \(*\in \left\{ d,m\right\}\), \(N_{c_*}=\left\{ \left\{ c_*\right\} \cup \left\{ pa(c_* )\right\} \right\}\) are the parent nodes of \(c_*\) and itself, \(\parallel\) represents concatenation which enables the multi-head attention mechanism, \(\sigma\) is a nonlinear activation function, \(W^k \in {\mathbb {R}}^{m\times d}\) is the weight matrix for input transformation, and \(a_{c_*,j}^k\) is the corresponding k-th normalized attention coefficient.
\(\textit{Dictionary embedding.}\) Dictionary embedding is constructed from a symptom dictionary \(D_s\), which contains all symptoms in EMR data. For each symptom \(s_i\in s_t^{(n)},\) its dictionary embedding \(d_{s_i}\) is just its index value in \(D_s\).
Pre-training
The pre-training module creates a pre-training visit model based on the input embedding transformed from single-visit records of EMR. By pre-training, a large number of single visit data are effectively used to mine the richer internal features of EMR.
Before pre-training, a multi-layer Transformer architecture [50] is adopted to derive visit embedding from two ontology embedding and one dictionary embedding of each EMR data. For \(P_t^{(n)}\), three types of visit embedding can be obtained as follows:
$$\begin{aligned} v_d^t= & {} Transformer\left( \left\{ [CLS]\right\} \cup \left\{ o_{d_i}\mid d_i\in d_t^{(n)}\right\} \right) \end{aligned}$$
(2)
$$\begin{aligned} v_m^t= & {} Transformer\left( \left\{ [CLS]\right\} \cup \left\{ o_{m_i}\mid m_i\in m_t^{(n)}\right\} \right) \end{aligned}$$
(3)
$$\begin{aligned} v_s^t= & {} Transformer\left( \left\{ [CLS]\right\} \cup \left\{ d_{s_i}\mid s_i\in s_t^{(n)}\right\} \right) \end{aligned}$$
(4)
where \(v_d^t\) is diagnostic visit embedding, \(v_m^t\) is drug visit embedding, \(v_s^t\) is symptom visit embedding, and [CLS] is the first tag of each sequence whose final hidden state will be used as an aggregate sequence representation of the classification task for enabling BERT to better handle various downstream tasks. In order to obtain the consistent length of the input token, it is necessary to align the tokens obtained by padding.
This paper conducts the following two kinds of pre-training tasks to make visit embedding absorb enough information about medication recommendation.
\(\textit{Mask EMR Field Task (Mask EF Task).}\) This task randomly masks some of the embedding to better represent information about the composition of EMR records. By changing word token masking of sentences [51] into field masking of EMR records, the following loss function is calculated:
$$\begin{aligned} \mathrm {L}_{s}\left( v_*,C_*^{(n)}\right) =-logP\left( C_*^{(n)}\mid v_*\right) =-\sum _{c_*\in c_*^{(n)}} logP(c_*\mid v_*)+\sum _{c_*\in (c_*\setminus c_*^{(n)})} logP(c_*\mid v_*) \end{aligned}$$
(5)
where \(C_*^{(n)}=\left( d_t^{(n)} \cup m_t^{(n)} \cup s_t^{(n)}\right)\) is an union set of medical codes and symptoms of n-th patient, \(c_* \in C_*^{(n)}\) denotes a medical code or symptom involved in the n-th patient and \(c_*\in \left\{ c_*\setminus c_*^{(n)}\right\}\)denotes the medical codes or symptoms not used for the n-th patient, \(*\in \left\{ d,m,s\right\}\). We minimize the binary cross entropy loss \(L_s\) to make the model have stronger self-prediction ability.
\(\textit{Correlation Prediction Task (CorP Task).}\) This task is used to represent information about the correlation among diagnostic codes, drug codes and symptoms. In BERT, the next sentence prediction (NSP) task facilitates the prediction of sentence relations. G-Bert revised the NSP task as the multidirectional prediction task for predicting unknown disease or drug codes of the sequence [16]. This paper revises the NSP task [52] as the CorP Task. For mutual prediction of diagnostic codes, drug codes and symptoms, the following three loss functions are calculated:
$$\begin{aligned} \mathrm {L}_{dm}= & {} -logP\left( C_d^{(n)}\mid v_m\right) -logP\left( C_m^{(n)}\mid v_d\right) \end{aligned}$$
(6)
$$\begin{aligned} \mathrm {L}_{ds}= & {} -logP\left( C_d^{(n)}\mid v_s\right) -logP\left( C_s^{(n)}\mid v_d\right) \end{aligned}$$
(7)
$$\begin{aligned} \mathrm {L}_{ms}= & {} -logP\left( C_m^{(n)}\mid v_s\right) -logP\left( C_s^{(n)}\mid v_m\right) \end{aligned}$$
(8)
Finally, the pre-training optimization objective can simply be the combination of the aforementioned losses:
$$\begin{aligned} \mathrm {L}_{pr}=\mathrm {L}_{s}\left( v_d,C_d^{(n)}\right) +\mathrm {L}_{s}\left( v_m,C_m^{(n)}\right) +\mathrm {L}_{s}\left( v_s,C_s^{(n)}\right) +\mathrm {L}_{dm}+\mathrm {L}_{s}+\mathrm {L}_{ms} \end{aligned}$$
(9)
Prediction
A MLP module with adversarial training is used to achieve the final prediction task. Based on the pre-training model, multi-visit EMR sequences can be transformed to three types of visit embedding sequences. Concatenating the average of previous diagnostic visit embedding, drug visit embedding, and symptom visit embedding before the t-th visit, as well as the diagnostic visit embedding and symptom visit embedding at the t-th visit, the MLP [53] can predict the recommended drug codes at the t-th visit as follows:
$$\begin{aligned} y_t=Sigmoid\left( W\left[ \left( \frac{1}{t-1} \sum _{\tau<t}v_d^{\tau }\right) \parallel \left( \frac{1}{t-1} \sum _{\tau<t}v_s^{\tau }\right) \parallel \left( \frac{1}{t-1} \sum _{\tau <t}v_m^{\tau }\right) \parallel v_d^{\tau } \parallel v_s^{\tau }\right] +b\right) \end{aligned}$$
(10)
where \(W\in {\mathbb {R}}^{\mid C_m \mid \times 3l}\)is a learnable transformation matrix.
Therefore, the loss function can be calculated as follows:
$$\begin{aligned} L_n=-\frac{1}{T-1} \sum _{t=2}^T\left( y_t^T log\hat{(y_t)}+\left( 1-y_t^T\right) log\left( 1-y_t^T\right) \right) \end{aligned}$$
(11)
where y is the predicted value sequence and \({\hat{y}}\) is the true value sequence. In this formula, t =2 means that the prediction starts from the second visit of the patient. The reason is that this paper focuses on longitudinal sequential medication recommendation which predicts the drugs currently suitable for the patient based on the patient’s historical and current diagnosis and symptom.
In order to avoid the over-fitting of model, this paper integrates the adversarial training FGM into the deep prediction model [54]. Adversarial training can not only improve the defense ability of the model against adversarial samples, but also improve the generalization ability of the original samples. For the prediction task, the disturbance \(r_{adv-d}\) and \(r_{adv-m}\) are added to the diagnostic ontology embedding and the drug ontology embedding respectively, in order to make the model wrong as much as possible and increase the robustness. Referring to [54], the disturbance can be calculated as follows:
$$\begin{aligned} \begin{aligned} r_{adv-d}=-\epsilon \frac{ \triangledown _x v_d^t}{\parallel \triangledown _x v_d^t \parallel _2} \\ r_{adv-m}=-\epsilon \frac{\triangledown _x v_m^t}{ \parallel \triangledown _x v_m^t \parallel _2} \end{aligned} \end{aligned}$$
(12)
where \(\epsilon\) is a constant. \(r_{adv-d}\) and \(r_{adv-m}\) are normalized values with the gradient of \(v_d^t\) and \(v_m^t\). The drug sequence \(y_t\) is predicted from the disturbed \(v_d^{\tau '}\) and \(v_m^{\tau '}\) which can be combined with the real drug sequence \(\hat{y_t}\) to construct a loss function. In back propagation, the gradient of counter training is accumulated on the basis of the normal gradient. Then the original values of \(v_d^{\tau }\) and \(v_m^{\tau }\) are restored. Finally, the parameters are updated according to the gradient of accumulated adversarial training. The loss function after adversarial training is defined in the same way as Eq. (11) where \(y_t\) is calculated from the disturbed diagnostic ontology embedding and drug ontology embedding on the basis of Eq. (13) as follows:
$$\begin{aligned} \begin{aligned} y_t=Sigmoid\left( W\left[ \left( \frac{1}{t-1} \sum _{\tau<t}\left( v_d^{\tau }+r_{adv-d}\right) \right) \parallel \left( \frac{1}{t-1} \sum _{\tau<t}v_s^{\tau }\right) \parallel \right. \right. \\\left. \left. \left( \frac{1}{t-1} \sum _{\tau <t}\left( v_m^{\tau }+r_{adv-m}\right) \right) \parallel \left( v_d^{\tau }+r_{adv-d}\right) \parallel v_s^{\tau }\right] +b\right) \end{aligned} \end{aligned}$$
(13)