Lazy fisher-yates shuffling for precise rejection sampling

This is a trick I figured out a while ago. It came up in a problem I was working on, so I thought I’d write it up.

I haven’t seen it anywhere else, but I would be very surprised if it were in any way original to me rather than a reinvention. I think it’s niche and not very useful, so it’s hard to find prior art.

Attention conservation notice: My conclusion at the end is going to be that this is a neat trick that is probably not worth bothering with. You get a slight asymptotic improvement for a worse constant factor. I mention some variants and cases where it might be useful at the end, but it’s definitely more of an interesting intellectual exercise than a practical tool.

Suppose you have the following problem: You want to implement the following function:

def sample(random, n, f):
    """returns a number i in range(n) such that f(i)
    is True. Raises ValueError if no such i exists."""

What is the right way to do this?

The obvious way to do it is this:

def sample(random, n, f):
    choices = [i for i in range(n) if f(i)]
    if not choices:
        raise ValueError("No such i!")
    return random.choice(choices)

We calculate all of the possible outcomes up front and then sample from those.

This works but it is always \(O(n)\). We can’t help to do better than that in general – in the case where \(f(i)\) always returns False, we must call it \(n\) times to verify that and raise an error – but it is highly inefficient if, say, f always returns True.

For cases where we know a priori that there is at least one i with f(i) returning True, we can use the following implementation:

def sample(random, n, f):
    while True:
        i = random.randrange(0, n)
        if f(i):
            return i

Here we do rejection sampling: We simulate the distribution directly by picking from the larger uniform distribution and selecting on the conditional property that \(f(i)\) is True.

How well this works depends on a parameter \(A\) – the number of values in \([0, n)\) such that \(f(i)\) is True. The probability of stopping at any loop iteration is \(\frac{A}{n}\) (the probability of the result being True), so the number of loops is a geometric distribution with that parameter, and so the expected number of loop iterations is \(\frac{n}{R}\). i.e. we’ve effectively sped up our search by a factor of \(A\).

So in expectation rejection sampling is always at least as good and usually better than our straightforward loop, but it has three problems:

  1. It will occasionally take more than \(n\) iterations to complete because it may try the same \(i\) more than once.
  2. It loops forever when \(A=0\)
  3. Even when \(A > 0\) its worst-case may take strictly more iterations than the filtering method.

So how can we fix this?

If we fix the first by arranging so that it never calls \(f\) with the same \(i\) more than once, then we also implicitly fix the other two: At some point we’ve called \(f\) with every \(i\) and we can terminate the loop and error, and if each loop iteration returns a fresh \(i\) then we can’t have more than \(n\) iterations.

We can do this fairly naturally by shuffling the order in which we try the indices and then returning the first one that returns True:

def sample(random, n, f):
    indices = list(range(n))
    random.shuffle(indices)
    for i in indices:
       if f(i):
           return i
    raise ValueError("No such i!")

Effectively at each iteration of the loop we are uniformly at random selecting an \(i\) among those we haven’t seen yet.

So now we are making fewer calls to \(f\) – we only call it until we first find an example, as in our rejection sampling. However we’re still paying an \(O(n)\) cost for actually doing the shuffle!

