Multiplying Matrices, Fast and Slow

I recently read a very interesting blog post about exposing Intel SIMD intrinsics via a fork of the Scala compiler (scala-virtualized), which reports multiplicative improvements in throughput over HotSpot JIT compiled code. The academic paper (SIMD Intrinsics on Managed Language Runtimes), which has been accepted at CGO 2018, proposes a powerful alternative to the traditional JVM approach of pairing dumb programmers with a (hopefully) smart JIT compiler. Lightweight Modular Staging (LMS) allows the generation of an executable binary from a high level representation: handcrafted representations of vectorised algorithms, written in a dialect of Scala, can be compiled natively and later invoked with a single JNI call. This approach bypasses C2 without incurring excessive JNI costs. The freely available benchmarks can be easily run to reproduce the results in the paper, which is an achievement in itself, but some of the Java implementations used as baselines look less efficient than they could be. This post is about improving the efficiency of the Java matrix multiplication the LMS generated code is benchmarked against. Despite finding edge cases where autovectorisation fails, I find it is possible to get performance comparable to LMS with plain Java (and a JDK upgrade).

Two implementations of Java matrix multiplication are provided in the NGen benchmarks: JMMM.baseline – a naive but cache unfriendly matrix multiplication – and JMMM.blocked which is supplied as an improvement. JMMM.blocked is something of a local maximum because it does manual loop unrolling: this actually removes the trigger for autovectorisation analysis. I provide a simple and cache-efficient Java implementation (with the same asymptotic complexity, the improvement is just technical) and benchmark these implementations using JDK8 and the soon to be released JDK10 separately.

public void fast(float[] a, float[] b, float[] c, int n) {
   int in = 0;
   for (int i = 0; i < n; ++i) {
       int kn = 0;
       for (int k = 0; k < n; ++k) {
           float aik = a[in + k];
           for (int j = 0; j < n; ++j) {
               c[in + j] += aik * b[kn + j];
           }
           kn += n;
       }
       in += n;
    }
}

With JDK 1.8.0_131, the “fast” implementation is only 2x faster than the blocked algorithm; this is nowhere near fast enough to match LMS. In fact, LMS does a lot better than 5x blocked (6x-8x) on my Skylake laptop at 2.6GHz, and performs between 2x and 4x better than the improved implementation. Flops / Cycle is calculated as size ^ 3 * 2 / CPU frequency Hz.

====================================================
Benchmarking MMM.jMMM.fast (JVM implementation)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 0.4994459272
          32 | 1.0666533335
          64 | 0.9429120397
         128 | 0.9692385519
         192 | 0.9796619688
         256 | 1.0141446247
         320 | 0.9894415771
         384 | 1.0046245750
         448 | 1.0221353392
         512 | 0.9943527764
         576 | 0.9952093603
         640 | 0.9854689714
         704 | 0.9947153752
         768 | 1.0197765248
         832 | 1.0479691069
         896 | 1.0060121097
         960 | 0.9937347412
        1024 | 0.9056494897
====================================================

====================================================
Benchmarking MMM.nMMM.blocked (LMS generated)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 0.2500390686
          32 | 3.9999921875
          64 | 4.1626523901
         128 | 4.4618695374
         192 | 3.9598982956
         256 | 4.3737341517
         320 | 4.2412225389
         384 | 3.9640163416
         448 | 4.0957167537
         512 | 3.3801071278
         576 | 4.1869326167
         640 | 3.8225244883
         704 | 3.8648224140
         768 | 3.5240611589
         832 | 3.7941562681
         896 | 3.1735179981
         960 | 2.5856903789
        1024 | 1.7817152313
====================================================

====================================================
Benchmarking MMM.jMMM.blocked (JVM implementation)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 0.3333854248
          32 | 0.6336670915
          64 | 0.5733484649
         128 | 0.5987433798
         192 | 0.5819900921
         256 | 0.5473562109
         320 | 0.5623263520
         384 | 0.5583823292
         448 | 0.5657882256
         512 | 0.5430879470
         576 | 0.5269635678
         640 | 0.5595204791
         704 | 0.5297557807
         768 | 0.5493631388
         832 | 0.5471832673
         896 | 0.4769554752
         960 | 0.4985080443
        1024 | 0.4014589400
====================================================

JDK10 is about to be released so it’s worth looking at the effect of recent improvements to C2, including better use of AVX2 and support for vectorised FMA. Since LMS depends on scala-virtualized, which currently only supports Scala 2.11, the LMS implementation cannot be run with a more recent JDK so its performance running in JDK10 could only be extrapolated. Since its raison d’ĂȘtre is to bypass C2, it could be reasonably assumed it is insulated from JVM performance improvements (or regressions). Measurements of floating point operations per cycle provide a sensible comparison, in any case.

Moving away from ScalaMeter, I created a JMH benchmark to see how matrix multiplication behaves in JDK10.

@OutputTimeUnit(TimeUnit.SECONDS)
@State(Scope.Benchmark)
public class MMM {

  @Param({"8", "32", "64", "128", "192", "256", "320", "384", "448", "512" , "576", "640", "704", "768", "832", "896", "960", "1024"})
  int size;

  private float[] a;
  private float[] b;
  private float[] c;

  @Setup(Level.Trial)
  public void init() {
    a = DataUtil.createFloatArray(size * size);
    b = DataUtil.createFloatArray(size * size);
    c = new float[size * size];
  }

  @Benchmark
  public void fast(Blackhole bh) {
    fast(a, b, c, size);
    bh.consume(c);
  }

  @Benchmark
  public void baseline(Blackhole bh) {
    baseline(a, b, c, size);
    bh.consume(c);
  }

  @Benchmark
  public void blocked(Blackhole bh) {
    blocked(a, b, c, size);
    bh.consume(c);
  }

  //
  // Baseline implementation of a Matrix-Matrix-Multiplication
  //
  public void baseline (float[] a, float[] b, float[] c, int n){
    for (int i = 0; i < n; i += 1) {
      for (int j = 0; j < n; j += 1) {
        float sum = 0.0f;
        for (int k = 0; k < n; k += 1) {
          sum += a[i * n + k] * b[k * n + j];
        }
        c[i * n + j] = sum;
      }
    }
  }

