Binary search is one of those classic algorithms that most people who know about algorithms at all will know how to do (and many will even be able to implement correctly! Probably fewer than think they can though – it took me a long time to go to thinking I could implement binary search correctly to actually being able to implement it correctly).
Some of this is because the way people think about binary search is somewhat flawed. It’s often treated as being about sorted arrays data, when that’s really only one application of it. So lets start with a review of what the right way to think about binary search is.
We have two integers \(a\) and \(b\) (probably non-negative, but it doesn’t matter), with \(a < b\). We also have some function that takes integers \(a \leq i \leq b\), with \(f(a) \neq f(b)\). We want to find \(c\) with \(a \leq c < b\) such that \(f(c) \neq f(c+ 1)\).
i.e. we’re looking to find a single exact point at which a function changes value. For functions that are monotonic (that is, either non-increasing or non-decreasing), this point will be unique, but in general it may not be.
To recover the normal idea of binary search, suppose we have some array \(xs\) of length \(n\). We want to find the smallest insertion point for some value \(v\).
To do this, we can use the function \(f(i)\) that that returns whether \(xs[i] < v\). Either this function is constantly true (in which case every element is < v and v should be inserted at the end), constantly false (in which case v should be inserted at the beginning), or the index i with \(f(i) \neq f(i + 1)\) is the point after which \(v\) should be inserted.
This also helps clarify the logic for writing a binary search:
def binary_search(f, lo, hi): # Invariant: f(hi) != f(lo) while lo + 1 < hi: assert f(lo) != f(hi) mid = (lo + hi) // 2 if f(mid) == f(lo): lo = mid else: hi = mid return lo
Every iteration we cut the interval in half - because we know the gap between them is at least one, this must reduce the length. If \(f\) gives the same value to the midpoint as to lo, it must be our new lower bound, if not it must be our new upper bound (note that generically in this case we might not have \(f(mid) = f(hi)\), though in the typical case where \(f\) only takes two values we will).
Anyway, all of this is besides the point of this post, it's just scene setting.
Because the point of this post is this: Is this actually optimal?
Generically, yes it is. If we consider the functions \(f_k(i) = i < k\), each value we examine can only cut out half of these functions, so we must ask at least \(\log_2(b - a)\) questions, so binary search is optimal. But that's the generic case. In a lot of typical cases we have something else going for us: Often we expect change to be quite frequent, or at least to be very close to the lower bound. For example, suppose we were binary searching for a small value in a sorted list. Chances are good it's going to be a lot closer to the left hand side than the right, but we're going to do a full \(\log_2(n)\) calls every single time. We can solve this by starting the binary search with an exponential probe - we try small values, growing the gap by a factor of two each time, until we find one that gives a different value. This then gives us a (hopefully smaller) upper bound, and a lower bound somewhat closer to that.
def exponential_probe(f, lo, hi): gap = 1 while lo + gap < hi: if f(lo + gap) == f(lo): lo += gap gap *= 2 else: return lo, lo + gap return lo, hi
We can then put these together to give a better search algorithm, by using the exponential probe as the new upper bound for our binary search:
def find_change_point(f, lo, hi): assert f(lo) != f(hi) return binary_search(f, *exponential_probe(f, lo, hi))
When the value found is near or after the middle, this will end up being more expensive by a factor of about two - we have to do an extra \(\log_2(n)\) calls to probe up to the midpoint - but when the heuristic wins it potentially wins big - it will often take the \(\log_2(n)\) factor (which although not huge can easily be in the 10-20 range for reasonable sized gaps) and turn it into 1 or 2. Complexity wise, this will run in \(O(\log(k - lo)\), where \(k\) is the value returned, rather than the original \(O(hi - lo)\).
This idea isn't as intrinsically valuable as binary search, because it doesn't really improve the constant factors or the response to random data, but it's still very useful in a lot of real world applications. I first came across it in the context of timsort, which uses this to find a good merge point when merging two sublists in its merge step.
Edit to note: It was pointed out to me on Twitter that I'm relying on python's bigints to avoid the overflow problem that binary search will have if you implement it on fixed sized integers. I did know this at one point, but I confess I had forgotten. The above code works fine in Python, but if int is fixed size you want the following slightly less clear versions:
def midpoint(lo, hi): if lo <= 0 and hi >= 0: return (lo + hi) // 2 else: return lo + (hi - lo) // 2 def binary_search(f, lo, hi): # Invariant: f(hi) != f(lo) while lo + 1 < hi: assert f(lo) != f(hi) mid = midpoint(lo, hi) if f(mid) == f(lo): lo = mid else: hi = mid return lo def exponential_probe(f, lo, hi): gap = 1 midway = midpoint(lo, hi) while True: if f(lo + gap) == f(lo): lo += gap if lo >= midway: break else: gap *= 2 else: hi = lo + gap break return lo, hi
These avoid calculating any intermediate integers which overflow in the midpoint calculation:
- If \(lo \leq 0\) and \(hi \geq 0\) then \(lo \leq hi + lo \leq hi\), so is representable.
- If \(lo \geq 0\) then \(0 \leq hi - lo \leq hi\), so is representable.
The reason we need the two different cases is that e.g. if \(lo\) were INT_MIN and \(hi\) were INT_MAX, then \(hi - lo\) would overflow but \(lo + hi\) would be fine. Conversely if \(lo\) were INT_MAX - 1 and \(hi\) were INT_MAX, \(hi - lo\) would be fine but \(hi + lo\) would overflow.
The following should then be a branch free way of doing the same:
def midpoint(lo, hi): large_part = lo // 2 + hi // 2 small_part = ((lo & 1) + (hi & 1)) // 2 return large_part + small_part
We calculate (x + y) // 2 as x // 2 + y // 2, and then we fix up the rounding error this causes by calculating the midpoint of the low bits correctly. The intermediate parts don't overflow because we know the first sum fits in \([lo, hi]\), and the second fits in \([0, 1]\). The final sum also fits in \([lo, hi]\) so also doesn't overflow.
I haven't verified this part too carefully, but Hypothesis tells me it at least works for Python's big integers, and I think it should still work for normal C integers.