Vectorised Algorithms in Java

There has been a Cambrian explosion of JVM data technologies in recent years. It’s all very exciting, but is the JVM really competitive with C in this area? I would argue that there is a reason Apache Arrow is polyglot, and it’s not just interoperability with Python. To pick on one project impressive enough to be thriving after seven years, if you’ve actually used Apache Spark you will be aware that it looks fastest next to its predecessor, MapReduce. Big data is a lot like teenage sex: everybody talks about it, nobody really knows how to do it, and everyone keeps their embarrassing stories to themselves. In games of incomplete information, it’s possible to overestimate the competence of others: nobody opens up about how slow their Spark jobs really are because there’s a risk of looking stupid.

If it can be accepted that Spark is inefficient, the question becomes is Spark fundamentally inefficient? Flare provides a drop-in replacement for Spark’s backend, but replaces JIT compiled code with highly efficient native code, yielding order of magnitude improvements in job throughput. Some of Flare’s gains come from generating specialised code, but the rest comes from just generating better native code than C2 does. If Flare validates Spark’s execution model, perhaps it raises questions about the suitability of the JVM for high throughput data processing.

I think this will change radically in the coming years. I think the most important reason is the advent of explicit support for SIMD provided by the vector API, which is currently incubating in Project Panama. Once the vector API is complete, I conjecture that projects like Spark will be able to profit enormously from it. This post takes a look at the API in its current state and ignores performance.

Why Vectorisation?

Assuming a flat processor frequency, throughput is improved by a combination of executing many instructions per cycle (pipelining) and processing multiple data items per instruction (SIMD). SIMD instruction sets are provided by Intel as the various generations of SSE and AVX. If throughput is the only goal, maximising SIMD may even be worth reducing the frequency, which can happen on Intel chips when using AVX. Vectorisation allows throughput to be increased by the use of SIMD instructions.

Analytical workloads are particularly suitable for vectorisation, especially over columnar data, because they typically involve operations consuming the entire range of a few numerical attributes of a data set. Vectorised analytical processing with filters is explicitly supported by vector masks, and vectorisation is also profitable for operations on indices typically performed for filtering prior to calculations. I don’t actually need to make a strong case for the impact of vectorisation on analytical workloads: just read the work of top researchers like Daniel Abadi and Daniel Lemire.

Vectorisation in the JVM

C2 provides quite a lot of autovectorisation, which works very well sometimes, but the support is limited and brittle. I have written about this several times. Because AVX can reduce the processor frequency, it’s not always profitable to vectorise, so compilers employ cost models to decide when they should do so. Such cost models require platform specific calibration, and sometimes C2 can get it wrong. Sometimes, specifically in the case of floating point operations, using SIMD conflicts with the JLS, and the code C2 generates can be quite inefficient. In general, data parallel code can be better optimised by C compilers, such as GCC, than C2 because there are fewer constraints, and there is a larger budget for analysis at compile time. This all makes having intrinsics very appealing, and as a user I would like to be able to:

  1. Bypass JLS floating point constraints.
  2. Bypass cost model based decisions.
  3. Avoid JNI at all costs.
  4. Use a modern “object-functional” style. SIMD intrinsics in C are painful.

There is another attempt to provide SIMD intrinsics to JVM users via LMS, a framework for writing programs which write programs, designed by Tiark Rompf (who is also behind Flare). This work is very promising (I have written about it before), but it uses JNI. It’s only at the prototype stage, but currently the intrinsics are auto-generated from XML definitions, which leads to a one-to-one mapping to the intrinsics in immintrin.h, yielding a similar programming experience. This could likely be improved a lot, but the reliance on JNI is fundamental, albeit with minimal boundary crossing.

I am quite excited by the vector API in Project Panama because it looks like it will meet all of these requirements, at least to some extent. It remains to be seen quite how far the implementors will go in the direction of associative floating point arithmetic, but it has to opt out of JLS floating point semantics to some extent, which I think is progressive.

The Vector API

Disclaimer: Everything below is based on my experience with a recent build of the experimental code in the Project Panama fork of OpenJDK. I am not affiliated with the design or implementation of this API, may not be using it properly, and it may change according to its designers’ will before it is released!

To understand the vector API you need to know that there are different register widths and different SIMD instruction sets. Because of my area of work, and 99% of the server market is Intel, I am only interested in AVX, but ARM have their own implementations with different maximum register sizes, which presumably need to be handled by a JVM vector API. On Intel CPUs, SSE instruction sets use up to 128 bit registers (xmm, four ints), AVX and AVX2 use up to 256 bit registers (ymm, eight ints), and AVX512 use up to 512 bit registers (zmm, sixteen ints).

The instruction sets are typed, and instructions designed to operate on packed doubles can’t operate on packed ints without explicit casting. This is modeled by the interface Vector<Shape>, parametrised by the Shape interface which models the register width.

The types of the vector elements is modeled by abstract element type specific classes such as IntVector. At the leaves of the hierarchy are the concrete classes specialised both to element type and register width, such as IntVector256 which extends IntVector<Shapes.S256Bit>.

Since EJB, the word factory has been a dirty word, which might be why the word species is used in this API. To create a IntVector<Shapes.S256Bit>, you can create the factory/species as follows:

public static final IntVector.IntSpecies<Shapes.S256Bit> YMM_INT = 
          (IntVector.IntSpecies<Shapes.S256Bit>) Vector.species(int.class, Shapes.S_256_BIT);

There are now various ways to create a vector from the species, which all have their use cases. First, you can load vectors from arrays: imagine you want to calculate the bitwise intersection of two int[]s. This can be written quite cleanly, without any shape/register information.


public static int[] intersect(int[] left, int[] right) {
    assert left.length == right.length;
    int[] result = new int[left.length];
    for (int i = 0; i < left.length; i += YMM_INT.length()) {
      YMM_INT.fromArray(left, i)
             .and(YMM_INT.fromArray(right, i))
             .intoArray(result, i);
    }
}

A common pattern in vectorised code is to broadcast a variable into a vector, for instance, to facilitate the multiplication of a vector by a scalar.

IntVector<Shapes.S256Bit> multiplier = YMM_INT.broadcast(x);

Or to create a vector from some scalars, for instance in a lookup table.

IntVector<Shapes.S256Bit> vector = YMM_INT.scalars(0, 1, 2, 3, 4, 5, 6, 7);

A zero vector can be created from a species:

IntVector<Shapes.S256Bit> zero = YMM_INT.zero();

The big split in the class hierarchy is between integral and floating point types. Integral types have meaningful bitwise operations (I am looking forward to trying to write a vectorised population count algorithm), which are absent from FloatVector and DoubleVector, and there is no concept of fused-multiply-add for integral types, so there is obviously no IntVector.fma. The common subset of operations is arithmetic, casting and loading/storing operations.

I generally like the API a lot: it feels familiar to programming with streams, but on the other hand, it isn’t too far removed from traditional intrinsics. Below is an implementation of a fast matrix multiplication written in C, and below it is the same code written with the vector API:


