Skip to main content

A hybrid self-attention deep learning framework for multivariate sleep stage classification

Abstract

Background

Sleep is a complex and dynamic biological process characterized by different sleep patterns. Comprehensive sleep monitoring and analysis using multivariate polysomnography (PSG) records has achieved significant efforts to prevent sleep-related disorders. To alleviate the time consumption caused by manual visual inspection of PSG, automatic multivariate sleep stage classification has become an important research topic in medical and bioinformatics.

Results

We present a unified hybrid self-attention deep learning framework, namely HybridAtt, to automatically classify sleep stages by capturing channel and temporal correlations from multivariate PSG records. We construct a new multi-view convolutional representation module to learn channel-specific and global view features from the heterogeneous PSG inputs. The hybrid attention mechanism is designed to further fuse the multi-view features by inferring their dependencies without any additional supervision. The learned attentional representation is subsequently fed through a softmax layer to train an end-to-end deep learning model.

Conclusions

We empirically evaluate our proposed HybridAtt model on a benchmark PSG dataset in two feature domains, referred to as the time and frequency domains. Experimental results show that HybridAtt consistently outperforms ten baseline methods in both feature spaces, demonstrating the effectiveness of HybridAtt in the task of sleep stage classification.

Background

Sleep is a complicated biological process and plays an essential role in health. Sleep occurs in cycle and involves different sleep stages, helping restore functions of body and mind, such as immune, nervous, skeletal, and muscular systems [1]. Unhealthy lifestyles and work-related stress may lead to sleep disturbances, which has become one of the serious issues in modern societies. Sleep disorders not only cause a reduction in physical performance during the day, but have negative effects on cognitive functions [2]. Moreover, some psychological and neurological diseases can also deteriorate normal sleep patterns [3]. Towards this end, in order to provide prevention and treatment of the sleep-related disorders, sleep stage analysis has garnered great interest among researchers in medical and bioinformatics recently.

In practice, physicians often use polysomnography (PSG) records to comprehensively analyze sleep [4]. PSG data contain multivariate physiological signals, such as electroencephalogram (EEG), electromyogram (EMG), electrocardiogram (ECG), and electrooculogram (EOG), in order to monitor different body regions. In particular, through visual inspection, each 30-s time slot of PSG data can be classified into different sleep stages by different rules. According to the standard Rechtschaffen and Kales (R&K) rules [5], for example, the sleep phase can be classified into stages as wakefulness, non-rapid eye movement (NREM) sleep, and rapid eye movement sleep. Among them, the NREM sleep is further subdivided into four sleep stages referred to as S1, S2, S3, and S4. However, it is extremely time-consuming and laborious for physicians to visually inspect long-term PSG records. In addition, identifying and analyzing sleep patterns also requires highly-trained professionals. Therefore, it is necessary to develop an automatic system capable of classifying sleep stages to enhance efficiency of PSG sleep analysis.

In recent years, various automatic sleep stage classification systems have been presented utilizing overnight PSG records [2, 3]. Several researchers focus on extracting different handcrafted features from multivariate PSG data to train an aggregated classifier. On one hand, different kinds of discriminative features, such as time-domain features [6, 7], frequency-domain features [8, 9], and other nonlinear measurements [10, 11], have been adopted to analyze the PSG data in each time slot. On the other hand, some well-known classifiers in machine learning, including support vector machine (SVM) [12, 13] and neural networks (NN) [14, 15], are employed to help identify the sleep stages. These methods advance the development of automatic sleep stage classification systems, but typically requires a significant amount of domain knowledge and would not guarantee consistent good performance using multi-stage training procedures to make all the components work together. Furthermore, the recent advances in deep learning allow researchers to improve classification performance by directly learning feature representations from the multivariate biosignals [16]. By constructing multi-layer neural networks in different way, some classic deep learning structures, such as deep belief networks (DBN) [17, 18], convolutional neural networks (CNN) [19–21] and recurrent neural networks (RNN) [22, 23], have been well applied in the task of sleep stage classification with promising results.

However, existing deep learning models lack a mechanism to extract comprehensive correlations of the multivariate PSG records, presenting a challenge to accurately classify sleep stages. Specifically, the complex correlations among PSG channels are important to recognize sleep patterns. For instance, the abnormal wake-up (i.e., wakefulness stage) in central sleep apnea is caused by the nervous system irregularities which trigger the heart abnormalities and muscles movements [24]. These correlated physiological conditions can be reflected from EEG, ECG, and EMG, respectively, which are helpful for sleep stage classification. Secondly, PSG data involve dynamic correlations across different timestamps (or time slots), which help identify informative events during sleep, such as irregular sleep-wake rhythm and sudden involuntary movement [25], to improve classification performance.

