In a typical diagnosis, clinicians undertake a symptom investigation and diagnose a disease based on observed symptoms (Fig. 1). The detailed architecture of the proposed dialogue system, A-SIDDS (association guided symptom investigation and diagnosis dialogue system), is illustrated in Fig. 2. A patient initiates the diagnosis process by informing their suffering symptoms (explicit symptoms). The controller policy of the proposed dialogue system acts as a clinic receptionist, which activates a lower-level department policy as per the patient report. The activated departmental policy conducts a symptom investigation guided by Association and Recommendation Module (ARM). Once the lower-level policies collect adequate information, the controller policy activates the disease classifier, which diagnoses patients’ diseases depending on the collected information. The detailed working methodologies of each module are as follows:
Symptom investigation
Symptom investigation is the first stage of diagnosis, where doctors conduct an investigation and extract other relevant symptoms depending upon patients’ reported chief complaints and other confirmed symptoms during inspections. Thus, the agent aims to learn appropriate and intelligent behavior for collecting adequate symptom information in minimal time, i.e., an optimal diagnosis dialogue policy. The policy learning loop consists of three main components: I. Diagnosis Policy Learning, II. Association & relevance module (ARM), and III. Internal & Association and Recommendation (AR) critic. Each sub-stage and its detailed working method are explained below.
Diagnosis policy learning
Diagnosis policy (\(\pi\)) is the decision function, which decides whether to investigate symptoms or predict disease after observing symptoms, i.e., \(a = \pi (S)\), where S is a set of observed symptoms, and a could be a symptom or disease. To improve investigation efficacy and patient satisfaction, clinics used to have different departments such as ENT (Ear, Nose, and Throat) and pediatrics, etc. Motivated by the real-world scenario and the promising results obtained by Liao et al. [12, 37], we also utilized a hierarchical policy learning method, where the higher-level policy (controller) activates one of the lower-level policies (departmental) depending on patients’ self-report and other symptoms and the department policy conducts group-specific symptom investigation.
Controller policy
Controller policy is the first layer policy, which is responsible for activating an appropriate department policy (DP\(_i\)) and disease classifier for symptom inspection and disease projection, respectively. It can be seen as a clinic’s receptionist who refers patients to a particular department as per their chief complaint/self-report. It is also responsible for triggering the disease classifier (DC) once the lower policies (department policies) collect adequate symptom information. The controller policy selects an action (ac) depending upon current dialogue state (S) as follows: \(ac = P (A_c \vert S, \pi _c)\) where \(\pi _c\) is the controller policy, \(A_c\) is its action space which consists of department policies (\(DP_{i}\)) and disease classifier. For each action ac on a state S, the agent gets a penalty/reward (r\(_c\): Reward(S, ac)) depending upon the effectiveness of the taken action as follows:
$$\begin{aligned} \begin{aligned} rc_{t} = {\left\{ \begin{array}{ll} \sum _{i=1}^{n} \gamma ^{i}_{c} r_{t+i}^{d}, &{} \quad \text {if } ac_t = DP_{i} \\ r_{t}^{d}, &{} \quad \text {if } ac_t = DC \\ \end{array}\right. } \end{aligned} \end{aligned}$$
(1)
where i is the number of turns taken by the activated lower level policy corresponding to the master action, ac\(_t\). The agent aims to maximize the cumulative reward over episodes (\(R = \sum _{n=1} ^{N} \sum _{t=0} ^{T} \gamma _{c}^{t} * rc_{t}\)), leading to adequate symptom investigation and thus accurate diagnosis. Here, N, T are the number of dialogues in an episode and the number of turns in n\(^{th}\) conversation, \(\gamma _c\) is discounted factor which governs the role of immediate and future rewards in policy learning.
The controller policy \(\pi _{c}\) is optimized using a value-based deep reinforcement learning technique called Deep Q Network (DQN) [38]. It learns a state-action value function (Q\(^c\) (S, ac)), which estimates a value for each action (department) for a given dialogue state S (informed symptoms). The policy selects an action with highest Q value (reward), i.e., \(ac = argmax_i {Q^{c} (S, A_{c}^{i} | \pi _c})\). The \(Q^c\) function has been calculated and optimized through Bellman equation [39] and temporal difference (TD) loss [40] as follows :
$$\begin{aligned} Q^c(S_t, ac_t)= & {} {\mathbb {E}} [rc_t + \gamma _{c} * \text {max}_{ac_{t+1}} Q^c(S_{t+1}, ac_{t+1})] \end{aligned}$$
(2)
$$\begin{aligned} L_{t}^{c}= & {} [(rc_{t} + \gamma _c * \text {max}_{ac_{t+1} \in A_c} Q^c (S_{t+1},ac_{t+1} \vert \pi _c^{t-1}, \theta ^{t-1})) - Q^{c}(S, a \vert \pi _c^t, \theta _{t} ]^2 \end{aligned}$$
(3)
where \(L_{t}^{c}\) is the loss at \(t^{th}\) time step, which is difference between state-action value calculated through current policy parameter (behavior network : \(\theta _t\)) and previously froze policy parameter (target network : \(\theta _{t-1}\)).
Departmental policy
The departmental/lower lever policies (DP\(_i\): \(\pi ^{i}\)) are responsible for symptom inspection corresponding to their departments. The proposed model has nine departmental policies corresponding to each disease group. These departmental policies learn to select an appropriate action (symptom for inspection) depending upon the current dialogue state, which contains informed/confirmed symptoms. It selects an action (a\(_i\)) as follows:
$$\begin{aligned} a_i = \text {argmax}_j Q^i (A_{ij} \vert S, \pi ^i) \end{aligned}$$
(4)
where \(Q^i\) is state-action value function of \(i^{th}\) department policy (\(\pi ^{i}\)) and A\(_{ij}\) is j\(^{th}\) action of i\(^{th}\) departmental policy. The state, S, consists of the status of informed and inspected symptoms, dialogue turn, agent’s previous actions, K most relevant symptoms predicted by the ARM module, and reward. The size of the action space of each policy is N\(_i\) + 1, where N\(_i\) is the number of symptoms in i\(^{th}\) department. The additional action is to return the control to the controller policy. The department agent gets a reward /penalty (internal and ARM critic) at each time step depending upon the appropriateness and relevance of agent’s action (a\(_i\)) to the current state (S). These policies (\(\pi ^{i}_d\)) have also been optimized using the DQN algorithm as the controller policy (Equs. 2 and 3).
Association and relevance module (ARM)
The Association and Relevance Module (ARM) is responsible for conducting knowledge-aware, association-guided symptom investigation for adequate symptom information extraction. The module gets the current state (S\(_t\)) and inspected symptom (Sym\(_t\)) as inputs, and it outputs an association score & symptom recommendation (RS\(_t\)). The association module provides an association score (as\(_t\)) depending upon the relevance of the currently requested symptom (Sym\(_t\)) with the confirmed symptoms (SS), i.e.,
$$\begin{aligned} as_t = \sum _{k=1}^{n_t} Association (Sym_t, SS_k) \end{aligned}$$
(5)
where SS is the set of inspected and confirmed symptoms (including patient self-report) till t\(^{th}\) turn of the dialogue and n\(_t\) is the number of symptoms in it. The association score is provided as a critic to the agent, reinforcing the agent to conduct an association-aware symptom investigation. We construct and utilize a symptom-symptom knowledge graph to calculate the associations between two symptoms. In the knowledge graph, nodes represent symptoms, and an edge between two nodes signifies the co-relation between these two symptoms. The edge between two nodes/symptoms (S\(_i\), S\(_j\)) is determined based on the frequency of their co-occurrence. The weight of the edge from the symptom \(S_i\) to \(S_j\) is computed as follows:
$$\begin{aligned} Association(S_i, S_j) = \frac{n(S_i, S_j)}{\sum _k n(S_i, S_k)} \end{aligned}$$
(6)
where \(n(S_i, S_j)\) is the number of instances in the diagnosis dataset, where S\(_i\) and S\(_j\) have co-occurred. The term k ranges in the entire symptom space (Sy). Here, the denominator represents the number of instances where the symptom \(S_i\) has occurred with symptom \(S_k\) (\(S_k\) \(\in\) Sy). Thus, the association score of the symptom \(S_i\) with \(S_j\) signifies the chances of occurrence of \(S_j\) with it. The high value of the association score (\(S_i\), \(S_j\)) indicates that a patient is most likely to suffer from symptom \(S_j\) if he/she observes symptom \(S_i\).
A symptom may strongly suggest the existence of another symptom, which are caused by a common condition. For instance, when we think about cold, the next symptom that comes to our mind is cough. Cold and cough often co-occur together. Motivated by the observation, the proposed model incorporates a recommendation module, which recommends some of the most relevant symptoms (RS) from the entire symptom set (Sy) depending upon confirmed symptoms, SS. It selects top K symptoms from symptom space, which are highly relevant to the current context (confirmed symptom set, SS) and co-occur together. This module utilizes association scores for determining top K relevant symptoms as follows:
$$\begin{aligned} RS = \Pi _{i=1}^{K} argmax _{s \in Sy} \sum _{j=1}^{|SS|} Association(SS_j, s) \end{aligned}$$
(7)
These recommended symptoms are reflected in the current dialogue state, and the agent is reinforced to investigate these most relevant symptoms through the recommendation critic. This module aids the agent in conducting a knowledge-aware, association guided symptom investigation, which improves the user experience and reduces the number of turns required to diagnose the patient.
Internal and association and recommendation (AR) critics
A reinforcement learning agent’s reward model is one of the most critical elements, which implicitly supervises the agent for the underlying task. We propose and incorporate two novel reward functions, namely recommendation-based critic and association-based critic, to reinforce the agent for conducting context-aware, association-guided symptom investigation. The critics (intrinsic critic: r\(_d\), r\(_{rr}\): recommendation based critic, and r\(_{ar}\): association based critic) are defined as follows:
$$\begin{aligned} r_{d}= & {} {\left\{ \begin{array}{ll} = + t_1 * N &{}\text {if success}\\ = + t_2 * N , &{} \text {if } match (Sym_t) = 1 \\ = -t_3 * N, &{} \text {if repetition} \\ = 0, &{} \text {Otherwise} \end{array}\right. } \end{aligned}$$
(8)
$$\begin{aligned} r_{rr}= & {} {\left\{ \begin{array}{ll} = + t_4 &{}\text {if } Sym_t \in RS_t\\ = - t_5, &{} \text {Otherwise } \end{array}\right. } \end{aligned}$$
(9)
$$\begin{aligned} r_{ar}= & {} {\left\{ \begin{array}{ll} = + t_4 &{}\text {if } as_t >h\\ = +1 &{} \text {if } l< as_t < h \\ = - t_5, &{} \text {Otherwise} \end{array}\right. } \end{aligned}$$
(10)
$$\begin{aligned} {{\mathrm {r}}}= &{} {{\mathrm {r}}}_{{\mathrm {d}}} + ({{\mathrm {r}}}_{{\mathrm {rr}}} + {{\mathrm {r}}}_{{\mathrm {ar}}}) \end{aligned}$$
(11)
where N and \(t_i\) are the maximum no. of allowed turns for diagnosis and shaping parameters, respectively. The term, match(Sym\(_t\)) = 1 indicates that the department policy has requested a symptom (Sym\(_t\)) that the patient is truly suffering from. Here, \(Sym_t, RS_t\), and \(as_t\) are the agent’s requested symptom, recommended symptoms, and association score between \(Sym_t\) and other conformed symptoms (SS) at t\(^{th}\) turn, respectively. The terms l, h denote the lower and desired thresholds for association scores between requested symptom and confirmed symptoms (SS), respectively.
The internal critic (r\(_d\)) reinforces to complete the task successfully, whereas the immediate rewards (recommendation: r\(_{rr}\) and association: r\(_{ar}\)) act as the task behavior shaping elements. The recommendation and association reward models provide a reward/penalty depending upon the appropriateness of agent action and its relevance in relation to dialogue context (already informed symptoms including patient self-report, SS). If the agent inspects a recommended symptom, it gets a reward (case 9.1); otherwise, it gets a small penalty. The association reward (r\(_{ar}\) provides a reward/penalty proportional to the relevance (association score) of the currently requested symptom with the ongoing context/confirmed symptoms (SS), which motivates the agent to enquire relevant and knowledge-grounded symptoms.
Diagnosis state tracker, patient and disease classifier
Diagnosis state tracker is responsible for tracking dialogue (diagnosis) state, which contains information about inspected symptoms, dialogue turn, and the agent’s previous actions. After each agent and user turn, the state tracker updates dialogue state with essential information such as agent requested symptoms, user response, turn number, and the reward/critic corresponding to agent action. We have developed a pseudo environment/user simulator similar to the popular task-oriented user simulators [12, 41]. The user simulator initializes each diagnosis session with a diagnosis case from training samples. At the first turn of a conversation, the patient simulator informs the diagnosis agent’s self-report (all explicit symptoms) and asks to identify the disease/condition that the patient may be experiencing. Then, the simulator responds to each agent’s request for symptoms as per the sampled diagnosis case during the conversation. Disease classification is the final stage, which diagnoses a disease depending upon the extracted symptoms (including the patient’s self-report). In our work, it is a two-layered deep neural network, which takes a one-hot encoding representation of symptom status as input and predicts the probability distribution over all diseases.
Investigation relevance score (IReS)
Automatic disease diagnosis is a sequential decision problem in which an agent interacts with end-users over time for symptom investigation and then diagnoses the most appropriate disease based on the observed symptoms. Thus, an adequate set of symptom collection is critical to accurate diagnosis, which directly influences end-users engagement with the system. A single irrelevant symptom inspection by a diagnosis agent may cause end-users to lose trust in the system, resulting in the termination of dialogues in a large number of such cases. For instance, a person comes with difficulty of breathing, and if an agent inspects some irrelevant (less relevant) symptoms such as skin growth and knee swelling, the end-user may become annoyed and terminate the chat. However, the existing works [10, 12, 15, 24] employ objective metrics such as diagnosis accuracy and symptom investigation time for measuring their proposed models’ efficacy. Motivated by the vital significance of symptom relevance and the inability of existing evaluation metrics to capture this critical aspect, we formulate and propose a novel automatic evaluation metric called Investigation Relevance Score (IReS) for evaluating a diagnosis agent’s efficacy in terms of the relevance of symptoms inspected by the agent during symptom investigations. The metric is formulated as follows:
$$\begin{aligned} I \, ReS-1= & {} \frac{\sum _{i=1}^{m} \sum _{j=1}^{j=n} Association(S_j, PSR_i)}{t* \sum n_i} \end{aligned}$$
(12)
$$\begin{aligned} I \, ReS-2= & {} \frac{\sum _{i=1}^{m} \sum _{j=1}^{j=n} Association(S_j, SS_{ij})}{t* \sum n_i} \end{aligned}$$
(13)
where m is the number of testing samples, and \(n_i\) is the number of turns taken by the agent for i\(^{th}\) diagnosis test sample. The term PSR\(_i\) and SS\(_{ij}\) denote patient self-report and confirmed symptoms till j\(^{th}\) turn of i\(^{th}\) sample, respectively. The IReS-1 measures the relevance of symptom investigation with patient self-reported symptoms (PSR), and the IReS-2 measures the relevance of symptom investigation with the ongoing context (confirmed symptoms, including PSR)