  //
  // Blocked version of MMM, reference implementation available at:
  // http://csapp.cs.cmu.edu/2e/waside/waside-blocking.pdf
  //
  public void blocked(float[] a, float[] b, float[] c, int n) {
    int BLOCK_SIZE = 8;
    for (int kk = 0; kk < n; kk += BLOCK_SIZE) {
      for (int jj = 0; jj < n; jj += BLOCK_SIZE) {
        for (int i = 0; i < n; i++) {
          for (int j = jj; j < jj + BLOCK_SIZE; ++j) {
            float sum = c[i * n + j];
            for (int k = kk; k < kk + BLOCK_SIZE; ++k) {
              sum += a[i * n + k] * b[k * n + j];
            }
            c[i * n + j] = sum;
          }
        }
      }
    }
  }

  public void fast(float[] a, float[] b, float[] c, int n) {
    int in = 0;
    for (int i = 0; i < n; ++i) {
      int kn = 0;
      for (int k = 0; k < n; ++k) {
        float aik = a[in + k];
        for (int j = 0; j < n; ++j) {
          c[in + j] = Math.fma(aik,  b[kn + j], c[in + j]);
        }
        kn += n;
      }
      in += n;
    }
  }
}

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size Ratio to blocked Flops/Cycle
baseline thrpt 1 10 1228544.82 38793.17392 ops/s 8 1.061598336 0.483857652
baseline thrpt 1 10 22973.03402 1012.043446 ops/s 32 1.302266947 0.57906183
baseline thrpt 1 10 2943.088879 221.57475 ops/s 64 1.301414733 0.593471609
baseline thrpt 1 10 358.010135 9.342801 ops/s 128 1.292889618 0.577539747
baseline thrpt 1 10 105.758366 4.275503 ops/s 192 1.246415143 0.575804515
baseline thrpt 1 10 41.465557 1.112753 ops/s 256 1.430003946 0.535135851
baseline thrpt 1 10 20.479081 0.462547 ops/s 320 1.154267894 0.516198866
baseline thrpt 1 10 11.686685 0.263476 ops/s 384 1.186535349 0.509027985
baseline thrpt 1 10 7.344184 0.269656 ops/s 448 1.166421127 0.507965526
baseline thrpt 1 10 3.545153 0.108086 ops/s 512 0.81796657 0.366017216
baseline thrpt 1 10 3.789384 0.130934 ops/s 576 1.327168294 0.557048123
baseline thrpt 1 10 1.981957 0.040136 ops/s 640 1.020965271 0.399660104
baseline thrpt 1 10 1.76672 0.036386 ops/s 704 1.168272442 0.474179037
baseline thrpt 1 10 1.01026 0.049853 ops/s 768 0.845514112 0.352024966
baseline thrpt 1 10 1.115814 0.03803 ops/s 832 1.148752171 0.494331667
baseline thrpt 1 10 0.703561 0.110626 ops/s 896 0.938435436 0.389298235
baseline thrpt 1 10 0.629896 0.052448 ops/s 960 1.081741651 0.428685898
baseline thrpt 1 10 0.407772 0.019079 ops/s 1024 1.025356561 0.336801424
blocked thrpt 1 10 1157259.558 49097.48711 ops/s 8 1 0.455782226
blocked thrpt 1 10 17640.8025 1226.401298 ops/s 32 1 0.444656782
blocked thrpt 1 10 2261.453481 98.937035 ops/s 64 1 0.456020355
blocked thrpt 1 10 276.906961 22.851857 ops/s 128 1 0.446704605
blocked thrpt 1 10 84.850033 4.441454 ops/s 192 1 0.461968485
blocked thrpt 1 10 28.996813 7.585551 ops/s 256 1 0.374219842
blocked thrpt 1 10 17.742052 0.627629 ops/s 320 1 0.447208892
blocked thrpt 1 10 9.84942 0.367603 ops/s 384 1 0.429003641
blocked thrpt 1 10 6.29634 0.402846 ops/s 448 1 0.435490676
blocked thrpt 1 10 4.334105 0.384849 ops/s 512 1 0.447472097
blocked thrpt 1 10 2.85524 0.199102 ops/s 576 1 0.419726816
blocked thrpt 1 10 1.941258 0.10915 ops/s 640 1 0.391453182
blocked thrpt 1 10 1.51225 0.076621 ops/s 704 1 0.40588053
blocked thrpt 1 10 1.194847 0.063147 ops/s 768 1 0.416344283
blocked thrpt 1 10 0.971327 0.040421 ops/s 832 1 0.430320551
blocked thrpt 1 10 0.749717 0.042997 ops/s 896 1 0.414837526
blocked thrpt 1 10 0.582298 0.016725 ops/s 960 1 0.39629231
blocked thrpt 1 10 0.397688 0.043639 ops/s 1024 1 0.328472491
fast thrpt 1 10 1869676.345 76416.50848 ops/s 8 1.615606743 0.736364837
fast thrpt 1 10 48485.47216 1301.926828 ops/s 32 2.748484496 1.222132271
fast thrpt 1 10 6431.341657 153.905413 ops/s 64 2.843897392 1.296875098
fast thrpt 1 10 840.601821 45.998723 ops/s 128 3.035683242 1.356053685
fast thrpt 1 10 260.386996 13.022418 ops/s 192 3.068790745 1.417684611
fast thrpt 1 10 107.895708 6.584674 ops/s 256 3.720950575 1.392453537
fast thrpt 1 10 56.245336 2.729061 ops/s 320 3.170170846 1.417728592
fast thrpt 1 10 32.917996 2.196624 ops/s 384 3.342125323 1.433783932
fast thrpt 1 10 20.960189 2.077684 ops/s 448 3.328948087 1.449725854
fast thrpt 1 10 14.005186 0.7839 ops/s 512 3.231390564 1.445957112
fast thrpt 1 10 8.827584 0.883654 ops/s 576 3.091713481 1.297675056
fast thrpt 1 10 7.455607 0.442882 ops/s 640 3.840605937 1.503417416
fast thrpt 1 10 5.322894 0.464362 ops/s 704 3.519850554 1.428638807
fast thrpt 1 10 4.308522 0.153846 ops/s 768 3.605919419 1.501303934
fast thrpt 1 10 3.375274 0.106715 ops/s 832 3.474910097 1.495325228
fast thrpt 1 10 2.320152 0.367881 ops/s 896 3.094703735 1.28379924
fast thrpt 1 10 2.057478 0.150198 ops/s 960 3.533376381 1.400249889
fast thrpt 1 10 1.66255 0.181116 ops/s 1024 4.180538513 1.3731919