To this end, we propose HybridAtt, a deep learning framework with hybrid self-attention mechanism to classify sleep stages from the multivariate PSG inputs. The proposed hybrid self-attention mechanism is able to capture the dual correlations of PSG channels and timestamps by inferring their dependencies without any additional supervision. Moreover, a multi-view convolutional representation module is constructed to help the proposed attention mechanism fuse PSG data. We conduct cross-subject experiments in comparison with ten baseline methods, and demonstrate the effectiveness of our proposed HybridAtt model on a benchmark PSG dataset in two feature domains, referred to as the time and frequency domains. We summarize our main contributions as follows:

  • We propose HybridAtt, an end-to-end hybrid self-attention deep learning framework for sleep stage classification using multivariate PSG records.

  • HybridAtt explicitly extracts the dual correlations of PSG channels and timestamps by inferring their dependencies based on multi-view convolutional representations.

  • We empirically show that HybridAtt consistently achieve best performance compared with ten baselines on a benchmark dataset under different feature domains.

Methods

In this section, we introduce the technical details of our HybridAtt model with multivariate PSG inputs. We first describe the overall architecture and then detail the main components of HybridAtt.

Model architecture

Figure 1 presents the architecture of our proposed HybridAtt model. The goal of HybridAtt is to capture dual correlations of PSG channels and timestamps by calculating the dependencies of their multi-view convolutional representations, in order to improve the performance of sleep stage classification using multivariate PSG records. Formally, we assume that there are M multivariate PSG records with T(M) timestamps, denoted as \(\left \{\boldsymbol {X}_{1}^{(m)}, \boldsymbol {X}_{2}^{(m)}, \cdots, \boldsymbol {X}_{T^{(m)}}^{(m)}\right \}_{m=1}^{M}\). Each record Xt at timestamp t contains a set of C-channel heterogeneous waveform vectors \(\left \{\boldsymbol {x}_{t}^{1}, \boldsymbol {x}_{t}^{2}, \cdots, \boldsymbol {x}_{t}^{C}\right \}\) where \(\boldsymbol {x}_{t}^{c} \in \mathbb {R}^{n^{(c)}}\). To learn informative features from the heterogeneous inputs, in our model, we first feed the input Xt into a multi-view convolutional representation module to extract the channel-view hidden features \(\boldsymbol {d}_{t}^{1:C}\) and global-view hidden features st, respectively. We then develop a channel-wise attention module to capture the complex channel correlations at each timestamp based on the learned multi-view features. Subsequently, a time-wise attention module, combined with bidirectional gated recurrent units (BGRU), is utilized to distinguish the dynamic correlations. Here we use ht and ct to denote the learned hidden state and context vector at timestamp t. Finally, we can further obtain an attentional hidden representation \(\boldsymbol {\tilde {h}}_{t}\) to predict the label \(\boldsymbol {y}_{t} \in \{0,1\}^{|\mathcal {C}|}\) where \(|\mathcal {C}|\) is the unique number of categories related to sleep stages. The proposed model can be trained in an end-to-end fashion.

Fig. 1
figure 1

Main architecture of the HybridAtt model. The goal of HybridAtt is to capture dual correlations of PSG channels and timestamps by calculating the dependencies of their multi-view convolutional representations, in order to improve the performance of sleep stage classification using multivariate PSG records

Multi-view convolutional representation

In practice, the collected PSG data often tend to be heterogeneous, referred to different sample rates, signal strengths, and rhythm patterns. Inspired by the rapid development of multi-view deep learning [26–29], we propose to modify the CNN structure to preserve the unique characteristics of each biomedical channel during feature representation. Given the input \(\boldsymbol {x}_{t}^{c}\) in the c-th channel at timestamp t, we use a 1-D channel-CNN encoder (i.e., CNNc) to derive its channel-view representation \(\boldsymbol {d}_{t}^{c} \in \mathbb {R}^{p}\), as follows:

$$\begin{array}{@{}rcl@{}} \boldsymbol{d}_{t}^{c} = \text{CNN}_{c}\left(\boldsymbol{x}_{t}^{c};\boldsymbol{\theta_{c}}\right), \end{array} $$
(1)

where θc denotes the learnable parameter set of CNNc. Similarly, we utilize a 2-D global-CNN encoder (i.e., CNNg) to obtain a global-view representation \(\boldsymbol {s}_{t} \in \mathbb {R}^{p}\) based on all the channels, as follows:

$$\begin{array}{@{}rcl@{}} \boldsymbol{s}_{t} = \text{CNN}_{g}\left(\boldsymbol{x}_{t}^{1:C};\boldsymbol{\theta_{g}}\right), \end{array} $$
(2)

where θg denotes the learnable parameter set of CNNg. Here we align the input dimension of each channel using linear interpolation to obtain a matrix input for Eq. (2).

In order to unleash the power of the multi-view convolutional representation module, we further polish the CNN structure design in our HybridAtt model, as shown in Fig. 2. The main design strategy consists of two aspects. First, the convolutional layer should cover multiple resolution scales since the waveform patterns of biosignals are related to different frequency modes [30]. Here we set different sizes of feature kernels in parallel to extract multi-scale features from biosignals. Second, CNNc and CNNg should focus on different characteristics of input data during feature learning. Towards this end, we guide these two encoders by setting max pooling for CNNc to extract the most important features of different channels, and setting average pooling for CNNg to retain more general information among all the channels. Taking the advantage of the multi-view structure, informative features with same dimensions can be well learned from the heterogeneous PSG inputs, and hence further help the hybrid attention mechanism capture the dual correlations.