static void mmul_tiled_avx_unrolled(const int n, const float *left, const float *right, float *result) {
    const int block_width = n >= 256 ? 512 : 256;
    const int block_height = n >= 512 ? 8 : n >= 256 ? 16 : 32;
    for (int column_offset = 0; column_offset < n; column_offset += block_width) {
        for (int row_offset = 0; row_offset < n; row_offset += block_height) {
            for (int i = 0; i < n; ++i) {
                for (int j = column_offset; j < column_offset + block_width && j < n; j += 64) {
                    __m256 sum1 = _mm256_load_ps(result + i * n + j);
                    __m256 sum2 = _mm256_load_ps(result + i * n + j + 8);
                    __m256 sum3 = _mm256_load_ps(result + i * n + j + 16);
                    __m256 sum4 = _mm256_load_ps(result + i * n + j + 24);
                    __m256 sum5 = _mm256_load_ps(result + i * n + j + 32);
                    __m256 sum6 = _mm256_load_ps(result + i * n + j + 40);
                    __m256 sum7 = _mm256_load_ps(result + i * n + j + 48);
                    __m256 sum8 = _mm256_load_ps(result + i * n + j + 56);
                    for (int k = row_offset; k < row_offset + block_height && k < n; ++k) {
                        __m256 multiplier = _mm256_set1_ps(left[i * n + k]);
                        sum1 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j), sum1);
                        sum2 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 8), sum2);
                        sum3 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 16), sum3);
                        sum4 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 24), sum4);
                        sum5 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 32), sum5);
                        sum6 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 40), sum6);
                        sum7 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 48), sum7);
                        sum8 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 56), sum8);
                    }
                    _mm256_store_ps(result + i * n + j, sum1);
                    _mm256_store_ps(result + i * n + j + 8, sum2);
                    _mm256_store_ps(result + i * n + j + 16, sum3);
                    _mm256_store_ps(result + i * n + j + 24, sum4);
                    _mm256_store_ps(result + i * n + j + 32, sum5);
                    _mm256_store_ps(result + i * n + j + 40, sum6);
                    _mm256_store_ps(result + i * n + j + 48, sum7);
                    _mm256_store_ps(result + i * n + j + 56, sum8);
                }
            }
        }
    }
}


  private static void mmul(int n, float[] left, float[] right, float[] result) {
    int blockWidth = n >= 256 ? 512 : 256;
    int blockHeight = n >= 512 ? 8 : n >= 256 ? 16 : 32;
    for (int columnOffset = 0; columnOffset < n; columnOffset += blockWidth) {
      for (int rowOffset = 0; rowOffset < n; rowOffset += blockHeight) {
        for (int i = 0; i < n; ++i) {
          for (int j = columnOffset; j < columnOffset + blockWidth && j < n; j += 64) {
            var sum1 = YMM_FLOAT.fromArray(result, i * n + j);
            var sum2 = YMM_FLOAT.fromArray(result, i * n + j + 8);
            var sum3 = YMM_FLOAT.fromArray(result, i * n + j + 16);
            var sum4 = YMM_FLOAT.fromArray(result, i * n + j + 24);
            var sum5 = YMM_FLOAT.fromArray(result, i * n + j + 32);
            var sum6 = YMM_FLOAT.fromArray(result, i * n + j + 40);
            var sum7 = YMM_FLOAT.fromArray(result, i * n + j + 48);
            var sum8 = YMM_FLOAT.fromArray(result, i * n + j + 56);
            for (int k = rowOffset; k < rowOffset + blockHeight && k < n; ++k) {
              var multiplier = YMM_FLOAT.broadcast(left[i * n + k]);
              sum1 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j), sum1);
              sum2 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 8), sum2);
              sum3 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 16), sum3);
              sum4 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 24), sum4);
              sum5 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 32), sum5);
              sum6 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 40), sum6);
              sum7 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 48), sum7);
              sum8 = multiplier.fma(YMM_FLOAT.fromArray(right, k * n + j + 56), sum8);
            }
            sum1.intoArray(result, i * n + j);
            sum2.intoArray(result, i * n + j + 8);
            sum3.intoArray(result, i * n + j + 16);
            sum4.intoArray(result, i * n + j + 24);
            sum5.intoArray(result, i * n + j + 32);
            sum6.intoArray(result, i * n + j + 40);
            sum7.intoArray(result, i * n + j + 48);
            sum8.intoArray(result, i * n + j + 56);
          }
        }
      }
    }
  }

They just aren’t that different, and it’s easy to translate between the two. I wouldn’t expect it to be fast yet though. I have no idea what the scope of work involved in implementing all of the C2 intrinsics to make this possible is, but I assume it’s vast. The class jdk.incubator.vector.VectorIntrinsics seems to contain all of the intrinsics implemented so far, and it doesn’t contain every operation used in my array multiplication code. There is also the question of value types and vector box elimination. I will probably look at this again in the future when more of the JIT compiler work has been done, but I’m starting to get very excited about the possibility of much faster JVM based data processing.

I have written various benchmarks for useful analytical subroutines using the Vector API at github.

Matrix Multiplication Revisited

In a recent post, I took a look at matrix multiplication in pure Java, to see if it can go faster than reported in SIMD Intrinsics on Managed Language Runtimes. I found faster implementations than the paper’s benchmarks implied was possible. Nevertheless, I found that there were some limitations in Hotspot’s autovectoriser that I didn’t expect to see, even in JDK10. Were these limitations somehow fundamental, or can other compilers do better with essentially the same input?

I took a look at the code generated by GCC’s autovectoriser to see what’s possible in C/C++ without resorting to complicated intrinsics. For a bit of fun, I went over to the dark side to squeeze out some a lot of extra performance, which gave inspiration to a simple vectorised Java implementation which can maintain intensity as matrix size increases.

Background

The paper reports a 5x improvement in matrix multiplication throughput as a result of using LMS generated intrinsics. Using GCC as LMS’s backend, I easily reproduced very good throughput, but I found two Java implementations better than the paper’s baseline. The best performing Java implementation proposed in the paper was blocked. This post is not about the LMS benchmarks, but this code is this post’s inspiration.


public void blocked(float[] a, float[] b, float[] c, int n) {
  int BLOCK_SIZE = 8; 
  // GOOD: attempts to do as much work in submatrices
  // GOOD: tries to avoid bringing data through cache multiple times
  for (int kk = 0; kk < n; kk += BLOCK_SIZE) {
    for (int jj = 0; jj < n; jj += BLOCK_SIZE) {
      for (int i = 0; i < n; i++) {
        for (int j = jj; j < jj + BLOCK_SIZE; ++j) {
          // BAD: manual unrolling, bypasses optimisations
          float sum = c[i * n + j]; 
          for (int k = kk; k < kk + BLOCK_SIZE; ++k) {
            // BAD: second read (k * n) requires a gather - bad for cache, bad for dTLB
            // BAD: horizontal sums are inefficient
            sum += a[i * n + k] * b[k * n + j]; 
          }
          c[i * n + j] = sum;
         }
       }
    }
  }
}

I proposed the following implementation for improved cache efficiency and expected it to vectorise automatically.

public void fast(float[] a, float[] b, float[] c, int n) {
   // GOOD: 2x faster than "blocked" - why?
   int in = 0;
   for (int i = 0; i < n; ++i) {
       int kn = 0;
       for (int k = 0; k < n; ++k) {
           float aik = a[in + k];
           // MIXED: passes over c[in:in+n] multiple times per k-value, "free" if n is small
           // MIXED: reloads b[kn:kn+n] repeatedly for each i, bad if n is large, "free" if n is small
           // BAD: doesn't vectorise but should
           for (int j = 0; j < n; ++j) {
               c[in + j] += aik * b[kn + j]; // sequential writes and reads, cache and vectoriser friendly
           }
           kn += n;
       }
       in += n;
    }
}

My code actually doesn’t vectorise, even in JDK10, which really surprised me because the inner loop vectorises if the offsets are always zero. In any case, there is a simple hack involving the use of buffers, which unfortunately thrashes the cache, but narrows the field significantly.

  
  public void fastBuffered(float[] a, float[] b, float[] c, int n) {
    float[] bBuffer = new float[n];
    float[] cBuffer = new float[n];
    int in = 0;
    for (int i = 0; i < n; ++i) {
      int kn = 0;
      for (int k = 0; k < n; ++k) {
        float aik = a[in + k];
        System.arraycopy(b, kn, bBuffer, 0, n);
        saxpy(n, aik, bBuffer, cBuffer);
        kn += n;
      }
      System.arraycopy(cBuffer, 0, c, in, n);
      Arrays.fill(cBuffer, 0f);
      in += n;
    }
  }

I left the problem looking like this, with the “JDKX vectorised” lines using the algorithm above with a buffer hack:

GCC Autovectorisation

The Java code is very easy to translate into C/C++. Before looking at performance I want to get an idea of what GCC’s autovectoriser does. I want to see the code generated at GCC optimisation level 3, with unrolled loops, FMA, and AVX2, which can be seen as follows:

g++ -mavx2 -mfma -march=native -funroll-loops -O3 -S mmul.cpp

The generated assembly code can be seen in full context here. Let’s look at the mmul_saxpy routine first:

static void mmul_saxpy(const int n, const float* left, const float* right, float* result) {
    int in = 0;
    for (int i = 0; i < n; ++i) {
        int kn = 0;
        for (int k = 0; k < n; ++k) {
            float aik = left[in + k];
            for (int j = 0; j < n; ++j) {
                result[in + j] += aik * right[kn + j];
            }
            kn += n;
        }
        in += n;
    }
}

