返回

Python递归函数中数学运算次数统计方法 (含AST静态分析)

python

Python 递归函数中基本数学运算次数的统计方法

这篇博客要讲的是,怎么去数一个 Python 递归函数里面,到底做了多少次基础的数学运算,像加减乘除、取余数、比大小这些。这对于分析算法用了多少时间、空间,挺有用的。咱们就拿下面这段代码来说事:

def bar(k):
    score = 0
    for i in range(2, k + 1):
        j = i - 1
        while j > 0:
            if i % j == 0:
                score = score // 2
            else:
                score = score + 10
            j = j - 1
    return score

mod = 10**9 + 7

def foo(n):
    m = n % 3
    if n == 0:
        return 1
    elif m == 0:
        v = foo(n // 3)
        t = 0
        for i in range(1, n+1):
            t = t + bar(4 * i)
        return v + t
    elif m == 1:
        v = foo(n - 1)
        return v + bar(n * n * n)
    else:
        v = foo(n - 2)
        r = 1
        for a in range(2, n):
            r = r * a % mod
            r = r + bar(n)
        return v + r

一、 问题出在哪儿?

一开始,你可能想着,直接在递归函数里头,碰到一次运算就加个一,像这样:

def count_operations(n):
    if n == 0:
        return 1
    elif n % 3 == 0:
        return count_operations(n // 3) + 4
    elif n % 3 == 1:
        return 6 + count_operations(n - 1) + 4
    else:
        return 9 + 2 * count_operations(n - 2)

但这么搞,有两个大问题:

  1. bar 函数没算进去。 foo 函数里头还套了个 bar 函数,bar 里面做的那些运算,你都没给算上。
  2. 循环也没算清楚。 for 循环和 while 循环里,每次循环判断,其实都藏着运算,这些你也没仔细数。

二、 咋解决?几种办法

1. 手动计数 (累,但直观)

最笨的办法,就是一点一点抠代码,人肉去数。

  • 拆解 foo 函数:

    • m = n % 3: 1 次 %
    • if n == 0:1 次 ==
    • elif m == 0:1 次 ==
    • elif m == 1:1 次 ==
    • 递归调用 foo(n // 3)foo(n - 1)foo(n - 2): //- 运算各一次, 加上函数本身计数.
    • bar函数的调用
    • 各种 return: 可能带有 +.
  • 拆解 bar 函数:

    • score = 0:没运算。
    • for i in range(2, k + 1):循环开始前有 1 次 +,循环每次判断有若干次 +<=
    • j = i - 1:1 次 -
    • while j > 0:每次循环判断 1 次 >
    • if i % j == 0:1 次 %,1 次 ==
    • score = score // 2score = score + 10:1 次 // 或 1 次 +
    • j = j - 1:1 次 -
  • 然后根据不同的 n 的值,去算它走了 foo 的哪个分支,bar 执行了几次,每次 bar 里面的循环又走了几遍……总之特别麻烦,容易漏。
    例子:
    当 N=1, foo(1)被执行:

    1. m = n % 3n % 3 (1 次)
    2. 三次 ==比较 (3 次).
    3. 执行elif m == 1, v = foo(n - 1) 被执行 (n-1, 1 次).
    4. foo(0) 被调用:
      1. n%3, n==0 (2 次).
      2. 返回 1.
    5. 返回到 foo(1)继续执行,计算 bar(n * n * n). 三次* (3次).
    6. bar(1) 被调用
      1. range(2, k+1):其中有k+1, 算一次 + 运算。但是由于 range 函数特性, 不产生运算。
    7. 返回到foo(1), 计算最后的+.

总共 1+3+1+2+3+1= 11次。
当 N >1, 情况非常复杂,容易导致错误,所以不建议。

2. 重载运算符 (巧妙, 但有坑)

有个很巧的方法,就是把 Python 里面的那些运算符,给它“偷梁换柱”一下。

class Integer(int):
    n_ops = 0

def new_patch(name):
    def patch(self, *args):
        Integer.n_ops += 1
        value = getattr(int, name)(self, *args)
        if isinstance(value, int) and not (value is True or value is False):
            value = Integer(value)
        return value
    patch.__name__ = name
    return patch

methods = {
    '__le__': '\u2264',
    '__lt__': '<',
    '__ge__': '\u2265',
    '__gt__': '>',
    '__eq__': '==',
    '__add__': '+',
    '__sub__': '-',
    '__mul__': '*',
    '__floordiv__': '//',
    '__mod__': '%',
}

for name in methods:
    setattr(Integer, name, new_patch(name))

def bar(k):
    score = Integer(0)
    for i in range(2, k + 1):
        j = i - 1
        while j > 0:
            if i % j == 0:
                score = score // 2
            else:
                score = score + 10
            j = j - 1
    return score

mod = 10**9 + 7

def foo(n):
    n = Integer(n)
    m = n % 3
    if n == 0:
        return 1
    elif m == 0:
        v = foo(n // 3)
        t = Integer(0)
        for i in range(1, n+1):
            t = t + bar(4 * i)
        return v + t
    elif m == 1:
        v = foo(n - 1)
        return v + bar(n * n * n)
    else:
        v = foo(n - 2)
        r = Integer(1)
        for a in range(2, n):
            r = r * a % mod
            r = r + bar(n)
        return v + r
Integer.n_ops = 0 # 重置计数
print(foo(5))  #随便给个 n 值, 或其它你需要的值。
print(Integer.n_ops)

原理:

  1. Integer 类: 咱们自己造个 Integer 类,让它继承 Python 里的 int 类。
  2. new_patch 函数: 这个函数是个“补丁”,用来给 Integer 类打补丁的。
  3. patch 函数: 这就是“补丁”的核心。它先给全局变量 Integer.n_ops 加 1,表示运算次数加 1。然后,它调用真正的 int 类的运算符(比如 int.__add__ 就是真的加法)。如果运算结果还是个整数,而且不是布尔值,就把它也变成 Integer 类型,这样下次运算又能被我们“监控”到。
  4. methods 字典: 这里头列出了我们要“监控”的运算符。
  5. 循环 setattrInteger 类的这些运算符,一个个都给换成咱们的 patch 函数。
  6. 将函数内初始赋值为0, 1的变量类型进行替换.

注意:

  • 代码里面 mod = 10**9 + 7没被包括进去.
  • 小心使用.

3. AST 静态分析(高级,但准确)

用 Python 自带的 ast 模块,可以把代码变成一棵“抽象语法树”(Abstract Syntax Tree)。

import ast

class OpCounter(ast.NodeVisitor):
    def __init__(self):
        self.counts = {
            '+': 0, '-': 0, '*': 0, '//': 0, '%': 0,
            '>': 0, '<': 0, '==': 0, '<=': 0, '>=': 0
        }

    def visit_BinOp(self, node):
        op_map = {
            ast.Add: '+', ast.Sub: '-', ast.Mult: '*',
            ast.FloorDiv: '//', ast.Mod: '%'
        }
        op_type = type(node.op)
        if op_type in op_map:
            self.counts[op_map[op_type]] += 1
        self.generic_visit(node)

    def visit_Compare(self, node):
        op_map = {
            ast.Gt: '>', ast.Lt: '<', ast.Eq: '==',
            ast.LtE: '<=', ast.GtE: '>='
        }
        for op in node.ops:
            op_type = type(op)
            if op_type in op_map:
                self.counts[op_map[op_type]] += 1
        self.generic_visit(node)

def count_ops_in_function(func):
    counter = OpCounter()
    counter.visit(ast.parse(inspect.getsource(func)))
    return counter.counts
import inspect

print(count_ops_in_function(bar))
print(count_ops_in_function(foo))

原理:

  • ast.parse 把函数代码变成 AST。
  • OpCounter 类去遍历这棵树。
  • visit_BinOp 专门处理二元运算(加减乘除这些)。
  • visit_Compare 专门处理比较运算(大于小于这些)。
  • 统计完,分别打印.

这种方法不用运行代码, 不用改代码, 特别准,还能分别统计不同类型的运算。

####三种方法的取舍

  • 方法一非常不推荐, 当N很大时, 非常复杂且容易遗漏或数错.
  • 方法二可以满足绝大部分运算次数的统计需求, 缺点是可能会与其他利用了运算符重载的代码冲突.
  • 方法三准确且通用, 是高级玩家的首选.