A friend asked for articles to explain neural networks to business people, and I couldn’t find any articles I liked, so I decided to write one. This was originally written in my notebook, but it fits the format here better, so I’ve moved it over.
Neural networks are very hot right now, and people seem to treat them as a magic black box that can solve any problem. Unfortunately, while neural networks are a powerful and useful tool, they are much more limited than their popular representation would suggest, and I’d like to give you a bit of a sense of what they are, why they’re useful, and where that use falls down.
The following are the most important things to understand about neural networks:
- Neural networks are an interesting implementation technique for a well studied set of problems.
- Neural networks do not allow a machine to “think” in any general sense – they are mostly useful for implementing simple decision procedures and predictive models, but they do not do any general reasoning or planning (they may be used as components in systems that do though).
- Neural networks are only as good as the data you give them.
What do neural networks do?
The general problem that neural networks solve is called machine learning. This is a somewhat misleading term because how machines “learn” in this sense doesn’t map that well to how humans learn. Machine learning is really just a particular type of automated statistics that can be used to make simple predictions and decisions based on similarity to previously observed data. There are many ways to do machine learning, and any given instance of it is a particular algorithm, a precise set of rules that describe how the computer “learns” in this particular problem.
Any machine learning always starts with some set of data that we “train” on. In the same way that “learning” is misleading, “training” is also misleading. The analogy is that we are teaching the machine to do something by showing it lots of examples, but the reality is more that a program does statistics to determine patterns in the data that it can use to make predictions later.
Typically we present this to the machine learning algorithm as a collection of features, which are a way of breaking down the objects we want to work on into a set of numeric or categorical (i.e. one of a small number of possibilities) properties. For example:
- We might represent an image as red-green-blue integer values for each pixel.
- We might represent a person in terms of “age”, “country of birth”, “gender”.
- We might represent a piece of text as a “bag of words”, which counts the number of times each word appears in it.
All decisions made by the machine learning process are based only on these features – two things that are somehow different in a way not captured by their features will still get the same result. Features are not intended to fully represent the things we study, but capture some aspects of it that we think are important and should be sufficient to predict the results we want to know. The process of turning real world things into useful features is something of an art, and is often as or more important than the actual machine learning you do.
Once we have turned our data set into features, we train on it. This is the process of taking our data set and turning it into a model that we will use as the basis of our future decisions. Sometimes we do this with the data presented entirely up front, and sometimes the algorithm is designed so that it can learn as you go. The former is much more common, especially with neural networks, but both approaches are possible.
One important difference between machine and human learning is that machine learning tends to need a lot more data than you would expect a human to. Often we train our machine learning algorithms on hundreds of thousands, or millions of data points, where you might expect a human to learn after only a few. It takes a lot longer for most machine learning approaches to learn what a picture of a dog looks like than it does a toddler. Why is complicated, and how to do better is an open research problem, but the main reasons are:
- the toddler has a lot of built in machinery about how to do image processing already, while the machine learning system has to learn that from scratch.
- the machine learning system is much more general and can learn things that humans are bad at as easily as it can learn things that humans are good at – e.g. a toddler would struggle to do much with actuarial data that a machine learning algorithm would find easier than recognising dogs.
The need for these large data sets for training is where “big data” comes in. There are a lot of arguments as to what should count as big data, but a good rule of thumb is “if you can fit it on a single computer then it’s not big data”. Given how large computers can affordably get these days, most things that people call big data probably aren’t.
Anyway, given a set of input data, there are roughly three types of commonly studied machine learning that we can do with it:
- Supervised learning takes some input data and some labels or scores for it, and tries to learn how to predict those labels or scores for other similar data. For example, you might want to classify an email as “spam” or “not spam”, or you might want to predict the life expectancy of someone given a set of data about their health.
- Reinforcement learning is for making decisions based on data. Based on the outcome of the decision, you feed back into the process with either a “reward” or a “punishment” (I want to emphasise again that this is a metaphor and the algorithm is not in any meaningful sense thinking or able to experience pleasure or pain), and the algorithm adjusts its decision making process to favour decisions that are similar to previous ones that have worked well and different from previous ones that have worked badly. e.g. you might use this for stock picking, and reward the algorithm for picking stocks that do well and punish it for picking stocks that do badly.
- Unsupervised learning builds a model of the “shape” of the data. I won’t talk about this too much, but if you’ve seen those deep dream pictures or “I trained a bot on this corpus of text and this is what I got” articles, that is usually what’s going on here (although the much more common case, particularly with the articles, is that someone has just made it up and no machine learning was involved at all): The system has built a predictive model of the data and is used to randomly generate something that matches that model.
Almost all applications of machine learning you hear about are one of these three things, possibly with some additional logic layered on top. For example AlphaGo, Google’s Go playing AI, is roughly a combination of supervised learning and unsupervised learning with a rules based system that describes the rules of Go, and uses the output of the machine learning to choose moves.
Neural networks are generally not used for reinforcement learning. I’m unclear on to what degree this is due to intrinsic limitations and to what degree it’s not just well supported in the tooling but the way we usually build neural networks requires a large batch training process, so applications of neural networks will generally be supervised or unsupervised learning.
There are a very large number of different approaches you can take to these problems, of which neural networks are only one of them. An easy to understand and classic approach to machine learning is the use of decision trees: Simple rules of the form “If (this feature has this range of values) then (do this) else (do this other thing)”. These are very easy to learn – you pick some feature that splits the data set well, break the data set into two parts based on that, and then try again on the smaller parts. If nothing works particularly well, you add a rule that says “Predict (most common outcome)” (I am eliding a lot of details here that don’t matter to you unless you want to actually implement these ideas).
One classic limitations that machine learning struggles with is in learning patterns that require understanding complex aspects of the data that are not easily understood from a small number of features. If you imagine image processing, to the computer a small image is just a list of a couple hundred thousand numbers. A human would struggle to get anything useful from that too!
Neural networks attempt to overcome this by learning in layers. Each layer consists of something that looks a bit like unsupervised learning and a bit like supervised learning, where each layer “learns” to extract high level features from the one below. In the image case you start with the bottom layer that looks at the raw pixel data, and then it might e.g. try to identify patterns of lines in that data. Each layer is essentially a new set of features that takes the overly detailed previous layer’s features and tries to turn it into a representation that it can more easily work with. This allows us to combine a very simple learning technique (the “neurons” that are really just a very simple bit of machine learning that tries to turn a set of features into a score between zero and one) into a much more complex one that is able to handle high level features of the data. This is where the “deep” in “deep learning” comes from – a deep neural network is one with many layers.
The result of this is that where a “shallower” machine learning system might suffer from a problem of not being able to see the wood for the trees, neural networks can make predictions based on a more structured representation of the data, which makes things obvious that were hard for the algorithm to see from the more fine-grained representation. Some times these structured representations will be ones that are obvious to humans (e.g. lines in images), but often they will not be, especially (e.g. they capture some strategic aspect of the Go board).
Why are neural networks cool right now?
Neural networks are not at all new technology, but are recently seeing a revival for a number of reasons:
- We have a lot more data now, which allows us to paper over any limitations by just training it more.
- We have much faster computers now, particularly as a lot of neural network training can be made highly parallel (that means that you can break it up into small chunks that can be run at the same time without waiting for each other, then combine the results).
- We have made a number of incremental improvements to the algorithms that allow us to speed up our training and improve the quality of results.
This has allowed us to apply them to problems where it was previously infeasible, which is the main source of all of the current progress in this field.
What is going to go wrong when I use neural networks?
This is important to understand, because things will go wrong. The main things that it is important to understand about this are:
- It is much more important to have good input data than good algorithms, and gathering good input data is expensive. Machine learning is only as good as its training, and if your input data is biased or missing important examples, your machine learning will perform badly. e.g. A classic failure mode here is that if you train image recognition only on white people, it will often fail to see people of colour.
- Compared to a human decision maker, neural networks are fragile. There is an entire class of things called “adversarial examples” where carefully chosen trivial changes that a human wouldn’t even notice can cause the neural network to output an entirely different answer.
- Even without adversaries, machine learning algorithm will make stupid goofs that are obvious to a human decision maker. It will do this even if it is on average better than a human decision maker. This is simply because different things are obvious to machines and humans. Depending on how visible these decisions are, this will probably make you look silly when it happens.
When should I use neural networks?
First off, you should not be making a decision about whether to use neural networks if this article is teaching you new things. You should be making a decision on whether to usemachine learning. Leave the decision of what type you should be using to an expert – there is a good chance they will choose neural networks (tech people are very trend driven), but there might be something simpler and better suited for your problem.
Roughly the following rules of thumb should help you decide whether to use machine learning:
- Is this a problem that would benefit from being able to make lots of fairly constrained decisions very fast? If no, then maybe talk to an expert (some problems don’t look like this on the face of it but can still benefit – e.g. translation isn’t obviously of this form, but it can benefit from machine learning), but you’re probably out of luck and even if you’re not this is going to be a huge and expensive research project.
- Could you just write down a short list of rules that are sufficient to make those decisions? If yes, just do that instead of trying to learn them from data.
- If you had an averagely intelligent human with maybe a couple of days of on the job training and enough background knowledge to understand the problem making those decisions, would their answers be good enough? If no, you’re really going to struggle to get your machine learning to be good enough.
- Think through the failure modes of the previous section. Do you have a plan in place to avoid them, or to mitigate them when they occur? Is the cost of that plan worth the benefits of using machine learning?
- Reflect on the following question: When someone in the media says “Why did you use a machine rather than a person here?” and your answer is “It was cheaper”, how is it going to play out and are you happy with that outcome? If no, your problem is probably politically a poor fit for machine learning.
If your problem passed all of those tests, it might well be amenable to machine learning. Go talk to an expert about it.