#### Autovectorised FMA in JDK10

*Fused-multiply-add* (FMA) allows floating point expressions of the form `a * x + b`

to be evaluated in a single instruction, which is useful for numerical linear algebra. Despite the obvious appeal of FMA, JVM implementors are rather constrained when it comes to floating point arithmetic because Java programs are expected to be **reproducible** across versions and target architectures. FMA does not produce precisely the same result as the equivalent multiplication and addition instructions (this is caused by the compounding effect of rounding) so its use is a change in semantics rather than an optimisation; the user must opt in. To the best of my knowledge, support for FMA was first proposed in 2000, along with reorderable floating point operations, which would have been activated by a `fastfp`

keyword, but this proposal was withdrawn. In Java 9, the intrinsic `Math.fma`

was introduced to provide access to FMA for the first time.

### DAXPY Benchmark

A good use case to evaluate `Math.fma`

is *DAXPY* from the Basic Linear Algebra Subroutine library. The code below will compile with JDK9+

```
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Thread)
public class DAXPY {
double s;
@Setup(Level.Invocation)
public void init() {
s = ThreadLocalRandom.current().nextDouble();
}
@Benchmark
public void daxpyFMA(DoubleData state, Blackhole bh) {
double[] a = state.data1;
double[] b = state.data2;
for (int i = 0; i < a.length; ++i) {
a[i] = Math.fma(s, b[i], a[i]);
}
bh.consume(a);
}
@Benchmark
public void daxpy(DoubleData state, Blackhole bh) {
double[] a = state.data1;
double[] b = state.data2;
for (int i = 0; i < a.length; ++i) {
a[i] += s * b[i];
}
bh.consume(a);
}
}
```

Running this benchmark with Java 9, you may wonder why you bothered because the code is actually slower.

Benchmark | Mode | Threads | Samples | Score | Score Error (99.9%) | Unit | Param: size |
---|---|---|---|---|---|---|---|

daxpy | thrpt | 1 | 10 | 25.011242 | 2.259007 | ops/ms | 100000 |

daxpy | thrpt | 1 | 10 | 0.706180 | 0.046146 | ops/ms | 1000000 |

daxpyFMA | thrpt | 1 | 10 | 15.334652 | 0.271946 | ops/ms | 100000 |

daxpyFMA | thrpt | 1 | 10 | 0.623838 | 0.018041 | ops/ms | 1000000 |

This is because using `Math.fma`

disables autovectorisation. Taking a look at `PrintAssembly`

you can see that the naive `daxpy`

routine exploits AVX2, whereas `daxpyFMA`

reverts to scalar usage of SSE2.

// daxpy routine, code taken from main vectorised loop vmovdqu ymm1,ymmword ptr [r10+rdx*8+10h] vmulpd ymm1,ymm1,ymm2 vaddpd ymm1,ymm1,ymmword ptr [r8+rdx*8+10h] vmovdqu ymmword ptr [r8+rdx*8+10h],ymm1 // daxpyFMA routine vmovsd xmm2,qword ptr [rcx+r13*8+10h] vfmadd231sd xmm2,xmm0,xmm1 vmovsd qword ptr [rcx+r13*8+10h],xmm2

Not to worry, this seems to have been fixed in JDK 10. Since Java 10’s release is around the corner, there are early access builds available for all platforms. Rerunning this benchmark, FMA no longer incurs costs, and it doesn’t bring the performance boost some people might expect. The benefit is that there is less floating point error because the total number of roundings is halved.

Benchmark | Mode | Threads | Samples | Score | Score Error (99.9%) | Unit | Param: size |
---|---|---|---|---|---|---|---|

daxpy | thrpt | 1 | 10 | 2582.363228 | 116.637400 | ops/ms | 1000 |

daxpy | thrpt | 1 | 10 | 405.904377 | 32.364782 | ops/ms | 10000 |

daxpy | thrpt | 1 | 10 | 25.210111 | 1.671794 | ops/ms | 100000 |

daxpy | thrpt | 1 | 10 | 0.608660 | 0.112512 | ops/ms | 1000000 |

daxpyFMA | thrpt | 1 | 10 | 2650.264580 | 211.342407 | ops/ms | 1000 |

