A hybrid self-attention deep learning framework for multivariate sleep stage classification

Background Sleep is a complex and dynamic biological process characterized by different sleep patterns. Comprehensive sleep monitoring and analysis using multivariate polysomnography (PSG) records has achieved significant efforts to prevent sleep-related disorders. To alleviate the time consumption caused by manual visual inspection of PSG, automatic multivariate sleep stage classification has become an important research topic in medical and bioinformatics. Results We present a unified hybrid self-attention deep learning framework, namely HybridAtt, to automatically classify sleep stages by capturing channel and temporal correlations from multivariate PSG records. We construct a new multi-view convolutional representation module to learn channel-specific and global view features from the heterogeneous PSG inputs. The hybrid attention mechanism is designed to further fuse the multi-view features by inferring their dependencies without any additional supervision. The learned attentional representation is subsequently fed through a softmax layer to train an end-to-end deep learning model. Conclusions We empirically evaluate our proposed HybridAtt model on a benchmark PSG dataset in two feature domains, referred to as the time and frequency domains. Experimental results show that HybridAtt consistently outperforms ten baseline methods in both feature spaces, demonstrating the effectiveness of HybridAtt in the task of sleep stage classification.


Background
Sleep is a complicated biological process and plays an essential role in health. Sleep occurs in cycle and involves different sleep stages, helping restore functions of body and mind, such as immune, nervous, skeletal, and muscular systems [1]. Unhealthy lifestyles and work-related stress may lead to sleep disturbances, which has become one of the serious issues in modern societies. Sleep disorders not only cause a reduction in physical performance *Correspondence: kebinj@bjut.edu.cn 1 College of Information and Communication Engineering, Beijing University of Technology, Beijing, China 2 Beijing Laboratory of Advanced Information Networks, Beijing, China Full list of author information is available at the end of the article during the day, but have negative effects on cognitive functions [2]. Moreover, some psychological and neurological diseases can also deteriorate normal sleep patterns [3]. Towards this end, in order to provide prevention and treatment of the sleep-related disorders, sleep stage analysis has garnered great interest among researchers in medical and bioinformatics recently.
In practice, physicians often use polysomnography (PSG) records to comprehensively analyze sleep [4]. PSG data contain multivariate physiological signals, such as electroencephalogram (EEG), electromyogram (EMG), electrocardiogram (ECG), and electrooculogram (EOG), in order to monitor different body regions. In particular, through visual inspection, each 30-s time slot of PSG data can be classified into different sleep stages by different rules. According to the standard Rechtschaffen and Kales (R&K) rules [5], for example, the sleep phase can be classified into stages as wakefulness, non-rapid eye movement (NREM) sleep, and rapid eye movement sleep. Among them, the NREM sleep is further subdivided into four sleep stages referred to as S1, S2, S3, and S4. However, it is extremely time-consuming and laborious for physicians to visually inspect long-term PSG records. In addition, identifying and analyzing sleep patterns also requires highlytrained professionals. Therefore, it is necessary to develop an automatic system capable of classifying sleep stages to enhance efficiency of PSG sleep analysis.
In recent years, various automatic sleep stage classification systems have been presented utilizing overnight PSG records [2,3]. Several researchers focus on extracting different handcrafted features from multivariate PSG data to train an aggregated classifier. On one hand, different kinds of discriminative features, such as time-domain features [6,7], frequency-domain features [8,9], and other nonlinear measurements [10,11], have been adopted to analyze the PSG data in each time slot. On the other hand, some well-known classifiers in machine learning, including support vector machine (SVM) [12,13] and neural networks (NN) [14,15], are employed to help identify the sleep stages. These methods advance the development of automatic sleep stage classification systems, but typically requires a significant amount of domain knowledge and would not guarantee consistent good performance using multi-stage training procedures to make all the components work together. Furthermore, the recent advances in deep learning allow researchers to improve classification performance by directly learning feature representations from the multivariate biosignals [16]. By constructing multi-layer neural networks in different way, some classic deep learning structures, such as deep belief networks (DBN) [17,18], convolutional neural networks (CNN) [19][20][21] and recurrent neural networks (RNN) [22,23], have been well applied in the task of sleep stage classification with promising results.
However, existing deep learning models lack a mechanism to extract comprehensive correlations of the multivariate PSG records, presenting a challenge to accurately classify sleep stages. Specifically, the complex correlations among PSG channels are important to recognize sleep patterns. For instance, the abnormal wake-up (i.e., wakefulness stage) in central sleep apnea is caused by the nervous system irregularities which trigger the heart abnormalities and muscles movements [24]. These correlated physiological conditions can be reflected from EEG, ECG, and EMG, respectively, which are helpful for sleep stage classification. Secondly, PSG data involve dynamic correlations across different timestamps (or time slots), which help identify informative events during sleep, such as irregular sleep-wake rhythm and sudden involuntary movement [25], to improve classification performance.
To this end, we propose HybridAtt, a deep learning framework with hybrid self-attention mechanism to classify sleep stages from the multivariate PSG inputs. The proposed hybrid self-attention mechanism is able to capture the dual correlations of PSG channels and timestamps by inferring their dependencies without any additional supervision. Moreover, a multi-view convolutional representation module is constructed to help the proposed attention mechanism fuse PSG data. We conduct cross-subject experiments in comparison with ten baseline methods, and demonstrate the effectiveness of our proposed HybridAtt model on a benchmark PSG dataset in two feature domains, referred to as the time and frequency domains. We summarize our main contributions as follows: • We propose HybridAtt, an end-to-end hybrid self-attention deep learning framework for sleep stage classification using multivariate PSG records. • HybridAtt explicitly extracts the dual correlations of PSG channels and timestamps by inferring their dependencies based on multi-view convolutional representations. • We empirically show that HybridAtt consistently achieve best performance compared with ten baselines on a benchmark dataset under different feature domains.

