简体   繁体   English

在 x86-64 平台上用 C(++) 为 64 位无符号参数计算 (a*b)%n FAST?

[英]Compute (a*b)%n FAST for 64-bit unsigned arguments in C(++) on x86-64 platforms?

I'm looking for a fast method to efficiently compute ( ab ) modulo n (in the mathematical sense of that) for a , b , n of type uint64_t .我正在寻找一种快速方法来有效地计算( ab )模n (在数学意义上)对于uint64_t类型的abn I could live with preconditions such as n!=0 , or even a<n && b<n .我可以接受诸如n!=0 ,甚至a<n && b<n先决条件。

Notice that the C expression (a*b)%n won't cut it, because the product is truncated to 64 bits.请注意,C 表达式(a*b)%n不会剪切它,因为乘积被截断为 64 位。 I'm looking for (uint64_t)(((uint128_t)a*b)%n) except that I do not have a uint128_t (that I know, in Visual C++).我正在寻找(uint64_t)(((uint128_t)a*b)%n)除了我没有uint128_t (我知道,在 Visual C++ 中)。

I'm in for a Visual C++ (preferably) or GCC/clang intrinsic making best use of the underlying hardware available on x86-64 platforms;我正在使用 Visual C++(最好)或 GCC/clang 内在方法,以充分利用 x86-64 平台上可用的底层硬件; or if that can't be done for a portable inline function.或者如果对于便携式inline函数无法做到这一点。

Ok, how about this (not tested)好的,这个怎么样(未测试)

modmul:
; rcx = a
; rdx = b
; r8 = n
mov rax, rdx
mul rcx
div r8
mov rax, rdx
ret

The precondition is that a * b / n <= ~0ULL , otherwise there will be a divide error.前提是a * b / n <= ~0ULL ,否则会出现除法错误。 That's a slightly less strict condition than a < n && m < n , one of them can be bigger than n as long as the other is small enough.这是一个比a < n && m < n稍微不严格的条件,其中一个可以大于n ,只要另一个足够小。

Unfortunately it has to be assembled and linked in separately, because MSVC doesn't support inline asm for 64bit targets.不幸的是,它必须单独组装和链接,因为 MSVC 不支持 64 位目标的内联 asm。

It's also still slow, the real problem is that 64bit div , which can take nearly a hundred cycles (seriously, up to 90 cycles on Nehalem for example).它也仍然很慢,真正的问题是 64 位div ,它可能需要近一百个周期(例如,在 Nehalem 上最多需要 90 个周期)。

You could do it the old-fashioned way with shift/add/subtract.您可以使用 shift/add/subtract 以老式的方式来完成。 The below code assumes a < n and下面的代码假设a < n
n < 2 63 (so things don't overflow): n < 2 63 (所以事情不会溢出):

uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
    uint64_t rv = 0;
    while (b) {
        if (b&1)
            if ((rv += a) >= n) rv -= n;
        if ((a += a) >= n) a -= n;
        b >>= 1; }
    return rv;
}

You could use while (a && b) for the loop instead to short-circuit things if it's likely that a will be a factor of n .你可以使用while (a && b)如果它是可能的,而不是循环短路事情a将是一个因素n Will be slightly slower (more comparisons and likely correctly predicted branches) if a is not a factor of n .如果a不是n的因子,则会稍微慢一些(更多的比较和可能正确预测的分支)。

If you really, absolutely, need that last bit (allowing n up to 2 64 -1), you can use:如果你真的,绝对需要最后一点(允许n高达 2 64 -1),你可以使用:

uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
    uint64_t rv = 0;
    while (b) {
        if (b&1) {
            rv += a;
            if (rv < a || rv >= n) rv -= n; }
        uint64_t t = a;
        a += a;
        if (a < t || a >= n) a -= n;
        b >>= 1; }
    return rv;
}

Alternately, just use GCC instrinsics to access the underlying x64 instructions:或者,只需使用 GCC 内在函数来访问底层 x64 指令:

inline uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
    uint64_t rv;
    asm ("mul %3" : "=d"(rv), "=a"(a) : "1"(a), "r"(b));
    asm ("div %4" : "=d"(rv), "=a"(a) : "0"(rv), "1"(a), "r"(n));
    return rv;
}