Fig. 2
figure 2

CNN Structure of the multi-view convolutional representation module in HybridAtt. Informative features can be well extracted from the heterogeneous PSG inputs, and hence can further help the hybrid attention mechanism capture dual correlations

Hybrid self-attention mechanism

Channel-wise attention

In order to capture the complex correlations among PSG channels, we develop a channel-wise attention layer that is able to infer the importance of each channel based on the learned multi-view features, and fuse representations relied on more informative ones. Given the multi-view features \(\boldsymbol {d}_{t}^{c}\) and st obtained by Eqs. (1) and (2), we first compute a fusional rate \(r_{t}^{c} \in \mathbb {R}\) for each channel c at timestamp t, inferring how much information carried by each CNN encoder should be fused. The formulation is as follows:

$$\begin{array}{@{}rcl@{}} r_{t}^{c} = \sigma \left(\boldsymbol{W}_{rg}^{\top}\boldsymbol{s}_{t} + \boldsymbol{W}_{rc}^{\top}\boldsymbol{d}_{t}^{c} + b_{rc}\right), \end{array} $$
(3)

where \(\boldsymbol {W}_{rg} \in \mathbb {R}^{p}, \boldsymbol {W}_{rc} \in \mathbb {R}^{p}\), and \(b_{rc} \in \mathbb {R}\) are learnable parameters. Here we rescale \(r_{t}^{c}\) into the range of [0,1] using sigmoid function σ(·) in Eq. (3). Then, we assign an attention energy \(e_{t}^{c}\) for each channel c based on its fusional rate \(r_{t}^{c}\), as follows:

$$\begin{array}{@{}rcl@{}} e_{t}^{c} = \boldsymbol{W}_{ec}^{\top}\left(\left(1-r_{t}^{c}\right) \odot \boldsymbol{s}_{t} + r_{t}^{c} \odot \boldsymbol{d}_{t}^{c}\right)+b_{ec}, \end{array} $$
(4)

where \(\boldsymbol {W}_{ec} \in \mathbb {R}^{p}\) and \(b_{ec} \in \mathbb {R}\) are learnable parameters, and ⊙ denotes the element-wise multiplication operator. Given the attention energy, a channel-wise contribution score vector \(\boldsymbol {\alpha }_{t} \in \mathbb {R}^{C}\) can be normalized using softmax function, as follows:

$$\begin{array}{@{}rcl@{}} \boldsymbol{\alpha}_{t} = \text{Softmax}\left(\left[e_{t}^{1}, \cdots, e_{t}^{c},\cdots, e_{t}^{C}\right]\right). \end{array} $$
(5)

Each element \(\alpha _{t}^{c}\) in the vector measures the importance of information carried by the c-th channel.

Accordingly, we use weighted aggregation to calculate the output vector of the channel-wise attention \(\boldsymbol {\tilde {x}}_{t} \in \mathbb {R}^{2p}\) based on the contribution score vector αt:

$$\begin{array}{@{}rcl@{}} \boldsymbol{\tilde{x}}_{t}=\boldsymbol{s}_{t} \oplus \left(\sum\limits_{c=1}^{C}\alpha_{t}^{c} \odot \boldsymbol{d}_{t}^{c}\right), \end{array} $$
(6)

where ⊕ is the concatenation operator. In this way, our model can fully incorporate the multi-view information carried by both two feature views, and thus fuse more informative features from multivariate PSG records.

Time-wise attention

To capture the dynamic correlations across timestamps, the aforementioned attention strategy can be employed as well, namely time-wise attention. Given the learned vector sequence from \(\boldsymbol {\tilde {x}}_{1}\) to \(\boldsymbol {\tilde {x}}_{T}\), we derive the hidden state \(\boldsymbol {h}_{t} \in \mathbb {R}^{2q}\) through a 2-layer BGRU [31], as follows:

$$\begin{array}{@{}rcl@{}} \boldsymbol{h}_{1:T}=\text{BGRU}(\boldsymbol{\tilde{x}}_{1:T};\boldsymbol{\theta}_{r}), \end{array} $$
(7)

where θr is the learnable parameter set of BGRU. Here the hidden state ht at timestamp t is obtained by concatenating the forward hidden vector \(\overrightarrow {\boldsymbol {h}}_{t} \in \mathbb {R}^{q}\) and the backward hidden vector \(\overleftarrow {\boldsymbol {h}}_{t} \in \mathbb {R}^{q}\) in BGRU.

Subsequently, we can reformalize the attention strategy from Eqs. (3) to (5), to compute the time-wise contribution score vector \(\boldsymbol {\beta }_{t} \in \mathbb {R}^{T}\):