Interestingly, the blocked algorithm is now the worst native JVM implementation. The code generated by C2 got a lot faster, but peaks at 1.5 flops/cycle, which still doesn’t compete with LMS. Why? Taking a look at the assembly, it’s clear that the autovectoriser choked on the array offsets and produced scalar SSE2 code, just like the implementations in the paper. I wasn’t expecting this.

vmovss  xmm5,dword ptr [rdi+rcx*4+10h]
vfmadd231ss xmm5,xmm6,xmm2
vmovss  dword ptr [rdi+rcx*4+10h],xmm5

Is this the end of the story? No, with some hacks and the cost of array allocation and a copy or two, autovectorisation can be tricked into working again to generate faster code:


    public void fast(float[] a, float[] b, float[] c, int n) {
        float[] bBuffer = new float[n];
        float[] cBuffer = new float[n];
        int in = 0;
        for (int i = 0; i < n; ++i) {
            int kn = 0;
            for (int k = 0; k < n; ++k) {
                float aik = a[in + k];
                System.arraycopy(b, kn, bBuffer, 0, n);
                saxpy(n, aik, bBuffer, cBuffer);
                kn += n;
            }
            System.arraycopy(cBuffer, 0, c, in, n); 
            Arrays.fill(cBuffer, 0f);
            in += n;
        }
    }

    private void saxpy(int n, float aik, float[] b, float[] c) {
        for (int i = 0; i < n; ++i) {
            c[i] += aik * b[i];
        }
    }

Adding this hack into the NGen benchmark (back in JDK 1.8.0_131) I get closer to the LMS generated code, and beat it beyond L3 cache residency (6MB). LMS is still faster when both matrices fit in L3 concurrently, but by percentage points rather than a multiple. The cost of the hacky array buffers gives the game up for small matrices.

====================================================
Benchmarking MMM.jMMM.fast (JVM implementation)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 0.2500390686
          32 | 0.7710872405
          64 | 1.1302489072
         128 | 2.5113453810
         192 | 2.9525859816
         256 | 3.1180920385
         320 | 3.1081563593
         384 | 3.1458423577
         448 | 3.0493148252
         512 | 3.0551158263
         576 | 3.1430376938
         640 | 3.2169923048
         704 | 3.1026513283
         768 | 2.4190053777
         832 | 3.3358586705
         896 | 3.0755689237
         960 | 2.9996690697
        1024 | 2.2935654309
====================================================

====================================================
Benchmarking MMM.nMMM.blocked (LMS generated)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 1.0001562744
          32 | 5.3330416826
          64 | 5.8180867784
         128 | 5.1717318641
         192 | 5.1639907462
         256 | 4.3418618628
         320 | 5.2536572701
         384 | 4.0801359215
         448 | 4.1337007093
         512 | 3.2678160754
         576 | 3.7973028890
         640 | 3.3557513664
         704 | 4.0103133240
         768 | 3.4188362575
         832 | 3.2189488327
         896 | 3.2316685219
         960 | 2.9985655539
        1024 | 1.7750946796
====================================================

With the benchmark below I calculate flops/cycle with improved JDK10 autovectorisation.


  @Benchmark
  public void fastBuffered(Blackhole bh) {
    fastBuffered(a, b, c, size);
    bh.consume(c);
  }

  public void fastBuffered(float[] a, float[] b, float[] c, int n) {
    float[] bBuffer = new float[n];
    float[] cBuffer = new float[n];
    int in = 0;
    for (int i = 0; i < n; ++i) {
      int kn = 0;
      for (int k = 0; k < n; ++k) {
        float aik = a[in + k];
        System.arraycopy(b, kn, bBuffer, 0, n);
        saxpy(n, aik, bBuffer, cBuffer);
        kn += n;
      }
      System.arraycopy(cBuffer, 0, c, in, n);
      Arrays.fill(cBuffer, 0f);
      in += n;
    }
  }

  private void saxpy(int n, float aik, float[] b, float[] c) {
    for (int i = 0; i < n; ++i) {
      c[i] = Math.fma(aik, b[i], c[i]);
    }
  }

Just as in the modified NGen benchmark, this starts paying off once the matrices have 64 rows and columns. Finally, and it took an upgrade and a hack, I breached 4 Flops per cycle:

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size Flops / Cycle
fastBuffered thrpt 1 10 1047184.034 63532.95095 ops/s 8 0.412429404
fastBuffered thrpt 1 10 58373.56367 3239.615866 ops/s 32 1.471373026
fastBuffered thrpt 1 10 12099.41654 497.33988 ops/s 64 2.439838038
fastBuffered thrpt 1 10 2136.50264 105.038006 ops/s 128 3.446592911
fastBuffered thrpt 1 10 673.470622 102.577237 ops/s 192 3.666730488
fastBuffered thrpt 1 10 305.541519 25.959163 ops/s 256 3.943181586
fastBuffered thrpt 1 10 158.437372 6.708384 ops/s 320 3.993596774
fastBuffered thrpt 1 10 88.283718 7.58883 ops/s 384 3.845306266
fastBuffered thrpt 1 10 58.574507 4.248521 ops/s 448 4.051345968
fastBuffered thrpt 1 10 37.183635 4.360319 ops/s 512 3.839002314
fastBuffered thrpt 1 10 29.949884 0.63346 ops/s 576 4.40270151
fastBuffered thrpt 1 10 20.715833 4.175897 ops/s 640 4.177331789
fastBuffered thrpt 1 10 10.824837 0.902983 ops/s 704 2.905333492
fastBuffered thrpt 1 10 8.285254 1.438701 ops/s 768 2.886995686
fastBuffered thrpt 1 10 6.17029 0.746537 ops/s 832 2.733582608
fastBuffered thrpt 1 10 4.828872 1.316901 ops/s 896 2.671937962
fastBuffered thrpt 1 10 3.6343 1.293923 ops/s 960 2.473381573
fastBuffered thrpt 1 10 2.458296 0.171224 ops/s 1024 2.030442485

