Vectorised Algorithms in Java

There has been a Cambrian explosion of JVM data technologies in recent years. It’s all very exciting, but is the JVM really competitive with C in this area? I would argue that there is a reason Apache Arrow is polyglot, and it’s not just interoperability with Python. To pick on one project impressive enough to be thriving after seven years, if you’ve actually used Apache Spark you will be aware that it looks fastest next to its predecessor, MapReduce. Big data is a lot like teenage sex: everybody talks about it, nobody really knows how to do it, and everyone keeps their embarrassing stories to themselves. In games of incomplete information, it’s possible to overestimate the competence of others: nobody opens up about how slow their Spark jobs really are because there’s a risk of looking stupid.

If it can be accepted that Spark is inefficient, the question becomes is Spark fundamentally inefficient? Flare provides a drop-in replacement for Spark’s backend, but replaces JIT compiled code with highly efficient native code, yielding order of magnitude improvements in job throughput. Some of Flare’s gains come from generating specialised code, but the rest comes from just generating better native code than C2 does. If Flare validates Spark’s execution model, perhaps it raises questions about the suitability of the JVM for high throughput data processing.

I think this will change radically in the coming years. I think the most important reason is the advent of explicit support for SIMD provided by the vector API, which is currently incubating in Project Panama. Once the vector API is complete, I conjecture that projects like Spark will be able to profit enormously from it. This post takes a look at the API in its current state and ignores performance.

Why Vectorisation?

Assuming a flat processor frequency, throughput is improved by a combination of executing many instructions per cycle (pipelining) and processing multiple data items per instruction (SIMD). SIMD instruction sets are provided by Intel as the various generations of SSE and AVX. If throughput is the only goal, maximising SIMD may even be worth reducing the frequency, which can happen on Intel chips when using AVX. Vectorisation allows throughput to be increased by the use of SIMD instructions.

Analytical workloads are particularly suitable for vectorisation, especially over columnar data, because they typically involve operations consuming the entire range of a few numerical attributes of a data set. Vectorised analytical processing with filters is explicitly supported by vector masks, and vectorisation is also profitable for operations on indices typically performed for filtering prior to calculations. I don’t actually need to make a strong case for the impact of vectorisation on analytical workloads: just read the work of top researchers like Daniel Abadi and Daniel Lemire.

Vectorisation in the JVM

C2 provides quite a lot of autovectorisation, which works very well sometimes, but the support is limited and brittle. I have written about this several times. Because AVX can reduce the processor frequency, it’s not always profitable to vectorise, so compilers employ cost models to decide when they should do so. Such cost models require platform specific calibration, and sometimes C2 can get it wrong. Sometimes, specifically in the case of floating point operations, using SIMD conflicts with the JLS, and the code C2 generates can be quite inefficient. In general, data parallel code can be better optimised by C compilers, such as GCC, than C2 because there are fewer constraints, and there is a larger budget for analysis at compile time. This all makes having intrinsics very appealing, and as a user I would like to be able to:

  1. Bypass JLS floating point constraints.
  2. Bypass cost model based decisions.
  3. Avoid JNI at all costs.
  4. Use a modern “object-functional” style. SIMD intrinsics in C are painful.

There is another attempt to provide SIMD intrinsics to JVM users via LMS, a framework for writing programs which write programs, designed by Tiark Rompf (who is also behind Flare). This work is very promising (I have written about it before), but it uses JNI. It’s only at the prototype stage, but currently the intrinsics are auto-generated from XML definitions, which leads to a one-to-one mapping to the intrinsics in immintrin.h, yielding a similar programming experience. This could likely be improved a lot, but the reliance on JNI is fundamental, albeit with minimal boundary crossing.

I am quite excited by the vector API in Project Panama because it looks like it will meet all of these requirements, at least to some extent. It remains to be seen quite how far the implementors will go in the direction of associative floating point arithmetic, but it has to opt out of JLS floating point semantics to some extent, which I think is progressive.

The Vector API

Disclaimer: Everything below is based on my experience with a recent build of the experimental code in the Project Panama fork of OpenJDK. I am not affiliated with the design or implementation of this API, may not be using it properly, and it may change according to its designers’ will before it is released!

To understand the vector API you need to know that there are different register widths and different SIMD instruction sets. Because of my area of work, and 99% of the server market is Intel, I am only interested in AVX, but ARM have their own implementations with different maximum register sizes, which presumably need to be handled by a JVM vector API. On Intel CPUs, SSE instruction sets use up to 128 bit registers (xmm, four ints), AVX and AVX2 use up to 256 bit registers (ymm, eight ints), and AVX512 use up to 512 bit registers (zmm, sixteen ints).

The instruction sets are typed, and instructions designed to operate on packed doubles can’t operate on packed ints without explicit casting. This is modeled by the interface Vector<Shape>, parametrised by the Shape interface which models the register width.

The types of the vector elements is modeled by abstract element type specific classes such as IntVector. At the leaves of the hierarchy are the concrete classes specialised both to element type and register width, such as IntVector256 which extends IntVector<Shapes.S256Bit>.

Since EJB, the word factory has been a dirty word, which might be why the word species is used in this API. To create a IntVector<Shapes.S256Bit>, you can create the factory/species as follows:

public static final IntVector.IntSpecies<Shapes.S256Bit> YMM_INT = 
          (IntVector.IntSpecies<Shapes.S256Bit>) Vector.species(int.class, Shapes.S_256_BIT);

There are now various ways to create a vector from the species, which all have their use cases. First, you can load vectors from arrays: imagine you want to calculate the bitwise intersection of two int[]s. This can be written quite cleanly, without any shape/register information.


public static int[] intersect(int[] left, int[] right) {
    assert left.length == right.length;
    int[] result = new int[left.length];
    for (int i = 0; i < left.length; i += YMM_INT.length()) {
      YMM_INT.fromArray(left, i)
             .and(YMM_INT.fromArray(right, i))
             .intoArray(result, i);
    }
}

A common pattern in vectorised code is to broadcast a variable into a vector, for instance, to facilitate the multiplication of a vector by a scalar.

IntVector<Shapes.S256Bit> multiplier = YMM_INT.broadcast(x);

Or to create a vector from some scalars, for instance in a lookup table.

IntVector<Shapes.S256Bit> vector = YMM_INT.scalars(0, 1, 2, 3, 4, 5, 6, 7);

A zero vector can be created from a species:

IntVector<Shapes.S256Bit> zero = YMM_INT.zero();

The big split in the class hierarchy is between integral and floating point types. Integral types have meaningful bitwise operations (I am looking forward to trying to write a vectorised population count algorithm), which are absent from FloatVector and DoubleVector, and there is no concept of fused-multiply-add for integral types, so there is obviously no IntVector.fma. The common subset of operations is arithmetic, casting and loading/storing operations.

I generally like the API a lot: it feels familiar to programming with streams, but on the other hand, it isn’t too far removed from traditional intrinsics. Below is an implementation of a fast matrix multiplication written in C, and below it is the same code written with the vector API:


