Co-locating Spark Partitions with HBase Regions

HBase scans can be accelerated if they start and stop on a single region server. IO costs can be reduced further if the scan is executed on the same machine as the region server. This article is about extending the Spark RDD abstraction to load an RDD from an HBase table so each partition is co-located with a region server. This pattern could be adopted to read data into Spark from other sharded data stores, whenever there is a metadata protocol available to dictate partitioning.

The strategy involves creating a custom implementation of the Spark class RDD, which understands how to create partitions from metadata about HBase regions. To read data from HBase, we want to execute a scan on a single region server, and we want to execute on the same machine as the region to minimise IO. Therefore we need the start key, stop key, and hostname for each region associated with each Spark partition.


public class HBasePartition implements Partition {

  private final String regionHostname;
  private final int partitionIndex;
  private final byte[] start;
  private final byte[] stop;

  public HBasePartition(String regionHostname, int partitionIndex, byte[] start, byte[] stop) {
    this.regionHostname = regionHostname;
    this.partitionIndex = partitionIndex;
    this.start = start;
    this.stop = stop;
  }

  public String getRegionHostname() {
    return regionHostname;
  }

  public byte[] getStart() {
    return start;
  }

  public byte[] getStop() {
    return stop;
  }

  @Override
  public int index() {
    return partitionIndex;
  }
}

The HBase interface RegionLocator, which can be obtained from a Connection instance, can be used to build an array of HBasePartitions. It aids efficiency to check if it is possible to skip each region entirely, if the supplied start and stop keys do not overlap with its extent.


public class HBasePartitioner implements Serializable {

  public Partition[] getPartitions(byte[] table, byte[] start, byte[] stop) {
    try(RegionLocator regionLocator = ConnectionFactory.createConnection().getRegionLocator(TableName.valueOf(table))) {
      List<HRegionLocation> regionLocations = regionLocator.getAllRegionLocations();
      int regionCount = regionLocations.size();
      List<Partition> partitions = Lists.newArrayListWithExpectedSize(regionCount);
      int partition = 0;
      for(HRegionLocation regionLocation : regionLocations) {
        HRegionInfo regionInfo = regionLocation.getRegionInfo();
        byte[] regionStart = regionInfo.getStartKey();
        byte[] regionStop = regionInfo.getEndKey();
        if(!skipRegion(start, stop, regionStart, regionStop)) {
          partitions.add(new HBasePartition(regionLocation.getHostname(),
                                            partition++,
                                            max(start, regionStart),
                                            min(stop, regionStop)));
        }
      }
      return partitions.toArray(new Partition[partition]);
    }
    catch (IOException e) {
      throw new RuntimeException("Could not create HBase region partitions", e);
    }
  }

  private static boolean skipRegion(byte[] scanStart, byte[] scanStop, byte[] regionStart, byte[] regionStop) {
    // check scan starts before region stops, and that the scan stops before the region starts
    return min(scanStart, regionStop) == regionStop || max(scanStop, regionStart) == regionStart;
  }

  private static byte[] min(byte[] left, byte[] right) {
    if(left.length == 0) {
      return left;
    }
    if(right.length == 0) {
      return right;
    }
    return Bytes.compareTo(left, right) < 0 ? left : right;   }   private static byte[] max(byte[] left, byte[] right) {     if(left.length == 0) {       return right;     }     if(right.length == 0) {       return left;     }     return Bytes.compareTo(left, right) >= 0 ? left : right;
  }
}

Finally, we can implement an RDD specialised for executing HBasePartitions. We want to exploit the ability to choose or influence where the partition is executed, so need access to a Scala RDD method getPreferredLocations. This method is not available on JavaRDD, so we are forced to do some Scala conversions. The Scala/Java conversion work is quite tedious but necessary when accessing low level features on a Java-only project.


public class HBaseRDD<T> extends RDD<T> {

  private static <T> ClassTag<T> createClassTag(Class>T> klass) {
    return scala.reflect.ClassTag$.MODULE$.apply(klass);
  }

  private final HBasePartitioner partitioner;
  private final String tableName;
  private final byte[] startKey;
  private final byte[] stopKey;
  private final Function<Result, T> mapper;

  public HBaseRDD(SparkContext sparkContext,
                  Class<T> klass,
                  HBasePartitioner partitioner,
                  String tableName,
                  byte[] startKey,
                  byte[] stopKey,
                  Function<Result, T> mapper) {
    super(new EmptyRDD<>(sparkContext, createClassTag(klass)), createClassTag(klass));
    this.partitioner = partitioner;
    this.tableName = tableName;
    this.startKey = startKey;
    this.stopKey = stopKey;
    this.mapper = mapper;
  }

