Counting the number of recursions
I have some code that approximates a solution recusively, what it actually does is not important, but it works towards r' == rt by varying mg (m guess, starting with 4.0 because I "know" that ought to be in the ballpark).
solve_m f ar st qt = solve_m' f ar st qt 4.0
where
solve_m' f ar st qt mg
| rd > precis = f' (mg - sf)
| rd < (-precis) = f' (mg + sf)
| otherwise = mg
where
f' = solve_m' f ar st qt
rt = st + qt
r' = f st ar mg
rd = rt - r'
sf = abs(rd)
What I would like to be able to do is count the number of cycles, I know the right way to do this is with the State monad, but what is the most elegant way to fit the put/get into a function like this? Make f' a do block? Or is it simply to add a counter solve_m' and return (counter, mg)?
Thanks!
Edit: This seems to be basically what I want, and no Monads necessary:
solve_m f ar st qt = (last (series), length(series))
where
series = takeWhile termPred (iterate sol开发者_Python百科ve_m' 4.0)
termPred m' = (abs (rt - (f st ar m'))) > precis
rt = st + qt
solve_m' mg
| rt > r' = (mg - sf)
| rt < r' = (mg + sf)
where
r' = f st ar mg
rd = rt - r'
sf = abs(rd)
Still looks a little messy (repeated code) but I'll tidy it up... This is getting me acceptable results in 1/10000th of the iterations of the code it will replace!
Without looking at your algorithm, the generic way to do this is divide up your termination criteria from the iterative algorithm:
terminationPred :: a -> Bool
algorithm :: a -> a
then use either iterate and takeWhile:
itermediates = takeWhile (not . terminationPred) . iterate algorithm
resultAndRecursions :: a -> (a, Int)
resultAndRecursions a = (last (intermediates a), length (intermediates a) - 1)
-- you'd want to make your own safe function here, not use last and length
or unfold:
intermediates = unfoldr op
where
op a | terminationPred a = Nothing
| otherwise = let a' = algorithm a
in Just (a', a')
EDIT: also notice these two intermediates are slightly different in that the first maintains the base case (the input a
, hence the - 1
) while the second does not and thus would have a minor difference in the complementary resultAndRecursions
.
Well, first of all, you can remove most of the arguments to solve_m'
: they don't change in recursive calls, and the arguments of solve_m
are in scope for the where
clause. This also makes the f'
function unnecessary.
solve_m f ar st qt = solve_m' 4.0
where
solve_m' mg
| rd > precis = solve_m' (mg - sf)
| rd < (-precis) = solve_m' (mg + sf)
| otherwise = mg
where
rt = st + qt
r' = f st ar mg
rd = rt - r'
sf = abs(rd)
Now, solve_m'
has type Double -> Double
, because all it does is perform the next iteration and then either finish or call itself tail-recursively. As it happens, the standard libraries include a function called iterate
with type (a -> a) -> a -> [a]
, which takes such a function and produces an (possibly infinite) list of each step in the iteration. The number of recursive calls needed is, of course, precisely the length of the resulting list. produces an embarrassing mistake in my answer.
What iterate
actually does is produce an infinite list, in this case with endlessly repeating copies of the "final" result. Not really what you want. I was probably thinking of unfoldr :: (b -> Maybe (a, b)) -> b -> [a]
.
The other option--which I actually prefer--would be to remove the guard that checks for the answer being close enough and use iterate
after all, producing an infinite list of new approximations, then consume the resulting list comparing adjacent elements to see how close you're getting. I'd give some example code but given the earlier mistake that might be unwise.
EDIT: Okay, for the sake of completeness, here's a couple quick examples:
Using iterate
and takeWhile
:
solve_m_iter f ar st qt = takeWhile notDoneYet $ iterate nextApprox 4.0
where rd mg = st + qt - f st ar mg
notDoneYet mg = abs (rd mg) > precis
nextApprox mg | rd mg > precis = mg - abs (rd mg)
| rd mg < -precis = mg + abs (rd mg)
Using unfoldr
:
solve_m_unfold f ar st qt = unfoldr nextApprox
where nextApprox mg | rd > precis = keep $ mg - abs rd
| rd < -precis = keep $ mg + abs rd
| otherwise = Nothing
where rd = st + qt - f st ar mg
keep x = Just (x, x)
And a slightly better function to get the result without traversing the list twice:
getResult = foldl (\(n, _) x -> (n + 1, x)) (0, 4.0)
Definitely quick-and-dirty code, but hopefully helpful.
精彩评论