返回
Python递归函数中数学运算次数统计方法 (含AST静态分析)
python
2025-03-15 05:45:41
Python 递归函数中基本数学运算次数的统计方法
这篇博客要讲的是,怎么去数一个 Python 递归函数里面,到底做了多少次基础的数学运算,像加减乘除、取余数、比大小这些。这对于分析算法用了多少时间、空间,挺有用的。咱们就拿下面这段代码来说事:
def bar(k):
score = 0
for i in range(2, k + 1):
j = i - 1
while j > 0:
if i % j == 0:
score = score // 2
else:
score = score + 10
j = j - 1
return score
mod = 10**9 + 7
def foo(n):
m = n % 3
if n == 0:
return 1
elif m == 0:
v = foo(n // 3)
t = 0
for i in range(1, n+1):
t = t + bar(4 * i)
return v + t
elif m == 1:
v = foo(n - 1)
return v + bar(n * n * n)
else:
v = foo(n - 2)
r = 1
for a in range(2, n):
r = r * a % mod
r = r + bar(n)
return v + r
一、 问题出在哪儿?
一开始,你可能想着,直接在递归函数里头,碰到一次运算就加个一,像这样:
def count_operations(n):
if n == 0:
return 1
elif n % 3 == 0:
return count_operations(n // 3) + 4
elif n % 3 == 1:
return 6 + count_operations(n - 1) + 4
else:
return 9 + 2 * count_operations(n - 2)
但这么搞,有两个大问题:
bar
函数没算进去。foo
函数里头还套了个bar
函数,bar
里面做的那些运算,你都没给算上。- 循环也没算清楚。
for
循环和while
循环里,每次循环判断,其实都藏着运算,这些你也没仔细数。
二、 咋解决?几种办法
1. 手动计数 (累,但直观)
最笨的办法,就是一点一点抠代码,人肉去数。
-
拆解
foo
函数:m = n % 3
: 1 次%
if n == 0
:1 次==
elif m == 0
:1 次==
elif m == 1
:1 次==
- 递归调用
foo(n // 3)
或foo(n - 1)
或foo(n - 2)
://
或-
运算各一次, 加上函数本身计数. bar
函数的调用- 各种
return
: 可能带有+
.
-
拆解
bar
函数:score = 0
:没运算。for i in range(2, k + 1)
:循环开始前有 1 次+
,循环每次判断有若干次+
和<=
。j = i - 1
:1 次-
。while j > 0
:每次循环判断 1 次>
。if i % j == 0
:1 次%
,1 次==
。score = score // 2
或score = score + 10
:1 次//
或 1 次+
。j = j - 1
:1 次-
。
-
然后根据不同的 n 的值,去算它走了
foo
的哪个分支,bar
执行了几次,每次bar
里面的循环又走了几遍……总之特别麻烦,容易漏。
例子:
当 N=1,foo(1)
被执行:m = n % 3
:n % 3
(1 次)- 三次
==
比较 (3 次). - 执行
elif m == 1
,v = foo(n - 1)
被执行 (n-1
, 1 次). foo(0)
被调用:n%3
,n==0
(2 次).- 返回
1
.
- 返回到
foo(1)
继续执行,计算bar(n * n * n)
. 三次*
(3次). bar(1)
被调用range(2, k+1)
:其中有k+1
, 算一次+
运算。但是由于 range 函数特性, 不产生运算。
- 返回到
foo(1)
, 计算最后的+
.
总共 1+3+1+2+3+1= 11次。
当 N >1, 情况非常复杂,容易导致错误,所以不建议。
2. 重载运算符 (巧妙, 但有坑)
有个很巧的方法,就是把 Python 里面的那些运算符,给它“偷梁换柱”一下。
class Integer(int):
n_ops = 0
def new_patch(name):
def patch(self, *args):
Integer.n_ops += 1
value = getattr(int, name)(self, *args)
if isinstance(value, int) and not (value is True or value is False):
value = Integer(value)
return value
patch.__name__ = name
return patch
methods = {
'__le__': '\u2264',
'__lt__': '<',
'__ge__': '\u2265',
'__gt__': '>',
'__eq__': '==',
'__add__': '+',
'__sub__': '-',
'__mul__': '*',
'__floordiv__': '//',
'__mod__': '%',
}
for name in methods:
setattr(Integer, name, new_patch(name))
def bar(k):
score = Integer(0)
for i in range(2, k + 1):
j = i - 1
while j > 0:
if i % j == 0:
score = score // 2
else:
score = score + 10
j = j - 1
return score
mod = 10**9 + 7
def foo(n):
n = Integer(n)
m = n % 3
if n == 0:
return 1
elif m == 0:
v = foo(n // 3)
t = Integer(0)
for i in range(1, n+1):
t = t + bar(4 * i)
return v + t
elif m == 1:
v = foo(n - 1)
return v + bar(n * n * n)
else:
v = foo(n - 2)
r = Integer(1)
for a in range(2, n):
r = r * a % mod
r = r + bar(n)
return v + r
Integer.n_ops = 0 # 重置计数
print(foo(5)) #随便给个 n 值, 或其它你需要的值。
print(Integer.n_ops)
原理:
Integer
类: 咱们自己造个Integer
类,让它继承 Python 里的int
类。new_patch
函数: 这个函数是个“补丁”,用来给Integer
类打补丁的。patch
函数: 这就是“补丁”的核心。它先给全局变量Integer.n_ops
加 1,表示运算次数加 1。然后,它调用真正的int
类的运算符(比如int.__add__
就是真的加法)。如果运算结果还是个整数,而且不是布尔值,就把它也变成Integer
类型,这样下次运算又能被我们“监控”到。methods
字典: 这里头列出了我们要“监控”的运算符。- 循环
setattr
: 把Integer
类的这些运算符,一个个都给换成咱们的patch
函数。 - 将函数内初始赋值为0, 1的变量类型进行替换.
注意:
- 代码里面
mod = 10**9 + 7
没被包括进去. - 小心使用.
3. AST 静态分析(高级,但准确)
用 Python 自带的 ast
模块,可以把代码变成一棵“抽象语法树”(Abstract Syntax Tree)。
import ast
class OpCounter(ast.NodeVisitor):
def __init__(self):
self.counts = {
'+': 0, '-': 0, '*': 0, '//': 0, '%': 0,
'>': 0, '<': 0, '==': 0, '<=': 0, '>=': 0
}
def visit_BinOp(self, node):
op_map = {
ast.Add: '+', ast.Sub: '-', ast.Mult: '*',
ast.FloorDiv: '//', ast.Mod: '%'
}
op_type = type(node.op)
if op_type in op_map:
self.counts[op_map[op_type]] += 1
self.generic_visit(node)
def visit_Compare(self, node):
op_map = {
ast.Gt: '>', ast.Lt: '<', ast.Eq: '==',
ast.LtE: '<=', ast.GtE: '>='
}
for op in node.ops:
op_type = type(op)
if op_type in op_map:
self.counts[op_map[op_type]] += 1
self.generic_visit(node)
def count_ops_in_function(func):
counter = OpCounter()
counter.visit(ast.parse(inspect.getsource(func)))
return counter.counts
import inspect
print(count_ops_in_function(bar))
print(count_ops_in_function(foo))
原理:
- 用
ast.parse
把函数代码变成 AST。 OpCounter
类去遍历这棵树。visit_BinOp
专门处理二元运算(加减乘除这些)。visit_Compare
专门处理比较运算(大于小于这些)。- 统计完,分别打印.
这种方法不用运行代码, 不用改代码, 特别准,还能分别统计不同类型的运算。
####三种方法的取舍
- 方法一非常不推荐, 当N很大时, 非常复杂且容易遗漏或数错.
- 方法二可以满足绝大部分运算次数的统计需求, 缺点是可能会与其他利用了运算符重载的代码冲突.
- 方法三准确且通用, 是高级玩家的首选.