Generating a list of functions in python [duplicate]
I have the following python code that generates a list of anonymous functions:
basis = [ (lambda x: n*x) for n in [0, 1, 2] ]
print basis[0](1)
I would have expected it to be equivalent to
basis = [ (lambda x: 0*x), (lambda x: 1*x), (lambda x: 2*x) ]
print basis[0](1)
However, whereas the second snippet prints out 0 which is what I would expect, the first prints 2. What's wrong with the first snippet of code, and why doesn't it behave as expected?
You can use a default parameter to create a closure on n
>>> basis = [ (lambda x,n=n: n*x) for n in [0, 1, 2] ]
>>> print basis[0](1)
0
Because it's "pass by name".
That is, when the lambda is run, it executes n*x: x is bound to 1 (it is a parameter), n is looked up in the environment (it is now 2). So, the result is 2.
The problem is that in the first example, each lambda is bound to the same n -- in other words, it is capturing the variable, not the variable's value. Since n has the value of 2 at the end of the loop, each lambda is using the value 2 for n.
Apparently you can use default parameters to solve this problem:
basis = [ (lambda x,n=n: n*x) for n in [0, 1, 2] ]
print basis[0](1)
Since default parameter values are constants, the n on the right side of n=n will be evaluated each time through the loop to give you a new captured value.
I want to help with the understanding of the comment from Karl Knechtel ( Dec 13 '10 at 7:32). The following code shows how using the generator, the original lambda definition gives the intended result, but it does not using list or tuple:
>>> #GENERATOR
... basis = ( (lambda x: n*x) for n in [0, 1, 2] )
>>> print(type(basis))
<type 'generator'>
>>> basis = ( (lambda x: n*x) for n in [0, 1, 2] )
>>> print([x(3) for x in basis])
[0, 3, 6]
>>> #TUPLE
... basis = tuple( (lambda x: n*x) for n in [0, 1, 2] )
>>> print(type(basis))
<type 'tuple'>
>>> print([x(3) for x in basis])
[6, 6, 6]
>>> #LIST
... basis = list( (lambda x: n*x) for n in [0, 1, 2] )
>>> print(type(basis))
<type 'list'>
>>> print([x(3) for x in basis])
[6, 6, 6]
>>> #CORRECTED LIST
... basis = list( (lambda x, n=n: n*x) for n in [0, 1, 2] )
>>> print(type(basis))
<type 'list'>
>>> print([x(3) for x in basis])
[0, 3, 6]
加载中,请稍侯......
精彩评论