简体   繁体   中英

Why does optimized prime-factor counting algorithm run slower

HiI saw an online answer for counting the distinct prime-factors of a number, and it looked non-optimal. So I tried to improve it, but in a simple benchmark, my variant is much slower than the original.

The algorithm counts the distinct prime factors of a number. The original uses a HashSet to collect the factors, then uses size to get their number. My "improved" version uses an int counter, and breaks up while loops into if/while to avoid unnecessary calls.

Update: tl/dr (see accepted answer for details)

The original code had a performance bug calling Math.sqrt unnecessarily that the compiler fixed:

int n = ...;
// sqrt does not need to be recomputed if n does not change
for (int i = 3; i <= Math.sqrt(n); i += 2) {
    while (n % i == 0) {
        n /= i;
    }
}

The compiler optimized the sqrt call to only happen when n changes. But by making the loop contents a little more complex (no functional change though), the compiler stopped optimizing that way, and sqrt was called on every iteration.

Original question

public class PrimeFactors {

    // fast version, takes 10s for input 8
    static int countPrimeFactorsSet(int n) {
        Set<Integer> primeFactorSet = new HashSet<>();
        while (n % 2 == 0) {
            primeFactorSet.add(2);
            n /= 2;
        }
        for (int i = 3; i <= Math.sqrt(n); i += 2) {
            while (n % i == 0) {
                primeFactorSet.add(i);
                n /= i;
            }
        }
        if (n > 2) {
            primeFactorSet.add(n);
        }
        return primeFactorSet.size();
    }

    // slow version, takes 19s for input 8
    static int countPrimeFactorsCounter(int n) {
        int count = 0; // using simple int
        if (n % 2 == 0) {
            count ++; // only add on first division
            n /= 2;
            while (n % 2 == 0) {
                n /= 2;
            }
        }
        for (int i = 3; i <= Math.sqrt(n); i += 2) {
            if (n % i == 0) {
                count++; // only add on first division
                n /= i;
                while (n % i == 0) {
                    n /= i;
                }
            }
        }
        if (n > 2) {
            count++;
        }
        return count;
    }

    static int findNumberWithNPrimeFactors(final int n) {
        for (int i = 3; ; i++) {
            // switch implementations
            if (countPrimeFactorsCounter(i) == n) {
            // if (countPrimeFactorsSet(i) == n) {
                return i;
            }
        }
    }

    public static void main(String[] args) {
        findNumberWithNPrimeFactors(8); // benchmark warmup
        findNumberWithNPrimeFactors(8);
        long start = System.currentTimeMillis();
        int result = findNumberWithNPrimeFactors(n);
        long duration = System.currentTimeMillis() - start;

        System.out.println("took ms " + duration + " to find " + result);
    }
}

The output for the original version is consistently around 10s (on java8), whereas the "optimized" version is closer to 20s (both print the same result). Actually, just changing the single while-loop to an if-block with a contained while-loop already slows down the original method to half the speed.

Using -Xint to run the JVM in interpreted mode, the optimized version runs 3 times faster. Using -Xcomp makes both implementations run at similar speed. So it seems the JIT can optimize the version with a single while-loop and a HashSet more than the version with a simple int counter.

Would a proper microbenchmark ( How do I write a correct micro-benchmark in Java? ) tell me something else? Is there a performance optimization principle I overlooked (eg Java performance tips )?

I converted your example into JMH benchmark to make fair measurements, and indeed the set variant appeared twice as fast as counter :

Benchmark              Mode  Cnt     Score    Error   Units
PrimeFactors.counter  thrpt    5   717,976 ±  7,232  ops/ms
PrimeFactors.set      thrpt    5  1410,705 ± 15,894  ops/ms

To find out the reason, I reran the benchmark with built-in -prof xperfasm profiler. It happened that counter method spent more than 60% time executing vsqrtsd instruction - obviously, the compiled counterpart of Math.sqrt(n) .

  0,02%   │  │  │     │  0x0000000002ab8f3e: vsqrtsd %xmm0,%xmm0,%xmm0    <-- Math.sqrt
 61,27%   │  │  │     │  0x0000000002ab8f42: vcvtsi2sd %r10d,%xmm1,%xmm1

At the same time the hottest instruction of the set method was idiv , the result of n % i compilation.

             │  │ ││  0x0000000002ecb9e7: idiv   %ebp               ;*irem
 55,81%      │  ↘ ↘│  0x0000000002ecb9e9: test   %edx,%edx

It's not a surprise that Math.sqrt is a slow operation. But why it was executed more frequently in the first case?

The clue is the transformation of the code you made during optimization. You wrapped a simple while loop into an extra if block. This made the control flow a little more complex, so that JIT failed to hoist Math.sqrt computation out of the loop and had to recompute it on every iteration.

