264 lines
6.2 KiB
Go
264 lines
6.2 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.wisehodl.dev/jay/aicli/config"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
type sequenceTransport struct {
|
|
responses []*http.Response
|
|
index int
|
|
}
|
|
|
|
func (t *sequenceTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
|
if t.index >= len(t.responses) {
|
|
return nil, fmt.Errorf("no more responses in sequence")
|
|
}
|
|
resp := t.responses[t.index]
|
|
t.index++
|
|
return resp, nil
|
|
}
|
|
|
|
func TestTryModel(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg config.ConfigData
|
|
model string
|
|
query string
|
|
mockResp *http.Response
|
|
mockErr error
|
|
want string
|
|
wantErr bool
|
|
errContains string
|
|
}{
|
|
{
|
|
name: "successful request",
|
|
cfg: config.ConfigData{
|
|
Protocol: config.ProtocolOpenAI,
|
|
URL: "https://api.example.com",
|
|
APIKey: "sk-test",
|
|
},
|
|
model: "gpt-4",
|
|
query: "test query",
|
|
mockResp: makeResponse(200, `{"choices":[{"message":{"content":"response text"}}]}`),
|
|
want: "response text",
|
|
},
|
|
{
|
|
name: "http error",
|
|
cfg: config.ConfigData{
|
|
Protocol: config.ProtocolOpenAI,
|
|
URL: "https://api.example.com",
|
|
APIKey: "sk-test",
|
|
},
|
|
model: "gpt-4",
|
|
query: "test query",
|
|
mockResp: makeResponse(500, `{"error":"server error"}`),
|
|
wantErr: true,
|
|
errContains: "HTTP 500",
|
|
},
|
|
{
|
|
name: "parse error",
|
|
cfg: config.ConfigData{
|
|
Protocol: config.ProtocolOpenAI,
|
|
URL: "https://api.example.com",
|
|
APIKey: "sk-test",
|
|
},
|
|
model: "gpt-4",
|
|
query: "test query",
|
|
mockResp: makeResponse(200, `{"choices":[]}`),
|
|
wantErr: true,
|
|
errContains: "empty choices array",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
transport := &mockRoundTripper{
|
|
response: tt.mockResp,
|
|
err: tt.mockErr,
|
|
}
|
|
oldClient := httpClient
|
|
httpClient = &http.Client{
|
|
Timeout: 5 * time.Minute,
|
|
Transport: transport,
|
|
}
|
|
defer func() { httpClient = oldClient }()
|
|
|
|
got, err := tryModel(tt.cfg, tt.model, tt.query)
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
if tt.errContains != "" {
|
|
assert.Contains(t, err.Error(), tt.errContains)
|
|
}
|
|
return
|
|
}
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func captureStderr(f func()) string {
|
|
old := os.Stderr
|
|
r, w, _ := os.Pipe()
|
|
os.Stderr = w
|
|
|
|
f()
|
|
|
|
w.Close()
|
|
os.Stderr = old
|
|
|
|
var buf bytes.Buffer
|
|
io.Copy(&buf, r)
|
|
return buf.String()
|
|
}
|
|
|
|
func TestSendChatRequest(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg config.ConfigData
|
|
query string
|
|
mockResp []*http.Response
|
|
wantResponse string
|
|
wantModel string
|
|
wantErr bool
|
|
errContains string
|
|
checkStderr func(*testing.T, string)
|
|
}{
|
|
{
|
|
name: "primary model succeeds",
|
|
cfg: config.ConfigData{
|
|
Protocol: config.ProtocolOpenAI,
|
|
URL: "https://api.example.com",
|
|
APIKey: "sk-test",
|
|
Model: "gpt-4",
|
|
FallbackModels: []string{"gpt-3.5"},
|
|
},
|
|
query: "test",
|
|
mockResp: []*http.Response{
|
|
makeResponse(200, `{"choices":[{"message":{"content":"primary response"}}]}`),
|
|
},
|
|
wantResponse: "primary response",
|
|
wantModel: "gpt-4",
|
|
},
|
|
{
|
|
name: "primary fails, fallback succeeds",
|
|
cfg: config.ConfigData{
|
|
Protocol: config.ProtocolOpenAI,
|
|
URL: "https://api.example.com",
|
|
APIKey: "sk-test",
|
|
Model: "gpt-4",
|
|
FallbackModels: []string{"gpt-3.5"},
|
|
},
|
|
query: "test",
|
|
mockResp: []*http.Response{
|
|
makeResponse(500, `{"error":"server error"}`),
|
|
makeResponse(200, `{"choices":[{"message":{"content":"fallback response"}}]}`),
|
|
},
|
|
wantResponse: "fallback response",
|
|
wantModel: "gpt-3.5",
|
|
checkStderr: func(t *testing.T, stderr string) {
|
|
assert.Contains(t, stderr, "Model gpt-4 failed")
|
|
assert.Contains(t, stderr, "trying gpt-3.5")
|
|
},
|
|
},
|
|
{
|
|
name: "all models fail",
|
|
cfg: config.ConfigData{
|
|
Protocol: config.ProtocolOpenAI,
|
|
URL: "https://api.example.com",
|
|
APIKey: "sk-test",
|
|
Model: "gpt-4",
|
|
FallbackModels: []string{"gpt-3.5"},
|
|
},
|
|
query: "test",
|
|
mockResp: []*http.Response{
|
|
makeResponse(500, `{"error":"error1"}`),
|
|
makeResponse(500, `{"error":"error2"}`),
|
|
},
|
|
wantErr: true,
|
|
errContains: "all models failed",
|
|
checkStderr: func(t *testing.T, stderr string) {
|
|
assert.Contains(t, stderr, "Model gpt-4 failed")
|
|
assert.Contains(t, stderr, "Model gpt-3.5 failed")
|
|
},
|
|
},
|
|
{
|
|
name: "quiet mode suppresses progress",
|
|
cfg: config.ConfigData{
|
|
Protocol: config.ProtocolOpenAI,
|
|
URL: "https://api.example.com",
|
|
APIKey: "sk-test",
|
|
Model: "gpt-4",
|
|
FallbackModels: []string{"gpt-3.5"},
|
|
Quiet: true,
|
|
},
|
|
query: "test",
|
|
mockResp: []*http.Response{
|
|
makeResponse(500, `{"error":"error1"}`),
|
|
makeResponse(200, `{"choices":[{"message":{"content":"response"}}]}`),
|
|
},
|
|
wantResponse: "response",
|
|
wantModel: "gpt-3.5",
|
|
checkStderr: func(t *testing.T, stderr string) {
|
|
assert.Empty(t, stderr)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
transport := &sequenceTransport{responses: tt.mockResp}
|
|
|
|
oldClient := httpClient
|
|
httpClient = &http.Client{
|
|
Timeout: 5 * time.Minute,
|
|
Transport: transport,
|
|
}
|
|
defer func() { httpClient = oldClient }()
|
|
|
|
var stderr string
|
|
var response string
|
|
var model string
|
|
var duration time.Duration
|
|
var err error
|
|
|
|
if tt.checkStderr != nil {
|
|
stderr = captureStderr(func() {
|
|
response, model, duration, err = SendChatRequest(tt.cfg, tt.query)
|
|
})
|
|
} else {
|
|
response, model, duration, err = SendChatRequest(tt.cfg, tt.query)
|
|
}
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
if tt.errContains != "" {
|
|
assert.Contains(t, err.Error(), tt.errContains)
|
|
}
|
|
if tt.checkStderr != nil {
|
|
tt.checkStderr(t, stderr)
|
|
}
|
|
return
|
|
}
|
|
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.wantResponse, response)
|
|
assert.Equal(t, tt.wantModel, model)
|
|
assert.Greater(t, duration, time.Duration(0))
|
|
|
|
if tt.checkStderr != nil {
|
|
tt.checkStderr(t, stderr)
|
|
}
|
|
})
|
|
}
|
|
}
|