Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,6 @@
*/
package org.opensearch.search.relevance;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.opensearch.action.support.ActionFilter;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
Expand Down Expand Up @@ -42,82 +34,93 @@
import org.opensearch.search.relevance.transformer.ResultTransformer;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfigurationFactory;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.pipeline.KendraRankingResponseProcessor;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.PersonalizeRankingResponseProcessor;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClientSettings;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParametersExtBuilder;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class SearchRelevancePlugin extends Plugin implements ActionPlugin, SearchPlugin, SearchPipelinePlugin {

private OpenSearchClient openSearchClient;
private KendraHttpClient kendraClient;
private KendraIntelligentRanker kendraIntelligentRanker;
private OpenSearchClient openSearchClient;
private KendraHttpClient kendraClient;
private KendraIntelligentRanker kendraIntelligentRanker;
private KendraClientSettings kendraClientSettings;

private Collection<ResultTransformer> getAllResultTransformers() {
// Initialize and add other transformers here
return List.of(this.kendraIntelligentRanker);
}

private Collection<ResultTransformerConfigurationFactory> getResultTransformerConfigurationFactories() {
return List.of(KendraIntelligentRankingConfigurationFactory.INSTANCE);
}

@Override
public List<ActionFilter> getActionFilters() {
return Arrays.asList(new SearchActionFilter(getAllResultTransformers(), openSearchClient));
}

@Override
public List<Setting<?>> getSettings() {
// NOTE: cannot use kendraIntelligentRanker.getTransformerSettings because the object is not yet created
List<Setting<?>> allTransformerSettings = new ArrayList<>();
allTransformerSettings.addAll(KendraIntelligentRankerSettings.getAllSettings());
// Add settings for other transformers here
return allTransformerSettings;
}

private Collection<ResultTransformer> getAllResultTransformers() {
// Initialize and add other transformers here
return List.of(this.kendraIntelligentRanker);
}
@Override
public Collection<Object> createComponents(
Client client,
ClusterService clusterService,
ThreadPool threadPool,
ResourceWatcherService resourceWatcherService,
ScriptService scriptService,
NamedXContentRegistry xContentRegistry,
Environment environment,
NodeEnvironment nodeEnvironment,
NamedWriteableRegistry namedWriteableRegistry,
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
this.openSearchClient = new OpenSearchClient(client);
this.kendraClientSettings = KendraClientSettings.getClientSettings(environment.settings());
this.kendraClient = new KendraHttpClient(this.kendraClientSettings);
this.kendraIntelligentRanker = new KendraIntelligentRanker(this.kendraClient);

private Collection<ResultTransformerConfigurationFactory> getResultTransformerConfigurationFactories() {
return List.of(KendraIntelligentRankingConfigurationFactory.INSTANCE);
}

@Override
public List<ActionFilter> getActionFilters() {
return Arrays.asList(new SearchActionFilter(getAllResultTransformers(), openSearchClient));
}

@Override
public List<Setting<?>> getSettings() {
// NOTE: cannot use kendraIntelligentRanker.getTransformerSettings because the object is not yet created
List<Setting<?>> allTransformerSettings = new ArrayList<>();
allTransformerSettings.addAll(KendraIntelligentRankerSettings.getAllSettings());
// Add settings for other transformers here
return allTransformerSettings;
}

@Override
public Collection<Object> createComponents(
Client client,
ClusterService clusterService,
ThreadPool threadPool,
ResourceWatcherService resourceWatcherService,
ScriptService scriptService,
NamedXContentRegistry xContentRegistry,
Environment environment,
NodeEnvironment nodeEnvironment,
NamedWriteableRegistry namedWriteableRegistry,
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
this.openSearchClient = new OpenSearchClient(client);
this.kendraClient = new KendraHttpClient(KendraClientSettings.getClientSettings(environment.settings()));
this.kendraIntelligentRanker = new KendraIntelligentRanker(this.kendraClient);

return Arrays.asList(
this.openSearchClient,
this.kendraClient,
this.kendraIntelligentRanker
);
}
return Arrays.asList(
this.openSearchClient,
this.kendraClientSettings,
this.kendraClient,
this.kendraIntelligentRanker
);
}

@Override
public List<SearchExtSpec<?>> getSearchExts() {
Map<String, ResultTransformerConfigurationFactory> resultTransformerMap = getResultTransformerConfigurationFactories().stream()
.collect(Collectors.toMap(ResultTransformerConfigurationFactory::getName, i -> i));
return List.of(new SearchExtSpec<>(SearchConfigurationExtBuilder.NAME,
input -> new SearchConfigurationExtBuilder(input, resultTransformerMap),
parser -> SearchConfigurationExtBuilder.parse(parser, resultTransformerMap)),
new SearchExtSpec<>(PersonalizeRequestParametersExtBuilder.NAME,
input -> new PersonalizeRequestParametersExtBuilder(input),
parser -> PersonalizeRequestParametersExtBuilder.parse(parser)));
}
@Override
public List<SearchExtSpec<?>> getSearchExts() {
Map<String, ResultTransformerConfigurationFactory> resultTransformerMap = getResultTransformerConfigurationFactories().stream()
.collect(Collectors.toMap(ResultTransformerConfigurationFactory::getName, i -> i));
return List.of(new SearchExtSpec<>(SearchConfigurationExtBuilder.NAME,
input -> new SearchConfigurationExtBuilder(input, resultTransformerMap),
parser -> SearchConfigurationExtBuilder.parse(parser, resultTransformerMap)),
new SearchExtSpec<>(PersonalizeRequestParametersExtBuilder.NAME,
input -> new PersonalizeRequestParametersExtBuilder(input),
parser -> PersonalizeRequestParametersExtBuilder.parse(parser)));
}

@Override
public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Processor.Parameters parameters) {
return Map.of(PersonalizeRankingResponseProcessor.TYPE,
new PersonalizeRankingResponseProcessor.Factory(
PersonalizeClientSettings.getClientSettings(parameters.env.settings())));
}
}
@Override
public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Processor.Parameters parameters) {
return Map.of(PersonalizeRankingResponseProcessor.TYPE, new PersonalizeRankingResponseProcessor.Factory(PersonalizeClientSettings.getClientSettings(parameters.env.settings())),
KendraRankingResponseProcessor.TYPE, new KendraRankingResponseProcessor.Factory(this.kendraClientSettings));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.relevance.transformer.kendraintelligentranking.pipeline;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.KendraIntelligentRanker;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraClientSettings;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraHttpClient;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfiguration;

import static org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.Constants.KENDRA_DEFAULT_DOC_LIMIT;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
* This is a {@link SearchResponseProcessor} that applies kendra intelligence ranking
*/
public class KendraRankingResponseProcessor implements SearchResponseProcessor {
/**
* key to reference this processor type from a search pipeline
*/
public static final String TYPE = "kendra_ranking";
private final List<String> titleField;
private final List<String> bodyField;
private final int docLimit;
private final String tag;
private final String description;
private final KendraHttpClient kendraClient;

private static final Logger logger = LogManager.getLogger(KendraRankingResponseProcessor.class);

/**
* Constructor that apply configuration for kendra re-ranking
*
* @param tag processor tag
* @param description processor description
* @param titleField titleField applied to kendra re-ranking
* @param bodyField bodyField applied to kendra re-ranking
* @param inputDocLimit docLimit applied to kendra re-ranking
* @param kendraClient kendraClient to connect with kendra
*/
public KendraRankingResponseProcessor(String tag, String description, List<String> titleField, List<String> bodyField, Integer inputDocLimit, KendraHttpClient kendraClient) {
super();
this.titleField = titleField;
this.bodyField = bodyField;
this.tag = tag;
this.description = description;
this.kendraClient = kendraClient;
int docLimit;
if (inputDocLimit == null) {
docLimit = KENDRA_DEFAULT_DOC_LIMIT;
} else {
docLimit = inputDocLimit;
}
this.docLimit = docLimit;
}

/**
* Gets the type of the processor.
*/
@Override
public String getType() {
return TYPE;
}

/**
* Gets the tag of a processor.
*/
@Override
public String getTag() {
return tag;
}

/**
* Gets the description of a processor.
*/
@Override
public String getDescription() {
return description;
}

/**
* Transform the response hit and apply kendra re-ranking logic
*/
@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
SearchHits hits = response.getHits();

if (hits.getHits().length == 0) {
// Avoid call to re-rank empty results
logger.info("TotalHits = 0. Returning search response without transforming.");
return response;
}

KendraIntelligentRankingConfiguration.KendraIntelligentRankingProperties properties = new KendraIntelligentRankingConfiguration.KendraIntelligentRankingProperties(bodyField, titleField, docLimit);
KendraIntelligentRankingConfiguration configuration = new KendraIntelligentRankingConfiguration(1, properties);
KendraIntelligentRanker ranker = new KendraIntelligentRanker(this.kendraClient);
SearchRequest processedRequest = ranker.preprocessRequest(request, configuration);

if (ranker.shouldTransform(processedRequest, configuration)) {
long startTime = System.nanoTime();
SearchHits reRankedSearchHits = ranker.transform(hits, processedRequest, configuration);
long timeTookMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime);

final SearchResponseSections internalResponse = new InternalSearchResponse(reRankedSearchHits,
(InternalAggregations) response.getAggregations(), response.getSuggest(),
new SearchProfileShardResults(response.getProfileResults()), response.isTimedOut(),
response.isTerminatedEarly(), response.getNumReducePhases());

final SearchResponse newResponse = new SearchResponse(internalResponse, response.getScrollId(),
response.getTotalShards(), response.getSuccessfulShards(),
response.getSkippedShards(), timeTookMillis, response.getShardFailures(),
response.getClusters());
logger.info("kendra ranking processor took " + timeTookMillis + " ms");
return newResponse;
} else
return response;
}

/**
* This is a factor that creates the KendraRankingResponseProcessor
*/
public static final class Factory implements Processor.Factory<SearchResponseProcessor> {

private final KendraClientSettings clientSettings;

/**
* Constructor for factory
* @param kendraClientSettings credentials to create kendra client
*/
public Factory(KendraClientSettings kendraClientSettings) {
this.clientSettings = kendraClientSettings;
}

public KendraRankingResponseProcessor create(
Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories,
String tag,
String description,
Map<String, Object> config
) throws Exception {
List<String> titleField = Collections.singletonList(ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, "title_field"));
List<String> bodyField = Collections.singletonList(ConfigurationUtils.readStringProperty(TYPE, tag, config, "body_field"));
String inputDocLimit = ConfigurationUtils.readOptionalStringOrIntProperty(TYPE, tag, config, "doc_limit");
KendraHttpClient kendraClient = new KendraHttpClient(this.clientSettings);
int docLimit;
if (inputDocLimit == null) {
docLimit = KENDRA_DEFAULT_DOC_LIMIT;
} else {
docLimit = Integer.parseInt(inputDocLimit);
}
return new KendraRankingResponseProcessor(tag, description, titleField, bodyField, docLimit, kendraClient);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
package org.opensearch.search.relevance.transformer.kendraintelligentranking;

import org.apache.lucene.search.TotalHits;
import org.mockito.Mockito;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.settings.Setting;
Expand All @@ -24,37 +23,23 @@
import org.opensearch.search.relevance.configuration.ResultTransformerConfiguration;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraClientSettings;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraHttpClient;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraIntelligentClientTests;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfiguration;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfiguration.KendraIntelligentRankingProperties;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.dto.RescoreRequest;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.dto.RescoreResult;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.dto.RescoreResultItem;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;

public class KendraIntelligentRankerTests extends OpenSearchTestCase {
private static KendraHttpClient buildMockHttpClient(Function<RescoreRequest, RescoreResult> mockRescoreImpl) {
KendraHttpClient kendraHttpClient = Mockito.mock(KendraHttpClient.class);
Mockito.when(kendraHttpClient.isValid()).thenReturn(true);
Mockito.doAnswer(invocation -> {
RescoreRequest rescoreRequest = invocation.getArgument(0);
return mockRescoreImpl.apply(rescoreRequest);
}).when(kendraHttpClient).rescore(Mockito.any(RescoreRequest.class));
return kendraHttpClient;
}

private static KendraHttpClient buildMockHttpClient() {
return buildMockHttpClient(r -> new RescoreResult());
}
public class KendraIntelligentRankerTests extends KendraIntelligentClientTests {

public void testGetSettings() {
List<Setting<?>> settings = new KendraIntelligentRanker(buildMockHttpClient()).getTransformerSettings();
Expand Down
Loading