daxpyFMA | thrpt | 1 | 10 | 389.274693 | 43.567450 | ops/ms | 10000 |

daxpyFMA | thrpt | 1 | 10 | 24.941172 | 2.393358 | ops/ms | 100000 |

daxpyFMA | thrpt | 1 | 10 | 0.671310 | 0.158470 | ops/ms | 1000000 |

// vectorised daxpyFMA routine, code taken from main loop (you can still see the old code in pre/post loops) vmovdqu ymm0,ymmword ptr [r9+r13*8+10h] vfmadd231pd ymm0,ymm1,ymmword ptr [rbx+r13*8+10h] vmovdqu ymmword ptr [r9+r13*8+10h],ymm0

#### New Methods in Java 9: Math.fma and Arrays.mismatch

There are two noteworthy new methods in Java 9: `Arrays.mismatch`

and `Math.fma`

.

#### Arrays.mismatch

This method takes two primitive arrays, and returns the index of the first differing values. This effectively computes the longest common prefix of the two arrays. This is really quite useful, mostly for text processing but also for Bioinformatics (protein sequencing and so on, much more interesting than the sort of thing I work on). Having worked extensively with Apache HBase (where a vast majority of the API involves manipulating byte arrays) I can think of lots of less interesting use cases for this method.

Looking carefully, you can see that the method calls into the internal `ArraysSupport`

utility class, which will try to perform a vectorised mismatch (an intrinsic candidate). Since this will use AVX instructions, this is very fast; much faster than a handwritten loop.

Let’s measure the boost versus a handwritten loop, testing across a range of common prefices and array lengths of `byte[]`

.

```
@Benchmark
@CompilerControl(CompilerControl.Mode.DONT_INLINE)
public void Mismatch_Intrinsic(BytePrefixData data, Blackhole bh) {
bh.consume(Arrays.mismatch(data.data1, data.data2));
}
@Benchmark
@CompilerControl(CompilerControl.Mode.DONT_INLINE)
public void Mismatch_Handwritten(BytePrefixData data, Blackhole bh) {
byte[] data1 = data.data1;
byte[] data2 = data.data2;
int length = Math.min(data1.length, data2.length);
int mismatch = -1;
for (int i = 0; i < length; ++i) {
if (data1[i] != data2[i]) {
mismatch = i;
break;
}
}
bh.consume(mismatch);
}
```

The results speak for themselves. Irritatingly, there is some duplication in output because I haven’t figured out how to make JMH use a subset of the Cartesian product of its parameters.

Benchmark | (prefix) | (size) | Mode | Cnt | Score | Error | Units |
---|---|---|---|---|---|---|---|

Mismatch_Handwritten | 10 | 100 | thrpt | 10 | 22.360 | ± 0.938 | ops/us |

Mismatch_Handwritten | 10 | 1000 | thrpt | 10 | 2.459 | ± 0.256 | ops/us |

Mismatch_Handwritten | 10 | 10000 | thrpt | 10 | 0.255 | ± 0.009 | ops/us |

Mismatch_Handwritten | 100 | 100 | thrpt | 10 | 22.763 | ± 0.869 | ops/us |

Mismatch_Handwritten | 100 | 1000 | thrpt | 10 | 2.690 | ± 0.044 | ops/us |

Mismatch_Handwritten | 100 | 10000 | thrpt | 10 | 0.273 | ± 0.008 | ops/us |

Mismatch_Handwritten | 1000 | 100 | thrpt | 10 | 24.970 | ± 0.713 | ops/us |

Mismatch_Handwritten | 1000 | 1000 | thrpt | 10 | 2.791 | ± 0.066 | ops/us |

Mismatch_Handwritten | 1000 | 10000 | thrpt | 10 | 0.281 | ± 0.007 | ops/us |

Mismatch_Intrinsic | 10 | 100 | thrpt | 10 | 89.169 | ± 2.759 | ops/us |

Mismatch_Intrinsic | 10 | 1000 | thrpt | 10 | 26.995 | ± 0.501 | ops/us |

Mismatch_Intrinsic | 10 | 10000 | thrpt | 10 | 3.553 | ± 0.065 | ops/us |

