Skip to content

PyO3:使用 Rust 扩展 Python 提升执行速度

发布于  at 12:37 AM更新于  at 12:03 PM

PyO3 是 Python 的 Rust 绑定。

虚拟环境与 maturin

创建一个 Python 虚拟环境(更好的管理)

mkdir simple_trig
pipenv install

激活虚拟环境,安装 maturin

pipenv shell
pipenv install maturin

初始化 maturin

maturin init

会自动生成 Rust 的项目,生成的 lib.rs 如下:

use pyo3::prelude::*;

/// Formats the sum of two numbers as string.
#[pyfunction]
fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
    Ok((a + b).to_string())
}

/// A Python module implemented in Rust.
#[pymodule]
fn simple_trig(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
    Ok(())
}

只需要运行 maturin develop,就可以自动打包和安装到 Python 的虚拟环境。

% python
>>> import simple_trig
>>> simple_trig.sum_as_string(5, 20)
'25'

新建 Python 脚本

from random import random
from time import perf_counter

COUNT = 500000  # Change this value depending on the speed of your computer
DATA = [(random() - 0.5) * 3 for _ in range(COUNT)]

e = 2.7182818284590452353602874713527

def sinh(x):
    return (1 - (e ** (-2 * x))) / (2 * (e ** -x))

def cosh(x):
    return (1 + (e ** (-2 * x))) / (2 * (e ** -x))

def tanh(x):
    tanh_x = sinh(x) / cosh(x)
    return tanh_x

def test(fn, name):
    start = perf_counter()
    result = fn(DATA)
    duration = perf_counter() - start
    print('{} took {:.3f} seconds\n'.format(name, duration))

    for d in result:
        assert -1 <= d <= 1, " incorrect values"

if __name__ == "__main__":
    print('Running benchmarks with COUNT = {}'.format(COUNT))

    test(lambda d: [tanh(x) for x in d], '[tanh(x) for x in d] (Python implementation)')

Rust 实现

use pyo3::prelude::*;

/// Formats the sum of two numbers as string.
#[pyfunction]
fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
    Ok((a + b).to_string())
}

const E: f64 = 2.7182818284590452353602874713527;

fn sinh_impl(x: f64) -> f64 {
    (1.0 - f64::powf(E, -2.0 * x)) / (2.0 * f64::powf(E, -x))
}

fn cosh_impl(x: f64) -> f64 {
    (1.0 + f64::powf(E, -2.0 * x)) / (2.0 * f64::powf(E, -x))
}

#[pyfunction]
fn fast_tanh(x: f64) -> PyResult<f64> {
    Ok(sinh_impl(x) / cosh_impl(x))
}

/// A Python module implemented in Rust.
#[pymodule]
fn simple_trig(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
    m.add_function(wrap_pyfunction!(fast_tanh, m)?)?;
    Ok(())
}

Bound

PyO3 0.20+ 中引入了 Bound 类型,用于解决 Rust 与 Python 交互时的内存安全性和对象生命周期管理问题。

Bound 是 PyO3 提供的一个智能指针(类似 RcArc),用于安全地持有对 Python 对象的引用。

Bound<'_, PyModule> 的含义:

&Bound<'_, PyModule> 表示对一个具有生命周期管理的 Python 模块的不可变引用。

引入 Bound 之前,如果 Python 垃圾回收器在 Rust 代码执行期间回收了模块对象,Rust 代码中的 &PyModule 会变成悬垂指针,导致未定义行为(UB)。

加入测试函数

from simple_trig import fast_tanh

if __name__ == "__main__":
    print('Running benchmarks with COUNT = {}'.format(COUNT))

    test(lambda d: [tanh(x) for x in d], '[tanh(x) for x in d] (Python implementation)')
    test(lambda d: [fast_tanh(x) for x in d], '[fast_tanh(x) for x in d] (Rust extension)')

令我震惊的是,Rust 居然更慢……

Running benchmarks with COUNT = 500000
[tanh(x) for x in d] (Python implementation) took 0.234 seconds

[fast_tanh(x) for x in d] (Rust extension) took 0.282 seconds

好吧,一定是哪里出问题了。

(2025 年现在重新运行,PyO3 默认已经比 Python 版本快了)

实际使用的时候应该使用 --release,可以极大的提升速度:

maturin develop --release

最终可以达到这样的效果(下面的测试机器和上面的测试机器不同,以各自测试的 Python 速度作为基准比较合适,不要横向比较):

Running benchmarks with COUNT = 500000
[tanh(x) for x in d] (Python implementation) took 0.431 seconds

[fast_tanh(x) for x in d] (Rust extension) took 0.070 seconds

小结

Python 在没有优化的情况下确实在性能上相对而言差一点,使用 Rust 重写 Python 模块,确实可以极大的提升 Python 性能。

本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自小谷的随笔

上一篇
自然语言处理与大语言模型简史
下一篇
布隆过滤器原理及其应用