  @Override
  public Iterator<T> compute(Partition split, TaskContext context) {
    HBasePartition partition = (HBasePartition)split;
    try(Connection connection = ConnectionFactory.createConnection()) {
      Scan scan = new Scan()
                      .setStartRow(partition.getStart())
                      .setStopRow(partition.getStop())
                      .setCacheBlocks(false);
      Table table = connection.getTable(TableName.valueOf(tableName));
      ResultScanner scanner = table.getScanner(scan);
      return JavaConversions.asScalaIterator(
              StreamSupport.stream(scanner.spliterator(), false).map(mapper).iterator()
      );
    }
    catch (IOException e) {
      throw new RuntimeException("Region scan failed", e);
    }
  }

  @Override
  public Seq<String> getPreferredLocations(Partition split) {
    Set<String> locations = ImmutableSet.of(((HBasePartition)split).getRegionHostname());
    return JavaConversions.asScalaSet(locations).toSeq();
  }

  @Override
  public Partition[] getPartitions() {
    return partitioner.getPartitions(Bytes.toBytes(tableName), startKey, stopKey);
  }
}

As far as the interface of this class is concerned, it’s just normal Java, so it can be used from a more Java-centric Spark project, despite using some Scala APIs under the hood. We could achieve similar results with mapPartitions, but would have less control over partitioning and co-location.

Posted on

HBase Connection Management

I have built several web applications recently using Apache HBase as a backend data store. This article addresses some of the design concerns and approaches made in efficiently managing HBase connections.

One of the first things I noticed about the HBase client API was how long it takes to create the connection. HBase connection creation is effectively Zookeeper based service discovery, the end result being a client which knows where all the region servers are, and which region server is serving which key space. This operation is expensive and needs to be minimised.

At first I only created the connection once, when I started the web application. This is very simple and is fine for most use cases.


public static void main(String[] args) throws Exception {
        Configuration configuration = HBaseConfiguration.create();
        Connection connection = ConnectionFactory.createConnection(configuration);
}

This approach is great unless there is the requirement to proxy your end user when querying HBase. If Apache Ranger is enabled on your HBase cluster, proxying your users allows it to apply user specific authorisation to the query, rather than to your web application service user. This poses a few constraints: the most relevant being that you need to create a connection per user so you can’t just connect when you start your application any more.

Proxy Users

I needed to proxy users and minimise connection creation, so I built a connection pool class which, given a user principal, creates a connection as the user. I used Guava’s loading cache to handle cache eviction and concurrency. Guava’s cache also has a very useful eviction listener, which allows the connection to be closed when evicted from the cache.

In order to get the user proxying working, the UserGroupInformation for the web application service principal itself is required (see here), and you need to have successfully authenticated your user (I used SPNego to do this). The Hadoop class UserProvider is then used to create a proxy user. Your web application service principal also needs to be configured as a proxying user in core-site.xml, which you can manage via tools like Ambari.


public class ConnectionPool implements Closeable {

  private static final Logger LOGGER = LoggerFactory.getLogger(ConnectionPool.class);
  private final Configuration configuration;
  private final LoadingCache<String, Connection> cache;
  private final ExecutorService threadPool;
  private final UserProvider userProvider;
  private volatile boolean closed = false;
  private final UserGroupInformation loginUser;

  public ConnectionPool(Configuration configuration, UserGroupInformation loginUser) {
    this.loginUser = loginUser;
    this.configuration = configuration;
    this.userProvider = UserProvider.instantiate(configuration);
    this.threadPool = Executors.newFixedThreadPool(50, new ThreadFactoryBuilder().setNameFormat("hbase-client-connection-pool").build());
    this.cache = createCache();
  }

  public Connection getConnection(Principal principal) throws IOException {
    return cache.getUnchecked(principal.getName());
  }

  @Override
  public void close() throws IOException {
    if(!closed) {
      closed = true;
      cache.invalidateAll();
      cache.cleanUp();
      connectionThreadPool.shutdown();
    }
  }

  private Connection createConnection(String userName) throws IOException {
      UserGroupInformation proxyUserGroupInformation = UserGroupInformation.createProxyUser(userName, loginUser);
      return ConnectionFactory.createConnection(configuration, threadPool, userProvider.create(proxyUserGroupInformation));
  }

  private LoadingCache<String, Connection> createCache() {
    return CacheBuilder.newBuilder()
                       .expireAfterAccess(10, TimeUnit.MINUTES)
            .<String, Connection>removalListener(eviction -> {
              Connection connection = eviction.getValue();
              if(null != connection) {
                try {
                  connection.close();
                } catch (IOException e) {
                  LOGGER.error("Connection could not be closed for user=" + eviction.getKey(), e);
                }
              }
            })
            .maximumSize(100)
            .build(new CacheLoader<String, Connection>() {
              @Override
              public Connection load(String userName) throws Exception {
                LOGGER.info("Create connection for user={}", userName);
                return createConnection(userName);
              }
            });
  }
}

One drawback of this approach is that the user experiences a slow connection the first time they query the server or any time after their connection has been evicted from the cache. They will also observe a lag if you are sharding your application behind a load balancer without sticky sessions. If you use a round robin strategy connection creation costs will be incurred whenever there is a new instance/user combination route.

Posted on