Mismatch_Intrinsic | 100 | 100 | thrpt | 10 | 83.037 | ± 5.590 | ops/us |

Mismatch_Intrinsic | 100 | 1000 | thrpt | 10 | 26.249 | ± 0.714 | ops/us |

Mismatch_Intrinsic | 100 | 10000 | thrpt | 10 | 3.523 | ± 0.122 | ops/us |

Mismatch_Intrinsic | 1000 | 100 | thrpt | 10 | 87.921 | ± 6.566 | ops/us |

Mismatch_Intrinsic | 1000 | 1000 | thrpt | 10 | 25.812 | ± 0.442 | ops/us |

Mismatch_Intrinsic | 1000 | 10000 | thrpt | 10 | 4.177 | ± 0.059 | ops/us |

Why is there such a big difference? Look at how the score decreases as a function of array length, even when the common prefix, and therefore the number of comparisons required, is small: clearly the performance of this algorithm depends on the efficiency of memory access. `Arrays.mismatch`

optimises this, reading qwords of the array into SIMD registers. Working one `long`

at a time, it is possible to compute the XOR in a single instruction to determine if it’s even necessary to look at each byte.

java/util/ArraysSupport.vectorizedMismatch(Ljava/lang/Object;JLjava/lang/Object;JII)I [0x000002bd9215a820, 0x000002bd9215aa78] 600 bytes Argument 0 is unknown.RIP: 0x2bd9215a820 Code size: 0x00000258 [Entry Point] [Verified Entry Point] [Constants] # {method} {0x000002bda79cbf68} 'vectorizedMismatch' '(Ljava/lang/Object;JLjava/lang/Object;JII)I' in 'java/util/ArraysSupport' # parm0: rdx:rdx = 'java/lang/Object' # parm1: r8:r8 = long # parm2: r9:r9 = 'java/lang/Object' # parm3: rdi:rdi = long # parm4: rsi = int # parm5: rcx = int # [sp+0x60] (sp of caller) 0x000002bd9215a820: mov dword ptr [rsp+0ffffffffffff9000h],eax ;...89 ;...84 ;...24 ;...00 ;...90 ;...ff ;...ff 0x000002bd9215a827: push rbp ;...55 0x000002bd9215a828: sub rsp,50h ;...48 ;...83 ;...ec ;...50 ;*synchronization entry ; - java.util.ArraysSupport::vectorizedMismatch@-1 (line 120) 0x000002bd9215a82c: mov r10,rdi ;...4c ;...8b ;...d7 0x000002bd9215a82f: vmovq xmm2,r9 ;...c4 ;...c1 ;...f9 ;...6e ;...d1 0x000002bd9215a834: vmovq xmm1,rdx ;...c4 ;...e1 ;...f9 ;...6e ;...ca 0x000002bd9215a839: mov r14d,ecx ;...44 ;...8b ;...f1 0x000002bd9215a83c: vmovd xmm0,esi ;...c5 ;...f9 ;...6e ;...c6 0x000002bd9215a840: mov r9d,3h ;...41 ;...b9 ;...03 ;...00 ;...00 ;...00 0x000002bd9215a846: sub r9d,ecx ;...44 ;...2b ;...c9 ;*isub {reexecute=0 rethrow=0 return_oop=0} ; - java.util.ArraysSupport::vectorizedMismatch@5 (line 120) 0x000002bd9215a849: mov edx,esi ;...8b ;...d6 0x000002bd9215a84b: mov ecx,r9d ;...41 ;...8b ;...c9 0x000002bd9215a84e: sar edx,cl ;...d3 ;...fa ;*ishr {reexecute=0 rethrow=0 return_oop=0} ; - java.util.ArraysSupport::vectorizedMismatch@17 (line 122) 0x000002bd9215a850: mov eax,1h ;...b8 ;...01 ;...00 ;...00 ;...00 0x000002bd9215a855: xor edi,edi ;...33 ;...ff 0x000002bd9215a857: test edx,edx ;...85 ;...d2 0x000002bd9215a859: jle 2bd9215a97ah ;...0f ;...8e ;...1b ;...01 ;...00 ;...00

