Teaching machines how to do standardized test-like reading questions.
This project and article was jointly done by Yonah Mann, Rohan Menezes and myself.
When you think about it, reading comprehension is kind of a miracle of human thinking. That we can take a piece of text and, with little to no context, gain a deep understanding of the purpose of the text and even infer facts that do not feature verbatim in the text is really quite difficult and impressive. In the world of artificial intelligence and machine learning, researchers have spent years and even decades trying to teach machines to read and comprehend. Over the past few weeks, our team has worked on solving one small piece of the “reading comprehension puzzle”.
How is this task modeled?
To effectively measure how well a machine can “understand text”, let’s think back to how humans are tested on their comprehension: through a standardized reading exam!
Remember these innocuous-looking, but incredibly difficult questions on exams?
Well, OK, these questions look pretty easy. But, that’s not always the case! As anyone who has taken a higher level standardized test knows, reading comprehension questions can get quite difficult. For now, we turn our focus to “fill-in-the-blank” style questions: a machine is given a reading passage, as well as a multiple choice fill-in-the-blank question, and all it has to do is choose the right choice that would best fill the blank! Sounds simple, right?
Turns out the problem is fairly difficult for machines to learn. We call this “fill-in-the-blank” style of question a cloze-style reading comprehension task. There are many difficulties that are faced here:
- First, machines have to learn the structure and meaning of language first. Unlike a human, who is already familiar with how words mesh together in a sentence and is able to “understand” the true meaning behind one, a machine needs to somehow be taught that.
- Second, it may not be obvious where to look in a passage for the answer to a question. One could spend a really long time looking for an answer in a piece of text and it could be right there staring at them on the page! For machines it’s even harder; since language is highly flexible a sequence of words that you are looking for might not show up word-for-word in the passage.
With reading comprehension being so difficult, there’s no singular approach machines can take to solve the problem. So, now what do we do?
Let’s add a little…MACHINE LEARNING!
Why don’t we employ the power of machine learning to help us solve this problem?
Machine learning has emerged to be an extremely powerful technique in reading text and extracting important concepts from it; it’s been the obsession of most computational linguists for the past few years.
So let’s make this obsession a good one and put it to use in our problem!
First, a brief detour: we’re going to be using the CNN/Daily Mail dataset in this project. Take a look at an example document/query:
( @entity1 ) it’s the kind of thing you see in movies, like @entity6’s role in “@entity7” or @entity9’s “@entity8.” but, in real life, it’s hard to swallow the idea of a single person being stranded at sea for days, weeks, if not months and somehow living to talk about it. miracles do happen, though, and not just in @entity17…
Query: an @entity156 man says he drifted from @entity103 to @placeholder over a year @entity113
Each document and query has undergone entity recognition and tokenization already. The goal is to guess which entity should be substituted into “@placeholder” in order for the query to make sense.
Our goal now is to formulate this as an appropriate machine learning problem that we can train a model and use it to predict words correctly. In spirit of this, we can actually form our problem as a binary classification problem; that is, given a new document and query pair, we can transform it into a set of new “document-query” pairs such that a certain subset of them correspond to correctly guessing an entity fits the blank, and the other subset corresponding to negative examples, i.e. correctly guessing that an entity should not fill the blank.
For every document-query pair, we also create some features to associate with the pair, since feeding the entire pair into a machine learning model at this point is infeasible.
We employ a logistic regression model to implement this problem. After training the model, we achieve an accuracy of 29%, meaning that 29% of the documents had the blank filled in correctly. For context, most documents in the dataset contain about 25 entities, so randomly guessing a word for each document would have an accuracy of roughly 4%. So this model performs pretty decently!
The logistic regression model performs okay, but if we’re being honest, 29% accuracy isn’t exactly “human-like”. So, how can we make our model learn more effectively?
Here is where deep learning comes into play. When humans read text, they don’t just learn a few heuristics about the text and then make guesses based off of those heuristics. Rather, they learn to understand the underlying meaning of the text and to make meaningful inferences based off of their understanding. That is our goal with this problem as well! Deep learning will provide us with the tools we need to truly teach machines to read.
WIth a new approach comes new goals. From now we on, we don’t want to limit this problem to that of a binary classification problem, rather we view this problem more holistically — our model will be allowed to choose any word in the document as the correct entity to “fill in the blank”. This is more representative of actual learning then our previous formulation.
Using these machine learning techniques is great and all, but can we do better than this? On the one hand, logistic regression is an effective machine learning model to use and is a quick way to get a baseline accuracy, but it falls short on several aspects. The way that logistic regression decides whether a word should fill in a blank is too rigid; namely, logistic regression can only learn linear functions, which aren’t that suitable for a wide range of problems.
This is where we can now turn to deep learning and the power of neural networks for our problem. Neural networks are a recent hot development in machine learning that allow us to learn more complex functions than normal models like logistic regression can.
In this article, we’ll consider a special kind of neural network, called Long Short-Term Memory, or LSTM for short. Here’s what the LSTM looks like:
Seems complicated, but if we break it down piece by piece, we can understand what this network is doing. Imagine reading over a sentence: word by word, your mind is thinking about the sentence as you read and is formulating thoughts little by little. The same goes for an LSTM; it will take in each word of a sentence and generate a hidden state after seeing a word. You can think of this hidden state as the thought that the LSTM transmits to the next time step when it comes time to read the next word.
In the context of our problem, we can feed the passage followed by the question into our LSTM, and finally guess which word would best fit the blank in the query based on the final output of the LSTM. (If you’re wondering how to obtain the final word, the way we obtain this is by taking our output from our LSTM, and creating a list of probabilities for every possible word that can possibly fill the blank word. We then choose the word that has the highest probability.)
What is special about the LSTM (versus other networks that have a similar structure) is that an LSTM is able to “remember” information about words over long ranges in the sentence, and be able to “forget” information quickly when necessary. This allows an LSTM to determine what is important in looking at a certain word and what it needs to remember for previous words.
You may now ask: “how do we feed words into a network?” We could feed in the actual strings into the network, but it’s hard for neural networks to parse raw strings of data. Instead, we represent each word using embeddings. This involves representing each word using a fixed-length vector, so that it is easy for the LSTM to run computations on the words. Ideally, we want words that have to do with each other to be “closer” to each other with respect to their embeddings.
Fortunately, the great minds at Stanford NLP have already done this task for us; they have a downloadable set of embeddings called GloVe (Global Vectors for Word Representation) that have proved to be very effective in natural language processing tasks. We use these in our models, and achieve a stunning increase in accuracy: 39%! This improvement over the base non-deep model signifies the power of deep learning in being able to model this task, as well as the power of the LSTM.
The linear regression and BiLSTM loss curves. Note how the BiLSTM loss rate drops, this is a sign of fantastic learning!