The code generated for the core of the loop looks better now:

vmovdqu ymm1,ymmword ptr [r13+r11*4+10h]
vfmadd231ps ymm1,ymm3,ymmword ptr [r14+r11*4+10h]
vmovdqu ymmword ptr [r13+r11*4+10h],ymm1                                               

These benchmark results can be compared on a line chart.

Given this improvement, it would be exciting to see how LMS can profit from JDK9 or JDK10 – does LMS provide the impetus to resume maintenance of scala-virtualized? L3 cache, which the LMS generated code seems to depend on for throughput, is typically shared between cores: a single thread rarely enjoys exclusive access. I would like to see benchmarks for the LMS generated code in the presence of concurrency.

Autovectorised FMA in JDK10

Fused-multiply-add (FMA) allows floating point expressions of the form a * x + b to be evaluated in a single instruction, which is useful for numerical linear algebra. Despite the obvious appeal of FMA, JVM implementors are rather constrained when it comes to floating point arithmetic because Java programs are expected to be reproducible across versions and target architectures. FMA does not produce precisely the same result as the equivalent multiplication and addition instructions (this is caused by the compounding effect of rounding) so its use is a change in semantics rather than an optimisation; the user must opt in. To the best of my knowledge, support for FMA was first proposed in 2000, along with reorderable floating point operations, which would have been activated by a fastfp keyword, but this proposal was withdrawn. In Java 9, the intrinsic Math.fma was introduced to provide access to FMA for the first time.

DAXPY Benchmark

A good use case to evaluate Math.fma is DAXPY from the Basic Linear Algebra Subroutine library. The code below will compile with JDK9+

@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Thread)
public class DAXPY {
  
  double s;

  @Setup(Level.Invocation)
  public void init() {
    s = ThreadLocalRandom.current().nextDouble();
  }

  @Benchmark
  public void daxpyFMA(DoubleData state, Blackhole bh) {
    double[] a = state.data1;
    double[] b = state.data2;
    for (int i = 0; i < a.length; ++i) {
      a[i] = Math.fma(s, b[i], a[i]);
    }
    bh.consume(a);
  }

  @Benchmark
  public void daxpy(DoubleData state, Blackhole bh) {
    double[] a = state.data1;
    double[] b = state.data2;
    for (int i = 0; i < a.length; ++i) {
      a[i] += s * b[i];
    }
    bh.consume(a);
  }
}

Running this benchmark with Java 9, you may wonder why you bothered because the code is actually slower.

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
daxpy thrpt 1 10 25.011242 2.259007 ops/ms 100000
daxpy thrpt 1 10 0.706180 0.046146 ops/ms 1000000
daxpyFMA thrpt 1 10 15.334652 0.271946 ops/ms 100000
daxpyFMA thrpt 1 10 0.623838 0.018041 ops/ms 1000000

This is because using Math.fma disables autovectorisation. Taking a look at PrintAssembly you can see that the naive daxpy routine exploits AVX2, whereas daxpyFMA reverts to scalar usage of SSE2.

// daxpy routine, code taken from main vectorised loop
vmovdqu ymm1,ymmword ptr [r10+rdx*8+10h]
vmulpd  ymm1,ymm1,ymm2
vaddpd  ymm1,ymm1,ymmword ptr [r8+rdx*8+10h]
vmovdqu ymmword ptr [r8+rdx*8+10h],ymm1

// daxpyFMA routine
vmovsd  xmm2,qword ptr [rcx+r13*8+10h]
vfmadd231sd xmm2,xmm0,xmm1
vmovsd  qword ptr [rcx+r13*8+10h],xmm2

Not to worry, this seems to have been fixed in JDK 10. Since Java 10’s release is around the corner, there are early access builds available for all platforms. Rerunning this benchmark, FMA no longer incurs costs, and it doesn’t bring the performance boost some people might expect. The benefit is that there is less floating point error because the total number of roundings is halved.

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
daxpy thrpt 1 10 2582.363228 116.637400 ops/ms 1000
daxpy thrpt 1 10 405.904377 32.364782 ops/ms 10000
daxpy thrpt 1 10 25.210111 1.671794 ops/ms 100000
daxpy thrpt 1 10 0.608660 0.112512 ops/ms 1000000
daxpyFMA thrpt 1 10 2650.264580 211.342407 ops/ms 1000
daxpyFMA thrpt 1 10 389.274693 43.567450 ops/ms 10000
daxpyFMA thrpt 1 10 24.941172 2.393358 ops/ms 100000
daxpyFMA thrpt 1 10 0.671310 0.158470 ops/ms 1000000

// vectorised daxpyFMA routine, code taken from main loop (you can still see the old code in pre/post loops)
vmovdqu ymm0,ymmword ptr [r9+r13*8+10h]
vfmadd231pd ymm0,ymm1,ymmword ptr [rbx+r13*8+10h]
vmovdqu ymmword ptr [r9+r13*8+10h],ymm0

Spliterator Characteristics and Performance

The streams API has been around for a while now, and I’m a big fan of it. It allows for a clean declarative programming style, which permits various optimisations to occur, and keeps the pastafarians at bay. I also think the Stream is the perfect abstraction for data interchange across API boundaries. This is partly because a Stream is lazy, meaning you don’t need to pay for consumption until you actually need to, and partly because a Stream can only be used once and there can be no ambiguity about ownership. If you supply a Stream to an API, you must expect that it has been used and so must discard it. This almost entirely eradicates defensive copies and can mean that no intermediate data structures need ever exist. Despite my enthusiasm for this abstraction, there’s some weirdness in this API when you scratch beneath surface.

I wanted to find a way to quickly run length encode an IntStream and found it difficult to make this code as fast as I’d like it to be. The code is too slow because it’s necessary to inspect each int, even when there is enough context available to potentially apply optimisations such as skipping over ranges. It’s likely that I am experiencing the friction of treating Stream as a data interchange format, which wasn’t one of its design goals, but this led me to investigate spliterator characteristics (there is no contiguous characteristic, which could speed up RLE greatly) and their relationship with performance.

