Skip to content

Commit eac98e7

Browse files
committed
test: add OpenAI provider tests to improve coverage
- Add testable OpenAI provider with configurable base URL - Add comprehensive tests for OpenAI Embed, Chat, ModelID, Dimensions - Provider coverage improved from 45.7% to 90.6%
1 parent 6ac37cc commit eac98e7

2 files changed

Lines changed: 310 additions & 3 deletions

File tree

internal/ai/provider/provider.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,17 @@ func (p *OllamaProvider) Chat(ctx context.Context, messages []ChatMessage) (stri
189189
return chatResp.Message.Content, nil
190190
}
191191

192+
// DefaultOpenAIBaseURL is the default base URL for OpenAI API.
193+
const DefaultOpenAIBaseURL = "https://api.openai.com"
194+
192195
// OpenAIProvider implements EmbeddingProvider and ChatProvider using OpenAI API.
193196
type OpenAIProvider struct {
194197
apiKey string
195198
embeddingModel string
196199
chatModel string
197200
client *http.Client
198201
dimensions int
202+
baseURL string
199203
}
200204

201205
// NewOpenAIProvider creates a new OpenAI provider.
@@ -205,6 +209,11 @@ func NewOpenAIProvider(embeddingModel, chatModel string, timeout time.Duration)
205209
return nil, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
206210
}
207211

212+
return newOpenAIProviderWithConfig(apiKey, embeddingModel, chatModel, DefaultOpenAIBaseURL, timeout), nil
213+
}
214+
215+
// newOpenAIProviderWithConfig creates an OpenAI provider with explicit configuration (for testing).
216+
func newOpenAIProviderWithConfig(apiKey, embeddingModel, chatModel, baseURL string, timeout time.Duration) *OpenAIProvider {
208217
// Determine dimensions based on model
209218
dimensions := 1536 // default for text-embedding-3-small
210219
switch embeddingModel {
@@ -224,7 +233,8 @@ func NewOpenAIProvider(embeddingModel, chatModel string, timeout time.Duration)
224233
Timeout: timeout,
225234
},
226235
dimensions: dimensions,
227-
}, nil
236+
baseURL: baseURL,
237+
}
228238
}
229239

230240
// openAIEmbedRequest is the request body for OpenAI embeddings API.
@@ -255,7 +265,7 @@ func (p *OpenAIProvider) Embed(ctx context.Context, text string) ([]float32, err
255265
return nil, fmt.Errorf("failed to marshal request: %w", err)
256266
}
257267

258-
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/embeddings", bytes.NewReader(jsonBody))
268+
req, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/v1/embeddings", bytes.NewReader(jsonBody))
259269
if err != nil {
260270
return nil, fmt.Errorf("failed to create request: %w", err)
261271
}
@@ -328,7 +338,7 @@ func (p *OpenAIProvider) Chat(ctx context.Context, messages []ChatMessage) (stri
328338
return "", fmt.Errorf("failed to marshal request: %w", err)
329339
}
330340

331-
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/chat/completions", bytes.NewReader(jsonBody))
341+
req, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/v1/chat/completions", bytes.NewReader(jsonBody))
332342
if err != nil {
333343
return "", fmt.Errorf("failed to create request: %w", err)
334344
}