We need to help JIT compiler a bit in order to bring the performance back. Let's hoist Math.sqrt computation out of the loop manually.

    static int countPrimeFactorsSet(int n) {
        Set<Integer> primeFactorSet = new HashSet<>();
        while (n % 2 == 0) {
            primeFactorSet.add(2);
            n /= 2;
        }
        double sn = Math.sqrt(n);  // compute Math.sqrt out of the loop
        for (int i = 3; i <= sn; i += 2) {
            while (n % i == 0) {
                primeFactorSet.add(i);
                n /= i;
            }
            sn = Math.sqrt(n);     // recompute after n changes
        }
        if (n > 2) {
            primeFactorSet.add(n);
        }
        return primeFactorSet.size();
    }

    static int countPrimeFactorsCounter(int n) {
        int count = 0; // using simple int
        if (n % 2 == 0) {
            count ++; // only add on first division
            n /= 2;
            while (n % 2 == 0) {
                n /= 2;
            }
        }
        double sn = Math.sqrt(n);  // compute Math.sqrt out of the loop
        for (int i = 3; i <= sn; i += 2) {
            if (n % i == 0) {
                count++; // only add on first division
                n /= i;
                while (n % i == 0) {
                    n /= i;
                }
                sn = Math.sqrt(n);     // recompute after n changes
            }
        }
        if (n > 2) {
            count++;
        }
        return count;
    }

Now counter method became fast! Even a bit faster than set (which is quite expected, because it does the same amount of computation, excluding the Set overhead).

Benchmark              Mode  Cnt     Score    Error   Units
PrimeFactors.counter  thrpt    5  1513,228 ± 13,046  ops/ms
PrimeFactors.set      thrpt    5  1411,573 ± 10,004  ops/ms

Note that set performance did not change, because JIT was able to do the same optimization itself, thanks to a simpler control flow graph.

Conclusion: Java performance is a really complicated thing, especially when talking about micro-optimizations. JIT optimizations are fragile, and it's hard to understand JVM's mind without specialized tools like JMH and profilers.

First off, there are two sets of operations in the tests: Testing for factors, and recording those factors. When switching up the implementations, using a Set, versus using an ArrayList (in my rewrite, below), versus simply counting the factors will make a difference.

Second off, I'm seeing very large variations in the timings. This is running from Eclipse. I have no clear sense of what is causing the big variations.

My 'lessons learned' is to be mindful of what exactly it being measured. Is the intent to measure the factorization algorithm itself (the cost of the while loops plus the arithmetic operations)? Should time recording the factors be included?

A minor technical point: The lack of multiple-value-setq , which is available in lisp, is keenly felt in this implementation. One would very much rather perform the remainder and integer division as a single operation, rather than writing these out as two distinct steps. From a language and algorithm studies perspective, this is worth looking up.

Here are timing results for three variations of the factorization implementation. The first is from the initial (un-optimized) implementation, but changed to use a simple List instead of a harder to time Set to store the factors. The second is your optimization, but still tracking using a List. The third is your optimization, but including the change to count the factors.

  18 -  3790 1450 2410 (average of 10 iterations)
  64 -  1630 1220  260 (average of 10 iterations)
1091 - 16170 2850 1180 (average of 10 iterations)
1092 -  2720 1370  380 (average of 10 iterations)

4096210 - 28830 5430 9120 (average of  10 iterations, trial 1)
4096210 - 18380 6190 5920 (average of  10 iterations, trial 2)
4096210 - 10072 5816 4836 (average of 100 iterations, trial 1)
4096210 -  7202 5036 3682 (average of 100 iterations, trial 1)

---

Test value [ 18 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621713914872600 (ns) ]
End   [ 1621713914910500 (ns) ]
Delta [ 37900 (ns) ]
Avg   [ 3790 (ns) ]
Factors: [2, 3, 3]
Times [optimized]
Start [ 1621713915343500 (ns) ]
End   [ 1621713915358000 (ns) ]
Delta [ 14500 (ns) ]
Avg   [ 1450 (ns) ]
Factors: [2, 3, 3]
Times [counting]
Start [ 1621713915550400 (ns) ]
End   [ 1621713915574500 (ns) ]
Delta [ 24100 (ns) ]
Avg   [ 2410 (ns) ]
Factors: 3
---
Test value [ 64 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621747046013900 (ns) ]
End   [ 1621747046030200 (ns) ]
Delta [ 16300 (ns) ]
Avg   [ 1630 (ns) ]
Factors: [2, 2, 2, 2, 2, 2]
Times [optimized]
Start [ 1621747046337800 (ns) ]
End   [ 1621747046350000 (ns) ]
Delta [ 12200 (ns) ]
Avg   [ 1220 (ns) ]
Factors: [2, 2, 2, 2, 2, 2]
Times [counting]
Start [ 1621747046507900 (ns) ]
End   [ 1621747046510500 (ns) ]
Delta [ 2600 (ns) ]
Avg   [ 260 (ns) ]
Factors: 6
---
Test value [ 1091 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621687024226500 (ns) ]
End   [ 1621687024388200 (ns) ]
Delta [ 161700 (ns) ]
Avg   [ 16170 (ns) ]
Factors: [1091]
Times [optimized]
Start [ 1621687024773200 (ns) ]
End   [ 1621687024801700 (ns) ]
Delta [ 28500 (ns) ]
Avg   [ 2850 (ns) ]
Factors: [1091]
Times [counting]
Start [ 1621687024954900 (ns) ]
End   [ 1621687024966700 (ns) ]
Delta [ 11800 (ns) ]
Avg   [ 1180 (ns) ]
Factors: 1
---
Test value [ 1092 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621619636267500 (ns) ]
End   [ 1621619636294700 (ns) ]
Delta [ 27200 (ns) ]
Avg   [ 2720 (ns) ]
Factors: [2, 2, 3, 7, 13]
Times [optimized]
Start [ 1621619636657100 (ns) ]
End   [ 1621619636670800 (ns) ]
Delta [ 13700 (ns) ]
Avg   [ 1370 (ns) ]
Factors: [2, 2, 3, 7, 13]
Times [counting]
Start [ 1621619636895300 (ns) ]
End   [ 1621619636899100 (ns) ]
Delta [ 3800 (ns) ]
Avg   [ 380 (ns) ]
Factors: 5
---
Test value [ 4096210 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621652753519800 (ns) ]
End   [ 1621652753808100 (ns) ]
Delta [ 288300 (ns) ]
Avg   [ 28830 (ns) ]
Factors: [2, 5, 19, 21559]
Times [optimized]
Start [ 1621652754116300 (ns) ]
End   [ 1621652754170600 (ns) ]
Delta [ 54300 (ns) ]
Avg   [ 5430 (ns) ]
Factors: [2, 5, 19, 21559]
Times [counting]
Start [ 1621652754323500 (ns) ]
End   [ 1621652754414700 (ns) ]
Delta [ 91200 (ns) ]
Avg   [ 9120 (ns) ]
Factors: 4

Here is my rewrite of the test code. Most of interest are findFactors , findFactorsOpt , and findFactorsCount .

package my.tests;

import java.util.ArrayList;
import java.util.List;

public class PrimeFactorsTest {

    public static void main(String[] args) {
        if ( args.length < 2 ) {
            System.out.println("Usage: " + PrimeFactorsTest.class.getName() + " testValue warmupIterations testIterations");
            return;
        }

        int testValue = Integer.valueOf(args[0]);
        int warmCount = Integer.valueOf(args[1]);
        int testCount = Integer.valueOf(args[2]);

        if ( testValue <= 2 ) {
            System.out.println("Test value [ " + testValue + " ] must be at least 2.");
            return;
        } else {
            System.out.println("Test value [ " + testValue + " ]");
        }
        if ( warmCount <= 0 ) {
            System.out.println("Warm-up count [ " + testCount + " ] must be at least 1.");
        } else {
            System.out.println("Warm-up count [ " + warmCount + " ]");
        }
        if ( testCount <= 1 ) {
            System.out.println("Test count [ " + testCount + " ] must be at least 1.");
        } else {
            System.out.println("Test count [ " + testCount + " ]");
        }

        timedFactors(testValue, warmCount, testCount);
        timedFactorsOpt(testValue, warmCount, testCount);
        timedFactorsCount(testValue, warmCount, testCount);
    }

    public static void timedFactors(int testValue, int warmCount, int testCount) {
        List<Integer> factors = new ArrayList<Integer>();

        for ( int warmNo = 0; warmNo < warmCount; warmNo++ ) {
            factors.clear();
            findFactors(testValue, factors);
        }

        long startTime = System.nanoTime();
        for ( int testNo = 0; testNo < testCount; testNo++ ) {
            factors.clear();
            findFactors(testValue, factors);
        }
        long endTime = System.nanoTime();

        System.out.println("Times [non-optimized]");
        System.out.println("Start [ " + startTime + " (ns) ]");
        System.out.println("End   [ " + endTime + " (ns) ]");
        System.out.println("Delta [ " + (endTime - startTime) + " (ns) ]");
        System.out.println("Avg   [ " + (endTime - startTime) / testCount + " (ns) ]");
        System.out.println("Factors: " + factors);
    }

    public static void findFactors(int n, List<Integer> factors) {
        while ( n % 2 == 0 ) {
            n /= 2;
            factors.add( Integer.valueOf(2) );
        }

        for ( int factor = 3; factor <= Math.sqrt(n); factor += 2 ) {
            while ( n % factor == 0 ) {
                n /= factor;
                factors.add( Integer.valueOf(factor) );
            }
        }

        if ( n > 2 ) {
            factors.add( Integer.valueOf(n) );
        }
    }