static void mmul_tiled_avx_unrolled(const int n, const float *left, const float *right, float *result) {
    const int block_width = n >= 256 ? 512 : 256;
    const int block_height = n >= 512 ? 8 : n >= 256 ? 16 : 32;
    for (int column_offset = 0; column_offset < n; column_offset += block_width) {
        for (int row_offset = 0; row_offset < n; row_offset += block_height) {
            for (int i = 0; i < n; ++i) {
                for (int j = column_offset; j < column_offset + block_width && j < n; j += 64) {
                    __m256 sum1 = _mm256_load_ps(result + i * n + j);
                    __m256 sum2 = _mm256_load_ps(result + i * n + j + 8);
                    __m256 sum3 = _mm256_load_ps(result + i * n + j + 16);
                    __m256 sum4 = _mm256_load_ps(result + i * n + j + 24);
                    __m256 sum5 = _mm256_load_ps(result + i * n + j + 32);
                    __m256 sum6 = _mm256_load_ps(result + i * n + j + 40);
                    __m256 sum7 = _mm256_load_ps(result + i * n + j + 48);
                    __m256 sum8 = _mm256_load_ps(result + i * n + j + 56);
                    for (int k = row_offset; k < row_offset + block_height && k < n; ++k) {
                        __m256 multiplier = _mm256_set1_ps(left[i * n + k]);
                        sum1 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j), sum1);
                        sum2 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 8), sum2);
                        sum3 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 16), sum3);
                        sum4 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 24), sum4);
                        sum5 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 32), sum5);
                        sum6 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 40), sum6);
                        sum7 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 48), sum7);
                        sum8 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 56), sum8);
                    }
                    _mm256_store_ps(result + i * n + j, sum1);
                    _mm256_store_ps(result + i * n + j + 8, sum2);
                    _mm256_store_ps(result + i * n + j + 16, sum3);
                    _mm256_store_ps(result + i * n + j + 24, sum4);
                    _mm256_store_ps(result + i * n + j + 32, sum5);
                    _mm256_store_ps(result + i * n + j + 40, sum6);
                    _mm256_store_ps(result + i * n + j + 48, sum7);
                    _mm256_store_ps(result + i * n + j + 56, sum8);
                }
            }
        }
    }
}


  private static void mmul(int n, float[] left, float[] right, float[] result) {
    int blockWidth = n >= 256 ? 512 : 256;
    int blockHeight = n >= 512 ? 8 : n >= 256 ? 16 : 32;
    for (int columnOffset = 0; columnOffset < n; columnOffset += blockWidth) {
      for (int rowOffset = 0; rowOffset < n; rowOffset += blockHeight) {
        for (int i = 0; i < n; ++i) {
          for (int j = columnOffset; j < columnOffset + blockWidth && j < n; j += 64) {
            var sum1 = YMM_FLOAT.fromArray(result, i * n + j);
            var sum2 = YMM_FLOAT.fromArray(result, i * n + j + 8);
            var sum3 = YMM_FLOAT.fromArray(result, i * n + j + 16);
            var sum4 = YMM_FLOAT.fromArray(result, i * n + j + 24);
            var sum5 = YMM_FLOAT.fromArray(result, i * n + j + 32);
            var sum6 = YMM_FLOAT.fromArray(result, i * n + j + 40);
            var sum7 = YMM_FLOAT.fromArray(result, i * n + j + 48);
            var sum8 = YMM_FLOAT.fromArray(result, i * n + j + 56);
            for (int k = rowOffset; k < rowOffset + blockHeight && k < n; ++k) {
              var multiplier = YMM_FLOAT.broadcast(left[i * n + k]);
              sum1 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j), sum1);
              sum2 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 8), sum2);
              sum3 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 16), sum3);
              sum4 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 24), sum4);
              sum5 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 32), sum5);
              sum6 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 40), sum6);
              sum7 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 48), sum7);
              sum8 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 56), sum8);
            }
            sum1.intoArray(result, i * n + j);
            sum2.intoArray(result, i * n + j + 8);
            sum3.intoArray(result, i * n + j + 16);
            sum4.intoArray(result, i * n + j + 24);
            sum5.intoArray(result, i * n + j + 32);
            sum6.intoArray(result, i * n + j + 40);
            sum7.intoArray(result, i * n + j + 48);
            sum8.intoArray(result, i * n + j + 56);
          }
        }
      }
    }
  }

They just aren’t that different, and it’s easy to translate between the two. I wouldn’t expect it to be fast yet though. I have no idea what the scope of work involved in implementing all of the C2 intrinsics to make this possible is, but I assume it’s vast. The class jdk.incubator.vector.VectorIntrinsics seems to contain all of the intrinsics implemented so far, and it doesn’t contain every operation used in my array multiplication code. There is also the question of value types and vector box elimination. I will probably look at this again in the future when more of the JIT compiler work has been done, but I’m starting to get very excited about the possibility of much faster JVM based data processing.

I have written various benchmarks for useful analytical subroutines using the Vector API at github.

Sum of Squares

Streams and lambdas, especially the limited support offered for primitive types, are a fantastic addition to the Java language. They’re not supposed to be fast, but how do these features compare to a good old for loop? For a simple calculation amenable to instruction level parallelism, I compare modern and traditional implementations and observe the differences in instructions generated.

Sum of Squares

The sum of squares is the building block of a linear regression analysis so is ubiquitous in statistical computing. It is associative and therefore data-parallel. I compare four implementations: a sequential stream wrapping an array, a parallel stream wrapping an array, a generative sequential stream and a traditional for loop. The benchmark code is on github.


  @Param({"1024", "8192"})
  int size;

  private double[] data;

  @Setup(Level.Iteration)
  public void init() {
    this.data = createDoubleArray(size);
  }

  @Benchmark
  public double SS_SequentialStream() {
    return DoubleStream.of(data)
            .map(x -> x * x)
            .reduce((x, y) -> x + y)
            .orElse(0D);
  }

  @Benchmark
  public double SS_ParallelStream() {
    return DoubleStream.of(data)
            .parallel()
            .map(x -> x * x)
            .reduce((x, y) -> x + y)
            .orElse(0);
  }

  @Benchmark
  public double SS_ForLoop() {
    double result = 0D;
    for (int i = 0; i < data.length; ++i) {
      result += data[i] * data[i];
    }
    return result;
  }

  @Benchmark
  public double SS_GenerativeSequentialStream() {
    return IntStream.iterate(0, i -> i < size, i -> i + 1)
            .mapToDouble(i -> data[i])
            .map(x -> x * x)
            .reduce((x, y) -> x + y)
            .orElse(0);
  }

I must admit I prefer the readability of the stream versions, but let’s see if there is a comedown after the syntactic sugar rush.

Running a Benchmark

I compare the four implementations on an array of one million doubles. I am using JDK 9.0.1, VM 9.0.1+11 on a fairly powerful laptop with 8 processors:

$ cat /proc/cpuinfo
processor       : 0
vendor_id       : GenuineIntel
cpu family      : 6
model           : 94
model name      : Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz
stepping        : 3
cpu MHz         : 2592.000
cache size      : 256 KB
physical id     : 0
siblings        : 8
core id         : 0
cpu cores       : 4
apicid          : 0
initial apicid  : 0
fpu             : yes
fpu_exception   : yes
cpuid level     : 22
wp              : yes
flags           : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe pni dtes64 monitor ds_cpl vmx est tm2 ssse3 fma cx16 xtpr pdcm sse4_1 sse4_2 x2apic movbe popcnt aes xsave osxsave avx f16c rdrand lahf_lm ida arat epb xsaveopt pln pts dtherm fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx rdseed adx smap clflushopt
clflush size    : 64
cache_alignment : 64
address sizes   : 39 bits physical, 48 bits virtual
power management:

Before running the benchmark we might expect the for loop and stream to have similar performance, and the parallel version to be about eight times faster (though remember that the arrays aren’t too big). The generative version is very similar to the for loop so a slow down might not be expected.

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
SS_ForLoop thrpt 1 10 258351.774491 39797.567968 ops/s 1024
SS_ForLoop thrpt 1 10 29463.408428 4814.826388 ops/s 8192
SS_GenerativeSequentialStream thrpt 1 10 219699.607567 9095.569546 ops/s 1024
SS_GenerativeSequentialStream thrpt 1 10 28351.900454 828.513989 ops/s 8192
SS_ParallelStream thrpt 1 10 22827.821827 2826.577213 ops/s 1024
SS_ParallelStream thrpt 1 10 23230.623610 273.415352 ops/s 8192
SS_SequentialStream thrpt 1 10 225431.985145 9051.538442 ops/s 1024
SS_SequentialStream thrpt 1 10 29123.734157 1333.721437 ops/s 8192

The for loop and stream are similar. The parallel version is a long way behind (yes that’s right: more threads less power), but exhibits constant scaling (incidentally, a measurement like this is a good way to guess the minimum unit of work in a parallelised implementation). If the data is large it could become profitable to use it. The generative stream is surprisingly good, almost as good as the version that wraps the array, though there is a fail-safe way to slow it down: add a limit clause to the method chain (try it…).

Profiling with perfasm, it is clear that the for loop body is being vectorised, but only the loads and multiplications are done in parallel – the complicated string of SSE instructions is the reduction, which must be done in order.

<-- unrolled load -->
  0.01%    0x00000243d8969170: vmovdqu ymm1,ymmword ptr [r11+r8*8+0f0h]
  0.07%    0x00000243d896917a: vmovdqu ymm2,ymmword ptr [r11+r8*8+0d0h]
  0.75%    0x00000243d8969184: vmovdqu ymm3,ymmword ptr [r11+r8*8+0b0h]
  0.01%    0x00000243d896918e: vmovdqu ymm4,ymmword ptr [r11+r8*8+90h]
  0.02%    0x00000243d8969198: vmovdqu ymm5,ymmword ptr [r11+r8*8+70h]
  0.03%    0x00000243d896919f: vmovdqu ymm6,ymmword ptr [r11+r8*8+50h]
  0.77%    0x00000243d89691a6: vmovdqu ymm10,ymmword ptr [r11+r8*8+30h]
  0.02%    0x00000243d89691ad: vmovdqu ymm7,ymmword ptr [r11+r8*8+10h]
