Due to reasons, I found myself in need of a data structure supporting a slightly unusual combination of operations. Developing it involved a fairly straightforward process of refinement, and a number of interesting tricks, so I thought it might be informative to people to walk through (a somewhat stylised version of) the design.

The data structure is a particular type of random sampler, starting from a shared array of values (possibly containing duplicates). Values are hashable and comparable for equality.

It needs to support the following operations:

- Initialise from a random number generator and a shared immutable array of values so that it holds all those values.
- Sample an element uniformly at random from the remaining values, or raise an error if there are none.
- Unconditionally (i.e. without checking whether it’s present) remove all instances of a value from the list.

The actual data structure I want is a bit more complicated than that, but those are enough to demonstrate the basic principles.

What’s surprising is that you can do all of these operations in amortised O(1). This includes the initialisation from a list of n values!

The idea behind designing this is to start with the most natural data structure that doesn’t achieve these bounds and then try to refine it until it does. That data structure is a resizable array. You can sample uniformly by just picking an index into the array. You can delete by doing a scan and deleting the first element that is equal to the value. This means you have to be able to mutate the array, so initalising it requires copying.

Which means it’s time for some code.

Let’s start by writing some code.

First lets write a test suite for this data structure:

from collections import Counter from hypothesis.stateful import RuleBasedStateMachine, rule, precondition import hypothesis.strategies as st from sampler import Sampler class FakeRandom(object): def __init__(self, data): self.__data = data def randint(self, m, n): return self.__data.draw(st.integers(m, n), label="randint(%d, %d)" % ( m, n )) class SamplerRules(RuleBasedStateMachine): def __init__(self): super(SamplerRules, self).__init__() self.__initialized = False @precondition(lambda self: not self.__initialized) @rule( values=st.lists(st.integers()).map(tuple), data=st.data() ) def initialize(self, values, data): self.__initial_values = values self.__sampler = Sampler(values, FakeRandom(data)) self.__counts = Counter(values) self.__initialized = True @precondition(lambda self: self.__initialized) @rule() def sample(self): if sum(self.__counts.values()) != 0: v = self.__sampler.sample() assert self.__counts[v] != 0 else: try: self.__sampler.sample() assert False, "Should have raised" except IndexError: pass @precondition(lambda self: self.__initialized) @rule(data=st.data()) def remove(self, data): v = data.draw(st.sampled_from(self.__initial_values)) self.__sampler.remove(v) self.__counts[v] = 0 TestSampler = SamplerRules.TestCase |

This uses Hypothesis’s rule based stateful testing to completely describe the valid range of behaviour of the data structure. There are a number of interesting and possibly non-obvious details in there, but this is a data structures post rather than a Hypothesis post, so I’m just going to gloss over them and invite you to peruse the tests in more detail at your leisure if you’re interested.

Now lets look at an implementation of this, using the approach described above:

class Sampler(object): def __init__(self, values, random): self.__values = list(values) self.__random = random def sample(self): if not self.__values: raise IndexError("Cannot sample from empty list") i = self.__random.randint(0, len(self.__values) - 1) return self.__values[i] def remove(self, value): self.__values = [v for v in self.__values if v != value] |

The test suite passes, so we’ve successfully implemented the operations (or our bugs are too subtle for Hypothesis to find in a couple seconds). Hurrah.

But we’ve not really achieved our goals: Sampling is O(1), sure, but remove and initialisation are both O(n). How can we fix that?

The idea is to incrementally patch up this data structure by finding the things that make it O(n) and seeing if we can defer the cost for each element until we actually definitely need to incur that cost to get the correct result.

Let’s start by fixing removal.

The first key observation is that if we don’t care about the order of values in a list (which we don’t because we only access it through random sampling), we can remove the element present at an index in O(1) by popping the element that is at the end of the list and overwriting the index with that value (if it wasn’t the last index). This is the approach normally taken if you want to implement random sampling without replacement, but in our use case we’ve separated removal from sampling so it’s not quite so easy.

The problem is that we don’t know where (or even if) the value we want to delete is in our array, so we still have to do an O(n) scan to find it.