$$\begin{array}{@{}rcl@{}} r_{i} = \sigma \left(\boldsymbol{W}_{rt}^{\top}\boldsymbol{h}_{t} + \boldsymbol{W}_{ri}^{\top}\boldsymbol{h}_{i} + b_{rt}\right), \end{array} $$
$$\begin{array}{@{}rcl@{}} e_{t,i} = \boldsymbol{W}_{et}^{\top}((1-r_{i}) \odot \boldsymbol{h}_{t} + r_{i} \odot \boldsymbol{h}_{i})+b_{et}, \end{array} $$
$$\begin{array}{@{}rcl@{}} \boldsymbol{\beta}_{t} = \text{Softmax}([e_{t,1}, \cdots, e_{t,i},\cdots, e_{t,T}]), \end{array} $$

where \(\boldsymbol {W}_{rt} \in \mathbb {R}^{2q}, \boldsymbol {W}_{ri} \in \mathbb {R}^{2q}, b_{rt} \in \mathbb {R}, \boldsymbol {W}_{et} \in \mathbb {R}^{2q}\), and \(b_{et} \in \mathbb {R}\) are the learnable parameters. Finally, a temporal context vector \(\boldsymbol {c}_{t} \in \mathbb {R}^{2q}\) can be derived as the output of the time-wise attention:

$$\begin{array}{@{}rcl@{}} \boldsymbol{c}_{t}=\sum\limits_{i=1}^{T}\beta_{t,i} \odot \boldsymbol{h}_{t}. \end{array} $$
(8)

Unified neural classifier

With the help of our hybrid attention mechanism, we can obtain an attentional representation \(\boldsymbol {\hat {h}} \in \mathbb {R}^{r}\) by fusing the context vector ct and the current hidden state ht, defined as:

$$\begin{array}{@{}rcl@{}} \boldsymbol{\hat{h}}_{t} = f(\boldsymbol{W}_{h}[\boldsymbol{c}_{t} \oplus \boldsymbol{h}_{t}] + \boldsymbol{b}_{h}), \end{array} $$

where \(\boldsymbol {W}_{h} \in \mathbb {R}^{r \times 4q}\) and \(\boldsymbol {b}_{h} \in \mathbb {R}^{r}\) denote the learnable parameters. The attentional representation is then fed through the softmax layer to classify sleep stages, as follows:

$$\begin{array}{@{}rcl@{}} \boldsymbol{\hat{y}}_{t} = \text{Softmax}(\boldsymbol{W}_{s}\boldsymbol{\hat{h}}_{t}+\boldsymbol{b}_{s}). \end{array} $$
(9)

where \(\boldsymbol {W}_{s} \in \mathbb {R}^{|\mathcal {C}| \times r}\) and \(\boldsymbol {b}_{s} \in \mathbb {R}^{|\mathcal {C}|}\) are the learnable parameters. To train HybridAtt in an end-to-end manner, we employ cross-entropy to measure the classification loss between the \(\boldsymbol {\hat {y}}_{t}\) obtained by Eq. 9 and the ground truth yt. The cost function of our unified HybridAtt model JHybridAtt is defined as:

$$\begin{array}{@{}rcl@{}}\small \begin{aligned} & J_{\textsf{\scriptsize HybridAtt}}\left(\boldsymbol{X}_{1}^{(1)},\cdots,\boldsymbol{X}_{T^{(1)}}^{(1)},\cdots,\boldsymbol{X}_{1}^{(M)}, \cdots,\boldsymbol{X}_{T^{(M)}}^{(M)}\right) \\ = & \ -\frac{1}{M} \sum^{M}_{i=1}\frac{1}{T^{(i)}}\sum\limits_{t=1}^{T^{(i)}}\left[ \boldsymbol{y}_{t}^{\top}\log{\boldsymbol{\hat{y}}_{t}}+(\boldsymbol{1}-\boldsymbol{y}_{t})^{\top} \log{(\boldsymbol{1}-\boldsymbol{\hat{y}}_{t})}\right]. \end{aligned} \end{array} $$

Results and discussion

In this section, we evaluate HybridAtt on a benchmark PSG dataset in two feature domains, referred to as the time and frequency domains. We first introduce the dataset, then describe the baselines and some experiment details. We finally present and discuss the quantitative results in terms of different evaluation metrics.

Dataset description

We conduct experiments for multivariate PSG sleep stage classification based on the UCD dataset collected from St. Vincent’s University Hospital and University College Dublin [32]. This dataset contains 14-channel overnight PSG data, consisting of 128Hz EEG, 64Hz EMG, and other types of biosignals. We generate 287,840 input vectors from all 25 subjects, and each 30-s fragment is labeled as in one of the five sleep stages. In more detail, a 30-s long timestamp contains 53,760 data points in the time domain, and 27,300 data points in the frequency domain using short-time Fourier transform (STFT). Note that we merge the original S3 and S4 stages as a new S3 stage, and only retain the time slots belonging to the five sleep stages in our experiments.

Baselines

We compare HybridAtt with the following ten existing biosignal feature learning baselines:

SVM. SVM is a classic machine learning method. Here we use one-vs-all SVM for the five-class classification task. To avoid the curse of dimensionality, we utilize principal component analysis (PCA) to select top-r related components from all the PSG channels as features before training SVM, namely PSVM.

Deep neural networks (DNN). DNN is a basic multi-layer neural network. We train a 3-layer DNN with softmax by concatenating all the PSG channels as input.

RNN. RNN is designed for time series. Similar to DNN, we concatenate data and train the same BGRU structure as HybridAtt used with a softmax layer.

RNNAtt. RNNAtt is a RNN variant with attention mechanism. We add two existing attention strategies, called location-based and concatenation-based attention [33], after the BGRU structure, referred to as RNNAtt l and RNNAtt c, respectively.

CNN. CNN is a commonly used deep learning model for biosignals. We integrate the PSG data as a matrix, and train the same CNN structure in our multi-view convolutional representation module.

CRNN. CRNN is a CNN variant combined with RNN. Here we directly integrate the aforementioned CNN and BGRU to train a unified model.

CRNNAtt. CRNNAtt utilizes attention mechanism after the CRNN structure. Similarly, we perform the same process as RNNAtt, namely CRNNAtt l and CRNNAtt c, respectively.

ChannelAtt. ChannelAtt [34] is proposed to soft-select critical channels from multivariate biosignals using a global attention mechanism. Different from the original model using fully-connected layer for feature extraction, we use the proposed CNN structure as the feature encoder to train the model.

Our approaches

To fairly evaluate our proposed attention strategy, we show the performance of the following two approaches in the experiments.

HybridAtt l. HybridAtt l is a reduced model using the location-based attention mechanism in HybridAtt for sleep stage classification.

HybridAtt f. HybridAtt f uses the proposed attention strategy to calculate score vectors in the channel-wise and time-wise layers.

Evaluation criteria

To quantify the performance, five evaluation measurements are used to validate HybridAtt for PSG-based sleep stage classification. Both accuracy and F1-score are adopted for evaluation. Here we employ Macro and Micro metric to measure F1-score, namely Macro-F1 and Micro-F1, respectively. The Macro-based area-under-the-curve (AUC) of precision-recall (PR) and receiver operator characteristic (ROC) are also utilized to evaluate each approach, namely AUC-PR and AUC-ROC, respectively. Moreover, to evaluate our model as a general cross-subject classifier, we perform 5-fold subject-independent cross validation and report the average test performance with standard deviation (μ±σ) for each method. The ratio of training, validation and test sets is 0.7:0.1:0.2.

Implementation details

We implement all the approaches with Pytorch. The training process is done locally using NVIDIA Titan Xp GPU. Adadelta [35] is adopted for the training process to optimize the cost function in terms of the learnable parameters. We also use weight decay with 0.001 L2 penalty coefficient, 0.95 momentum, and 0.5 dropout rate for all the approaches. The structure configuration of our multi-view convolutional representation module is listed in Table 1, and we set p=128,q=128, and r=128 for our models and baselines.

Table 1 Configurations of the multi-view convolutional representation module in HybridAtt

Experimental results

We investigate the effectiveness of our proposed HybridAtt model, compared to the aforementioned baseline methods in the task of sleep stage classification. Tables 2 and 3 report the comparison results tested in the frequency and time domains, respectively. We highlight the best evaluation scores in boldface. We observe that HybridAtt achieves the best performance compared with the corresponding baselines in both feature domains on the UCD dataset.

Table 2 Classification performance comparisons on the UCD dataset in the frequency domain
Table 3 Classification performance comparisons on the UCD dataset in the time domain

Given the results of the baselines, the performance of the traditional classification method PSVM is better than DNN and the RNN-based models in the frequency domain, but worse in the time domain. It means that the raw frequency features of PSG data would carry distinctive information which help SVM learn relatively clear hyper-lane to separate sleep stages. The results of DNN between two feature domains also make the same observation, demonstrating the capability of the PSG spectral features. The limited improvement of attention-based RNN models, compared with RNN in both domains, show that the features learned by RNN does not provide enough information for attention mechanisms to make correct classification. This also indicates that simply concatenating PSG data is unsuitable for fully-connected networks to learn informative features, since it would ignore multivariate prior information. We can see that CNN-based models get better performance than the other baselines, benefiting from the proposed structure in the multi-view convolutional representation module. Compared with CRNN, the attention-based CRNN models perform better, because attention mechanism is able to fuse features based on more useful information carried by sequential representations. To fuse features from multi-channel representations using attention mechanism, ChannelAtt works well in classifying sleep stages and achieve better results than CRNNAtt. It illustrates that the hidden connections among PSG channels, captured by ChannelAtt, are more helpful for sleep stage classification. Furthermore, by capturing dual correlations among channels and timestamps, our proposed HybridAtt model consistently gains the best evaluation scores in both the time and frequency feature domains.