This routine uses SIMD instructions, which means in principle any other compiler could do this too. The inner loop has been unrolled, but this is only by virtue of the -funroll-loops flag. C2 does this sort of thing as standard, but only for hot loops. In general you might not want to unroll loops because of the impact on code size, and it’s great that a JIT compiler can decide only to do this when it’s profitable.

.L9:
  vmovups  (%rdx,%rax), %ymm4
  vfmadd213ps  (%rbx,%rax), %ymm3, %ymm4
  addl  $8, %r10d
  vmovaps  %ymm4, (%r11,%rax)
  vmovups  32(%rdx,%rax), %ymm5
  vfmadd213ps  32(%rbx,%rax), %ymm3, %ymm5
  vmovaps  %ymm5, 32(%r11,%rax)
  vmovups  64(%rdx,%rax), %ymm1
  vfmadd213ps  64(%rbx,%rax), %ymm3, %ymm1
  vmovaps  %ymm1, 64(%r11,%rax)
  vmovups  96(%rdx,%rax), %ymm2
  vfmadd213ps  96(%rbx,%rax), %ymm3, %ymm2
  vmovaps  %ymm2, 96(%r11,%rax)
  vmovups  128(%rdx,%rax), %ymm4
  vfmadd213ps  128(%rbx,%rax), %ymm3, %ymm4
  vmovaps  %ymm4, 128(%r11,%rax)
  vmovups  160(%rdx,%rax), %ymm5
  vfmadd213ps  160(%rbx,%rax), %ymm3, %ymm5
  vmovaps  %ymm5, 160(%r11,%rax)
  vmovups  192(%rdx,%rax), %ymm1
  vfmadd213ps  192(%rbx,%rax), %ymm3, %ymm1
  vmovaps  %ymm1, 192(%r11,%rax)
  vmovups  224(%rdx,%rax), %ymm2
  vfmadd213ps  224(%rbx,%rax), %ymm3, %ymm2
  vmovaps  %ymm2, 224(%r11,%rax)
  addq  $256, %rax
  cmpl  %r10d, 24(%rsp)
  ja  .L9

The mmul_blocked routine is compiled to quite convoluted assembly. It has a huge problem with the expression right[k * n + j], which requires a gather and is almost guaranteed to create 8 cache misses per block for large matrices. Moreover, this inefficiency gets much worse with problem size.

static void mmul_blocked(const int n, const float* left, const float* right, float* result) {
    int BLOCK_SIZE = 8;
    for (int kk = 0; kk < n; kk += BLOCK_SIZE) {
        for (int jj = 0; jj < n; jj += BLOCK_SIZE) {
            for (int i = 0; i < n; i++) {
                for (int j = jj; j < jj + BLOCK_SIZE; ++j) {
                    float sum = result[i * n + j];
                    for (int k = kk; k < kk + BLOCK_SIZE; ++k) {
                        sum += left[i * n + k] * right[k * n + j]; // second read here requires a gather
                    }
                    result[i * n + j] = sum;
                }
            }
        }
    }
}

This compiles to assembly with the unrolled vectorised loop below:

.L114:
  cmpq  %r10, %r9
  setbe  %cl
  cmpq  56(%rsp), %r8
  setnb  %dl
  orl  %ecx, %edx
  cmpq  %r14, %r9
  setbe  %cl
  cmpq  64(%rsp), %r8
  setnb  %r15b
  orl  %ecx, %r15d
  andl  %edx, %r15d
  cmpq  %r11, %r9
  setbe  %cl
  cmpq  48(%rsp), %r8
  setnb  %dl
  orl  %ecx, %edx
  andl  %r15d, %edx
  cmpq  %rbx, %r9
  setbe  %cl
  cmpq  40(%rsp), %r8
  setnb  %r15b
  orl  %ecx, %r15d
  andl  %edx, %r15d
  cmpq  %rsi, %r9
  setbe  %cl
  cmpq  32(%rsp), %r8
  setnb  %dl
  orl  %ecx, %edx
  andl  %r15d, %edx
  cmpq  %rdi, %r9
  setbe  %cl
  cmpq  24(%rsp), %r8
  setnb  %r15b
  orl  %ecx, %r15d
  andl  %edx, %r15d
  cmpq  %rbp, %r9
  setbe  %cl
  cmpq  16(%rsp), %r8
  setnb  %dl
  orl  %ecx, %edx
  andl  %r15d, %edx
  cmpq  %r12, %r9
  setbe  %cl
  cmpq  8(%rsp), %r8
  setnb  %r15b
  orl  %r15d, %ecx
  testb  %cl, %dl
  je  .L111
  leaq  32(%rax), %rdx
  cmpq  %rdx, %r8
  setnb  %cl
  cmpq  %rax, %r9
  setbe  %r15b
  orb  %r15b, %cl
  je  .L111
  vmovups  (%r8), %ymm2
  vbroadcastss  (%rax), %ymm0
  vfmadd132ps  (%r14), %ymm2, %ymm0
  vbroadcastss  4(%rax), %ymm1
  vfmadd231ps  (%r10), %ymm1, %ymm0
  vbroadcastss  8(%rax), %ymm3
  vfmadd231ps  (%r11), %ymm3, %ymm0
  vbroadcastss  12(%rax), %ymm4
  vfmadd231ps  (%rbx), %ymm4, %ymm0
  vbroadcastss  16(%rax), %ymm5
  vfmadd231ps  (%rsi), %ymm5, %ymm0
  vbroadcastss  20(%rax), %ymm2
  vfmadd231ps  (%rdi), %ymm2, %ymm0
  vbroadcastss  24(%rax), %ymm1
  vfmadd231ps  0(%rbp), %ymm1, %ymm0
  vbroadcastss  28(%rax), %ymm3
  vfmadd231ps  (%r12), %ymm3, %ymm0
  vmovups  %ymm0, (%r8)

Benchmarks

I implemented a suite of benchmarks to compare the implementations. You can run them, but since they measure throughput and intensity averaged over hundreds of iterations per matrix size, the full run will take several hours.

g++ -mavx2 -mfma -march=native -funroll-loops -O3 mmul.cpp -o mmul.exe && ./mmul.exe > results.csv

The saxpy routine wins, with blocked fading fast after a middling start.

name size throughput (ops/s) flops/cycle
blocked 64 22770.2 4.5916
saxpy 64 25638.4 5.16997
blocked 128 2736.9 4.41515
saxpy 128 4108.52 6.62783
blocked 192 788.132 4.29101
saxpy 192 1262.45 6.87346
blocked 256 291.728 3.76492
saxpy 256 521.515 6.73044
blocked 320 147.979 3.72997
saxpy 320 244.528 6.16362
blocked 384 76.986 3.35322
saxpy 384 150.441 6.55264
blocked 448 50.4686 3.4907
saxpy 448 95.0752 6.57594
blocked 512 30.0085 3.09821
saxpy 512 65.1842 6.72991
blocked 576 22.8301 3.35608
saxpy 576 44.871 6.59614
blocked 640 15.5007 3.12571
saxpy 640 32.3709 6.52757
blocked 704 12.2478 3.28726
saxpy 704 25.3047 6.79166
blocked 768 8.69277 3.02899
saxpy 768 19.8011 6.8997
blocked 832 7.29356 3.23122
saxpy 832 15.3437 6.7976
blocked 896 4.95207 2.74011
saxpy 896 11.9611 6.61836
blocked 960 3.4467 2.34571
saxpy 960 9.25535 6.29888
blocked 1024 2.02289 1.67082
saxpy 1024 6.87039 5.67463

With GCC autovectorisation, saxpy performs well, maintaining intensity as size increases, albeit well below the theoretical capacity. It would be nice if similar code could be JIT compiled in Java.

Intel Intrinsics

To understand the problem space a bit better, I find out how fast matrix multiplication can get without domain expertise by handcrafting an algorithm with intrinsics. My laptop’s Skylake chip (turbo boost and hyperthreading disabled) is capable of 32 SP flops per cycle per core – Java and the LMS implementation previously fell a long way short of that. It was difficult getting beyond 4f/c with Java, and LMS peaked at almost 6f/c before quickly tailing off. GCC autovectorisation achieved and maintained 7f/c.

To start, I’ll take full advantage of the facility to align the matrices on 64 byte intervals, since I have 64B cache lines, though this might just be voodoo. I take the saxpy routine and replace its kernel with intrinsics. Because of the -funroll-loops option, this will get unrolled without effort.