Methods
In this section, we introduce the technical details of our HybridAtt model with multivariate PSG inputs. We first describe the overall architecture and then detail the main components of HybridAtt. Figure 1 presents the architecture of our proposed HybridAtt model. The goal of HybridAtt is to capture dual correlations of PSG channels and timestamps by calculating the dependencies of their multi-view convolutional representations, in order to improve the performance of sleep stage classification using multivariate PSG records. Formally, we assume that there are M multivariate PSG records with T (M) timestamps, denoted as

Model architecture
. Each record X t at timestamp Here we use h t and c t to denote the learned hidden state and context vector at timestamp t. Finally, we can further obtain an attentional hidden representationh t to predict the label y t ∈ {0, 1} |C| where |C| is the unique number of categories related to sleep stages. The proposed model can be trained in an end-to-end fashion.

Multi-view convolutional representation
In practice, the collected PSG data often tend to be heterogeneous, referred to different sample rates, signal strengths, and rhythm patterns. Inspired by the rapid development of multi-view deep learning [26][27][28][29], we propose to modify the CNN structure to preserve the unique characteristics of each biomedical channel during feature representation. Given the input x c t in the c-th channel at timestamp t, we use a 1-D channel-CNN encoder (i.e., CNN c ) to derive its channel-view representation d c t ∈ R p , as follows: where θ c denotes the learnable parameter set of CNN c . Similarly, we utilize a 2-D global-CNN encoder (i.e., CNN g ) to obtain a global-view representation s t ∈ R p based on all the channels, as follows: where θ g denotes the learnable parameter set of CNN g .
Here we align the input dimension of each channel using linear interpolation to obtain a matrix input for Eq. (2). In order to unleash the power of the multi-view convolutional representation module, we further polish the CNN structure design in our HybridAtt model, as shown in Fig. 2. The main design strategy consists of two aspects. First, the convolutional layer should cover multiple resolution scales since the waveform patterns of biosignals are related to different frequency modes [30]. Here we set different sizes of feature kernels in parallel to extract multiscale features from biosignals. Second, CNN c and CNN g should focus on different characteristics of input data during feature learning. Towards this end, we guide these two encoders by setting max pooling for CNN c to extract the most important features of different channels, and setting average pooling for CNN g to retain more general information among all the channels. Taking the advantage of the multi-view structure, informative features with same dimensions can be well learned from the heterogeneous PSG inputs, and hence further help the hybrid attention mechanism capture the dual correlations.

Hybrid self-attention mechanism Channel-wise attention
In order to capture the complex correlations among PSG channels, we develop a channel-wise attention layer that is able to infer the importance of each channel based on the learned multi-view features, and fuse representations relied on more informative ones. Given the multi-view features d c t and s t obtained by Eqs. (1) and (2), we first compute a fusional rate r c t ∈ R for each channel c at timestamp t, inferring how much information carried by each CNN encoder should be fused. The formulation is as follows: where W rg ∈ R p , W rc ∈ R p , and b rc ∈ R are learnable parameters. Here we rescale r c t into the range of [ 0, 1] using sigmoid function σ (·) in Eq. (3). Then, we assign an attention energy e c t for each channel c based on its fusional rate r c t , as follows: where W ec ∈ R p and b ec ∈ R are learnable parameters, and denotes the element-wise multiplication operator. Given the attention energy, a channel-wise contribution score vector α t ∈ R C can be normalized using softmax function, as follows: Each element α c t in the vector measures the importance of information carried by the c-th channel.
Accordingly, we use weighted aggregation to calculate the output vector of the channel-wise attentionx t ∈ R 2p based on the contribution score vector α t : where ⊕ is the concatenation operator. In this way, our model can fully incorporate the multi-view information carried by both two feature views, and thus fuse more informative features from multivariate PSG records.

Time-wise attention
To capture the dynamic correlations across timestamps, the aforementioned attention strategy can be employed as well, namely time-wise attention. Given the learned vector sequence fromx 1 tox T , we derive the hidden state h t ∈ R 2q through a 2-layer BGRU [31], as follows: where θ r is the learnable parameter set of BGRU. Here the hidden state h t at timestamp t is obtained by concatenating the forward hidden vector − → h t ∈ R q and the backward hidden vector ← − h t ∈ R q in BGRU. Subsequently, we can reformalize the attention strategy from Eqs. (3) to (5), to compute the time-wise contribution score vector β t ∈ R T : 1 , · · · , e t,i , · · · , e t,T ] ), where W rt ∈ R 2q , W ri ∈ R 2q , b rt ∈ R, W et ∈ R 2q , and b et ∈ R are the learnable parameters. Finally, a temporal context vector c t ∈ R 2q can be derived as the output of the time-wise attention:

Unified neural classifier
With the help of our hybrid attention mechanism, we can obtain an attentional representationĥ ∈ R r by fusing the context vector c t and the current hidden state h t , defined as:ĥ where W s ∈ R |C|×r and b s ∈ R |C| are the learnable parameters. To train HybridAtt in an end-to-end manner, we employ cross-entropy to measure the classification loss between theŷ t obtained by Eq. 9 and the ground truth y t . The cost function of our unified HybridAtt model J HybridAtt is defined as: T (1) , · · · , X (M)

Results and discussion
In this section, we evaluate HybridAtt on a benchmark PSG dataset in two feature domains, referred to as the time and frequency domains. We first introduce the dataset, then describe the baselines and some experiment details. We finally present and discuss the quantitative results in terms of different evaluation metrics.

Dataset description
We

Baselines
We compare HybridAtt with the following ten existing biosignal feature learning baselines: SVM. SVM is a classic machine learning method. Here we use one-vs-all SVM for the five-class classification task. To avoid the curse of dimensionality, we utilize principal component analysis (PCA) to select top-r related components from all the PSG channels as features before training SVM, namely PSVM.
Deep neural networks (DNN). DNN is a basic multi-layer neural network. We train a 3-layer DNN with softmax by concatenating all the PSG channels as input.
RNN. RNN is designed for time series. Similar to DNN, we concatenate data and train the same BGRU structure as HybridAtt used with a softmax layer.
RNNAtt. RNNAtt is a RNN variant with attention mechanism. We add two existing attention strategies, called location-based and concatenation-based attention [33], after the BGRU structure, referred to as RNNAtt l and RNNAtt c , respectively.
CNN. CNN is a commonly used deep learning model for biosignals. We integrate the PSG data as a matrix, and train the same CNN structure in our multi-view convolutional representation module.
CRNN. CRNN is a CNN variant combined with RNN. Here we directly integrate the aforementioned CNN and BGRU to train a unified model.
CRNNAtt. CRNNAtt utilizes attention mechanism after the CRNN structure. Similarly, we perform the same process as RNNAtt, namely CRNNAtt l and CRNNAtt c , respectively.
ChannelAtt. ChannelAtt [34] is proposed to soft-select critical channels from multivariate biosignals using a global attention mechanism. Different from the original model using fully-connected layer for feature extraction, we use the proposed CNN structure as the feature encoder to train the model.

Our approaches
To fairly evaluate our proposed attention strategy, we show the performance of the following two approaches in the experiments.
HybridAtt l . HybridAtt l is a reduced model using the location-based attention mechanism in HybridAtt for sleep stage classification.
HybridAtt f . HybridAtt f uses the proposed attention strategy to calculate score vectors in the channel-wise and time-wise layers.