<-- multiplication starts -->
  0.01%    0x00000243d89691b4: vmulpd  ymm1,ymm1,ymm1
  0.02%    0x00000243d89691b8: vmovdqu ymmword ptr [rsp+28h],ymm1
  0.76%    0x00000243d89691be: vmulpd  ymm15,ymm7,ymm7
  0.00%    0x00000243d89691c2: vmulpd  ymm12,ymm2,ymm2
  0.01%    0x00000243d89691c6: vmulpd  ymm7,ymm3,ymm3
  0.02%    0x00000243d89691ca: vmulpd  ymm8,ymm4,ymm4
  0.72%    0x00000243d89691ce: vmulpd  ymm9,ymm5,ymm5
  0.00%    0x00000243d89691d2: vmulpd  ymm11,ymm6,ymm6
  0.01%    0x00000243d89691d6: vmulpd  ymm13,ymm10,ymm10
<-- multiplication ends here, scalar reduction starts -->
  0.03%    0x00000243d89691db: vaddsd  xmm0,xmm0,xmm15
  0.72%    0x00000243d89691e0: vpshufd xmm5,xmm15,0eh
  0.01%    0x00000243d89691e6: vaddsd  xmm0,xmm0,xmm5
  2.14%    0x00000243d89691ea: vextractf128 xmm6,ymm15,1h
  0.03%    0x00000243d89691f0: vaddsd  xmm0,xmm0,xmm6
  3.21%    0x00000243d89691f4: vpshufd xmm5,xmm6,0eh
  0.02%    0x00000243d89691f9: vaddsd  xmm0,xmm0,xmm5
  2.81%    0x00000243d89691fd: vaddsd  xmm0,xmm0,xmm13
  2.82%    0x00000243d8969202: vpshufd xmm5,xmm13,0eh
  0.03%    0x00000243d8969208: vaddsd  xmm0,xmm0,xmm5
  2.87%    0x00000243d896920c: vextractf128 xmm6,ymm13,1h
  0.01%    0x00000243d8969212: vaddsd  xmm0,xmm0,xmm6
  3.03%    0x00000243d8969216: vpshufd xmm5,xmm6,0eh
  0.03%    0x00000243d896921b: vaddsd  xmm0,xmm0,xmm5
  2.94%    0x00000243d896921f: vaddsd  xmm0,xmm0,xmm11
  2.70%    0x00000243d8969224: vpshufd xmm5,xmm11,0eh
  0.03%    0x00000243d896922a: vaddsd  xmm0,xmm0,xmm5
  2.98%    0x00000243d896922e: vextractf128 xmm6,ymm11,1h
  0.01%    0x00000243d8969234: vaddsd  xmm0,xmm0,xmm6
  3.11%    0x00000243d8969238: vpshufd xmm5,xmm6,0eh
  0.03%    0x00000243d896923d: vaddsd  xmm0,xmm0,xmm5
  2.95%    0x00000243d8969241: vaddsd  xmm0,xmm0,xmm9
  2.61%    0x00000243d8969246: vpshufd xmm5,xmm9,0eh
  0.02%    0x00000243d896924c: vaddsd  xmm0,xmm0,xmm5
  2.89%    0x00000243d8969250: vextractf128 xmm6,ymm9,1h
  0.04%    0x00000243d8969256: vaddsd  xmm0,xmm0,xmm6
  3.13%    0x00000243d896925a: vpshufd xmm5,xmm6,0eh
  0.01%    0x00000243d896925f: vaddsd  xmm0,xmm0,xmm5
  2.96%    0x00000243d8969263: vaddsd  xmm0,xmm0,xmm8
  2.83%    0x00000243d8969268: vpshufd xmm4,xmm8,0eh
  0.01%    0x00000243d896926e: vaddsd  xmm0,xmm0,xmm4
  3.00%    0x00000243d8969272: vextractf128 xmm10,ymm8,1h
  0.02%    0x00000243d8969278: vaddsd  xmm0,xmm0,xmm10
  3.13%    0x00000243d896927d: vpshufd xmm4,xmm10,0eh
  0.01%    0x00000243d8969283: vaddsd  xmm0,xmm0,xmm4
  3.01%    0x00000243d8969287: vaddsd  xmm0,xmm0,xmm7
  2.95%    0x00000243d896928b: vpshufd xmm1,xmm7,0eh
  0.02%    0x00000243d8969290: vaddsd  xmm0,xmm0,xmm1
  3.06%    0x00000243d8969294: vextractf128 xmm2,ymm7,1h
  0.01%    0x00000243d896929a: vaddsd  xmm0,xmm0,xmm2
  3.07%    0x00000243d896929e: vpshufd xmm1,xmm2,0eh
  0.02%    0x00000243d89692a3: vaddsd  xmm0,xmm0,xmm1
  3.07%    0x00000243d89692a7: vaddsd  xmm0,xmm0,xmm12
  2.92%    0x00000243d89692ac: vpshufd xmm3,xmm12,0eh
  0.02%    0x00000243d89692b2: vaddsd  xmm0,xmm0,xmm3
  3.11%    0x00000243d89692b6: vextractf128 xmm1,ymm12,1h
  0.01%    0x00000243d89692bc: vaddsd  xmm0,xmm0,xmm1
  3.02%    0x00000243d89692c0: vpshufd xmm3,xmm1,0eh
  0.02%    0x00000243d89692c5: vaddsd  xmm0,xmm0,xmm3
  2.97%    0x00000243d89692c9: vmovdqu ymm1,ymmword ptr [rsp+28h]
  0.02%    0x00000243d89692cf: vaddsd  xmm0,xmm0,xmm1
  3.05%    0x00000243d89692d3: vpshufd xmm2,xmm1,0eh
  0.03%    0x00000243d89692d8: vaddsd  xmm0,xmm0,xmm2
  2.97%    0x00000243d89692dc: vextractf128 xmm14,ymm1,1h
  0.01%    0x00000243d89692e2: vaddsd  xmm0,xmm0,xmm14
  2.99%    0x00000243d89692e7: vpshufd xmm2,xmm14,0eh
  0.02%    0x00000243d89692ed: vaddsd  xmm0,xmm0,xmm2 

The sequential stream code is not as good – it is scalar – but the difference in performance is not as stark as it might be because of the inefficient scalar reduction in the for loop: this is JLS floating point semantics twisting C2’s arm behind its back.

  0.00%    0x0000021a1df54c24: vmovsd  xmm0,qword ptr [rbx+r9*8+48h]
  0.00%    0x0000021a1df54c2b: vmovsd  xmm2,qword ptr [rbx+r9*8+18h]
  0.02%    0x0000021a1df54c32: vmovsd  xmm3,qword ptr [rbx+r9*8+40h]
  2.93%    0x0000021a1df54c39: vmovsd  xmm4,qword ptr [rbx+r9*8+38h]
  0.00%    0x0000021a1df54c40: vmovsd  xmm5,qword ptr [rbx+r9*8+30h]
  0.01%    0x0000021a1df54c47: vmovsd  xmm6,qword ptr [rbx+r9*8+28h]
  0.02%    0x0000021a1df54c4e: vmovsd  xmm7,qword ptr [rbx+r9*8+20h]
  2.99%    0x0000021a1df54c55: vmulsd  xmm8,xmm0,xmm0
  0.00%    0x0000021a1df54c59: vmulsd  xmm0,xmm7,xmm7
           0x0000021a1df54c5d: vmulsd  xmm6,xmm6,xmm6
  0.01%    0x0000021a1df54c61: vmulsd  xmm5,xmm5,xmm5
  2.91%    0x0000021a1df54c65: vmulsd  xmm4,xmm4,xmm4
  0.00%    0x0000021a1df54c69: vmulsd  xmm3,xmm3,xmm3
  0.00%    0x0000021a1df54c6d: vmulsd  xmm2,xmm2,xmm2
  0.02%    0x0000021a1df54c71: vaddsd  xmm1,xmm2,xmm1
  6.10%    0x0000021a1df54c75: vaddsd  xmm0,xmm0,xmm1
  5.97%    0x0000021a1df54c79: vaddsd  xmm0,xmm6,xmm0
 16.22%    0x0000021a1df54c7d: vaddsd  xmm0,xmm5,xmm0
  7.86%    0x0000021a1df54c81: vaddsd  xmm0,xmm4,xmm0
 11.16%    0x0000021a1df54c85: vaddsd  xmm1,xmm3,xmm0
 11.90%    0x0000021a1df54c89: vaddsd  xmm0,xmm8,xmm1

