[英]Optimize summation code
来自采访:
int fn(int a, int b)
{
int sum = 0;
for (int i = a * 4; i > 0; i--)
{
sum += b * i * i;
}
return sum;
}
该代码如何进一步优化? 我知道有一个求和公式,但是我不认为记住这样的公式是面试官想要的。 那么,您将如何优化它?
编辑:谢谢chqrlie,faivvy,asimes和Ap31的建议和解答。 所以我想目前有三种优化方法:
在这三个答案中,我可能会选择1和3,因为它们可以应用于结构相似的所有类型的代码。 您应该提到有一个公式可以用作奖励,但是我怀疑公式是否是面试官想要的。
还有其他建议吗?
公式:1 * 1 + 2 * 2 + ... + n * n = n(n + 1)(2n + 1)/ 6
int fn(int a, int b)
{
a <<= 2;
return (a*(a + 1)*((a << 1) + 1) / 6) * b;
}
那是你要的吗?
除非a
为负,否则函数fn
计算b
乘以平方和,直到4*a
。
从1
到n
的平方和可以计算为n(n + 1)(2n + 1)/ 6 。
这是C语言的翻译:
int fn(int a, int b) {
if (a <= 0 || b == 0) {
return 0;
} else {
int n = a * 4;
return n * (n + 1) * (2 * n + 1) / 6 * b;
}
}
正如Ap31所指出的,clang精巧到足以检测到循环优化并将原始函数转换为直接计算的能力,但是它将上面的代码编译为更加紧凑的16条汇编指令 (原始代码为36条)。
为避免中间结果可能出现溢出,这里有一个略有不同的公式,该公式不计算较大的中间结果:
int fn(int a, int b) {
if (a <= 0 || b == 0) {
return 0;
} else {
if (a % 3 == 0)
return (a / 3) * (4 * a + 1) * (8 * a + 1) * b * 2;
else
return (4 * a + 1) * (8 * a + 1) / 3 * a * b * 2;
}
}
如果long long
类型大于int
类型,则更简单的选择是:
int fn(int a, int b) {
if (a <= 0 || b == 0) {
return 0;
} else {
unsigned long long n = a * 4;
return (int)(n * (n + 1) * (2 * n + 1) / 6 * b);
}
}
面试官当然希望通过@faivvy(和@chqrlie)答案进行优化,您可以始终导出公式,或者只是说您知道该公式存在,就可以完全摆脱循环。
不要忘记一些通常的陷阱: a
可能为负, a*a*(2*a + 1)
可能溢出。
还要注意的另一件事是, 现代编译器可以自己执行此操作 -您也可以向采访者提及。
正如@faivvy在他的回答中指出的那样,您可以尝试完全消除for循环
但是,另一种方法(正确处理负数a
)是执行循环展开,我将调用该函数fnUnroll
。 如果您不熟悉循环展开,则可以减少迭代次数并并行求和
如评论中所述,每次迭代都不需要乘以b
,这可以在最后完成。 我添加了另一个名为fnUnrollNoMult
函数来显示此信息
#include <chrono>
#include <cstdlib>
#include <iostream>
int fn(int a, int b) {
int sum = 0;
for (int i = a * 4; i > 0; i--)
sum += b * i * i;
return sum;
}
int fnUnroll(int a, int b) {
// Set up some number of accumulators, I picked 4
int sum0 = 0;
int sum1 = 0;
int sum2 = 0;
int sum3 = 0;
int i = 1;
int limit = a * 4;
// Sum 4 values in parallel
for ( ; i < limit; i += 4) {
sum0 += b * i * i;
sum1 += b * (i + 1) * (i + 1);
sum2 += b * (i + 2) * (i + 2);
sum3 += b * (i + 3) * (i + 3);
}
// Handle the remainder (if any)
for ( ; i < limit; i++)
sum0 += b * i + i;
// Sum the accumulators
return sum0 + sum1 + sum2 + sum3;
}
int fnUnrollNoMult(int a, int b) {
int sum0 = 0;
int sum1 = 0;
int sum2 = 0;
int sum3 = 0;
// Remove b from the loops
int i = 1;
int limit = a * 4;
for ( ; i < limit; i += 4) {
sum0 += i * i;
sum1 += (i + 1) * (i + 1);
sum2 += (i + 2) * (i + 2);
sum3 += (i + 3) * (i + 3);
}
for ( ; i < limit; i++)
sum0 += i + i;
// Handle b here
return b * (sum0 + sum1 + sum2 + sum3);
}
int main(int argc, char** argv) {
// Expects two arguments: a and b
if (argc != 3) {
std::cout << "Usage: " << argv[0] << " <int> <int>\n";
return 1;
}
int a = atoi(argv[1]);
int b = atoi(argv[2]);
// This is just to demonstrate correctness
for (int i = 0; i < 100; i++)
for (int j = 0; j < 100; j++)
if (
fn(i, j) != fnUnroll(i, j) ||
fn(i, j) != fnUnrollNoMult(i, j)
) {
std::cout << "Not equal: " << i << ", " << j << std::endl;
return 1;
}
// Benchmark
using namespace std::chrono;
{
auto start = high_resolution_clock::now();
int result = fn(a, b);
auto stop = high_resolution_clock::now();
std::cout << "fn value: " << result << std::endl;
std::cout << "fn nanos: " << duration_cast<nanoseconds>(stop - start).count() << std::endl;
}
{
auto start = high_resolution_clock::now();
int result = fnUnroll(a, b);
auto stop = high_resolution_clock::now();
std::cout << "fnUnroll value: " << result << std::endl;
std::cout << "fnUnroll nanos: " << duration_cast<nanoseconds>(stop - start).count() << std::endl;
}
{
auto start = high_resolution_clock::now();
int result = fnUnrollNoMult(a, b);
auto stop = high_resolution_clock::now();
std::cout << "fnUnrollNoMult value: " << result << std::endl;
std::cout << "fnUnrollNoMult nanos: " << duration_cast<nanoseconds>(stop - start).count() << std::endl;
}
return 0;
}
下面的程序需要两个表示a
和b
参数。 下面,我将该程序编译为g++ -std=c++14 foo.cpp -O3
并针对一些a
值获得了这些结果:
./a.out 1 2
fn value: 60
fn nanos: 373
fnUnroll value: 60
fnUnroll nanos: 209
fnUnrollNoMult value: 60
fnUnrollNoMult nanos: 157
./a.out 1000 2
fn value: -267004960
fn nanos: 3509
fnUnroll value: -267004960
fnUnroll nanos: 2820
fnUnrollNoMult value: -267004960
fnUnrollNoMult nanos: 1568
./a.out 1000000 2
fn value: -619707648
fn nanos: 3137685
fnUnroll value: -619707648
fnUnroll nanos: 2387840
fnUnrollNoMult value: -619707648
fnUnrollNoMult nanos: 1220519
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.