Spliterator Characteristics

Streams have spliterators, which control iteration and splitting behaviour. If you want to process a stream in parallel, it is the spliterator which dictates how to split the stream, if possible. There’s more to a spliterator than parallel execution though, and each single threaded execution can be optimised based on the characteristics bit mask. The different characteristics are as follows:

  • ORDERED promises that there is an order. For instance, trySplit is guaranteed to give a prefix of elements.
  • DISTINCT a promise that each element in the stream is unique.
  • SORTED a promise that the stream is already sorted.
  • SIZED promises the size of the stream is known. This is not true when a call to iterate generates the stream.
  • NONNULL promises that no elements in the stream are null.
  • IMMUTABLE promises the underlying data will not change.
  • CONCURRENT promises that the underlying data can be modified concurrently. Must not also be IMMUTABLE.
  • SUBSIZED promises that the sizes of splits are known, must also be SIZED.

There’s javadoc for all of these flags, which should be your point of reference, and you need to read it because you wouldn’t guess based on relative performance. For instance, IntStream.range(inclusive, exclusive) creates an RangeIntSpliterator with the characteristics ORDERED | SIZED | SUBSIZED | IMMUTABLE | NONNULL | DISTINCT | SORTED. This means that this stream has no duplicates, no nulls, is already sorted in natural order, the size is known, and it will be chunked deterministically. The data and the iteration order never change, and if we split it, we will always get the same first chunk. So these code snippets should have virtually the same performance:


    @Benchmark
    public long countRange() {
        return IntStream.range(0, size).count();
    }

    @Benchmark
    public long countRangeDistinct() {
        return IntStream.range(0, size).distinct().count();
    }

This is completely at odds with observations. Even though the elements are already distinct, and metadata exists to support this, requesting the distinct elements decimates performance.

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
countRange thrpt 1 10 49.465729 1.804123 ops/us 262144
countRangeDistinct thrpt 1 10 0.000395 0.000002 ops/us 262144

It turns out this is because IntStream.distinct has a one-size-fits-all implementation which completely ignores the DISTINCT characteristic, and goes ahead and boxes the entire range.

    // from java.util.stream.IntPipeline
    @Override
    public final IntStream distinct() {
        // While functional and quick to implement, this approach is not very efficient.
        // An efficient version requires an int-specific map/set implementation.
        return boxed().distinct().mapToInt(i -> i);
    }

There is even more observable weirdness. If we wanted to calculate the sum of the first 1000 natural numbers, these two snippets should have the same performance. Requesting what should be a redundant sort doubles the throughput.


    @Benchmark 
    public long headSum() {
        return IntStream.range(0, size).limit(1000).sum();
    }

    @Benchmark
    public long sortedHeadSum() {
        return IntStream.range(0, size).sorted().limit(1000).sum();
    }

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
headSum thrpt 1 10 0.209763 0.002478 ops/us 262144
sortedHeadSum thrpt 1 10 0.584227 0.006004 ops/us 262144

In fact, you would have a hard time finding a relationship between Spliterator characteristics and performance, but you can see cases of characteristics driving optimisations if you look hard enough, such as in IntStream.count, where the SIZED characteristic is used.

    // see java.util.stream.ReduceOps.makeIntCounting
    @Override
    public <P_IN> Long evaluateSequential(PipelineHelper<Integer> helper, Spliterator<P_IN> spliterator) {
        if (StreamOpFlag.SIZED.isKnown(helper.getStreamAndOpFlags()))
            return spliterator.getExactSizeIfKnown();
        return super.evaluateSequential(helper, spliterator);
    }

This is a measurably worthwhile optimisation, when benchmarked against the unsized spliterator created by IntStream.iterate:


    @Benchmark
    public long countIterator() {
        return IntStream.iterate(0, i -> i < size, i -> i + 1).count();
    }

    @Benchmark
    public long countRange() {
        return IntStream.range(0, size).count();
    }

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
countIterator thrpt 1 10 0.001198 0.001629 ops/us 262144
countRange thrpt 1 10 43.166065 4.628715 ops/us 262144

What about limit, that’s supposed to be useful for speeding up streams by limiting the amount of work done? Unfortunately not. It actually makes things potentially much worse. In SliceOps.flags, we see it will actually disable SIZED operations:

    //see java.util.stream.SliceOps
    private static int flags(long limit) {
        return StreamOpFlag.NOT_SIZED | ((limit != -1) ? StreamOpFlag.IS_SHORT_CIRCUIT : 0);
    }

This has a significant effect on performance, as can be seen in the following benchmark:


    @Benchmark
    public long countRange() {
        return IntStream.range(0, size).count();
    }

    @Benchmark
    public long countHalfRange() {
        return IntStream.range(0, size).limit(size / 2).count();
    }

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: size
countHalfRange thrpt 1 10 0.003632 0.003363 ops/us 262144
countRange thrpt 1 10 44.859998 6.191411 ops/us 262144

It’s almost as if there were grand plans involving characteristic based optimisation, and perhaps time ran out (IntStream.distinct has a very apologetic comment) or others were better on paper than in reality. In any case, it looks like they aren’t as influential as you might expect. Given that the relationship between the characteristics which exist and performance is flaky at best, it’s unlikely that a new one would get implemented, but I think the characteristic CONTIGUOUS is missing.

Don’t Fear the Builder

Writing a nontrivial piece of code requires a programmer to make countless tiny decisions, and to produce good quality code quickly, many of these decisions must be made subconciously. But what happens if your instinct is mostly superstitious bullshit? I instinctively avoid allocating objects when I don’t have to, so would not voluntarily resort to using things like EqualsBuilder or HashCodeBuilder from Apache Commons. I feel dismayed whenever I am asked about the relationship between hash code and equals in interviews for contracts. This indicates a lack of mutual respect, but what’s revealing is that enough people passing themselves off as developers don’t know this stuff that it makes the question worth asking. EqualsBuilder and HashCodeBuilder make it so you don’t actually need to know, making it safe to put any object in a HashSet, whoever wrote the class. But should you use these classes just because they protect the naive (and you from the naive)? Is it a runtime tax? Or is any perceived cost magically eradicated by the JVM? It’s time to benchmark my instinct.