The 64-bit div instruction is really slow, however, so the loop might actually be faster.然而,64 位 div 指令确实很慢,因此循环实际上可能更快。 You'd need to profile to be sure.你需要配置文件才能确定。

7 years later, I got a solution working in Visual Studio 2019 7 年后,我得到了一个在 Visual Studio 2019 中工作的解决方案

#include <stdint.h>
#include <intrin.h>
#pragma intrinsic(_umul128)
#pragma intrinsic(_udiv128)

// compute (a*b)%n with 128-bit intermediary result
// assumes n>0  and  a*b < n * 2**64 (always the case when a<=n || b<=n )
inline uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
  uint64_t r, s = _umul128(a, b, &r);
  (void)_udiv128(r, s, n, &r);
  return r;
}

// compute (a*b)%n with 128-bit intermediary result
// assumes n>0, works including if a*b >= n * 2**64
inline uint64_t mulmod1(uint64_t a, uint64_t b, uint64_t n) {
  uint64_t r, s = _umul128(a % n, b, &r);
  (void)_udiv128(r, s, n, &r);
  return r;
}

This intrinsic is named __mul128 .此内在函数名为__mul128

typedef unsigned long long BIG;

// handles only the "hard" case when high bit of n is set
BIG shl_mod( BIG v, BIG n, int by )
{
    if (v > n) v -= n;
    while (by--) {
        if (v > (n-v))
            v -= n-v;
        else
            v <<= 1;
    }
    return v;
}

Now you can use shl_mod(B, n, 64)现在你可以使用shl_mod(B, n, 64)

Having no inline assembly kind of sucks.没有内联汇编有点糟糕。 Anyway, the function call overhead is actually extremely small.不管怎样,函数调用的开销其实是非常小的。 Parameters are passed in volatile registers and no cleanup is needed.参数在易失性寄存器中传递,不需要清理。

I don't have an assembler, and x64 targets don't support __asm, so I had no choice but to "assemble" my function from opcodes myself.我没有汇编器,而且 x64 目标不支持 __asm,所以我别无选择,只能自己从操作码“组装”我的函数。

Obviously it depends on .显然这取决于 . I'm using mpir (gmp) as a reference to show the function produces correct results.我使用 mpir (gmp) 作为参考来显示函数产生正确的结果。


#include "stdafx.h"

// mulmod64(a, b, m) == (a * b) % m
typedef uint64_t(__cdecl *mulmod64_fnptr_t)(uint64_t a, uint64_t b, uint64_t m);

uint8_t mulmod64_opcodes[] = {
    0x48, 0x89, 0xC8, // mov rax, rcx
    0x48, 0xF7, 0xE2, // mul rdx
    0x4C, 0x89, 0xC1, // mov rcx, r8
    0x48, 0xF7, 0xF1, // div rcx
    0x48, 0x89, 0xD0, // mov rax,rdx
    0xC3              // ret
};

mulmod64_fnptr_t mulmod64_fnptr;

void init() {
    DWORD dwOldProtect;
    VirtualProtect(
        &mulmod64_opcodes,
        sizeof(mulmod64_opcodes),
        PAGE_EXECUTE_READWRITE,
        &dwOldProtect);
    // NOTE: reinterpret byte array as a function pointer
    mulmod64_fnptr = (mulmod64_fnptr_t)(void*)mulmod64_opcodes;
}

int main() {
    init();

    uint64_t a64 = 2139018971924123ull;
    uint64_t b64 = 1239485798578921ull;
    uint64_t m64 = 8975489368910167ull;

    // reference code
    mpz_t a, b, c, m, r;
    mpz_inits(a, b, c, m, r, NULL);
    mpz_set_ui(a, a64);
    mpz_set_ui(b, b64);
    mpz_set_ui(m, m64);
    mpz_mul(c, a, b);
    mpz_mod(r, c, m);

    gmp_printf("(%Zd * %Zd) mod %Zd = %Zd\n", a, b, m, r);

    // using mulmod64
    uint64_t r64 = mulmod64_fnptr(a64, b64, m64);
    printf("(%llu * %llu) mod %llu = %llu\n", a64, b64, m64, r64);
    return 0;
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM