Autovectorised FMA in JDK10

Fused-multiply-add (FMA) allows floating point expressions of the form a * x + b to be evaluated in a single instruction, which is useful for numerical linear algebra. Despite the obvious appeal of FMA, JVM implementors are rather constrained when it comes to floating point arithmetic because Java programs are expected to be reproducible across versions and target architectures. FMA does not produce precisely the same result as the equivalent multiplication and addition instructions (this is caused by the compounding effect of rounding) so its use is a change in semantics rather than an optimisation; the user must opt in. To the best of my knowledge, support for FMA was first proposed in 2000, along with reorderable floating point operations, which would have been activated by a fastfp keyword, but this proposal was withdrawn. In Java 9, the intrinsic Math.fma was introduced to provide access to FMA for the first time.

DAXPY Benchmark

A good use case to evaluate Math.fma is DAXPY from the Basic Linear Algebra Subroutine library. The code below will compile with JDK9+

@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Thread)
public class DAXPY {
  
  double s;

  @Setup(Level.Invocation)
  public void init() {
    s = ThreadLocalRandom.current().nextDouble();
  }

  @Benchmark
  public void daxpyFMA(DoubleData state, Blackhole bh) {
    double[] a = state.data1;
    double[] b = state.data2;
    for (int i = 0; i < a.length; ++i) {
      a[i] = Math.fma(s, b[i], a[i]);
    }
    bh.consume(a);
  }

  @Benchmark
  public void daxpy(DoubleData state, Blackhole bh) {
    double[] a = state.data1;
    double[] b = state.data2;
    for (int i = 0; i < a.length; ++i) {
      a[i] += s * b[i];
    }
    bh.consume(a);
  }
}

Running this benchmark with Java 9, you may wonder why you bothered because the code is actually slower.

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
daxpy thrpt 1 10 25.011242 2.259007 ops/ms 100000
daxpy thrpt 1 10 0.706180 0.046146 ops/ms 1000000
daxpyFMA thrpt 1 10 15.334652 0.271946 ops/ms 100000
daxpyFMA thrpt 1 10 0.623838 0.018041 ops/ms 1000000

This is because using Math.fma disables autovectorisation. Taking a look at PrintAssembly you can see that the naive daxpy routine exploits AVX2, whereas daxpyFMA reverts to scalar usage of SSE2.

// daxpy routine, code taken from main vectorised loop
vmovdqu ymm1,ymmword ptr [r10+rdx*8+10h]
vmulpd  ymm1,ymm1,ymm2
vaddpd  ymm1,ymm1,ymmword ptr [r8+rdx*8+10h]
vmovdqu ymmword ptr [r8+rdx*8+10h],ymm1

// daxpyFMA routine
vmovsd  xmm2,qword ptr [rcx+r13*8+10h]
vfmadd231sd xmm2,xmm0,xmm1
vmovsd  qword ptr [rcx+r13*8+10h],xmm2

Not to worry, this seems to have been fixed in JDK 10. Since Java 10’s release is around the corner, there are early access builds available for all platforms. Rerunning this benchmark, FMA no longer incurs costs, and it doesn’t bring the performance boost some people might expect. The benefit is that there is less floating point error because the total number of roundings is halved.

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
daxpy thrpt 1 10 2582.363228 116.637400 ops/ms 1000
daxpy thrpt 1 10 405.904377 32.364782 ops/ms 10000
daxpy thrpt 1 10 25.210111 1.671794 ops/ms 100000
daxpy thrpt 1 10 0.608660 0.112512 ops/ms 1000000
daxpyFMA thrpt 1 10 2650.264580 211.342407 ops/ms 1000
daxpyFMA thrpt 1 10 389.274693 43.567450 ops/ms 10000
daxpyFMA thrpt 1 10 24.941172 2.393358 ops/ms 100000
daxpyFMA thrpt 1 10 0.671310 0.158470 ops/ms 1000000

// vectorised daxpyFMA routine, code taken from main loop (you can still see the old code in pre/post loops)
vmovdqu ymm0,ymmword ptr [r9+r13*8+10h]
vfmadd231pd ymm0,ymm1,ymmword ptr [rbx+r13*8+10h]
vmovdqu ymmword ptr [r9+r13*8+10h],ymm0

8 Comments

  • Maaartinus says:

    Do I understand it right that there’s no speed benefit from FMA in Java? The smaller rounding error oftentimes doesn’t matter.

    Is better performance in assembly possible? Or is the ALU saturated, or maybe the L3 cache or memory? AFAIK, the throughput of FMA is two instructions per cycle. No idea about L3/memory; I guess the smaller size fits in L3 and the bigger doesn’t. Could you try much smaller sizes?

  • @Maaartinus I’ve updated the post with results for lengths 1000 and 10000. There’s not a lot of difference. Note that I am not using powers of 2 for lengths because I think it is important to capture the costs of pre or post loops.

  • Maaartinus says:

    Thank you. The score for 10000 is 400 where I’d expect 250; that’s very strange.

    Concerning the cost of pre- and post-loops, my intuition says that they may matter for tens or maybe hundreds of iterations, but not for thousands.

    • You can run the benchmark at github – build it then run java -jar simd.jar --include DAXPY.*. I’d be interested to see your numbers.

      By the way, these benchmarks are just a personal hobby that I write about, and there may be errors, but I’ve found a few effects reproducible enough to warrant new OpenJDK tickets.

      • I think this is because of poor cache aligment, leading to long scalar pre and post loops. In fact I looked at lengths 1000 (long scalar loop) vs 1024 (no scalar loop, fully vectorised) and saw this:

        Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
        daxpy thrpt 1 10 7174.402950 159.243965 ops/ms 1024
        daxpy thrpt 1 10 2513.730256 228.738629 ops/ms 1000
        daxpyFMA thrpt 1 10 7043.651612 516.362625 ops/ms 1024
        daxpyFMA thrpt 1 10 2584.155841 229.136919 ops/ms 1000

        Pretty amazing. Make sure you size your arrays as powers of 2, I guess.

  • Maaartinus says:

    I’d need to setup a new VM for testing this and that’s complicated at the moment. Actually, my CPU is not up to the task either.

    Sure, you can make mistakes, catching them is what the commenters are for. 😉

    I’d say, the vectorisation is half-arsed at best. AFAIK there’s nothing what could stop it working pretty well for size=1000, but your benchmark shows nearly a threefold slowdown. I’ve never heard about any cache alignment needs above 16 bytes. Anyway, the pre- and post-loops should take care about a few (<10) array elements and everything else should run at full speed. Or am I missing something substantial?

  • Maaartinus says:

    Anyway, do we agree that the performances for size=1000 and 1024 should be about the same? If so (and you’re using the most recent version), then you may want to file a bug as JDK-8181616 is marked fixed and I wouldn’t call it working.

Leave a Reply

Your email address will not be published. Required fields are marked *