We can fix this partially by inlining the implementation of the shuffle as follows (this is a standard Fisher-Yates shuffle, though we’re doing it backwards:

def sample(random, n, f):
    indices = list(range(n))
    random.shuffle(indices)
    for j in range(n):
        k = random.randrange(j, n)
        indices[j], indices[k] = indices[k], indices[j]
        i = indices[j]
        if f(i):
            return i
    raise ValueError("No such i!")

This works because after \(j\) iterations of the loop in a Fisher-Yates shuffle the indices up to and including \(j\) are shuffled. Thus we can effectively fuse our search into the loop – if we know we’re going to terminate here, we can just stop and not bother shuffling the rest.

But we’re still paying \(O(n)\) cost for initialising the list of indices. Can we do better? Yes we can!

The idea is that we can amortise the cost of our reads of indices into our writes of it, and we only do \(O(\frac{n}{A})\) writes. If we haven’t written to a position in the indices list yet, then its value must be equal to its position.

We can do this naturally in python using defaultdict (there are better data structures for this, but a hashmap gets us the desired amortized complexity) as follows:

class key_dict(defaultdict):
    def __missing__(self, key):
        return key
 
def sample(random, n, f):
    indices = key_dict()
    random.shuffle(indices)
    for j in range(n):
        k = random.randrange(j, n)
        indices[j], indices[k] = indices[k], indices[j]
        i = indices[j]
        if f(i):
            return i
    raise ValueError("No such i!")

So we have now completely avoided any \(O(n)\) costs (assuming we’re writing Python 3 and so range is lazy. If you’re still using legacy Python, replace it with an xrange).

What is the actual expected number of loop iterations?

Well if \(A = n\) then we only make one loop iteration, and if \(A = 0\) we always make \(n\). In general we certainly make no more than \(n – A\) iterations, because after that many iterations we’ve exhausted all the values for which \(f\) is False.

I, err, confess I’ve got a bit lost in the maths for what the expected number of iterations of this loop is. If \(L\) is a random variable that is the position of the loop iteration on which this terminates with a successful result (with \(L = n + 1\) if there are no valid values) then a counting argument gives us that \(P(L \geq k) = \frac{(n – A)! (n – k)!}{n! (n – A – k)!}\), and we can calculate \(E(L) = \sum\limits_{k=0}^{n – A} P(L \geq k)\), but the sum is fiddly and my ability to do such sums is rusty. Based on plugging some special cases into Wolfram Alpha, I think the expected number of loop iterations is something like \(\frac{n+1}{A + 1}\), which at least gives the right numbers for the boundary cases. If that is the right answer then I’m sure there’s some elegant combinatorial argument that shows it, but I’m not currently seeing it. Assuming this is right, this is asymptotically no better than the original rejection sampling.

Regardless, the expected number of loop iterations is definitely \(\leq \frac{n}{k}\), because we can simulate it by running our rejection sampling and counting \(1\) whenever we see a new random variable, so it is strictly dominated by the expected number of iterations of the pure rejection sampling. So we do achieve our promised result of an algorithm that strictly improves on both rejection sampling and filter then sample – it has a better expected complexity than the former, and the same worst case complexity as the latter.

Is this method actually worth using? Ehhh… maybe? It’s conceptually pleasing to me, and it’s nice to have a method that complexity-wise strictly out-performs either of the two natural choices, but in practice the constant overhead of using the hash map almost certainly is greater than any benefit you get from it.

The real problem that limits this being actually useful is that either \(n\) is small, in which case who cares, or \(n\) is large, in which case the chances of drawing a duplicate are sufficiently low that the overhead is negligible.

There are a couple of ways this idea can still be useful in the right niche though:

  • This does reduce the constant factor in the number of calls to \(f\) (especially if \(A\) is small), so if \(f\) is expensive then the constant factor improvement in number of calls may be enough to justify this.
  • If you already have your values you want to sample in an array, and the order of the array is arbitrary, then you can use the lazy Fisher Yates shuffle trick directly without the masking hash map.
  • If you’re genuinely unsure about whether \(A > 0\) and do want to be able to check, this method allows you to do that (but you could also just run rejection sampling \(n\) times and then fall back to filter and sample and it would probably be slightly faster to do so).

If any of those apply, this might be a worthwhile trick. If they don’t, hopefully you enjoyed reading about it anyway. Sorry.

This entry was posted in Python on by .

One thought on “Lazy fisher-yates shuffling for precise rejection sampling

  1. Alice

    How about dumping the values you sample into a list at the start, and when the list hits a particular fraction of n (a half, say), pay the O(n) cost of listing and shuffling the remaining values? You won’t check the same value very often (about one in four will be duplicates), so the happy case is low-cost O(1) and the worst case is still O(n).

Comments are closed.