From the results of our models, HybridAtt f outperforms the baselines in terms of all five evaluation measurements. For example, HybridAtt f obtains the best accuracy of 0.7424 in the time domain, compared with 0.7317 and 0.7169 achieved by our reduced model HybridAtt l and the baseline model ChannelAtt, respectively. Compared the results between two domains, on one hand, we observe that the results of HybridAtt l performs on par with those of ChannelAtt in the frequency domain. It means that adopting traditional location-based attention in the channel-wise layer cannot capture enough information from multi-view representations, and hence fail to help time-wise attention extract high-level features. On the other hand, HybridAtt f, utilizing the proposed attention strategy fuse multi-view features, achieves a robust performance under different raw feature spaces. Moreover, the results of HybridAtt in the time domain performs better than those in the frequency domain. We conjecture that CNN has a similar convolution procedure as STFT, but CNN adopts learnable kernels during convolution while STFT employs fixed Fourier functions. Taking the advantage of end-to-end learning, our hybrid attention mechanism can help learn more representative convolutional kernels in CNN than the handcrafted window functions in STFT.

Figure 3 illustrates the ROC and PR curves of all the test folds on the UCD dataset, respectively. We observe that the proposed HybridAtt f method consistently gains the best AUC in terms of the PR and ROC in different domains, demonstrating an effective cross-subject method in the task of sleep stage classification. Based on the overall performance comparisons, we conclude that attention mechanism is key to identify sleep patterns for sleep stage classification. Adopting single-dimension attention in different aspects, such as CRNNAtt and ChannelAtt, may lose useful information dealing with multivariate PSG records. Multi-view representation is also essential for attention mechanism inferring important information. By constructing hybrid attention networks based on multi-view convolutional representation, the HybridAtt achieves better results in both feature domains, in comparison with different feature learning methods, demonstrating the effectiveness of HybridAtt in PSG-based sleep stage classification.

Fig. 3
figure 3

ROC and PR curves of the proposed method and the baselines in different feature domains on the UCD dataset. a and b plots the ROC curves in the frequency and time domains, respectively. c and d plots the PR curves in the frequency and time domains, respectively

Conclusions

In this paper, we present a unified hybrid self-attention deep learning framework, namely HybridAtt, to classify sleep stages from multivariate PSG records. HybridAtt is designed to capture dual correlations among channels and timestamps based on multi-view convolutional feature representations. Experiments on a benchmark PSG dataset show that HybridAtt is able to efficiently fuse multivariate information from PSG data and hence consistently beats the baselines in both the time and the frequency feature domains. In future work, we will extend HybridAtt to other biomedical applications with similar data structure, and propose advanced attention mechanism that can jointly learn two-dimensional contribution scores in one step, instead of adopting the multi-step attention strategy.

Availability of data and materials

The UCD dataset used in our experiments can be downloaded in https://physionet.org/physiobank/database/ucddb/. The data is available for public and free to use.

Abbreviations

AUC:

Area under the curve

BGRU:

Bidirectional gated recurrent units

CNN:

Convolutional neural networks

DBN:

Deep belief networks

DNN:

Deep neural networks

ECG:

Electrocardiogram

EEG:

Electroencephalogram

EMG:

Electromyogram

EOG:

Electrooculogram

NN:

Neural networks

NREM:

Non-rapid eye movement

PCA:

Principal component analysis

PR:

Precision recall

PSG:

Polysomnography

R&K:

Rechtschaffen and Kales

RNN:

Recurrent neural networks

ROC:

Receiver operator characteristic

STFT:

Short-time fourier transform

SVM:

Support vector machine