Baseline

There are definitely no allocations in the code IntelliJ (or similar) would generate for equals and hashCode automatically. I will use this generated code as the baseline: if I can match or beat this code with EqualsBuilder or HashCodeBuilder respectively, I can dismiss my prejudice. Here’s the baseline class:


public class Baseline {
    
    private final String a;
    private final int b;
    private final String c;
    private final Instant d;

    public Baseline(String a, int b, String c, Instant d) {
        this.a = a;
        this.b = b;
        this.c = c;
        this.d = d;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;

        Baseline baseline = (Baseline) o;

        if (b != baseline.b) return false;
        if (!a.equals(baseline.a)) return false;
        if (!c.equals(baseline.c)) return false;
        return d.equals(baseline.d);
    }

    @Override
    public int hashCode() {
        int result = a.hashCode();
        result = 31 * result + b;
        result = 31 * result + c.hashCode();
        result = 31 * result + d.hashCode();
        return result;
    }
}

It’s not pretty and it’s a real pain to keep generated code up to date when adding and removing fields (you can delete it and regenerate it but it creates merge noise) but there’s no obvious way to improve on this code. It’s correct and looks fast. To benchmark alternatives, I want to look at the effect on set membership queries against a HashSet of various sizes.

Builder Implementation

I will contrast this with the following code using code from Apache Commons (incidentally, also generated by IntelliJ):


public class Builder {

    private final String a;
    private final int b;
    private final String c;
    private final Instant d;

    public Builder(String a, int b, String c, Instant d) {
        this.a = a;
        this.b = b;
        this.c = c;
        this.d = d;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;

        if (o == null || getClass() != o.getClass()) return false;

        Builder builder = (Builder) o;

        return new EqualsBuilder()
                .append(b, builder.b)
                .append(a, builder.a)
                .append(c, builder.c)
                .append(d, builder.d)
                .isEquals();
    }

    @Override
    public int hashCode() {
        return new HashCodeBuilder(17, 37)
                .append(a)
                .append(b)
                .append(c)
                .append(d)
                .toHashCode();
    }
}

This code is a bit neater and easier to maintain, but does it have an observable cost?

HashSet Benchmark

For a parametrised range of HashSet sizes, I measure throughput for calls to contains. This requires that both hashCode and equals are called, the latter potentially several times. There is bias in this benchmark, because the implementation producing the best hash function will minimise the calls to equals, but I am willing to reward the better implementation rather than maniacally isolate any attributable allocation cost.

The code is simple, if a little repetitive. For each implementation, I make some random instances of my class, and make one that I can be sure isn’t in the set. I measure the cost of repeated invocations to the hashCode and equals methods by finding an object, and searching but failing to find an object separately.


@State(Scope.Benchmark)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public class HashSetContainsBenchmark {

    @Param({"8", "96", "1024", "8192"})
    int setSize;

    private HashSet<Baseline> baselineSet;
    private Baseline presentBaseline;
    private Baseline missingBaseline;

    private HashSet<Builder> builderSet;
    private Builder presentBuilder;
    private Builder missingBuilder;

    @Setup(Level.Trial)
    public void setup() {
        setupBaselineState();
        setupBuilderState();
    }

    private void setupBaselineState() {
        baselineSet = new HashSet<>();
        for (int i = 0; i < setSize; ++i) {
            while(!baselineSet.add(newBaselineInstance()));
        }
        presentBaseline = baselineSet.iterator().next();
        do {
            missingBaseline = newBaselineInstance();
        } while (baselineSet.contains(missingBaseline));
    }

    private void setupBuilderState() {
        builderSet = new HashSet<>();
        for (int i = 0; i < setSize; ++i) {
            while(!builderSet.add(newBuilderInstance()));
        }
        presentBuilder = builderSet.iterator().next();
        do {
            missingBuilder = newBuilderInstance();
        } while (builderSet.contains(missingBuilder));
    }

    @Benchmark
    public boolean findPresentBaselineInstance() {
        return baselineSet.contains(presentBaseline);
    }

    @Benchmark
    public boolean findMissingBaselineInstance() {
        return baselineSet.contains(missingBaseline);
    }

    @Benchmark
    public boolean findPresentBuilderInstance() {
        return builderSet.contains(presentBuilder);
    }

    @Benchmark
    public boolean findMissingBuilderInstance() {
        return builderSet.contains(missingBuilder);
    }

    private static Baseline newBaselineInstance() {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        return new Baseline(UUID.randomUUID().toString(),
                random.nextInt(),
                UUID.randomUUID().toString(),
                Instant.ofEpochMilli(random.nextLong(0, Instant.now().toEpochMilli())));
    }

    private static Builder newBuilderInstance() {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        return new Builder(UUID.randomUUID().toString(),
                random.nextInt(),
                UUID.randomUUID().toString(),
                Instant.ofEpochMilli(random.nextLong(0, Instant.now().toEpochMilli())));
    }
}

Running this benchmark with org.apache.commons:commons-lang3:3.7 yields the throughput results below. There is an occasional and slight throughput degradation with the builders, but the builder implementation is sometimes faster in this benchmark, but the difference isn’t large enough to worry about.

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: setSize
findMissingBaselineInstance thrpt 1 10 153.878772 7.993036 ops/us 8
findMissingBaselineInstance thrpt 1 10 135.835085 6.031588 ops/us 96
findMissingBaselineInstance thrpt 1 10 137.716073 25.731562 ops/us 1024
findMissingBaselineInstance thrpt 1 10 123.896741 15.484023 ops/us 8192
findMissingBuilderInstance thrpt 1 10 142.051528 4.177352 ops/us 8
findMissingBuilderInstance thrpt 1 10 132.326909 1.017351 ops/us 96
findMissingBuilderInstance thrpt 1 10 127.220440 8.680761 ops/us 1024
findMissingBuilderInstance thrpt 1 10 140.029701 6.272960 ops/us 8192
findPresentBaselineInstance thrpt 1 10 139.714236 1.626873 ops/us 8
findPresentBaselineInstance thrpt 1 10 140.244864 1.777298 ops/us 96
findPresentBaselineInstance thrpt 1 10 135.273760 7.587937 ops/us 1024
findPresentBaselineInstance thrpt 1 10 133.158555 5.101069 ops/us 8192
findPresentBuilderInstance thrpt 1 10 111.196443 18.804103 ops/us 8
findPresentBuilderInstance thrpt 1 10 124.008441 5.182294 ops/us 96
findPresentBuilderInstance thrpt 1 10 126.729750 4.286002 ops/us 1024
findPresentBuilderInstance thrpt 1 10 124.159285 5.287920 ops/us 8192

Where did the builder go?

It’s worth taking a look at the compiler output to find out what happened. Considering the case where the object isn’t found in the set, we can see that the code gets inlined, and the hash code must be quite good. By enabling the args -XX:+PrintCompilation -XX:+UnlockDiagnosticVMOptions -XX:+PrintInlining it’s clear to see that the entire hash code generation gets inlined into the calling loop, and that equals is never executed (suggesting that there are no collisions to resolve here) .

                              @ 8   java.util.HashSet::contains (9 bytes)   inline (hot)
                                @ 5   java.util.HashMap::containsKey (18 bytes)   inline (hot)
                                 \-> TypeProfile (17571/17571 counts) = java/util/HashMap
                                  @ 2   java.util.HashMap::hash (20 bytes)   inline (hot)
                                    @ 9   com.openkappa.simd.builder.Builder::hashCode (43 bytes)   inline (hot)
                                      @ 8   org.apache.commons.lang3.builder.HashCodeBuilder::<init> (60 bytes)   inline (hot)
                                        @ 1   java.lang.Object::<init> (1 bytes)   inline (hot)
                                        @ 26   org.apache.commons.lang3.Validate::isTrue (18 bytes)   unloaded signature classes
                                        @ 46   org.apache.commons.lang3.Validate::isTrue (18 bytes)   unloaded signature classes
                                      @ 15   org.apache.commons.lang3.builder.HashCodeBuilder::append (58 bytes)   inline (hot)
                                        @ 21   java.lang.Object::getClass (0 bytes)   (intrinsic)
                                        @ 24   java.lang.Class::isArray (0 bytes)   (intrinsic)
                                        @ 49   java.lang.String::hashCode (49 bytes)   inline (hot)
                                          @ 19   java.lang.String::isLatin1 (19 bytes)   inline (hot)
                                          @ 29   java.lang.StringLatin1::hashCode (42 bytes)   inline (hot)
                                      @ 22   org.apache.commons.lang3.builder.HashCodeBuilder::append (17 bytes)   inline (hot)
                                      @ 29   org.apache.commons.lang3.builder.HashCodeBuilder::append (58 bytes)   inline (hot)
                                        @ 21   java.lang.Object::getClass (0 bytes)   (intrinsic)
                                        @ 24   java.lang.Class::isArray (0 bytes)   (intrinsic)
                                        @ 49   java.lang.String::hashCode (49 bytes)   inline (hot)
                                          @ 19   java.lang.String::isLatin1 (19 bytes)   inline (hot)
                                          @ 29   java.lang.StringLatin1::hashCode (42 bytes)   inline (hot)
                                      @ 36   org.apache.commons.lang3.builder.HashCodeBuilder::append (58 bytes)   inline (hot)
                                        @ 21   java.lang.Object::getClass (0 bytes)   (intrinsic)
                                        @ 24   java.lang.Class::isArray (0 bytes)   (intrinsic)
                                        @ 49   java.time.Instant::hashCode (22 bytes)   inline (hot)
                                      @ 39   org.apache.commons.lang3.builder.HashCodeBuilder::toHashCode (5 bytes)   accessor
                                  @ 6   java.util.HashMap::getNode (148 bytes)   inline (hot)
                                    @ 59   com.openkappa.simd.builder.Builder::equals (84 bytes)   never executed
                                    @ 126   com.openkappa.simd.builder.Builder::equals (84 bytes)   never executed

So the code is clearly small enough to get inlined. But what about the builder itself, isn’t it allocated? Not if it doesn’t escape. On a debug JVM, it’s possible to observe the removal of the builder’s allocation the JVM args -XX:+PrintEscapeAnalysis -XX:+PrintEliminateAllocations. However, on a production JVM, it can be observed indirectly by measuring the difference when escape analysis is disabled with -XX:-DoEscapeAnalysis.

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit Param: setSize
findMissingBaselineInstance thrpt 1 10 150.972758 5.095163 ops/us 8
findMissingBaselineInstance thrpt 1 10 140.236057 2.222751 ops/us 96
findMissingBaselineInstance thrpt 1 10 132.680494 5.225503 ops/us 1024
findMissingBaselineInstance thrpt 1 10 115.220385 4.232488 ops/us 8192
findMissingBuilderInstance thrpt 1 10 68.869355 4.479944 ops/us 8
findMissingBuilderInstance thrpt 1 10 67.786351 5.980353 ops/us 96
findMissingBuilderInstance thrpt 1 10 71.113062 3.057181 ops/us 1024
findMissingBuilderInstance thrpt 1 10 75.088839 1.592294 ops/us 8192
findPresentBaselineInstance thrpt 1 10 141.425675 2.630898 ops/us 8
findPresentBaselineInstance thrpt 1 10 142.636359 2.854795 ops/us 96
findPresentBaselineInstance thrpt 1 10 143.939199 0.475918 ops/us 1024
findPresentBaselineInstance thrpt 1 10 142.006635 3.098352 ops/us 8192
findPresentBuilderInstance thrpt 1 10 71.546971 6.152584 ops/us 8
findPresentBuilderInstance thrpt 1 10 62.451705 11.896730 ops/us 96
findPresentBuilderInstance thrpt 1 10 68.903442 3.263955 ops/us 1024
findPresentBuilderInstance thrpt 1 10 69.576038 4.047581 ops/us 8192