    public static void timedFactorsOpt(int testValue, int warmCount, int testCount) {
        List<Integer> factors = new ArrayList<Integer>();
        for ( int warmNo = 0; warmNo < warmCount; warmNo++ ) {
            factors.clear();
            findFactorsOpt(testValue, factors);
        }

        long startTime = System.nanoTime();
        for ( int testNo = 0; testNo < testCount; testNo++ ) {
            factors.clear();
            findFactorsOpt(testValue, factors);
        }
        long endTime = System.nanoTime();

        System.out.println("Times [optimized]");
        System.out.println("Start [ " + startTime + " (ns) ]");
        System.out.println("End   [ " + endTime + " (ns) ]");
        System.out.println("Delta [ " + (endTime - startTime) + " (ns) ]");
        System.out.println("Avg   [ " + (endTime - startTime) / testCount + " (ns) ]");
        System.out.println("Factors: " + factors);
    }

    public static void findFactorsOpt(int n, List<Integer> factors) {
        if ( n % 2 == 0 ) {
            n /= 2;

            Integer factor = Integer.valueOf(2); 
            factors.add(factor);

            while (n % 2 == 0) {
                n /= 2;

                factors.add(factor);
            }
        }

        for ( int factorValue = 3; factorValue <= Math.sqrt(n); factorValue += 2) {
            if ( n % factorValue == 0 ) {
                n /= factorValue;

                Integer factor = Integer.valueOf(factorValue); 
                factors.add(factor);

                while ( n % factorValue == 0 ) {
                    n /= factorValue;
                    factors.add(factor);
                }
            }
        }

        if (n > 2) {
            factors.add( Integer.valueOf(n) );
        }
    }

    public static void timedFactorsCount(int testValue, int warmCount, int testCount) {
        int numFactors = 0;

        for ( int warmNo = 0; warmNo < warmCount; warmNo++ ) {
            numFactors = findFactorsCount(testValue);
        }

        long startTime = System.nanoTime();
        for ( int testNo = 0; testNo < testCount; testNo++ ) {
            numFactors = findFactorsCount(testValue);
        }
        long endTime = System.nanoTime();

        System.out.println("Times [counting]");
        System.out.println("Start [ " + startTime + " (ns) ]");
        System.out.println("End   [ " + endTime + " (ns) ]");
        System.out.println("Delta [ " + (endTime - startTime) + " (ns) ]");
        System.out.println("Avg   [ " + (endTime - startTime) / testCount + " (ns) ]");
        System.out.println("Factors: " + numFactors);
    }

    public static int findFactorsCount(int n) {
        int numFactors = 0;

        if ( n % 2 == 0 ) {
            n /= 2;
            numFactors++;

            while (n % 2 == 0) {
                n /= 2;
                numFactors++;
            }
        }

        for ( int factorValue = 3; factorValue <= Math.sqrt(n); factorValue += 2) {
            if ( n % factorValue == 0 ) {
                n /= factorValue;
                numFactors++;

                while ( n % factorValue == 0 ) {
                    n /= factorValue;
                    numFactors++;
                }
            }
        }

        if (n > 2) {
            numFactors++;
        }

        return numFactors;
    }
}

First your block if here : for (int i = 3; i <= Math.sqrt(n); i += 2) { if (n % i == 0) {...

should be out of the loop,

Secondly, you can perform this code with differents methodes like :

while (n % 2 == 0) { Current++; n /= 2; }

you can change it with : if(n % 2 ==0) { current++; n=n%2; } if(n % 2 ==0) { current++; n=n%2; }

Essentially, you should avoid conditions or instruction inside loops because of your methode:

(findNumberWithNPrimeFactors)

the complexity of your algorithm is the complexity of each loop (findNumberWithNPrimeFactors) X ( iteration number )

if you add a test or an affectation inside your loop you will get a + 1 ( Complexity (findNumberWithNPrimeFactors) X ( iteration number ) )

The following makes Math.sqrt superfluous, by dividing the n. Continuously comparing with a smaller square root might even be the slowest operation.

Then a do-while would be better style.

static int countPrimeFactorsCounter2(int n) {
    int count = 0; // using simple int
    if (n % 2 == 0) {
        ++count; // only add on first division
        do {
            n /= 2;
        } while (n % 2 == 0);
    }
    for (int i = 3; i <= n; i += 2) {
        if (n % i == 0) {
            count++; // only add on first division
            do {
                n /= i;
            } while (n % i == 0);
        }
    }
    //if (n > 2) {
    //    ++count;
    //}
    return count;
}

The logical fallacy of using the square root is based that with ∀ a, b: ab = n you only need to try for a < √n . However in an n-dividing loop you save just one single step. Notice that the sqrt is calculated at every odd number i.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM