diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/pom.xml b/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/pom.xml new file mode 100644 index 0000000000..6d159cc3a6 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/pom.xml @@ -0,0 +1,74 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 2.0.0-SNAPSHOT + ../../../pom.xml + + spring-ai-autoconfigure-model-openai-batch-repository-jdbc + jar + Spring AI JDBC OpenAI Batch Execution Repository Auto Configuration + Spring AI JDBC OpenAI Batch Execution Repository Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + scm:git:git://github.com/spring-projects/spring-ai.git + scm:git:ssh://git@github.com/spring-projects/spring-ai.git + + + + + + org.springframework.ai + spring-ai-openai-batch-repository-jdbc + ${project.parent.version} + + + + org.springframework.ai + spring-ai-autoconfigure-model-openai + ${project.parent.version} + + + + + org.springframework.boot + spring-boot-starter + + + org.springframework.boot + spring-boot-starter-jdbc + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + + org.springframework.boot + spring-boot-autoconfigure-processor + true + + + + + org.springframework.boot + spring-boot-starter-test + test + + + + com.h2database + h2 + test + + + + diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/model/openai/batch/repository/jdbc/autoconfigure/JdbcBatchExecutionRepositoryAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/model/openai/batch/repository/jdbc/autoconfigure/JdbcBatchExecutionRepositoryAutoConfiguration.java new file mode 100644 index 0000000000..20d6aee84a --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/model/openai/batch/repository/jdbc/autoconfigure/JdbcBatchExecutionRepositoryAutoConfiguration.java @@ -0,0 +1,70 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.openai.batch.repository.jdbc.autoconfigure; + +import javax.sql.DataSource; + +import org.springframework.ai.model.openai.autoconfigure.OpenAiBatchAutoConfiguration; +import org.springframework.ai.openai.batch.BatchExecutionRepository; +import org.springframework.ai.openai.batch.repository.jdbc.JdbcBatchExecutionRepository; +import org.springframework.ai.openai.batch.repository.jdbc.JdbcBatchExecutionRepositoryDialect; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.sql.autoconfigure.init.OnDatabaseInitializationCondition; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Conditional; +import org.springframework.jdbc.core.JdbcTemplate; + +/** + * Auto-configuration for JDBC-based {@link BatchExecutionRepository}. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +@AutoConfiguration(before = OpenAiBatchAutoConfiguration.class) +@ConditionalOnClass({ JdbcBatchExecutionRepository.class, DataSource.class, JdbcTemplate.class }) +@EnableConfigurationProperties(JdbcBatchExecutionRepositoryProperties.class) +public class JdbcBatchExecutionRepositoryAutoConfiguration { + + @Bean + @ConditionalOnMissingBean(BatchExecutionRepository.class) + JdbcBatchExecutionRepository jdbcBatchExecutionRepository(JdbcTemplate jdbcTemplate, DataSource dataSource) { + JdbcBatchExecutionRepositoryDialect dialect = JdbcBatchExecutionRepositoryDialect.from(dataSource); + return JdbcBatchExecutionRepository.builder().jdbcTemplate(jdbcTemplate).dialect(dialect).build(); + } + + @Bean + @ConditionalOnMissingBean + @Conditional(OnJdbcBatchExecutionRepositoryDatasourceInitializationCondition.class) + JdbcBatchExecutionRepositorySchemaInitializer jdbcBatchExecutionScriptDatabaseInitializer(DataSource dataSource, + JdbcBatchExecutionRepositoryProperties properties) { + return new JdbcBatchExecutionRepositorySchemaInitializer(dataSource, properties); + } + + static class OnJdbcBatchExecutionRepositoryDatasourceInitializationCondition + extends OnDatabaseInitializationCondition { + + OnJdbcBatchExecutionRepositoryDatasourceInitializationCondition() { + super("Jdbc Batch Execution Repository", + JdbcBatchExecutionRepositoryProperties.CONFIG_PREFIX + ".initialize-schema"); + } + + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/model/openai/batch/repository/jdbc/autoconfigure/JdbcBatchExecutionRepositoryProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/model/openai/batch/repository/jdbc/autoconfigure/JdbcBatchExecutionRepositoryProperties.java new file mode 100644 index 0000000000..b8a8b3bfdc --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/model/openai/batch/repository/jdbc/autoconfigure/JdbcBatchExecutionRepositoryProperties.java @@ -0,0 +1,40 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.openai.batch.repository.jdbc.autoconfigure; + +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.jdbc.init.DatabaseInitializationProperties; + +/** + * Configuration properties for JDBC Batch Execution Repository. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +@ConfigurationProperties(JdbcBatchExecutionRepositoryProperties.CONFIG_PREFIX) +public class JdbcBatchExecutionRepositoryProperties extends DatabaseInitializationProperties { + + public static final String CONFIG_PREFIX = "spring.ai.openai.batch.repository.jdbc"; + + private static final String DEFAULT_SCHEMA_LOCATION = "classpath:org/springframework/ai/openai/batch/repository/jdbc/schema-@@platform@@.sql"; + + @Override + public String getDefaultSchemaLocation() { + return DEFAULT_SCHEMA_LOCATION; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/model/openai/batch/repository/jdbc/autoconfigure/JdbcBatchExecutionRepositorySchemaInitializer.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/model/openai/batch/repository/jdbc/autoconfigure/JdbcBatchExecutionRepositorySchemaInitializer.java new file mode 100644 index 0000000000..e5b9ab618a --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/model/openai/batch/repository/jdbc/autoconfigure/JdbcBatchExecutionRepositorySchemaInitializer.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.openai.batch.repository.jdbc.autoconfigure; + +import javax.sql.DataSource; + +import org.springframework.boot.jdbc.init.PropertiesBasedDataSourceScriptDatabaseInitializer; + +/** + * Performs database initialization for the JDBC Batch Execution Repository. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +class JdbcBatchExecutionRepositorySchemaInitializer + extends PropertiesBasedDataSourceScriptDatabaseInitializer { + + JdbcBatchExecutionRepositorySchemaInitializer(DataSource dataSource, + JdbcBatchExecutionRepositoryProperties properties) { + super(dataSource, properties); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/model/openai/batch/repository/jdbc/autoconfigure/package-info.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/model/openai/batch/repository/jdbc/autoconfigure/package-info.java new file mode 100644 index 0000000000..92a87e6ae7 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/model/openai/batch/repository/jdbc/autoconfigure/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NullMarked +package org.springframework.ai.model.openai.batch.repository.jdbc.autoconfigure; + +import org.jspecify.annotations.NullMarked; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 0000000000..ba0114bb5b --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1,16 @@ +# +# Copyright 2023-present the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +org.springframework.ai.model.openai.batch.repository.jdbc.autoconfigure.JdbcBatchExecutionRepositoryAutoConfiguration diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiBatchAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiBatchAutoConfiguration.java new file mode 100644 index 0000000000..5362b4cd53 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiBatchAutoConfiguration.java @@ -0,0 +1,93 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.openai.autoconfigure; + +import java.util.List; + +import com.openai.client.OpenAIClient; + +import org.springframework.ai.openai.AbstractOpenAiOptions; +import org.springframework.ai.openai.batch.BatchExecutionRepository; +import org.springframework.ai.openai.batch.BatchRequestHandler; +import org.springframework.ai.openai.batch.InMemoryBatchExecutionRepository; +import org.springframework.ai.openai.batch.OpenAiBatchApi; +import org.springframework.ai.openai.batch.OpenAiBatchListener; +import org.springframework.ai.openai.batch.OpenAiBatchModel; +import org.springframework.ai.openai.setup.OpenAiSetup; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; + +/** + * Batch API {@link AutoConfiguration Auto-configuration} for OpenAI SDK. + *

