Chapter 2: Computational modelling of the incremental processing of a sentence
2.3. Modelling prediction with neural networks
2.3.4. Adding recurrence in the network
Although a simple neural network can be trained to generate an accurate prediction based on the given linguistic context, it lacks one of the most important aspects of human speech processing. Speech comprehension in humans involves understanding the relationship between sequentially unfolding words over time and interpreting them in the context of each other. The cognitive significance of “time” is not merely limited to language as human behaviours are generally co-ordinated in time. It directly implies causation and understanding the causal relationship between the series of behaviours over time, in turn, enlightens one’s metacognitive processes. Therefore, any plausible cognitive model of human behaviours must represent temporal relation between the sequences of events.
An intuitive approach is to express time explicitly as an input in a form of a vector (or matrix). The first element in this vector represents the first temporal event, the second element
represents the second temporal event and so on. However, the duration (or the number) of events often vary in practice and such events cannot be compared in this framework (i.e. all vectors must be in same length). Also, consider the following two vectors:
[0 1 1 1 0 0 0 0 0] [0 0 0 1 1 1 0 0 0]
Although these vectors could plausibly reflect the same basic pattern in time (e.g. “He chose
the path that ran by the river” vs. “The experienced walker chose the path crossing the
river”), they can be judged as highly dissimilar because of the geometric difference in their absolute temporal positions (Elman, 1990). Rather than providing the information about time explicitly as an input in a specific format, Elman (1990) argued for representing time
implicitly by its effects on processing. In this perspective, an input is an operator on the mental state such that it alters the state of the system to produce a goal-oriented behaviour. Then, the implicit representation of time can be expressed by adding recurrent links between the states of the system over time (see Figure 2-5).
71
Neural networks with these recurrent links are called recurrent neural networks which are common approach for language modelling in these days. Unlike a simple neural network whose prediction is purely based on the current input, a recurrent network alters the previous internal state based on the current input (see Figure 2-5).
Figure 2-5: Visual illustration of a recurrent neural network. 𝑥, 𝑠 and 𝑜 are input, hidden and output representations respectively. 𝑈 is a weight matrix that projects the input 𝑜 at any arbitrary given time 𝑡 to the hidden layer 𝑠 at 𝑡. 𝑊 is a weight matrix mapping the previous hidden state 𝑠(𝑡 − 1) to the current state 𝑠(𝑡) (i.e. a recurrent link). 𝑉 is a weight matrix mapping the hidden state 𝑠 to the output 𝑜. Note that the recurrent link 𝑊 is a new feature added to this recurrent architecture that does not exist in a simple neural network in Figure 2-3. With this addition, the concept of “time” is now implicitly represented by the
architecture.
The forward propagation in this architecture can be expressed by a set of equations below: 𝑠(𝑡) = 𝜎(𝑥(𝑡)𝑈 + 𝑠(𝑡 − 1)𝑊 + 𝑏1) … (8)
𝑜(𝑡) = 𝜑(𝑠(𝑡)𝑉 + 𝑏2) … (9)
where 𝜎 and 𝜑 are the arbitrary non-linear activation functions at hidden (e.g. sigmoid) and output (e.g. softmax) layers respectively and 𝑏1 and 𝑏2 are the bias terms, allowing the layers to model the data space centred on some point other than the origin. Other notations are as described in Figure 2-5. Without the 𝑠(𝑡 − 1)𝑊 term in (8), the propagation becomes exactly same as a simple feedforward neural network described above.
72
Training RNN works similarly to a simple neural network except that the recurrent link 𝑊 is also trained by back-propagating the error gradient through time using the chain rule as described in Appendix 7: 𝜕 𝜕𝑊𝑞1,𝑞2𝐻(𝑌(𝑡), 𝑂(𝑡)) = ∑ 𝜕𝐻(𝑌(𝑡), 𝑂(𝑡)) 𝜕𝑠1(𝑡)𝑗 𝜕𝑠1(𝑡)𝑗 𝜕𝑠(𝑡)𝑞2 𝜕𝑠(𝑡)𝑞2 𝜕𝑠2(𝑡)𝑞2 𝜕𝑠2(𝑡)𝑞2 𝜕𝑊𝑞1,𝑞2 𝐽 𝑗=1 … (10)
where 𝑠1(𝑡) = 𝑠(𝑡)𝑉 + 𝑏2 and 𝑠2(𝑡) = 𝑥(𝑡)𝑈 + 𝑠(𝑡 − 1)𝑊 + 𝑏1. Then, 𝜕
𝜕𝑊𝑞1,𝑞2𝐻(𝑌(𝑡), 𝑂(𝑡)) = ∑ 𝑉(𝑡)𝑞2,𝑗(𝑜(𝑡)𝑗− 𝑦(𝑡)𝑗)𝑠(𝑡)𝑞2(1 − 𝑠(𝑡)𝑞2)𝑠(𝑡 − 1)𝑞1 𝐽
𝑗=1
… (11)
This network only allows one adjacent previous state in time to influence the output. However, in a simple sentence “The business owner declared bankruptcy”, the model will perform much better in predicting “bankruptcy” when it knows the subject “The business owner” on top of the verb “declared”. In order to incorporate the contributions from every hidden state over time, it is necessary to sum up the contributions of each time step to the gradient. Following on from (10), it can be formulated as below:
𝜕 𝜕𝑊𝑞1,𝑞2𝐻(𝑌(𝑡), 𝑂(𝑡)) = ∑𝜕𝐻(𝑌(𝑡), 𝑂(𝑡)) 𝜕𝑠1(𝑡) 𝑗 𝜕𝑠1(𝑡)𝑗 𝜕𝑠(𝑡)𝑞2∑ 𝜕𝑠(𝑡)𝑞2 𝜕𝑠2(𝑡 − 𝜏)𝑞2 𝜕𝑠2(𝑡 − 𝜏)𝑞2 𝜕𝑊𝑞1,𝑞2 𝑡−1 𝜏=0 … (12) 𝐽 𝑗=1 Note that 𝜕𝑠(𝑡)𝑞2
𝜕𝑠2(𝑡−𝜏)𝑞2 can be expanded using the chain rule depending on 𝜏. Hence, the error propagation through time can be computed by the extended formulation of (12). This is known as the back-propagation through time (BPTT) algorithm (due to the fact that the training becomes very difficult as 𝑡 → ∞, a practical implementation of BPTT back- propagates the error gradient only up to a certain time).
Not surprisingly, a recurrent neural network (RNN) generally performs better than the simple neural network when the inputs are sequences (like a sentence in language) instead of
unrelated individual events. However, an important limitation of RNN is that it often fails to capture the long distance dependencies (e.g. the dependency relation between “child” and “smiled” in “The child who I thought you liked smiled”). This is mainly because of the
73
“vanishing gradient” problem during training described above: with the derivative of sigmoid being less than 1 (i.e. ≤ 0.25), propagating the error through a number of recurrent layers necessarily forces the gradient to vanish (i.e. very close to zero), given the number of
multiplications. One solution I suggested above is to use the ReLU instead of the sigmoid as its derivative is either 0 or 1 but this function brings other problems like dead neurons (i.e. a group of neurons can be plunged into a perpetually inactive state). To address this issue of vanishing gradient more effectively, a more sophisticated architecture called long short-term memory (LSTM) was introduced (Hochreiter & Schmidhuber, 1997).