Elasticsearch: RestHighLevelClient impossible mocking

Created on 27 Mar 2019  路  7Comments  路  Source: elastic/elasticsearch

Hi,

I'm trying to unit test some code using RestHighLevelClient.

I have

private synchronized void createIndex(String indexName) throws IOException {
        if (!client.exists(new GetRequest(indexName), RequestOptions.DEFAULT)) {
            CreateIndexRequest request = new CreateIndexRequest(indexName);
            request.mapping(indexType, ElasticSearchConstants.MAPPING_SOURCE);
            request.alias(new Alias(alias));
            request.settings(ElasticSearchConstants.MAPPING_SETTINGS, XContentType.JSON);
            client.indices().create(request, RequestOptions.DEFAULT);
        }
    }

Now i want to verify that client.indices().create was called.

RestHighLevelClient is not final and I was able to mock it.

Mock for RestHighLevelClient.indices() returns null by default.
So of course we want to mock that

when(elasticSearchClient.indices()).thenReturn(mock(IndicesClient.class));

BAM problem. IndicesClient is final class. I can't mock it.

So I can't test if create gets called and with which parameters.

IndicesClient should not be final class or indices() should return an interface and your final class implementation is hidden away.

This is pretty much basic stuff for public APIs guys.

:CorFeatureJava High Level REST Client CorFeatures

Most helpful comment

FWIW - I was able to use Mockito 2's mock-maker-inline to mock all the things. Here's a snippet of the body of a test that mocks every part used while scrolling.

    String testHitOneJson = "{...testjson}";
    String testScrollId = "unit-test-scroll-id";
    String testHitOneId = "hit-one-id";
    int testSize = 100;
    EsClientFactory clientFactory = mock(EsClientFactory.class);
    RestHighLevelClient restHighLevelClient = mock(RestHighLevelClient.class);
SearchResponse firstSearchResponse = mock(SearchResponse.class);
    SearchHits firstSearchHits = mock(SearchHits.class);
    SearchHit firstSearchHitsHitOne = mock(SearchHit.class);
    ClearScrollResponse clearScrollResponse = mock(ClearScrollResponse.class);

    when(firstSearchResponse.getScrollId()).thenReturn(testScrollId);
    when(firstSearchHitsHitOne.getId()).thenReturn(testHitOneId);
    when(firstSearchHitsHitOne.getSourceAsString()).thenReturn(testHitOneJson);
    when(firstSearchHits.getHits()).thenReturn(new SearchHit[] {firstSearchHitsHitOne});
    when(firstSearchResponse.getHits()).thenReturn(firstSearchHits);

    SearchResponse secondSearchResponse = mock(SearchResponse.class);
    SearchHits secondSearchHits = mock(SearchHits.class);

    when(secondSearchResponse.getScrollId()).thenReturn(testScrollId);
    when(secondSearchHits.getHits()).thenReturn(new SearchHit[] {});
    when(secondSearchResponse.getHits()).thenReturn(secondSearchHits);

    when(this.restHighLevelClient.search(any(SearchRequest.class), any(RequestOptions.class)))
        .thenReturn(firstSearchResponse);

    when(this.restHighLevelClient.scroll(any(SearchScrollRequest.class), any(RequestOptions.class)))
        .thenReturn(secondSearchResponse);

    when(this.clientFactory.buildClient()).thenReturn(this.restHighLevelClient);

    when(this.restHighLevelClient.clearScroll(
            any(ClearScrollRequest.class), any(RequestOptions.class)))
        .thenReturn(clearScrollResponse);

    doNothing().when(this.restHighLevelClient).close();

    // test things that use client

    // verify close and anything else needed
    verify(this.restHighLevelClient, times(1)).close();

All 7 comments

Pinging @elastic/es-core-features

Same boat here. Makes it impossible to properly unit test the code => I have to wrap each client to hide it behind an interface which is then mockable. Questionable design here.

I ended up wrapping the client like this:

import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.get.MultiGetItemResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequest;

import java.util.List;
import java.util.Optional;

