Generating a list of functions in python [duplicate]
I have the following python code that generates a list of anonymous functions:
basis = [ (lambda x: n*x) for n in [0, 1, 2] ]
print basis[0](1)
I would have expected it to be equivalent to
basis = [ (lambda x: 0*x), (lambda x: 1*x), (lambda x: 2*x) ]
print basis[0](1)
However, whereas the second snippet prints out 0 which is what I would expect, the first prints 2. What's wrong with the first snippet of code, and why doesn't it behave as expected?
You can use a default parameter to create a closure on n
>>> basis = [ (lambda x,n=n: n*x) for n in [0, 1, 2] ]
>>> print basis[0](1)
0
Because it's "pass by name".
That is, when the lambda
is run, it executes n*x
: x
is bound to 1
(it is a parameter), n
is looked up in the environment (it is now 2). So, the result is 2.
The problem is that in the first example, each lambda is bound to the same n
-- in other words, it is capturing the variable, not the variable's value. Since n
has the value of 2 at the end of the loop, each lambda is using the value 2 for n
.
Apparently you can use default parameters to solve this problem:
basis = [ (lambda x,n=n: n*x) for n in [0, 1, 2] ]
print basis[0](1)
Since default parameter values are constants, the n
on the right side of n=n
will be evaluated each time through the loop to give you a new captured value.
I want to help with the understanding of the comment from Karl Knechtel ( Dec 13 '10 at 7:32). The following code shows how using the generator, the original lambda definition gives the intended result, but it does not using list or tuple:
>>> #GENERATOR
... basis = ( (lambda x: n*x) for n in [0, 1, 2] )
>>> print(type(basis))
<type 'generator'>
>>> basis = ( (lambda x: n*x) for n in [0, 1, 2] )
>>> print([x(3) for x in basis])
[0, 3, 6]
>>> #TUPLE
... basis = tuple( (lambda x: n*x) for n in [0, 1, 2] )
>>> print(type(basis))
<type 'tuple'>
>>> print([x(3) for x in basis])
[6, 6, 6]
>>> #LIST
... basis = list( (lambda x: n*x) for n in [0, 1, 2] )
>>> print(type(basis))
<type 'list'>
>>> print([x(3) for x in basis])
[6, 6, 6]
>>> #CORRECTED LIST
... basis = list( (lambda x, n=n: n*x) for n in [0, 1, 2] )
>>> print(type(basis))
<type 'list'>
>>> print([x(3) for x in basis])
[0, 3, 6]
精彩评论