The same code can be seen in SS_ParallelStream. SS_GenerativeSequentialStream is much more interesting because it hasn’t been unrolled – see the interleaved control statements. It is also not vectorised.

           0x0000013c1a639c17: vmovsd  xmm0,qword ptr [rbp+r9*8+10h]
  0.01%    0x0000013c1a639c1e: vmulsd  xmm2,xmm0,xmm0    
  0.01%    0x0000013c1a639c22: test    r8d,r8d
           0x0000013c1a639c25: jne     13c1a639e09h   
           0x0000013c1a639c2b: mov     r10d,dword ptr [r12+rax*8+8h]
           0x0000013c1a639c30: cmp     r10d,0f8022d85h 
           0x0000013c1a639c37: jne     13c1a639e3bh     
  0.01%    0x0000013c1a639c3d: vaddsd  xmm2,xmm1,xmm2
  0.01%    0x0000013c1a639c41: vmovsd  qword ptr [rdi+10h],xmm2
  0.00%    0x0000013c1a639c46: movsxd  r10,r9d
           0x0000013c1a639c49: vmovsd  xmm0,qword ptr [rbp+r10*8+18h]
  0.01%    0x0000013c1a639c50: vmulsd  xmm0,xmm0,xmm0
  0.01%    0x0000013c1a639c54: mov     r10d,dword ptr [r12+rax*8+8h]
  0.00%    0x0000013c1a639c59: cmp     r10d,0f8022d85h
           0x0000013c1a639c60: jne     13c1a639e30h
           0x0000013c1a639c66: vaddsd  xmm0,xmm0,xmm2
  0.02%    0x0000013c1a639c6a: vmovsd  qword ptr [rdi+10h],xmm0
  0.02%    0x0000013c1a639c6f: mov     r10d,r9d
           0x0000013c1a639c72: add     r10d,2h 
           0x0000013c1a639c76: cmp     r10d,r11d
           0x0000013c1a639c79: jnl     13c1a639d96h 
  0.01%    0x0000013c1a639c7f: add     r9d,4h 
  0.02%    0x0000013c1a639c83: vmovsd  xmm1,qword ptr [rbp+r10*8+10h]
  0.00%    0x0000013c1a639c8a: movzx   r8d,byte ptr [rdi+0ch]
  0.00%    0x0000013c1a639c8f: vmulsd  xmm1,xmm1,xmm1
  0.01%    0x0000013c1a639c93: test    r8d,r8d
           0x0000013c1a639c96: jne     13c1a639dfbh
  0.01%    0x0000013c1a639c9c: vaddsd  xmm1,xmm0,xmm1
  0.01%    0x0000013c1a639ca0: vmovsd  qword ptr [rdi+10h],xmm1
  0.02%    0x0000013c1a639ca5: movsxd  r8,r10d
  0.00%    0x0000013c1a639ca8: vmovsd  xmm0,qword ptr [rbp+r8*8+18h]
           0x0000013c1a639caf: vmulsd  xmm0,xmm0,xmm0
           0x0000013c1a639cb3: vaddsd  xmm0,xmm0,xmm1
  0.06%    0x0000013c1a639cb7: vmovsd  qword ptr [rdi+10h],xmm0

So it looks like streams don’t vectorise like good old for loops, and you won’t gain from Stream.parallelStream unless you have humungous arrays (which you might be avoiding for other reasons). This was actually a very nice case for the Stream because optimal code can’t be generated for floating point reductions. What happens with sum of squares for ints? Generating data in an unsurprising way:


  @Benchmark
  public int SS_SequentialStream_Int() {
    return IntStream.of(intData)
            .map(x -> x * x)
            .reduce((x, y) -> x + y)
            .orElse(0);
  }

  @Benchmark
  public int SS_ParallelStream_Int() {
    return IntStream.of(intData)
            .parallel()
            .map(x -> x * x)
            .reduce((x, y) -> x + y)
            .orElse(0);
  }

  @Benchmark
  public int SS_ForLoop_Int() {
    int result = 0;
    for (int i = 0; i < intData.length; ++i) {
      result += intData[i] * intData[i];
    }
    return result;
  }

  @Benchmark
  public int SS_GenerativeSequentialStream_Int() {
    return IntStream.iterate(0, i -> i < size, i -> i + 1)
            .map(i -> intData[i])
            .map(x -> x * x)
            .reduce((x, y) -> x + y)
            .orElse(0);
  }

The landscape has completely changed, thanks to the exploitation of associative arithmetic and the VPHADDD instruction which simplifies the reduction in the for loop.

<-- load -->
  0.00%    0x000001f5cdd8cd30: vmovdqu ymm0,ymmword ptr [rdi+r10*4+0f0h]
  1.93%    0x000001f5cdd8cd3a: vmovdqu ymm1,ymmword ptr [rdi+r10*4+0d0h]
  0.10%    0x000001f5cdd8cd44: vmovdqu ymm2,ymmword ptr [rdi+r10*4+0b0h]
  0.07%    0x000001f5cdd8cd4e: vmovdqu ymm3,ymmword ptr [rdi+r10*4+90h]
  0.05%    0x000001f5cdd8cd58: vmovdqu ymm4,ymmword ptr [rdi+r10*4+70h]
  1.75%    0x000001f5cdd8cd5f: vmovdqu ymm5,ymmword ptr [rdi+r10*4+50h]
  0.08%    0x000001f5cdd8cd66: vmovdqu ymm6,ymmword ptr [rdi+r10*4+30h]
  0.07%    0x000001f5cdd8cd6d: vmovdqu ymm7,ymmword ptr [rdi+r10*4+10h]
<-- multiply -->
  0.01%    0x000001f5cdd8cd74: vpmulld ymm0,ymm0,ymm0
  1.81%    0x000001f5cdd8cd79: vmovdqu ymmword ptr [rsp+28h],ymm0
  0.02%    0x000001f5cdd8cd7f: vpmulld ymm15,ymm7,ymm7
  1.79%    0x000001f5cdd8cd84: vpmulld ymm11,ymm1,ymm1
  0.06%    0x000001f5cdd8cd89: vpmulld ymm8,ymm2,ymm2
  1.82%    0x000001f5cdd8cd8e: vpmulld ymm9,ymm3,ymm3
  0.06%    0x000001f5cdd8cd93: vpmulld ymm10,ymm4,ymm4
  1.79%    0x000001f5cdd8cd98: vpmulld ymm12,ymm5,ymm5
  0.08%    0x000001f5cdd8cd9d: vpmulld ymm6,ymm6,ymm6