internal/ai/provider/provider_coverage_test.go

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,300 @@ func TestOllamaProviderBaseURL(t *testing.T) {
305305
t.Errorf("expected base URL '%s', got '%s'", baseURL, provider.baseURL)
306306
}
307307
}
308+
309+
// TestOpenAIProviderEmbed tests OpenAI Embed with mock server.
310+
func TestOpenAIProviderEmbed(t *testing.T) {
311+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
312+
if r.URL.Path != "/v1/embeddings" {
313+
t.Errorf("expected path /v1/embeddings, got %s", r.URL.Path)
314+
}
315+
316+
if r.Method != "POST" {
317+
t.Errorf("expected POST method, got %s", r.Method)
318+
}
319+
320+
// Check authorization header
321+
auth := r.Header.Get("Authorization")
322+
if auth != "Bearer test-api-key" {
323+
t.Errorf("expected Authorization 'Bearer test-api-key', got '%s'", auth)
324+
}
325+
326+
var req openAIEmbedRequest
327+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
328+
t.Errorf("failed to decode request: %v", err)
329+
}
330+
331+
if req.Model != "text-embedding-3-small" {
332+
t.Errorf("expected model 'text-embedding-3-small', got '%s'", req.Model)
333+
}
334+
335+
resp := openAIEmbedResponse{
336+
Data: []struct {
337+
Embedding []float64 `json:"embedding"`
338+
}{
339+
{Embedding: []float64{0.1, 0.2, 0.3, 0.4, 0.5}},
340+
},
341+
}
342+
w.Header().Set("Content-Type", "application/json")
343+
_ = json.NewEncoder(w).Encode(resp)
344+
}))
345+
defer server.Close()
346+
347+
provider := newOpenAIProviderWithConfig("test-api-key", "text-embedding-3-small", "gpt-4o-mini", server.URL, 60*time.Second)
348+
349+
embedding, err := provider.Embed(context.Background(), "test text")
350+
if err != nil {
351+
t.Fatalf("Embed failed: %v", err)
352+
}
353+
354+
if len(embedding) != 5 {
355+
t.Errorf("expected 5 dimensions, got %d", len(embedding))
356+
}
357+
358+
if embedding[0] != 0.1 {
359+
t.Errorf("expected first value 0.1, got %f", embedding[0])
360+
}
361+
}
362+
363+
// TestOpenAIProviderEmbedError tests OpenAI Embed error handling.
364+
func TestOpenAIProviderEmbedError(t *testing.T) {
365+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
366+
resp := openAIEmbedResponse{
367+
Error: &struct {
368+
Message string `json:"message"`
369+
}{
370+
Message: "Rate limit exceeded",
371+
},
372+
}
373+
w.Header().Set("Content-Type", "application/json")
374+
_ = json.NewEncoder(w).Encode(resp)
375+
}))
376+
defer server.Close()
377+
378+
provider := newOpenAIProviderWithConfig("test-api-key", "text-embedding-3-small", "gpt-4o-mini", server.URL, 60*time.Second)
379+
380+
_, err := provider.Embed(context.Background(), "test text")
381+
if err == nil {
382+
t.Error("expected error, got nil")
383+
}
384+
if err != nil && !contains(err.Error(), "Rate limit exceeded") {
385+
t.Errorf("expected error about rate limit, got: %v", err)
386+
}
387+
}
388+
389+
// TestOpenAIProviderEmbedEmptyResponse tests OpenAI Embed with empty data.
390+
func TestOpenAIProviderEmbedEmptyResponse(t *testing.T) {
391+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
392+
resp := openAIEmbedResponse{
393+
Data: []struct {
394+
Embedding []float64 `json:"embedding"`
395+
}{},
396+
}
397+
w.Header().Set("Content-Type", "application/json")
398+
_ = json.NewEncoder(w).Encode(resp)
399+
}))
400+
defer server.Close()
401+
402+
provider := newOpenAIProviderWithConfig("test-api-key", "text-embedding-3-small", "gpt-4o-mini", server.URL, 60*time.Second)
403+
404+
_, err := provider.Embed(context.Background(), "test text")
405+
if err == nil {
406+
t.Error("expected error for empty data, got nil")
407+
}
408+
}
409+
410+
// TestOpenAIProviderEmbedInvalidJSON tests OpenAI Embed with invalid JSON.
411+
func TestOpenAIProviderEmbedInvalidJSON(t *testing.T) {
412+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
413+
w.Header().Set("Content-Type", "application/json")
414+
_, _ = w.Write([]byte("not valid json"))
415+
}))
416+
defer server.Close()
417+
418+
provider := newOpenAIProviderWithConfig("test-api-key", "text-embedding-3-small", "gpt-4o-mini", server.URL, 60*time.Second)
419+
420+
_, err := provider.Embed(context.Background(), "test text")
421+
if err == nil {
422+
t.Error("expected error for invalid JSON, got nil")
423+
}
424+
}
425+
426+
// TestOpenAIProviderChat tests OpenAI Chat with mock server.
427+
func TestOpenAIProviderChat(t *testing.T) {
428+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
429+
if r.URL.Path != "/v1/chat/completions" {
430+
t.Errorf("expected path /v1/chat/completions, got %s", r.URL.Path)
431+
}
432+
433+
var req openAIChatRequest
434+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
435+
t.Errorf("failed to decode request: %v", err)
436+
}
437+
438+
if req.Model != "gpt-4o-mini" {
439+
t.Errorf("expected model 'gpt-4o-mini', got '%s'", req.Model)
440+
}
441+
442+
resp := openAIChatResponse{
443+
Choices: []struct {
444+
Message ChatMessage `json:"message"`
445+
}{
446+
{Message: ChatMessage{Role: "assistant", Content: "Hello from OpenAI!"}},
447+
},
448+
}
449+
w.Header().Set("Content-Type", "application/json")
450+
_ = json.NewEncoder(w).Encode(resp)
451+
}))
452+
defer server.Close()
453+
454+
provider := newOpenAIProviderWithConfig("test-api-key", "text-embedding-3-small", "gpt-4o-mini", server.URL, 60*time.Second)
455+
456+
messages := []ChatMessage{
457+
{Role: "user", Content: "Hello"},
458+
}
459+
460+
response, err := provider.Chat(context.Background(), messages)
461+
if err != nil {
462+
t.Fatalf("Chat failed: %v", err)
463+
}
464+
465+
if response != "Hello from OpenAI!" {
466+
t.Errorf("expected 'Hello from OpenAI!', got '%s'", response)
467+
}
468+
}
469+
470+
// TestOpenAIProviderChatError tests OpenAI Chat error handling.
471+
func TestOpenAIProviderChatError(t *testing.T) {
472+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
473+
resp := openAIChatResponse{
474+
Error: &struct {
475+
Message string `json:"message"`
476+
}{
477+
Message: "Invalid API key",
478+
},
479+
}
480+
w.Header().Set("Content-Type", "application/json")
481+
_ = json.NewEncoder(w).Encode(resp)
482+
}))
483+
defer server.Close()
484+
485+
provider := newOpenAIProviderWithConfig("test-api-key", "text-embedding-3-small", "gpt-4o-mini", server.URL, 60*time.Second)
486+
487+
messages := []ChatMessage{
488+
{Role: "user", Content: "Hello"},
489+
}
490+
491+
_, err := provider.Chat(context.Background(), messages)
492+
if err == nil {
493+
t.Error("expected error, got nil")
494+
}
495+
}
496+
497+
// TestOpenAIProviderChatEmptyChoices tests OpenAI Chat with empty choices.
498+
func TestOpenAIProviderChatEmptyChoices(t *testing.T) {
499+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
500+
resp := openAIChatResponse{
501+
Choices: []struct {
502+
Message ChatMessage `json:"message"`
503+
}{},
504+
}
505+
w.Header().Set("Content-Type", "application/json")
506+
_ = json.NewEncoder(w).Encode(resp)
507+
}))
508+
defer server.Close()
509+
510+
provider := newOpenAIProviderWithConfig("test-api-key", "text-embedding-3-small", "gpt-4o-mini", server.URL, 60*time.Second)
511+
512+
messages := []ChatMessage{
513+
{Role: "user", Content: "Hello"},
514+
}
515+
516+
_, err := provider.Chat(context.Background(), messages)
517+
if err == nil {
518+
t.Error("expected error for empty choices, got nil")
519+
}
520+
}
521+
522+
// TestOpenAIProviderChatInvalidJSON tests OpenAI Chat with invalid JSON.
523+
func TestOpenAIProviderChatInvalidJSON(t *testing.T) {
524+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
525+
w.Header().Set("Content-Type", "application/json")
526+
_, _ = w.Write([]byte("not valid json"))
527+
}))
528+
defer server.Close()
529+
530+
provider := newOpenAIProviderWithConfig("test-api-key", "text-embedding-3-small", "gpt-4o-mini", server.URL, 60*time.Second)
531+
532+
messages := []ChatMessage{
533+
{Role: "user", Content: "Hello"},
534+
}
535+
536+
_, err := provider.Chat(context.Background(), messages)
537+
if err == nil {
538+
t.Error("expected error for invalid JSON, got nil")
539+
}
540+
}
541+
542+
// TestOpenAIProviderModelID tests OpenAI ModelID.
543+
func TestOpenAIProviderModelID(t *testing.T) {
544+
provider := newOpenAIProviderWithConfig("test-api-key", "text-embedding-3-small", "gpt-4o-mini", "http://localhost", 60*time.Second)
545+
546+
if provider.ModelID() != "text-embedding-3-small" {
547+
t.Errorf("expected 'text-embedding-3-small', got '%s'", provider.ModelID())
548+
}
549+
}
550+
551+
// TestOpenAIProviderDimensions tests OpenAI Dimensions for different models.
552+
func TestOpenAIProviderDimensionsAllModels(t *testing.T) {
553+
tests := []struct {
554+
model string
555+
dimensions int
556+
}{
557+
{"text-embedding-3-small", 1536},
558+
{"text-embedding-3-large", 3072},
559+
{"text-embedding-ada-002", 1536},
560+
{"custom-model", 1536}, // default
561+
}
562+
563+
for _, tt := range tests {
564+
t.Run(tt.model, func(t *testing.T) {
565+
provider := newOpenAIProviderWithConfig("test-api-key", tt.model, "gpt-4o-mini", "http://localhost", 60*time.Second)
566+
if provider.Dimensions() != tt.dimensions {
567+
t.Errorf("expected %d dimensions, got %d", tt.dimensions, provider.Dimensions())
568+
}
569+
})
570+
}
571+
}
572+
573+
// TestOpenAIProviderContextCancelled tests context cancellation.
574+
func TestOpenAIProviderContextCancelled(t *testing.T) {
575+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
576+
time.Sleep(100 * time.Millisecond)
577+
w.WriteHeader(http.StatusOK)
578+
}))
579+
defer server.Close()
580+
581+
provider := newOpenAIProviderWithConfig("test-api-key", "text-embedding-3-small", "gpt-4o-mini", server.URL, 60*time.Second)
582+
583+
ctx, cancel := context.WithCancel(context.Background())
584+
cancel()
585+
586+
_, err := provider.Embed(ctx, "test text")
587+
if err == nil {
588+
t.Error("expected error with cancelled context")
589+
}
590+
}
591+
592+
// Helper function for string contains check
593+
func contains(s, substr string) bool {
594+
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
595+
}
596+
597+
func containsHelper(s, substr string) bool {
598+
for i := 0; i <= len(s)-len(substr); i++ {
599+
if s[i:i+len(substr)] == substr {
600+
return true
601+
}
602+
}
603+
return false
604+
}

0 commit comments

Comments
 (0)