Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-observation-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package org.springframework.ai.model.stabilityai.autoconfigure;

import io.micrometer.observation.ObservationRegistry;

import org.springframework.ai.image.observation.ImageModelObservationConvention;
import org.springframework.ai.model.SpringAIModelProperties;
import org.springframework.ai.model.SpringAIModels;
import org.springframework.ai.stabilityai.StabilityAiImageModel;
Expand Down Expand Up @@ -67,8 +70,13 @@ public StabilityAiApi stabilityAiApi(StabilityAiConnectionProperties commonPrope
@Bean
@ConditionalOnMissingBean
public StabilityAiImageModel stabilityAiImageModel(StabilityAiApi stabilityAiApi,
StabilityAiImageProperties stabilityAiImageProperties) {
return new StabilityAiImageModel(stabilityAiApi, stabilityAiImageProperties.getOptions());
StabilityAiImageProperties stabilityAiImageProperties,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ImageModelObservationConvention> observationConvention) {
var imageModel = new StabilityAiImageModel(stabilityAiApi, stabilityAiImageProperties.getOptions(),
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP));
observationConvention.ifAvailable(imageModel::setObservationConvention);
return imageModel;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright 2023-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.model.stabilityai.autoconfigure;

import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistry;
import org.junit.jupiter.api.Test;

import org.springframework.ai.image.observation.DefaultImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationConvention;
import org.springframework.ai.stabilityai.StabilityAiImageModel;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Autoconfiguration tests that exercise the observability wiring of
* {@link StabilityAiImageAutoConfiguration}.
*
* @author Gaurav Kumar
*/
class StabilityAiImageAutoConfigurationObservabilityTests {

private static final String[] BASE_PROPS = new String[] { "spring.ai.stabilityai.image.api-key=API_KEY",
"spring.ai.stabilityai.image.base-url=https://example.invalid" };

@Test
void defaultsToNoopObservationRegistryWhenNoBeanPresent() {
new ApplicationContextRunner().withPropertyValues(BASE_PROPS)
.withConfiguration(AutoConfigurations.of(StabilityAiImageAutoConfiguration.class))
.run(context -> assertThat(context).hasSingleBean(StabilityAiImageModel.class));
}

@Test
void usesUserProvidedObservationRegistryBean() {
new ApplicationContextRunner().withPropertyValues(BASE_PROPS)
.withUserConfiguration(ObservationRegistryConfig.class)
.withConfiguration(AutoConfigurations.of(StabilityAiImageAutoConfiguration.class))
.run(context -> {
assertThat(context).hasSingleBean(StabilityAiImageModel.class);
assertThat(context).hasSingleBean(ObservationRegistry.class);
});
}

@Test
void appliesUserProvidedObservationConventionBean() {
new ApplicationContextRunner().withPropertyValues(BASE_PROPS)
.withUserConfiguration(ObservationRegistryConfig.class, CustomConventionConfig.class)
.withConfiguration(AutoConfigurations.of(StabilityAiImageAutoConfiguration.class))
.run(context -> {
assertThat(context).hasSingleBean(StabilityAiImageModel.class);
assertThat(context).hasSingleBean(ImageModelObservationConvention.class);
});
}

@Configuration(proxyBeanMethods = false)
static class ObservationRegistryConfig {

@Bean
ObservationRegistry observationRegistry() {
return TestObservationRegistry.create();
}

}

@Configuration(proxyBeanMethods = false)
static class CustomConventionConfig {

@Bean
ImageModelObservationConvention imageModelObservationConvention() {
return new DefaultImageModelObservationConvention();
}

}

}
6 changes: 6 additions & 0 deletions models/spring-ai-stability-ai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-observation-test</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package org.springframework.ai.stabilityai;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import io.micrometer.observation.ObservationRegistry;
import org.jspecify.annotations.Nullable;

import org.springframework.ai.image.Image;
Expand All @@ -28,30 +30,63 @@
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.ai.image.observation.DefaultImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationContext;
import org.springframework.ai.image.observation.ImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.util.Assert;

/**
* StabilityAiImageModel is a class that implements the ImageModel interface. It provides
* a client for calling the StabilityAI image generation API.
*
* <p>
* Observability data is emitted through the provided {@link ObservationRegistry} using
* the portable {@link ImageModelObservationConvention} infrastructure, matching the
* pattern used by the other Spring AI image models.
*
* @author Mark Pollack
* @author Gaurav Kumar
*/
public class StabilityAiImageModel implements ImageModel {

private static final ImageModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultImageModelObservationConvention();

private final StabilityAiImageOptions defaultOptions;

private final StabilityAiApi stabilityAiApi;

private final ObservationRegistry observationRegistry;

private ImageModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public StabilityAiImageModel(StabilityAiApi stabilityAiApi) {
this(stabilityAiApi, StabilityAiImageOptions.builder().build());
this(stabilityAiApi, StabilityAiImageOptions.builder().build(), ObservationRegistry.NOOP);
}

public StabilityAiImageModel(StabilityAiApi stabilityAiApi, StabilityAiImageOptions defaultOptions) {
this(stabilityAiApi, defaultOptions, ObservationRegistry.NOOP);
}

/**
* Creates a new StabilityAiImageModel.
* @param stabilityAiApi the StabilityAI API client
* @param defaultOptions the default image options
* @param observationRegistry the {@link ObservationRegistry} used to record image
* generation observations; {@link ObservationRegistry#NOOP} is used when {@code null}
* @since 2.0.0
*/
public StabilityAiImageModel(StabilityAiApi stabilityAiApi, StabilityAiImageOptions defaultOptions,
@Nullable ObservationRegistry observationRegistry) {
Assert.notNull(stabilityAiApi, "StabilityAiApi must not be null");
Assert.notNull(defaultOptions, "StabilityAiImageOptions must not be null");
this.stabilityAiApi = stabilityAiApi;
this.defaultOptions = defaultOptions;
this.observationRegistry = Objects.requireNonNullElse(observationRegistry, ObservationRegistry.NOOP);
}

private static StabilityAiApi.GenerateImageRequest getGenerateImageRequest(ImagePrompt stabilityAiImagePrompt,
Expand Down Expand Up @@ -79,30 +114,40 @@ public StabilityAiImageOptions getOptions() {
}

/**
* Calls the StabilityAiImageModel with the given StabilityAiImagePrompt and returns
* the ImageResponse. This overloaded call method lets you pass the full set of Prompt
* instructions that StabilityAI supports.
* Calls the StabilityAiImageModel with the given ImagePrompt and returns the
* ImageResponse. Emits a {@code gen_ai.client.operation} observation via the
* configured {@link ObservationRegistry}.
* @param imagePrompt the StabilityAiImagePrompt containing the prompt and image model
* options
* @return the ImageResponse generated by the StabilityAiImageModel
*/
public ImageResponse call(ImagePrompt imagePrompt) {
// Merge the runtime options passed via the prompt with the default options
// configured via the constructor.
// Runtime options overwrite StabilityAiImageModel options
// configured via the constructor. Runtime options overwrite defaults.
StabilityAiImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions);

// Copy the org.springframework.ai.model derived ImagePrompt and ImageOptions data
// types to the data types used in StabilityAiApi
StabilityAiApi.GenerateImageRequest generateImageRequest = getGenerateImageRequest(imagePrompt,
requestImageOptions);

// Make the request
StabilityAiApi.GenerateImageResponse generateImageResponse = this.stabilityAiApi
.generateImage(generateImageRequest);
// Pass the original ImagePrompt to the observation context, matching the pattern
// used by other Spring AI image models (e.g. OpenAiImageModel) so that
// observation tags remain consistent across providers.
var observationContext = ImageModelObservationContext.builder()
.imagePrompt(imagePrompt)
.provider(AiProvider.STABILITY_AI.value())
.build();

// Convert to org.springframework.ai.model derived ImageResponse data type
return convertResponse(generateImageResponse);
return Objects.requireNonNull(
ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
StabilityAiApi.GenerateImageResponse generateImageResponse = this.stabilityAiApi
.generateImage(generateImageRequest);
ImageResponse imageResponse = convertResponse(generateImageResponse);
observationContext.setResponse(imageResponse);
return imageResponse;
}));
}

private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse generateImageResponse) {
Expand All @@ -115,6 +160,16 @@ private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse gener
return new ImageResponse(imageGenerationList, new ImageResponseMetadata());
}

/**
* Use the provided convention for reporting observation data.
* @param observationConvention the provided convention
* @since 1.1.0
*/
public void setObservationConvention(ImageModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.observationConvention = observationConvention;
}

/**
* Merge runtime and default {@link ImageOptions} to compute the final options to use
* in the request. Protected access for testing purposes, though maybe useful for
Expand Down
Loading