diff --git a/flink-connector-elasticsearch-base/src/main/java/org/apache/flink/connector/elasticsearch/table/ElasticsearchDynamicSource.java b/flink-connector-elasticsearch-base/src/main/java/org/apache/flink/connector/elasticsearch/table/ElasticsearchDynamicSource.java index 83a95f9e..79f9bbaf 100644 --- a/flink-connector-elasticsearch-base/src/main/java/org/apache/flink/connector/elasticsearch/table/ElasticsearchDynamicSource.java +++ b/flink-connector-elasticsearch-base/src/main/java/org/apache/flink/connector/elasticsearch/table/ElasticsearchDynamicSource.java @@ -24,14 +24,14 @@ * from a logical description. */ public class ElasticsearchDynamicSource implements LookupTableSource, SupportsProjectionPushDown { - private final DecodingFormat> format; - private final ElasticsearchConfiguration config; + protected final DecodingFormat> format; + protected final ElasticsearchConfiguration config; private final int lookupMaxRetryTimes; private final LookupCache lookupCache; private final String docType; private final String summaryString; - private final ElasticsearchApiCallBridge apiCallBridge; - private DataType physicalRowDataType; + protected final ElasticsearchApiCallBridge apiCallBridge; + protected DataType physicalRowDataType; public ElasticsearchDynamicSource( DecodingFormat> format, @@ -84,7 +84,7 @@ public LookupRuntimeProvider getLookupRuntimeProvider(LookupContext context) { } } - private NetworkClientConfig buildNetworkClientConfig() { + protected NetworkClientConfig buildNetworkClientConfig() { NetworkClientConfig.Builder builder = new NetworkClientConfig.Builder(); if (config.getUsername().isPresent() && !StringUtils.isNullOrWhitespaceOnly(config.getUsername().get())) { diff --git a/flink-connector-elasticsearch-base/src/main/java/org/apache/flink/connector/elasticsearch/table/ElasticsearchDynamicTableFactoryBase.java b/flink-connector-elasticsearch-base/src/main/java/org/apache/flink/connector/elasticsearch/table/ElasticsearchDynamicTableFactoryBase.java index 0395eafd..7da1294b 100644 --- a/flink-connector-elasticsearch-base/src/main/java/org/apache/flink/connector/elasticsearch/table/ElasticsearchDynamicTableFactoryBase.java +++ b/flink-connector-elasticsearch-base/src/main/java/org/apache/flink/connector/elasticsearch/table/ElasticsearchDynamicTableFactoryBase.java @@ -162,7 +162,7 @@ ElasticsearchConfiguration getConfiguration(FactoryUtil.TableFactoryHelper helpe } @Nullable - private LookupCache getLookupCache(ReadableConfig tableOptions) { + protected LookupCache getLookupCache(ReadableConfig tableOptions) { LookupCache cache = null; if (tableOptions .get(LookupOptions.CACHE_TYPE) diff --git a/flink-connector-elasticsearch7/pom.xml b/flink-connector-elasticsearch7/pom.xml index f8cbbf43..b0b8014d 100644 --- a/flink-connector-elasticsearch7/pom.xml +++ b/flink-connector-elasticsearch7/pom.xml @@ -165,6 +165,14 @@ under the License. test + + org.apache.flink + flink-table-planner_${scala.binary.version} + ${flink.version} + test-jar + test + + org.apache.flink flink-table-runtime diff --git a/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7Configuration.java b/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7Configuration.java new file mode 100644 index 00000000..7c0711fb --- /dev/null +++ b/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7Configuration.java @@ -0,0 +1,16 @@ +package org.apache.flink.connector.elasticsearch.table; + +import org.apache.flink.configuration.ReadableConfig; + +import static org.apache.flink.connector.elasticsearch.table.Elasticsearch7ConnectorOptions.VECTOR_SEARCH_MAX_RETRIES; + +/** Elasticsearch 7 specific configuration. */ +public class Elasticsearch7Configuration extends ElasticsearchConfiguration { + Elasticsearch7Configuration(ReadableConfig config) { + super(config); + } + + public int getVectorSearchMaxRetries() { + return config.get(VECTOR_SEARCH_MAX_RETRIES); + } +} diff --git a/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7ConnectorOptions.java b/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7ConnectorOptions.java new file mode 100644 index 00000000..ac5296c5 --- /dev/null +++ b/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7ConnectorOptions.java @@ -0,0 +1,18 @@ +package org.apache.flink.connector.elasticsearch.table; + +import org.apache.flink.configuration.ConfigOption; +import org.apache.flink.configuration.ConfigOptions; + +/** + * Options specific for the Elasticsearch 7 connector. Public so that the {@link + * org.apache.flink.table.api.TableDescriptor} can access it. + */ +public class Elasticsearch7ConnectorOptions extends ElasticsearchConnectorOptions { + private Elasticsearch7ConnectorOptions() {} + + public static final ConfigOption VECTOR_SEARCH_MAX_RETRIES = + ConfigOptions.key("vector-search.max-retries") + .intType() + .defaultValue(3) + .withDescription("The max retry times for vector searching Elasticsearch."); +} diff --git a/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7DynamicSource.java b/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7DynamicSource.java new file mode 100644 index 00000000..88e072fb --- /dev/null +++ b/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7DynamicSource.java @@ -0,0 +1,109 @@ +package org.apache.flink.connector.elasticsearch.table; + +import org.apache.flink.api.common.serialization.DeserializationSchema; +import org.apache.flink.connector.elasticsearch.ElasticsearchApiCallBridge; +import org.apache.flink.connector.elasticsearch.NetworkClientConfig; +import org.apache.flink.connector.elasticsearch.table.search.ElasticsearchRowDataVectorSearchFunction; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.connector.format.DecodingFormat; +import org.apache.flink.table.connector.source.DynamicTableSource; +import org.apache.flink.table.connector.source.VectorSearchTableSource; +import org.apache.flink.table.connector.source.lookup.cache.LookupCache; +import org.apache.flink.table.connector.source.search.VectorSearchFunctionProvider; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.RowType; + +import org.elasticsearch.client.RestHighLevelClient; + +import javax.annotation.Nullable; + +/** + * A {@link DynamicTableSource} that describes how to create a {@link Elasticsearch7DynamicSource} + * from a logical description. + */ +public class Elasticsearch7DynamicSource extends ElasticsearchDynamicSource + implements VectorSearchTableSource { + + public Elasticsearch7DynamicSource( + DecodingFormat> format, + ElasticsearchConfiguration config, + DataType physicalRowDataType, + int lookupMaxRetryTimes, + String summaryString, + ElasticsearchApiCallBridge apiCallBridge, + @Nullable LookupCache lookupCache, + @Nullable String docType) { + super( + format, + config, + physicalRowDataType, + lookupMaxRetryTimes, + summaryString, + apiCallBridge, + lookupCache, + docType); + } + + @Override + public VectorSearchRuntimeProvider getSearchRuntimeProvider( + VectorSearchContext vectorSearchContext) { + + NetworkClientConfig networkClientConfig = buildNetworkClientConfig(); + + ElasticsearchRowDataVectorSearchFunction vectorSearchFunction = + new ElasticsearchRowDataVectorSearchFunction( + this.format.createRuntimeDecoder(vectorSearchContext, physicalRowDataType), + ((Elasticsearch7Configuration) config).getVectorSearchMaxRetries(), + config.getIndex(), + getSearchColumn(vectorSearchContext), + DataType.getFieldNames(physicalRowDataType).toArray(new String[0]), + config.getHosts(), + networkClientConfig, + (ElasticsearchApiCallBridge) apiCallBridge); + + return VectorSearchFunctionProvider.of(vectorSearchFunction); + } + + private String getSearchColumn(VectorSearchContext vectorSearchContext) { + int[][] searchColumns = vectorSearchContext.getSearchColumns(); + + if (searchColumns.length != 1) { + throw new IllegalArgumentException( + String.format( + "Elasticsearch only supports one search columns now, but input search columns size is %d.", + searchColumns.length)); + } + int[] searchColumn = searchColumns[0]; + if (searchColumn.length != 1) { + throw new IllegalArgumentException( + "Elasticsearch doesn't support to search data using nested columns."); + } + int searchColumnIndex = searchColumn[0]; + + if (searchColumnIndex < 0 + || searchColumnIndex >= physicalRowDataType.getChildren().size()) { + throw new ValidationException( + String.format( + "The specified search column with index %d doesn't exist in schema.", + searchColumnIndex)); + } + + DataType searchColumnType = physicalRowDataType.getChildren().get(searchColumnIndex); + if (!searchColumnType.getLogicalType().is(LogicalTypeRoot.ARRAY) + || !((ArrayType) searchColumnType.getLogicalType()) + .getElementType() + .is(LogicalTypeRoot.FLOAT)) { + throw new UnsupportedOperationException( + String.format( + "Elasticsearch only supports search data using float vector now, but input search column type is %s.", + searchColumnType)); + } + + return ((RowType) (physicalRowDataType.getLogicalType())) + .getFieldNames() + .get(searchColumnIndex); + } +} diff --git a/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7DynamicTableFactory.java b/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7DynamicTableFactory.java index 2f6d8849..06268c44 100644 --- a/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7DynamicTableFactory.java +++ b/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7DynamicTableFactory.java @@ -19,10 +19,22 @@ package org.apache.flink.connector.elasticsearch.table; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.serialization.DeserializationSchema; +import org.apache.flink.configuration.ReadableConfig; import org.apache.flink.connector.elasticsearch.Elasticsearch7ApiCallBridge; import org.apache.flink.connector.elasticsearch.ElasticsearchApiCallBridge; import org.apache.flink.connector.elasticsearch.sink.Elasticsearch7SinkBuilder; +import org.apache.flink.table.connector.format.DecodingFormat; +import org.apache.flink.table.connector.source.DynamicTableSource; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.factories.DeserializationFormatFactory; import org.apache.flink.table.factories.DynamicTableSinkFactory; +import org.apache.flink.table.factories.FactoryUtil; + +import org.elasticsearch.client.RestHighLevelClient; + +import static org.apache.flink.table.connector.source.lookup.LookupOptions.MAX_RETRIES; +import static org.elasticsearch.common.Strings.capitalize; /** A {@link DynamicTableSinkFactory} for discovering {@link ElasticsearchDynamicSink}. */ @Internal @@ -34,7 +46,38 @@ public Elasticsearch7DynamicTableFactory() { } @Override - ElasticsearchApiCallBridge getElasticsearchApiCallBridge() { + ElasticsearchConfiguration getConfiguration(FactoryUtil.TableFactoryHelper helper) { + return new Elasticsearch7Configuration(helper.getOptions()); + } + + @Override + public DynamicTableSource createDynamicTableSource(Context context) { + final FactoryUtil.TableFactoryHelper helper = + FactoryUtil.createTableFactoryHelper(this, context); + final ReadableConfig options = helper.getOptions(); + final DecodingFormat> format = + helper.discoverDecodingFormat( + DeserializationFormatFactory.class, + org.apache.flink.connector.elasticsearch.table.ElasticsearchConnectorOptions + .FORMAT_OPTION); + + ElasticsearchConfiguration config = getConfiguration(helper); + helper.validate(); + validateConfiguration(config); + + return new Elasticsearch7DynamicSource( + format, + config, + context.getPhysicalRowDataType(), + options.get(MAX_RETRIES), + capitalize(FACTORY_IDENTIFIER), + getElasticsearchApiCallBridge(), + getLookupCache(options), + getDocumentType(config)); + } + + @Override + ElasticsearchApiCallBridge getElasticsearchApiCallBridge() { return new Elasticsearch7ApiCallBridge(); } } diff --git a/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/search/ElasticsearchRowDataVectorSearchFunction.java b/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/search/ElasticsearchRowDataVectorSearchFunction.java new file mode 100644 index 00000000..b6ee3465 --- /dev/null +++ b/flink-connector-elasticsearch7/src/main/java/org/apache/flink/connector/elasticsearch/table/search/ElasticsearchRowDataVectorSearchFunction.java @@ -0,0 +1,185 @@ +package org.apache.flink.connector.elasticsearch.table.search; + +import org.apache.flink.api.common.serialization.DeserializationSchema; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.connector.elasticsearch.ElasticsearchApiCallBridge; +import org.apache.flink.connector.elasticsearch.NetworkClientConfig; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.utils.JoinedRowData; +import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.functions.VectorSearchFunction; +import org.apache.flink.util.FlinkRuntimeException; + +import org.apache.http.HttpHost; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.RestHighLevelClient; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptType; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** The {@link VectorSearchFunction} implementation for Elasticsearch. */ +public class ElasticsearchRowDataVectorSearchFunction extends VectorSearchFunction { + private static final Logger LOG = + LoggerFactory.getLogger(ElasticsearchRowDataVectorSearchFunction.class); + private static final long serialVersionUID = 1L; + private static final String QUERY_VECTOR = "query_vector"; + + private final DeserializationSchema deserializationSchema; + + private final String index; + + private final String[] producedNames; + private final int maxRetryTimes; + private SearchRequest searchRequest; + private SearchSourceBuilder searchSourceBuilder; + + private final ElasticsearchApiCallBridge callBridge; + private final NetworkClientConfig networkClientConfig; + private final List hosts; + private final String cosineSimilarity; + + private transient RestHighLevelClient client; + + public ElasticsearchRowDataVectorSearchFunction( + DeserializationSchema deserializationSchema, + int maxRetryTimes, + String index, + String searchColumn, + String[] producedNames, + List hosts, + NetworkClientConfig networkClientConfig, + ElasticsearchApiCallBridge callBridge) { + + checkNotNull(deserializationSchema, "No DeserializationSchema supplied."); + checkNotNull(maxRetryTimes, "No maxRetryTimes supplied."); + checkNotNull(producedNames, "No fieldNames supplied."); + checkNotNull(hosts, "No hosts supplied."); + checkNotNull(networkClientConfig, "No networkClientConfig supplied."); + checkNotNull(callBridge, "No ElasticsearchApiCallBridge supplied."); + + this.deserializationSchema = deserializationSchema; + this.maxRetryTimes = maxRetryTimes; + this.index = index; + this.producedNames = producedNames; + + this.networkClientConfig = networkClientConfig; + this.hosts = hosts; + this.callBridge = callBridge; + this.cosineSimilarity = + String.format( + "cosineSimilarity(params.%s, '%s') + 1.0", QUERY_VECTOR, searchColumn); + } + + @Override + public void open(FunctionContext context) throws Exception { + this.client = callBridge.createClient(networkClientConfig, hosts); + + // Set searchRequest in open method in case of amount of calling in eval method when every + // record comes. + this.searchRequest = new SearchRequest(index); + searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.fetchSource(producedNames, null); + deserializationSchema.open(null); + } + + @Override + public Collection vectorSearch(int topK, RowData features) throws IOException { + // Elasticsearch 7.x doesn't support ANN, we use script score to achieve exact matching. + Map params = + Collections.singletonMap(QUERY_VECTOR, features.getArray(0).toFloatArray()); + + Script script = new Script(ScriptType.INLINE, "painless", cosineSimilarity, params); + + ScriptScoreQueryBuilder scriptScoreQuery = + new ScriptScoreQueryBuilder(new MatchAllQueryBuilder(), script); + + searchSourceBuilder.query(scriptScoreQuery).size(topK); + + searchRequest.source(searchSourceBuilder); + + for (int retry = 0; retry <= maxRetryTimes; retry++) { + try { + ArrayList rows = new ArrayList<>(); + Tuple2 searchResponse = search(client, searchRequest); + + if (searchResponse.f1.length > 0) { + for (SearchResult result : searchResponse.f1) { + String source = result.source; + RowData row = parseSearchResult(source); + GenericRowData scoreData = new GenericRowData(1); + scoreData.setField(0, Double.valueOf(result.score)); + if (row != null) { + rows.add(new JoinedRowData(row, scoreData)); + } + } + rows.trimToSize(); + return rows; + } + } catch (IOException e) { + LOG.error(String.format("Elasticsearch search error, retry times = %d", retry), e); + if (retry >= maxRetryTimes) { + throw new FlinkRuntimeException("Execution of Elasticsearch search failed.", e); + } + try { + Thread.sleep(1000L * retry); + } catch (InterruptedException e1) { + LOG.warn( + "Interrupted while waiting to retry failed elasticsearch search, aborting"); + throw new FlinkRuntimeException(e1); + } + } + } + return Collections.emptyList(); + } + + private RowData parseSearchResult(String result) { + RowData row = null; + try { + row = deserializationSchema.deserialize(result.getBytes()); + } catch (IOException e) { + LOG.error("Deserialize search hit failed: " + e.getMessage()); + } + + return row; + } + + private Tuple2 search( + RestHighLevelClient client, SearchRequest searchRequest) throws IOException { + SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT); + SearchHit[] searchHits = searchResponse.getHits().getHits(); + + return new Tuple2<>( + searchResponse.getScrollId(), + Stream.of(searchHits) + .map(hit -> new SearchResult(hit.getSourceAsString(), hit.getScore())) + .toArray(SearchResult[]::new)); + } + + private static class SearchResult { + private final String source; + private final Float score; + + public SearchResult(String source, Float score) { + this.source = source; + this.score = score; + } + } +} diff --git a/flink-connector-elasticsearch7/src/test/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7VectorSearchITCase.java b/flink-connector-elasticsearch7/src/test/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7VectorSearchITCase.java new file mode 100644 index 00000000..9ddbb5d5 --- /dev/null +++ b/flink-connector-elasticsearch7/src/test/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch7VectorSearchITCase.java @@ -0,0 +1,326 @@ +package org.apache.flink.connector.elasticsearch.table; + +import org.apache.flink.connector.elasticsearch.ElasticsearchUtil; +import org.apache.flink.connector.elasticsearch.test.DockerImageVersions; +import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration; +import org.apache.flink.table.api.EnvironmentSettings; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.planner.factories.TestValuesTableFactory; +import org.apache.flink.test.junit5.MiniClusterExtension; +import org.apache.flink.types.Row; +import org.apache.flink.util.CollectionUtil; + +import org.apache.http.HttpHost; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.RestClient; +import org.elasticsearch.client.RestHighLevelClient; +import org.elasticsearch.client.indices.CreateIndexRequest; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.elasticsearch.ElasticsearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.io.IOException; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.flink.table.api.Expressions.row; +import static org.assertj.core.api.Assertions.assertThat; + +/** {@code VECTOR_SEARCH } ITCase for Elasticsearch. */ +@Testcontainers +public class Elasticsearch7VectorSearchITCase { + private static final Logger LOG = + LoggerFactory.getLogger(Elasticsearch7VectorSearchITCase.class); + + private static final int PARALLELISM = 2; + + @Container + private static final ElasticsearchContainer ES_CONTAINER = + ElasticsearchUtil.createElasticsearchContainer( + DockerImageVersions.ELASTICSEARCH_7, LOG); + + String getElasticsearchHttpHostAddress() { + return ES_CONTAINER.getHttpHostAddress(); + } + + private RestHighLevelClient getClient() { + return new RestHighLevelClient( + RestClient.builder(HttpHost.create(getElasticsearchHttpHostAddress()))); + } + + @RegisterExtension + private static final MiniClusterExtension MINI_CLUSTER_RESOURCE = + new MiniClusterExtension( + new MiniClusterResourceConfiguration.Builder() + .setNumberTaskManagers(1) + .setNumberSlotsPerTaskManager(PARALLELISM) + .build()); + + private final List inputData = + Arrays.asList( + Row.of(1L, "Spark", new Float[] {5f, 12f, 13f}), + Row.of(2L, "Flink", new Float[] {-5f, -12f, -13f})); + + private TableEnvironment tEnv; + + @BeforeEach + void beforeEach() { + tEnv = TableEnvironment.create(EnvironmentSettings.inStreamingMode()); + } + + @Test + public void testSearchFullTypeVectorTable() throws Exception { + String index = "table_with_all_supported_types"; + createFullTypesIndex(index); + tEnv.executeSql( + "CREATE TABLE esTable (" + + " id BIGINT,\n" + + " f1 STRING,\n" + + " f2 BOOLEAN,\n" + + " f3 TINYINT,\n" + + " f4 SMALLINT,\n" + + " f5 INTEGER,\n" + + " f6 DATE,\n" + + " f7 TIMESTAMP,\n" + + " f8 FLOAT,\n" + + " f9 DOUBLE,\n" + + " f10 ARRAY,\n" + + " f11 ARRAY,\n" + + " f12 ARRAY,\n" + + " f13 ARRAY,\n" + + " PRIMARY KEY (id) NOT ENFORCED\n" + + ")\n" + + "WITH (\n" + + String.format("'%s'='%s',\n", "connector", "elasticsearch-7") + + String.format( + "'%s'='%s',\n", + ElasticsearchConnectorOptions.INDEX_OPTION.key(), index) + + String.format( + "'%s'='%s'\n", + ElasticsearchConnectorOptions.HOSTS_OPTION.key(), + ES_CONTAINER.getHttpHostAddress()) + + ")"); + + tEnv.fromValues( + row( + 1, + "ABCDE", + true, + (byte) 127, + (short) 257, + 65535, + LocalDate.ofEpochDay(12345), + LocalDateTime.parse("2012-12-12T12:12:12"), + 11.11f, + 12.22d, + new Float[] {11.11f, 11.12f}, + new Double[] {12.22d, 12.22d}, + new int[] {Integer.MIN_VALUE, Integer.MAX_VALUE}, + new long[] {Long.MIN_VALUE, Long.MAX_VALUE})) + .executeInsert("esTable") + .await(); + + // Wait for es construct index. + Thread.sleep(2000); + + List rows = + CollectionUtil.iteratorToList( + tEnv.executeSql( + "WITH t(id, vector) AS (SELECT * FROM (VALUES (1, CAST(ARRAY[11.11, 1] AS ARRAY))))\n" + + "SELECT * FROM t, LATERAL TABLE(VECTOR_SEARCH(TABLE esTable, DESCRIPTOR(f10), t.vector, 3))\n") + .collect()) + .stream() + .map(Row::toString) + .collect(Collectors.toList()); + assertThat(rows) + .isEqualTo( + Collections.singletonList( + "+I[1, [11.11, 1.0], 1, ABCDE, true, 127, 257, 65535, 2003-10-20, 2012-12-12T12:12:12, 11.11, 12.22, [11.11, 11.12], [12.22, 12.22], [-2147483648, 2147483647], [-9223372036854775808, 9223372036854775807], 1.767361044883728]")); + } + + @Test + void testSearchUsingFloatArray() throws Exception { + String index = "table_with_multiple_data_with"; + createSimpleIndex(index); + tEnv.executeSql( + "CREATE TABLE es_table(" + + " id BIGINT," + + " label STRING," + + " vector ARRAY" + + ")\n WITH (\n" + + String.format("'%s'='%s',\n", "connector", "elasticsearch-7") + + String.format( + "'%s'='%s',\n", + ElasticsearchConnectorOptions.INDEX_OPTION.key(), index) + + String.format( + "'%s'='%s'\n", + ElasticsearchConnectorOptions.HOSTS_OPTION.key(), + ES_CONTAINER.getHttpHostAddress()) + + ")"); + + tEnv.fromValues( + row(1L, "Batch", new Float[] {5f, 12f, 13f}), + row(2L, "Streaming", new Float[] {-5f, -12f, -13f}), + row(3L, "Big Data", new Float[] {1f, 1f, 0f})) + .executeInsert("es_table") + .await(); + + // Wait for es construct index. + Thread.sleep(2000); + + tEnv.executeSql( + String.format( + "CREATE TABLE src(\n" + + " id BIGINT PRIMARY KEY NOT ENFORCED,\n" + + " content STRING,\n" + + " index ARRAY\n" + + ") WITH (\n" + + " 'connector' = 'values',\n" + + " 'data-id' = '%s'\n" + + ");\n", + TestValuesTableFactory.registerData(inputData))); + assertThat( + CollectionUtil.iteratorToList( + tEnv.executeSql( + "SELECT content, label FROM src, LATERAL TABLE(VECTOR_SEARCH(TABLE es_table, DESCRIPTOR(vector), src.index, 2))") + .collect()) + .stream() + .map(Row::toString) + .collect(Collectors.toList())) + .isEqualTo( + Arrays.asList( + "+I[Spark, Batch]", + "+I[Spark, Big Data]", + "+I[Flink, Streaming]", + "+I[Flink, Big Data]")); + } + + private void createFullTypesIndex(String index) throws IOException { + XContentBuilder mappingBuilder = XContentFactory.jsonBuilder(); + mappingBuilder.startObject(); + mappingBuilder.startObject("properties"); + + // id: long + mappingBuilder.startObject("id"); + mappingBuilder.field("type", "long"); + mappingBuilder.endObject(); + + // f1: string + mappingBuilder.startObject("f1"); + mappingBuilder.field("type", "text"); + mappingBuilder.endObject(); + + // f2: boolean + mappingBuilder.startObject("f2"); + mappingBuilder.field("type", "boolean"); + mappingBuilder.endObject(); + + // f3: tinyint + mappingBuilder.startObject("f3"); + mappingBuilder.field("type", "byte"); + mappingBuilder.endObject(); + + // f4: long + mappingBuilder.startObject("f4"); + mappingBuilder.field("type", "short"); + mappingBuilder.endObject(); + + // f5: long + mappingBuilder.startObject("f5"); + mappingBuilder.field("type", "integer"); + mappingBuilder.endObject(); + + // f6: date + mappingBuilder.startObject("f6"); + mappingBuilder.field("type", "date"); + mappingBuilder.endObject(); + + // f7: timestamp + mappingBuilder.startObject("f7"); + mappingBuilder.field("type", "text"); + mappingBuilder.endObject(); + + // f8: float + mappingBuilder.startObject("f8"); + mappingBuilder.field("type", "float"); + mappingBuilder.endObject(); + + // f9: double + mappingBuilder.startObject("f9"); + mappingBuilder.field("type", "double"); + mappingBuilder.endObject(); + + // f10: Array + mappingBuilder.startObject("f10"); + mappingBuilder.field("type", "dense_vector"); + mappingBuilder.field("dims", 2); + mappingBuilder.endObject(); + + // f11: Array + mappingBuilder.startObject("f11"); + mappingBuilder.field("type", "dense_vector"); + mappingBuilder.field("dims", 2); + mappingBuilder.endObject(); + + // f12: Array + mappingBuilder.startObject("f12"); + mappingBuilder.field("type", "dense_vector"); + mappingBuilder.field("dims", 2); + mappingBuilder.endObject(); + + // f13: Array + mappingBuilder.startObject("f13"); + mappingBuilder.field("type", "dense_vector"); + mappingBuilder.field("dims", 2); + mappingBuilder.endObject(); + + mappingBuilder.endObject(); // end properties + mappingBuilder.endObject(); // end root + + CreateIndexRequest request = new CreateIndexRequest(index); + request.mapping(mappingBuilder); + + this.getClient().indices().create(request, RequestOptions.DEFAULT); + } + + private void createSimpleIndex(String index) throws IOException { + XContentBuilder mappingBuilder = XContentFactory.jsonBuilder(); + mappingBuilder.startObject(); + mappingBuilder.startObject("properties"); + + // id: long + mappingBuilder.startObject("id"); + mappingBuilder.field("type", "long"); + mappingBuilder.endObject(); + + // f1: string + mappingBuilder.startObject("label"); + mappingBuilder.field("type", "text"); + mappingBuilder.endObject(); + + // f2: float vector + mappingBuilder.startObject("vector"); + mappingBuilder.field("type", "dense_vector"); + mappingBuilder.field("dims", 3); + mappingBuilder.endObject(); + + mappingBuilder.endObject(); // end properties + mappingBuilder.endObject(); // end root + + CreateIndexRequest request = new CreateIndexRequest(index); + request.mapping(mappingBuilder); + + this.getClient().indices().create(request, RequestOptions.DEFAULT); + } +} diff --git a/flink-connector-elasticsearch7/src/test/resources/testcontainers.properties b/flink-connector-elasticsearch7/src/test/resources/testcontainers.properties new file mode 100644 index 00000000..07514cc8 --- /dev/null +++ b/flink-connector-elasticsearch7/src/test/resources/testcontainers.properties @@ -0,0 +1,17 @@ +################################################################################ +# Copyright 2023 Ververica Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +ryuk.container.image = testcontainers/ryuk:0.6.0 diff --git a/pom.xml b/pom.xml index 0656d7f0..ad1200bb 100644 --- a/pom.xml +++ b/pom.xml @@ -51,7 +51,8 @@ under the License. - 2.0.0 + 2.2.0 + 2.12 2.15.3 4.13.2