static void mmul_saxpy_avx(const int n, const float* left, const float* right, float* result) {
    int in = 0;
    for (int i = 0; i < n; ++i) {
        int kn = 0;
        for (int k = 0; k < n; ++k) {
            __m256 aik = _mm256_set1_ps(left[in + k]);
            int j = 0;
            for (; j < n; j += 8) {
                _mm256_store_ps(result + in + j, _mm256_fmadd_ps(aik, _mm256_load_ps(right + kn + j), _mm256_load_ps(result + in + j)));
            }
            for (; j < n; ++j) {
                result[in + j] += left[in + k] * right[kn + j];
            }
            kn += n;
        }
        in += n;
    }
}

This code is actually not a lot faster, if at all, than the basic saxpy above: a lot of aggressive optimisations have already been applied.

Combining Blocked and SAXPY

What makes blocked so poor is the gather and the cache miss, not the concept of blocking itself. A limiting factor for saxpy performance is that the ratio of loads to floating point operations is too high. With this in mind, I tried combining the blocking idea with saxpy, by implementing saxpy multiplications for smaller sub-matrices. This results in a different algorithm with fewer loads per floating point operation, and the inner two loops are swapped. It avoids the gather and the cache miss in blocked. Because the matrices are in row major format, I make the width of the blocks much larger than the height. Also, different heights and widths make sense depending on the size of the matrix, so I choose them dynamically. The design constraints are to avoid gathers and horizontal reduction.

static void mmul_tiled_avx(const int n, const float *left, const float *right, float *result) {
    const int block_width = n >= 256 ? 512 : 256;
    const int block_height = n >= 512 ? 8 : n >= 256 ? 16 : 32;
    for (int row_offset = 0; row_offset < n; row_offset += block_height) {
        for (int column_offset = 0; column_offset < n; column_offset += block_width) {
            for (int i = 0; i < n; ++i) {
                for (int j = column_offset; j < column_offset + block_width && j < n; j += 8) {
                    __m256 sum = _mm256_load_ps(result + i * n + j);
                    for (int k = row_offset; k < row_offset + block_height && k < n; ++k) {
                        sum = _mm256_fmadd_ps(_mm256_set1_ps(left[i * n + k]), _mm256_load_ps(right + k * n + j), sum);
                    }
                    _mm256_store_ps(result + i * n + j, sum);
                }
            }
        }
    }
}

You will see in the benchmark results that this routine really doesn’t do very well compared to saxpy. Finally, I unroll it, which is profitable despite setting -funroll-loops because there is slightly more to this than an unroll. This is a sequence of vertical reductions which have no data dependencies.

static void mmul_tiled_avx_unrolled(const int n, const float *left, const float *right, float *result) {
    const int block_width = n >= 256 ? 512 : 256;
    const int block_height = n >= 512 ? 8 : n >= 256 ? 16 : 32;
    for (int column_offset = 0; column_offset < n; column_offset += block_width) {
        for (int row_offset = 0; row_offset < n; row_offset += block_height) {
            for (int i = 0; i < n; ++i) {
                for (int j = column_offset; j < column_offset + block_width && j < n; j += 64) {
                    __m256 sum1 = _mm256_load_ps(result + i * n + j);
                    __m256 sum2 = _mm256_load_ps(result + i * n + j + 8);
                    __m256 sum3 = _mm256_load_ps(result + i * n + j + 16);
                    __m256 sum4 = _mm256_load_ps(result + i * n + j + 24);
                    __m256 sum5 = _mm256_load_ps(result + i * n + j + 32);
                    __m256 sum6 = _mm256_load_ps(result + i * n + j + 40);
                    __m256 sum7 = _mm256_load_ps(result + i * n + j + 48);
                    __m256 sum8 = _mm256_load_ps(result + i * n + j + 56);
                    for (int k = row_offset; k < row_offset + block_height && k < n; ++k) {
                        __m256 multiplier = _mm256_set1_ps(left[i * n + k]);
                        sum1 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j), sum1);
                        sum2 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 8), sum2);
                        sum3 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 16), sum3);
                        sum4 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 24), sum4);
                        sum5 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 32), sum5);
                        sum6 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 40), sum6);
                        sum7 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 48), sum7);
                        sum8 = _mm256_fmadd_ps(multiplier, _mm256_load_ps(right + k * n + j + 56), sum8);
                    }
                    _mm256_store_ps(result + i * n + j, sum1);
                    _mm256_store_ps(result + i * n + j + 8, sum2);
                    _mm256_store_ps(result + i * n + j + 16, sum3);
                    _mm256_store_ps(result + i * n + j + 24, sum4);
                    _mm256_store_ps(result + i * n + j + 32, sum5);
                    _mm256_store_ps(result + i * n + j + 40, sum6);
                    _mm256_store_ps(result + i * n + j + 48, sum7);
                    _mm256_store_ps(result + i * n + j + 56, sum8);
                }
            }
        }
    }
}

This final implementation is fast, and is probably as good as I am going to manage, without reading papers. This should be a CPU bound problem because the algorithm is O(n^3) whereas the problem size is O(n^2). But the flops/cycle decreases with problem size in all of these implementations. It’s possible that this could be amelioarated by a better dynamic tiling policy. I’m unlikely to be able to fix that.

It does make a huge difference being able to go very low level – handwritten intrinsics with GCC unlock awesome throughput – but it’s quite hard to actually get to the point where you can beat a good optimising compiler. Mind you, there are harder problems to solve this, and you may well be a domain expert.

The benchmark results summarise this best:

name size throughput (ops/s) flops/cycle
saxpy_avx 64 49225.7 9.92632
tiled_avx 64 33680.5 6.79165
tiled_avx_unrolled 64 127936 25.7981
saxpy_avx 128 5871.02 9.47109
tiled_avx 128 4210.07 6.79166
tiled_avx_unrolled 128 15997.6 25.8072
saxpy_avx 192 1603.84 8.73214
tiled_avx 192 1203.33 6.55159
tiled_avx_unrolled 192 4383.09 23.8638
saxpy_avx 256 633.595 8.17689
tiled_avx 256 626.157 8.0809
tiled_avx_unrolled 256 1792.52 23.1335
saxpy_avx 320 284.161 7.1626
tiled_avx 320 323.197 8.14656
tiled_avx_unrolled 320 935.571 23.5822
saxpy_avx 384 161.517 7.03508
tiled_avx 384 188.215 8.19794
tiled_avx_unrolled 384 543.235 23.6613
saxpy_avx 448 99.1987 6.86115
tiled_avx 448 118.588 8.2022
tiled_avx_unrolled 448 314 21.718
saxpy_avx 512 70.0296 7.23017
tiled_avx 512 73.2019 7.55769
tiled_avx_unrolled 512 197.815 20.4233
saxpy_avx 576 46.1944 6.79068
tiled_avx 576 50.6315 7.44294
tiled_avx_unrolled 576 126.045 18.5289
saxpy_avx 640 33.8209 6.81996
tiled_avx 640 37.0288 7.46682
tiled_avx_unrolled 640 92.784 18.7098
saxpy_avx 704 24.9096 6.68561
tiled_avx 704 27.7543 7.44912
tiled_avx_unrolled 704 69.0399 18.53
saxpy_avx 768 19.5158 6.80027
tiled_avx 768 21.532 7.50282
tiled_avx_unrolled 768 54.1763 18.8777
saxpy_avx 832 12.8635 5.69882
tiled_avx 832 14.6666 6.49766
tiled_avx_unrolled 832 37.9592 16.8168
saxpy_avx 896 12.0526 6.66899
tiled_avx 896 13.3799 7.40346
tiled_avx_unrolled 896 34.0838 18.8595
saxpy_avx 960 8.97193 6.10599
tiled_avx 960 10.1052 6.87725
tiled_avx_unrolled 960 21.0263 14.3098
saxpy_avx 1024 6.73081 5.55935
tiled_avx 1024 7.21214 5.9569
tiled_avx_unrolled 1024 12.7768 10.5531

Can we do better in Java?

Writing genuinely fast code gives an indication of how little of the processor Java actually utilises, but is it possible to bring this knowledge over to Java? The saxpy based implementations in my previous post performed well for small to medium sized matrices. Once the matrices grow, however, they become too big to be allowed to pass through cache multiple times: we need hot, small cached data to be replenished from the larger matrix. Ideally we wouldn’t need to make any copies, but it seems that the autovectoriser doesn’t like offsets: System.arraycopy is a reasonably fast compromise. The basic sequential read pattern is validated: even native code requiring a gather does not perform well for this problem. The best effort C++ code translates almost verbatim into this Java code, which is quite fast for large matrices.