References

  1. Luyster FS, Strollo PJ, Zee PC, Walsh JK. Sleep: a health imperative. Sleep. 2012; 35(6):727–34.

    Article  Google Scholar 

  2. Aboalayon KAI, Faezipour M, Almuhammadi WS, Moslehpour S. Sleep stage classification using eeg signal analysis: a comprehensive survey and new investigation. Entropy. 2016; 18(9):272.

    Article  Google Scholar 

  3. Boostani R, Karimzadeh F, Nami M. A comparative review on sleep stage classification methods in patients and healthy individuals. Comput Methods Programs Biomed. 2017; 140:77–91.

    Article  Google Scholar 

  4. Şen B, Peker M, Çavuşoğlu A, Çelebi FV. A comparative study on classification of sleep stage based on eeg signals using feature selection and classification algorithms. J Med Syst. 2014; 38(3):18.

    Article  Google Scholar 

  5. Wolpert EA. A manual of standardized terminology, techniques and scoring system for sleep stages of human subjects. Arch Gen Psychiatr. 1969; 20(2):246–7.

    Article  Google Scholar 

  6. Khalighi S, Sousa T, Oliveira D, Pires G, Nunes U. Efficient feature selection for sleep staging based on maximal overlap discrete wavelet transform and svm. In: Engineering in Medicine and Biology Society, EMBC, 2011 Annual International Conference of the IEEE. IEEE: 2011. p. 3306–9.

  7. Tsai P-Y, Hu W, Kuo TB, Shyu L-Y. A portable device for real time drowsiness detection using novel active dry electrode system. In: Engineering in Medicine and Biology Society, 2009. EMBC 2009. Annual International Conference of the IEEE. IEEE: 2009. p. 3775–8.

  8. Charbonnier S, Zoubek L, Lesecq S, Chapotot F. Self-evaluated automatic classifier as a decision-support tool for sleep/wake staging. Comput Biol Med. 2011; 41(6):380–9.

    Article  CAS  Google Scholar 

  9. Li Y, Yingle F, Gu L, Qinye T. Sleep stage classification based on eeg hilbert-huang transform. In: Industrial Electronics and Applications, 2009. ICIEA 2009. 4th IEEE Conference On. IEEE: 2009. p. 3676–81. https://doi.org/10.1109/iciea.2009.5138842.

  10. Shi J, Liu X, Li Y, Zhang Q, Li Y, Ying S. Multi-channel eeg-based sleep stage classification with joint collaborative representation and multiple kernel learning. J Neurosci Methods. 2015; 254:94–101.

    Article  Google Scholar 

  11. Phan H, Do Q, Do T-L, Vu D-L. Metric learning for automatic sleep stage classification. In: Engineering in Medicine and Biology Society (EMBC), 2013 35th Annual International Conference of the IEEE. IEEE: 2013. p. 5025–8. https://doi.org/10.1109/embc.2013.6610677.

  12. Huang C-S, Lin C-L, Ko L-W, Liu S-Y, Sua T-P, Lin C-T. A hierarchical classification system for sleep stage scoring via forehead eeg signals. In: Computational Intelligence, Cognitive Algorithms, Mind, and Brain (CCMB), 2013 IEEE Symposium On. IEEE: 2013. p. 1–5.

  13. Gudmundsson S, Runarsson TP, Sigurdsson S. Automatic sleep staging using support vector machines with posterior probability estimates. In: Computational Intelligence for Modelling, Control and Automation, 2005 and International Conference on Intelligent Agents, Web Technologies and Internet Commerce, International Conference On, vol. 2. IEEE: 2005. p. 366–72. https://doi.org/10.1109/cimca.2005.1631496.

  14. Özşen S. Classification of sleep stages using class-dependent sequential feature selection and artificial neural network. Neural Comput Applic. 2013; 23(5):1239–50.

    Article  Google Scholar 

  15. Tagluk ME, Sezgin N, Akin M. Estimation of sleep stages by an artificial neural network employing eeg, emg and eog. J Med Syst. 2010; 34(4):717–25.

    Article  Google Scholar 

  16. Najdi S, Gharbali AA, Fonseca JM. Feature transformation based on stacked sparse autoencoders for sleep stage classification. In: Doctoral Conference on Computing, Electrical and Industrial Systems. Springer: 2017. p. 191–200. https://doi.org/10.1007/978-3-319-56077-9_18.

    Google Scholar 

  17. Längkvist M, Karlsson L, Loutfi A. Sleep stage classification using unsupervised feature learning. Adv Artif Neural Syst. 2012; 2012:5.

    Google Scholar 

  18. Zhang J, Wu Y, Bai J, Chen F. Automatic sleep stage classification based on sparse deep belief net and combination of multiple classifiers. Trans Inst Meas Control. 2016; 38(4):435–51.

    Article  Google Scholar 

  19. Supratak A, Dong H, Wu C, Guo Y. Deepsleepnet: a model for automatic sleep stage scoring based on raw single-channel eeg. IEEE Trans Neural Syst Rehabil Eng. 2017; 25(11):1998–2008.

    Article  Google Scholar 

  20. Tsinalis O, Matthews PM, Guo Y, Zafeiriou S. Automatic sleep stage scoring with single-channel eeg using convolutional neural networks. 2016. arXiv preprint arXiv:1610.01683.

  21. Chambon S, Galtier MN, Arnal PJ, Wainrib G, Gramfort A. A deep learning architecture for temporal sleep stage classification using multivariate and multimodal time series. IEEE Trans Neural Syst Rehabil Eng. 2018; 26:758–69.

    Article  Google Scholar 

  22. Giri EP, Fanany MI, Arymurthy AM. Combining generative and discriminative neural networks for sleep stages classification. 2016. arXiv preprint arXiv:1610.01741.

  23. Zhao M, Yue S, Katabi D, Jaakkola TS, Bianchi MT. Learning sleep stages from radio signals: a conditional adversarial architecture. In: International Conference on Machine Learning. ACM: 2017. p. 4100–9.

  24. Guilleminault C, Tilkian A, Dement WC. The sleep apnea syndromes. Annu Rev Med. 1976; 27(1):465–84.

    Article  CAS  Google Scholar 

  25. Thorpy MJ. Classification of sleep disorders. Sleep disorders medicine. Journal of clinical neurophysiology: official publication of the American Electroencephalographic Society. 1990; 7(1):67–81.

    Article  CAS  Google Scholar 

  26. Zhao J, Xie X, Xu X, Sun S. Multi-view learning overview: Recent progress and new challenges. Inform Fusion. 2017; 38:43–54.

    Article  Google Scholar 

  27. Yuan Y, Xun G, Jia K, Zhang A. A multi-view deep learning method for epileptic seizure detection using short-time fourier transform. In: Proceedings of the 8th ACM International Conference on Bioinformatics, Computational Biology, and Health Informatics. ACM: 2017. p. 213–22. https://doi.org/10.1145/3107411.3107419.

  28. Yuan Y, Xun G, Jia K, Zhang A. A multi-context learning approach for eeg epileptic seizure detection. BMC Syst Biol. 2018; 12(6):107.

    Article  Google Scholar 

  29. Yuan Y, Jia K, Ma F, Xun G, Wang Y, Su L, Zhang A. Multivariate sleep stage classification using hybrid self-attentive deep learning networks. In: 2018 IEEE International Conference on Bioinformatics and Biomedicine (BIBM). IEEE: 2018. p. 963–8. https://doi.org/10.1109/bibm.2018.8621146.

  30. Yuan Y, Xun G, Suo Q, Jia K, Zhang A. Wave2vec: Deep representation learning for clinical temporal data. Neurocomputing. 2018; 324:31–42.

    Article  Google Scholar 

  31. Schuster M, Paliwal KK. Bidirectional recurrent neural networks. IEEE Trans Signal Process. 1997; 45(11):2673–81.

    Article  Google Scholar 

  32. Goldberger AL, Amaral LA, Glass L, Hausdorff JM, Ivanov PC, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. Physiobank, physiotoolkit, and physionet. Circulation. 2000; 101(23):215–20.

    Article  Google Scholar 

  33. Ma F, Chitta R, Zhou J, You Q, Sun T, Gao J. Dipole: Diagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks. In: Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. ACM: 2017. p. 1903–11. https://doi.org/10.1145/3097983.3098088.

  34. Yuan Y, Xun G, Ma F, Suo Q, Xue H, Jia K, Zhang A. A novel channel-aware attention framework for multi-channel eeg seizure detection via multi-view deep learning. In: Biomedical & Health Informatics (BHI), 2018 IEEE EMBS International Conference On. IEEE: 2018. p. 206–9. https://doi.org/10.1109/bhi.2018.8333405.

  35. Zeiler MD. Adadelta: an adaptive learning rate method. 2012. arXiv preprint arXiv:1212.5701.