+ * Enabled only when {@code spring.ai.openai.batch.enabled=true} is set. Wires together + * the {@link OpenAiBatchApi}, {@link OpenAiBatchModel}, registered + * {@link BatchRequestHandler}s, and {@link OpenAiBatchListener}s. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +@AutoConfiguration +@EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiBatchProperties.class }) +@ConditionalOnProperty(name = "spring.ai.openai.batch.enabled", havingValue = "true") +public class OpenAiBatchAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public OpenAiBatchApi openAiBatchApi(OpenAiConnectionProperties commonProperties, + OpenAiBatchProperties batchProperties) { + OpenAiAutoConfigurationUtil.ResolvedConnectionProperties resolved = OpenAiAutoConfigurationUtil + .resolveConnectionProperties(commonProperties, batchProperties); + + OpenAIClient openAIClient = this.openAiClient(resolved); + + return new OpenAiBatchApi(openAIClient, new com.fasterxml.jackson.databind.ObjectMapper()); + } + + @Bean + @ConditionalOnMissingBean + public BatchExecutionRepository batchExecutionRepository() { + return new InMemoryBatchExecutionRepository(); + } + + @Bean + @ConditionalOnMissingBean + public OpenAiBatchModel openAiBatchModel(OpenAiBatchProperties batchProperties, OpenAiBatchApi batchApi, + BatchExecutionRepository executionRepository, ObjectProvider>> handlers, + ObjectProvider> listeners) { + return OpenAiBatchModel.builder() + .batchApi(batchApi) + .options(batchProperties.getOptions()) + .executionRepository(executionRepository) + .handlers(handlers.getIfAvailable(List::of)) + .listeners(listeners.getIfAvailable(List::of)) + .build(); + } + + private OpenAIClient openAiClient(AbstractOpenAiOptions resolved) { + return OpenAiSetup.setupSyncClient(resolved.getBaseUrl(), resolved.getApiKey(), resolved.getCredential(), + resolved.getMicrosoftDeploymentName(), resolved.getMicrosoftFoundryServiceVersion(), + resolved.getOrganizationId(), resolved.isMicrosoftFoundry(), resolved.isGitHubModels(), + resolved.getModel(), resolved.getTimeout(), resolved.getMaxRetries(), resolved.getProxy(), + resolved.getCustomHeaders()); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiBatchProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiBatchProperties.java new file mode 100644 index 0000000000..08615a51dc --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiBatchProperties.java @@ -0,0 +1,59 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.openai.autoconfigure; + +import org.springframework.ai.openai.AbstractOpenAiOptions; +import org.springframework.ai.openai.batch.OpenAiBatchOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * OpenAI SDK Batch API autoconfiguration properties. + *

+ * Configuration properties are available under the {@code spring.ai.openai.batch} prefix. + * All batch-specific settings (rate limits, token budgets, retry policies) are + * configurable, eliminating the need for hardcoded values. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +@ConfigurationProperties(OpenAiBatchProperties.CONFIG_PREFIX) +public class OpenAiBatchProperties extends AbstractOpenAiOptions { + + public static final String CONFIG_PREFIX = "spring.ai.openai.batch"; + + /** + * Whether the OpenAI Batch API support is enabled. + */ + private boolean enabled = false; + + @NestedConfigurationProperty + private final OpenAiBatchOptions options = OpenAiBatchOptions.builder().build(); + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public OpenAiBatchOptions getOptions() { + return this.options; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 0f89c1f89d..7e6713082b 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -19,3 +19,4 @@ org.springframework.ai.model.openai.autoconfigure.OpenAiImageAutoConfiguration org.springframework.ai.model.openai.autoconfigure.OpenAiAudioSpeechAutoConfiguration org.springframework.ai.model.openai.autoconfigure.OpenAiAudioTranscriptionAutoConfiguration org.springframework.ai.model.openai.autoconfigure.OpenAiModerationAutoConfiguration +org.springframework.ai.model.openai.autoconfigure.OpenAiBatchAutoConfiguration diff --git a/models/spring-ai-openai-batch-repository-jdbc/pom.xml b/models/spring-ai-openai-batch-repository-jdbc/pom.xml new file mode 100644 index 0000000000..1bb0147d2a --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/pom.xml @@ -0,0 +1,78 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 2.0.0-SNAPSHOT + ../../pom.xml + + + spring-ai-openai-batch-repository-jdbc + Spring AI JDBC OpenAI Batch Execution Repository + Spring AI JDBC implementation of BatchExecutionRepository for OpenAI Batch API + + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + scm:git:git://github.com/spring-projects/spring-ai.git + scm:git:ssh://git@github.com/spring-projects/spring-ai.git + + + + + org.springframework.ai + spring-ai-openai + ${project.version} + + + + org.springframework + spring-jdbc + + + + com.zaxxer + HikariCP + + + + + com.h2database + h2 + test + true + + + + org.springframework.boot + spring-boot-starter-jdbc + test + + + + org.springframework.boot + spring-boot-starter-test + test + + + diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/H2BatchExecutionRepositoryDialect.java b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/H2BatchExecutionRepositoryDialect.java new file mode 100644 index 0000000000..2651f46fe7 --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/H2BatchExecutionRepositoryDialect.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch.repository.jdbc; + +/** + * H2-specific SQL dialect for batch execution repository. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public class H2BatchExecutionRepositoryDialect implements JdbcBatchExecutionRepositoryDialect { + + @Override + public String getUpsertSql() { + return """ + MERGE INTO SPRING_AI_BATCH_EXECUTION (batch_id, endpoint, status, request_count, input_file_id, created_at, updated_at) + KEY (batch_id) + VALUES (?, ?, ?, ?, ?, ?, ?)"""; + } + +} diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/HsqldbBatchExecutionRepositoryDialect.java b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/HsqldbBatchExecutionRepositoryDialect.java new file mode 100644 index 0000000000..a52c397b48 --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/HsqldbBatchExecutionRepositoryDialect.java @@ -0,0 +1,44 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch.repository.jdbc; + +/** + * HSQLDB-specific SQL dialect for batch execution repository. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public class HsqldbBatchExecutionRepositoryDialect implements JdbcBatchExecutionRepositoryDialect { + + @Override + public String getUpsertSql() { + return """ + MERGE INTO SPRING_AI_BATCH_EXECUTION AS target + USING (VALUES (?, ?, ?, ?, ?, ?, ?)) AS source (batch_id, endpoint, status, request_count, input_file_id, created_at, updated_at) + ON target.batch_id = source.batch_id + WHEN MATCHED THEN UPDATE SET + target.endpoint = source.endpoint, + target.status = source.status, + target.request_count = source.request_count, + target.input_file_id = source.input_file_id, + target.created_at = source.created_at, + target.updated_at = source.updated_at + WHEN NOT MATCHED THEN INSERT (batch_id, endpoint, status, request_count, input_file_id, created_at, updated_at) + VALUES (source.batch_id, source.endpoint, source.status, source.request_count, source.input_file_id, source.created_at, source.updated_at)"""; + } + +} diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/JdbcBatchExecutionRepository.java b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/JdbcBatchExecutionRepository.java new file mode 100644 index 0000000000..ce7d4a27f6 --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/JdbcBatchExecutionRepository.java @@ -0,0 +1,230 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch.repository.jdbc; + +import java.lang.reflect.Field; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.List; +import java.util.Optional; + +import javax.sql.DataSource; + +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.openai.batch.BatchExecution; +import org.springframework.ai.openai.batch.BatchExecutionRepository; +import org.springframework.ai.openai.batch.BatchExecutionStatus; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.datasource.DataSourceTransactionManager; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.support.TransactionTemplate; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; + +/** + * An implementation of {@link BatchExecutionRepository} for JDBC. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public final class JdbcBatchExecutionRepository implements BatchExecutionRepository { + + private final JdbcTemplate jdbcTemplate; + + private final TransactionTemplate transactionTemplate; + + private final JdbcBatchExecutionRepositoryDialect dialect; + + private static final BatchExecutionRowMapper ROW_MAPPER = new BatchExecutionRowMapper(); + + private JdbcBatchExecutionRepository(JdbcTemplate jdbcTemplate, JdbcBatchExecutionRepositoryDialect dialect, + @Nullable PlatformTransactionManager txManager) { + Assert.notNull(jdbcTemplate, "jdbcTemplate cannot be null"); + Assert.notNull(dialect, "dialect cannot be null"); + this.jdbcTemplate = jdbcTemplate; + this.dialect = dialect; + if (txManager == null) { + Assert.state(jdbcTemplate.getDataSource() != null, "jdbcTemplate dataSource cannot be null"); + txManager = new DataSourceTransactionManager(jdbcTemplate.getDataSource()); + } + this.transactionTemplate = new TransactionTemplate(txManager); + } + + @Override + public void save(BatchExecution execution) { + Assert.notNull(execution, "execution cannot be null"); + this.transactionTemplate.execute(status -> { + this.jdbcTemplate.update(this.dialect.getUpsertSql(), execution.getBatchId(), execution.getEndpoint(), + execution.getStatus().name(), execution.getRequestCount(), execution.getInputFileId(), + Timestamp.from(execution.getCreatedAt()), Timestamp.from(execution.getUpdatedAt())); + return null; + }); + } + + @Override + public Optional findById(String batchId) { + Assert.hasText(batchId, "batchId cannot be null or empty"); + List results = this.jdbcTemplate.query(this.dialect.getSelectByIdSql(), ROW_MAPPER, batchId); + return results.isEmpty() ? Optional.empty() : Optional.of(results.get(0)); + } + + @Override + public List findByStatus(BatchExecutionStatus status) { + Assert.notNull(status, "status cannot be null"); + return this.jdbcTemplate.query(this.dialect.getSelectByStatusSql(), ROW_MAPPER, status.name()); + } + + @Override + public List findPendingExecutions() { + return this.jdbcTemplate.query(this.dialect.getSelectPendingExecutionsSql(), ROW_MAPPER); + } + + @Override + public void deleteById(String batchId) { + Assert.hasText(batchId, "batchId cannot be null or empty"); + this.jdbcTemplate.update(this.dialect.getDeleteByIdSql(), batchId); + } + + public static Builder builder() { + return new Builder(); + } + + private static class BatchExecutionRowMapper implements RowMapper { + + @Override + public BatchExecution mapRow(ResultSet rs, int rowNum) throws SQLException { + String batchId = rs.getString("batch_id"); + String endpoint = rs.getString("endpoint"); + BatchExecutionStatus status = BatchExecutionStatus.valueOf(rs.getString("status")); + int requestCount = rs.getInt("request_count"); + String inputFileId = rs.getString("input_file_id"); + Timestamp createdAtTs = rs.getTimestamp("created_at"); + Timestamp updatedAtTs = rs.getTimestamp("updated_at"); + + BatchExecution execution = new BatchExecution(batchId, endpoint, status, requestCount, inputFileId); + + // Use reflection to set the persisted timestamps since BatchExecution + // sets createdAt/updatedAt to Instant.now() in its constructor. + if (createdAtTs != null) { + setFieldValue(execution, "createdAt", createdAtTs.toInstant()); + } + if (updatedAtTs != null) { + setFieldValue(execution, "updatedAt", updatedAtTs.toInstant()); + } + + return execution; + } + + private static void setFieldValue(BatchExecution execution, String fieldName, Instant value) { + Field field = ReflectionUtils.findField(BatchExecution.class, fieldName); + if (field != null) { + ReflectionUtils.makeAccessible(field); + ReflectionUtils.setField(field, execution, value); + } + } + + } + + public static final class Builder { + + private @Nullable JdbcTemplate jdbcTemplate; + + private @Nullable JdbcBatchExecutionRepositoryDialect dialect; + + private @Nullable DataSource dataSource; + + private @Nullable PlatformTransactionManager platformTransactionManager; + + private static final Logger logger = LoggerFactory.getLogger(Builder.class); + + private Builder() { + } + + public Builder jdbcTemplate(JdbcTemplate jdbcTemplate) { + this.jdbcTemplate = jdbcTemplate; + return this; + } + + public Builder dialect(JdbcBatchExecutionRepositoryDialect dialect) { + this.dialect = dialect; + return this; + } + + public Builder dataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + public Builder transactionManager(PlatformTransactionManager txManager) { + this.platformTransactionManager = txManager; + return this; + } + + public JdbcBatchExecutionRepository build() { + DataSource effectiveDataSource = resolveDataSource(); + JdbcBatchExecutionRepositoryDialect effectiveDialect = resolveDialect(effectiveDataSource); + return new JdbcBatchExecutionRepository(resolveJdbcTemplate(), effectiveDialect, + this.platformTransactionManager); + } + + private JdbcTemplate resolveJdbcTemplate() { + if (this.jdbcTemplate != null) { + return this.jdbcTemplate; + } + if (this.dataSource != null) { + return new JdbcTemplate(this.dataSource); + } + throw new IllegalArgumentException("DataSource must be set (either via dataSource() or jdbcTemplate())"); + } + + private DataSource resolveDataSource() { + if (this.dataSource != null) { + return this.dataSource; + } + if (this.jdbcTemplate != null && this.jdbcTemplate.getDataSource() != null) { + return this.jdbcTemplate.getDataSource(); + } + throw new IllegalArgumentException("DataSource must be set (either via dataSource() or jdbcTemplate())"); + } + + private JdbcBatchExecutionRepositoryDialect resolveDialect(DataSource dataSource) { + if (this.dialect == null) { + return JdbcBatchExecutionRepositoryDialect.from(dataSource); + } + else { + warnIfDialectMismatch(dataSource, this.dialect); + return this.dialect; + } + } + + private void warnIfDialectMismatch(DataSource dataSource, JdbcBatchExecutionRepositoryDialect explicitDialect) { + JdbcBatchExecutionRepositoryDialect detected = JdbcBatchExecutionRepositoryDialect.from(dataSource); + if (!detected.getClass().equals(explicitDialect.getClass())) { + logger.warn("Explicitly set dialect {} will be used instead of detected dialect {} from datasource", + explicitDialect.getClass().getSimpleName(), detected.getClass().getSimpleName()); + } + } + + } + +} diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/JdbcBatchExecutionRepositoryDialect.java b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/JdbcBatchExecutionRepositoryDialect.java new file mode 100644 index 0000000000..d273086262 --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/JdbcBatchExecutionRepositoryDialect.java @@ -0,0 +1,102 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch.repository.jdbc; + +import java.sql.DatabaseMetaData; + +import javax.sql.DataSource; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.jdbc.support.JdbcUtils; + +/** + * Abstraction for database-specific SQL for batch execution repository. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public interface JdbcBatchExecutionRepositoryDialect { + + Logger logger = LoggerFactory.getLogger(JdbcBatchExecutionRepositoryDialect.class); + + String SELECT_COLUMNS = "batch_id, endpoint, status, request_count, input_file_id, created_at, updated_at"; + + /** + * Returns the SQL to find a batch execution by ID. + */ + default String getSelectByIdSql() { + return "SELECT " + SELECT_COLUMNS + " FROM SPRING_AI_BATCH_EXECUTION WHERE batch_id = ?"; + } + + /** + * Returns the SQL to find batch executions by status. + */ + default String getSelectByStatusSql() { + return "SELECT " + SELECT_COLUMNS + " FROM SPRING_AI_BATCH_EXECUTION WHERE status = ?"; + } + + /** + * Returns the SQL to find all pending (non-terminal) batch executions. + */ + default String getSelectPendingExecutionsSql() { + return "SELECT " + SELECT_COLUMNS + + " FROM SPRING_AI_BATCH_EXECUTION WHERE status NOT IN ('RESULTS_PROCESSED', 'FAILED', 'EXPIRED', 'CANCELLED')"; + } + + /** + * Returns the database-specific upsert SQL. + */ + String getUpsertSql(); + + /** + * Returns the SQL to delete a batch execution by ID. + */ + default String getDeleteByIdSql() { + return "DELETE FROM SPRING_AI_BATCH_EXECUTION WHERE batch_id = ?"; + } + + /** + * Detects the dialect from the DataSource. + */ + static JdbcBatchExecutionRepositoryDialect from(DataSource dataSource) { + String productName = null; + try { + productName = JdbcUtils.extractDatabaseMetaData(dataSource, DatabaseMetaData::getDatabaseProductName); + } + catch (Exception e) { + logger.warn("Due to failure in establishing JDBC connection or parsing metadata, the JDBC database vendor " + + "could not be determined", e); + } + if (productName == null || productName.trim().isEmpty()) { + logger.warn("Database product name is null or empty, defaulting to Postgres dialect."); + return new PostgresBatchExecutionRepositoryDialect(); + } + return switch (productName) { + case "PostgreSQL" -> new PostgresBatchExecutionRepositoryDialect(); + case "MySQL", "MariaDB" -> new MysqlBatchExecutionRepositoryDialect(); + case "Microsoft SQL Server" -> new SqlServerBatchExecutionRepositoryDialect(); + case "HSQL Database Engine" -> new HsqldbBatchExecutionRepositoryDialect(); + case "SQLite" -> new SqliteBatchExecutionRepositoryDialect(); + case "H2" -> new H2BatchExecutionRepositoryDialect(); + case "Oracle" -> new OracleBatchExecutionRepositoryDialect(); + default -> new PostgresBatchExecutionRepositoryDialect(); + }; + } + +} diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/MysqlBatchExecutionRepositoryDialect.java b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/MysqlBatchExecutionRepositoryDialect.java new file mode 100644 index 0000000000..3386a2d7f4 --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/MysqlBatchExecutionRepositoryDialect.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch.repository.jdbc; + +/** + * MySQL dialect for batch execution repository. Also works for MariaDB. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public class MysqlBatchExecutionRepositoryDialect implements JdbcBatchExecutionRepositoryDialect { + + @Override + public String getUpsertSql() { + return """ + INSERT INTO SPRING_AI_BATCH_EXECUTION (batch_id, endpoint, status, request_count, input_file_id, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + AS new_values + ON DUPLICATE KEY UPDATE + endpoint = new_values.endpoint, + status = new_values.status, + request_count = new_values.request_count, + input_file_id = new_values.input_file_id, + created_at = new_values.created_at, + updated_at = new_values.updated_at"""; + } + +} diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/OracleBatchExecutionRepositoryDialect.java b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/OracleBatchExecutionRepositoryDialect.java new file mode 100644 index 0000000000..d2a1581fa8 --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/OracleBatchExecutionRepositoryDialect.java @@ -0,0 +1,44 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch.repository.jdbc; + +/** + * Oracle dialect for batch execution repository. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public class OracleBatchExecutionRepositoryDialect implements JdbcBatchExecutionRepositoryDialect { + + @Override + public String getUpsertSql() { + return """ + MERGE INTO SPRING_AI_BATCH_EXECUTION target + USING (SELECT ? AS batch_id, ? AS endpoint, ? AS status, ? AS request_count, ? AS input_file_id, ? AS created_at, ? AS updated_at FROM DUAL) source + ON (target.batch_id = source.batch_id) + WHEN MATCHED THEN UPDATE SET + target.endpoint = source.endpoint, + target.status = source.status, + target.request_count = source.request_count, + target.input_file_id = source.input_file_id, + target.created_at = source.created_at, + target.updated_at = source.updated_at + WHEN NOT MATCHED THEN INSERT (batch_id, endpoint, status, request_count, input_file_id, created_at, updated_at) + VALUES (source.batch_id, source.endpoint, source.status, source.request_count, source.input_file_id, source.created_at, source.updated_at)"""; + } + +} diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/PostgresBatchExecutionRepositoryDialect.java b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/PostgresBatchExecutionRepositoryDialect.java new file mode 100644 index 0000000000..feadb5c5ca --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/PostgresBatchExecutionRepositoryDialect.java @@ -0,0 +1,41 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch.repository.jdbc; + +/** + * PostgreSQL dialect for batch execution repository. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public class PostgresBatchExecutionRepositoryDialect implements JdbcBatchExecutionRepositoryDialect { + + @Override + public String getUpsertSql() { + return """ + INSERT INTO SPRING_AI_BATCH_EXECUTION (batch_id, endpoint, status, request_count, input_file_id, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (batch_id) DO UPDATE SET + endpoint = EXCLUDED.endpoint, + status = EXCLUDED.status, + request_count = EXCLUDED.request_count, + input_file_id = EXCLUDED.input_file_id, + created_at = EXCLUDED.created_at, + updated_at = EXCLUDED.updated_at"""; + } + +} diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/SqlServerBatchExecutionRepositoryDialect.java b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/SqlServerBatchExecutionRepositoryDialect.java new file mode 100644 index 0000000000..823c8f4b01 --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/SqlServerBatchExecutionRepositoryDialect.java @@ -0,0 +1,44 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch.repository.jdbc; + +/** + * SQL Server dialect for batch execution repository. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public class SqlServerBatchExecutionRepositoryDialect implements JdbcBatchExecutionRepositoryDialect { + + @Override + public String getUpsertSql() { + return """ + MERGE INTO SPRING_AI_BATCH_EXECUTION WITH (HOLDLOCK) AS target + USING (VALUES (?, ?, ?, ?, ?, ?, ?)) AS source (batch_id, endpoint, status, request_count, input_file_id, created_at, updated_at) + ON target.batch_id = source.batch_id + WHEN MATCHED THEN UPDATE SET + target.endpoint = source.endpoint, + target.status = source.status, + target.request_count = source.request_count, + target.input_file_id = source.input_file_id, + target.created_at = source.created_at, + target.updated_at = source.updated_at + WHEN NOT MATCHED THEN INSERT (batch_id, endpoint, status, request_count, input_file_id, created_at, updated_at) + VALUES (source.batch_id, source.endpoint, source.status, source.request_count, source.input_file_id, source.created_at, source.updated_at);"""; + } + +} diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/SqliteBatchExecutionRepositoryDialect.java b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/SqliteBatchExecutionRepositoryDialect.java new file mode 100644 index 0000000000..a6a5af69b0 --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/SqliteBatchExecutionRepositoryDialect.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch.repository.jdbc; + +/** + * SQLite dialect for batch execution repository. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public class SqliteBatchExecutionRepositoryDialect implements JdbcBatchExecutionRepositoryDialect { + + @Override + public String getUpsertSql() { + return """ + INSERT OR REPLACE INTO SPRING_AI_BATCH_EXECUTION (batch_id, endpoint, status, request_count, input_file_id, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?)"""; + } + +} diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/aot/hint/JdbcBatchExecutionRepositoryRuntimeHints.java b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/aot/hint/JdbcBatchExecutionRepositoryRuntimeHints.java new file mode 100644 index 0000000000..a72228ab73 --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/aot/hint/JdbcBatchExecutionRepositoryRuntimeHints.java @@ -0,0 +1,43 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch.repository.jdbc.aot.hint; + +import javax.sql.DataSource; + +import org.jspecify.annotations.Nullable; + +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; + +/** + * A {@link RuntimeHintsRegistrar} for JDBC Batch Execution Repository hints. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +class JdbcBatchExecutionRepositoryRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(RuntimeHints hints, @Nullable ClassLoader classLoader) { + hints.reflection() + .registerType(DataSource.class, hint -> hint.withMembers(MemberCategory.INVOKE_DECLARED_METHODS)); + + hints.resources().registerPattern("org/springframework/ai/openai/batch/repository/jdbc/schema-*.sql"); + } + +} diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/aot/hint/package-info.java b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/aot/hint/package-info.java new file mode 100644 index 0000000000..84f641e2ce --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/aot/hint/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NullMarked +package org.springframework.ai.openai.batch.repository.jdbc.aot.hint; + +import org.jspecify.annotations.NullMarked; diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/package-info.java b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/package-info.java new file mode 100644 index 0000000000..caca48f8fb --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/java/org/springframework/ai/openai/batch/repository/jdbc/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NullMarked +package org.springframework.ai.openai.batch.repository.jdbc; + +import org.jspecify.annotations.NullMarked; diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-h2.sql b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-h2.sql new file mode 100644 index 0000000000..cf3332d9dd --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-h2.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS SPRING_AI_BATCH_EXECUTION ( + batch_id VARCHAR(255) NOT NULL PRIMARY KEY, + endpoint VARCHAR(255) NOT NULL, + status VARCHAR(50) NOT NULL, + request_count INTEGER NOT NULL, + input_file_id VARCHAR(255), + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL +); diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-hsqldb.sql b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-hsqldb.sql new file mode 100644 index 0000000000..cf3332d9dd --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-hsqldb.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS SPRING_AI_BATCH_EXECUTION ( + batch_id VARCHAR(255) NOT NULL PRIMARY KEY, + endpoint VARCHAR(255) NOT NULL, + status VARCHAR(50) NOT NULL, + request_count INTEGER NOT NULL, + input_file_id VARCHAR(255), + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL +); diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-mariadb.sql b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-mariadb.sql new file mode 100644 index 0000000000..cf3332d9dd --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-mariadb.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS SPRING_AI_BATCH_EXECUTION ( + batch_id VARCHAR(255) NOT NULL PRIMARY KEY, + endpoint VARCHAR(255) NOT NULL, + status VARCHAR(50) NOT NULL, + request_count INTEGER NOT NULL, + input_file_id VARCHAR(255), + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL +); diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-mysql.sql b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-mysql.sql new file mode 100644 index 0000000000..cf3332d9dd --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-mysql.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS SPRING_AI_BATCH_EXECUTION ( + batch_id VARCHAR(255) NOT NULL PRIMARY KEY, + endpoint VARCHAR(255) NOT NULL, + status VARCHAR(50) NOT NULL, + request_count INTEGER NOT NULL, + input_file_id VARCHAR(255), + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL +); diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-oracle.sql b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-oracle.sql new file mode 100644 index 0000000000..b234ca8b88 --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-oracle.sql @@ -0,0 +1,16 @@ +BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE SPRING_AI_BATCH_EXECUTION ( + batch_id VARCHAR2(255) NOT NULL PRIMARY KEY, + endpoint VARCHAR2(255) NOT NULL, + status VARCHAR2(50) NOT NULL, + request_count NUMBER(10) NOT NULL, + input_file_id VARCHAR2(255), + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL + )'; +EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; +END; diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-postgresql.sql b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-postgresql.sql new file mode 100644 index 0000000000..cf3332d9dd --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-postgresql.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS SPRING_AI_BATCH_EXECUTION ( + batch_id VARCHAR(255) NOT NULL PRIMARY KEY, + endpoint VARCHAR(255) NOT NULL, + status VARCHAR(50) NOT NULL, + request_count INTEGER NOT NULL, + input_file_id VARCHAR(255), + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL +); diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-sqlite.sql b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-sqlite.sql new file mode 100644 index 0000000000..cf3332d9dd --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-sqlite.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS SPRING_AI_BATCH_EXECUTION ( + batch_id VARCHAR(255) NOT NULL PRIMARY KEY, + endpoint VARCHAR(255) NOT NULL, + status VARCHAR(50) NOT NULL, + request_count INTEGER NOT NULL, + input_file_id VARCHAR(255), + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL +); diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-sqlserver.sql b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-sqlserver.sql new file mode 100644 index 0000000000..94ff53088a --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/main/resources/org/springframework/ai/openai/batch/repository/jdbc/schema-sqlserver.sql @@ -0,0 +1,10 @@ +IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'SPRING_AI_BATCH_EXECUTION') +CREATE TABLE SPRING_AI_BATCH_EXECUTION ( + batch_id VARCHAR(255) NOT NULL PRIMARY KEY, + endpoint VARCHAR(255) NOT NULL, + status VARCHAR(50) NOT NULL, + request_count INTEGER NOT NULL, + input_file_id VARCHAR(255), + created_at DATETIME2 NOT NULL, + updated_at DATETIME2 NOT NULL +); diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/test/java/org/springframework/ai/openai/batch/repository/jdbc/JdbcBatchExecutionRepositoryBuilderTests.java b/models/spring-ai-openai-batch-repository-jdbc/src/test/java/org/springframework/ai/openai/batch/repository/jdbc/JdbcBatchExecutionRepositoryBuilderTests.java new file mode 100644 index 0000000000..995632c850 --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/test/java/org/springframework/ai/openai/batch/repository/jdbc/JdbcBatchExecutionRepositoryBuilderTests.java @@ -0,0 +1,105 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch.repository.jdbc; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.SQLException; + +import javax.sql.DataSource; + +import org.junit.jupiter.api.Test; + +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.transaction.PlatformTransactionManager; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link JdbcBatchExecutionRepository.Builder}. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +class JdbcBatchExecutionRepositoryBuilderTests { + + @Test + void testBuilderWithExplicitDialect() { + DataSource dataSource = mock(DataSource.class); + JdbcBatchExecutionRepositoryDialect dialect = mock(JdbcBatchExecutionRepositoryDialect.class); + + JdbcBatchExecutionRepository repository = JdbcBatchExecutionRepository.builder() + .dataSource(dataSource) + .dialect(dialect) + .build(); + + assertThat(repository).isNotNull(); + } + + @Test + void testBuilderWithExplicitDialectAndTransactionManager() { + DataSource dataSource = mock(DataSource.class); + JdbcBatchExecutionRepositoryDialect dialect = mock(JdbcBatchExecutionRepositoryDialect.class); + PlatformTransactionManager txManager = mock(PlatformTransactionManager.class); + + JdbcBatchExecutionRepository repository = JdbcBatchExecutionRepository.builder() + .dataSource(dataSource) + .dialect(dialect) + .transactionManager(txManager) + .build(); + + assertThat(repository).isNotNull(); + } + + @Test + void testBuilderWithDialectFromDataSource() throws SQLException { + DataSource dataSource = mock(DataSource.class); + Connection connection = mock(Connection.class); + DatabaseMetaData metaData = mock(DatabaseMetaData.class); + + when(dataSource.getConnection()).thenReturn(connection); + when(connection.getMetaData()).thenReturn(metaData); + when(metaData.getURL()).thenReturn("jdbc:postgresql://localhost:5432/testdb"); + + JdbcBatchExecutionRepository repository = JdbcBatchExecutionRepository.builder().dataSource(dataSource).build(); + + assertThat(repository).isNotNull(); + } + + @Test + void testBuilderWithNullDataSource() { + assertThatThrownBy(() -> JdbcBatchExecutionRepository.builder().build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("DataSource must be set (either via dataSource() or jdbcTemplate())"); + } + + @Test + void repositoryShouldUseProvidedJdbcTemplate() throws SQLException { + DataSource dataSource = mock(DataSource.class); + JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource); + + JdbcBatchExecutionRepository repository = JdbcBatchExecutionRepository.builder() + .jdbcTemplate(jdbcTemplate) + .build(); + + assertThat(repository).extracting("jdbcTemplate").isSameAs(jdbcTemplate); + } + +} diff --git a/models/spring-ai-openai-batch-repository-jdbc/src/test/java/org/springframework/ai/openai/batch/repository/jdbc/JdbcBatchExecutionRepositoryTests.java b/models/spring-ai-openai-batch-repository-jdbc/src/test/java/org/springframework/ai/openai/batch/repository/jdbc/JdbcBatchExecutionRepositoryTests.java new file mode 100644 index 0000000000..f062845dff --- /dev/null +++ b/models/spring-ai-openai-batch-repository-jdbc/src/test/java/org/springframework/ai/openai/batch/repository/jdbc/JdbcBatchExecutionRepositoryTests.java @@ -0,0 +1,158 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch.repository.jdbc; + +import java.util.List; +import java.util.Optional; + +import javax.sql.DataSource; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.openai.batch.BatchExecution; +import org.springframework.ai.openai.batch.BatchExecutionStatus; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link JdbcBatchExecutionRepository} using H2 in-memory database. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +class JdbcBatchExecutionRepositoryTests { + + private JdbcBatchExecutionRepository repository; + + private JdbcTemplate jdbcTemplate; + + @BeforeEach + void setUp() { + DataSource dataSource = new EmbeddedDatabaseBuilder().setType(EmbeddedDatabaseType.H2) + .generateUniqueName(true) + .addScript("org/springframework/ai/openai/batch/repository/jdbc/schema-h2.sql") + .build(); + this.jdbcTemplate = new JdbcTemplate(dataSource); + this.repository = JdbcBatchExecutionRepository.builder().dataSource(dataSource).build(); + } + + @Test + void saveAndFindById() { + BatchExecution execution = new BatchExecution("batch-1", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, + 10, "file-abc123"); + + this.repository.save(execution); + + Optional found = this.repository.findById("batch-1"); + assertThat(found).isPresent(); + assertThat(found.get().getBatchId()).isEqualTo("batch-1"); + assertThat(found.get().getEndpoint()).isEqualTo("/v1/chat/completions"); + assertThat(found.get().getStatus()).isEqualTo(BatchExecutionStatus.SUBMITTED); + assertThat(found.get().getRequestCount()).isEqualTo(10); + assertThat(found.get().getInputFileId()).isEqualTo("file-abc123"); + assertThat(found.get().getCreatedAt()).isNotNull(); + assertThat(found.get().getUpdatedAt()).isNotNull(); + } + + @Test + void upsertUpdateExistingRecord() { + BatchExecution execution = new BatchExecution("batch-1", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, + 10, "file-abc123"); + this.repository.save(execution); + + execution.setStatus(BatchExecutionStatus.IN_PROGRESS); + this.repository.save(execution); + + Optional found = this.repository.findById("batch-1"); + assertThat(found).isPresent(); + assertThat(found.get().getStatus()).isEqualTo(BatchExecutionStatus.IN_PROGRESS); + + // Verify only one record exists + Integer count = this.jdbcTemplate.queryForObject( + "SELECT COUNT(*) FROM SPRING_AI_BATCH_EXECUTION WHERE batch_id = ?", Integer.class, "batch-1"); + assertThat(count).isEqualTo(1); + } + + @Test + void findByStatus() { + this.repository + .save(new BatchExecution("batch-1", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, 5, "file-1")); + this.repository + .save(new BatchExecution("batch-2", "/v1/chat/completions", BatchExecutionStatus.IN_PROGRESS, 3, "file-2")); + this.repository + .save(new BatchExecution("batch-3", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, 7, "file-3")); + + List submitted = this.repository.findByStatus(BatchExecutionStatus.SUBMITTED); + assertThat(submitted).hasSize(2); + assertThat(submitted).extracting(BatchExecution::getBatchId).containsExactlyInAnyOrder("batch-1", "batch-3"); + + List inProgress = this.repository.findByStatus(BatchExecutionStatus.IN_PROGRESS); + assertThat(inProgress).hasSize(1); + assertThat(inProgress.get(0).getBatchId()).isEqualTo("batch-2"); + } + + @Test + void findPendingExecutions() { + // Non-terminal statuses + this.repository + .save(new BatchExecution("batch-1", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, 5, "file-1")); + this.repository + .save(new BatchExecution("batch-2", "/v1/chat/completions", BatchExecutionStatus.IN_PROGRESS, 3, "file-2")); + this.repository + .save(new BatchExecution("batch-3", "/v1/chat/completions", BatchExecutionStatus.VALIDATING, 2, "file-3")); + this.repository + .save(new BatchExecution("batch-4", "/v1/chat/completions", BatchExecutionStatus.COMPLETED, 4, "file-4")); + + // Terminal statuses + this.repository.save(new BatchExecution("batch-5", "/v1/chat/completions", + BatchExecutionStatus.RESULTS_PROCESSED, 6, "file-5")); + this.repository + .save(new BatchExecution("batch-6", "/v1/chat/completions", BatchExecutionStatus.FAILED, 1, "file-6")); + this.repository + .save(new BatchExecution("batch-7", "/v1/chat/completions", BatchExecutionStatus.EXPIRED, 8, "file-7")); + this.repository + .save(new BatchExecution("batch-8", "/v1/chat/completions", BatchExecutionStatus.CANCELLED, 9, "file-8")); + + List pending = this.repository.findPendingExecutions(); + assertThat(pending).hasSize(4); + assertThat(pending).extracting(BatchExecution::getBatchId) + .containsExactlyInAnyOrder("batch-1", "batch-2", "batch-3", "batch-4"); + } + + @Test + void deleteById() { + this.repository + .save(new BatchExecution("batch-1", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, 5, "file-1")); + + assertThat(this.repository.findById("batch-1")).isPresent(); + + this.repository.deleteById("batch-1"); + + assertThat(this.repository.findById("batch-1")).isEmpty(); + } + + @Test + void findByIdReturnsEmptyForMissing() { + Optional found = this.repository.findById("nonexistent"); + assertThat(found).isEmpty(); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchExecution.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchExecution.java new file mode 100644 index 0000000000..9ef2f53e98 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchExecution.java @@ -0,0 +1,114 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.time.Instant; + +import org.springframework.util.Assert; + +/** + * Represents a tracked OpenAI Batch API execution. Stores metadata about a batch + * submission including its current status, request count, and associated file + * identifiers. + *

+ * This entity is managed by a {@link BatchExecutionRepository} and is updated as the + * batch progresses through its lifecycle. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public class BatchExecution { + + private final String batchId; + + private final String endpoint; + + private final int requestCount; + + private final String inputFileId; + + private final Instant createdAt; + + private BatchExecutionStatus status; + + private Instant updatedAt; + + /** + * Creates a new batch execution record. + * @param batchId the OpenAI batch ID + * @param endpoint the API endpoint this batch targets + * @param status the initial status + * @param requestCount the number of requests in the batch + * @param inputFileId the OpenAI file ID of the uploaded JSONL input + */ + public BatchExecution(String batchId, String endpoint, BatchExecutionStatus status, int requestCount, + String inputFileId) { + Assert.hasText(batchId, "batchId must not be blank"); + Assert.hasText(endpoint, "endpoint must not be blank"); + Assert.notNull(status, "status must not be null"); + Assert.isTrue(requestCount > 0, "requestCount must be positive"); + Assert.hasText(inputFileId, "inputFileId must not be blank"); + this.batchId = batchId; + this.endpoint = endpoint; + this.status = status; + this.requestCount = requestCount; + this.inputFileId = inputFileId; + this.createdAt = Instant.now(); + this.updatedAt = this.createdAt; + } + + public String getBatchId() { + return this.batchId; + } + + public String getEndpoint() { + return this.endpoint; + } + + public BatchExecutionStatus getStatus() { + return this.status; + } + + public void setStatus(BatchExecutionStatus status) { + Assert.notNull(status, "status must not be null"); + this.status = status; + this.updatedAt = Instant.now(); + } + + public int getRequestCount() { + return this.requestCount; + } + + public String getInputFileId() { + return this.inputFileId; + } + + public Instant getCreatedAt() { + return this.createdAt; + } + + public Instant getUpdatedAt() { + return this.updatedAt; + } + + @Override + public String toString() { + return "BatchExecution{batchId='" + this.batchId + "', endpoint='" + this.endpoint + "', status=" + this.status + + ", requestCount=" + this.requestCount + '}'; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchExecutionRepository.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchExecutionRepository.java new file mode 100644 index 0000000000..1ef59f73be --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchExecutionRepository.java @@ -0,0 +1,72 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.util.List; +import java.util.Optional; + +/** + * Repository interface for persisting and querying {@link BatchExecution} records. + *

+ * Provides an out-of-the-box {@link InMemoryBatchExecutionRepository} for simple + * use-cases and development. For production, users may implement this interface with a + * database-backed store (e.g., Spring Data JPA, JDBC) to survive application restarts. + * + * @author Yasin Akbas + * @since 2.0.0 + * @see InMemoryBatchExecutionRepository + * @see BatchExecution + */ +public interface BatchExecutionRepository { + + /** + * Saves a batch execution record. If a record with the same batch ID already exists, + * it is replaced. + * @param execution the batch execution to save + */ + void save(BatchExecution execution); + + /** + * Finds a batch execution by its OpenAI batch ID. + * @param batchId the OpenAI batch ID + * @return the batch execution, or empty if not found + */ + Optional findById(String batchId); + + /** + * Finds all batch executions with the given status. + * @param status the status to filter by + * @return a list of matching batch executions + */ + List findByStatus(BatchExecutionStatus status); + + /** + * Finds all batch executions that are still in progress and should be checked for + * completion. This includes executions in {@link BatchExecutionStatus#SUBMITTED}, + * {@link BatchExecutionStatus#VALIDATING}, {@link BatchExecutionStatus#IN_PROGRESS}, + * and {@link BatchExecutionStatus#FINALIZING} states. + * @return a list of non-terminal batch executions + */ + List findPendingExecutions(); + + /** + * Deletes a batch execution record. + * @param batchId the OpenAI batch ID to delete + */ + void deleteById(String batchId); + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchExecutionStatus.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchExecutionStatus.java new file mode 100644 index 0000000000..38fa1bf00e --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchExecutionStatus.java @@ -0,0 +1,101 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import com.openai.models.batches.Batch; + +/** + * Status values for a tracked batch execution, mapping to OpenAI Batch API states with an + * additional {@link #RESULTS_PROCESSED} terminal state indicating that the application + * has fully processed the results. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public enum BatchExecutionStatus { + + /** Batch has been submitted to OpenAI and is awaiting validation. */ + SUBMITTED, + + /** OpenAI is validating the batch input file. */ + VALIDATING, + + /** Batch is actively being processed by OpenAI. */ + IN_PROGRESS, + + /** OpenAI is finalizing the batch results. */ + FINALIZING, + + /** Batch completed on OpenAI but results have not yet been processed. */ + COMPLETED, + + /** Results have been downloaded and dispatched to handlers. */ + RESULTS_PROCESSED, + + /** Batch failed on the OpenAI side. */ + FAILED, + + /** Batch expired before completion (OpenAI's 24h window). */ + EXPIRED, + + /** Batch is being cancelled. */ + CANCELLING, + + /** Batch was cancelled. */ + CANCELLED; + + /** + * Converts an OpenAI SDK {@link Batch.Status} to a {@link BatchExecutionStatus}. + * @param openAiStatus the OpenAI batch status + * @return the corresponding execution status + */ + public static BatchExecutionStatus fromOpenAiStatus(Batch.Status openAiStatus) { + if (Batch.Status.VALIDATING.equals(openAiStatus)) { + return VALIDATING; + } + if (Batch.Status.IN_PROGRESS.equals(openAiStatus)) { + return IN_PROGRESS; + } + if (Batch.Status.FINALIZING.equals(openAiStatus)) { + return FINALIZING; + } + if (Batch.Status.COMPLETED.equals(openAiStatus)) { + return COMPLETED; + } + if (Batch.Status.FAILED.equals(openAiStatus)) { + return FAILED; + } + if (Batch.Status.EXPIRED.equals(openAiStatus)) { + return EXPIRED; + } + if (Batch.Status.CANCELLING.equals(openAiStatus)) { + return CANCELLING; + } + if (Batch.Status.CANCELLED.equals(openAiStatus)) { + return CANCELLED; + } + return SUBMITTED; + } + + /** + * Returns whether this status represents a terminal state (no further transitions). + */ + public boolean isTerminal() { + return this == RESULTS_PROCESSED || this == FAILED || this == EXPIRED || this == CANCELLED; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchRequestCustomId.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchRequestCustomId.java new file mode 100644 index 0000000000..1ae3c9acb0 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchRequestCustomId.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import org.springframework.util.Assert; + +/** + * Represents a custom ID for an OpenAI Batch API request line. Uses a {@code ::} + * delimiter to safely encode an entity identifier and a handler identifier. + *

+ * The {@code ::} delimiter is chosen over single characters (e.g., {@code _} or + * {@code -}) to avoid collisions with identifiers that may naturally contain those + * characters. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public record BatchRequestCustomId(String entityId, String handlerId) { + + /** + * Delimiter used to separate the entity ID and handler ID in the serialized custom ID + * string. + */ + public static final String DELIMITER = "::"; + + public BatchRequestCustomId { + Assert.hasText(entityId, "entityId must not be blank"); + Assert.hasText(handlerId, "handlerId must not be blank"); + Assert.isTrue(!entityId.contains(DELIMITER), "entityId must not contain the delimiter '" + DELIMITER + "'"); + Assert.isTrue(!handlerId.contains(DELIMITER), "handlerId must not contain the delimiter '" + DELIMITER + "'"); + } + + /** + * Parses a custom ID string into a {@link BatchRequestCustomId}. + * @param customId the custom ID string in the format {@code entityId::handlerId} + * @return the parsed custom ID + * @throws IllegalArgumentException if the string does not contain exactly two parts + */ + public static BatchRequestCustomId parse(String customId) { + Assert.hasText(customId, "customId must not be blank"); + String[] parts = customId.split(DELIMITER, -1); + if (parts.length != 2) { + throw new IllegalArgumentException( + "Invalid custom ID format: expected 'entityId" + DELIMITER + "handlerId', got: " + customId); + } + return new BatchRequestCustomId(parts[0], parts[1]); + } + + @Override + public String toString() { + return this.entityId + DELIMITER + this.handlerId; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchRequestHandler.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchRequestHandler.java new file mode 100644 index 0000000000..989da2e77d --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchRequestHandler.java @@ -0,0 +1,122 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.util.List; +import java.util.Map; + +/** + * Interface for batch request handlers that know how to generate OpenAI API request + * bodies from domain-specific input data. + *

+ * Implementations are responsible for: + *

+ * + *

+ * The handler follows the hybrid storage approach: domain-specific input data is + * stored in the database, and the full OpenAI API envelope (model, reasoning_effort, + * response_format, messages, etc.) is generated on demand at execution time. This means + * configuration changes (e.g., fixing a wrong model or reasoning_effort) automatically + * apply to all pending requests without data cleanup. + * + * @param the type of domain-specific input data + * @author Yasin Akbas + * @since 2.0.0 + */ +public interface BatchRequestHandler { + + /** + * Returns the unique identifier for this handler. Used as the {@code handlerId} + * component of {@link BatchRequestCustomId}. + * @return the handler identifier, must be non-blank and not contain + * {@link BatchRequestCustomId#DELIMITER} + */ + String getHandlerId(); + + /** + * Returns the OpenAI API endpoint this handler targets. + * @return the endpoint URL (e.g., {@code /v1/chat/completions}, + * {@code /v1/embeddings}) + */ + String getEndpoint(); + + /** + * Generates the OpenAI API request body from the given domain input data. This method + * is called at batch execution time, allowing the request to reflect the latest + * handler configuration (model, prompts, parameters, etc.). + * @param input the domain-specific input data + * @return the request body as a map of parameters suitable for the target endpoint + */ + Map generateRequestBody(I input); + + /** + * Estimates the number of prompt tokens this request will consume. Used for token + * budget management. + * @param input the domain-specific input data + * @return the estimated token count + */ + int estimateTokenUsage(I input); + + /** + * Called for each successfully completed response line from the batch. The + * {@code batchVersion} is the version stored in the batch metadata at creation time. + * Handlers can use this to determine whether the response format matches the current + * handler logic. + * @param customId the parsed custom ID from the response + * @param responseBody the response body from the OpenAI API + * @param batchVersion the version from the batch metadata + */ + void onSuccess(BatchRequestCustomId customId, Map responseBody, int batchVersion); + + /** + * Called for each failed response line from the batch. + * @param customId the parsed custom ID from the response + * @param error the error details + * @param batchVersion the version from the batch metadata + */ + void onError(BatchRequestCustomId customId, BatchResponseLine.Error error, int batchVersion); + + /** + * Returns the pending domain input items to be processed in the next batch. Each item + * is keyed by its entity ID (used as the {@code entityId} in + * {@link BatchRequestCustomId}). + * @param maxItems the maximum number of items to return + * @return a map of entity IDs to domain input data + */ + Map getPendingItems(int maxItems); + + /** + * Converts a list of request lines to be submitted in a batch. This default + * implementation generates a {@link BatchRequestLine} for each pending item by + * calling {@link #generateRequestBody(Object)}. + * @param pendingItems the pending items keyed by entity ID + * @return the list of request lines ready for batch submission + */ + default List toRequestLines(Map pendingItems) { + return pendingItems.entrySet().stream().map(entry -> { + String customId = new BatchRequestCustomId(entry.getKey(), getHandlerId()).toString(); + Map body = generateRequestBody(entry.getValue()); + return BatchRequestLine.post(customId, getEndpoint(), body); + }).toList(); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchRequestLine.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchRequestLine.java new file mode 100644 index 0000000000..de7a5dee1d --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchRequestLine.java @@ -0,0 +1,70 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.util.Assert; + +/** + * Represents a single request line in an OpenAI Batch API JSONL input file. + *

+ * Each line follows the format: + * + *

+ * {
+ *   "custom_id": "entity-123::my-handler",
+ *   "method": "POST",
+ *   "url": "/v1/chat/completions",
+ *   "body": { ... }
+ * }
+ * 
+ * + * The {@code body} is a generic map that the handler populates with the appropriate API + * request parameters. This keeps the batch framework endpoint-agnostic. + * + * @author Yasin Akbas + * @since 2.0.0 + * @see OpenAI Batch + * API + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public record BatchRequestLine(@JsonProperty("custom_id") String customId, @JsonProperty("method") String method, + @JsonProperty("url") String url, @JsonProperty("body") Map body) { + + public BatchRequestLine { + Assert.hasText(customId, "customId must not be blank"); + Assert.hasText(method, "method must not be blank"); + Assert.hasText(url, "url must not be blank"); + Assert.notNull(body, "body must not be null"); + } + + /** + * Creates a POST request line for the given endpoint. + * @param customId the custom ID for tracking this request + * @param url the API endpoint URL (e.g., {@code /v1/chat/completions}) + * @param body the request body parameters + * @return a new {@link BatchRequestLine} + */ + public static BatchRequestLine post(String customId, String url, Map body) { + return new BatchRequestLine(customId, "POST", url, body); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchResponseLine.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchResponseLine.java new file mode 100644 index 0000000000..8d87c16343 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/BatchResponseLine.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.jspecify.annotations.Nullable; + +/** + * Represents a single response line from an OpenAI Batch API JSONL output file. + *

+ * Each line follows the format: + * + *

+ * {
+ *   "id": "batch_req_abc123",
+ *   "custom_id": "entity-123::my-handler",
+ *   "response": {
+ *     "status_code": 200,
+ *     "request_id": "req_abc123",
+ *     "body": { ... }
+ *   },
+ *   "error": {
+ *     "code": "...",
+ *     "message": "..."
+ *   }
+ * }
+ * 
+ * + * @author Yasin Akbas + * @since 2.0.0 + * @see OpenAI Batch + * API + */ +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +public record BatchResponseLine(@JsonProperty("id") @Nullable String id, + @JsonProperty("custom_id") @Nullable String customId, @JsonProperty("response") @Nullable Response response, + @JsonProperty("error") @Nullable Error error) { + + /** + * Returns whether this response line indicates a successful response. + */ + public boolean isSuccess() { + return this.response != null && this.response.statusCode() != null && this.response.statusCode() == 200 + && this.error == null; + } + + /** + * The response envelope containing status code and body. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Response(@JsonProperty("status_code") @Nullable Integer statusCode, + @JsonProperty("request_id") @Nullable String requestId, + @JsonProperty("body") @Nullable Map body) { + } + + /** + * Error details when a request in the batch fails. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Error(@JsonProperty("code") @Nullable String code, + @JsonProperty("message") @Nullable String message) { + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/InMemoryBatchExecutionRepository.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/InMemoryBatchExecutionRepository.java new file mode 100644 index 0000000000..0ecae4d8dc --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/InMemoryBatchExecutionRepository.java @@ -0,0 +1,71 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import org.springframework.util.Assert; + +/** + * An in-memory implementation of {@link BatchExecutionRepository} backed by a + * {@link ConcurrentHashMap}. + *

+ * Suitable for development, testing, and single-instance deployments where batch + * execution tracking does not need to survive application restarts. For production + * use-cases with multiple instances or restart resilience, provide a database-backed + * implementation. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public final class InMemoryBatchExecutionRepository implements BatchExecutionRepository { + + private final Map store = new ConcurrentHashMap<>(); + + @Override + public void save(BatchExecution execution) { + Assert.notNull(execution, "execution must not be null"); + this.store.put(execution.getBatchId(), execution); + } + + @Override + public Optional findById(String batchId) { + Assert.hasText(batchId, "batchId must not be blank"); + return Optional.ofNullable(this.store.get(batchId)); + } + + @Override + public List findByStatus(BatchExecutionStatus status) { + Assert.notNull(status, "status must not be null"); + return this.store.values().stream().filter(e -> e.getStatus() == status).toList(); + } + + @Override + public List findPendingExecutions() { + return this.store.values().stream().filter(e -> !e.getStatus().isTerminal()).toList(); + } + + @Override + public void deleteById(String batchId) { + Assert.hasText(batchId, "batchId must not be blank"); + this.store.remove(batchId); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchApi.java new file mode 100644 index 0000000000..308295e889 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchApi.java @@ -0,0 +1,204 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.client.OpenAIClient; +import com.openai.core.JsonValue; +import com.openai.core.MultipartField; +import com.openai.core.http.HttpResponse; +import com.openai.models.batches.Batch; +import com.openai.models.batches.BatchCreateParams; +import com.openai.models.files.FileCreateParams; +import com.openai.models.files.FileObject; +import com.openai.models.files.FilePurpose; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.util.Assert; + +/** + * Low-level client for the OpenAI Batch and Files APIs, wrapping the official OpenAI Java + * SDK. + *

+ * Provides methods for: + *

    + *
  • Generating and uploading JSONL batch input files
  • + *
  • Creating, retrieving, listing, and cancelling batch executions
  • + *
  • Downloading batch output/error files
  • + *
  • Deleting files from OpenAI storage
  • + *
+ * + * @author Yasin Akbas + * @since 2.0.0 + * @see OpenAI Batch + * API + */ +public class OpenAiBatchApi { + + private static final Logger logger = LoggerFactory.getLogger(OpenAiBatchApi.class); + + private final OpenAIClient openAiClient; + + private final ObjectMapper objectMapper; + + public OpenAiBatchApi(OpenAIClient openAiClient, ObjectMapper objectMapper) { + Assert.notNull(openAiClient, "openAiClient must not be null"); + Assert.notNull(objectMapper, "objectMapper must not be null"); + this.openAiClient = openAiClient; + this.objectMapper = objectMapper; + } + + /** + * Uploads a JSONL file containing batch request lines and creates a batch execution. + * @param requestLines the batch request lines to include + * @param endpoint the target API endpoint (e.g., {@code /v1/chat/completions}) + * @param completionWindow the completion window (e.g., {@code 24h}) + * @param metadata optional metadata to attach to the batch + * @return the created {@link Batch} object + */ + public Batch createBatch(List requestLines, BatchCreateParams.Endpoint endpoint, + BatchCreateParams.CompletionWindow completionWindow, Map metadata) { + Assert.notEmpty(requestLines, "requestLines must not be empty"); + + String jsonl = generateJsonl(requestLines); + FileObject inputFile = uploadJsonlFile(jsonl); + + logger.debug("Creating batch with {} requests, endpoint={}, inputFileId={}", requestLines.size(), + endpoint.asString(), inputFile.id()); + + BatchCreateParams.Builder createParams = BatchCreateParams.builder() + .inputFileId(inputFile.id()) + .endpoint(endpoint) + .completionWindow(completionWindow); + + if (metadata != null && !metadata.isEmpty()) { + BatchCreateParams.Metadata.Builder metaBuilder = BatchCreateParams.Metadata.builder(); + metadata.forEach((key, value) -> metaBuilder.putAdditionalProperty(key, JsonValue.from(value))); + createParams.metadata(metaBuilder.build()); + } + + return this.openAiClient.batches().create(createParams.build()); + } + + /** + * Retrieves the current status and details of a batch execution. + * @param batchId the OpenAI batch ID + * @return the {@link Batch} object with current status + */ + public Batch retrieveBatch(String batchId) { + Assert.hasText(batchId, "batchId must not be blank"); + return this.openAiClient.batches().retrieve(batchId); + } + + /** + * Cancels an in-progress batch execution. + * @param batchId the OpenAI batch ID to cancel + * @return the cancelled {@link Batch} object + */ + public Batch cancelBatch(String batchId) { + Assert.hasText(batchId, "batchId must not be blank"); + return this.openAiClient.batches().cancel(batchId); + } + + /** + * Downloads the content of a file from OpenAI and returns it as a string. + * @param fileId the file ID to download + * @return the file content as a UTF-8 string + */ + public String downloadFileContent(String fileId) { + Assert.hasText(fileId, "fileId must not be blank"); + HttpResponse response = this.openAiClient.files().content(fileId); + try (var is = response.body()) { + return new String(is.readAllBytes(), StandardCharsets.UTF_8); + } + catch (Exception ex) { + throw new OpenAiBatchException("Failed to download file content for fileId: " + fileId, ex); + } + } + + /** + * Parses JSONL output content into a list of {@link BatchResponseLine} objects. + * @param jsonlContent the JSONL content string + * @return parsed response lines + */ + public List parseResponseLines(String jsonlContent) { + Assert.hasText(jsonlContent, "jsonlContent must not be blank"); + try (var reader = new BufferedReader(new InputStreamReader( + new ByteArrayInputStream(jsonlContent.getBytes(StandardCharsets.UTF_8)), StandardCharsets.UTF_8))) { + return reader.lines() + .filter(line -> !line.isBlank()) + .map(line -> deserialize(line, BatchResponseLine.class)) + .collect(Collectors.toList()); + } + catch (Exception ex) { + throw new OpenAiBatchException("Failed to parse batch response JSONL", ex); + } + } + + /** + * Deletes a file from OpenAI storage. + * @param fileId the file ID to delete + */ + public void deleteFile(String fileId) { + Assert.hasText(fileId, "fileId must not be blank"); + this.openAiClient.files().delete(fileId); + logger.debug("Deleted file: {}", fileId); + } + + private FileObject uploadJsonlFile(String jsonlContent) { + byte[] bytes = jsonlContent.getBytes(StandardCharsets.UTF_8); + MultipartField fileField = MultipartField.builder() + .value(new ByteArrayInputStream(bytes)) + .filename("batch_input.jsonl") + .contentType("application/jsonl") + .build(); + FileCreateParams params = FileCreateParams.builder().file(fileField).purpose(FilePurpose.BATCH).build(); + return this.openAiClient.files().create(params); + } + + private String generateJsonl(List requestLines) { + return requestLines.stream().map(line -> serialize(line)).collect(Collectors.joining("\n")); + } + + private String serialize(Object value) { + try { + return this.objectMapper.writeValueAsString(value); + } + catch (Exception ex) { + throw new OpenAiBatchException("Failed to serialize batch request line to JSON", ex); + } + } + + private T deserialize(String json, Class type) { + try { + return this.objectMapper.readValue(json, type); + } + catch (Exception ex) { + throw new OpenAiBatchException("Failed to deserialize batch response line from JSON: " + json, ex); + } + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchException.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchException.java new file mode 100644 index 0000000000..aaa69bc42e --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchException.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +/** + * Exception thrown when an error occurs during OpenAI Batch API operations. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public class OpenAiBatchException extends RuntimeException { + + public OpenAiBatchException(String message) { + super(message); + } + + public OpenAiBatchException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchListener.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchListener.java new file mode 100644 index 0000000000..a7e615aff1 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchListener.java @@ -0,0 +1,81 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.util.List; + +import com.openai.models.batches.Batch; + +/** + * Listener interface for OpenAI Batch API lifecycle events. + *

+ * Implementations can react to batch creation, completion, failure, expiration, and + * individual request results. A no-op default is provided for all methods so + * implementations only need to override the events they care about. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public interface OpenAiBatchListener { + + /** + * Called after a batch has been successfully created on the OpenAI API. + * @param batch the created batch + * @param requestCount the number of requests in the batch + */ + default void onBatchCreated(Batch batch, int requestCount) { + } + + /** + * Called when a batch has completed successfully. + * @param batch the completed batch + */ + default void onBatchCompleted(Batch batch) { + } + + /** + * Called when a batch has failed. + * @param batch the failed batch + */ + default void onBatchFailed(Batch batch) { + } + + /** + * Called when a batch has expired before completion. + * @param batch the expired batch + */ + default void onBatchExpired(Batch batch) { + } + + /** + * Called when a batch has been cancelled. + * @param batch the cancelled batch + */ + default void onBatchCancelled(Batch batch) { + } + + /** + * Called after individual response lines from a completed batch have been processed. + * @param batch the batch that produced these results + * @param successLines response lines that completed successfully + * @param errorLines response lines that failed + */ + default void onBatchResultsProcessed(Batch batch, List successLines, + List errorLines) { + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchModel.java new file mode 100644 index 0000000000..cf198ca90b --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchModel.java @@ -0,0 +1,583 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import com.openai.client.OpenAIClient; +import com.openai.models.batches.Batch; +import com.openai.models.batches.BatchCreateParams; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.openai.setup.OpenAiSetup; +import org.springframework.util.Assert; + +/** + * OpenAI Batch API model that orchestrates the lifecycle of batch executions. + *

+ * This model provides three main lifecycle phases: + *

    + *
  1. Prepare — Collects pending items from registered + * {@link BatchRequestHandler}s and converts them into {@link BatchRequestLine}s
  2. + *
  3. Execute — Uploads JSONL files and creates batch executions on the OpenAI + * API
  4. + *
  5. Check — Polls batch status, processes output/error files, and dispatches + * results to handlers
  6. + *
+ * + *

+ * Following the spring-ai pattern, this class uses a builder for construction and can + * optionally auto-create the OpenAI client from options if one is not explicitly + * provided. + * + * @author Yasin Akbas + * @since 2.0.0 + * @see BatchRequestHandler + * @see OpenAiBatchApi + * @see OpenAiBatchListener + */ +public final class OpenAiBatchModel { + + private static final Logger logger = LoggerFactory.getLogger(OpenAiBatchModel.class); + + /** + * Metadata key for the handler version stored with each batch. + */ + static final String METADATA_HANDLER_VERSION = "handler-version"; + + private final OpenAiBatchApi batchApi; + + private final OpenAiBatchOptions options; + + private final List> handlers; + + private final List listeners; + + private final BatchExecutionRepository executionRepository; + + private OpenAiBatchModel(Builder builder) { + this.options = builder.options != null ? builder.options : OpenAiBatchOptions.builder().build(); + + if (builder.batchApi != null) { + this.batchApi = builder.batchApi; + } + else { + OpenAIClient openAiClient = Objects.requireNonNullElseGet(builder.openAiClient, + () -> OpenAiSetup.setupSyncClient(this.options.getBaseUrl(), this.options.getApiKey(), + this.options.getCredential(), this.options.getMicrosoftDeploymentName(), + this.options.getMicrosoftFoundryServiceVersion(), this.options.getOrganizationId(), + this.options.isMicrosoftFoundry(), this.options.isGitHubModels(), this.options.getModel(), + this.options.getTimeout(), this.options.getMaxRetries(), this.options.getProxy(), + this.options.getCustomHeaders())); + this.batchApi = new OpenAiBatchApi(openAiClient, new com.fasterxml.jackson.databind.ObjectMapper()); + } + + this.handlers = builder.handlers != null ? List.copyOf(builder.handlers) : List.of(); + this.listeners = builder.listeners != null ? List.copyOf(builder.listeners) : List.of(); + this.executionRepository = builder.executionRepository != null ? builder.executionRepository + : new InMemoryBatchExecutionRepository(); + + validateHandlers(); + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Collects pending items from all registered handlers and creates batch executions. + * Requests are grouped by endpoint since OpenAI requires each batch to target a + * single endpoint. + * @return a list of created {@link Batch} objects + */ + public List createBatchExecutions() { + List createdBatches = new ArrayList<>(); + + Map> linesByEndpoint = new HashMap<>(); + + for (BatchRequestHandler handler : this.handlers) { + int maxItems = this.options.getMaxRequestsPerBatch(); + Map pendingItems = handler.getPendingItems(maxItems); + + if (pendingItems.isEmpty()) { + logger.debug("No pending items for handler: {}", handler.getHandlerId()); + continue; + } + + @SuppressWarnings("unchecked") + BatchRequestHandler typedHandler = (BatchRequestHandler) handler; + @SuppressWarnings("unchecked") + Map typedItems = (Map) pendingItems; + + List requestLines = typedHandler.toRequestLines(typedItems); + + linesByEndpoint.computeIfAbsent(handler.getEndpoint(), k -> new ArrayList<>()).addAll(requestLines); + + logger.info("Prepared {} request lines from handler '{}' for endpoint '{}'", requestLines.size(), + handler.getHandlerId(), handler.getEndpoint()); + } + + for (Map.Entry> entry : linesByEndpoint.entrySet()) { + String endpoint = entry.getKey(); + List allLines = entry.getValue(); + + if (allLines.isEmpty()) { + continue; + } + + // Split into chunks respecting the max requests per batch limit + for (int i = 0; i < allLines.size(); i += this.options.getMaxRequestsPerBatch()) { + List chunk = allLines.subList(i, + Math.min(i + this.options.getMaxRequestsPerBatch(), allLines.size())); + + try { + Batch batch = this.batchApi.createBatch(chunk, BatchCreateParams.Endpoint.of(endpoint), + BatchCreateParams.CompletionWindow.of(this.options.getCompletionWindow()), + Map.of("spring-ai-version", "2.0.0", METADATA_HANDLER_VERSION, + String.valueOf(this.options.getHandlerVersion()))); + + createdBatches.add(batch); + + BatchExecution execution = new BatchExecution(batch.id(), endpoint, BatchExecutionStatus.SUBMITTED, + chunk.size(), batch.inputFileId()); + this.executionRepository.save(execution); + + notifyBatchCreated(batch, chunk.size()); + + logger.info("Created batch '{}' with {} requests for endpoint '{}'", batch.id(), chunk.size(), + endpoint); + } + catch (Exception ex) { + logger.error("Failed to create batch for endpoint '{}' with {} requests", endpoint, chunk.size(), + ex); + } + } + } + + return createdBatches; + } + + /** + * Checks the status of a batch execution and processes results if completed. + * @param batchId the OpenAI batch ID to check + * @return the current {@link Batch} status + */ + public Batch checkBatchExecution(String batchId) { + Assert.hasText(batchId, "batchId must not be blank"); + + Batch batch = this.batchApi.retrieveBatch(batchId); + + Batch.Status status = batch.status(); + logger.debug("Batch '{}' status: {}", batchId, status); + + updateExecutionStatus(batchId, BatchExecutionStatus.fromOpenAiStatus(status)); + + if (Batch.Status.COMPLETED.equals(status)) { + handleCompletedBatch(batch); + updateExecutionStatus(batchId, BatchExecutionStatus.RESULTS_PROCESSED); + notifyBatchCompleted(batch); + } + else if (Batch.Status.FAILED.equals(status)) { + int batchVersion = extractBatchVersion(batch); + List errorLines = new ArrayList<>(); + handleErrorFile(batch, errorLines, batchVersion); + notifyBatchFailed(batch); + } + else if (Batch.Status.EXPIRED.equals(status)) { + handleCompletedBatch(batch); + updateExecutionStatus(batchId, BatchExecutionStatus.RESULTS_PROCESSED); + notifyBatchExpired(batch); + } + else if (Batch.Status.CANCELLED.equals(status)) { + notifyBatchCancelled(batch); + } + + return batch; + } + + /** + * Checks all pending (non-terminal) batch executions tracked by the repository. This + * is a convenience method for scheduled polling — call it periodically to + * automatically process all in-flight batches. + * @return a list of checked {@link Batch} objects with their current statuses + */ + public List checkAllPendingExecutions() { + List pending = this.executionRepository.findPendingExecutions(); + + if (pending.isEmpty()) { + logger.debug("No pending batch executions to check"); + return List.of(); + } + + logger.info("Checking {} pending batch executions", pending.size()); + List results = new ArrayList<>(); + + for (BatchExecution execution : pending) { + try { + Batch batch = checkBatchExecution(execution.getBatchId()); + results.add(batch); + } + catch (Exception ex) { + logger.error("Failed to check batch execution '{}'", execution.getBatchId(), ex); + } + } + + return results; + } + + /** + * Cancels a batch execution. + * @param batchId the OpenAI batch ID to cancel + * @return the cancelled {@link Batch} + */ + public Batch cancelBatch(String batchId) { + Assert.hasText(batchId, "batchId must not be blank"); + return this.batchApi.cancelBatch(batchId); + } + + /** + * Returns the configured options. + */ + public OpenAiBatchOptions getOptions() { + return this.options; + } + + /** + * Returns the registered handlers (unmodifiable). + */ + public List> getHandlers() { + return this.handlers; + } + + /** + * Returns the batch execution repository. + */ + public BatchExecutionRepository getExecutionRepository() { + return this.executionRepository; + } + + private void handleCompletedBatch(Batch batch) { + int batchVersion = extractBatchVersion(batch); + List successLines = new ArrayList<>(); + List errorLines = new ArrayList<>(); + + // Process output file + batch.outputFileId().ifPresent(outputFileId -> { + try { + String content = this.batchApi.downloadFileContent(outputFileId); + List lines = this.batchApi.parseResponseLines(content); + + for (BatchResponseLine line : lines) { + if (line.isSuccess()) { + successLines.add(line); + dispatchSuccess(line, batchVersion); + } + else { + errorLines.add(line); + dispatchError(line, batchVersion); + } + } + + if (this.options.isDeleteFilesAfterProcessing()) { + this.batchApi.deleteFile(outputFileId); + } + } + catch (Exception ex) { + logger.error("Failed to process output file '{}' for batch '{}'", outputFileId, batch.id(), ex); + } + }); + + // Process error file + handleErrorFile(batch, errorLines, batchVersion); + + notifyResultsProcessed(batch, successLines, errorLines); + } + + private void handleErrorFile(Batch batch, List errorLines, int batchVersion) { + batch.errorFileId().ifPresent(errorFileId -> { + try { + String content = this.batchApi.downloadFileContent(errorFileId); + List lines = this.batchApi.parseResponseLines(content); + + for (BatchResponseLine line : lines) { + errorLines.add(line); + dispatchError(line, batchVersion); + } + + if (this.options.isDeleteFilesAfterProcessing()) { + this.batchApi.deleteFile(errorFileId); + } + } + catch (Exception ex) { + logger.error("Failed to process error file '{}' for batch '{}'", errorFileId, batch.id(), ex); + } + }); + + // Clean up input file + if (this.options.isDeleteFilesAfterProcessing()) { + try { + this.batchApi.deleteFile(batch.inputFileId()); + } + catch (Exception ex) { + logger.warn("Failed to delete input file '{}' for batch '{}'", batch.inputFileId(), batch.id(), ex); + } + } + } + + private void dispatchSuccess(BatchResponseLine line, int batchVersion) { + BatchResponseLine.Response response = line.response(); + if (line.customId() == null || response == null || response.body() == null) { + logger.warn("Skipping success response with missing customId or body: {}", line.id()); + return; + } + Map body = response.body(); + try { + BatchRequestCustomId customId = BatchRequestCustomId.parse(line.customId()); + findHandler(customId.handlerId()).ifPresentOrElse(handler -> { + try { + handler.onSuccess(customId, body, batchVersion); + } + catch (Exception ex) { + logger.error("Handler '{}' failed processing success for entity '{}'", customId.handlerId(), + customId.entityId(), ex); + } + }, () -> logger.warn("No handler found for handlerId '{}' from customId '{}'", customId.handlerId(), + line.customId())); + } + catch (IllegalArgumentException ex) { + logger.warn("Could not parse customId '{}': {}", line.customId(), ex.getMessage()); + } + } + + private void dispatchError(BatchResponseLine line, int batchVersion) { + if (line.customId() == null) { + logger.warn("Skipping error response with missing customId: {}", line.id()); + return; + } + try { + BatchRequestCustomId customId = BatchRequestCustomId.parse(line.customId()); + BatchResponseLine.Error error = extractError(line); + + findHandler(customId.handlerId()).ifPresent(handler -> { + try { + handler.onError(customId, error, batchVersion); + } + catch (Exception ex) { + logger.error("Handler '{}' failed processing error for entity '{}'", customId.handlerId(), + customId.entityId(), ex); + } + }); + } + catch (IllegalArgumentException ex) { + logger.warn("Could not parse customId '{}': {}", line.customId(), ex.getMessage()); + } + } + + @SuppressWarnings("unchecked") + private BatchResponseLine.Error extractError(BatchResponseLine line) { + if (line.error() != null) { + return line.error(); + } + // Extract error details from non-200 response body + BatchResponseLine.Response response = line.response(); + if (response != null && response.body() != null) { + Map body = response.body(); + Object errorObj = body.get("error"); + if (errorObj instanceof Map errorMap) { + Object codeVal = errorMap.get("code"); + Object msgVal = errorMap.get("message"); + String code = codeVal != null ? String.valueOf(codeVal) : "http_" + response.statusCode(); + String message = msgVal != null ? String.valueOf(msgVal) : "Request failed"; + return new BatchResponseLine.Error(code, message); + } + return new BatchResponseLine.Error("http_" + response.statusCode(), + "Request failed with status " + response.statusCode()); + } + return new BatchResponseLine.Error("unknown", "No error details available"); + } + + private void updateExecutionStatus(String batchId, BatchExecutionStatus status) { + this.executionRepository.findById(batchId).ifPresent(execution -> { + execution.setStatus(status); + this.executionRepository.save(execution); + }); + } + + private int extractBatchVersion(Batch batch) { + return batch.metadata().map(metadata -> { + com.openai.core.JsonValue versionValue = metadata._additionalProperties().get(METADATA_HANDLER_VERSION); + if (versionValue != null) { + try { + String str = versionValue.toString().replace("\"", ""); + return Integer.parseInt(str); + } + catch (NumberFormatException ex) { + logger.warn("Invalid handler-version in batch '{}' metadata: {}", batch.id(), versionValue); + } + } + return OpenAiBatchOptions.DEFAULT_HANDLER_VERSION; + }).orElse(OpenAiBatchOptions.DEFAULT_HANDLER_VERSION); + } + + private java.util.Optional> findHandler(String handlerId) { + return this.handlers.stream().filter(h -> h.getHandlerId().equals(handlerId)).findFirst(); + } + + private void validateHandlers() { + Map handlerIdCounts = new HashMap<>(); + for (BatchRequestHandler handler : this.handlers) { + String id = handler.getHandlerId(); + Assert.hasText(id, "handler ID must not be blank"); + Assert.isTrue(!id.contains(BatchRequestCustomId.DELIMITER), + "handler ID must not contain the delimiter '" + BatchRequestCustomId.DELIMITER + "'"); + handlerIdCounts.merge(id, 1, Integer::sum); + } + for (Map.Entry entry : handlerIdCounts.entrySet()) { + if (entry.getValue() > 1) { + throw new IllegalArgumentException("Duplicate handler ID: '" + entry.getKey() + "' (found " + + entry.getValue() + " handlers with the same ID)"); + } + } + } + + private void notifyBatchCreated(Batch batch, int requestCount) { + for (OpenAiBatchListener listener : this.listeners) { + try { + listener.onBatchCreated(batch, requestCount); + } + catch (Exception ex) { + logger.warn("Listener failed on onBatchCreated", ex); + } + } + } + + private void notifyBatchCompleted(Batch batch) { + for (OpenAiBatchListener listener : this.listeners) { + try { + listener.onBatchCompleted(batch); + } + catch (Exception ex) { + logger.warn("Listener failed on onBatchCompleted", ex); + } + } + } + + private void notifyBatchFailed(Batch batch) { + for (OpenAiBatchListener listener : this.listeners) { + try { + listener.onBatchFailed(batch); + } + catch (Exception ex) { + logger.warn("Listener failed on onBatchFailed", ex); + } + } + } + + private void notifyBatchExpired(Batch batch) { + for (OpenAiBatchListener listener : this.listeners) { + try { + listener.onBatchExpired(batch); + } + catch (Exception ex) { + logger.warn("Listener failed on onBatchExpired", ex); + } + } + } + + private void notifyBatchCancelled(Batch batch) { + for (OpenAiBatchListener listener : this.listeners) { + try { + listener.onBatchCancelled(batch); + } + catch (Exception ex) { + logger.warn("Listener failed on onBatchCancelled", ex); + } + } + } + + private void notifyResultsProcessed(Batch batch, List successLines, + List errorLines) { + for (OpenAiBatchListener listener : this.listeners) { + try { + listener.onBatchResultsProcessed(batch, successLines, errorLines); + } + catch (Exception ex) { + logger.warn("Listener failed on onBatchResultsProcessed", ex); + } + } + } + + public static final class Builder { + + private @Nullable OpenAIClient openAiClient; + + private @Nullable OpenAiBatchApi batchApi; + + private @Nullable OpenAiBatchOptions options; + + private @Nullable List> handlers; + + private @Nullable List listeners; + + private @Nullable BatchExecutionRepository executionRepository; + + private Builder() { + } + + public Builder openAiClient(OpenAIClient openAiClient) { + this.openAiClient = openAiClient; + return this; + } + + public Builder batchApi(OpenAiBatchApi batchApi) { + this.batchApi = batchApi; + return this; + } + + public Builder options(OpenAiBatchOptions options) { + this.options = options; + return this; + } + + public Builder handlers(List> handlers) { + this.handlers = handlers; + return this; + } + + public Builder listeners(List listeners) { + this.listeners = listeners; + return this; + } + + public Builder executionRepository(BatchExecutionRepository executionRepository) { + this.executionRepository = executionRepository; + return this; + } + + public OpenAiBatchModel build() { + return new OpenAiBatchModel(this); + } + + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchOptions.java new file mode 100644 index 0000000000..8da24bb3da --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/OpenAiBatchOptions.java @@ -0,0 +1,409 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.net.Proxy; +import java.time.Duration; +import java.util.Map; +import java.util.Objects; + +import com.openai.azure.AzureOpenAIServiceVersion; +import com.openai.credential.Credential; +import org.jspecify.annotations.Nullable; + +import org.springframework.ai.openai.AbstractOpenAiOptions; + +/** + * Options for OpenAI Batch API operations. + *

+ * Extends {@link AbstractOpenAiOptions} for connection configuration and adds + * batch-specific settings such as completion window, rate limits, retry policies, and + * token budget configuration. + * + * @author Yasin Akbas + * @since 2.0.0 + */ +public class OpenAiBatchOptions extends AbstractOpenAiOptions { + + /** + * Default completion window for batch jobs. + */ + public static final String DEFAULT_COMPLETION_WINDOW = "24h"; + + /** + * Default maximum number of requests per batch. + */ + public static final int DEFAULT_MAX_REQUESTS_PER_BATCH = 50_000; + + /** + * Default maximum file size in bytes (200 MB). + */ + public static final long DEFAULT_MAX_FILE_SIZE_BYTES = 200L * 1024 * 1024; + + /** + * Default safety factor applied to token estimates. + */ + public static final double DEFAULT_TOKEN_SAFETY_FACTOR = 1.2; + + /** + * Default minimum tokens required before submitting a batch. + */ + public static final long DEFAULT_MINIMUM_TOKENS_TO_SUBMIT = 5_000_000L; + + /** + * Default maximum retry attempts for failed requests. + */ + public static final int DEFAULT_MAX_RETRY_ATTEMPTS = 2; + + /** + * Default whether to clean up files after processing. + */ + public static final boolean DEFAULT_DELETE_FILES_AFTER_PROCESSING = true; + + /** + * Default handler version stored in batch metadata. + */ + public static final int DEFAULT_HANDLER_VERSION = 1; + + private @Nullable String completionWindow; + + private int maxRequestsPerBatch = DEFAULT_MAX_REQUESTS_PER_BATCH; + + private long maxFileSizeBytes = DEFAULT_MAX_FILE_SIZE_BYTES; + + private double tokenSafetyFactor = DEFAULT_TOKEN_SAFETY_FACTOR; + + private long minimumTokensToSubmit = DEFAULT_MINIMUM_TOKENS_TO_SUBMIT; + + private int maxRetryAttempts = DEFAULT_MAX_RETRY_ATTEMPTS; + + private boolean deleteFilesAfterProcessing = DEFAULT_DELETE_FILES_AFTER_PROCESSING; + + private int handlerVersion = DEFAULT_HANDLER_VERSION; + + public static Builder builder() { + return new Builder(); + } + + public String getCompletionWindow() { + return this.completionWindow != null ? this.completionWindow : DEFAULT_COMPLETION_WINDOW; + } + + public void setCompletionWindow(@Nullable String completionWindow) { + this.completionWindow = completionWindow; + } + + public int getMaxRequestsPerBatch() { + return this.maxRequestsPerBatch; + } + + public void setMaxRequestsPerBatch(int maxRequestsPerBatch) { + this.maxRequestsPerBatch = maxRequestsPerBatch; + } + + public long getMaxFileSizeBytes() { + return this.maxFileSizeBytes; + } + + public void setMaxFileSizeBytes(long maxFileSizeBytes) { + this.maxFileSizeBytes = maxFileSizeBytes; + } + + public double getTokenSafetyFactor() { + return this.tokenSafetyFactor; + } + + public void setTokenSafetyFactor(double tokenSafetyFactor) { + this.tokenSafetyFactor = tokenSafetyFactor; + } + + public long getMinimumTokensToSubmit() { + return this.minimumTokensToSubmit; + } + + public void setMinimumTokensToSubmit(long minimumTokensToSubmit) { + this.minimumTokensToSubmit = minimumTokensToSubmit; + } + + public int getMaxRetryAttempts() { + return this.maxRetryAttempts; + } + + public void setMaxRetryAttempts(int maxRetryAttempts) { + this.maxRetryAttempts = maxRetryAttempts; + } + + public boolean isDeleteFilesAfterProcessing() { + return this.deleteFilesAfterProcessing; + } + + public void setDeleteFilesAfterProcessing(boolean deleteFilesAfterProcessing) { + this.deleteFilesAfterProcessing = deleteFilesAfterProcessing; + } + + public int getHandlerVersion() { + return this.handlerVersion; + } + + public void setHandlerVersion(int handlerVersion) { + this.handlerVersion = handlerVersion; + } + + public OpenAiBatchOptions copy() { + return builder().from(this).build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof OpenAiBatchOptions that)) { + return false; + } + return this.maxRequestsPerBatch == that.maxRequestsPerBatch && this.maxFileSizeBytes == that.maxFileSizeBytes + && Double.compare(this.tokenSafetyFactor, that.tokenSafetyFactor) == 0 + && this.minimumTokensToSubmit == that.minimumTokensToSubmit + && this.maxRetryAttempts == that.maxRetryAttempts + && this.deleteFilesAfterProcessing == that.deleteFilesAfterProcessing + && this.handlerVersion == that.handlerVersion + && Objects.equals(this.completionWindow, that.completionWindow) + && Objects.equals(getBaseUrl(), that.getBaseUrl()) && Objects.equals(getApiKey(), that.getApiKey()); + } + + @Override + public int hashCode() { + return Objects.hash(this.completionWindow, this.maxRequestsPerBatch, this.maxFileSizeBytes, + this.tokenSafetyFactor, this.minimumTokensToSubmit, this.maxRetryAttempts, + this.deleteFilesAfterProcessing, this.handlerVersion, getBaseUrl(), getApiKey()); + } + + @Override + public String toString() { + return "OpenAiBatchOptions{" + "completionWindow='" + getCompletionWindow() + '\'' + ", maxRequestsPerBatch=" + + this.maxRequestsPerBatch + ", maxFileSizeBytes=" + this.maxFileSizeBytes + ", tokenSafetyFactor=" + + this.tokenSafetyFactor + ", minimumTokensToSubmit=" + this.minimumTokensToSubmit + + ", maxRetryAttempts=" + this.maxRetryAttempts + ", deleteFilesAfterProcessing=" + + this.deleteFilesAfterProcessing + ", handlerVersion=" + this.handlerVersion + ", baseUrl='" + + getBaseUrl() + '\'' + '}'; + } + + public static final class Builder { + + private @Nullable String completionWindow; + + private int maxRequestsPerBatch = DEFAULT_MAX_REQUESTS_PER_BATCH; + + private long maxFileSizeBytes = DEFAULT_MAX_FILE_SIZE_BYTES; + + private double tokenSafetyFactor = DEFAULT_TOKEN_SAFETY_FACTOR; + + private long minimumTokensToSubmit = DEFAULT_MINIMUM_TOKENS_TO_SUBMIT; + + private int maxRetryAttempts = DEFAULT_MAX_RETRY_ATTEMPTS; + + private boolean deleteFilesAfterProcessing = DEFAULT_DELETE_FILES_AFTER_PROCESSING; + + private int handlerVersion = DEFAULT_HANDLER_VERSION; + + private @Nullable String baseUrl; + + private @Nullable String apiKey; + + private @Nullable Credential credential; + + private @Nullable String microsoftDeploymentName; + + private @Nullable AzureOpenAIServiceVersion microsoftFoundryServiceVersion; + + private @Nullable String organizationId; + + private boolean microsoftFoundry; + + private boolean gitHubModels; + + private @Nullable Duration timeout; + + private @Nullable Integer maxRetries; + + private @Nullable Proxy proxy; + + private @Nullable Map customHeaders; + + private Builder() { + } + + public Builder completionWindow(String completionWindow) { + this.completionWindow = completionWindow; + return this; + } + + public Builder maxRequestsPerBatch(int maxRequestsPerBatch) { + this.maxRequestsPerBatch = maxRequestsPerBatch; + return this; + } + + public Builder maxFileSizeBytes(long maxFileSizeBytes) { + this.maxFileSizeBytes = maxFileSizeBytes; + return this; + } + + public Builder tokenSafetyFactor(double tokenSafetyFactor) { + this.tokenSafetyFactor = tokenSafetyFactor; + return this; + } + + public Builder minimumTokensToSubmit(long minimumTokensToSubmit) { + this.minimumTokensToSubmit = minimumTokensToSubmit; + return this; + } + + public Builder maxRetryAttempts(int maxRetryAttempts) { + this.maxRetryAttempts = maxRetryAttempts; + return this; + } + + public Builder deleteFilesAfterProcessing(boolean deleteFilesAfterProcessing) { + this.deleteFilesAfterProcessing = deleteFilesAfterProcessing; + return this; + } + + public Builder handlerVersion(int handlerVersion) { + this.handlerVersion = handlerVersion; + return this; + } + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder credential(Credential credential) { + this.credential = credential; + return this; + } + + public Builder microsoftDeploymentName(String deploymentName) { + this.microsoftDeploymentName = deploymentName; + return this; + } + + public Builder microsoftFoundryServiceVersion(AzureOpenAIServiceVersion serviceVersion) { + this.microsoftFoundryServiceVersion = serviceVersion; + return this; + } + + public Builder organizationId(String organizationId) { + this.organizationId = organizationId; + return this; + } + + public Builder microsoftFoundry(boolean isMicrosoftFoundry) { + this.microsoftFoundry = isMicrosoftFoundry; + return this; + } + + public Builder gitHubModels(boolean isGitHubModels) { + this.gitHubModels = isGitHubModels; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder maxRetries(int maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder proxy(Proxy proxy) { + this.proxy = proxy; + return this; + } + + public Builder customHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + return this; + } + + public Builder from(OpenAiBatchOptions options) { + this.completionWindow = options.completionWindow; + this.maxRequestsPerBatch = options.maxRequestsPerBatch; + this.maxFileSizeBytes = options.maxFileSizeBytes; + this.tokenSafetyFactor = options.tokenSafetyFactor; + this.minimumTokensToSubmit = options.minimumTokensToSubmit; + this.maxRetryAttempts = options.maxRetryAttempts; + this.deleteFilesAfterProcessing = options.deleteFilesAfterProcessing; + this.handlerVersion = options.handlerVersion; + this.baseUrl = options.getBaseUrl(); + this.apiKey = options.getApiKey(); + this.credential = options.getCredential(); + this.microsoftDeploymentName = options.getMicrosoftDeploymentName(); + this.microsoftFoundryServiceVersion = options.getMicrosoftFoundryServiceVersion(); + this.organizationId = options.getOrganizationId(); + this.microsoftFoundry = options.isMicrosoftFoundry(); + this.gitHubModels = options.isGitHubModels(); + this.timeout = options.getTimeout(); + this.maxRetries = options.getMaxRetries(); + this.proxy = options.getProxy(); + if (options.getCustomHeaders() != null) { + this.customHeaders = options.getCustomHeaders(); + } + return this; + } + + public OpenAiBatchOptions build() { + OpenAiBatchOptions options = new OpenAiBatchOptions(); + options.setCompletionWindow(this.completionWindow); + options.setMaxRequestsPerBatch(this.maxRequestsPerBatch); + options.setMaxFileSizeBytes(this.maxFileSizeBytes); + options.setTokenSafetyFactor(this.tokenSafetyFactor); + options.setMinimumTokensToSubmit(this.minimumTokensToSubmit); + options.setMaxRetryAttempts(this.maxRetryAttempts); + options.setDeleteFilesAfterProcessing(this.deleteFilesAfterProcessing); + options.setHandlerVersion(this.handlerVersion); + options.setBaseUrl(this.baseUrl); + options.setApiKey(this.apiKey); + options.setCredential(this.credential); + options.setDeploymentName(this.microsoftDeploymentName); + options.setMicrosoftFoundryServiceVersion(this.microsoftFoundryServiceVersion); + options.setOrganizationId(this.organizationId); + options.setMicrosoftFoundry(this.microsoftFoundry); + options.setGitHubModels(this.gitHubModels); + if (this.timeout != null) { + options.setTimeout(this.timeout); + } + if (this.maxRetries != null) { + options.setMaxRetries(this.maxRetries); + } + options.setProxy(this.proxy); + if (this.customHeaders != null) { + options.setCustomHeaders(this.customHeaders); + } + return options; + } + + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/package-info.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/package-info.java new file mode 100644 index 0000000000..bb3dab7ba6 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/batch/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * OpenAI Batch API support for Spring AI. + *

+ * Provides an endpoint-agnostic batch processing framework that supports all OpenAI Batch + * API endpoints ({@code /v1/chat/completions}, {@code /v1/embeddings}, + * {@code /v1/moderations}, etc.) through the {@code BatchRequestHandler} abstraction. + */ +@org.jspecify.annotations.NullMarked +package org.springframework.ai.openai.batch; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/BatchRequestCustomIdTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/BatchRequestCustomIdTests.java new file mode 100644 index 0000000000..4c75b173ed --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/BatchRequestCustomIdTests.java @@ -0,0 +1,109 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link BatchRequestCustomId}. + * + * @author Yasin Akbas + */ +class BatchRequestCustomIdTests { + + @Test + void shouldCreateCustomIdFromParts() { + BatchRequestCustomId customId = new BatchRequestCustomId("entity-123", "chat-handler"); + assertThat(customId.entityId()).isEqualTo("entity-123"); + assertThat(customId.handlerId()).isEqualTo("chat-handler"); + assertThat(customId.toString()).isEqualTo("entity-123::chat-handler"); + } + + @Test + void shouldParseCustomIdString() { + BatchRequestCustomId customId = BatchRequestCustomId.parse("entity-456::embed-handler"); + assertThat(customId.entityId()).isEqualTo("entity-456"); + assertThat(customId.handlerId()).isEqualTo("embed-handler"); + } + + @Test + void shouldParseCustomIdWithHyphensAndNumbers() { + BatchRequestCustomId customId = BatchRequestCustomId.parse("my-entity-789::my-handler-v2"); + assertThat(customId.entityId()).isEqualTo("my-entity-789"); + assertThat(customId.handlerId()).isEqualTo("my-handler-v2"); + } + + @Test + void shouldHandleEntityIdWithUnderscores() { + BatchRequestCustomId customId = new BatchRequestCustomId("entity_with_underscores", "handler"); + assertThat(customId.toString()).isEqualTo("entity_with_underscores::handler"); + + BatchRequestCustomId parsed = BatchRequestCustomId.parse("entity_with_underscores::handler"); + assertThat(parsed.entityId()).isEqualTo("entity_with_underscores"); + } + + @Test + void shouldRejectBlankEntityId() { + assertThatThrownBy(() -> new BatchRequestCustomId("", "handler")).isInstanceOf(IllegalArgumentException.class); + } + + @Test + void shouldRejectBlankHandlerId() { + assertThatThrownBy(() -> new BatchRequestCustomId("entity", "")).isInstanceOf(IllegalArgumentException.class); + } + + @Test + void shouldRejectEntityIdContainingDelimiter() { + assertThatThrownBy(() -> new BatchRequestCustomId("entity::bad", "handler")) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void shouldRejectHandlerIdContainingDelimiter() { + assertThatThrownBy(() -> new BatchRequestCustomId("entity", "handler::bad")) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void shouldRejectInvalidFormatWhenParsing() { + assertThatThrownBy(() -> BatchRequestCustomId.parse("no-delimiter-here")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid custom ID format"); + } + + @Test + void shouldRejectTooManyPartsWhenParsing() { + assertThatThrownBy(() -> BatchRequestCustomId.parse("a::b::c")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid custom ID format"); + } + + @Test + void shouldRejectBlankStringWhenParsing() { + assertThatThrownBy(() -> BatchRequestCustomId.parse("")).isInstanceOf(IllegalArgumentException.class); + } + + @Test + void shouldRoundTrip() { + BatchRequestCustomId original = new BatchRequestCustomId("comp-12345", "metadata-gen"); + BatchRequestCustomId parsed = BatchRequestCustomId.parse(original.toString()); + assertThat(parsed).isEqualTo(original); + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/BatchRequestLineTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/BatchRequestLineTests.java new file mode 100644 index 0000000000..5ca6c0912d --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/BatchRequestLineTests.java @@ -0,0 +1,71 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link BatchRequestLine}. + * + * @author Yasin Akbas + */ +class BatchRequestLineTests { + + @Test + void shouldCreatePostRequestLine() { + BatchRequestLine line = BatchRequestLine.post("entity-1::chat", "/v1/chat/completions", + Map.of("model", "gpt-4o-mini", "messages", "hello")); + + assertThat(line.customId()).isEqualTo("entity-1::chat"); + assertThat(line.method()).isEqualTo("POST"); + assertThat(line.url()).isEqualTo("/v1/chat/completions"); + assertThat(line.body()).containsEntry("model", "gpt-4o-mini"); + } + + @Test + void shouldCreateWithFullConstructor() { + BatchRequestLine line = new BatchRequestLine("custom-id::handler", "POST", "/v1/embeddings", + Map.of("model", "text-embedding-3-small", "input", "test")); + + assertThat(line.customId()).isEqualTo("custom-id::handler"); + assertThat(line.url()).isEqualTo("/v1/embeddings"); + } + + @Test + void shouldRejectBlankCustomId() { + assertThatThrownBy(() -> BatchRequestLine.post("", "/v1/chat/completions", Map.of())) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void shouldRejectBlankUrl() { + assertThatThrownBy(() -> BatchRequestLine.post("id::handler", "", Map.of())) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void shouldRejectNullBody() { + assertThatThrownBy(() -> BatchRequestLine.post("id::handler", "/v1/chat/completions", null)) + .isInstanceOf(IllegalArgumentException.class); + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/BatchResponseLineTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/BatchResponseLineTests.java new file mode 100644 index 0000000000..43d6ec8bfe --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/BatchResponseLineTests.java @@ -0,0 +1,140 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.util.Map; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link BatchResponseLine} JSON deserialization. + * + * @author Yasin Akbas + */ +class BatchResponseLineTests { + + private final ObjectMapper objectMapper = new ObjectMapper(); + + // CHECKSTYLE.OFF + @Test + void shouldDeserializeSuccessResponse() throws Exception { + String json = """ + { + "id": "batch_req_abc123", + "custom_id": "entity-1::chat-handler", + "response": { + "status_code": 200, + "request_id": "req_xyz", + "body": { + "id": "chatcmpl-123", + "choices": [{"message": {"content": "Hello!"}}] + } + } + } + """; + + BatchResponseLine line = this.objectMapper.readValue(json, BatchResponseLine.class); + + assertThat(line.id()).isEqualTo("batch_req_abc123"); + assertThat(line.customId()).isEqualTo("entity-1::chat-handler"); + assertThat(line.isSuccess()).isTrue(); + assertThat(line.response()).isNotNull(); + assertThat(line.response().statusCode()).isEqualTo(200); + assertThat(line.response().requestId()).isEqualTo("req_xyz"); + assertThat(line.response().body()).containsKey("id"); + assertThat(line.error()).isNull(); + } + + @Test + void shouldDeserializeErrorResponse() throws Exception { + String json = """ + { + "id": "batch_req_err456", + "custom_id": "entity-2::embed-handler", + "error": { + "code": "rate_limit_exceeded", + "message": "Too many requests" + } + } + """; + + BatchResponseLine line = this.objectMapper.readValue(json, BatchResponseLine.class); + + assertThat(line.id()).isEqualTo("batch_req_err456"); + assertThat(line.customId()).isEqualTo("entity-2::embed-handler"); + assertThat(line.isSuccess()).isFalse(); + assertThat(line.error()).isNotNull(); + assertThat(line.error().code()).isEqualTo("rate_limit_exceeded"); + assertThat(line.error().message()).isEqualTo("Too many requests"); + } + + @Test + void shouldDeserializeResponseWithNon200StatusCode() throws Exception { + String json = """ + { + "id": "batch_req_500", + "custom_id": "entity-3::handler", + "response": { + "status_code": 500, + "body": {"error": {"message": "Internal error"}} + } + } + """; + + BatchResponseLine line = this.objectMapper.readValue(json, BatchResponseLine.class); + + assertThat(line.isSuccess()).isFalse(); + assertThat(line.response().statusCode()).isEqualTo(500); + } + + @Test + void shouldHandleUnknownFields() throws Exception { + String json = """ + { + "id": "batch_req_unknown", + "custom_id": "entity-4::handler", + "unknown_field": "should be ignored", + "response": { + "status_code": 200, + "body": {} + } + } + """; + + BatchResponseLine line = this.objectMapper.readValue(json, BatchResponseLine.class); + assertThat(line.id()).isEqualTo("batch_req_unknown"); + assertThat(line.isSuccess()).isTrue(); + } + // CHECKSTYLE.ON + + @Test + void shouldSerializeRequestLine() throws Exception { + BatchRequestLine line = BatchRequestLine.post("entity-1::handler", "/v1/chat/completions", + Map.of("model", "gpt-4o-mini", "messages", java.util.List.of(Map.of("role", "user", "content", "Hi")))); + + String json = this.objectMapper.writeValueAsString(line); + + assertThat(json).contains("\"custom_id\":\"entity-1::handler\""); + assertThat(json).contains("\"method\":\"POST\""); + assertThat(json).contains("\"url\":\"/v1/chat/completions\""); + assertThat(json).contains("\"model\":\"gpt-4o-mini\""); + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/InMemoryBatchExecutionRepositoryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/InMemoryBatchExecutionRepositoryTests.java new file mode 100644 index 0000000000..1cf2ffa48e --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/InMemoryBatchExecutionRepositoryTests.java @@ -0,0 +1,125 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link InMemoryBatchExecutionRepository}. + * + * @author Yasin Akbas + */ +class InMemoryBatchExecutionRepositoryTests { + + @Test + void shouldSaveAndFindById() { + InMemoryBatchExecutionRepository repository = new InMemoryBatchExecutionRepository(); + BatchExecution execution = new BatchExecution("batch_1", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, + 10, "file-input-1"); + + repository.save(execution); + + assertThat(repository.findById("batch_1")).isPresent(); + assertThat(repository.findById("batch_1").get().getEndpoint()).isEqualTo("/v1/chat/completions"); + } + + @Test + void shouldReturnEmptyForMissingBatch() { + InMemoryBatchExecutionRepository repository = new InMemoryBatchExecutionRepository(); + assertThat(repository.findById("nonexistent")).isEmpty(); + } + + @Test + void shouldFindByStatus() { + InMemoryBatchExecutionRepository repository = new InMemoryBatchExecutionRepository(); + repository + .save(new BatchExecution("batch_1", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, 10, "file-1")); + repository + .save(new BatchExecution("batch_2", "/v1/embeddings", BatchExecutionStatus.IN_PROGRESS, 20, "file-2")); + repository + .save(new BatchExecution("batch_3", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, 5, "file-3")); + + List submitted = repository.findByStatus(BatchExecutionStatus.SUBMITTED); + assertThat(submitted).hasSize(2); + assertThat(submitted).extracting(BatchExecution::getBatchId).containsExactlyInAnyOrder("batch_1", "batch_3"); + } + + @Test + void shouldFindPendingExecutions() { + InMemoryBatchExecutionRepository repository = new InMemoryBatchExecutionRepository(); + repository + .save(new BatchExecution("batch_1", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, 10, "file-1")); + repository + .save(new BatchExecution("batch_2", "/v1/embeddings", BatchExecutionStatus.IN_PROGRESS, 20, "file-2")); + + BatchExecution completed = new BatchExecution("batch_3", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, + 5, "file-3"); + completed.setStatus(BatchExecutionStatus.RESULTS_PROCESSED); + repository.save(completed); + + BatchExecution failed = new BatchExecution("batch_4", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, 3, + "file-4"); + failed.setStatus(BatchExecutionStatus.FAILED); + repository.save(failed); + + List pending = repository.findPendingExecutions(); + assertThat(pending).hasSize(2); + assertThat(pending).extracting(BatchExecution::getBatchId).containsExactlyInAnyOrder("batch_1", "batch_2"); + } + + @Test + void shouldDeleteById() { + InMemoryBatchExecutionRepository repository = new InMemoryBatchExecutionRepository(); + repository + .save(new BatchExecution("batch_1", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, 10, "file-1")); + + repository.deleteById("batch_1"); + + assertThat(repository.findById("batch_1")).isEmpty(); + } + + @Test + void shouldReplaceOnSave() { + InMemoryBatchExecutionRepository repository = new InMemoryBatchExecutionRepository(); + BatchExecution execution = new BatchExecution("batch_1", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, + 10, "file-1"); + repository.save(execution); + + execution.setStatus(BatchExecutionStatus.IN_PROGRESS); + repository.save(execution); + + assertThat(repository.findById("batch_1").get().getStatus()).isEqualTo(BatchExecutionStatus.IN_PROGRESS); + } + + @Test + void shouldTrackTerminalStatuses() { + assertThat(BatchExecutionStatus.RESULTS_PROCESSED.isTerminal()).isTrue(); + assertThat(BatchExecutionStatus.FAILED.isTerminal()).isTrue(); + assertThat(BatchExecutionStatus.EXPIRED.isTerminal()).isTrue(); + assertThat(BatchExecutionStatus.CANCELLED.isTerminal()).isTrue(); + assertThat(BatchExecutionStatus.SUBMITTED.isTerminal()).isFalse(); + assertThat(BatchExecutionStatus.IN_PROGRESS.isTerminal()).isFalse(); + assertThat(BatchExecutionStatus.VALIDATING.isTerminal()).isFalse(); + assertThat(BatchExecutionStatus.FINALIZING.isTerminal()).isFalse(); + assertThat(BatchExecutionStatus.COMPLETED.isTerminal()).isFalse(); + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/OpenAiBatchModelTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/OpenAiBatchModelTests.java new file mode 100644 index 0000000000..0714b35010 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/OpenAiBatchModelTests.java @@ -0,0 +1,406 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; + +import com.openai.models.batches.Batch; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OpenAiBatchModel}. + * + * @author Yasin Akbas + */ +class OpenAiBatchModelTests { + + @Test + void shouldRejectDuplicateHandlerIds() { + OpenAiBatchApi batchApi = mock(OpenAiBatchApi.class); + TestHandler handler1 = new TestHandler("same-id", "/v1/chat/completions"); + TestHandler handler2 = new TestHandler("same-id", "/v1/embeddings"); + + assertThatThrownBy( + () -> OpenAiBatchModel.builder().batchApi(batchApi).handlers(List.of(handler1, handler2)).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Duplicate handler ID"); + } + + @Test + void shouldAllowMultipleHandlersWithDifferentIds() { + OpenAiBatchApi batchApi = mock(OpenAiBatchApi.class); + TestHandler handler1 = new TestHandler("chat-handler", "/v1/chat/completions"); + TestHandler handler2 = new TestHandler("embed-handler", "/v1/embeddings"); + + OpenAiBatchModel model = OpenAiBatchModel.builder() + .batchApi(batchApi) + .handlers(List.of(handler1, handler2)) + .build(); + + assertThat(model.getHandlers()).hasSize(2); + } + + @Test + void shouldCreateBatchFromPendingItems() { + OpenAiBatchApi batchApi = mock(OpenAiBatchApi.class); + Batch mockBatch = mock(Batch.class); + when(mockBatch.id()).thenReturn("batch_abc123"); + when(mockBatch.inputFileId()).thenReturn("file-input"); + when(batchApi.createBatch(anyList(), any(), any(), anyMap())).thenReturn(mockBatch); + + TestHandler handler = new TestHandler("chat", "/v1/chat/completions"); + handler.addPendingItem("entity-1", Map.of("text", "Hello")); + handler.addPendingItem("entity-2", Map.of("text", "World")); + + OpenAiBatchModel model = OpenAiBatchModel.builder().batchApi(batchApi).handlers(List.of(handler)).build(); + + List batches = model.createBatchExecutions(); + + assertThat(batches).hasSize(1); + assertThat(batches.get(0).id()).isEqualTo("batch_abc123"); + } + + @Test + void shouldReturnEmptyWhenNoPendingItems() { + OpenAiBatchApi batchApi = mock(OpenAiBatchApi.class); + TestHandler handler = new TestHandler("chat", "/v1/chat/completions"); + + OpenAiBatchModel model = OpenAiBatchModel.builder().batchApi(batchApi).handlers(List.of(handler)).build(); + + List batches = model.createBatchExecutions(); + + assertThat(batches).isEmpty(); + verify(batchApi, never()).createBatch(anyList(), any(), any(), anyMap()); + } + + @Test + void shouldGroupRequestsByEndpoint() { + OpenAiBatchApi batchApi = mock(OpenAiBatchApi.class); + Batch mockBatch1 = mock(Batch.class); + Batch mockBatch2 = mock(Batch.class); + when(mockBatch1.id()).thenReturn("batch_chat"); + when(mockBatch1.inputFileId()).thenReturn("file-chat"); + when(mockBatch2.id()).thenReturn("batch_embed"); + when(mockBatch2.inputFileId()).thenReturn("file-embed"); + when(batchApi.createBatch(anyList(), any(), any(), anyMap())).thenReturn(mockBatch1, mockBatch2); + + TestHandler chatHandler = new TestHandler("chat", "/v1/chat/completions"); + chatHandler.addPendingItem("e1", Map.of("text", "hello")); + + TestHandler embedHandler = new TestHandler("embed", "/v1/embeddings"); + embedHandler.addPendingItem("e2", Map.of("input", "test")); + + OpenAiBatchModel model = OpenAiBatchModel.builder() + .batchApi(batchApi) + .handlers(List.of(chatHandler, embedHandler)) + .build(); + + List batches = model.createBatchExecutions(); + + assertThat(batches).hasSize(2); + } + + @Test + void shouldNotifyListenersOnBatchCreated() { + OpenAiBatchApi batchApi = mock(OpenAiBatchApi.class); + Batch mockBatch = mock(Batch.class); + when(mockBatch.id()).thenReturn("batch_123"); + when(mockBatch.inputFileId()).thenReturn("file-input"); + when(batchApi.createBatch(anyList(), any(), any(), anyMap())).thenReturn(mockBatch); + + TestHandler handler = new TestHandler("chat", "/v1/chat/completions"); + handler.addPendingItem("e1", Map.of("text", "hello")); + + AtomicInteger listenerCalled = new AtomicInteger(0); + OpenAiBatchListener listener = new OpenAiBatchListener() { + @Override + public void onBatchCreated(Batch batch, int requestCount) { + assertThat(batch.id()).isEqualTo("batch_123"); + assertThat(requestCount).isEqualTo(1); + listenerCalled.incrementAndGet(); + } + }; + + OpenAiBatchModel model = OpenAiBatchModel.builder() + .batchApi(batchApi) + .handlers(List.of(handler)) + .listeners(List.of(listener)) + .build(); + + model.createBatchExecutions(); + + assertThat(listenerCalled.get()).isEqualTo(1); + } + + @Test + void shouldDispatchSuccessToCorrectHandler() { + OpenAiBatchApi batchApi = mock(OpenAiBatchApi.class); + + Batch mockBatch = mock(Batch.class); + when(mockBatch.id()).thenReturn("batch_done"); + when(mockBatch.status()).thenReturn(Batch.Status.COMPLETED); + when(mockBatch.outputFileId()).thenReturn(Optional.of("file-output-123")); + when(mockBatch.errorFileId()).thenReturn(Optional.empty()); + when(mockBatch.inputFileId()).thenReturn("file-input-123"); + when(mockBatch.metadata()).thenReturn(Optional.empty()); + + String jsonlOutput = """ + {"id":"req1","custom_id":"entity-1::chat","response":{"status_code":200,"body":{"result":"ok"}}} + """; + when(batchApi.retrieveBatch("batch_done")).thenReturn(mockBatch); + when(batchApi.downloadFileContent("file-output-123")).thenReturn(jsonlOutput); + when(batchApi.parseResponseLines(jsonlOutput)).thenReturn(List.of(new BatchResponseLine("req1", + "entity-1::chat", new BatchResponseLine.Response(200, "req_x", Map.of("result", "ok")), null))); + + TestHandler chatHandler = new TestHandler("chat", "/v1/chat/completions"); + + OpenAiBatchModel model = OpenAiBatchModel.builder().batchApi(batchApi).handlers(List.of(chatHandler)).build(); + + model.checkBatchExecution("batch_done"); + + assertThat(chatHandler.getSuccessCount()).isEqualTo(1); + assertThat(chatHandler.getLastSuccessEntityId()).isEqualTo("entity-1"); + } + + @Test + void shouldPassBatchVersionFromMetadataToHandler() { + OpenAiBatchApi batchApi = mock(OpenAiBatchApi.class); + + Batch mockBatch = mock(Batch.class); + when(mockBatch.id()).thenReturn("batch_ver"); + when(mockBatch.status()).thenReturn(Batch.Status.COMPLETED); + when(mockBatch.outputFileId()).thenReturn(Optional.of("file-out")); + when(mockBatch.errorFileId()).thenReturn(Optional.empty()); + when(mockBatch.inputFileId()).thenReturn("file-in"); + + // Batch metadata contains handler-version=2 + com.openai.models.batches.Batch.Metadata metadata = com.openai.models.batches.Batch.Metadata.builder() + .putAdditionalProperty("handler-version", com.openai.core.JsonValue.from("2")) + .build(); + when(mockBatch.metadata()).thenReturn(Optional.of(metadata)); + + when(batchApi.retrieveBatch("batch_ver")).thenReturn(mockBatch); + when(batchApi.downloadFileContent("file-out")).thenReturn("line"); + when(batchApi.parseResponseLines("line")).thenReturn(List.of(new BatchResponseLine("req1", "entity-1::chat", + new BatchResponseLine.Response(200, "req_x", Map.of("result", "ok")), null))); + + List receivedVersions = new ArrayList<>(); + TestHandler handler = new TestHandler("chat", "/v1/chat/completions") { + @Override + public void onSuccess(BatchRequestCustomId customId, Map responseBody, int batchVersion) { + receivedVersions.add(batchVersion); + super.onSuccess(customId, responseBody, batchVersion); + } + }; + + OpenAiBatchModel model = OpenAiBatchModel.builder().batchApi(batchApi).handlers(List.of(handler)).build(); + + model.checkBatchExecution("batch_ver"); + + assertThat(handler.getSuccessCount()).isEqualTo(1); + assertThat(receivedVersions).containsExactly(2); + } + + @Test + void shouldSaveBatchExecutionToRepository() { + OpenAiBatchApi batchApi = mock(OpenAiBatchApi.class); + Batch mockBatch = mock(Batch.class); + when(mockBatch.id()).thenReturn("batch_repo_test"); + when(mockBatch.inputFileId()).thenReturn("file-input-456"); + when(batchApi.createBatch(anyList(), any(), any(), anyMap())).thenReturn(mockBatch); + + InMemoryBatchExecutionRepository repository = new InMemoryBatchExecutionRepository(); + TestHandler handler = new TestHandler("chat", "/v1/chat/completions"); + handler.addPendingItem("e1", Map.of("text", "hello")); + + OpenAiBatchModel model = OpenAiBatchModel.builder() + .batchApi(batchApi) + .executionRepository(repository) + .handlers(List.of(handler)) + .build(); + + model.createBatchExecutions(); + + Optional execution = repository.findById("batch_repo_test"); + assertThat(execution).isPresent(); + assertThat(execution.get().getStatus()).isEqualTo(BatchExecutionStatus.SUBMITTED); + assertThat(execution.get().getRequestCount()).isEqualTo(1); + assertThat(execution.get().getEndpoint()).isEqualTo("/v1/chat/completions"); + } + + @Test + void shouldUpdateExecutionStatusOnCheck() { + OpenAiBatchApi batchApi = mock(OpenAiBatchApi.class); + + Batch mockBatch = mock(Batch.class); + when(mockBatch.id()).thenReturn("batch_status"); + when(mockBatch.status()).thenReturn(Batch.Status.IN_PROGRESS); + when(batchApi.retrieveBatch("batch_status")).thenReturn(mockBatch); + + InMemoryBatchExecutionRepository repository = new InMemoryBatchExecutionRepository(); + repository.save(new BatchExecution("batch_status", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, 5, + "file-input")); + + OpenAiBatchModel model = OpenAiBatchModel.builder().batchApi(batchApi).executionRepository(repository).build(); + + model.checkBatchExecution("batch_status"); + + assertThat(repository.findById("batch_status").get().getStatus()).isEqualTo(BatchExecutionStatus.IN_PROGRESS); + } + + @Test + void shouldCheckAllPendingExecutions() { + OpenAiBatchApi batchApi = mock(OpenAiBatchApi.class); + + Batch batch1 = mock(Batch.class); + when(batch1.id()).thenReturn("batch_1"); + when(batch1.status()).thenReturn(Batch.Status.IN_PROGRESS); + when(batchApi.retrieveBatch("batch_1")).thenReturn(batch1); + + Batch batch2 = mock(Batch.class); + when(batch2.id()).thenReturn("batch_2"); + when(batch2.status()).thenReturn(Batch.Status.IN_PROGRESS); + when(batchApi.retrieveBatch("batch_2")).thenReturn(batch2); + + InMemoryBatchExecutionRepository repository = new InMemoryBatchExecutionRepository(); + repository + .save(new BatchExecution("batch_1", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, 10, "file-1")); + repository + .save(new BatchExecution("batch_2", "/v1/embeddings", BatchExecutionStatus.IN_PROGRESS, 20, "file-2")); + // Terminal execution should NOT be checked + BatchExecution done = new BatchExecution("batch_3", "/v1/chat/completions", BatchExecutionStatus.SUBMITTED, 5, + "file-3"); + done.setStatus(BatchExecutionStatus.RESULTS_PROCESSED); + repository.save(done); + + OpenAiBatchModel model = OpenAiBatchModel.builder().batchApi(batchApi).executionRepository(repository).build(); + + List results = model.checkAllPendingExecutions(); + + assertThat(results).hasSize(2); + } + + @Test + void shouldBuildWithDefaultOptions() { + OpenAiBatchApi batchApi = mock(OpenAiBatchApi.class); + OpenAiBatchModel model = OpenAiBatchModel.builder().batchApi(batchApi).build(); + + assertThat(model.getOptions()).isNotNull(); + assertThat(model.getOptions().getCompletionWindow()).isEqualTo("24h"); + assertThat(model.getOptions().getMaxRequestsPerBatch()).isEqualTo(50_000); + } + + /** + * Simple test handler for unit testing. + */ + static class TestHandler implements BatchRequestHandler> { + + private final String handlerId; + + private final String endpoint; + + private final Map> pendingItems = new LinkedHashMap<>(); + + private final List successEntityIds = new ArrayList<>(); + + private final List errorEntityIds = new ArrayList<>(); + + TestHandler(String handlerId, String endpoint) { + this.handlerId = handlerId; + this.endpoint = endpoint; + } + + void addPendingItem(String entityId, Map input) { + this.pendingItems.put(entityId, input); + } + + int getSuccessCount() { + return this.successEntityIds.size(); + } + + int getErrorCount() { + return this.errorEntityIds.size(); + } + + String getLastSuccessEntityId() { + return this.successEntityIds.isEmpty() ? null : this.successEntityIds.get(this.successEntityIds.size() - 1); + } + + @Override + public String getHandlerId() { + return this.handlerId; + } + + @Override + public String getEndpoint() { + return this.endpoint; + } + + @Override + public Map generateRequestBody(Map input) { + Map body = new LinkedHashMap<>(input); + body.put("model", "gpt-4o-mini"); + return body; + } + + @Override + public int estimateTokenUsage(Map input) { + return 100; + } + + @Override + public void onSuccess(BatchRequestCustomId customId, Map responseBody, int batchVersion) { + this.successEntityIds.add(customId.entityId()); + } + + @Override + public void onError(BatchRequestCustomId customId, BatchResponseLine.Error error, int batchVersion) { + this.errorEntityIds.add(customId.entityId()); + } + + @Override + public Map> getPendingItems(int maxItems) { + Map> result = new LinkedHashMap<>(); + int count = 0; + for (Map.Entry> entry : this.pendingItems.entrySet()) { + if (count >= maxItems) { + break; + } + result.put(entry.getKey(), entry.getValue()); + count++; + } + return result; + } + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/OpenAiBatchOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/OpenAiBatchOptionsTests.java new file mode 100644 index 0000000000..1f01a5ee10 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/batch/OpenAiBatchOptionsTests.java @@ -0,0 +1,128 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.batch; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link OpenAiBatchOptions}. + * + * @author Yasin Akbas + */ +class OpenAiBatchOptionsTests { + + @Test + void shouldUseDefaults() { + OpenAiBatchOptions options = OpenAiBatchOptions.builder().build(); + + assertThat(options.getCompletionWindow()).isEqualTo("24h"); + assertThat(options.getMaxRequestsPerBatch()).isEqualTo(50_000); + assertThat(options.getMaxFileSizeBytes()).isEqualTo(200L * 1024 * 1024); + assertThat(options.getTokenSafetyFactor()).isEqualTo(1.2); + assertThat(options.getMinimumTokensToSubmit()).isEqualTo(5_000_000L); + assertThat(options.getMaxRetryAttempts()).isEqualTo(2); + assertThat(options.isDeleteFilesAfterProcessing()).isTrue(); + } + + @Test + void shouldApplyCustomValues() { + OpenAiBatchOptions options = OpenAiBatchOptions.builder() + .completionWindow("48h") + .maxRequestsPerBatch(10_000) + .maxFileSizeBytes(100L * 1024 * 1024) + .tokenSafetyFactor(1.5) + .minimumTokensToSubmit(1_000_000L) + .maxRetryAttempts(5) + .deleteFilesAfterProcessing(false) + .build(); + + assertThat(options.getCompletionWindow()).isEqualTo("48h"); + assertThat(options.getMaxRequestsPerBatch()).isEqualTo(10_000); + assertThat(options.getMaxFileSizeBytes()).isEqualTo(100L * 1024 * 1024); + assertThat(options.getTokenSafetyFactor()).isEqualTo(1.5); + assertThat(options.getMinimumTokensToSubmit()).isEqualTo(1_000_000L); + assertThat(options.getMaxRetryAttempts()).isEqualTo(5); + assertThat(options.isDeleteFilesAfterProcessing()).isFalse(); + } + + @Test + void shouldCopyOptions() { + OpenAiBatchOptions original = OpenAiBatchOptions.builder() + .completionWindow("12h") + .maxRequestsPerBatch(1000) + .baseUrl("https://custom.api.com") + .apiKey("sk-test") + .build(); + + OpenAiBatchOptions copy = original.copy(); + + assertThat(copy.getCompletionWindow()).isEqualTo("12h"); + assertThat(copy.getMaxRequestsPerBatch()).isEqualTo(1000); + assertThat(copy.getBaseUrl()).isEqualTo("https://custom.api.com"); + assertThat(copy.getApiKey()).isEqualTo("sk-test"); + assertThat(copy).isNotSameAs(original); + } + + @Test + void shouldSetConnectionProperties() { + OpenAiBatchOptions options = OpenAiBatchOptions.builder() + .baseUrl("https://api.openai.com") + .apiKey("sk-test-key") + .organizationId("org-123") + .build(); + + assertThat(options.getBaseUrl()).isEqualTo("https://api.openai.com"); + assertThat(options.getApiKey()).isEqualTo("sk-test-key"); + assertThat(options.getOrganizationId()).isEqualTo("org-123"); + } + + @Test + void shouldBuildFromExistingOptions() { + OpenAiBatchOptions original = OpenAiBatchOptions.builder() + .completionWindow("6h") + .maxRequestsPerBatch(500) + .tokenSafetyFactor(1.8) + .apiKey("original-key") + .build(); + + OpenAiBatchOptions rebuilt = OpenAiBatchOptions.builder().from(original).maxRequestsPerBatch(1000).build(); + + assertThat(rebuilt.getCompletionWindow()).isEqualTo("6h"); + assertThat(rebuilt.getMaxRequestsPerBatch()).isEqualTo(1000); + assertThat(rebuilt.getTokenSafetyFactor()).isEqualTo(1.8); + assertThat(rebuilt.getApiKey()).isEqualTo("original-key"); + } + + @Test + void shouldImplementEqualsAndHashCode() { + OpenAiBatchOptions options1 = OpenAiBatchOptions.builder() + .completionWindow("24h") + .maxRequestsPerBatch(1000) + .build(); + + OpenAiBatchOptions options2 = OpenAiBatchOptions.builder() + .completionWindow("24h") + .maxRequestsPerBatch(1000) + .build(); + + assertThat(options1).isEqualTo(options2); + assertThat(options1.hashCode()).isEqualTo(options2.hashCode()); + } + +} diff --git a/pom.xml b/pom.xml index 221d6c43c6..4cbd1019a3 100644 --- a/pom.xml +++ b/pom.xml @@ -109,6 +109,7 @@ auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai auto-configurations/models/spring-ai-autoconfigure-model-elevenlabs auto-configurations/models/spring-ai-autoconfigure-model-openai + auto-configurations/models/spring-ai-autoconfigure-model-openai-batch-repository-jdbc auto-configurations/models/spring-ai-autoconfigure-model-minimax auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai auto-configurations/models/spring-ai-autoconfigure-model-ollama @@ -184,6 +185,7 @@ models/spring-ai-mistral-ai models/spring-ai-ollama models/spring-ai-openai + models/spring-ai-openai-batch-repository-jdbc models/spring-ai-postgresml models/spring-ai-stability-ai models/spring-ai-transformers @@ -202,6 +204,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-model-mistral-ai spring-ai-spring-boot-starters/spring-ai-starter-model-ollama spring-ai-spring-boot-starters/spring-ai-starter-model-openai + spring-ai-spring-boot-starters/spring-ai-starter-model-openai-batch-repository-jdbc spring-ai-spring-boot-starters/spring-ai-starter-model-postgresml-embedding spring-ai-spring-boot-starters/spring-ai-starter-model-stability-ai spring-ai-spring-boot-starters/spring-ai-starter-model-transformers diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index e0b095da7b..39e46a3047 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -309,6 +309,24 @@ ${project.version} + + org.springframework.ai + spring-ai-openai-batch-repository-jdbc + ${project.version} + + + + org.springframework.ai + spring-ai-autoconfigure-model-openai-batch-repository-jdbc + ${project.version} + + + + org.springframework.ai + spring-ai-starter-model-openai-batch-repository-jdbc + ${project.version} + + org.springframework.ai spring-ai-postgresml diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index e3f225753a..88dd28f16b 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -66,6 +66,9 @@ *** xref:api/moderation[Moderation Models] **** xref:api/moderation/openai-moderation.adoc[OpenAI] **** xref:api/moderation/mistral-ai-moderation.adoc[Mistral AI] + +*** Batch API +**** xref:api/batch/openai-batch.adoc[OpenAI] // ** xref:api/generic-model.adoc[] ** xref:api/chat-memory.adoc[Chat Memory] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/batch/openai-batch.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/batch/openai-batch.adoc new file mode 100644 index 0000000000..eb0cb104b9 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/batch/openai-batch.adoc @@ -0,0 +1,506 @@ += OpenAI Batch API + +Spring AI supports https://platform.openai.com/docs/guides/batch[OpenAI's Batch API], which enables asynchronous processing of large volumes of API requests at 50% cost reduction compared to synchronous calls. +Batch requests are processed within a 24-hour completion window and have separate, higher rate limits from the synchronous API. + +The Batch API supports the following OpenAI endpoints: + +* `/v1/chat/completions` — Chat completions +* `/v1/embeddings` — Embedding generation +* `/v1/responses` — Responses API +* `/v1/completions` — Legacy completions +* `/v1/moderations` — Content moderation +* `/v1/images/generations` — Image generation +* `/v1/images/edits` — Image editing + +== Prerequisites + +. Create an OpenAI account and obtain an API key. You can sign up at the https://platform.openai.com/signup[OpenAI signup page] and generate an API key on the https://platform.openai.com/account/api-keys[API Keys page]. +. Add the `spring-ai-starter-model-openai` dependency to your project. For more information, refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section. + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the OpenAI Batch API. +To enable it, add the following dependency to your project's Maven `pom.xml` file: + +[tabs] +====== +Maven:: ++ +[source,xml] +---- + + org.springframework.ai + spring-ai-starter-model-openai + +---- + +Gradle:: ++ +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-starter-model-openai' +} +---- +====== + +Then enable the batch support in your `application.properties`: + +[source,properties] +---- +spring.ai.openai.api-key=${OPENAI_API_KEY} +spring.ai.openai.batch.enabled=true +---- + +== Architecture + +The batch system uses a three-phase lifecycle: + +1. **Prepare** — Collect pending items from registered `BatchRequestHandler` implementations and convert them into `BatchRequestLine` objects +2. **Execute** — Upload JSONL files to OpenAI and create batch executions +3. **Check** — Poll batch status, download output/error files, and dispatch results back to handlers + +[NOTE] +==== +The batch support does **not** auto-configure scheduling. +You control when each lifecycle phase runs — via `@Scheduled`, Spring Batch, or any other trigger. +==== + +=== Hybrid Storage Approach + +Handlers follow a **hybrid storage** approach: domain-specific input data (e.g., the text to analyze) is stored in your database, while the full OpenAI API request envelope (model, parameters, messages) is generated on demand at execution time. +This means configuration changes (e.g., switching from `gpt-4o` to `gpt-4o-mini`, or adjusting `reasoning_effort`) automatically apply to all pending requests without requiring data cleanup. + +=== Endpoint-agnostic Design + +Each `BatchRequestHandler` declares its target endpoint. +The `OpenAiBatchModel` orchestrator automatically groups pending requests by endpoint, since OpenAI requires each batch to target a single endpoint. +This means you can register handlers for chat completions, embeddings, and other endpoints — all managed by a single `OpenAiBatchModel` instance. + +== Key Components + +=== `BatchRequestHandler` + +The primary interface you implement to integrate with the batch system. +The type parameter `I` represents your domain-specific input data type. + +[source,java] +---- +public class ChatCompletionBatchHandler + implements BatchRequestHandler { + + @Override + public String getHandlerId() { + return "article-summarizer"; + } + + @Override + public String getEndpoint() { + return "/v1/chat/completions"; + } + + @Override + public Map generateRequestBody( + ArticleInput input) { + return Map.of( + "model", "gpt-4o-mini", + "messages", List.of( + Map.of("role", "system", + "content", "Summarize the article."), + Map.of("role", "user", + "content", input.text()) + ) + ); + } + + @Override + public int estimateTokenUsage(ArticleInput input) { + return input.text().length() / 4; // rough estimate + } + + @Override + public void onSuccess(BatchRequestCustomId customId, + Map responseBody, + int batchVersion) { + String entityId = customId.entityId(); + // Process the result — save to database, etc. + } + + @Override + public void onError(BatchRequestCustomId customId, + BatchResponseLine.Error error, + int batchVersion) { + log.error("Failed for {}: {} - {}", + customId.entityId(), + error.code(), error.message()); + } + + @Override + public Map getPendingItems( + int maxItems) { + // Query your database for unprocessed articles + return articleRepository.findPending(maxItems); + } +} +---- + +=== `OpenAiBatchModel` + +The central orchestrator that manages the batch lifecycle. +Build it with the builder pattern: + +[source,java] +---- +OpenAiBatchModel batchModel = OpenAiBatchModel.builder() + .batchApi(batchApi) + .handlers(List.of( + chatHandler, + embeddingHandler + )) + .listeners(List.of(metricsListener)) + .executionRepository( + new InMemoryBatchExecutionRepository()) + .options(OpenAiBatchOptions.builder() + .maxRequestsPerBatch(10_000) + .completionWindow("24h") + .deleteFilesAfterProcessing(true) + .build()) + .build(); +---- + +Use the lifecycle methods to drive batch processing: + +[source,java] +---- +// Phase 1+2: Collect pending items and create batches +List batches = batchModel.createBatchExecutions(); + +// Phase 3: Check all pending batches (recommended) +List results = + batchModel.checkAllPendingExecutions(); + +// Or check a specific batch by ID +Batch updated = batchModel.checkBatchExecution(batchId); + +// Cancel if needed +batchModel.cancelBatch(batchId); +---- + +=== `OpenAiBatchApi` + +Low-level wrapper around the OpenAI SDK's `BatchService` and `FileService`. +You typically don't use this directly — `OpenAiBatchModel` manages it for you. +However, it's available if you need direct access: + +[source,java] +---- +// Create from an OpenAI client +OpenAiBatchApi batchApi = new OpenAiBatchApi( + openAiClient, new ObjectMapper()); + +// Retrieve batch status +Batch batch = batchApi.retrieveBatch("batch_abc123"); + +// Download and parse results +String content = batchApi.downloadFileContent( + batch.outputFileId().orElseThrow()); +List lines = + batchApi.parseResponseLines(content); +---- + +=== `OpenAiBatchListener` + +Callback interface for observing batch lifecycle events. +All methods have no-op defaults — override only the events you care about: + +[source,java] +---- +public class MetricsListener implements OpenAiBatchListener { + + @Override + public void onBatchCreated(Batch batch, + int requestCount) { + metrics.counter("batch.created").increment(); + log.info("Batch {} created with {} requests", + batch.id(), requestCount); + } + + @Override + public void onBatchCompleted(Batch batch) { + metrics.counter("batch.completed").increment(); + } + + @Override + public void onBatchResultsProcessed(Batch batch, + List successLines, + List errorLines) { + metrics.counter("batch.success", + "count", String.valueOf(successLines.size())) + .increment(); + metrics.counter("batch.errors", + "count", String.valueOf(errorLines.size())) + .increment(); + } +} +---- + +=== `BatchRequestCustomId` + +Encodes an entity identifier and handler identifier into a single custom ID string using the `::` delimiter: + +[source,java] +---- +// Create a custom ID +BatchRequestCustomId id = new BatchRequestCustomId( + "article-42", "article-summarizer"); +id.toString(); // "article-42::article-summarizer" + +// Parse from response +BatchRequestCustomId parsed = + BatchRequestCustomId.parse( + "article-42::article-summarizer"); +parsed.entityId(); // "article-42" +parsed.handlerId(); // "article-summarizer" +---- + +=== `BatchExecutionRepository` + +Tracks batch executions across their lifecycle, eliminating the need for manual batch ID tracking. +An in-memory default implementation (`InMemoryBatchExecutionRepository`) is provided out of the box. + +The repository tracks `BatchExecution` records with status, request count, endpoint, and timestamps. +When using auto-configuration, the `InMemoryBatchExecutionRepository` is registered as the default bean. + +=== Persistence + +The auto-configured `InMemoryBatchExecutionRepository` works for development and single-instance deployments. +For production scenarios (restart resilience, multi-instance), use the provided JDBC implementation or implement your own `BatchExecutionRepository`. + +==== JDBC Persistence + +Spring AI provides `JdbcBatchExecutionRepository` with multi-database support. +Add the JDBC starter dependency: + +[tabs] +====== +Maven:: ++ +[source,xml] +---- + + org.springframework.ai + spring-ai-starter-model-openai-batch-repository-jdbc + +---- + +Gradle:: ++ +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-starter-model-openai-batch-repository-jdbc' +} +---- +====== + +Configure schema initialization in `application.properties`: + +[source,properties] +---- +spring.ai.openai.batch.repository.jdbc.initialize-schema=always +---- + +Supported databases: PostgreSQL, MySQL, MariaDB, SQL Server, H2, HSQLDB, SQLite, Oracle. +The dialect is auto-detected from the `DataSource`. + +The JDBC implementation automatically takes precedence over the in-memory default when on the classpath. + +==== Custom Persistence + +You can implement your own `BatchExecutionRepository` for other data stores (MongoDB, Redis, etc.): + +[source,java] +---- +@Bean +public BatchExecutionRepository batchExecutionRepository() { + return new MyCustomBatchExecutionRepository(); +} +---- + +The `BatchRequestHandler` interface is responsible for managing its own domain data persistence (pending items, processed results). +The `BatchExecutionRepository` only tracks OpenAI batch execution metadata. + +== Handler Versioning + +When handler code changes (e.g., response parsing logic) while batches are in-flight, the new handler code may not be compatible with responses from the old version. +The batch system handles this by storing a version number in the OpenAI batch metadata. + +Configure the handler version in `OpenAiBatchOptions`: + +[source,java] +---- +OpenAiBatchOptions options = OpenAiBatchOptions.builder() + .handlerVersion(2) // Increment when parsing logic changes + .build(); +---- + +Or via properties: + +[source,properties] +---- +spring.ai.openai.batch.options.handler-version=2 +---- + +When a batch is created, the current `handlerVersion` is stored in the batch metadata. +When responses arrive, the version is extracted from the batch metadata and passed to `onSuccess()` / `onError()` as the `batchVersion` parameter. +Handlers can compare this to their expected version and act accordingly — for example, skipping incompatible results or applying version-specific parsing logic. + +[source,java] +---- +@Override +public void onSuccess(BatchRequestCustomId customId, + Map responseBody, + int batchVersion) { + if (batchVersion < CURRENT_VERSION) { + log.warn("Outdated batch version {} for {}", + batchVersion, customId.entityId()); + // Re-queue for reprocessing, or apply + // backward-compatible parsing + return; + } + // Process normally +} +---- + +[TIP] +==== +Only increment the version when your handler's `onSuccess()` parsing logic changes in a backward-incompatible way. +Changes to `generateRequestBody()` do not require a version bump since the request is generated at submission time. +==== + +== Scheduling Example + +The batch system does not include built-in scheduling. +Here is a typical setup using Spring's `@Scheduled`: + +[source,java] +---- +@Component +public class BatchScheduler { + + private final OpenAiBatchModel batchModel; + + // Run every 15 minutes: collect and submit + @Scheduled(fixedRate = 15, timeUnit = TimeUnit.MINUTES) + public void submitBatches() { + batchModel.createBatchExecutions(); + // Batches are automatically tracked in the + // BatchExecutionRepository + } + + // Run every 5 minutes: check all pending batches + @Scheduled(fixedRate = 5, timeUnit = TimeUnit.MINUTES) + public void checkBatches() { + // Automatically finds and checks all + // non-terminal executions + batchModel.checkAllPendingExecutions(); + } +} +---- + +== Configuration Properties + +All properties are under the `spring.ai.openai.batch` prefix. + +[cols="3,1,3", options="header"] +|=== +| Property | Default | Description + +| `spring.ai.openai.batch.enabled` +| `false` +| Enable OpenAI Batch API auto-configuration + +| `spring.ai.openai.batch.options.completion-window` +| `24h` +| OpenAI batch completion window + +| `spring.ai.openai.batch.options.max-requests-per-batch` +| `50000` +| Maximum requests per batch (OpenAI limit: 50,000) + +| `spring.ai.openai.batch.options.max-file-size-bytes` +| `209715200` +| Maximum JSONL file size in bytes (200 MB) + +| `spring.ai.openai.batch.options.token-safety-factor` +| `1.2` +| Multiplier applied to token estimates for safety margin + +| `spring.ai.openai.batch.options.minimum-tokens-to-submit` +| `5000000` +| Minimum available tokens before submitting a batch + +| `spring.ai.openai.batch.options.max-retry-attempts` +| `2` +| Maximum retry attempts for failed/expired requests + +| `spring.ai.openai.batch.options.delete-files-after-processing` +| `true` +| Delete input/output/error files after processing + +| `spring.ai.openai.batch.options.handler-version` +| `1` +| Version stored in batch metadata; passed to handlers when processing results +|=== + +Connection properties (`api-key`, `base-url`, `organization-id`, etc.) are inherited from the standard `spring.ai.openai.*` properties when using auto-configuration. + +== Batch Lifecycle States + +OpenAI batches transition through the following states: + +[source] +---- +VALIDATING → IN_PROGRESS → FINALIZING → COMPLETED → RESULTS_PROCESSED + → EXPIRED + → FAILED + → CANCELLING → CANCELLED +---- + +The `OpenAiBatchModel` handles each terminal state: + +* **COMPLETED** → **RESULTS_PROCESSED** — Downloads output file, dispatches successes and errors to handlers, notifies listeners, updates repository +* **FAILED** — Downloads error file (if available), dispatches errors to handlers +* **EXPIRED** → **RESULTS_PROCESSED** — Processes any partial output, notifies listeners +* **CANCELLED** — Notifies listeners + +Non-200 responses in the output file are treated as errors and dispatched to `BatchRequestHandler.onError()` with error details extracted from the response body. + +== Error Handling + +The batch system provides multiple layers of error isolation: + +* **Per-handler errors** — If one handler throws during `onSuccess` or `onError`, other handlers and remaining items continue processing +* **File processing errors** — Output/error file download failures are logged but don't crash the batch check +* **Custom ID parsing** — Malformed custom IDs are logged and skipped +* **Non-200 responses** — Responses with non-200 status codes in the output file have error details extracted from the response body and dispatched as errors + +== Sample Configuration + +[source,yaml] +---- +spring: + ai: + openai: + api-key: ${OPENAI_API_KEY} + batch: + enabled: true + options: + completion-window: 24h + max-requests-per-batch: 10000 + token-safety-factor: 1.5 + max-retry-attempts: 3 + delete-files-after-processing: true + handler-version: 1 +---- diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-openai-batch-repository-jdbc/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-openai-batch-repository-jdbc/pom.xml new file mode 100644 index 0000000000..765b47ad8e --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-openai-batch-repository-jdbc/pom.xml @@ -0,0 +1,58 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 2.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-starter-model-openai-batch-repository-jdbc + jar + Spring AI Starter - JDBC OpenAI Batch Execution Repository + Spring AI JDBC OpenAI Batch Execution Repository Starter + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + scm:git:git://github.com/spring-projects/spring-ai.git + scm:git:ssh://git@github.com/spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-autoconfigure-model-openai-batch-repository-jdbc + ${project.parent.version} + + + + org.springframework.ai + spring-ai-openai-batch-repository-jdbc + ${project.parent.version} + + + +