<-- vectorised reduce -->
  1.83%    0x000001f5cdd8cda2: vphaddd ymm4,ymm15,ymm15
  0.04%    0x000001f5cdd8cda7: vphaddd ymm4,ymm4,ymm7
  1.85%    0x000001f5cdd8cdac: vextracti128 xmm7,ymm4,1h
  0.07%    0x000001f5cdd8cdb2: vpaddd  xmm4,xmm4,xmm7
  1.78%    0x000001f5cdd8cdb6: vmovd   xmm7,r8d
  0.01%    0x000001f5cdd8cdbb: vpaddd  xmm7,xmm7,xmm4
  0.11%    0x000001f5cdd8cdbf: vmovd   r11d,xmm7
  0.05%    0x000001f5cdd8cdc4: vphaddd ymm4,ymm6,ymm6
  1.84%    0x000001f5cdd8cdc9: vphaddd ymm4,ymm4,ymm7
  5.43%    0x000001f5cdd8cdce: vextracti128 xmm7,ymm4,1h
  0.13%    0x000001f5cdd8cdd4: vpaddd  xmm4,xmm4,xmm7
  4.34%    0x000001f5cdd8cdd8: vmovd   xmm7,r11d
  0.36%    0x000001f5cdd8cddd: vpaddd  xmm7,xmm7,xmm4
  1.40%    0x000001f5cdd8cde1: vmovd   r8d,xmm7
  0.01%    0x000001f5cdd8cde6: vphaddd ymm6,ymm12,ymm12
  2.89%    0x000001f5cdd8cdeb: vphaddd ymm6,ymm6,ymm4
  3.25%    0x000001f5cdd8cdf0: vextracti128 xmm4,ymm6,1h
  0.87%    0x000001f5cdd8cdf6: vpaddd  xmm6,xmm6,xmm4
  6.36%    0x000001f5cdd8cdfa: vmovd   xmm4,r8d
  0.01%    0x000001f5cdd8cdff: vpaddd  xmm4,xmm4,xmm6
  1.69%    0x000001f5cdd8ce03: vmovd   r8d,xmm4
  0.03%    0x000001f5cdd8ce08: vphaddd ymm4,ymm10,ymm10
  1.83%    0x000001f5cdd8ce0d: vphaddd ymm4,ymm4,ymm7
  0.10%    0x000001f5cdd8ce12: vextracti128 xmm7,ymm4,1h
  3.29%    0x000001f5cdd8ce18: vpaddd  xmm4,xmm4,xmm7
  0.72%    0x000001f5cdd8ce1c: vmovd   xmm7,r8d
  0.23%    0x000001f5cdd8ce21: vpaddd  xmm7,xmm7,xmm4
  4.42%    0x000001f5cdd8ce25: vmovd   r11d,xmm7
  0.12%    0x000001f5cdd8ce2a: vphaddd ymm5,ymm9,ymm9
  1.69%    0x000001f5cdd8ce2f: vphaddd ymm5,ymm5,ymm1
  0.12%    0x000001f5cdd8ce34: vextracti128 xmm1,ymm5,1h
  3.28%    0x000001f5cdd8ce3a: vpaddd  xmm5,xmm5,xmm1
  0.22%    0x000001f5cdd8ce3e: vmovd   xmm1,r11d
  0.14%    0x000001f5cdd8ce43: vpaddd  xmm1,xmm1,xmm5
  3.81%    0x000001f5cdd8ce47: vmovd   r11d,xmm1
  0.22%    0x000001f5cdd8ce4c: vphaddd ymm0,ymm8,ymm8
  1.58%    0x000001f5cdd8ce51: vphaddd ymm0,ymm0,ymm3
  0.22%    0x000001f5cdd8ce56: vextracti128 xmm3,ymm0,1h
  2.82%    0x000001f5cdd8ce5c: vpaddd  xmm0,xmm0,xmm3
  0.36%    0x000001f5cdd8ce60: vmovd   xmm3,r11d
  0.20%    0x000001f5cdd8ce65: vpaddd  xmm3,xmm3,xmm0
  4.55%    0x000001f5cdd8ce69: vmovd   r8d,xmm3
  0.10%    0x000001f5cdd8ce6e: vphaddd ymm2,ymm11,ymm11
  1.71%    0x000001f5cdd8ce73: vphaddd ymm2,ymm2,ymm1
  0.09%    0x000001f5cdd8ce78: vextracti128 xmm1,ymm2,1h
  2.91%    0x000001f5cdd8ce7e: vpaddd  xmm2,xmm2,xmm1
  1.57%    0x000001f5cdd8ce82: vmovd   xmm1,r8d
  0.05%    0x000001f5cdd8ce87: vpaddd  xmm1,xmm1,xmm2
  4.84%    0x000001f5cdd8ce8b: vmovd   r11d,xmm1
  0.06%    0x000001f5cdd8ce90: vmovdqu ymm0,ymmword ptr [rsp+28h]
  0.03%    0x000001f5cdd8ce96: vphaddd ymm13,ymm0,ymm0
  1.83%    0x000001f5cdd8ce9b: vphaddd ymm13,ymm13,ymm14
  2.16%    0x000001f5cdd8cea0: vextracti128 xmm14,ymm13,1h
  0.14%    0x000001f5cdd8cea6: vpaddd  xmm13,xmm13,xmm14
  0.09%    0x000001f5cdd8ceab: vmovd   xmm14,r11d
  0.51%    0x000001f5cdd8ceb0: vpaddd  xmm14,xmm14,xmm13

If you’re the guy replacing all the for loops with streams because it’s 2018, you may be committing performance vandalism! That nice declarative API (as opposed to language feature) is at arms length and it really isn’t well optimised yet.

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
SS_ForLoop_Int thrpt 1 10 1021725.018981 74264.883362 ops/s 1024
SS_ForLoop_Int thrpt 1 10 129250.855026 5764.608094 ops/s 8192
SS_GenerativeSequentialStream_Int thrpt 1 10 55069.227826 1111.903102 ops/s 1024
SS_GenerativeSequentialStream_Int thrpt 1 10 6769.176830 684.970867 ops/s 8192
SS_ParallelStream_Int thrpt 1 10 20970.387258 719.846643 ops/s 1024
SS_ParallelStream_Int thrpt 1 10 19621.397202 1514.374286 ops/s 8192
SS_SequentialStream_Int thrpt 1 10 586847.001223 22390.512706 ops/s 1024
SS_SequentialStream_Int thrpt 1 10 87620.959677 3437.083075 ops/s 8192

Parallel streams might not be the best thing to reach for.

Multiplying Matrices, Fast and Slow

I recently read a very interesting blog post about exposing Intel SIMD intrinsics via a fork of the Scala compiler (scala-virtualized), which reports multiplicative improvements in throughput over HotSpot JIT compiled code. The academic paper (SIMD Intrinsics on Managed Language Runtimes), which has been accepted at CGO 2018, proposes a powerful alternative to the traditional JVM approach of pairing dumb programmers with a (hopefully) smart JIT compiler. Lightweight Modular Staging (LMS) allows the generation of an executable binary from a high level representation: handcrafted representations of vectorised algorithms, written in a dialect of Scala, can be compiled natively and later invoked with a single JNI call. This approach bypasses C2 without incurring excessive JNI costs. The freely available benchmarks can be easily run to reproduce the results in the paper, which is an achievement in itself, but some of the Java implementations used as baselines look less efficient than they could be. This post is about improving the efficiency of the Java matrix multiplication the LMS generated code is benchmarked against. Despite finding edge cases where autovectorisation fails, I find it is possible to get performance comparable to LMS with plain Java (and a JDK upgrade).

Two implementations of Java matrix multiplication are provided in the NGen benchmarks: JMMM.baseline – a naive but cache unfriendly matrix multiplication – and JMMM.blocked which is supplied as an improvement. JMMM.blocked is something of a local maximum because it does manual loop unrolling: this actually removes the trigger for autovectorisation analysis. I provide a simple and cache-efficient Java implementation (with the same asymptotic complexity, the improvement is just technical) and benchmark these implementations using JDK8 and the soon to be released JDK10 separately.

public void fast(float[] a, float[] b, float[] c, int n) {
   int in = 0;
   for (int i = 0; i < n; ++i) {
       int kn = 0;
       for (int k = 0; k < n; ++k) {
           float aik = a[in + k];
           for (int j = 0; j < n; ++j) {
               c[in + j] += aik * b[kn + j];
           }
           kn += n;
       }
       in += n;
    }
}

With JDK 1.8.0_131, the “fast” implementation is only 2x faster than the blocked algorithm; this is nowhere near fast enough to match LMS. In fact, LMS does a lot better than 5x blocked (6x-8x) on my Skylake laptop at 2.6GHz, and performs between 2x and 4x better than the improved implementation. Flops / Cycle is calculated as size ^ 3 * 2 / CPU frequency Hz.

====================================================
Benchmarking MMM.jMMM.fast (JVM implementation)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 0.4994459272
          32 | 1.0666533335
          64 | 0.9429120397
         128 | 0.9692385519
         192 | 0.9796619688
         256 | 1.0141446247
         320 | 0.9894415771
         384 | 1.0046245750
         448 | 1.0221353392
         512 | 0.9943527764
         576 | 0.9952093603
         640 | 0.9854689714
         704 | 0.9947153752
         768 | 1.0197765248
         832 | 1.0479691069
         896 | 1.0060121097
         960 | 0.9937347412
        1024 | 0.9056494897
====================================================

====================================================
Benchmarking MMM.nMMM.blocked (LMS generated)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 0.2500390686
          32 | 3.9999921875
          64 | 4.1626523901
         128 | 4.4618695374
         192 | 3.9598982956
         256 | 4.3737341517
         320 | 4.2412225389
         384 | 3.9640163416
         448 | 4.0957167537
         512 | 3.3801071278
         576 | 4.1869326167
         640 | 3.8225244883
         704 | 3.8648224140
         768 | 3.5240611589
         832 | 3.7941562681
         896 | 3.1735179981
         960 | 2.5856903789
        1024 | 1.7817152313
