找回密码
 立即注册
查看: 253|回复: 4

阶乘优化目前是什么程度?

[复制链接]
发表于 2023-2-13 13:10 | 显示全部楼层 |阅读模式
上课时无意间谈到100000!,于是给了我们组一个题目,要求我们将100000!优化至阶乘6差不多的运行时间。。。。。。当时就懵b了,所以有没有哪位大佬可以留下一些思路和代码,不胜感激。(个人认为一秒左右可以过关)
发表于 2023-2-13 13:18 | 显示全部楼层
题主的野心还是不够大,下列算法求一百万的阶乘,只需要0.2s!(不含进制转换和输出的时间)

怎么样用c语言求1000的阶乘?
发表于 2023-2-13 13:23 | 显示全部楼层
利用快速傅里叶变化和拉格朗日插值可以优化到
 楼主| 发表于 2023-2-13 13:23 | 显示全部楼层
另一位答主@北海若 用python从朴素算法一步步优化,展示了几种计算阶乘的方法,本来我想用C++实现一下,但效率一直无法突破,所以此处暂时略去,如果以后有啥发现再补,没有就不补了。
在这里补一下C++实现的朴素算法吧,只放一些没注释的代码,原理与下面的python一致(虽然python也没有注释
#include<iostream>
#include<chrono>
#include "unsigned_bigint.h"
#include "timer.h"
using namespace std;
using ubigint = kedixa::unsigned_bigint;

ubigint multi(int s, int t)
{
    if(t - s < 40)
    {
        ubigint u(1u);
        for(; s <= t; ++s)
            u *= s;
        return u;
    }
    int mid = (t+s) / 2;
    return multi(s, mid) * multi(mid+1, t);
}
ubigint fac(int n)
{
    return multi(1, n);
}
ubigint fac3(int n)
{
    int M = 64;
    vector<ubigint> vbig(M, ubigint(1));
    for(int i = 2; i <= n;)
    {
        ubigint t(1u);
        int r = 0;
        while(i <= n && t.size() < 50)
            t *= i, ++i, ++r;
        for(int j = 0; j < M; ++j)
        {
            if(vbig[j] == 1u) {
                vbig[j] = std::move(t);
                break;
            }
            t *= vbig[j];
            vbig[j] = ubigint(1u);
        }
    }
    ubigint u(1u);
    for(auto &b : vbig)
        u *= b;
    return u;
}
int main()
{
    kedixa::timer t;
    auto u1 = fac(100000);
    cout << t.stop() << endl;
    t.reset();
    u1 = fac3(100000);
    cout << t.stop() << endl;
}
无符号整数乘法是自己写的,需要的头文件在kedixa/klibcpp,实现在kedixa/klibcpp,实现的思路在C++大整数运算(一):概述 - kedixa的博客,找一些优秀的高精度计算库可能会更快一些。cpython中的算法还没看懂,如果能按照相同的思路写出来,效率应该差不多。
补充完毕。
----------
此处说一个关于朴素算法另一位答主可能没发现的优化,在计算高精度乘法的时候,用一个很大的整数乘以一个很小的整数,此时不管用什么方法计算,复杂度都不会降下来,而当两个大小相近的整数相乘的时候,高精度乘法的优势才能展现出来,因此如果用下面这种方法计算,速度会提高很多。
import math
import time
import timeit

# 朴素方法
def f1():
    y = 1
    for i in range(1, 100001):
        y *= i
    return y

# 一点小优化
def f2(s, t):
    if s==t:
        return s
    mid = (s+t)//2
    return f2(s, mid) * f2(mid+1, t)

# 小范围优化
def f3(s, t):
    if t - s < 50:
        r = 1
        for i in range(s, t+1):
            r *= i
        return r
    mid = (s+t)//2
    return f3(s, mid) * f3(mid+1, t)

print(timeit.timeit("f1()", "from __main__ import f1", number = 1))
print(timeit.timeit("f2(1, 100000)", "from __main__ import f2", number = 1))
print(timeit.timeit("f3(1, 100000)", "from __main__ import f3", number = 1))
print(timeit.timeit("math.factorial(100000)", "import math", number=1))我的电脑性能略差(划掉),修正了电脑性能差的问题后(换了个电脑跑),上面的代码输出约为
3.5504379999999855
0.32651099999998223
0.23477900000000318
0.18042300000001887感觉仔细研究一下还能再快一点,在同一个机器上,我用C++实现这种方法(f3)在 O2的优化下大概需要0.18秒,与math.factorial比较接近了。
8.23
再补一个非递归版的
def f4(n):
    L = [1 for i in range(32)]
    i = 2
    while i <= n:
        t, r = 1, 0
        while i <= n and r < 50:
            t *= i
            i += 1
            r += 1
            pass
        for j in range(0, 32):
            if L[j] == 1:
                L[j] = t
                break
            t *= L[j]
            L[j] = 1
    u = 1
    for i in L:
        u *= i
    return u按说非递归应该更快一点,但实际跑起来比f3略慢。看了一下cpython的实现 https://github.com/python/cpython/blob/772d809a63f40fd35679da3fb115cdf7fa81bd20/Modules/mathmodule.c#L1654  好像用的就是[1]中的方法。
另外提供一点资料
[1] Matrix67: The Aha Moments
[2] Java and C# Implementations
发表于 2023-2-13 13:28 | 显示全部楼层
一般的高精度乘法,随着位数的增加是一个O(n^2)的复杂度,因为我们模拟了人列竖式运算的过程。
第一层优化,可以把多位结合成一个int一起计算。计算机算123*12和3*2和速度相差无几。
然后,其实这个运算,是一个做多项式乘法的过程,可以用快速傅里叶变换,削减到O(nlog2n)的复杂度。Python内部的高精度乘法用的就是这个。
我想Python高精度已经是最优的方法了,测试了一下,大概三秒多的样子(不包括输出)。
——————————————补充——————————————
测试了一下,3.071614719199715秒
但是,如果使用Python数学库内的阶乘,只需要0.1973350872481774秒!
# 测试程序
import math
import time
begin = time.clock()
ans3 = math.factorial(100000)
end = time.clock()
print(end - begin)

begin = time.clock()
ans2 = 1
for i in range(1, 100001):
    ans2 *= i
end = time.clock()
print(end - begin)

print(ans2 == ans3)
"""
输出
0.1973350872481774
3.071614719199715
True
"""回答问题的同时自己也长知识了,阶乘居然可以这么算!
快速阶乘运算 - alexingcool的专栏 - CSDN博客
这篇文章的第一种算法对于计算较小的阶乘来说的确是足够的(或者,每步阶乘计算时取个余数什么的也是足够的),但是由于需要高精度这一因素,而x = (1 << n) + 1(n = 100000)的情况下,计算x ^ n即便使用log2n的快速幂,依旧是非常耗时的工作
在2000!时这个算法已经有些超时了,不用说100000了。

这篇文章后面还有第二种算法。其实只是对第一种的小改进,储存了已经计算好的组合数。
但是这样依然不够。依旧是上述问题。在计算2000!的时候,只需要0.2秒。但是3000!的时候便消耗了1.23s。甚至还不如用直接相乘的朴素算法来的快!
想想怎么改进吧……
# -*- coding: utf-8 -*-
"""
测试代码
"""

import math
import time


c = [1 for i in range(100001)]
cnrs = [0 for i in range(100001)]
num = 0
p_size = 0
msk = 0


def power(n, m):
    global num, p_size, msk
    if m == 1:
        temp = n
    elif m & 0x01 != 0:
        temp = n * power(n, m - 1)
    else:
        temp = power(n, m >> 1)
        temp *= temp
        cnrs[num] = ((temp >> ((m >> 1) * p_size)) & msk)
        num += 1
    return temp


def factor(n):
    global num
    if n == 1:
        return 1
    elif n & 0x01 == 1:
        return n * factor(n - 1)
    else:
        temp = factor(n >> 1)
        temp = cnrs[num] * (temp * temp)
        num += 1
        return temp


def factorial(n):
    global num, msk, p_size
    x = (2 ** n) + 1
    num = 0
    msk = x - 2
    p_size = n
    power(x, n)
    num = 0
    return factor(n)

begin = time.clock()
ans3 = math.factorial(3000)
end = time.clock()
print(end - begin)

begin = time.clock()
ans2 = factorial(3000)
end = time.clock()
print(end - begin)
print(ans2 == ans3)对朴素算法的优化倒是想到一个。
2 4 6 8 ... 100000的乘积可以变成2 ^ 50000(由于快速幂会快一些) * 50000 !
50000!再做同样的处理。
由于高精度的存在,减少了位数就是减少了运算时间。
这样一次下来,就减少到了1.55秒。正在实现递归优化。
————————补充————————
import math
import time


def fact(n):
    if n == 1:
        return 1
    tmp = n >> 1
    ans1 = fact(tmp) * (2 ** tmp)
    ans2 = 1
    for i in range(1, n + 1, 2):
        ans2 *= i
    return ans2 * ans1

begin = time.clock()
ans3 = math.factorial(100000)
end = time.clock()
print(end - begin)

begin = time.clock()
ans2 = fact(100000)
end = time.clock()
print(end - begin)
print(ans2 == ans3)输出:
0.20090164677547095
1.1225459590110793
True
一个简单的优化,大概少了3倍的时间。1秒左右的算法出来咯~
感觉还有待继续优化。因为在不考虑高精度的情况下,比较上面所述的log2n的算法来说,O(n)的时间其实还有很大的优化空间。
如果上面可以避免那两个高精度位运算,或者不让中间值那么大,就能达到进一步优化的效果。
————————更新————————
常数优化。方法是黑科技一般的拆循环。
import math
import time


def fact(n):
    if n == 1:
        return 1
    tmp = n >> 1
    ans1 = fact(tmp) * (2 ** tmp)
    ans2 = 1
    ans3 = 1
    for i in range(1, n + 1, 4):
        ans2 *= i
    for i in range(3, n + 1, 4):
        ans3 *= i
    return ans3 * ans2 * ans1

begin = time.clock()
ans3 = math.factorial(100000)
end = time.clock()
print(end - begin)

begin = time.clock()
ans2 = fact(100000)
end = time.clock()
print(end - begin)
print(ans2 == ans3)这样进一步减少了大部分高精度计算位数
现在只需要0.65秒
懒得打字嘛,点击右侧快捷回复 【右侧内容,后台自定义】
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Unity开发者联盟 ( 粤ICP备20003399号 )

GMT+8, 2025-1-23 22:31 , Processed in 0.131779 second(s), 23 queries .

Powered by Discuz! X3.5 Licensed

© 2001-2024 Discuz! Team.

快速回复 返回顶部 返回列表