public void tiled(float[] a, float[] b, float[] c, int n) {
    final int bufferSize = 512;
    final int width = Math.min(n, bufferSize);
    final int height = Math.min(n, n >= 512 ? 8 : n >= 256 ? 16 : 32);
    float[] sum = new float[bufferSize];
    float[] vector = new float[bufferSize];
    for (int rowOffset = 0; rowOffset < n; rowOffset += height) {
      for (int columnOffset = 0; columnOffset < n; columnOffset += width) {
        for (int i = 0; i < n; ++i) {
          for (int j = columnOffset; j < columnOffset + width && j < n; j += width) {
            int stride = Math.min(n - columnOffset, bufferSize);
            // copy to give autovectorisation a hint
            System.arraycopy(c, i * n + j, sum, 0, stride);
            for (int k = rowOffset; k < rowOffset + height && k < n; ++k) {
              float multiplier = a[i * n + k];
              System.arraycopy(b, k * n  + j, vector, 0, stride);
              for (int l = 0; l < stride; ++l) {
                sum[l] = Math.fma(multiplier, vector[l], sum[l]);
              }
            }
            System.arraycopy(sum, 0, c, i * n + j, stride);
          }
        }
      }
    }
  }

Benchmarking it using the same harness used in the previous post, the performance is ~10% higher for large arrays than my previous best effort. Still, the reality is that this is too slow to be useful. If you need to do linear algebra, use C/C++ for the time being!

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size flops/cycle
fastBuffered thrpt 1 10 53.331195 0.270526 ops/s 448 3.688688696
fastBuffered thrpt 1 10 34.365765 0.16641 ops/s 512 3.548072999
fastBuffered thrpt 1 10 26.128264 0.239719 ops/s 576 3.840914622
fastBuffered thrpt 1 10 19.044509 0.139197 ops/s 640 3.84031059
fastBuffered thrpt 1 10 14.312154 1.045093 ops/s 704 3.841312378
fastBuffered thrpt 1 10 7.772745 0.074598 ops/s 768 2.708411991
fastBuffered thrpt 1 10 6.276182 0.067338 ops/s 832 2.780495238
fastBuffered thrpt 1 10 4.8784 0.067368 ops/s 896 2.699343067
fastBuffered thrpt 1 10 4.068907 0.038677 ops/s 960 2.769160387
fastBuffered thrpt 1 10 2.568101 0.108612 ops/s 1024 2.121136502
tiled thrpt 1 10 56.495366 0.584872 ops/s 448 3.907540754
tiled thrpt 1 10 30.884954 3.221017 ops/s 512 3.188698735
tiled thrpt 1 10 15.580581 0.412654 ops/s 576 2.290381075
tiled thrpt 1 10 9.178969 0.841178 ops/s 640 1.850932038
tiled thrpt 1 10 12.229763 0.350233 ops/s 704 3.282408783
tiled thrpt 1 10 9.371032 0.330742 ops/s 768 3.265334889
tiled thrpt 1 10 7.727068 0.277969 ops/s 832 3.423271628
tiled thrpt 1 10 6.076451 0.30305 ops/s 896 3.362255222
tiled thrpt 1 10 4.916811 0.2823 ops/s 960 3.346215151
tiled thrpt 1 10 3.722623 0.26486 ops/s 1024 3.074720008

Incidental Similarity

I recently saw an interesting class, BitVector, in Apache Arrow, which represents a column of bits, providing minimal or zero copy distribution. The implementation is similar to a bitset but backed by a byte[] rather than a long[]. Given the coincidental similarity in implementation, it’s tempting to look at this, extend its interface and try to use it as a general purpose, distributed bitset. Could this work? Why not just implement some extra methods? Fork it on Github!

This post details the caveats of trying to adapt an abstraction beyond its intended purpose; in a scenario where generic bitset capabilities are added to BitVector without due consideration, examined through the lens of performance. This runs into the observable effect of word widening on throughput, given the constraints imposed by JLS 15.22. In the end, the only remedy is to use a long[], sacrificing the original zero copy design goal. I hope this is a fairly self-contained example of how uncontrolled adaptation can be hostile to the original design goals: having the source code isn’t enough reason to modify it.

Checking bits

How fast is it to check if the bit at index i is set or not? BitVector implements this functionality, and was designed for it. This can be measured by JMH by generating a random long[] and creating a byte[] 8x longer with identical bits. The throughput of checking the value of the bit at random indices can be measured. It turns out that if all you want to do is access bits, byte[] isn’t such a bad choice, and if those bytes are coming directly from the network, it could even be a great choice. I ran the benchmark below and saw that the two operations are similar (within measurement error).


@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Thread)
public class BitSet {

    @Param({"1024", "2048", "4096", "8192"})
    int size;

    private long[] leftLongs;
    private long[] rightLongs;
    private long[] differenceLongs;
    private byte[] leftBytes;
    private byte[] rightBytes;
    private byte[] differenceBytes;

    @Setup(Level.Trial)
    public void init() {
        this.leftLongs = createLongArray(size);
        this.rightLongs = createLongArray(size);
        this.differenceLongs = new long[size];
        this.leftBytes = makeBytesFromLongs(leftLongs);
        this.rightBytes = makeBytesFromLongs(rightLongs);
        this.differenceBytes = new byte[size * 8];
    }

    @Benchmark
    public boolean CheckBit_LongArray() {
        int index = index();
        return (leftLongs[index >>> 6] & (1L << index)) != 0;
    }

    @Benchmark
    public boolean CheckBit_ByteArray() {
        int index = index();
        return ((leftBytes[index >>> 3] & 0xFF) & (1 << (index & 7))) != 0;
    }

    private int index() {
        return ThreadLocalRandom.current().nextInt(size * 64);
    }

    private static byte[] makeBytesFromLongs(long[] array) {
        byte[] bytes = new byte[8 * array.length];
        for (int i = 0; i < array.length; ++i) {
            long word = array[i];
            bytes[8 * i + 7] = (byte) word;
            bytes[8 * i + 6] = (byte) (word >>> 8);
            bytes[8 * i + 5] = (byte) (word >>> 16);
            bytes[8 * i + 4] = (byte) (word >>> 24);
            bytes[8 * i + 3] = (byte) (word >>> 32);
            bytes[8 * i + 2] = (byte) (word >>> 40);
            bytes[8 * i + 1] = (byte) (word >>> 48);
            bytes[8 * i]     = (byte) (word >>> 56);
        }
        return bytes;
    }
}

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
CheckBit_ByteArray thrpt 1 10 174.421170 1.583275 ops/us 1024
CheckBit_ByteArray thrpt 1 10 173.938408 1.445796 ops/us 2048
CheckBit_ByteArray thrpt 1 10 172.522190 0.815596 ops/us 4096
CheckBit_ByteArray thrpt 1 10 167.550530 1.677091 ops/us 8192
CheckBit_LongArray thrpt 1 10 171.639695 0.934494 ops/us 1024
CheckBit_LongArray thrpt 1 10 169.703960 2.427244 ops/us 2048
CheckBit_LongArray thrpt 1 10 169.333360 1.649654 ops/us 4096
CheckBit_LongArray thrpt 1 10 166.518375 0.815433 ops/us 8192

To support this functionality, there’s no reason to choose either way, and it must be very appealing to use bytes as they are delivered from the network, avoiding copying costs. Given that for a database column, this is the only operation needed, and Apache Arrow has a stated aim to copy data as little as possible, this seems like quite a good decision.

Logical Conjugations