====================================================

====================================================
Benchmarking MMM.jMMM.blocked (JVM implementation)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 0.3333854248
          32 | 0.6336670915
          64 | 0.5733484649
         128 | 0.5987433798
         192 | 0.5819900921
         256 | 0.5473562109
         320 | 0.5623263520
         384 | 0.5583823292
         448 | 0.5657882256
         512 | 0.5430879470
         576 | 0.5269635678
         640 | 0.5595204791
         704 | 0.5297557807
         768 | 0.5493631388
         832 | 0.5471832673
         896 | 0.4769554752
         960 | 0.4985080443
        1024 | 0.4014589400
====================================================

JDK10 is about to be released so it’s worth looking at the effect of recent improvements to C2, including better use of AVX2 and support for vectorised FMA. Since LMS depends on scala-virtualized, which currently only supports Scala 2.11, the LMS implementation cannot be run with a more recent JDK so its performance running in JDK10 could only be extrapolated. Since its raison d’ĂȘtre is to bypass C2, it could be reasonably assumed it is insulated from JVM performance improvements (or regressions). Measurements of floating point operations per cycle provide a sensible comparison, in any case.

Moving away from ScalaMeter, I created a JMH benchmark to see how matrix multiplication behaves in JDK10.

@OutputTimeUnit(TimeUnit.SECONDS)
@State(Scope.Benchmark)
public class MMM {

  @Param({"8", "32", "64", "128", "192", "256", "320", "384", "448", "512" , "576", "640", "704", "768", "832", "896", "960", "1024"})
  int size;

  private float[] a;
  private float[] b;
  private float[] c;

  @Setup(Level.Trial)
  public void init() {
    a = DataUtil.createFloatArray(size * size);
    b = DataUtil.createFloatArray(size * size);
    c = new float[size * size];
  }

  @Benchmark
  public void fast(Blackhole bh) {
    fast(a, b, c, size);
    bh.consume(c);
  }

  @Benchmark
  public void baseline(Blackhole bh) {
    baseline(a, b, c, size);
    bh.consume(c);
  }

  @Benchmark
  public void blocked(Blackhole bh) {
    blocked(a, b, c, size);
    bh.consume(c);
  }

  //
  // Baseline implementation of a Matrix-Matrix-Multiplication
  //
  public void baseline (float[] a, float[] b, float[] c, int n){
    for (int i = 0; i < n; i += 1) {
      for (int j = 0; j < n; j += 1) {
        float sum = 0.0f;
        for (int k = 0; k < n; k += 1) {
          sum += a[i * n + k] * b[k * n + j];
        }
        c[i * n + j] = sum;
      }
    }
  }

  //
  // Blocked version of MMM, reference implementation available at:
  // http://csapp.cs.cmu.edu/2e/waside/waside-blocking.pdf
  //
  public void blocked(float[] a, float[] b, float[] c, int n) {
    int BLOCK_SIZE = 8;
    for (int kk = 0; kk < n; kk += BLOCK_SIZE) {
      for (int jj = 0; jj < n; jj += BLOCK_SIZE) {
        for (int i = 0; i < n; i++) {
          for (int j = jj; j < jj + BLOCK_SIZE; ++j) {
            float sum = c[i * n + j];
            for (int k = kk; k < kk + BLOCK_SIZE; ++k) {
              sum += a[i * n + k] * b[k * n + j];
            }
            c[i * n + j] = sum;
          }
        }
      }
    }
  }

  public void fast(float[] a, float[] b, float[] c, int n) {
    int in = 0;
    for (int i = 0; i < n; ++i) {
      int kn = 0;
      for (int k = 0; k < n; ++k) {
        float aik = a[in + k];
        for (int j = 0; j < n; ++j) {
          c[in + j] = Math.fma(aik,  b[kn + j], c[in + j]);
        }
        kn += n;
      }
      in += n;
    }
  }
}

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size Ratio to blocked Flops/Cycle
baseline thrpt 1 10 1228544.82 38793.17392 ops/s 8 1.061598336 0.483857652
baseline thrpt 1 10 22973.03402 1012.043446 ops/s 32 1.302266947 0.57906183
baseline thrpt 1 10 2943.088879 221.57475 ops/s 64 1.301414733 0.593471609
baseline thrpt 1 10 358.010135 9.342801 ops/s 128 1.292889618 0.577539747
baseline thrpt 1 10 105.758366 4.275503 ops/s 192 1.246415143 0.575804515
baseline thrpt 1 10 41.465557 1.112753 ops/s 256 1.430003946 0.535135851
baseline thrpt 1 10 20.479081 0.462547 ops/s 320 1.154267894 0.516198866
baseline thrpt 1 10 11.686685 0.263476 ops/s 384 1.186535349 0.509027985
baseline thrpt 1 10 7.344184 0.269656 ops/s 448 1.166421127 0.507965526
baseline thrpt 1 10 3.545153 0.108086 ops/s 512 0.81796657 0.366017216
baseline thrpt 1 10 3.789384 0.130934 ops/s 576 1.327168294 0.557048123
baseline thrpt 1 10 1.981957 0.040136 ops/s 640 1.020965271 0.399660104
baseline thrpt 1 10 1.76672 0.036386 ops/s 704 1.168272442 0.474179037
baseline thrpt 1 10 1.01026 0.049853 ops/s 768 0.845514112 0.352024966
baseline thrpt 1 10 1.115814 0.03803 ops/s 832 1.148752171 0.494331667
baseline thrpt 1 10 0.703561 0.110626 ops/s 896 0.938435436 0.389298235
baseline thrpt 1 10 0.629896 0.052448 ops/s 960 1.081741651 0.428685898
baseline thrpt 1 10 0.407772 0.019079 ops/s 1024 1.025356561 0.336801424
blocked thrpt 1 10 1157259.558 49097.48711 ops/s 8 1 0.455782226
blocked thrpt 1 10 17640.8025 1226.401298 ops/s 32 1 0.444656782
blocked thrpt 1 10 2261.453481 98.937035 ops/s 64 1 0.456020355
blocked thrpt 1 10 276.906961 22.851857 ops/s 128 1 0.446704605
blocked thrpt 1 10 84.850033 4.441454 ops/s 192 1 0.461968485
blocked thrpt 1 10 28.996813 7.585551 ops/s 256 1 0.374219842
blocked thrpt 1 10 17.742052 0.627629 ops/s 320 1 0.447208892
blocked thrpt 1 10 9.84942 0.367603 ops/s 384 1 0.429003641
blocked thrpt 1 10 6.29634 0.402846 ops/s 448 1 0.435490676
blocked thrpt 1 10 4.334105 0.384849 ops/s 512 1 0.447472097
blocked thrpt 1 10 2.85524 0.199102 ops/s 576 1 0.419726816
blocked thrpt 1 10 1.941258 0.10915 ops/s 640 1 0.391453182
blocked thrpt 1 10 1.51225 0.076621 ops/s 704 1 0.40588053
blocked thrpt 1 10 1.194847 0.063147 ops/s 768 1 0.416344283
blocked thrpt 1 10 0.971327 0.040421 ops/s 832 1 0.430320551
blocked thrpt 1 10 0.749717 0.042997 ops/s 896 1 0.414837526
blocked thrpt 1 10 0.582298 0.016725 ops/s 960 1 0.39629231
blocked thrpt 1 10 0.397688 0.043639 ops/s 1024 1 0.328472491
fast thrpt 1 10 1869676.345 76416.50848 ops/s 8 1.615606743 0.736364837
fast thrpt 1 10 48485.47216 1301.926828 ops/s 32 2.748484496 1.222132271
fast thrpt 1 10 6431.341657 153.905413 ops/s 64 2.843897392 1.296875098
fast thrpt 1 10 840.601821 45.998723 ops/s 128 3.035683242 1.356053685
fast thrpt 1 10 260.386996 13.022418 ops/s 192 3.068790745 1.417684611
fast thrpt 1 10 107.895708 6.584674 ops/s 256 3.720950575 1.392453537
fast thrpt 1 10 56.245336 2.729061 ops/s 320 3.170170846 1.417728592
fast thrpt 1 10 32.917996 2.196624 ops/s 384 3.342125323 1.433783932
fast thrpt 1 10 20.960189 2.077684 ops/s 448 3.328948087 1.449725854
fast thrpt 1 10 14.005186 0.7839 ops/s 512 3.231390564 1.445957112
fast thrpt 1 10 8.827584 0.883654 ops/s 576 3.091713481 1.297675056
fast thrpt 1 10 7.455607 0.442882 ops/s 640 3.840605937 1.503417416
fast thrpt 1 10 5.322894 0.464362 ops/s 704 3.519850554 1.428638807
fast thrpt 1 10 4.308522 0.153846 ops/s 768 3.605919419 1.501303934
fast thrpt 1 10 3.375274 0.106715 ops/s 832 3.474910097 1.495325228
fast thrpt 1 10 2.320152 0.367881 ops/s 896 3.094703735 1.28379924
fast thrpt 1 10 2.057478 0.150198 ops/s 960 3.533376381 1.400249889
fast thrpt 1 10 1.66255 0.181116 ops/s 1024 4.180538513 1.3731919