public interface DatabaseClient {

  boolean exists(String index, String id);

  BulkItemResponse[] bulk(BulkRequest request);

  Optional<String> get(String index, String id);

  MultiGetItemResponse[] multiGet(String index, List<String> ids);

  SearchResponse search(SearchRequest request);

  SearchResponse searchScroll(SearchScrollRequest request);

  void clearScroll(String scrollId);
}

Then, providing a Spring configuration which instantiates the wrapper:

import com.octoperf.database.client.api.DatabaseClient;
import org.elasticsearch.client.IndicesClient;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.SnapshotClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
class ElasticsearchClientConfig {

  @Bean
  DatabaseClient databaseClient(final RestHighLevelClient client) {
    return new ElasticsearchClient(
      client::exists,
      client::bulk,
      client::get,
      client::mget,
      client::search,
      client::scroll,
      client::clearScroll
    );
  }
}

And the wrapper itself:

import com.octoperf.database.client.api.DatabaseClient;
import io.vavr.CheckedFunction2;
import lombok.AllArgsConstructor;
import lombok.NonNull;
import lombok.experimental.FieldDefaults;
import lombok.extern.slf4j.Slf4j;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.*;
import org.elasticsearch.action.search.*;
import org.elasticsearch.client.RequestOptions;

import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

import static io.vavr.control.Try.of;
import static io.vavr.control.Try.success;
import static lombok.AccessLevel.PACKAGE;
import static lombok.AccessLevel.PRIVATE;
import static org.elasticsearch.client.RequestOptions.DEFAULT;
import static org.elasticsearch.search.fetch.subphase.FetchSourceContext.DO_NOT_FETCH_SOURCE;
import static org.elasticsearch.search.fetch.subphase.FetchSourceContext.FETCH_SOURCE;

@Slf4j
@AllArgsConstructor(access = PACKAGE)
@FieldDefaults(level = PRIVATE, makeFinal = true)
final class ElasticsearchClient implements DatabaseClient {
  public static final Consumer<? super Throwable> LOG_ERROR = e -> log.error("ElasticsearchClient", e);

  @NonNull
  CheckedFunction2<GetRequest, RequestOptions, Boolean> exists;
  @NonNull
  CheckedFunction2<BulkRequest, RequestOptions, BulkResponse> bulk;
  @NonNull
  CheckedFunction2<GetRequest, RequestOptions, GetResponse> get;
  @NonNull
  CheckedFunction2<MultiGetRequest, RequestOptions, MultiGetResponse> multiGet;
  @NonNull
  CheckedFunction2<SearchRequest, RequestOptions, SearchResponse> search;
  @NonNull
  CheckedFunction2<SearchScrollRequest, RequestOptions, SearchResponse> searchScroll;
  @NonNull
  CheckedFunction2<ClearScrollRequest, RequestOptions, ClearScrollResponse> clearScroll;

  @Override
  public boolean exists(final String index, final String id) {
    return success(new GetRequest(index, id))
      .map(r -> r.fetchSourceContext(DO_NOT_FETCH_SOURCE))
      .mapTry(r -> exists.apply(r, DEFAULT))
      .onFailure(LOG_ERROR)
      .get();
  }

  @Override
  public BulkItemResponse[] bulk(final BulkRequest request) {
    return of(() -> bulk.apply(request, DEFAULT))
      .map(BulkResponse::getItems)
      .onFailure(LOG_ERROR)
      .getOrElse(new BulkItemResponse[0]);
  }

  @Override
  public Optional<String> get(final String index, final String id) {
    return success(new GetRequest(index, id))
      .map(r -> r.fetchSourceContext(FETCH_SOURCE))
      .mapTry(r -> get.apply(r, DEFAULT))
      .filter(GetResponse::isExists)
      .map(GetResponse::getSourceAsString)
      .onFailure(LOG_ERROR)
      .toJavaOptional();
  }