Evaluation criteria
To quantify the performance, five evaluation measurements are used to validate HybridAtt for PSG-based sleep stage classification. Both accuracy and F1-score are adopted for evaluation. Here we employ Macro and Micro metric to measure F1-score, namely Macro-F1 and Micro-F1, respectively. The Macro-based area-under-thecurve (AUC) of precision-recall (PR) and receiver operator characteristic (ROC) are also utilized to evaluate each approach, namely AUC-PR and AUC-ROC, respectively. Moreover, to evaluate our model as a general cross-subject classifier, we perform 5-fold subject-independent cross validation and report the average test performance with standard deviation (μ ± σ ) for each method. The ratio of training, validation and test sets is 0.7 : 0.1 : 0.2.

Implementation details
We implement all the approaches with Pytorch. The training process is done locally using NVIDIA Titan Xp GPU. Adadelta [35] is adopted for the training process to optimize the cost function in terms of the learnable parameters. We also use weight decay with 0.001 L2 penalty coefficient, 0.95 momentum, and 0.5 dropout rate for all the approaches. The structure configuration of our multi-view convolutional representation module is listed in Table 1, and we set p = 128, q = 128, and r = 128 for our models and baselines.

Experimental results
We investigate the effectiveness of our proposed Hybri-dAtt model, compared to the aforementioned baseline methods in the task of sleep stage classification. Tables 2  and 3 report the comparison results tested in the frequency and time domains, respectively. We highlight the best evaluation scores in boldface. We observe that Hybri-dAtt achieves the best performance compared with the corresponding baselines in both feature domains on the UCD dataset. Given the results of the baselines, the performance of the traditional classification method PSVM is better than DNN and the RNN-based models in the frequency domain, but worse in the time domain. It means that the raw frequency features of PSG data would carry distinctive information which help SVM learn relatively clear hyper-lane to separate sleep stages. The results of DNN between two feature domains also make the same observation, demonstrating the capability of the PSG spectral features. The limited improvement of attention-based RNN models, compared with RNN in both domains, show that the features learned by RNN does not provide enough information for attention mechanisms to make correct classification. This also indicates that simply concatenating PSG data is unsuitable for fully-connected networks to learn informative features, since it would ignore multivariate prior information. We can see that CNN-based models get better performance than the other baselines, benefiting from the proposed structure in the multi-view convolutional representation module. Compared with CRNN, the attention-based CRNN models perform better, because attention mechanism is able to fuse features based on more useful information carried by sequential representations. To fuse features from multi-channel representations using attention mechanism, ChannelAtt works well in classifying sleep stages and achieve better results than CRNNAtt. It illustrates that the hidden connections among PSG channels, captured by ChannelAtt, are more helpful for sleep stage classification. Furthermore, by capturing dual correlations among channels and timestamps, our proposed HybridAtt model consistently gains the best evaluation scores in both the time and frequency feature domains.
From the results of our models, HybridAtt f outperforms the baselines in terms of all five evaluation measurements. For example, HybridAtt f obtains the best accuracy of 0.7424 in the time domain, compared with 0.7317 and 0.7169 achieved by our reduced model HybridAtt l and the baseline model ChannelAtt, respectively. Compared the results between two domains, on one hand, we observe that the results of HybridAtt l performs on par with those of ChannelAtt in the frequency domain. It means that adopting traditional location-based attention in the channel-wise layer cannot capture enough information We conjecture that CNN has a similar convolution procedure as STFT, but CNN adopts learnable kernels during convolution while STFT employs fixed Fourier functions. Taking the advantage of end-to-end learning, our hybrid attention mechanism can help learn more representative convolutional kernels in CNN than the handcrafted window functions in STFT. Figure 3 illustrates the ROC and PR curves of all the test folds on the UCD dataset, respectively. We observe that the proposed HybridAtt f method consistently gains the best AUC in terms of the PR and ROC in different domains, demonstrating an effective crosssubject method in the task of sleep stage classification. Based on the overall performance comparisons, we conclude that attention mechanism is key to identify sleep patterns for sleep stage classification. Adopting singledimension attention in different aspects, such as CRN-NAtt and ChannelAtt, may lose useful information dealing with multivariate PSG records. Multi-view representation is also essential for attention mechanism inferring important information. By constructing hybrid attention

Conclusions
In this paper, we present a unified hybrid self-attention deep learning framework, namely HybridAtt, to classify sleep stages from multivariate PSG records. HybridAtt is designed to capture dual correlations among channels and timestamps based on multi-view convolutional feature representations. Experiments on a benchmark PSG dataset show that HybridAtt is able to efficiently fuse multivariate information from PSG data and hence consistently beats the baselines in both the time and the frequency feature domains. In future work, we will extend HybridAtt to other biomedical applications with similar data structure, and propose advanced attention mechanism that can jointly learn two-dimensional contribution scores in one step, instead of adopting the multi-step attention strategy.