But what happens if you try to add a logical operation to BitVector, such as an XOR? We need to handle the fact that bytes are signed and their sign bit must be preserved in promotion, according to the JLS. This would break the bitset, so extra operations are required to keep the 8th bit in its right place. With the widening and its associated workarounds, suddenly the byte[] is a much poorer choice than a long[], and it shows in benchmarks.


    @Benchmark
    public void Difference_ByteArray(Blackhole bh) {
        for (int i = 0; i < leftBytes.length && i < rightBytes.length; ++i) {
            differenceBytes[i] = (byte)((leftBytes[i] & 0xFF) ^ (rightBytes[i] & 0xFF));
        }
        bh.consume(differenceBytes);
    }

    @Benchmark
    public void Difference_LongArray(Blackhole bh) {
        for (int i = 0; i < leftLongs.length && i < rightLongs.length; ++i) {
            differenceLongs[i] = leftLongs[i] ^ rightLongs[i];
        }
        bh.consume(differenceLongs);
    }

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
Difference_ByteArray thrpt 1 10 0.805872 0.038644 ops/us 1024
Difference_ByteArray thrpt 1 10 0.391705 0.017453 ops/us 2048
Difference_ByteArray thrpt 1 10 0.190102 0.008580 ops/us 4096
Difference_ByteArray thrpt 1 10 0.169104 0.015086 ops/us 8192
Difference_LongArray thrpt 1 10 2.450659 0.094590 ops/us 1024
Difference_LongArray thrpt 1 10 1.047330 0.016898 ops/us 2048
Difference_LongArray thrpt 1 10 0.546286 0.014211 ops/us 4096
Difference_LongArray thrpt 1 10 0.277378 0.015663 ops/us 8192

This is a fairly crazy slow down. Why? You need to look at the assembly generated in each case. For long[] it’s demonstrable that logical operations do vectorise. The JLS, specifically section 15.22, doesn’t really give the byte[] implementation a chance. It states that for logical operations, sub dword primitive types must be promoted or widened before the operation. This means that if one were to try to implement this operation with, say AVX2, using 256 bit ymmwords each consisting of 16 bytes, then each ymmword would have to be inflated by a factor of four: it gets complicated quickly, given this constraint. Despite that complexity, I was surprised to see that C2 does use 128 bit xmmwords, but it’s not as fast as using the full 256 bit registers available. This can be seen by printing out the emitted assembly like normal.

movsxd  r10,ebx     

vmovq   xmm2,mmword ptr [rsi+r10+10h]

vpxor   xmm2,xmm2,xmmword ptr [r8+r10+10h]

vmovq   mmword ptr [rax+r10+10h],xmm2

Vectorised Logical Operations in Java 9

This is a short post for my own reference, since I feel I have already done the topic of does Java 9 use AVX for this? to death. Cutting to the chase, Java 9 autovectorises loops to compute logical ANDs, XORs, ORs and ANDNOTs between arrays, making use of the instructions VPXOR, VPOR and VPAND. You can replicate this by running the code at github.

XOR


    @Benchmark
    public long[] xor(LongData state) {
        long[] result = new long[state.data1.length];
        long[] data1 = state.data1;
        long[] data2 = state.data2;
        for (int i = 0; i < data1.length && i < data2.length; ++i) {
            result[i] = data1[i] ^ data2[i];
        }
        return result;
    }

vmovdqu ymm0,ymmword ptr [r10+r13*8+10h]

vpxor   ymm0,ymm0,ymmword ptr [rbx+r13*8+10h]

vmovdqu ymmword ptr [rax+r13*8+10h],ymm0

OR


    @Benchmark
    public long[] or(LongData state) {
        long[] result = new long[state.data1.length];
        long[] data1 = state.data1;
        long[] data2 = state.data2;
        for (int i = 0; i < data1.length && i < data2.length; ++i) {
            result[i] = data1[i] | data2[i];
        }
        return result;
    }

vmovdqu ymm0,ymmword ptr [r10+rsi*8+30h]
 
vpor    ymm0,ymm0,ymmword ptr [rbx+rsi*8+30h]

vmovdqu ymmword ptr [rax+rsi*8+30h],ymm0

AND


    @Benchmark
    public long[] and(LongData state) {
        long[] result = new long[state.data1.length];
        long[] data1 = state.data1;
        long[] data2 = state.data2;
        for (int i = 0; i < data1.length && i < data2.length; ++i) {
            result[i] = data1[i] & data2[i];
        }
        return result;
    }

vmovdqu ymm0,ymmword ptr [r10+r13*8+10h]

vpand   ymm0,ymm0,ymmword ptr [rbx+r13*8+10h]

vmovdqu ymmword ptr [rax+r13*8+10h],ymm0

ANDNOT


    @Benchmark
    public long[] andNot(LongData state) {
        long[] result = new long[state.data1.length];
        long[] data1 = state.data1;
        long[] data2 = state.data2;
        for (int i = 0; i < data1.length && i < data2.length; ++i) {
            result[i] = data1[i] & ~data2[i];
        }
        return result;
    }

vpunpcklqdq xmm0,xmm0,xmm0

vinserti128 ymm0,ymm0,xmm0,1h

vmovdqu ymm1,ymmword ptr [rbx+r13*8+10h]

vpxor   ymm1,ymm1,ymm0

vpand   ymm1,ymm1,ymmword ptr [r10+r13*8+10h]

vmovdqu ymmword ptr [rax+r13*8+10h],ymm1

How much Algebra does C2 Know? Part 2: Distributivity

In part one of this series of posts, I looked at how important associativity and independence are for fast loops. C2 seems to utilise these properties to generate unrolled and pipelined machine code for loops, achieving higher throughput even in cases where the kernel of the loop is 3x slower according to vendor advertised instruction throughputs. C2 has a weird and wonderful relationship with distributivity, and hints from the programmer can both and help hinder the generation of good quality machine code.

Viability and Correctness

Distributivity is the simple notion of factoring out brackets. Is this, in general, a viable loop rewrite strategy? This can be utilised to transform the method Scale into FactoredScale, both of which perform floating point arithmetic:


    @CompilerControl(CompilerControl.Mode.DONT_INLINE)
    @Benchmark
    public double Scale(DoubleData state) {
        double value = 0D;
        double[] data = state.data1;
        for (int i = 0; i < data.length; ++i) {
            value += 3.14159 * data[i];
        }
        return value;
    }

    @CompilerControl(CompilerControl.Mode.DONT_INLINE)
    @Benchmark
    public double FactoredScale(DoubleData state) {
        double value = 0D;
        double[] data = state.data1;
        for (int i = 0; i < data.length; ++i) {
            value += data[i];
        }
        return 3.14159 * value;
    }

Running the project at github with the argument --include .*scale.*, there may be a performance gain to be had from this rewrite, but it isn’t clear cut:

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
FactoredScale thrpt 1 10 7.011606 0.274742 ops/ms 100000
FactoredScale thrpt 1 10 0.621515 0.026853 ops/ms 1000000
Scale thrpt 1 10 6.962434 0.240180 ops/ms 100000
Scale thrpt 1 10 0.671042 0.011686 ops/ms 1000000

With the real numbers it would be completely valid, but floating point arithmetic is not associative. Joseph Darcy explains why in this deep dive on floating point semantics. Broken associativity of addition entails broken distributivity of any operation over it, so the two loops are not equivalent, and they give different outputs (e.g. 15662.513298516365 vs 15662.51329851632 for one sample input). The rewrite isn’t correct even for floating point data, so it isn’t an optimisation that could be applied in good faith, except in a very small number of cases. You have to rewrite the loop yourself and figure out if the small but inevitable differences are acceptable.

Counterintuitive Performance

Integer multiplication is distributive over addition, and we can check if C2 does this rewrite by running the same code with 32 bit integer values, for now fixing a scale factor of 10 (which seems like an innocuous value, no?)


    @CompilerControl(CompilerControl.Mode.DONT_INLINE)
    @Benchmark
    public int Scale_Int(IntData state) {
        int value = 0;
        int[] data = state.data1;
        for (int i = 0; i < data.length; ++i) {
            value += 10 * data[i];
        }
        return value;
    }

    @CompilerControl(CompilerControl.Mode.DONT_INLINE)
    @Benchmark
    public int FactoredScale_Int(IntData state) {
        int value = 0;
        int[] data = state.data1;
        for (int i = 0; i < data.length; ++i) {
            value += data[i];
        }
        return 10 * value;
    }

The results are fascinating:

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
FactoredScale_Int thrpt 1 10 28.339699 0.608075 ops/ms 100000
FactoredScale_Int thrpt 1 10 2.392579 0.506413 ops/ms 1000000
Scale_Int thrpt 1 10 33.335721 0.295334 ops/ms 100000
Scale_Int thrpt 1 10 2.838242 0.448213 ops/ms 1000000

