Pythonic way to select list elements with different probability [duplicate]
import random
pos = ["A", "B", "C"]
x = random.choice["A", "B", "C"]
This code 开发者_StackOverflowgives me either "A", "B" or "C" with equal probability. Is there a nice way to express it when you want "A" with 30%, "B" with 40% and "C" with 30% probability?
Weights define a probability distribution function (pdf). Random numbers from any such pdf can be generated by applying its associated inverse cumulative distribution function to uniform random numbers between 0 and 1.
See also this SO explanation, or, as explained by Wikipedia:
If Y has a U[0,1] distribution then F⁻¹(Y) is distributed as F. This is used in random number generation using the inverse transform sampling-method.
import random
import bisect
import collections
def cdf(weights):
total = sum(weights)
result = []
cumsum = 0
for w in weights:
cumsum += w
result.append(cumsum / total)
return result
def choice(population, weights):
assert len(population) == len(weights)
cdf_vals = cdf(weights)
x = random.random()
idx = bisect.bisect(cdf_vals, x)
return population[idx]
weights=[0.3, 0.4, 0.3]
population = 'ABC'
counts = collections.defaultdict(int)
for i in range(10000):
counts[choice(population, weights)] += 1
print(counts)
# % test.py
# defaultdict(<type 'int'>, {'A': 3066, 'C': 2964, 'B': 3970})
The choice
function above uses bisect.bisect
, so selection of a weighted random variable is done in O(log n)
where n
is the length of weights
.
Note that as of version 1.7.0, NumPy has a Cythonized np.random.choice function. For example, this generates 1000 samples from the population [0,1,2,3]
with weights [0.1, 0.2, 0.3, 0.4]
:
import numpy as np
np.random.choice(4, 1000, p=[0.1, 0.2, 0.3, 0.4])
np.random.choice
also has a replace
parameter for sampling with or without replacement.
A theoretically better algorithm is the Alias Method. It builds a table which requires O(n)
time, but after that, samples can be drawn in O(1)
time. So, if you need to draw many samples, in theory the Alias Method may be faster. There is a Python implementation of the Walker Alias Method here, and a numpy version here.
Not... so much...
pos = ['A'] * 3 + ['B'] * 4 + ['C'] * 3
print random.choice(pos)
or
pos = {'A': 3, 'B': 4, 'C': 3}
print random.choice([x for x in pos for y in range(pos[x])])
Here's a class to expose a bunch of items with relative probabilities, without actually expanding the list:
import bisect
class WeightedTuple(object):
"""
>>> p = WeightedTuple({'A': 2, 'B': 1, 'C': 3})
>>> len(p)
6
>>> p[0], p[1], p[2], p[3], p[4], p[5]
('A', 'A', 'B', 'C', 'C', 'C')
>>> p[-1], p[-2], p[-3], p[-4], p[-5], p[-6]
('C', 'C', 'C', 'B', 'A', 'A')
>>> p[6]
Traceback (most recent call last):
...
IndexError
>>> p[-7]
Traceback (most recent call last):
...
IndexError
"""
def __init__(self, items):
self.indexes = []
self.items = []
next_index = 0
for key in sorted(items.keys()):
val = items[key]
self.indexes.append(next_index)
self.items.append(key)
next_index += val
self.len = next_index
def __getitem__(self, n):
if n < 0:
n = self.len + n
if n < 0 or n >= self.len:
raise IndexError
idx = bisect.bisect_right(self.indexes, n)
return self.items[idx-1]
def __len__(self):
return self.len
Now, just say:
data = WeightedTuple({'A': 30, 'B': 40, 'C': 30})
random.choice(data)
As of Python 3.6, there is random.choices
for that.
original answer from 2010:
You can also make use this form, which does not create a list arbitrarily big (and can work with either integral or decimal probabilities):
pos = [("A", 30), ("B", 40), ("C", 30)]
from random import uniform
def w_choice(seq):
total_prob = sum(item[1] for item in seq)
chosen = random.uniform(0, total_prob)
cumulative = 0
for item, probality in seq:
cumulative += probality
if cumulative > chosen:
return item
There are some good solutions offered here, but I would suggest that you look at Eli Bendersky's thorough discussion of this issue, which compares various algorithms to achieve this (with implementations in Python) before choosing one.
Try this:
import random
from decimal import Decimal
pos = {'A': Decimal("0.3"), 'B': Decimal("0.4"), 'C': Decimal("0.3")}
choice = random.random()
F_x = 0
for k, p in pos.iteritems():
F_x += p
if choice <= F_x:
x = k
break
精彩评论