The code for this benchmark is at github.

#### Math.fma

In comparison to users of some languages, Java programmers are lackadaisical about floating point errors. It’s a good job that historically Java hasn’t been considered suitable for the implementation of numerical algorithms. But all of a sudden there is a revolution of data science on the JVM, albeit mostly driven by the Scala community, with JVM implementations of structures like recurrent neural networks abounding. It matters less for machine learning than root finding, but how accurate can these implementations be without JVM level support for minimising the propagation floating point errors? With `Math.fma`

this is improving, by allowing two common operations to be performed before rounding.

`Math.fma`

fuses a multiplication and an addition into a single floating point operation to compute expressions like . This has two key benefits:

- There’s only one operation, and only one rounding error
- This is explicitly supported in AVX2 by the VFMADD* instructions

#### Newton’s Method

To investigate any superior suppression of floating point errors, I use a toy implementation of Newton’s method to compute the root of a quadratic equation, which any teenager could calculate analytically (the error is easy to quantify).

I compare these two implementations for (there is a repeated root at 1.5) to get an idea for the error (defined by ) after a large number of iterations.

I implemented this using FMA:

```
public class NewtonsMethodFMA {
private final double[] coefficients;
public NewtonsMethodFMA(double[] coefficients) {
this.coefficients = coefficients;
}
public double evaluateF(double x) {
double f = 0D;
int power = coefficients.length - 1;
for (int i = 0; i < coefficients.length; ++i) {
f = Math.fma(coefficients[i], Math.pow(x, power--), f);
}
return f;
}
public double evaluateDF(double x) {
double df = 0D;
int power = coefficients.length - 2;
for (int i = 0; i < coefficients.length - 1; ++i) {
df = Math.fma((power + 1) * coefficients[i], Math.pow(x, power--), df);
}
return df;
}
public double solve(double initialEstimate, int maxIterations) {
double result = initialEstimate;
for (int i = 0; i < maxIterations; ++i) {
result -= evaluateF(result)/evaluateDF(result);
}
return result;
}
}
```

And an implementation with normal operations:

```
public class NewtonsMethod {
private final double[] coefficients;
public NewtonsMethod(double[] coefficients) {
this.coefficients = coefficients;
}
public double evaluateF(double x) {
double f = 0D;
int power = coefficients.length - 1;
for (int i = 0; i < coefficients.length; ++i) {
f += coefficients[i] * Math.pow(x, power--);
}
return f;
}
public double evaluateDF(double x) {
double df = 0D;
int power = coefficients.length - 2;
for (int i = 0; i < coefficients.length - 1; ++i) {
df += (power + 1) * coefficients[i] * Math.pow(x, power--);
}
return df;
}
public double solve(double initialEstimate, int maxIterations) {
double result = initialEstimate;
for (int i = 0; i < maxIterations; ++i) {
result -= evaluateF(result)/evaluateDF(result);
}
return result;
}
}
```

When I run this code for 1000 iterations, the FMA version results in 1.5000000083575202, whereas the vanilla version results in 1.500000017233207. It’s completely unscientific, but seems plausible and confirms my prejudice so… In fact, it’s not that simple, and over a range of initial values, there is only a very small difference in FMA’s favour. There’s not even a performance improvement – clearly this method wasn’t added so you can start implementing numerical root finding algorithms – the key takeaway is that the results are slightly different because a different rounding strategy has been used.

Benchmark | (maxIterations) | Mode | Cnt | Score | Error | Units |
---|---|---|---|---|---|---|

NM_FMA | 100 | thrpt | 10 | 93.805 | ± 5.174 | ops/ms |

NM_FMA | 1000 | thrpt | 10 | 9.420 | ± 1.169 | ops/ms |

NM_FMA | 10000 | thrpt | 10 | 0.962 | ± 0.044 | ops/ms |

NM_HandWritten | 100 | thrpt | 10 | 93.457 | ± 5.048 | ops/ms |

NM_HandWritten | 1000 | thrpt | 10 | 9.274 | ± 0.483 | ops/ms |

NM_HandWritten | 10000 | thrpt | 10 | 0.928 | ± 0.041 | ops/ms |