How to serialize/deserialized pybrain networks?
PyBrain is a python library that provides (am开发者_JS百科ong other things) easy to use Artificial Neural Networks.
I fail to properly serialize/deserialize PyBrain networks using either pickle or cPickle.
See the following example:
from pybrain.datasets import SupervisedDataSet
from pybrain.tools.shortcuts import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer
import cPickle as pickle
import numpy as np
#generate some data
np.random.seed(93939393)
data = SupervisedDataSet(2, 1)
for x in xrange(10):
y = x * 3
z = x + y + 0.2 * np.random.randn()
data.addSample((x, y), (z,))
#build a network and train it
net1 = buildNetwork( data.indim, 2, data.outdim )
trainer1 = BackpropTrainer(net1, dataset=data, verbose=True)
for i in xrange(4):
trainer1.trainEpochs(1)
print '\tvalue after %d epochs: %.2f'%(i, net1.activate((1, 4))[0])
This is the output of the above code:
Total error: 201.501998476
value after 0 epochs: 2.79
Total error: 152.487616382
value after 1 epochs: 5.44
Total error: 120.48092561
value after 2 epochs: 7.56
Total error: 97.9884043452
value after 3 epochs: 8.41
As you can see, network total error decreases as the training progresses. You can also see that the predicted value approaches the expected value of 12.
Now we will do a similar exercise, but will include serialization/deserialization:
print 'creating net2'
net2 = buildNetwork(data.indim, 2, data.outdim)
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True)
trainer2.trainEpochs(1)
print '\tvalue after %d epochs: %.2f'%(1, net2.activate((1, 4))[0])
#So far, so good. Let's test pickle
pickle.dump(net2, open('testNetwork.dump', 'w'))
net2 = pickle.load(open('testNetwork.dump'))
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True)
print 'loaded net2 using pickle, continue training'
for i in xrange(1, 4):
trainer2.trainEpochs(1)
print '\tvalue after %d epochs: %.2f'%(i, net2.activate((1, 4))[0])
This is the output of this block:
creating net2
Total error: 176.339378639
value after 1 epochs: 5.45
loaded net2 using pickle, continue training
Total error: 123.392181859
value after 1 epochs: 5.45
Total error: 94.2867637623
value after 2 epochs: 5.45
Total error: 78.076711114
value after 3 epochs: 5.45
As you can see, it seems that the training has some effect on the network (the reported total error value continues to decrease), however the output value of the network freezes on a value that was relevant for the first training iteration.
Is there any caching mechanism that I need to be aware of that causes this erroneous behaviour? Are there better ways to serialize/deserialize pybrain networks?
Relevant version numbers:
- Python 2.6.5 (r265:79096, Mar 19 2010, 21:48:26) [MSC v.1500 32 bit (Intel)]
- Numpy 1.5.1
- cPickle 1.71
- pybrain 0.3
P.S. I have created a bug report on the project's site and will keep both SO and the bug tracker updatedj
Cause
The mechanism that causes this behavior is the handling of parameters (.params
) and derivatives (.derivs
) in PyBrain modules: in fact, all network parameters are stored in one array, but the individual Module
or Connection
objects have access to "their own" .params
, which, however are just a view on a slice of the total array. This allows both local and network-wide writes and read-outs on the same data-structure.
Apparently this slice-view link gets lost by pickling-unpickling.
Solution
Insert
net2.sorted = False
net2.sortModules()
after loading from the file (which recreates this sharing), and it should work.
精彩评论