In this section, we introduce the construction of functional brain networks, GAT-LI method including GAT2 model and interpretation method, and then we verify the proposed method through classification and interpretation experiments.
Construction of functional brain networks
The process of functional brain network construction from resting-state fMRI data is shown in Fig. 1.
Node of network The whole brain is parcellated into N ROIs using the brain atlas. Therefore, each network has N nodes. We use the Harvard Oxford (HO) atlas [23], so we have N = 110 nodes.
Edge and connectivity matrix The mean time series of each ROI are extracted, and the resting-state functional connectivity (rsFC) between ROIs are measured by computing the Pearson’s correlation coefficient of the extracted time-series. A \(\mathrm{N}\times \mathrm{N}\) connectivity matrix is constructed for each subject respectively, which can be represented as
$$S=\left[\begin{array}{cc}\begin{array}{cc}{\rho }_{{r}_{1},{r}_{1}}& {\rho }_{{r}_{1},{r}_{2}}\\ {\rho }_{{r}_{2},{r}_{1}}& {\rho }_{{r}_{2},{r}_{2}}\end{array}& \begin{array}{cc}\cdots & {\rho }_{{r}_{1},{r}_{N}}\\ \cdots & {\rho }_{2,{r}_{N}}\end{array}\\ \begin{array}{cc}\vdots & \vdots \\ {\rho }_{{r}_{N},{r}_{1}}& {\rho }_{{r}_{N},{r}_{2}}\end{array}& \begin{array}{cc}\vdots & \vdots \\ \cdots & {\rho }_{{r}_{N},{r}_{N}}\end{array}\end{array}\right],$$
(1)
where \({r}_{i}\) represents the \(i\)th ROI.
Edge weight For the connected edges between two nodes, the edge weight is expressed by the absolute value of the Pearson correlation coefficient between the time series of the nodes. That is, for node \({r}_{i}\) and node \({r}_{j}\), the edge weight between the two nodes is \(|{\rho }_{{r}_{i}{r}_{j}}|\).
Node feature The node feature (or node attribute) of each node (ROI) is represented by its functional connectivity profile with the rest of the regions [8], corresponding row of the connectivity matrix, such as:
$${{\varvec{h}}}_{i}=\left\{{\rho }_{{r}_{i},{r}_{1}},{\rho }_{{r}_{i},{r}_{2}},\dots ,{\rho }_{{r}_{i},{r}_{N}}\right\}.$$
(2)
Based on the number of nodes N = 110, a \(110\times 110\) connectivity matrix is constructed for each subject respectively, and the dimensions of node feature is 110.
GAT2 model
The architecture of the GAT2 model is illustrated in Fig. 2. The model is composed of two parts: the node representation learning part, and the pooling-and-prediction part. First, the node representation learning part learns the feature representation of the node with the graph attention networks. Then, the pooling-and-prediction part learns the graph representation based on node representation, and learns the prediction probability.
Node representation learning The input to the layer is a set of node features, \(\mathbf{h}=\left\{{{\varvec{h}}}_{1},{{\varvec{h}}}_{2},\dots ,{{\varvec{h}}}_{N}\right\}\), \({{\varvec{h}}}_{i}\in {\mathbb{R}}^{F}\), where N is the number of nodes, F is the dimensions of node features. The graph attention layer [17] uses self-attention mechanism to aggregate the node’s 1-hop neighborhood nodes to compute the node representation. The attention coefficients are computed as follows:
$${a}_{ij}=softmax\left(LeakyReLU\left({{\varvec{a}}}^{{\varvec{T}}}\left[\mathrm{W}{{\varvec{h}}}_{i}\Vert \mathrm{W}{{\varvec{h}}}_{j}\right]\right)\right),$$
(3)
where \({\varvec{a}}\in {\mathbb{R}}^{2{F}^{{\prime}}}\) and the self-attention is included in \({\varvec{a}}\). Masked attention is used to introduce network structure information, and attention is only assigned to the neighbor node set \({N}_{i}\) of node \(i\). The node representation generated from multi-head attention is computed as follows:
$${{\varvec{h}}}_{i}^{{{\prime}}}={\parallel }_{k=1}^{K}\sigma \left(\sum\limits_{j\epsilon {N}_{i}}{a}_{ij}^{k}{\mathrm{W}}^{k}{{\varvec{h}}}_{j}\right),$$
(4)
$${{\varvec{h}}}_{i}^{{\prime}}=\sigma \left(\frac{1}{K}\sum\limits_{k=1}^{K}\sum\limits_{j\in {N}_{i}}{a}_{ij}^{k}{\text{W}}^{k}{{\varvec{h}}}_{j}\right),$$
(5)
where the Eq. (4) uses \(\parallel\) as the concatenation operation, connecting the feature representations obtained by each attention; Eq. (5) is used to obtain the node representation of the last layer by averaging the features with multiple attentions; and \(\sigma =\frac{1}{1+{e}^{-x}}\).
Graph attention pooling For summarizing graph representation from nodes representation, we provide a sharing weight vector for every node, and the new one-dimensional representation \({P}_{i}\) of each node is obtained through function mapping, as shown in Eq. (6). Finally, we get the graph representation \(\mathbf{P}=\left\{{P}_{1},{P}_{2},\dots ,{P}_{N}\right\}\) whose dimensions are equal to the number of nodes.
$${P}_{i}=\sigma \left({\mathrm{W}}^{p}{{\varvec{h}}}_{i}^{{\prime}}\right),$$
(6)
where \({\mathrm{W}}^{{\varvec{p}}}\in 1\times {F}^{{\prime}}\).
Prediction In order to pay attention to the contribution made by each node to the final prediction result, each node representation is given a weight, and the weight calculation is shown in Eq. (7):
$$\mathbf{A}=softmax\left({\mathrm{W}}^{A}\mathbf{P}\right),$$
(7)
where \({\mathrm{W}}^{{\varvec{p}}}\in \mathrm{N}\times \mathrm{N}\) and \(\mathbf{P}=\left\{{P}_{1},{P}_{2},\dots ,{P}_{N}\right\}\). Then, using the contribution weights, the weighted sum of the node representation is used for the prediction of the model, as shown in Eq. (8):
$$\mathrm{prob}=\sum\limits_{i=1}^{N}{\mathbf{A}}_{i}{P}_{i}.$$
(8)
Interpretation methods
We use GNNExplainer [19] to interpret the trained GAT2 model, and identify the important features in GAT2 model. We use the GNNExplainer to learn a feature mask that masks out unimportant node features, i.e., where if the value of an element in feature mask matrix is closely to zero, the corresponding feature would be considered unimportant. The dimension of the feature mask matrix is 110 × 110 in this study.
Experiments
Dataset and preprocessing
We used the resting-state fMRI data from 1035 subjects in the ABIDE I initiative [20] for this study. The dataset includes 505 individuals diagnosed as having ASD and 530 HC. The preprocessed resting-state fMRI data were downloaded from the Preprocessed Connectomes Project (http://preprocessed-connectomes-project.org/abide/download.html). The data were preprocessed by the Configurable Pipeline for the Analysis of Connectomes (CPAC) pipeline [24] that included the following procedure: slice timing correction, motion realignment, intensity normalization, regression of nuisance signals, band-pass filtering (0.01–0.1 Hz) and registration of fMRI images to standard anatomical space (MNI152).
Experimental setup
Given the above GAT2 model, we conducted experiments on the ABIDE I dataset with 1035 subjects and applied the interpretation method to explain the results.
To evaluate the performance of the proposed model, we used sensitivity, specificity, accuracy, F1 score, AUC, and Matthews correlation coefficient (MCC) as our metrics. These metrics are defined as follows:
$$sensitivity = \frac{TP}{{TP + FN}}$$
(9)
$$specificity = \frac{TN}{{TN + FP}}$$
(10)
$$accuracy = \frac{TP + TN}{{TP + FN + TN + FP}}$$
(11)
$$F1 = \frac{2 \times TP}{{2 \times TP + FP + FN}}$$
(12)
$$MCC = \frac{TP \times TN - FP \times FN}{{\sqrt {(TP + FP) \times (TP + FN) \times (TN + FP) \times (TN + FN)} }}$$
(13)
where true positive (TP) is defined as the number of ASD subjects that are correctly classified, false positive (FP) is the number of HC subjects that are misclassified as ASD subjects, true negative (TN) is defined as the number of HC subjects that are correctly classified, and false negative (FN) is defined as the number of ASD subjects that are misclassified as HC subjects. Sensitivity measures the proportion of correctly identified ASD subjects among all identified ASD subjects. Specificity measures the proportion of correctly identified HC subjects among all real HC subjects. AUC is defined as the area under the receiver operating characteristic curve.
Classification comparison models and parameters
The comparison models include (i) traditional machine learning methods: SVM, PCA + SVM, and RF; (ii) non-graph deep learning model: MLP, CNN; (iii) GCN layer based GNN models: GCN-at (1st-order), and GCN-at (Cheby); (iv) GAT layer based GNN models: GAT2, GAT-average, and GAT-fc. The comparison models and their corresponding parameters are described as follows.
SVM Support vector machine (SVM) model with linear kernel. SVM method is an accepted benchmark method and has been widely used to classify fMRI data for brain disorders. SVM model sets the value of parameter C to 1.0.
PCA + SVM First use principal component analysis (PCA) to reduce the dimension of feature vector and then input into the SVM model for training and classification. Using PCA to retain 99% of the feature information, the dimension is reduced to 700 dimensions, and the dimensionality-reduced vector is input into the SVM with linear kernel for training and classification. The coefficient C is set to 1.0.
RF Random forest (RF) is an ensemble learning method for classification. We trained RF with 300 trees, and the maximum depth of the tree is set to 30.
MLP The MultiLayer Perceptron (MLP) model has two fully connected layers with LeakyReLU activation function. The number of units of the two fully connected hidden layers is 64, 32 respectively. Dropout layer is added to avoid overfitting and the dropout rate is 0.5. The output layer with one neuron is followed by a sigmoid activation function. The model training uses the Adam Optimizer, the learning rate is set to 0.0005, and the loss function uses the cross-entropy loss function.
CNN The convolutional neural networks (CNNs) model contains three convolutional layers and two fully connected layers, the number of convolutional kernels is 32, 64, 128 respectively, the size of all kernels is 3 * 3, and the activation function uses ReLU function. The number of neurons is 1024, 2 respectively, and the activation function uses ReLU function.
GCN-at (1st-order), GCN-at (Cheby) In order to verify the effectiveness of GAT layer for node representation learning in the GAT2 model, we designed GCN-at (1st-order) and GCN-at (Cheby) models to classify the functional brain networks. In these two models, the GCN layer is used for training to obtain the node representation, and then the node representation is input into the same pooling-and-prediction part of GAT2 for prediction.
According to the implementation of the GCN layer proposed in [25], for the GCN-at (1st-order) model, the node representation is obtained from the GCN layer via a first-order approximation of localized spectral filters on graphs; for the GCN-at (Cheby) model, the node representation is obtained from the GCN layer via Chebyshev polynomials filter, the polynomial order is set to 3. The model contains one GCN layer, the number of units is set to 24, and the activation function uses the LeakyReLU function. The loss function uses the cross-entropy loss function.
GAT-fc, GAT-average, GAT-learn In order to verify the validity of the prediction part in the GAT2 model, we designed GAT-fc, GAT-learn, and GAT-average models to classify the functional brain networks.
In GAT-fc model, after obtaining the node representation vector through the GAT layer, the node representation vectors were spliced to obtain a one-dimensional vector, which is input into the fully connected layer for prediction.
The GAT-fc model contains two GAT layers, the number of attention heads is set to 5 and 3, the number of units is set to 24 and 3, respectively; the number of units of the fully connected layer is set to 64. The activation function uses LeakyReLU function. The output layer is followed by a softmax activation function. The loss function uses cross-entropy loss function.
In GAT-average model, after obtaining the node representation vector in the GAT layer, the node representation \({P}_{i}\) is mapped through the sigmoid function. Based on the average-pooling method in GCN [11], the final prediction probability of GAT-average model is obtained by averaging the information of each node, as shown in Eq. (14):
$$\mathrm{prob}=\frac{{\sum }_{i=1}^{N}{P}_{i}}{N}.$$
(14)
The GAT-average model contains two GAT layers, the number of attention heads is set to 5 and 3, the number of units is set to 24 and 3, respectively, and the activation function uses LeakyReLU function. The loss function uses the cross-entropy loss function.
In GAT-learn model, we use the learnable pooling method in [15] for GAT. The GAT-learn model comprises two GAT layers, one cascaded convolution-pooling blocks, and one fully-connected layer. The block generates an \(N\times 11\) feature map (\({Y}^{(l)}\)) and an \(N\times 1\) cluster assignment matrix (\({S}^{T}\)) in two separate paths, and combines them using pooling formulation of Eq. (15) to obtain a pooled feature map (\(Y^{pool}\)) of 1 * 11.
$$Y^{pool} = S^{T} Y^{(l)}$$
(15)
GAT2 The model contains two GAT layers, the number of attention heads is set to 5 and 3, the number of neurons is set to 24 and 3, respectively, and the activation function uses LeakyReLU function. The node representation \({P}_{i}\) is obtained through the sigmoid function. Then the weighted sum of the node information is used for the prediction of the model. And we also set different number of GAT layers and different number of attention heads for comparing these hyper-parameters setting.
For inputs fed into non-graph learning models including SVM, PCA + SVM, RF, MLP, the upper triangle values of connectivity matrices are extracted and flattened into vectors, with the dimension of the feature vector being \(\left(110\times \left(110-1\right)\right)/2=5995\). The whole connectivity matrices are used as inputs for CNN model.
All the above graph neural networks based models use Adam Optimizer for training and the learning rate is set to 0.0001. All the above deep learning models use the early stop mechanism, and the training is stopped if the test set for 15 consecutive rounds does not decrease in error rates.
Comparison of classification with different network construction methods
We conducted more experiments to compare the classification performance of the GAT2 model with different network construction methods.
-
(i)
Influence of network construction via different brain atlases
We used HO atlas [23] and Automated Anatomical Labeling (AAL) atlas [26] to divide brain regions, extracted functional connectivity features to construct brain networks, and compared the performance of classification with GAT2 model.
-
(ii)
Influence of network sparsity
Considering that even weak connections between nodes may record some information, so we used dense network representation for classification in the classification experiments, where the dense network is the original network without using thresholds to eliminate weak connections.
In this study, we set a threshold for the sparse brain network, and identified the influence of network sparsity. For the adjacency matrix, according to the edge weight value between nodes, only the connected edges whose edge weight value is greater than the threshold were retained. The GAT2 model was used for experimental comparison.
Validating GAT2 in a larger dataset
We also validated the performance of GAT2 model in a larger synthetic dataset. We constructed a graph classification dataset with 4000 graphs, where each graph had 30 nodes and the weight of each connection was randomly selected from 0 to 1. The graph dataset was divided into two categories based on the following steps: (a) 15 nodes from the graph were randomly selected; (b) the sum of the connection weights between these 15 nodes was defined as W1, the sum of the connection weights between these 15 nodes and the rest 15 nodes was defined as W2, the sum of each graph was defined as \(\mathrm{W}0=\mathrm{W}1\times 2+\mathrm{W}2\), and the average value of W0 of 4000 graphs was then calculated; and (c) if W0 was larger than the average values, the category of this graph was set to Class-one, otherwise the category of the graph was set to Class-two. We also used corresponding row of the connectivity matrix to be node feature similar to the construction of brain networks described in “Construction of functional brain networks” section.
We compared the classifying performance of GAT2 model against SVM, RF, and CNN, under the similar setting with the previous experiments of ABIDE dataset. Some specific model parameters used in this experiment are as follows: The GAT2 model contained two GAT layers, the number of attention heads was set to 4 and 4, the number of neurons was set to 16 and 16, respectively; the CNN model contained three convolutional layers and two fully connected layers, the number of convolutional kernels was 16, 32, 64 respectively; the RF had 128 trees, and the maximum depth of the tree was set to 20.
Interpretation experiments
-
(i)
Comparison methods
We also used Saliency Map [21] and DeepLIFT [22] as comparative interpretation methods. Saliency Map is a typical neural network interpretation method, which is based on gradient sensitivity. To apply Saliency Map to the GAT2 model, we calculated the gradient of the model loss relative to the input features, and analyzed the features according to the gradient value. The larger the gradient value, the greater the impact the corresponding feature has on the classification. DeepLIFT is a method that can decompose the output prediction of a neural network on a specific input by back propagating the contributions of all neurons in the network to each feature of the input.
We explored the impact of features on classifying functional brain networks of the ASD individuals. The sample feature dimension of the input model is N × F, in which N represents the number of nodes, and F represents the node feature dimension. As described in “GAT2 model” section, the constructed network has N = 110 network nodes and F = 110 features of each node. The steps of obtaining the characteristic gradient value are as follows: (a) for the test samples, the gradient of the model loss relative to the input features was calculated to obtain the gradient value of each feature; (b) for each feature, the average value of the gradient across all samples was identified and the absolute value of them was calculated.
-
(ii)
Interpretation experiments
We applied Saliency Map, DeepLIFT, and GNNExplainer to interpret the trained GAT2 model, and estimated the classification performance impact of GAT2 models by the feature perturbation. We then compared the change of GAT2’s prediction when modifying the same number of features to compare the quality of the two interpretation methods.
We hacked the model by setting value of the nodal feature in instance x to zero, and observe the changes of prediction of GAT2 in one-fold data from the above fivefold cross-validation data division. We used metrics including sensitivity, specificity, accuracy, the change of prediction probability (CPP) which is the absolute change of probability of classifying \(\mathrm{x}\) as a positive instance, the number of label-changed instance (NLCI) which is the number of instances whose predicted label changes after being hacked.