One solution to this problem (which is an entirely valid one) is to have a mapping of values to the indexes they are found in. This is a little tricky to get right with duplicates, but it’s an entirely workable solution. It makes it much harder to do our O(1) initialize later though, so we’ll not go down this route.

Instead the idea is to defer the deletion until we know of an index for it, which we can do during our sampling! We keep a count of how many times a value has been deleted and, if we end up sampling it and the count is non-zero, we remove it from the list and decrement the count by one.

This means that we potentially pay an additional O(deletions) cost each time we sample, but these costs are “queued up” from the previous delete calls, and once incurred do not repeat, so this doesn’t break our claim of O(1) *amortised* time – the costs we pay on sampling are just one-off deferred costs from earlier.

Here’s some code:

class Sampler(object): def __init__(self, values, random): self.__values = list(values) self.__random = random self.__deletions = set() def sample(self): while True: if not self.__values: raise IndexError("Cannot sample from empty list") i = self.__random.randint(0, len(self.__values) - 1) v = self.__values[i] if v in self.__deletions: replacement = self.__values.pop() if i != len(self.__values): self.__values[i] = replacement else: return v def remove(self, value): self.__deletions.add(value) |

So now we’re almost done. All we have to do is figure out some way to create a mutable version of our immutable list in O(1).

This sounds impossible but turns out to be surprisingly easy.

The idea is to create a *mask* in front of our immutable sequence, which tracks a length and a mapping of indices to values. Whenever we write to the mutable “copy” we write to the mask. Whenever we read from the copy, we first check that it’s in bounds and if it is we read from the mask. If the index is not present in the mask we read from the original sequence.

The result is essentially a sort of very fine grained copy on write – we never have to look at the whole sequence, only the bits that we are reading from, so we never have to do O(n) work.

Here’s some code:

from collections import Counter class Sampler(object): def __init__(self, values, random): self.__values = values self.__length = len(values) self.__mask = {} self.__random = random self.__deletions = set() def sample(self): while True: if not self.__length: raise IndexError("Cannot sample from empty list") i = self.__random.randint(0, self.__length - 1) try: v = self.__mask[i] except KeyError: v = self.__values[i] if v in self.__deletions: j = self.__length - 1 try: replacement = self.__mask.pop(j) except KeyError: replacement = self.__values[j] self.__length = j if i != j: self.__mask[i] = replacement else: return v def remove(self, value): self.__deletions.add(value) |

And that’s it, we’re done!

There are more optimisations we could do here – e.g. the masking trick is relatively expensive, so it might make sense to switch back to a mutable array once we’ve masked off the entirety of the array, e.g. using a representation akin to the pypy dict implementation and throwing away the hash table part when the value array is of full length.

But that would only improve the constants (you can’t get better than O(1) asymptotically!), so I’d be loathe to take on the increased complexity until I saw a real world workload where this was the bottleneck (which I’m expecting to at some point if this idea bears fruit, but I don’t yet know if it will). We’ve got the asymptotics I wanted, so lets stop there while the implementation is fairly simple.

I’ve yet to actually use this in practice, but I’m still really pleased with the design of this thing. Starting from a fairly naive initial implementation, we’ve used some fairly generic tricks to patch up what started out as O(n) costs and turn them O(1). As well as everything dropping out nicely, a lot of these techniques are probably reusable for other things (the masking trick in particular is highly generic).

Update 09/4/2017: An earlier version of this claimed that this solution allowed you to remove a *single* instance of a value from the list. I’ve updated it to a version that removes all values from a list, due to a comment below correctly pointing out that that approach biases the distribution. Fortunately for me in my original use case the values are all distinct anyway so the distinction doesn’t matter, but I’ve now updated the post and the code to remove *all* instances of the value from the list.

Do you like data structures? Of course you do! Who doesn’t like data structures? Would you like *more* data structures? Naturally. So why not sign up for my Patreon and tell me so, so you can get more exciting blog posts like this! You’ll get access to drafts of upcoming blog posts, a slightly increased blogging rate from me, and the warm fuzzy feeling of supporting someone whose writing you enjoy.