The process of the deep neural network lung cancer prediction model based on KL divergence gene selection proposed in this paper is shown in Fig. 1. Firstly, we use KL divergence to select the related genes to lung cancer as the input of the deep neural network lung cancer prediction model. Secondly, we build a deep neural network which uses focal loss as the loss function and use the training set to train the model. Finally, we use the validation set to verify the generalization performance of the lung cancer prediction model and select a prediction model with the best parameters.
Data collection and preprocessing
The data we used were extracted from TCGA portal (https://tcga‐data.nci.nih.gov/tcga/) and ICGC portal (https://dcc.icgc.org/). The TCGA dataset and ICGC dataset used in this paper is the RNA-seq gene expression data of lung adenocarcinoma (LUAD) samples from the TCGA dataset. The TCGA dataset contains 533 lung cancer samples and 59 normal samples. The ICGC dataset contains 488 lung cancer samples and 55 normal samples.
This paper uses python to process the data into a training format that TensorFlow and sklearn can recognize.
KL divergence gene selection
There are more than 60,000 genes in the RNA-seq data in the TCGA and ICGC database, and more than 20,000 genes with protein translation. When using too much genetic data to train a lung cancer prediction model, it is easy to overfit. In clinical practice, the number of available cancer samples is very small compared to the number of gene features, which leads to model overfitting and decreased prediction accuracy. Feature selection is a good way to deal with these problems [12]. By reducing the entire feature space to a subset of selected features, over-fitting of the prediction model can be avoided, thereby reducing the problems caused by small sample sizes and high-latitude data. We mentioned above that the existing differential analysis gene selection methods and machine learning-based gene selection methods have some shortcomings [13]. For example, differential analysis gene selection methods have requirements for data distribution, and gene selection methods based on machine learning require a lot of data, otherwise it is easy to overfit. Taking into account the shortcomings of the above methods, this paper proposes a gene selection method based on KL divergence.
KL divergence [14] (Formula (1)) is an asymmetry measure of the difference between two probability distributions over the same variable x(P and Q represent two data distributions). In practice, P represents the true distribution of the data, and Q represents the theoretical distribution of the data or the approximate distribution of P.
$$\begin{array}{*{20}c} {{\text{D}}_{{{\text{kl}}}} = - \mathop \sum \limits_{i = 1}^{i = n} P\left( i \right)*ln\frac{Q\left( i \right)}{{P\left( i \right)}} \ge 0} \\ \end{array}$$
(1)
The KL divergence is always greater than or equal to zero. When the two data distributions are the same, the value of the KL divergence is 0. The greater the difference between the two distributions, the greater the value of the KL divergence.
For gene expression data, we can easily get the data distribution of each gene in the disease group and the control group. We can easily get the data distribution using a small sample data set, and then use KL divergence to measure the difference between the two distributions. If the two distributions are consistent, it means that the gene has nothing to do with the disease. If the two distributions are quite different, it means that the gene is related to the disease.
KL divergence has the advantage of simple calculation. We can easily calculate the difference between the two distributions using small data set, which is suitable for small dataset such as gene expression dataset.
Building deep natural network model
This paper uses a deep neural network model to predict lung cancer. Deep neural network is inspired by the working principle of the brain and has been widely used in many fields. A deep neural network generates output based on input variables. Given a set of features and a target, it can learn to generate nonlinear function approximations value. Between the input and output, there are one or more nonlinear layers, called hidden layers. The deep neural network has multiple nonlinear hidden layers, which enable the deep neural network to learn complex nonlinear function relationships from high-dimensional raw data without the guidance of artificial rules [9].
Figure 2 shows the deep neural network model constructed in this paper. The leftmost layer is the input layer, the rightmost layer is the output layer, and the middle layer is a hidden layer composed of hidden neurons. Then we set the loss function that meets your needs. Gradually reducing the loss value during the training process achieves the purpose of model convergence. The specific formulas of the model inference process are shown in Formula (2) to Formula (6).
$$\begin{array}{*{20}c} {hidden\_layer\_1 = relu\left( {input*W_{1} } \right)} \\ \end{array}$$
(2)
$$\begin{array}{*{20}c} {hidden\_layer\_2 = relu\left( {hidden\_layer\_1*W_{2} } \right)} \\ \end{array}$$
(3)
$$\begin{array}{*{20}c} {\hat{y} = sigmod\left( {hidden\_layer\_2} \right)} \\ \end{array}$$
(4)
$$\begin{array}{*{20}c} {relu\left( x \right) = max\left( {x,0} \right)} \\ \end{array}$$
(5)
$$\begin{array}{*{20}c} {sigmod\left( x \right) = \frac{1}{{1 + e^{ - x} }}} \\ \end{array}$$
(6)
Formula (2) and Formula (3) are the calculations of the hidden layer, where input is the input information. W1 is the parameter of the first layer, relu(x) is a non-linear function, which is defined as Formula (5). Formula (4) is the predicted output value of the model, using the sigmod function as the activation function[15]. The sigmod function is defined as shown in Formula (6), and its output value range is 0–1, which conforms to the meaning of probability, so the output value is the model’s prediction probability.
The binary classification model generally uses cross-entropy as the loss function, as shown in Formula (7), where y represents whether the sample is diseased, the disease is 1, and the non-diseased is 0, and \(\hat{y}\) represents the estimated sample's disease probability, we can see that the smaller the model classification error, the smaller the value of the loss function Formula (7). The cross-entropy loss function can achieve better results on a balanced data set, however, the gene expression data set is unbalanced, it is easy to distinguish the number of samples is relatively too large, and ultimately dominates the total loss, leading to the prediction result tends to be a large number of parties.
$$\begin{array}{*{20}c} {loss = - y*log\left( {\hat{y}} \right) - \left( {1 - y} \right)*log\left( {1 - \hat{y}} \right)} \\ \end{array}$$
(7)
In order to solve the problem of imbalanced data, we use focal loss function. Its formula is as shown in Formula (8). Focal loss considers that samples that are easily distinguishable by the model (samples with high confidence) have a very small improvement performance on the model. The model should pay attention to samples that are not easy to distinguish, at the same time adjust the ratio of positive and negative samples, which is reflected in the parameter α can adjust the ratio of positive and negative samples, and parameter γ can adjust the weight of samples that are easy to distinguish and improve the weight of the samples that are not easy to distinguish.
$$\begin{array}{*{20}c} {loss = - y*\alpha *\left( {1 - \hat{y}} \right)^{{\upgamma }} *log\left( {\hat{y}} \right) - \left( {1 - y} \right)*\widehat{{y^{{\upgamma }} }}*log\left( {1 - \hat{y}} \right)} \\ \end{array}$$
(8)
We use the gradient descent method to adjust W1 and W2. Taking into account the shortcomings of existing methods and the characteristics of gene expression data sets, this paper uses TensorFlow to establish a three layers deep neural network and uses Adam as the gradient descent optimizer [16]. In order to solve the problem of imbalanced data, this paper uses focal loss as a loss function, and multiple rounds of training were performed on the training data set.