multiprocessing: How do I share a dict among multiple processes?
A 开发者_StackOverflow社区program that creates several processes that work on a join-able queue, Q
, and may eventually manipulate a global dictionary D
to store results. (so each child process may use D
to store its result and also see what results the other child processes are producing)
If I print the dictionary D in a child process, I see the modifications that have been done on it (i.e. on D). But after the main process joins Q, if I print D, it's an empty dict!
I understand it is a synchronization/lock issue. Can someone tell me what is happening here, and how I can synchronize access to D?
A general answer involves using a Manager
object. Adapted from the docs:
from multiprocessing import Process, Manager
def f(d):
d[1] += '1'
d['2'] += 2
if __name__ == '__main__':
manager = Manager()
d = manager.dict()
d[1] = '1'
d['2'] = 2
p1 = Process(target=f, args=(d,))
p2 = Process(target=f, args=(d,))
p1.start()
p2.start()
p1.join()
p2.join()
print d
Output:
$ python mul.py
{1: '111', '2': 6}
multiprocessing is not like threading. Each child process will get a copy of the main process's memory. Generally state is shared via communication (pipes/sockets), signals, or shared memory.
Multiprocessing makes some abstractions available for your use case - shared state that's treated as local by use of proxies or shared memory: http://docs.python.org/library/multiprocessing.html#sharing-state-between-processes
Relevant sections:
- http://docs.python.org/library/multiprocessing.html#shared-ctypes-objects
- http://docs.python.org/library/multiprocessing.html#module-multiprocessing.managers
In addition to @senderle's here, some might also be wondering how to use the functionality of multiprocessing.Pool
.
The nice thing is that there is a .Pool()
method to the manager
instance that mimics all the familiar API of the top-level multiprocessing
.
from itertools import repeat
import multiprocessing as mp
import os
import pprint
def f(d: dict) -> None:
pid = os.getpid()
d[pid] = "Hi, I was written by process %d" % pid
if __name__ == '__main__':
with mp.Manager() as manager:
d = manager.dict()
with manager.Pool() as pool:
pool.map(f, repeat(d, 10))
# `d` is a DictProxy object that can be converted to dict
pprint.pprint(dict(d))
Output:
$ python3 mul.py
{22562: 'Hi, I was written by process 22562',
22563: 'Hi, I was written by process 22563',
22564: 'Hi, I was written by process 22564',
22565: 'Hi, I was written by process 22565',
22566: 'Hi, I was written by process 22566',
22567: 'Hi, I was written by process 22567',
22568: 'Hi, I was written by process 22568',
22569: 'Hi, I was written by process 22569',
22570: 'Hi, I was written by process 22570',
22571: 'Hi, I was written by process 22571'}
This is a slightly different example where each process just logs its process ID to the global DictProxy
object d
.
I'd like to share my own work that is faster than Manager's dict and is simpler and more stable than pyshmht library that uses tons of memory and doesn't work for Mac OS. Though my dict only works for plain strings and is immutable currently. I use linear probing implementation and store keys and values pairs in a separate memory block after the table.
from mmap import mmap
import struct
from timeit import default_timer
from multiprocessing import Manager
from pyshmht import HashTable
class shared_immutable_dict:
def __init__(self, a):
self.hs = 1 << (len(a) * 3).bit_length()
kvp = self.hs * 4
ht = [0xffffffff] * self.hs
kvl = []
for k, v in a.iteritems():
h = self.hash(k)
while ht[h] != 0xffffffff:
h = (h + 1) & (self.hs - 1)
ht[h] = kvp
kvp += self.kvlen(k) + self.kvlen(v)
kvl.append(k)
kvl.append(v)
self.m = mmap(-1, kvp)
for p in ht:
self.m.write(uint_format.pack(p))
for x in kvl:
if len(x) <= 0x7f:
self.m.write_byte(chr(len(x)))
else:
self.m.write(uint_format.pack(0x80000000 + len(x)))
self.m.write(x)
def hash(self, k):
h = hash(k)
h = (h + (h >> 3) + (h >> 13) + (h >> 23)) * 1749375391 & (self.hs - 1)
return h
def get(self, k, d=None):
h = self.hash(k)
while True:
x = uint_format.unpack(self.m[h * 4:h * 4 + 4])[0]
if x == 0xffffffff:
return d
self.m.seek(x)
if k == self.read_kv():
return self.read_kv()
h = (h + 1) & (self.hs - 1)
def read_kv(self):
sz = ord(self.m.read_byte())
if sz & 0x80:
sz = uint_format.unpack(chr(sz) + self.m.read(3))[0] - 0x80000000
return self.m.read(sz)
def kvlen(self, k):
return len(k) + (1 if len(k) <= 0x7f else 4)
def __contains__(self, k):
return self.get(k, None) is not None
def close(self):
self.m.close()
uint_format = struct.Struct('>I')
def uget(a, k, d=None):
return to_unicode(a.get(to_str(k), d))
def uin(a, k):
return to_str(k) in a
def to_unicode(s):
return s.decode('utf-8') if isinstance(s, str) else s
def to_str(s):
return s.encode('utf-8') if isinstance(s, unicode) else s
def mmap_test():
n = 1000000
d = shared_immutable_dict({str(i * 2): '1' for i in xrange(n)})
start_time = default_timer()
for i in xrange(n):
if bool(d.get(str(i))) != (i % 2 == 0):
raise Exception(i)
print 'mmap speed: %d gets per sec' % (n / (default_timer() - start_time))
def manager_test():
n = 100000
d = Manager().dict({str(i * 2): '1' for i in xrange(n)})
start_time = default_timer()
for i in xrange(n):
if bool(d.get(str(i))) != (i % 2 == 0):
raise Exception(i)
print 'manager speed: %d gets per sec' % (n / (default_timer() - start_time))
def shm_test():
n = 1000000
d = HashTable('tmp', n)
d.update({str(i * 2): '1' for i in xrange(n)})
start_time = default_timer()
for i in xrange(n):
if bool(d.get(str(i))) != (i % 2 == 0):
raise Exception(i)
print 'shm speed: %d gets per sec' % (n / (default_timer() - start_time))
if __name__ == '__main__':
mmap_test()
manager_test()
shm_test()
On my laptop performance results are:
mmap speed: 247288 gets per sec
manager speed: 33792 gets per sec
shm speed: 691332 gets per sec
simple usage example:
ht = shared_immutable_dict({'a': '1', 'b': '2'})
print ht.get('a')
Maybe you can try pyshmht, sharing memory based hash table extension for Python.
Notice
It's not fully tested, just for your reference.
It currently lacks lock/sem mechanisms for multiprocessing.
In my case I am not getting consistent outputs, e.g. the value of the __total_count__
isn't always 20.
from itertools import repeat
import multiprocessing as mp
import os
import pprint
from functools import partial
import numpy as np
import time
def counter(value, d: dict) -> None:
if value not in d:
d["__unique_count__"] += 1
d[value] = 1
else:
d[value] += 1
d["__total_count__"] += 1
if __name__ == '__main__':
mp.freeze_support()
with mp.Manager() as manager:
d = manager.dict()
d["__unique_count__"] = 0
d["__total_count__"] = 0
numbers = np.random.randint(0,5,size=100)
print(len(numbers))
with manager.Pool() as pool:
pool.map(partial(counter, d=d), numbers)
# `d` is a DictProxy object that can be converted to dict
final_d = dict(d)
pprint.pprint(final_d)
print(final_d["__unique_count__"], final_d["__total_count__"])
output1
100
{0: 26,
1: 16,
2: 26,
3: 14,
4: 18,
'__total_count__': 92,
'__unique_count__': 5}
5 92
output2
100
{0: 10,
1: 21,
2: 28,
3: 22,
4: 19,
'__total_count__': 95,
'__unique_count__': 5}
5 95
精彩评论