The code is doing thousands more multiplications in less time when the multiplication is not factored out of the loop. So what the devil is going on? Inspecting the assembly for the faster loop is revealing

  0x000001c89e499400: vmovdqu ymm8,ymmword ptr [rbp+r13*4+10h]
  0x000001c89e499407: movsxd  r10,r13d       
  0x000001c89e49940a: vmovdqu ymm9,ymmword ptr [rbp+r10*4+30h]
  0x000001c89e499411: vmovdqu ymm13,ymmword ptr [rbp+r10*4+0f0h]
  0x000001c89e49941b: vmovdqu ymm12,ymmword ptr [rbp+r10*4+50h]
  0x000001c89e499422: vmovdqu ymm4,ymmword ptr [rbp+r10*4+70h]
  0x000001c89e499429: vmovdqu ymm3,ymmword ptr [rbp+r10*4+90h]
  0x000001c89e499433: vmovdqu ymm2,ymmword ptr [rbp+r10*4+0b0h]
  0x000001c89e49943d: vmovdqu ymm0,ymmword ptr [rbp+r10*4+0d0h]
  0x000001c89e499447: vpslld  ymm11,ymm8,1h  
  0x000001c89e49944d: vpslld  ymm1,ymm0,1h   
  0x000001c89e499452: vpslld  ymm0,ymm0,3h   
  0x000001c89e499457: vpaddd  ymm5,ymm0,ymm1 
  0x000001c89e49945b: vpslld  ymm0,ymm2,3h   
  0x000001c89e499460: vpslld  ymm7,ymm3,3h   
  0x000001c89e499465: vpslld  ymm10,ymm4,3h 
  0x000001c89e49946a: vpslld  ymm15,ymm12,3h
  0x000001c89e499470: vpslld  ymm14,ymm13,3h
  0x000001c89e499476: vpslld  ymm1,ymm9,3h  
  0x000001c89e49947c: vpslld  ymm2,ymm2,1h  
  0x000001c89e499481: vpaddd  ymm6,ymm0,ymm2   
  0x000001c89e499485: vpslld  ymm0,ymm3,1h     
  0x000001c89e49948a: vpaddd  ymm7,ymm7,ymm0   
  0x000001c89e49948e: vpslld  ymm0,ymm4,1h     
  0x000001c89e499493: vpaddd  ymm10,ymm10,ymm0
  0x000001c89e499497: vpslld  ymm0,ymm12,1h   
  0x000001c89e49949d: vpaddd  ymm12,ymm15,ymm0
  0x000001c89e4994a1: vpslld  ymm0,ymm13,1h   
  0x000001c89e4994a7: vpaddd  ymm4,ymm14,ymm0 
  0x000001c89e4994ab: vpslld  ymm0,ymm9,1h    
  0x000001c89e4994b1: vpaddd  ymm2,ymm1,ymm0  
  0x000001c89e4994b5: vpslld  ymm0,ymm8,3h    
  0x000001c89e4994bb: vpaddd  ymm8,ymm0,ymm11 
  0x000001c89e4994c0: vphaddd ymm0,ymm8,ymm8  
  0x000001c89e4994c5: vphaddd ymm0,ymm0,ymm3  
  0x000001c89e4994ca: vextracti128 xmm3,ymm0,1h
  0x000001c89e4994d0: vpaddd  xmm0,xmm0,xmm3   
  0x000001c89e4994d4: vmovd   xmm3,ebx         
  0x000001c89e4994d8: vpaddd  xmm3,xmm3,xmm0   
  0x000001c89e4994dc: vmovd   r10d,xmm3        
  0x000001c89e4994e1: vphaddd ymm0,ymm2,ymm2   
  0x000001c89e4994e6: vphaddd ymm0,ymm0,ymm3   
  0x000001c89e4994eb: vextracti128 xmm3,ymm0,1h
  0x000001c89e4994f1: vpaddd  xmm0,xmm0,xmm3   
  0x000001c89e4994f5: vmovd   xmm3,r10d        
  0x000001c89e4994fa: vpaddd  xmm3,xmm3,xmm0   
  0x000001c89e4994fe: vmovd   r11d,xmm3        
  0x000001c89e499503: vphaddd ymm2,ymm12,ymm12  
  0x000001c89e499508: vphaddd ymm2,ymm2,ymm0    
  0x000001c89e49950d: vextracti128 xmm0,ymm2,1h 
  0x000001c89e499513: vpaddd  xmm2,xmm2,xmm0    
  0x000001c89e499517: vmovd   xmm0,r11d         
  0x000001c89e49951c: vpaddd  xmm0,xmm0,xmm2    
  0x000001c89e499520: vmovd   r10d,xmm0         
  0x000001c89e499525: vphaddd ymm0,ymm10,ymm10  
  0x000001c89e49952a: vphaddd ymm0,ymm0,ymm3   
  0x000001c89e49952f: vextracti128 xmm3,ymm0,1h
  0x000001c89e499535: vpaddd  xmm0,xmm0,xmm3
  0x000001c89e499539: vmovd   xmm3,r10d   
  0x000001c89e49953e: vpaddd  xmm3,xmm3,xmm0   
  0x000001c89e499542: vmovd   r11d,xmm3        
  0x000001c89e499547: vphaddd ymm2,ymm7,ymm7   
  0x000001c89e49954c: vphaddd ymm2,ymm2,ymm0   
  0x000001c89e499551: vextracti128 xmm0,ymm2,1h
  0x000001c89e499557: vpaddd  xmm2,xmm2,xmm0 
  0x000001c89e49955b: vmovd   xmm0,r11d      
  0x000001c89e499560: vpaddd  xmm0,xmm0,xmm2 
  0x000001c89e499564: vmovd   r10d,xmm0      
  0x000001c89e499569: vphaddd ymm0,ymm6,ymm6   
  0x000001c89e49956e: vphaddd ymm0,ymm0,ymm3   
  0x000001c89e499573: vextracti128 xmm3,ymm0,1h
  0x000001c89e499579: vpaddd  xmm0,xmm0,xmm3   
  0x000001c89e49957d: vmovd   xmm3,r10d        
  0x000001c89e499582: vpaddd  xmm3,xmm3,xmm0   
  0x000001c89e499586: vmovd   r11d,xmm3        
  0x000001c89e49958b: vphaddd ymm2,ymm5,ymm5   
  0x000001c89e499590: vphaddd ymm2,ymm2,ymm0   
  0x000001c89e499595: vextracti128 xmm0,ymm2,1h
  0x000001c89e49959b: vpaddd  xmm2,xmm2,xmm0
  0x000001c89e49959f: vmovd   xmm0,r11d     
  0x000001c89e4995a4: vpaddd  xmm0,xmm0,xmm2
  0x000001c89e4995a8: vmovd   r10d,xmm0
  0x000001c89e4995ad: vphaddd ymm2,ymm4,ymm4 
  0x000001c89e4995b2: vphaddd ymm2,ymm2,ymm1
  0x000001c89e4995b7: vextracti128 xmm1,ymm2,1h
  0x000001c89e4995bd: vpaddd  xmm2,xmm2,xmm1
  0x000001c89e4995c1: vmovd   xmm1,r10d  
  0x000001c89e4995c6: vpaddd  xmm1,xmm1,xmm2    
  0x000001c89e4995ca: vmovd   ebx,xmm1          

The loop is aggressively unrolled, pipelined, and vectorised. Moreover, the multiplication by ten results not in a multiplication but two left shifts (see VPSLLD) and an addition. Note that x << 1 + x << 3 = x * 10 and C2 seems to know it; this rewrite can be applied because it can be proven statically that the factor is always 10. The “optimised” loop doesn’t vectorise at all (and I have no idea why not – isn’t this a bug? Yes it is.)

  0x000002bbebeda3c8: add     ebx,dword ptr [rbp+r8*4+14h]
  0x000002bbebeda3cd: add     ebx,dword ptr [rbp+r8*4+18h]
  0x000002bbebeda3d2: add     ebx,dword ptr [rbp+r8*4+1ch]
  0x000002bbebeda3d7: add     ebx,dword ptr [rbp+r8*4+20h]
  0x000002bbebeda3dc: add     ebx,dword ptr [rbp+r8*4+24h]
  0x000002bbebeda3e1: add     ebx,dword ptr [rbp+r8*4+28h]
  0x000002bbebeda3e6: add     ebx,dword ptr [rbp+r8*4+2ch]
  0x000002bbebeda3eb: add     r13d,8h           
  0x000002bbebeda3ef: cmp     r13d,r11d         
  0x000002bbebeda3f2: jl      2bbebeda3c0h      
  