Interestingly, the blocked algorithm is now the worst native JVM implementation. The code generated by C2 got a lot faster, but peaks at 1.5 flops/cycle, which still doesn’t compete with LMS. Why? Taking a look at the assembly, it’s clear that the autovectoriser choked on the array offsets and produced scalar SSE2 code, just like the implementations in the paper. I wasn’t expecting this.

vmovss  xmm5,dword ptr [rdi+rcx*4+10h]
vfmadd231ss xmm5,xmm6,xmm2
vmovss  dword ptr [rdi+rcx*4+10h],xmm5

Is this the end of the story? No, with some hacks and the cost of array allocation and a copy or two, autovectorisation can be tricked into working again to generate faster code:


    public void fast(float[] a, float[] b, float[] c, int n) {
        float[] bBuffer = new float[n];
        float[] cBuffer = new float[n];
        int in = 0;
        for (int i = 0; i < n; ++i) {
            int kn = 0;
            for (int k = 0; k < n; ++k) {
                float aik = a[in + k];
                System.arraycopy(b, kn, bBuffer, 0, n);
                saxpy(n, aik, bBuffer, cBuffer);
                kn += n;
            }
            System.arraycopy(cBuffer, 0, c, in, n); 
            Arrays.fill(cBuffer, 0f);
            in += n;
        }
    }

    private void saxpy(int n, float aik, float[] b, float[] c) {
        for (int i = 0; i < n; ++i) {
            c[i] += aik * b[i];
        }
    }

Adding this hack into the NGen benchmark (back in JDK 1.8.0_131) I get closer to the LMS generated code, and beat it beyond L3 cache residency (6MB). LMS is still faster when both matrices fit in L3 concurrently, but by percentage points rather than a multiple. The cost of the hacky array buffers gives the game up for small matrices.

====================================================
Benchmarking MMM.jMMM.fast (JVM implementation)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 0.2500390686
          32 | 0.7710872405
          64 | 1.1302489072
         128 | 2.5113453810
         192 | 2.9525859816
         256 | 3.1180920385
         320 | 3.1081563593
         384 | 3.1458423577
         448 | 3.0493148252
         512 | 3.0551158263
         576 | 3.1430376938
         640 | 3.2169923048
         704 | 3.1026513283
         768 | 2.4190053777
         832 | 3.3358586705
         896 | 3.0755689237
         960 | 2.9996690697
        1024 | 2.2935654309
====================================================

====================================================
Benchmarking MMM.nMMM.blocked (LMS generated)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 1.0001562744
          32 | 5.3330416826
          64 | 5.8180867784
         128 | 5.1717318641
         192 | 5.1639907462
         256 | 4.3418618628
         320 | 5.2536572701
         384 | 4.0801359215
         448 | 4.1337007093
         512 | 3.2678160754
         576 | 3.7973028890
         640 | 3.3557513664
         704 | 4.0103133240
         768 | 3.4188362575
         832 | 3.2189488327
         896 | 3.2316685219
         960 | 2.9985655539
        1024 | 1.7750946796
====================================================

With the benchmark below I calculate flops/cycle with improved JDK10 autovectorisation.


  @Benchmark
  public void fastBuffered(Blackhole bh) {
    fastBuffered(a, b, c, size);
    bh.consume(c);
  }

  public void fastBuffered(float[] a, float[] b, float[] c, int n) {
    float[] bBuffer = new float[n];
    float[] cBuffer = new float[n];
    int in = 0;
    for (int i = 0; i < n; ++i) {
      int kn = 0;
      for (int k = 0; k < n; ++k) {
        float aik = a[in + k];
        System.arraycopy(b, kn, bBuffer, 0, n);
        saxpy(n, aik, bBuffer, cBuffer);
        kn += n;
      }
      System.arraycopy(cBuffer, 0, c, in, n);
      Arrays.fill(cBuffer, 0f);
      in += n;
    }
  }

  private void saxpy(int n, float aik, float[] b, float[] c) {
    for (int i = 0; i < n; ++i) {
      c[i] = Math.fma(aik, b[i], c[i]);
    }
  }

Just as in the modified NGen benchmark, this starts paying off once the matrices have 64 rows and columns. Finally, and it took an upgrade and a hack, I breached 4 Flops per cycle:

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size Flops / Cycle
fastBuffered thrpt 1 10 1047184.034 63532.95095 ops/s 8 0.412429404
fastBuffered thrpt 1 10 58373.56367 3239.615866 ops/s 32 1.471373026
fastBuffered thrpt 1 10 12099.41654 497.33988 ops/s 64 2.439838038
fastBuffered thrpt 1 10 2136.50264 105.038006 ops/s 128 3.446592911
fastBuffered thrpt 1 10 673.470622 102.577237 ops/s 192 3.666730488
fastBuffered thrpt 1 10 305.541519 25.959163 ops/s 256 3.943181586
fastBuffered thrpt 1 10 158.437372 6.708384 ops/s 320 3.993596774
fastBuffered thrpt 1 10 88.283718 7.58883 ops/s 384 3.845306266
fastBuffered thrpt 1 10 58.574507 4.248521 ops/s 448 4.051345968
fastBuffered thrpt 1 10 37.183635 4.360319 ops/s 512 3.839002314
fastBuffered thrpt 1 10 29.949884 0.63346 ops/s 576 4.40270151
fastBuffered thrpt 1 10 20.715833 4.175897 ops/s 640 4.177331789
fastBuffered thrpt 1 10 10.824837 0.902983 ops/s 704 2.905333492
fastBuffered thrpt 1 10 8.285254 1.438701 ops/s 768 2.886995686
fastBuffered thrpt 1 10 6.17029 0.746537 ops/s 832 2.733582608
fastBuffered thrpt 1 10 4.828872 1.316901 ops/s 896 2.671937962
fastBuffered thrpt 1 10 3.6343 1.293923 ops/s 960 2.473381573
fastBuffered thrpt 1 10 2.458296 0.171224 ops/s 1024 2.030442485

The code generated for the core of the loop looks better now:

vmovdqu ymm1,ymmword ptr [r13+r11*4+10h]
vfmadd231ps ymm1,ymm3,ymmword ptr [r14+r11*4+10h]
vmovdqu ymmword ptr [r13+r11*4+10h],ymm1                                               

These benchmark results can be compared on a line chart.

Given this improvement, it would be exciting to see how LMS can profit from JDK9 or JDK10 – does LMS provide the impetus to resume maintenance of scala-virtualized? L3 cache, which the LMS generated code seems to depend on for throughput, is typically shared between cores: a single thread rarely enjoys exclusive access. I would like to see benchmarks for the LMS generated code in the presence of concurrency.

Incidental Similarity

I recently saw an interesting class, BitVector, in Apache Arrow, which represents a column of bits, providing minimal or zero copy distribution. The implementation is similar to a bitset but backed by a byte[] rather than a long[]. Given the coincidental similarity in implementation, it’s tempting to look at this, extend its interface and try to use it as a general purpose, distributed bitset. Could this work? Why not just implement some extra methods? Fork it on Github!

This post details the caveats of trying to adapt an abstraction beyond its intended purpose; in a scenario where generic bitset capabilities are added to BitVector without due consideration, examined through the lens of performance. This runs into the observable effect of word widening on throughput, given the constraints imposed by JLS 15.22. In the end, the only remedy is to use a long[], sacrificing the original zero copy design goal. I hope this is a fairly self-contained example of how uncontrolled adaptation can be hostile to the original design goals: having the source code isn’t enough reason to modify it.

Checking bits

How fast is it to check if the bit at index i is set or not? BitVector implements this functionality, and was designed for it. This can be measured by JMH by generating a random long[] and creating a byte[] 8x longer with identical bits. The throughput of checking the value of the bit at random indices can be measured. It turns out that if all you want to do is access bits, byte[] isn’t such a bad choice, and if those bytes are coming directly from the network, it could even be a great choice. I ran the benchmark below and saw that the two operations are similar (within measurement error).


