开发者

Generating a list of functions in python [duplicate]

This question already has answers here: Creating functions (or lambdas) in a loop (or comprehension) (6 answers) 开发者_如何学C Closed 6 months ago.

I have the following python code that generates a list of anonymous functions:

basis = [ (lambda x: n*x) for n in [0, 1, 2] ]     
print basis[0](1)

I would have expected it to be equivalent to

basis = [ (lambda x: 0*x), (lambda x: 1*x), (lambda x: 2*x) ]
print basis[0](1)

However, whereas the second snippet prints out 0 which is what I would expect, the first prints 2. What's wrong with the first snippet of code, and why doesn't it behave as expected?


You can use a default parameter to create a closure on n

>>> basis = [ (lambda x,n=n: n*x) for n in [0, 1, 2] ]     
>>> print basis[0](1)
0


Because it's "pass by name".

That is, when the lambda is run, it executes n*x: x is bound to 1 (it is a parameter), n is looked up in the environment (it is now 2). So, the result is 2.


The problem is that in the first example, each lambda is bound to the same n -- in other words, it is capturing the variable, not the variable's value. Since n has the value of 2 at the end of the loop, each lambda is using the value 2 for n.

Apparently you can use default parameters to solve this problem:

basis = [ (lambda x,n=n: n*x) for n in [0, 1, 2] ]
print basis[0](1) 

Since default parameter values are constants, the n on the right side of n=n will be evaluated each time through the loop to give you a new captured value.


I want to help with the understanding of the comment from Karl Knechtel ( Dec 13 '10 at 7:32). The following code shows how using the generator, the original lambda definition gives the intended result, but it does not using list or tuple:

>>> #GENERATOR
... basis = ( (lambda x: n*x) for n in [0, 1, 2] )  
>>> print(type(basis))
<type 'generator'>
>>> basis = ( (lambda x: n*x) for n in [0, 1, 2] ) 
>>> print([x(3) for x in basis])
[0, 3, 6]
>>> #TUPLE
... basis = tuple( (lambda x: n*x) for n in [0, 1, 2] )
>>> print(type(basis))
<type 'tuple'>
>>> print([x(3) for x in basis])
[6, 6, 6]
>>> #LIST
... basis = list( (lambda x: n*x) for n in [0, 1, 2] )
>>> print(type(basis))
<type 'list'>
>>> print([x(3) for x in basis])
[6, 6, 6]
>>> #CORRECTED LIST
... basis = list( (lambda x, n=n: n*x) for n in [0, 1, 2] )
>>> print(type(basis))
<type 'list'>
>>> print([x(3) for x in basis])
[0, 3, 6]
0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