Without escape analysis, the baseline code hardly slows down, whereas the change in the builder benchmarks is stark. The difference of differences here isolates the cost of allocating the HashCodeBuilder on the heap, though the builder has not been allocated on the stack – it has been replaced by its fields, which are allocated on the stack. Escape analysis is enabled by default, and the key point here is that even if the code looks like it might be slower, it might be as good or better – it depends what the JIT does with it.

Beware Collection Factory Methods

I saw an interesting tweet referencing a Github issue where the impact of including an (in my view) unnecessary implementation of the List interface impacted inlining decisions, causing 20x degradation in throughput. Guava’s ImmutableList is my favourite class to seek and destroy because of the way it is often used – it tends to be associated with unnecessary copying where encapsulation would be a better solution. I had assumed performance gains won from finding and deleting all the instances of ImmutableList had been thanks to relieving the garbage collector from medieval torture. The performance degradation observed in the benchmark is caused by use of ImmutableList, along with all its subclasses, alongside ArrayList, making calls to List bimorphic at best, causing the JIT compiler to generate slower code. I may have inadvertently profited from better inlining in the past simply by removing as many ImmutableLists as possible!

This post doesn’t go into any details about the various mechanisms of method dispatch, and if you want to understand the impact of polymorphism on inlining, bookmark Aleksey Shipilev’s authoritative post and read it when you have some time to really concentrate.

Without resorting to using LinkedList, is it possible to contrive cases in Java 9 where performance is severely degraded by usages of Collections.unmodifiableList and List.of factory methods? Along with ArrayList, these are random access data structures so this should highlight the potential performance gains inlining can give.

The methodology is very simple: I randomly vary the List implementation and plug it into the same algorithm. It is cruder than you would see in Aleksey Shipilev’s post because I’ve targeted only the worst case by creating equal bias between implementations. Aleksey demonstrates that inlining decisions are statistical and opportunistic (the JIT can guess and later deoptimise), and if 90% of your call sites dispatch to the same implementation, it doesn’t matter as much as when the choice is made uniformly. It will vary from application to application, but it could easily be as bad as the case I present if List is used polymorphically.

I created five benchmarks which produce the same number, the same way. Three of these benchmarks only ever call into a single implementation of List and will be inlined monomorphically, to avoid bias, the result is XOR’d with a call to ThreadLocalRandom.current().nextInt() because the other benchmarks need this. One benchmark only ever calls into List.of and ArrayList, then one benchmark randomly chooses a list for each invocation. The difference is stark. You can really screw up performance by making the methods on List megamorphic.

Benchmark Mode Threads Samples Score Score Error (99.9%) Unit
sumLength_ArrayList thrpt 1 10 55.785270 3.218552 ops/us
sumLength_Factory thrpt 1 10 58.565918 2.852415 ops/us
sumLength_Random2 thrpt 1 10 35.842255 0.684658 ops/us
sumLength_Random3 thrpt 1 10 11.177564 0.080164 ops/us
sumLength_Unmodifiable thrpt 1 10 51.776108 3.751297 ops/us


@State(Scope.Thread)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public class MegamorphicList {

    private List<String>[] strings;

    @Setup(Level.Trial)
    public void init() {
        strings = new List[]{getArrayList(6), getFactoryList6(), getUnModifiableList(6)};
    }

    @Benchmark
    public int sumLength_ArrayList(Blackhole bh) {
        List<String> list = strings[0];
        int blackhole = 0;
        for (int i = 0; i < list.size(); ++i) {
            blackhole += list.get(i).length();
        }
        return blackhole ^ ThreadLocalRandom.current().nextInt(3);
    }

    @Benchmark
    public int sumLength_Factory() {
        List<String> list = strings[1];
        int blackhole = 0;
        for (int i = 0; i < list.size(); ++i) {
            blackhole += list.get(i).length();
        }
        return blackhole ^ ThreadLocalRandom.current().nextInt(3);
    }

    @Benchmark
    public int sumLength_Unmodifiable() {
        List<String> list = strings[2];
        int blackhole = 0;
        for (int i = 0; i < list.size(); ++i) {
            blackhole += list.get(i).length();
        }
        return blackhole ^ ThreadLocalRandom.current().nextInt(3);
    }

    @Benchmark
    public int sumLength_Random2() {
        List<String> list = strings[ThreadLocalRandom.current().nextInt(2)];
        int blackhole = 0;
        for (int i = 0; i < list.size(); ++i) {
            blackhole += list.get(i).length();
        }
        return blackhole;
    }

    @Benchmark
    public int sumLength_Random3() {
        List<String> list = strings[ThreadLocalRandom.current().nextInt(3)];
        int blackhole = 0;
        for (int i = 0; i < list.size(); ++i) {
            blackhole += list.get(i).length();
        }
        return blackhole;
    }

    private List<String> getUnModifiableList(int size) {
        return Collections.unmodifiableList(getArrayList(size));
    }

    private List<String> getFactoryList6() {
        return List.of(randomString(),
                       randomString(),
                       randomString(),
                       randomString(),
                       randomString(),
                       randomString()
                );
    }

    private List<String> getArrayList(int size) {
        List<String> list = new ArrayList<>();
        for (int i = 0; i < size; ++i) {
            list.add(randomString());
        }
        return list;
    }

    private String randomString() {
        return new String(DataUtil.createByteArray(ThreadLocalRandom.current().nextInt(10, 20)));
    }

}

Since writing this post, I have been challenged on whether this result is due to failure to inline or not. This can be easily verified by setting the following JVM arguments to print compilation:

-XX:+PrintCompilation -XX:+UnlockDiagnosticVMOptions -XX:+PrintInlining

You will see the ArrayList and ListN get inlined quickly in isolation:

\-> TypeProfile (19810/19810 counts) = java/util/ArrayList
@ 27   java.util.ArrayList::get (15 bytes)   inline (hot)
...
\-> TypeProfile (363174/363174 counts) = java/util/ImmutableCollections$ListN
@ 24   java.util.ImmutableCollections$ListN::get (17 bytes)   inline (hot)

However, the call remains virtual (and not inlined) when three or more implementations are present:

@ 30   java.util.List::get (0 bytes)   virtual call

I didn’t even bother to use factory methods with different arity, because three is the magic number. Syntactic sugar is nice, but use them with caution.