  @Override
  public MultiGetItemResponse[] multiGet(final String index,
                                         final List<String> ids) {
    final MultiGetRequest multi = new MultiGetRequest();
    ids.forEach(id -> multi.add(index, id));

    return success(multi)
      .mapTry(r -> multiGet.apply(r, DEFAULT))
      .map(MultiGetResponse::getResponses)
      .getOrElse(new MultiGetItemResponse[0]);
  }

  @Override
  public SearchResponse search(final SearchRequest request) {
    return success(request).mapTry(r -> search.apply(r, DEFAULT)).get();
  }

  @Override
  public SearchResponse searchScroll(final SearchScrollRequest request) {
    return success(request).mapTry(r -> searchScroll.apply(r, DEFAULT)).get();
  }

  @Override
  public void clearScroll(final String scrollId) {
    final ClearScrollRequest request = new ClearScrollRequest();
    request.addScrollId(scrollId);

    of(() -> clearScroll.apply(request, DEFAULT));
  }
}

This way, I can fully mock every single function injected in the ElasticsearchClient. The RestHighLevelClient can be easily provided by instantiating it in JUnit.

Thanks @jloisel for sharing your solution.

Sure no problem. Of course, it would be much nicer if the client would have exposed only interfaces. Even other classes like SearchRequest are causing mocking issue because its class is final. In that case, i'm instantiating the real object.

FWIW - I was able to use Mockito 2's mock-maker-inline to mock all the things. Here's a snippet of the body of a test that mocks every part used while scrolling.

    String testHitOneJson = "{...testjson}";
    String testScrollId = "unit-test-scroll-id";
    String testHitOneId = "hit-one-id";
    int testSize = 100;
    EsClientFactory clientFactory = mock(EsClientFactory.class);
    RestHighLevelClient restHighLevelClient = mock(RestHighLevelClient.class);
SearchResponse firstSearchResponse = mock(SearchResponse.class);
    SearchHits firstSearchHits = mock(SearchHits.class);
    SearchHit firstSearchHitsHitOne = mock(SearchHit.class);
    ClearScrollResponse clearScrollResponse = mock(ClearScrollResponse.class);

    when(firstSearchResponse.getScrollId()).thenReturn(testScrollId);
    when(firstSearchHitsHitOne.getId()).thenReturn(testHitOneId);
    when(firstSearchHitsHitOne.getSourceAsString()).thenReturn(testHitOneJson);
    when(firstSearchHits.getHits()).thenReturn(new SearchHit[] {firstSearchHitsHitOne});
    when(firstSearchResponse.getHits()).thenReturn(firstSearchHits);

    SearchResponse secondSearchResponse = mock(SearchResponse.class);
    SearchHits secondSearchHits = mock(SearchHits.class);

    when(secondSearchResponse.getScrollId()).thenReturn(testScrollId);
    when(secondSearchHits.getHits()).thenReturn(new SearchHit[] {});
    when(secondSearchResponse.getHits()).thenReturn(secondSearchHits);

    when(this.restHighLevelClient.search(any(SearchRequest.class), any(RequestOptions.class)))
        .thenReturn(firstSearchResponse);

    when(this.restHighLevelClient.scroll(any(SearchScrollRequest.class), any(RequestOptions.class)))
        .thenReturn(secondSearchResponse);

    when(this.clientFactory.buildClient()).thenReturn(this.restHighLevelClient);

    when(this.restHighLevelClient.clearScroll(
            any(ClearScrollRequest.class), any(RequestOptions.class)))
        .thenReturn(clearScrollResponse);

    doNothing().when(this.restHighLevelClient).close();

    // test things that use client

    // verify close and anything else needed
    verify(this.restHighLevelClient, times(1)).close();

Just an fyi, I do not think we will be doing this for now, see https://github.com/elastic/elasticsearch/issues/31065#issuecomment-396360557 for more information.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

dawi picture dawi  路  3Comments

brwe picture brwe  路  3Comments

abtpst picture abtpst  路  3Comments

matthughes picture matthughes  路  3Comments

DhairyashilBhosale picture DhairyashilBhosale  路  3Comments