Patient Outcome and Zero-shot Diagnosis Prediction with Hypernetwork-guided Multitask Learning
Shaoxiong Ji and Pekka Marttinen Aalto University, Finland
{shaoxiong.ji; pekka.marttinen}@aalto.fi
Abstract
Multitask deep learning has been applied to patient outcome prediction from text, taking clinical notes as input and training deep neu- ral networks with a joint loss function of multiple tasks. However, the joint training scheme of multitask learning suffers from inter-task interference, and diagnosis predic- tion among the multiple tasks has the gener- alizability issue due to rare diseases or unseen diagnoses. To solve these challenges, we pro- pose a hypernetwork-based approach that gen- erates task-conditioned parameters and coeffi- cients of multitask prediction heads to learn task-specific prediction and balance the mul- titask learning. We also incorporate semantic task information to improves the generalizabil- ity of our task-conditioned multitask model.
Experiments on early and discharge notes ex- tracted from the real-world MIMIC database show our method can achieve better perfor- mance on multitask patient outcome predic- tion than strong baselines in most cases. Be- sides, our method can effectively handle the scenario with limited information and improve zero-shot prediction on unseen diagnosis cate- gories.
1 Introduction
Recent advances apply artificial intelligence to pre- dict clinical events or infer the probable diagnosis for clinical decision support (Shickel et al.,2017;
Li et al.,2021). Those works extensively study clin- ical data from Electronic Health Records (EHRs).
For example, the Doctor AI model for predictive modeling on EHRs (Choi et al., 2016) and the Deep Patient model for unsupervised patient record representation learning (Miotto et al., 2016) are proposed by leveraging longitudinal medical data.
Within the patient records, we can also find various clinical notes written by general practitioners or physicians and annotated with specific diagnosis re- sults, medical codes, and other clues of patient out- come. Those clinical notes consist of meaningful
health information, including health profile, clini- cal synopsis, diagnostic investigations, and medi- cations, which can empower the clinical decision making (Li et al.,2021). Clinical text mining has demonstrated its feasibility in diverse healthcare applications, including medical code prediction (Ji et al.,2021), readmission prediction (Huang et al., 2019) and diagnosis prediction (van Aken et al., 2021). Readmission and length-of-stay prediction can help with utility management, which is espe- cially important when healthcare service is in high demand. Diagnostic prediction can support clini- cians making decisions based on a large number of clinical reports. This paper focuses on multi- task patient outcome prediction from clinical notes, which classifies patients’ diagnostic results and pre- dicts clinical outcomes to empower expert clini- cians and improve healthcare services’ efficiency.
Fig.1demonstrates an example of multitask patient outcome prediction. The multitask deep learning model takes a series of clinical notes as input, in- cluding patient health profile, clinical synopsis, and laboratory test reports. It outputs several patients outcomes such as the length of stay at the hospital, possible diagnoses, and the readmission probability after hospitalization.
Existing work on clinical note representation learning and multitask clinical outcome prediction focuses on building neural architectures for feature learning (e.g., Huang et al.,2019;Mulyar et al., 2019) or considers different prediction tasks under the joint learning framework (e.g., Harutyunyan et al.,2019;Mulyar and McInnes,2020). However, there remain some unsolved challenges. First, the joint learning scheme for multiple tasks cannot ef- fectively deal with inter-task interference. Second, the task context information is usually underused, making multitask learning task-ignorant. Third, current learning algorithms are easy to overfit on seen training examples but fail to generalize on rare diseases or unseen diagnoses.
arXiv:2109.03062v1 [cs.CL] 7 Sep 2021
Timeline
Health Profile
Adverse Reactions and Alerts
Clinical Synopsis
Operative Report Blood Test
Medications Radiology Report
Xray Examination Findings
Findings Results
Directions Medication Duration Results Findings Summary and Report
Multitask Deep Learning Model
Figure 1: An example of multitask patient outcome pre- diction based on sequential inputs of clinical notes in the electronic health record. The multitask deep learn- ing model dynamically predicts the probabilities of sev- eral patient outcomes at different stages during hospital admission.
We propose a novel multitask learning method for patient outcome and zero-shot diagnosis predic- tion from clinical notes to solve the aforementioned challenges. We propose to incorporate task infor- mation to enable task-specific prediction in multi- task learning. Inspired by the hypernetworks (Ha et al., 2017) that use a small neural network to generate parameters for a larger neural network, we encode the semantic task information as the task context and use the encoded task embeddings as shared meta knowledge to generate the task- specific parameters of prediction heads for multiple tasks. Our proposed method effectively utilizes the task information and produces task-aware multi- task prediction heads. Our method is also gener- alizable on unseen diagnoses by maintaining the shared meta knowledge encoded as semantic task embeddings. Furthermore, we propose to dynami- cally update the weight coefficients of the multitask learning objective function to balance the learning from multiple clinical tasks.
Our contributions include:
• We propose to utilize task information via task embeddings for multitask patient out- come prediction spanning readmission, diag- nosis, and length-of-stay prediction from clin- ical text and develop a hypernetwork for task- specific parameter generation to share infor- mation among different tasks.
• We propose to regularize the objective func-
tion via a parameterized task weighting scheme that effectively balances the learning process among multiple tasks.
• Experiments on two datasets extracted from a real-world clinical database show that our proposed method outperforms strong multi- task learning baselines and achieves more ro- bust performance in zero-shot diagnosis pre- diction.
2 Related Work
2.1 Deep Learning on EHRs
Patient outcome prediction, including automatic coding (Friedman et al., 2004;Yan et al., 2010), patient seveirty (Naemi et al.,2020), in-hospital mortality, decompensation, length of stay and phe- notyping (Harutyunyan et al.,2019), can empower health practitioners to make better clinical decision.
Many studies have been conducted to investigate the usability of deep learning for accurate predic- tion using time-series, clinical text, or multimodal data. Zufferey et al. (2015) compared different multi-label learning algorithms.Che et al.(2015) proposed a deep phenotyping model regularized on the categorical structure of the medical ontol- ogy.Reddy and Delen(2018) utilized RNN-LSTM to predict hospital readmission for lupus patients.
Zhang et al. (2019) proposed an attention-based LSTM network to model the disease progression in a time-aware fashion. Similarly,Men et al.(2021) extended the LSTM model to be attention- and time-aware for multi-disease prediction. Convo- lutional neural networks have also been deployed in this field. For example,Bardak and Tan(2021) developed CNN-based networks for multimodal learning from patient records. Clinical notes as an essential modality of EHRs have also been stud- ied intensively. Ghassemi et al.(2014) extracted topic features from free-text patient data for mor- tality prediction.Ji et al.(2020) proposed a dilated CNN model for medical code prediction from clin- ical notes. Mulyar and McInnes(2020) extracted clinical information from clinical documents via fine-tuning large-scaled pretrained language mod- els.
2.2 Multitask Learning
Multitask learning learns multiple tasks simulta- neously with a shared representation and trans- fers the domain knowledge in related tasks to im- prove the learning capacity (Caruana, 1997). It
has also attracted much attention from the research community of machine learning and natural lan- guage processing for healthcare. Mahajan et al.
(2020) studied semantic similarity degrees of clini- cal notes by a multitask learning method with iter- ative data selection. Sun et al.(2021) proposed a multitask learning framework with feature calibra- tion for medical coding. Several multitask learn- ing schemes have also been developed for learn- ing from clinical time-series (Harutyunyan et al., 2019). Hypernetworks (Ha et al.,2017) that gen- erate weights for other networks have also been adopted to solve the learning problem with multiple tasks, for example, the Hyperformer that generates the parameters of adapter modules for fine-tuning NLP task streams (Mahabadi et al.,2021). Inspired by this work, our work starts from a different set- ting that learns different tasks simultaneously, uti- lizes the semantic information of clinical predica- tion tasks, and balances the learning objective via task conditioning.
3 Method
3.1 Problem Setup and Overall Architecture The clinical document in patient health records is denoted as d1:n = x1, . . . , xn, where each xi is the i-th token. Patient outcome prediction predicts clinical results by mapping the input text into pre- diction scores with a function F : Xn→ Ymsuch that y = F (x1, . . . , xn; θ), where y ∈ Rm is the patient outcome, m is the number of classes, and θ are the model parameters. Under the multitask task setting, we have k learning tasks {Tt}kt=1that are related to the outcome of clinical intervention. Mul- titask learning algorithm captures the relatedness of multiple tasks and improves the modeling (Zhang and Yang,2021). In this paper, we predict multiple clinical outcomes. Thus, the goal of the learning process is to fit a single-input multi-output function yt=F (x1, . . . , xn; θshared, θt), where yt∈ Rmt is the ground truth label and mt is the number of classes of the t-th task, θshared are the shared model parameters, and θtthe task-specific param- eters. The learning objective is to minimize the weighted sum of loss functions of multiple tasks, denoted as
L(θ) =
k
X
t=1
αtLt(yt, ˆpt; θshared, θt), (1) where αtis the weight coefficient of t-th task.
The progress of a clinical intervention comes with multiple records, as illustrated in Fig.1, lead- ing to lengthy clinical notes. FollowingHuang et al.
(2019), we segment long clinical documents into several chunks, i.e., the document segmented to r chunks becomes d1:r =c1, . . . , ci, . . . , cr, where the i-th chunk cicontains s tokens. The segmenta- tion has two advantages: 1) it enables progressive prediction from chunks of clinical notes calculated from the predictions of individual chunks (see de- tails in Experiments); 2) it makes the input text length suitable for standard transformer-based pre- trained models (described in the next section).
We propose a hypernetwork-guided multitask learning framework to predict multiple patient out- comes. Fig. 2 illustrates the overall framework.
Specifically, we use a shared transformer-based text encoder to obtain hidden representations of clinical notes and task information. We develop a hypernetwork-based module, called Adapter Hy- perNet, to generate task-specific parameters for classification heads based on the task embeddings trained by a bottleneck network. The generated task-specific parameters then become the parame- ters of classification heads. To balance the learn- ing of the joint objective function, we deploy the weight hypernetwork to generate weight coeffi- cients conditioned on the task embeddings.
3.2 Base Text Encoder
We deploy the Bidirectional Encoder Representa- tions from Transformers (BERT) model (Devlin et al.,2019) as our base text encoder to learn rich text features. The BERT text encoder pre-trains a language model with masked language model- ing and next sentence prediction objectives in an unsupervised manner. Then, a downstream fine- tuning is followed by tuning the pre-trained model checkpoint with suitable learning objectives for downstream applications. The pre-training and fine-tuning paradigm can effectively exploit se- mantic knowledge from large training corpora and achieve superior performance in many downstream applications. Due to the discrepancy of vocabular- ies in general and specific domains, many efforts have been made to pre-train a transformer language model in various specific domains. To benefit the most from unsupervised representation learning for clinical application, we adopt the ClinicalBERT that starts from the standard BERT checkpoint and continues pre-training on clinical notes (Huang
FCN
h<latexit sha1_base64="9NQwdzShtQejuawPqhjblhk4RLI=">AAAB7nicbVA9SwNBEJ3zM8avqKXNYhCswl0UtAzaWFhEMB+QnGFvM5cs2ds7dveEcORH2FgoYuvvsfPfuEmu0MQHA4/3ZpiZFySCa+O6387K6tr6xmZhq7i9s7u3Xzo4bOo4VQwbLBaxagdUo+ASG4Ybge1EIY0Cga1gdDP1W0+oNI/lgxkn6Ed0IHnIGTVWag171cfsbtIrld2KOwNZJl5OypCj3it9dfsxSyOUhgmqdcdzE+NnVBnOBE6K3VRjQtmIDrBjqaQRaj+bnTshp1bpkzBWtqQhM/X3REYjrcdRYDsjaoZ60ZuK/3md1IRXfsZlkhqUbL4oTAUxMZn+TvpcITNibAllittbCRtSRZmxCRVtCN7iy8ukWa1455Xq/UW5dp3HUYBjOIEz8OASanALdWgAgxE8wyu8OYnz4rw7H/PWFSefOYI/cD5/AAnDj18=</latexit>L2 h<latexit sha1_base64="aiIFJPaVK91mxQzOQzbJYxgwRao=">AAAB7nicbVA9SwNBEJ2LXzF+RS1tFoNgFe4SQcugjYVFBPMByRn2NpNkyd7esbsnhCM/wsZCEVt/j53/xk1yhSY+GHi8N8PMvCAWXBvX/XZya+sbm1v57cLO7t7+QfHwqKmjRDFssEhEqh1QjYJLbBhuBLZjhTQMBLaC8c3Mbz2h0jySD2YSox/SoeQDzqixUmvUqz6md9NeseSW3TnIKvEyUoIM9V7xq9uPWBKiNExQrTueGxs/pcpwJnBa6CYaY8rGdIgdSyUNUfvp/NwpObNKnwwiZUsaMld/T6Q01HoSBrYzpGakl72Z+J/XSczgyk+5jBODki0WDRJBTERmv5M+V8iMmFhCmeL2VsJGVFFmbEIFG4K3/PIqaVbKXrVcub8o1a6zOPJwAqdwDh5cQg1uoQ4NYDCGZ3iFNyd2Xpx352PRmnOymWP4A+fzBwtLj2A=</latexit>L3
h<latexit sha1_base64="2/eXtPZdKTXurr4khq6fbWjrnkI=">AAAB7nicbVA9SwNBEJ3zM8avqKXNYhCswl0UtAzaWFhEMB+QnGFvs5cs2ds7dueEcORH2FgoYuvvsfPfuEmu0MQHA4/3ZpiZFyRSGHTdb2dldW19Y7OwVdze2d3bLx0cNk2casYbLJaxbgfUcCkUb6BAyduJ5jQKJG8Fo5up33ri2ohYPeA44X5EB0qEglG0UmvY8x6zu0mvVHYr7gxkmXg5KUOOeq/01e3HLI24QiapMR3PTdDPqEbBJJ8Uu6nhCWUjOuAdSxWNuPGz2bkTcmqVPgljbUshmam/JzIaGTOOAtsZURyaRW8q/ud1Ugyv/EyoJEWu2HxRmEqCMZn+TvpCc4ZybAllWthbCRtSTRnahIo2BG/x5WXSrFa880r1/qJcu87jKMAxnMAZeHAJNbiFOjSAwQie4RXenMR5cd6dj3nripPPHMEfOJ8/CDuPXg==</latexit>L1
h<latexit sha1_base64="cxkFMiiqyqCaGFRouDkisoaUfaQ=">AAAB+3icbVBNS8NAEN34WetXrEcvwSJ4KkkV9FjsxUMPFe0HtDFstpt26WYTdifSEvJXvHhQxKt/xJv/xm2bg7Y+GHi8N8PMPD/mTIFtfxtr6xubW9uFneLu3v7BoXlUaqsokYS2SMQj2fWxopwJ2gIGnHZjSXHoc9rxx/WZ33miUrFIPMA0pm6Ih4IFjGDQkmeWRl4f6ATSXr1x72aPaSPzzLJdseewVomTkzLK0fTMr/4gIklIBRCOleo5dgxuiiUwwmlW7CeKxpiM8ZD2NBU4pMpN57dn1plWBlYQSV0CrLn6eyLFoVLT0NedIYaRWvZm4n9eL4Hg2k2ZiBOggiwWBQm3ILJmQVgDJikBPtUEE8n0rRYZYYkJ6LiKOgRn+eVV0q5WnItK9e6yXLvJ4yigE3SKzpGDrlAN3aImaiGCJugZvaI3IzNejHfjY9G6ZuQzx+gPjM8fIcCUfw==</latexit>L[CLS] h<latexit sha1_base64="sE5y2arE6qEAjW8AjkGCNUb3r7Q=">AAAB7nicbVA9SwNBEJ2LXzF+RS1tFoNgFe5iQMugjYVFBPMByRn2Nptkyd7esTsnhCM/wsZCEVt/j53/xk1yhSY+GHi8N8PMvCCWwqDrfju5tfWNza38dmFnd2//oHh41DRRohlvsEhGuh1Qw6VQvIECJW/HmtMwkLwVjG9mfuuJayMi9YCTmPshHSoxEIyilVqjXvUxvZv2iiW37M5BVomXkRJkqPeKX91+xJKQK2SSGtPx3Bj9lGoUTPJpoZsYHlM2pkPesVTRkBs/nZ87JWdW6ZNBpG0pJHP190RKQ2MmYWA7Q4ojs+zNxP+8ToKDKz8VKk6QK7ZYNEgkwYjMfid9oTlDObGEMi3srYSNqKYMbUIFG4K3/PIqaVbK3kW5cl8t1a6zOPJwAqdwDh5cQg1uoQ4NYDCGZ3iFNyd2Xpx352PRmnOymWP4A+fzBwzTj2E=</latexit>L4
Transformer Encoders
Clinical Notes
FCN FCN
Loss 1 Loss 2 Loss 3
<latexit sha1_base64="Kfe2l2iGdo8ph/MEHW/TMNkFtOM=">AAAB7XicbVDLSgNBEOyNrxhfUY9eBoPgKexGQY9BLx4jmAckS5idzCZjZmeWeQhhyT948aCIV//Hm3/jJNmDJhY0FFXddHdFKWfa+P63V1hb39jcKm6Xdnb39g/Kh0ctLa0itEkkl6oTYU05E7RpmOG0kyqKk4jTdjS+nfntJ6o0k+LBTFIaJngoWMwINk5q9WTKre6XK37VnwOtkiAnFcjR6Je/egNJbEKFIRxr3Q381IQZVoYRTqelntU0xWSMh7TrqMAJ1WE2v3aKzpwyQLFUroRBc/X3RIYTrSdJ5DoTbEZ62ZuJ/3lda+LrMGMitYYKslgUW46MRLPX0YApSgyfOIKJYu5WREZYYWJcQCUXQrD88ipp1arBRbV2f1mp3+RxFOEETuEcAriCOtxBA5pA4BGe4RXePOm9eO/ex6K14OUzx/AH3ucP0UyPSQ==</latexit>
Joint Loss
↵1
<latexit sha1_base64="FgwWpVFQ8WwtgAVpVA0F0hHZPwE=">AAAB73icbVBNS8NAEJ3Ur1q/qh69LBbBU0mqoMeiF48V7Ae0oUy2m3bpZhN3N0IJ/RNePCji1b/jzX/jts1BWx8MPN6bYWZekAiujet+O4W19Y3NreJ2aWd3b/+gfHjU0nGqKGvSWMSqE6BmgkvWNNwI1kkUwygQrB2Mb2d++4kpzWP5YCYJ8yMcSh5yisZKnR6KZIR9r1+uuFV3DrJKvJxUIEejX/7qDWKaRkwaKlDrrucmxs9QGU4Fm5Z6qWYJ0jEOWddSiRHTfja/d0rOrDIgYaxsSUPm6u+JDCOtJ1FgOyM0I73szcT/vG5qwms/4zJJDZN0sShMBTExmT1PBlwxasTEEqSK21sJHaFCamxEJRuCt/zyKmnVqt5FtXZ/Wanf5HEU4QRO4Rw8uII63EEDmkBBwDO8wpvz6Lw4787HorXg5DPH8AfO5w+2H4/A</latexit>
↵2
<latexit sha1_base64="GbbpWyoE/JRAgiEztdpncYk5rm0=">AAAB73icbVBNS8NAEJ3Ur1q/qh69LBbBU0mqoMeiF48V7Ae0oUy2m3bpZhN3N0IJ/RNePCji1b/jzX/jts1BWx8MPN6bYWZekAiujet+O4W19Y3NreJ2aWd3b/+gfHjU0nGqKGvSWMSqE6BmgkvWNNwI1kkUwygQrB2Mb2d++4kpzWP5YCYJ8yMcSh5yisZKnR6KZIT9Wr9ccavuHGSVeDmpQI5Gv/zVG8Q0jZg0VKDWXc9NjJ+hMpwKNi31Us0SpGMcsq6lEiOm/Wx+75ScWWVAwljZkobM1d8TGUZaT6LAdkZoRnrZm4n/ed3UhNd+xmWSGibpYlGYCmJiMnueDLhi1IiJJUgVt7cSOkKF1NiISjYEb/nlVdKqVb2Lau3+slK/yeMowgmcwjl4cAV1uIMGNIGCgGd4hTfn0Xlx3p2PRWvByWeO4Q+czx+3o4/B</latexit>
↵<latexit sha1_base64="zZqVaw4BWxy/QaFrQEYUWH19uGM=">AAAB73icbVBNS8NAEJ34WetX1aOXxSJ4Kkkr6LHoxWMF+wFtKJPtpl262cTdjVBC/4QXD4p49e9489+4bXPQ1gcDj/dmmJkXJIJr47rfztr6xubWdmGnuLu3f3BYOjpu6ThVlDVpLGLVCVAzwSVrGm4E6ySKYRQI1g7GtzO//cSU5rF8MJOE+REOJQ85RWOlTg9FMsJ+rV8quxV3DrJKvJyUIUejX/rqDWKaRkwaKlDrrucmxs9QGU4FmxZ7qWYJ0jEOWddSiRHTfja/d0rOrTIgYaxsSUPm6u+JDCOtJ1FgOyM0I73szcT/vG5qwms/4zJJDZN0sShMBTExmT1PBlwxasTEEqSK21sJHaFCamxERRuCt/zyKmlVK16tUr2/LNdv8jgKcApncAEeXEEd7qABTaAg4Ble4c15dF6cd+dj0brm5DMn8AfO5w+5J4/C</latexit>3
hL2
<latexit sha1_base64="9NQwdzShtQejuawPqhjblhk4RLI=">AAAB7nicbVA9SwNBEJ3zM8avqKXNYhCswl0UtAzaWFhEMB+QnGFvM5cs2ds7dveEcORH2FgoYuvvsfPfuEmu0MQHA4/3ZpiZFySCa+O6387K6tr6xmZhq7i9s7u3Xzo4bOo4VQwbLBaxagdUo+ASG4Ybge1EIY0Cga1gdDP1W0+oNI/lgxkn6Ed0IHnIGTVWag171cfsbtIrld2KOwNZJl5OypCj3it9dfsxSyOUhgmqdcdzE+NnVBnOBE6K3VRjQtmIDrBjqaQRaj+bnTshp1bpkzBWtqQhM/X3REYjrcdRYDsjaoZ60ZuK/3md1IRXfsZlkhqUbL4oTAUxMZn+TvpcITNibAllittbCRtSRZmxCRVtCN7iy8ukWa1455Xq/UW5dp3HUYBjOIEz8OASanALdWgAgxE8wyu8OYnz4rw7H/PWFSefOYI/cD5/AAnDj18=</latexit>
hL1
<latexit sha1_base64="2/eXtPZdKTXurr4khq6fbWjrnkI=">AAAB7nicbVA9SwNBEJ3zM8avqKXNYhCswl0UtAzaWFhEMB+QnGFvs5cs2ds7dueEcORH2FgoYuvvsfPfuEmu0MQHA4/3ZpiZFyRSGHTdb2dldW19Y7OwVdze2d3bLx0cNk2casYbLJaxbgfUcCkUb6BAyduJ5jQKJG8Fo5up33ri2ohYPeA44X5EB0qEglG0UmvY8x6zu0mvVHYr7gxkmXg5KUOOeq/01e3HLI24QiapMR3PTdDPqEbBJJ8Uu6nhCWUjOuAdSxWNuPGz2bkTcmqVPgljbUshmam/JzIaGTOOAtsZURyaRW8q/ud1Ugyv/EyoJEWu2HxRmEqCMZn+TvpCc4ZybAllWthbCRtSTRnahIo2BG/x5WXSrFa880r1/qJcu87jKMAxnMAZeHAJNbiFOjSAwQie4RXenMR5cd6dj3nripPPHMEfOJ8/CDuPXg==</latexit>
hL[CLS]
<latexit sha1_base64="cxkFMiiqyqCaGFRouDkisoaUfaQ=">AAAB+3icbVBNS8NAEN34WetXrEcvwSJ4KkkV9FjsxUMPFe0HtDFstpt26WYTdifSEvJXvHhQxKt/xJv/xm2bg7Y+GHi8N8PMPD/mTIFtfxtr6xubW9uFneLu3v7BoXlUaqsokYS2SMQj2fWxopwJ2gIGnHZjSXHoc9rxx/WZ33miUrFIPMA0pm6Ih4IFjGDQkmeWRl4f6ATSXr1x72aPaSPzzLJdseewVomTkzLK0fTMr/4gIklIBRCOleo5dgxuiiUwwmlW7CeKxpiM8ZD2NBU4pMpN57dn1plWBlYQSV0CrLn6eyLFoVLT0NedIYaRWvZm4n9eL4Hg2k2ZiBOggiwWBQm3ILJmQVgDJikBPtUEE8n0rRYZYYkJ6LiKOgRn+eVV0q5WnItK9e6yXLvJ4yigE3SKzpGDrlAN3aImaiGCJugZvaI3IzNejHfjY9G6ZuQzx+gPjM8fIcCUfw==</latexit>
Transformer Encoders
Shared
Task Desc. 1 Task Desc. 2 Task Desc. 3 Adapter HyperNet
Weight HyperNet
Task HyperNet
Computation Flow Generation Flow
Multitask Classifiers
Base Text Encoder
HyperNetworks
Figure 2: The illustration of our proposed hypernetwork-guided multitask learning framework
et al.,2019). Given text chunks ci as inputs, we use the last layer’s hidden representation of BERT- based text encoder to represent the encoded clinical note, i.e.,
H = BERT(ci, θ0), (2) where H ∈ Rs×dh, ci is the i-th chunked text, and θ0 is the model parameter of BERT encoder initialized from self-supervised pretraining. We can use either the embedding of the special token [CLS]or the pooled embedding of the hidden rep- resentations to represent the clinical note, which is denoted as h ∈ Rdh. Similarly, we input the textual label of each task to obtain the task de- scription embeddings with a shared BERT-based text encoder parameterized with θ0. Specifically, for the t-th task with mt classes, the initial task embeddings are Tt ∈ Rmt×dt, where dt is the embedding dimension. The obtained task embed- dings capture the task information and preserve the semantic meaning of classes in each task.
3.3 Task-Conditioned Hypernetworks
Our goal is to utilize task side-information to en- able robust patient outcome prediction, especially for those classes with very few instances or that were unseen in the training set. In the joint train- ing of multitask learning, it is easy to suffer from intertask interference. Inspired by the hypernet- works (Ha et al.,2017) that generate model param- eters to enable weight-sharing across neural layers, we propose the task-conditioned hypernetworks with semantic task label embeddings to generate
the parameters of multitask classification heads for clinical outcomes. The task-conditioned parameter generation can enable task context-aware learning from multiple tasks and share knowledge across different tasks.
Task Embeddings: To facilitate the parameter generation conditioned on task information, we further use a bottleneck network to train the task embeddings obtained from the contextualized em- beddings of the BERT-based text encoder. The bottleneck network contains two feed forward fully- connected layers with ReLU as the activation func- tion (Nair and Hinton,2010), denoted as:
Zt= ReLU(TtW1)W2, (3) where W1 ∈ Rdt×db and W1 ∈ Rdb×dt are weight matrices of linear layers, Zt ∈ Rmt×dt are the output task embeddings, db is the hidden dimension that restricts the bottleneck with fewer neurons, and the bias terms are omitted for sim- plicity. The bottleneck network has been applied in many fields, such as image noise reduction. We use this structure to reduce the potential noise in task classes because the assignment of diagnosis in MIMIC-III does not use a systematic ontology.
Adapter Hypernetwork: Inspired by the hypernetwork-generated adapters (Mahabadi et al., 2021) that fine-tune transformer layers, we equip the multitask classification heads with adapter hypernetworks that generate the parameters.
Furthermore, we inject the contextual information in the task labels into parameters generation, which
provides task-specific classification heads capable of task context-aware learning.
The adapter hypernetwork takes task embed- dings of each task as input and generates model parameters written as
Wt=H(Zt), ∀ t = 1, . . . , k (4) where Wt ∈ Rdh×mt is the generated weight pa- rameters for the classification head of the t-th task.
We instantiate hypernetworks with linear layers to benefit from the simplicity. Specifically, the matri- ces of task embeddings Ztare flattened into vectors vt∈ Rmt∗dt×1, projected into the hidden represen- tation space Rmt∗dh×1 with nonlinear activation, and reshaped to the size of weight matrices Rmt×dh to act as the parameters of classification heads. The bias parameters are also generated similarly. The adapter hypernetwork conditioned on the task con- text shares the meta knowledge of tasks and enables task-specific prediction. Thus, the inter-task inter- ference in the multitask setting can be effectively mitigated. Furthermore, the semantic embeddings of task context provide the adapter hypernetwork a good initialization of task classes to generate task- specific parameters, facilitating the meta learning from limited information for zero- or few-shot sce- narios.
3.4 Task-Conditioned Learning Objective In the joint training framework of multitask learn- ing, the standard method corresponds to manually tuning the weight coefficient αt. However, this requires human efforts to configure additional hy- perparameters when tuning the multitask learning algorithms manually. We propose to use another hypernetwork G to generate the weight coefficients and make the weighted objective function of joint multitask learning conditioned on task context. The weight coefficient hypernetwork G is defined as
βt=G(Zt), ∀ t = 1, . . . , k. (5) We instantiate the hypernetwork with a multilayer perceptron (MLP) which takes the task embeddings Ztas input and outputs a one-dimensional scalar.
Then we apply the softmax function to generate normalized weight probability, i.e.,
αt= exp(βt) Pk
t=1expβt
. (6)
In this way, the task-conditioned learning objective can dynamically weigh the loss of different tasks
and balance the contributions of different tasks’
losses to the overall loss during the training pro- cess. In addition, the weight coefficients learned from the task context can preserve the semantic task information and reweight the joint loss dynam- ically.
4 Experiments
4.1 Datasets and Setup 4.1.1 Datasets
We use the admission-level patient records in the MIMIC-III (Medical Information Mart for Inten- sive Care III)1dataset for experiments. From the vast information in this database, we select the free-text patient notes as the testbed for natural lan- guage processing research. We release our code for reproducibility2.
Table 1: A statistical summary of datasets and tasks
Dataset # Samples # Classes
Train Val. Test Readm. Adm. Type Diag. LoS Discharge 26,245 3,037 3,063 2 3 2,624 10 Early Notes 6,656 743 608 2 3 2,637 10
This paper considers four classification tasks, i.e., readmission prediction, diagnosis classifica- tion, length-of-stay prediction, and admission type classification. Following the prior work (Huang et al., 2019), we build two datasets, i.e., one extracted from discharge summaries (denoted as Discharge) and one with early notes that were created within three days after admission (denoted as Early Notes), according to the period of pa- tient admission date and the date when the notes are charted. Table 1 summarizes the number of instances and classes in our extracted datasets. In- hospital death prediction is also an essential task of patient outcome. However, we focus more on readmission, and mortality precludes the possibil- ity of readmission. Thus, we filter out all the in- hospital death cases, which also aligns with the prior work (Huang et al.,2019). We also consider a proxy task of admission type classification that cat- egorizes the clinical notes into “emergency", “elec- tive" and “urgent" (studied in the discussion section below).
1Dataset available by requesthttps://mimic.mit.
edu/docs/iii/
2https://agit.ai/jsx/MT-Hyper
4.1.2 Tasks
Readmission Prediction The goal of this binary classification task is to predict whether the patient will be readmitted within 30 days of discharge.
Diagnosis Prediction The MIMIC-III dataset has a total of 15,691 classes of diagnoses in the
“ADMISSIONS” table. We extract the clinical notes from the “NOTEEVENTS" table and get a total of 2,715 diagnoses in the extracted datasets of dis- charge and early notes, where frequent diagnoses include pneumonia, congestive heart failure, and sepsis. In our train/val/test split of discharge notes, 2,209 diagnoses appear in the training set, which allows us to test the performance of zero-shot pre- diction.
Length-of-Stay (LoS) Prediction Following the setting ofHarutyunyan et al.(2019), we define the length-of-stay prediction as a 10-class classifi- cation problem. Specifically, the duration of stay is divided into ten buckets, i.e., one class for stays less than a day, seven classes for one-to-seven-day stays, one for stays longer than a week but shorter than two, and the last one class for stays over two weeks. A slight difference is that the duration of stay in our definition is calculated during the span between when the patient was admitted to the hos- pital and when discharged from the hospital.
4.1.3 Baselines
We compare our method, dubbed MT-Hyper, with several multitask learning algorithms on electronic health records and clinical natural language pro- cessing. MT-LSTM (Harutyunyan et al.,2019) adopts a basic LSTM-based neural network and is jointly trained with several patient outcome predic- tion tasks. This method is initially designed for modeling time series. We modify it with a word embedding module for modeling clinical text in patient’s health records. MT-BERT (Mulyar and McInnes, 2020) is a unified clinical NLP model that utilizes the BERT model architecture for text encoding and learns features with multiple clinical task prediction heads simultaneously. It originally consists of eight task heads, while our implemen- tation uses fewer clinical prediction tasks. MT- RAM (Sun et al.,2021)proposes a neural module with feature recalibration and aggregation for multi- task learning that mitigates the mutual interference of different tasks and refines the feature learning from noisy clinical text.
4.1.4 Settings and Evaluation
We use the Adam optimizer (Kingma and Ba,2014) to optimize the model and set different learning rates for different text encoders and classification heads. The learning rates of LSTM-based meth- ods range from 5e−5to 1e−3. As to BERT-based methods, we use a lower learning rate for both base text encoder and classification heads, ranging from 1e−5 to 5e−5. Due to the fact that clinical BERT only has pretrained checkpoints with the BERT base model and the limitation of computing resources, we adopt the BERT base architecture as our text encoder. We use 300D word embed- ding, and the dimension of LSTM hidden states is also set to 300. We run each algorithm ten times and report the average score and standard deviation.
We monitor the validation loss in each trial and set the number of epochs that triggers the early stop mechanism to be 3 for the BERT encoder and 10 for the LSTM encoder.
Our problem definition considers readmission as a binary classification problem, while the rest of the tasks are defined as multi-class classification prob- lems. For readmission prediction, we use weighted F1 and AUC-ROC. For multi-class classification problems, we report micro averaged scores. As we segment lengthy clinical notes into shorter chunks for progressive prediction, we report the results of two types of evaluation scheme, i.e., progressive scores that are calculated with the prediction prob- abilities of note chunks and the ultimate scores of a complete admission that are measured by aggre- gating every prediction on segmented notes. The ultimate scores can also be viewed as a similar strat- egy to the bagging-based ensemble, in which each single prediction is based on different chunks of an instance and the final prediction is an ensemble of different predictions.
4.2 Main Results
This section presents the main results on prediction at discharge and early prediction. As admission type classification has limited usage in real-world scenarios, we do not include it in the comparison of main results but use it in the discussion of task conditioning. The full results for all the four tasks are reported in the appendix.
Early Prediction Table 2 shows the results of patient outcome prediction from early notes. In both progressive and ultimate prediction, our pro- posed method outperforms three baselines in most
Table 2: Results of patient outcome prediction from early notes with average score ± standard deviation reported
Task Method Progressive Ultimate
F1 AUC-ROC F1 AUC-ROC
Readmission
MT-LSTM 51.69 ± 3.72 53.80 ± 1.99 52.36 ± 3.62 56.17 ± 2.23 MT-BERT 55.52 ± 3.92 56.75 ± 2.56 56.02 ± 4.34 60.85 ± 1.61 MT-RAM 56.00 ± 3.09 57.26 ± 1.94 56.51 ± 2.80 62.78 ± 1.39 MT-Hyper 56.38 ± 2.30 57.57 ± 1.27 57.67 ± 2.21 62.41 ± 1.61
Diagnosis
MT-LSTM 8.50 ± 3.39 63.00 ± 0.95 8.69 ± 3.79 62.26 ± 0.91 MT-BERT 11.63 ± 4.83 64.13 ± 1.30 11.80 ± 5.02 63.24 ± 1.30 MT-RAM 14.06 ± 1.21 68.58 ± 1.60 14.39 ± 1.41 67.84 ± 1.64 MT-Hyper 19.56 ± 1.33 72.98 ± 0.45 20.47 ± 1.38 73.21 ± 0.43
LOS
MT-LSTM 30.19 ± 2.70 75.73 ± 0.92 30.54 ± 2.68 76.50 ± 0.87 MT-BERT 26.40 ± 4.34 72.60 ± 2.81 27.25 ± 4.08 73.71 ± 2.66 MT-RAM 26.20 ± 3.08 73.15 ± 3.44 27.08 ± 2.80 74.04 ± 3.31 MT-Hyper 33.18 ± 0.91 71.84 ± 1.95 33.28 ± 1.03 72.23 ± 2.18
Average
MT-LSTM 30.12 ± 3.27 64.18 ± 1.29 30.53 ± 3.36 64.97 ± 1.34 MT-BERT 31.18 ± 4.36 64.49 ± 2.22 31.69 ± 4.48 65.93 ± 1.86 MT-RAM 32.09 ± 2.46 66.33 ± 2.33 32.66 ± 2.34 68.22 ± 2.11 MT-Hyper 36.37 ± 1.51 67.47 ± 1.23 37.14 ± 1.54 69.28 ± 1.41
cases. The baseline models are relatively stronger in tasks with fewer classes. For example, MT- BERT and MT-RAM produce a relatively good prediction on readmission, while their performance for multi-class diagnosis prediction is much worse than our proposed method. Furthermore, our pro- posed method can make more stable predictions in most cases according to the standard deviation. Our method outperforms the baselines by a relatively large margin in terms of the average scores of three prediction tasks.
Prediction at Discharge Next, we consider pre- diction at discharge, while keeping the setup otherwise similar to the early prediction. Dis- charge notes provide relatively complete informa- tion about patient hospital visits. Table 3shows the prediction results. Our method has compara- ble readmission prediction performance, superior diagnosis prediction, inferior performance in the length-of-stay prediction, and relatively better over- all performance measured by average score than other baselines. Besides, our method has smaller values of standard deviation in most cases when making a prediction at discharge.
4.3 Zero-shot Diagnosis Prediction
We study a more challenging scenario of zero-shot diagnosis prediction. In the extracted datasets, among the total 2,624 diagnoses of discharge, 2,209 diagnoses appear in the training set, i.e., 415 types of diagnoses are unseen in the evaluation data. We adopt this sampled setting to test the per- formance of zero-shot diagnosis prediction. Table4 shows the AUC-ROC scores of progressive and ul-
timate diagnosis prediction on zero-shot diagnosis categories. The results indicate that the baseline models make terribly incorrect predictions with AUC-ROC scores lower than the random guess.
These baselines cannot learn meta-knowledge of diagnoses, making them overfitting on seen di- agnoses in the training set. In contrast, our pro- posed MT-Hyper can capture the task context and semantic diagnosis information. Therefore, our method is more generalizable on unseen diagnosis and achieves satisfactory performance under the challenging zero-shot setting.
4.4 Discussion of Task Conditioning
We investigate the effect of task conditioning and whether hypernetwork-guided learning can facili- tate multitask learning and capture the task context information.
Parameterized Objective Weights Conditioned on Task Information Hypernetwork-guided learning enables automatic weighing of multitask learning objective function, which enables mon- itoring which task branch the learning algorithm prioritizes. We plot the curves of the weight coeffi- cients αtof different tasks in the multitask learning objective in Fig.3. The curves illustrate that the learning algorithm gives readmission and length- of-stay prediction a higher weight than diagnosis prediction. The weight coefficients of the three tasks update dynamically through model optimiza- tion, which presents a perspective to explain the model learning process. We also conduct an abla- tion study on the effect of weight hypernet in the appendix.
Table 3: Results of patient outcome prediction at discharge with average score ± standard deviation reported. Bold values indicate the cases when our model obtain the best performance.
Task Method Progressive Ultimate
F1 AUC-ROC F1 AUC-ROC
Readmission
MT-LSTM 60.54 ± 1.38 60.39 ± 1.22 63.17 ± 1.64 68.46 ± 1.93 MT-BERT 64.86 ± 1.08 64.78 ± 1.15 68.66 ± 3.17 76.33 ± 2.88 MT-RAM 65.92 ± 1.86 65.88 ± 1.34 70.41 ± 2.67 77.49 ± 2.30 MT-Hyper 66.41 ± 0.77 66.08 ± 0.74 69.66 ± 1.52 76.93 ± 1.11 Diagnosis
MT-LSTM 9.08 ± 1.60 66.51 ± 0.98 11.04 ± 2.74 64.85 ± 1.01 MT-BERT 9.75 ± 1.82 69.19 ± 6.22 10.80 ± 2.83 68.04 ± 7.98 MT-RAM 10.26 ± 1.99 67.38 ± 0.79 11.39 ± 2.99 65.70 ± 0.92 MT-Hyper 10.28 ± 0.25 76.93 ± 0.73 12.99 ± 1.57 79.07 ± 1.07
LOS
MT-LSTM 31.09 ± 0.93 76.77 ± 0.67 32.57 ± 1.29 79.33 ± 1.19 MT-BERT 31.37 ± 1.23 78.27 ± 0.90 34.91 ± 0.31 82.13 ± 0.26 MT-RAM 31.29 ± 1.91 76.14 ± 4.48 32.29 ± 2.95 77.69 ± 5.34 MT-Hyper 31.33 ± 0.38 73.76 ± 1.31 31.08 ± 0.21 73.01 ± 1.28
Average
MT-LSTM 33.57 ± 1.30 67.89 ± 0.96 35.59 ± 1.89 70.88 ± 1.38 MT-BERT 35.33 ± 1.37 70.75 ± 2.76 38.13 ± 2.10 75.50 ± 3.71 MT-RAM 35.82 ± 1.92 69.80 ± 2.20 38.03 ± 2.87 73.63 ± 2.86 MT-Hyper 36.01 ± 0.47 72.26 ± 0.93 37.91 ± 1.10 76.34 ± 1.15
Table 4: Results (AUC-ROC) of zero-shot diagnosis prediction on unseen diagnosis results with average score ± standard deviation reported
Dataset Method Progressive Ultimate
Discharge
MT-LSTM 11.00 ± 2.31 9.74 ± 1.55 MT-BERT 10.47 ± 0.76 9.00 ± 0.36 MT-RAM 11.40 ± 1.44 10.37 ± 1.25 MT-Hyper 64.06 ± 2.02 68.33 ± 2.76
Early
MT-LSTM 8.65 ± 1.19 8.47 ± 1.00 MT-BERT 8.24 ± 0.79 8.20 ± 0.74 MT-RAM 14.25 ± 3.11 14.12 ± 3.12 MT-Hyper 62.04 ± 1.13 63.76 ± 1.03
Figure 3: Curves of parameterized weight coefficients in the multitask learning objective.
Semantic Relation of Task Embeddings We further study the semantic relation of task embed- dings to verify if task conditioning uses task infor- mation. We include the admission type classifica- tion as an auxiliary task and train four tasks jointly to achieve this goal. We filter out “newborn" and
“death" and define the admission type classification as a three-way classification task with admission types of “emergency", “elective" and “urgent". The classification task categorizes the type of admis- sion for each clinical note. It may have no practical usage in clinical support. However, it is a useful auxiliary task for clinical note understanding. Thus, we choose this task as an auxiliary task to help study the semantic relation of task embeddings.
We obtain the task embeddings from the trained model and aggregate embeddings of each task into a vector via mean pooling. To visualize task vec- tors in low dimensional space, we apply principal component analysis to reduce the dimension of task vectors. The visualization of dimension-reduced task vectors is illustrated in Fig.4. We can see that readmission and admission type are closer to each other as they are related to hospital admission man- agement. The task vectors for diagnosis and length- of-stay are further apart from each other and also from the readmission and admission type vectors.
Task embeddings can capture this pattern, although pooling and dimension reduction may lose certain information. Task conditioning can retain the task information to some extent, which is helpful for the generation of task-specific parameters.
−0.25 0.00 0.25 0.50 0.75 1.00 Component 1
−0.75
−0.50
−0.25 0.00 0.25 0.50 0.75 1.00
Component2
Readm.
Adm. T Diag.
LoS
Figure 4: Visualization of task embeddings with dimen- sion reduced by PCA
5 Conclusion
Patient hospitalization is a complex process, which needs the learning algorithm to model multiple risk indicators. Multitask learning can jointly learn from patient records to empower clinical decision- making. To address the challenges of inter-task interference and poor generalizability to unseen la- bels in recent multitask learning frameworks for healthcare, we propose a hypernetwork-guided multitask learning method that learns from the task context and generates task-specific parameters for effective multitask prediction on patient outcomes.
The proposed method incorporates contextualized language representations to encode clinical notes and capture the task context via the semantic em- beddings of tasks. The hypernetwork-guided multi- task learning framework enables task-conditioned learning and balances the joint learning objective across different tasks. Experimental studies on the real-world MIMIC-III dataset show our proposed method achieves better performance on early pre- diction and more robust performance on zero-shot diagnosis prediction.
References
Batuhan Bardak and Mehmet Tan. 2021. Improving clinical outcome predictions using convolution over medical entities with multimodal learning. Artificial Intelligence in Medicine, 117:102112.
Rich Caruana. 1997. Multitask learning. Machine Learning, 28(41-75).
Zhengping Che, David Kale, Wenzhe Li, Moham- mad Taha Bahadori, and Yan Liu. 2015. Deep com- putational phenotyping. In Proceedings of the 21th ACM SIGKDD International Conference on Knowl- edge Discovery and Data Mining, pages 507–516.
Edward Choi, Mohammad Taha Bahadori, Andy Schuetz, Walter F Stewart, and Jimeng Sun. 2016.
Doctor AI: Predicting clinical events via recurrent neural networks. In Machine Learning for Health- care Conference, pages 301–318. PMLR.
Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. BERT: Pre-training of Deep Bidirectional Transformers for Language Un- derstanding. In NAACL-HLT.
Carol Friedman, Lyudmila Shagina, Yves Lussier, and George Hripcsak. 2004. Automated encoding of clinical documents based on natural language pro- cessing. Journal of the American Medical Informat- ics Association, 11(5):392–402.
Marzyeh Ghassemi, Tristan Naumann, Finale Doshi-Velez, Nicole Brimmer, Rohit Joshi, Anna Rumshisky, and Peter Szolovits. 2014. Unfolding physiological state: Mortality modelling in inten- sive care units. In Proceedings of the 20th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pages 75–84.
David Ha, Andrew Dai, and Quoc V Le. 2017. Hyper- networks. In Proceedings of ICLR.
Hrayr Harutyunyan, Hrant Khachatrian, David C Kale, Greg Ver Steeg, and Aram Galstyan. 2019. Multi- task learning and benchmarking with clinical time series data. Scientific Data, 6(1):1–18.
Kexin Huang, Jaan Altosaar, and Rajesh Ranganath.
2019. ClinicalBERT: Modeling clinical notes and predicting hospital readmission. arXiv preprint arXiv:1904.05342.
Shaoxiong Ji, Erik Cambria, and Pekka Marttinen.
2020. Dilated convolutional attention network for medical code assignment from clinical text. In Pro- ceedings of the 3rd Clinical Natural Language Pro- cessing Workshop at EMNLP, pages 73–78.
Shaoxiong Ji, Shirui Pan, and Pekka Marttinen. 2021.
Medical code assignment with gated convolution and note-code interaction. In Findings of the Associ- ation for Computational Linguistics: ACL-IJCNLP 2021, pages 1034–1043. Association for Computa- tional Linguistics.
Diederik P Kingma and Jimmy Ba. 2014. Adam: A Method for Stochastic Optimization. arXiv preprint arXiv:1412.6980.
Irene Li, Jessica Pan, Jeremy Goldwasser, Neha Verma, Wai Pan Wong, Muhammed Yavuz Nuzumlalı, Ben- jamin Rosand, Yixin Li, Matthew Zhang, David Chang, et al. 2021. Neural natural language process- ing for unstructured data in electronic health records:
a review. arXiv preprint arXiv:2107.02975.
Rabeeh Karimi Mahabadi, Sebastian Ruder, Mostafa Dehghani, and James Henderson. 2021. Parameter- efficient multi-task fine-tuning for transformers via shared hypernetworks. In Proceedings of ACL- IJCNLP.
Diwakar Mahajan, Ananya Poddar, Jennifer J Liang, Yen-Ting Lin, John M Prager, Parthasarathy Surya- narayanan, Preethi Raghavan, and Ching-Huei Tsou.
2020. Identification of semantically similar sen- tences in clinical notes: Iterative intermediate train- ing using multi-task learning. JMIR Medical Infor- matics, 8(11):e22508.
Lu Men, Noyan Ilk, Xinlin Tang, and Yuan Liu. 2021.
Multi-disease prediction using lstm recurrent neu- ral networks. Expert Systems with Applications, 177:114905.
Riccardo Miotto, Li Li, Brian A Kidd, and Joel T Dud- ley. 2016. Deep Patient: An Unsupervised Represen- tation to Predict the Future of Patients from the Elec- tronic Health Records. Scientific Reports, 6(1):1–
10.
Andriy Mulyar and Bridget T McInnes. 2020. MT- clinical BERT: scaling clinical information ex- traction with multitask learning. arXiv preprint arXiv:2004.10220.
Andriy Mulyar, Elliot Schumacher, Masoud Rouhizadeh, and Mark Dredze. 2019. Phenotyping of Clinical Notes with Improved Document Classi- fication Models Using Contextualized Neural Lan- guage Models. arXiv preprint arXiv:1910.13664.
Amin Naemi, Marjan Mansourvar, Thomas Schmidt, and Uffe Kock Wiil. 2020. Prediction of patients severity at emergency department using n and en- semble learning arx. In 2020 IEEE International Conference on Bioinformatics and Biomedicine (BIBM), pages 2793–2799. IEEE.
Vinod Nair and Geoffrey E Hinton. 2010. Rectified linear units improve restricted boltzmann machines.
In Proceedings of ICML.
Bhargava K Reddy and Dursun Delen. 2018. Pre- dicting hospital readmission for lupus patients:
An RNN-LSTM-based deep-learning methodology.
Computers in Biology and Medicine, 101:199–209.
Benjamin Shickel, Patrick James Tighe, Azra Bihorac, and Parisa Rashidi. 2017. Deep EHR: a survey of recent advances in deep learning techniques for elec- tronic health record (EHR) analysis. IEEE Journal of Biomedical and Health Informatics, 22(5):1589–
1604.
Wei Sun, Shaoxiong Ji, Erik Cambria, and Pekka Mart- tinen. 2021. Multitask recalibrated aggregation net- work for medical code prediction. In Proceedings of ECML-PKDD.
Betty van Aken, Jens-Michalis Papaioannou, Manuel Mayrdorfer, Klemens Budde, Felix Gers, and Alexander Loeser. 2021. Clinical outcome pre- diction from admission notes using self-supervised knowledge integration. In Proceedings of the 16th Conference of the European Chapter of the Associ- ation for Computational Linguistics: Main Volume, pages 881–893.
Yan Yan, Glenn Fung, Jennifer G Dy, and Romer Ros- ales. 2010. Medical coding classification by lever- aging inter-code relationships. In Proceedings of the 16th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pages 193–
202.
Yu Zhang and Qiang Yang. 2021. A survey on multi- task learning. IEEE Transactions on Knowledge and Data Engineering.
Yuan Zhang, Xi Yang, Julie Ivy, and Min Chi. 2019.
ATTAIN: Attention-based Time-Aware LSTM Net- works for Disease Progression Modeling. In Pro- ceedings of the 28th International Joint Conference on Artificial Intelligence, pages 4369–4375.
Damien Zufferey, Thomas Hofer, Jean Hennebert, Michael Schumacher, Rolf Ingold, and Stefano Bro- muri. 2015. Performance comparison of multi-label learning algorithms on clinical data for chronic dis- eases. Computers in Biology and Medicine, 65:34–
43.
A Additional Illustrations of Task Label Embeddings
We further illustrate more combinations of task la- bel embeddings. Firstly, we plot task label embed- dings of readmission, admission type and length- of-stay and together with the first 50 diagnoses.
Figure5shows that different label embeddings are relatively far apart from each other, and the two tasks that are similar to each other, readmission prediction and admission type prediction, appear close to each other. We further zoom in to the embeddings of length-of-stay task labels in Figure 6. However, here we cannot find any clear inter- pretable pattern. We hypothesize that the prediction of length-of-stay would benefit from clinical docu- ment understanding and numerical reasoning over the ordinal categories, and the hypernetwork-based architecture with semantic task embeddings does not effectively capture this ordinal relation of dif- ferent lengths of stay. This may partially explain why our method did not provide advantage over other methods on the length-of-stay prediction.
0.0 0.2 0.4
Component 1
−0.3
−0.2
−0.1 0.0 0.1 0.2 0.3
Component2
Readm.
Adm. T Diag.
LoS
Figure 5: Visualization of task label embeddings with dimension reduced by PCA
−0.14 −0.13 −0.12 −0.11 −0.10 −0.09 −0.08 Component 1
−0.14
−0.12
−0.10
−0.08
−0.06
Component2 LoS-0LoS-1
LoS-2 LoS-3 LoS-4 LoS-5 LoS-6 LoS-7 LoS-8 LoS-9
Figure 6: Visualization of LOS label embeddings with dimension reduced by PCA
B Ablation Study
Our proposed method consists of two hypernet- works, i.e., adapter hypernetwork for task-specific parameter generation and weight hypernetwork (WH) for task weight coefficient generation. We conduct an ablation study on the usage of weight hypernetwork. Table5and Table6show the results of patient outcome prediction from early notes and prediction at discharge respectively. The weight hypernetwork dynamically adjusts the weight coef- ficients of joint loss and contributes to the perfor- mance improvement in most cases.
C Results of Multitask Learning with Four Tasks Jointly Trained
We compare the results of multitask learning with four tasks jointly trained, i.e., readmission predic- tion, admission type classification, diagnosis pre- diction, and length-of-stay prediction. As admis- sion type classification has no clinical usage, we include this comparison in the appendix for com- pleteness. Table8and Table7show the results of patient outcome prediction from early notes and prediction at discharge respectively. Our method performs better than baselines in most cases on the four tasks, and achieves the best overall prediction performance measured by the average score.
Table 5: Ablation study on patient outcome prediction from early notes with three tasks jointly trained
Task Method Progressive Ultimate
F1 AUC-ROC F1 AUC-ROC
Readmission without WH 56.74 ± 1.26 57.28 ± 1.17 57.94 ± 1.53 61.56 ± 1.79 with WH 56.38 ± 2.30 57.57 ± 1.27 57.67 ± 2.21 62.41 ± 1.61 Diagnosis without WH 18.90 ± 2.06 69.26 ± 0.33 19.91 ± 2.29 69.58 ± 0.47 with WH 19.56 ± 1.33 72.98 ± 0.45 20.47 ± 1.38 73.21 ± 0.43 LOS without WH 33.59 ± 0.36 71.14 ± 1.25 33.85 ± 0.34 71.44 ± 1.32 with WH 33.18 ± 0.91 71.84 ± 1.95 33.28 ± 1.03 72.23 ± 2.18 Average without WH 36.41 ± 1.23 65.89 ± 0.92 37.23 ± 1.39 67.53 ± 1.19 with WH 36.37 ± 1.51 67.46 ± 1.22 37.14 ± 1.54 69.28 ± 1.41
Table 6: Ablation study on patient outcome prediction at discharge with three tasks jointly trained
Task Method Progressive Ultimate
F1 AUC-ROC F1 AUC-ROC
Readmission without WH 65.42 ± 0.88 65.21 ± 0.73 68.11 ± 1.46 75.99 ± 0.37 with WH 66.41 ± 0.77 66.08 ± 0.74 69.66 ± 1.52 76.93 ± 1.11 Diagnosis without WH 10.05 ± 0.87 75.13 ± 1.12 12.68 ± 1.39 78.54 ± 1.56 with WH 10.28 ± 0.25 76.93 ± 0.73 12.99 ± 1.57 79.07 ± 1.07 LOS without WH 31.20 ± 0.31 73.47 ± 0.59 30.82 ± 0.34 72.62 ± 0.79 with WH 31.33 ± 0.38 73.76 ± 1.31 31.08 ± 0.21 73.01 ± 1.28 Average without WH 35.56 ± 0.69 71.27 ± 0.81 37.20 ± 1.06 75.72 ± 0.91 with WH 36.01 ± 0.47 72.26 ± 0.93 37.91 ± 1.10 76.34 ± 1.15
Table 7: Results of patient outcome prediction at discharge with four tasks jointly trained
Task Method Progressive Ultimate
F1 AUC-ROC F1 AUC-ROC
Readmission
MT-LSTM 58.63 ± 3.14 58.87 ± 2.14 60.81 ± 5.04 66.82 ± 3.22 MT-BERT 64.64 ± 0.89 64.54 ± 0.51 68.36 ± 0.65 76.32 ± 1.41 MT-RAM 63.92 ± 1.53 64.35 ± 0.90 67.66 ± 1.61 76.17 ± 0.93 MT-Hyper 64.93 ± 0.92 64.76 ± 0.88 67.86 ± 2.29 76.62 ± 1.26 Admission
Type
MT-LSTM 87.66 ± 0.33 97.15 ± 0.31 84.42 ± 0.96 96.35 ± 0.54 MT-BERT 87.70 ± 0.24 97.28 ± 0.14 84.29 ± 0.52 96.43 ± 0.17 MT-RAM 88.58 ± 0.76 97.26 ± 0.14 86.15 ± 1.05 96.78 ± 0.25 MT-Hyper 88.70 ± 0.53 96.36 ± 0.91 86.69 ± 1.54 96.15 ± 1.45
Diagnosis
MT-LSTM 4.57 ± 0.97 64.94 ± 0.28 4.00 ± 1.28 62.87 ± 0.40 MT-BERT 9.73 ± 0.45 65.99 ± 0.46 10.27 ± 0.66 64.07 ± 0.41 MT-RAM 7.72 ± 1.55 66.22 ± 0.43 7.76 ± 2.05 64.17 ± 0.44 MT-Hyper 10.18 ± 0.90 76.00 ± 1.70 13.10 ± 1.61 68.17 ± 2.03
LOS
MT-LSTM 31.31 ± 0.72 75.73 ± 1.50 31.67 ± 0.84 75.87 ± 2.97 MT-BERT 32.11 ± 0.84 78.59 ± 0.54 35.10 ± 0.80 81.36 ± 0.26 MT-RAM 32.12 ± 1.26 77.53 ± 1.14 34.32 ± 1.76 80.27 ± 2.09 MT-Hyper 32.37 ± 0.46 77.10 ± 1.17 32.53 ± 0.54 77.86 ± 1.95
Average
MT-LSTM 45.54 ± 1.29 74.17 ± 1.06 45.23 ± 2.03 75.48 ± 1.78 MT-BERT 48.54 ± 0.61 76.60 ± 0.42 49.51 ± 0.66 79.54 ± 0.56 MT-RAM 48.08 ± 1.28 76.34 ± 0.65 48.97 ± 1.62 79.35 ± 0.93 MT-Hyper 49.05 ± 0.70 78.55 ± 1.17 50.05 ± 1.49 79.70 ± 1.67
Table 8: Results of patient outcome prediction from early notes with four tasks jointly trained
Task Method Progressive Ultimate
F1 AUC-ROC F1 AUC-ROC
Readmission
MT-LSTM 50.78 ± 2.96 53.48 ± 1.46 51.24 ± 2.92 57.29 ± 1.54 MT-BERT 55.93 ± 1.81 56.29 ± 1.74 57.05 ± 1.61 60.74 ± 1.53 MT-RAM 55.78 ± 1.74 56.39 ± 1.65 57.10 ± 1.55 61.36 ± 1.43 MT-Hyper 56.32 ± 1.54 56.67 ± 1.04 57.23 ± 1.47 61.07 ± 1.02 Admission
Type
MT-LSTM 85.64 ± 0.98 95.04 ± 0.72 84.41 ± 1.08 94.47 ± 0.83 MT-BERT 87.99 ± 1.43 95.78 ± 0.56 87.12 ± 1.50 95.47 ± 0.65 MT-RAM 88.32 ± 1.15 95.94 ± 0.33 87.58 ± 1.08 95.57 ± 0.33 MT-Hyper 89.31 ± 1.76 95.58 ± 1.80 88.31 ± 1.85 95.16 ± 1.97
Diagnosis
MT-LSTM 4.10 ± 0.64 62.74 ± 0.09 3.60 ± 0.65 61.95 ± 0.11 MT-BERT 11.56 ± 0.54 64.01 ± 0.51 11.43 ± 0.90 63.20 ± 0.57 MT-RAM 10.41 ± 1.18 65.87 ± 0.55 10.46 ± 1.12 65.13 ± 0.66 MT-Hyper 12.00 ± 1.43 68.95 ± 1.07 13.16 ± 1.45 69.20 ± 1.07
LOS
MT-LSTM 32.85 ± 1.02 75.70 ± 1.49 33.54 ± 0.68 76.27 ± 1.43 MT-BERT 27.62 ± 1.99 74.09 ± 1.70 28.88 ± 1.99 75.32 ± 1.67 MT-RAM 27.19 ± 1.72 73.59 ± 2.32 28.47 ± 1.84 74.81 ± 2.19 MT-Hyper 31.54 ± 1.42 72.35 ± 1.99 31.46 ± 1.81 72.91 ± 2.30
Average
MT-LSTM 43.34 ± 1.40 71.74 ± 0.94 43.20 ± 1.33 72.50 ± 0.98 MT-BERT 45.77 ± 1.44 72.54 ± 1.13 46.12 ± 1.50 73.68 ± 1.10 MT-RAM 45.43 ± 1.44 72.95 ± 1.21 45.91 ± 1.40 74.22 ± 1.15 MT-Hyper 47.29 ± 1.54 73.39 ± 1.48 47.54 ± 1.65 74.59 ± 1.59