Download references

Acknowledgements

Not applicable.

About this supplement

This article has been published as part of BMC Bioinformatics Volume 20 Supplement 16, 2019: Selected articles from the IEEE BIBM International Conference on Bioinformatics & Biomedicine (BIBM) 2018: bioinformatics and systems biology. The full contents of the supplement are available online at https://bmcbioinformatics.biomedcentral.com/articles/supplements/volume-20-supplement-16.

Funding

This work has been financially supported by the National Science Foundation of China (81871394, 61672064), and the Science and Technology Project of Beijing Municipal Education Commission (KM201810005030). The publication costs were funded by the Beijing Laboratory of Advanced Information Networks (PXM2019_014204_500029).

Author information

Authors and Affiliations

Authors

Contributions

YY, GX and FM developed the study concept and designed the model. YY programmed the deep learning framework, carried out the experiments, and wrote most of the manuscript. YW acquired and processed the data. YY and KJ analyzed the data and the experimental results. LS contributed to writing the manuscript. KJ and AZ supervised and helped conceive the study. All the authors were involved in the revision of the manuscript. All the authors read and approved the final manuscript.

Corresponding author

Correspondence to Kebin Jia.

Ethics declarations

Ethics approval and consent to participate

Not applicable.

Consent for publication

Not applicable.

Competing interests

The authors declare that they have no competing interests.

Additional information

Publisher’s Note

Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Rights and permissions

Open Access This article is distributed under the terms of the Creative Commons Attribution 4.0 International License (http://creativecommons.org/licenses/by/4.0/), which permits unrestricted use, distribution, and reproduction in any medium, provided you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons license, and indicate if changes were made. The Creative Commons Public Domain Dedication waiver(http://creativecommons.org/publicdomain/zero/1.0/) applies to the data made available in this article, unless otherwise stated.

Reprints and permissions

About this article

Check for updates. Verify currency and authenticity via CrossMark

Cite this article

Yuan, Y., Jia, K., Ma, F. et al. A hybrid self-attention deep learning framework for multivariate sleep stage classification. BMC Bioinformatics 20 (Suppl 16), 586 (2019). https://doi.org/10.1186/s12859-019-3075-z

Download citation

  • Published:

  • DOI: https://doi.org/10.1186/s12859-019-3075-z

Keywords