@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Thread)
public class BitSet {

    @Param({"1024", "2048", "4096", "8192"})
    int size;

    private long[] leftLongs;
    private long[] rightLongs;
    private long[] differenceLongs;
    private byte[] leftBytes;
    private byte[] rightBytes;
    private byte[] differenceBytes;

    @Setup(Level.Trial)
    public void init() {
        this.leftLongs = createLongArray(size);
        this.rightLongs = createLongArray(size);
        this.differenceLongs = new long[size];
        this.leftBytes = makeBytesFromLongs(leftLongs);
        this.rightBytes = makeBytesFromLongs(rightLongs);
        this.differenceBytes = new byte[size * 8];
    }

    @Benchmark
    public boolean CheckBit_LongArray() {
        int index = index();
        return (leftLongs[index >>> 6] & (1L << index)) != 0;
    }

    @Benchmark
    public boolean CheckBit_ByteArray() {
        int index = index();
        return ((leftBytes[index >>> 3] & 0xFF) & (1 << (index & 7))) != 0;
    }

    private int index() {
        return ThreadLocalRandom.current().nextInt(size * 64);
    }

    private static byte[] makeBytesFromLongs(long[] array) {
        byte[] bytes = new byte[8 * array.length];
        for (int i = 0; i < array.length; ++i) {
            long word = array[i];
            bytes[8 * i + 7] = (byte) word;
            bytes[8 * i + 6] = (byte) (word >>> 8);
            bytes[8 * i + 5] = (byte) (word >>> 16);
            bytes[8 * i + 4] = (byte) (word >>> 24);
            bytes[8 * i + 3] = (byte) (word >>> 32);
            bytes[8 * i + 2] = (byte) (word >>> 40);
            bytes[8 * i + 1] = (byte) (word >>> 48);
            bytes[8 * i]     = (byte) (word >>> 56);
        }
        return bytes;
    }
}

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
CheckBit_ByteArray thrpt 1 10 174.421170 1.583275 ops/us 1024
CheckBit_ByteArray thrpt 1 10 173.938408 1.445796 ops/us 2048
CheckBit_ByteArray thrpt 1 10 172.522190 0.815596 ops/us 4096
CheckBit_ByteArray thrpt 1 10 167.550530 1.677091 ops/us 8192
CheckBit_LongArray thrpt 1 10 171.639695 0.934494 ops/us 1024
CheckBit_LongArray thrpt 1 10 169.703960 2.427244 ops/us 2048
CheckBit_LongArray thrpt 1 10 169.333360 1.649654 ops/us 4096
CheckBit_LongArray thrpt 1 10 166.518375 0.815433 ops/us 8192

To support this functionality, there’s no reason to choose either way, and it must be very appealing to use bytes as they are delivered from the network, avoiding copying costs. Given that for a database column, this is the only operation needed, and Apache Arrow has a stated aim to copy data as little as possible, this seems like quite a good decision.

Logical Conjugations

But what happens if you try to add a logical operation to BitVector, such as an XOR? We need to handle the fact that bytes are signed and their sign bit must be preserved in promotion, according to the JLS. This would break the bitset, so extra operations are required to keep the 8th bit in its right place. With the widening and its associated workarounds, suddenly the byte[] is a much poorer choice than a long[], and it shows in benchmarks.


    @Benchmark
    public void Difference_ByteArray(Blackhole bh) {
        for (int i = 0; i < leftBytes.length && i < rightBytes.length; ++i) {
            differenceBytes[i] = (byte)((leftBytes[i] & 0xFF) ^ (rightBytes[i] & 0xFF));
        }
        bh.consume(differenceBytes);
    }

    @Benchmark
    public void Difference_LongArray(Blackhole bh) {
        for (int i = 0; i < leftLongs.length && i < rightLongs.length; ++i) {
            differenceLongs[i] = leftLongs[i] ^ rightLongs[i];
        }
        bh.consume(differenceLongs);
    }

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
Difference_ByteArray thrpt 1 10 0.805872 0.038644 ops/us 1024
Difference_ByteArray thrpt 1 10 0.391705 0.017453 ops/us 2048
Difference_ByteArray thrpt 1 10 0.190102 0.008580 ops/us 4096
Difference_ByteArray thrpt 1 10 0.169104 0.015086 ops/us 8192
Difference_LongArray thrpt 1 10 2.450659 0.094590 ops/us 1024
Difference_LongArray thrpt 1 10 1.047330 0.016898 ops/us 2048
Difference_LongArray thrpt 1 10 0.546286 0.014211 ops/us 4096
Difference_LongArray thrpt 1 10 0.277378 0.015663 ops/us 8192

This is a fairly crazy slow down. Why? You need to look at the assembly generated in each case. For long[] it’s demonstrable that logical operations do vectorise. The JLS, specifically section 15.22, doesn’t really give the byte[] implementation a chance. It states that for logical operations, sub dword primitive types must be promoted or widened before the operation. This means that if one were to try to implement this operation with, say AVX2, using 256 bit ymmwords each consisting of 16 bytes, then each ymmword would have to be inflated by a factor of four: it gets complicated quickly, given this constraint. Despite that complexity, I was surprised to see that C2 does use 128 bit xmmwords, but it’s not as fast as using the full 256 bit registers available. This can be seen by printing out the emitted assembly like normal.

movsxd  r10,ebx     

vmovq   xmm2,mmword ptr [rsi+r10+10h]

vpxor   xmm2,xmm2,xmmword ptr [r8+r10+10h]

vmovq   mmword ptr [rax+r10+10h],xmm2

Vectorised Logical Operations in Java 9

This is a short post for my own reference, since I feel I have already done the topic of does Java 9 use AVX for this? to death. Cutting to the chase, Java 9 autovectorises loops to compute logical ANDs, XORs, ORs and ANDNOTs between arrays, making use of the instructions VPXOR, VPOR and VPAND. You can replicate this by running the code at github.

XOR


    @Benchmark
    public long[] xor(LongData state) {
        long[] result = new long[state.data1.length];
        long[] data1 = state.data1;
        long[] data2 = state.data2;
        for (int i = 0; i < data1.length && i < data2.length; ++i) {
            result[i] = data1[i] ^ data2[i];
        }
        return result;
    }

vmovdqu ymm0,ymmword ptr [r10+r13*8+10h]

vpxor   ymm0,ymm0,ymmword ptr [rbx+r13*8+10h]

vmovdqu ymmword ptr [rax+r13*8+10h],ymm0

OR


    @Benchmark
    public long[] or(LongData state) {
        long[] result = new long[state.data1.length];
        long[] data1 = state.data1;
        long[] data2 = state.data2;
        for (int i = 0; i < data1.length && i < data2.length; ++i) {
            result[i] = data1[i] | data2[i];
        }
        return result;
    }

vmovdqu ymm0,ymmword ptr [r10+rsi*8+30h]
 
vpor    ymm0,ymm0,ymmword ptr [rbx+rsi*8+30h]

vmovdqu ymmword ptr [rax+rsi*8+30h],ymm0

AND


    @Benchmark
    public long[] and(LongData state) {
        long[] result = new long[state.data1.length];
        long[] data1 = state.data1;
        long[] data2 = state.data2;
        for (int i = 0; i < data1.length && i < data2.length; ++i) {
            result[i] = data1[i] & data2[i];
        }
        return result;
    }

vmovdqu ymm0,ymmword ptr [r10+r13*8+10h]

vpand   ymm0,ymm0,ymmword ptr [rbx+r13*8+10h]

vmovdqu ymmword ptr [rax+r13*8+10h],ymm0

ANDNOT


    @Benchmark
    public long[] andNot(LongData state) {
        long[] result = new long[state.data1.length];
        long[] data1 = state.data1;
        long[] data2 = state.data2;
        for (int i = 0; i < data1.length && i < data2.length; ++i) {
            result[i] = data1[i] & ~data2[i];
        }
        return result;
    }

vpunpcklqdq xmm0,xmm0,xmm0

vinserti128 ymm0,ymm0,xmm0,1h

vmovdqu ymm1,ymmword ptr [rbx+r13*8+10h]

vpxor   ymm1,ymm1,ymm0

vpand   ymm1,ymm1,ymmword ptr [r10+r13*8+10h]

vmovdqu ymmword ptr [rax+r13*8+10h],ymm1