This is a special case: data is usually dynamic and variable, so the loop cannot always be proven to be equivalent to a linear combination of bit shifts. The routine is compiled for all possible parameters, not just statically contrived cases like the one above, so you may never see this assembly in the wild. However, even with random factors, the slow looking loop is aggressively optimised in a way the hand “optimised” code is not:


    @CompilerControl(CompilerControl.Mode.DONT_INLINE)
    @Benchmark
    public int Scale_Int_Dynamic(ScaleState state) {
        int value = 0;
        int[] data = state.data;
        int factor = state.randomFactor();
        for (int i = 0; i < data.length; ++i) {
            value += factor * data[i];
        }
        return value;
    }

    @CompilerControl(CompilerControl.Mode.DONT_INLINE)
    @Benchmark
    public int FactoredScale_Int_Dynamic(ScaleState state) {
        int value = 0;
        int[] data = state.data;
        int factor = state.randomFactor();
        for (int i = 0; i < data.length; ++i) {
            value += data[i];
        }
        return factor * value;
    }

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
FactoredScale_Int_Dynamic thrpt 1 10 26.100439 0.340069 ops/ms 100000
FactoredScale_Int_Dynamic thrpt 1 10 1.918011 0.297925 ops/ms 1000000
Scale_Int_Dynamic thrpt 1 10 30.219809 2.977389 ops/ms 100000
Scale_Int_Dynamic thrpt 1 10 2.314159 0.378442 ops/ms 1000000

Far from seeking to exploit distributivity to reduce the number of multiplication instructions, it seems to almost embrace the extraneous operations as metadata to drive optimisations. The assembly for Scale_Int_Dynamic confirms this (it shows vectorised multiplication, not shifts, within the loop):


  0x000001f5ca2fa200: vmovdqu ymm0,ymmword ptr [r13+r14*4+10h]
  0x000001f5ca2fa207: vpmulld ymm11,ymm0,ymm2   
  0x000001f5ca2fa20c: movsxd  r10,r14d          
  0x000001f5ca2fa20f: vmovdqu ymm0,ymmword ptr [r13+r10*4+30h]
  0x000001f5ca2fa216: vmovdqu ymm1,ymmword ptr [r13+r10*4+0f0h]
  0x000001f5ca2fa220: vmovdqu ymm3,ymmword ptr [r13+r10*4+50h]
  0x000001f5ca2fa227: vmovdqu ymm7,ymmword ptr [r13+r10*4+70h]
  0x000001f5ca2fa22e: vmovdqu ymm6,ymmword ptr [r13+r10*4+90h]
  0x000001f5ca2fa238: vmovdqu ymm5,ymmword ptr [r13+r10*4+0b0h]
  0x000001f5ca2fa242: vmovdqu ymm4,ymmword ptr [r13+r10*4+0d0h]
  0x000001f5ca2fa24c: vpmulld ymm9,ymm0,ymm2    
  0x000001f5ca2fa251: vpmulld ymm4,ymm4,ymm2    
  0x000001f5ca2fa256: vpmulld ymm5,ymm5,ymm2    
  0x000001f5ca2fa25b: vpmulld ymm6,ymm6,ymm2    
  0x000001f5ca2fa260: vpmulld ymm8,ymm7,ymm2    
  0x000001f5ca2fa265: vpmulld ymm10,ymm3,ymm2   
  0x000001f5ca2fa26a: vpmulld ymm3,ymm1,ymm2    
  0x000001f5ca2fa26f: vphaddd ymm1,ymm11,ymm11  
  0x000001f5ca2fa274: vphaddd ymm1,ymm1,ymm0    
  0x000001f5ca2fa279: vextracti128 xmm0,ymm1,1h 
  0x000001f5ca2fa27f: vpaddd  xmm1,xmm1,xmm0    
  0x000001f5ca2fa283: vmovd   xmm0,ebx          
  0x000001f5ca2fa287: vpaddd  xmm0,xmm0,xmm1    
  0x000001f5ca2fa28b: vmovd   r10d,xmm0         
  0x000001f5ca2fa290: vphaddd ymm1,ymm9,ymm9    
  0x000001f5ca2fa295: vphaddd ymm1,ymm1,ymm0    
  0x000001f5ca2fa29a: vextracti128 xmm0,ymm1,1h 
  0x000001f5ca2fa2a0: vpaddd  xmm1,xmm1,xmm0    
  0x000001f5ca2fa2a4: vmovd   xmm0,r10d         
  0x000001f5ca2fa2a9: vpaddd  xmm0,xmm0,xmm1    
  0x000001f5ca2fa2ad: vmovd   r11d,xmm0         
  0x000001f5ca2fa2b2: vphaddd ymm0,ymm10,ymm10  
  0x000001f5ca2fa2b7: vphaddd ymm0,ymm0,ymm1    
  0x000001f5ca2fa2bc: vextracti128 xmm1,ymm0,1h 
  0x000001f5ca2fa2c2: vpaddd  xmm0,xmm0,xmm1    
  0x000001f5ca2fa2c6: vmovd   xmm1,r11d         
  0x000001f5ca2fa2cb: vpaddd  xmm1,xmm1,xmm0    
  0x000001f5ca2fa2cf: vmovd   r10d,xmm1         
  0x000001f5ca2fa2d4: vphaddd ymm1,ymm8,ymm8    
  0x000001f5ca2fa2d9: vphaddd ymm1,ymm1,ymm0    
  0x000001f5ca2fa2de: vextracti128 xmm0,ymm1,1h 
  0x000001f5ca2fa2e4: vpaddd  xmm1,xmm1,xmm0    
  0x000001f5ca2fa2e8: vmovd   xmm0,r10d         
  0x000001f5ca2fa2ed: vpaddd  xmm0,xmm0,xmm1    
  0x000001f5ca2fa2f1: vmovd   r11d,xmm0         
  0x000001f5ca2fa2f6: vphaddd ymm0,ymm6,ymm6    
  0x000001f5ca2fa2fb: vphaddd ymm0,ymm0,ymm1    
  0x000001f5ca2fa300: vextracti128 xmm1,ymm0,1h 
  0x000001f5ca2fa306: vpaddd  xmm0,xmm0,xmm1    
  0x000001f5ca2fa30a: vmovd   xmm1,r11d         
  0x000001f5ca2fa30f: vpaddd  xmm1,xmm1,xmm0    
  0x000001f5ca2fa313: vmovd   r10d,xmm1         
  0x000001f5ca2fa318: vphaddd ymm1,ymm5,ymm5    
  0x000001f5ca2fa31d: vphaddd ymm1,ymm1,ymm0    
  0x000001f5ca2fa322: vextracti128 xmm0,ymm1,1h 
  0x000001f5ca2fa328: vpaddd  xmm1,xmm1,xmm0    
  0x000001f5ca2fa32c: vmovd   xmm0,r10d         
  0x000001f5ca2fa331: vpaddd  xmm0,xmm0,xmm1    
  0x000001f5ca2fa335: vmovd   r11d,xmm0         
  0x000001f5ca2fa33a: vphaddd ymm0,ymm4,ymm4    
  0x000001f5ca2fa33f: vphaddd ymm0,ymm0,ymm1    
  0x000001f5ca2fa344: vextracti128 xmm1,ymm0,1h 
  0x000001f5ca2fa34a: vpaddd  xmm0,xmm0,xmm1    
  0x000001f5ca2fa34e: vmovd   xmm1,r11d         
  0x000001f5ca2fa353: vpaddd  xmm1,xmm1,xmm0    
  0x000001f5ca2fa357: vmovd   r10d,xmm1         
  0x000001f5ca2fa35c: vphaddd ymm1,ymm3,ymm3    
  0x000001f5ca2fa361: vphaddd ymm1,ymm1,ymm7    
  0x000001f5ca2fa366: vextracti128 xmm7,ymm1,1h 
  0x000001f5ca2fa36c: vpaddd  xmm1,xmm1,xmm7   
  0x000001f5ca2fa370: vmovd   xmm7,r10d        
  0x000001f5ca2fa375: vpaddd  xmm7,xmm7,xmm1   
  0x000001f5ca2fa379: vmovd   ebx,xmm7         

There are two lessons to be learnt here. The first is that what you see is not what you get. The second is about the correctness of asymptotic analysis. If hierarchical cache renders asymptotic analysis bullshit (linear time but cache friendly algorithms can, and do, outperform logarithmic algorithms with cache misses), optimising compilers render the field practically irrelevant.