Refactored, added comprehensive testing.
All checks were successful
Release / release (push) Successful in 3m17s

This commit is contained in:
Jay
2025-10-26 23:23:43 -04:00
parent ec32b75267
commit 1936f055e2
61 changed files with 4678 additions and 769 deletions

62
api/api.go Normal file
View File

@@ -0,0 +1,62 @@
package api
import (
"encoding/json"
"fmt"
"os"
"time"
"git.wisehodl.dev/jay/aicli/config"
)
// tryModel attempts a single model request through the complete pipeline:
// payload construction, HTTP execution, and response parsing.
func tryModel(cfg config.ConfigData, model string, query string) (string, error) {
payload := buildPayload(cfg, model, query)
if cfg.Verbose {
payloadJSON, _ := json.Marshal(payload)
fmt.Fprintf(os.Stderr, "[verbose] Request payload: %s\n", string(payloadJSON))
}
body, err := executeHTTP(cfg, payload)
if err != nil {
return "", err
}
if cfg.Verbose {
fmt.Fprintf(os.Stderr, "[verbose] Response: %s\n", string(body))
}
response, err := parseResponse(body, cfg.Protocol)
if err != nil {
return "", err
}
return response, nil
}
// SendChatRequest sends a query to the configured model with automatic fallback.
// Returns the response content, the model name that succeeded, total duration, and any error.
// On failure, attempts each fallback model in sequence until one succeeds or all fail.
func SendChatRequest(cfg config.ConfigData, query string) (string, string, time.Duration, error) {
models := append([]string{cfg.Model}, cfg.FallbackModels...)
start := time.Now()
for i, model := range models {
if !cfg.Quiet && i > 0 {
fmt.Fprintf(os.Stderr, "Model %s failed, trying %s...\n", models[i-1], model)
}
response, err := tryModel(cfg, model, query)
if err == nil {
return response, model, time.Since(start), nil
}
if !cfg.Quiet {
fmt.Fprintf(os.Stderr, "Model %s failed: %v\n", model, err)
}
}
return "", "", time.Since(start), fmt.Errorf("all models failed")
}

263
api/api_test.go Normal file
View File

@@ -0,0 +1,263 @@
package api
import (
"bytes"
"fmt"
"io"
"net/http"
"os"
"testing"
"time"
"git.wisehodl.dev/jay/aicli/config"
"github.com/stretchr/testify/assert"
)
type sequenceTransport struct {
responses []*http.Response
index int
}
func (t *sequenceTransport) RoundTrip(*http.Request) (*http.Response, error) {
if t.index >= len(t.responses) {
return nil, fmt.Errorf("no more responses in sequence")
}
resp := t.responses[t.index]
t.index++
return resp, nil
}
func TestTryModel(t *testing.T) {
tests := []struct {
name string
cfg config.ConfigData
model string
query string
mockResp *http.Response
mockErr error
want string
wantErr bool
errContains string
}{
{
name: "successful request",
cfg: config.ConfigData{
Protocol: config.ProtocolOpenAI,
URL: "https://api.example.com",
APIKey: "sk-test",
},
model: "gpt-4",
query: "test query",
mockResp: makeResponse(200, `{"choices":[{"message":{"content":"response text"}}]}`),
want: "response text",
},
{
name: "http error",
cfg: config.ConfigData{
Protocol: config.ProtocolOpenAI,
URL: "https://api.example.com",
APIKey: "sk-test",
},
model: "gpt-4",
query: "test query",
mockResp: makeResponse(500, `{"error":"server error"}`),
wantErr: true,
errContains: "HTTP 500",
},
{
name: "parse error",
cfg: config.ConfigData{
Protocol: config.ProtocolOpenAI,
URL: "https://api.example.com",
APIKey: "sk-test",
},
model: "gpt-4",
query: "test query",
mockResp: makeResponse(200, `{"choices":[]}`),
wantErr: true,
errContains: "empty choices array",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
transport := &mockRoundTripper{
response: tt.mockResp,
err: tt.mockErr,
}
oldClient := httpClient
httpClient = &http.Client{
Timeout: 5 * time.Minute,
Transport: transport,
}
defer func() { httpClient = oldClient }()
got, err := tryModel(tt.cfg, tt.model, tt.query)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
func captureStderr(f func()) string {
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
f()
w.Close()
os.Stderr = old
var buf bytes.Buffer
io.Copy(&buf, r)
return buf.String()
}
func TestSendChatRequest(t *testing.T) {
tests := []struct {
name string
cfg config.ConfigData
query string
mockResp []*http.Response
wantResponse string
wantModel string
wantErr bool
errContains string
checkStderr func(*testing.T, string)
}{
{
name: "primary model succeeds",
cfg: config.ConfigData{
Protocol: config.ProtocolOpenAI,
URL: "https://api.example.com",
APIKey: "sk-test",
Model: "gpt-4",
FallbackModels: []string{"gpt-3.5"},
},
query: "test",
mockResp: []*http.Response{
makeResponse(200, `{"choices":[{"message":{"content":"primary response"}}]}`),
},
wantResponse: "primary response",
wantModel: "gpt-4",
},
{
name: "primary fails, fallback succeeds",
cfg: config.ConfigData{
Protocol: config.ProtocolOpenAI,
URL: "https://api.example.com",
APIKey: "sk-test",
Model: "gpt-4",
FallbackModels: []string{"gpt-3.5"},
},
query: "test",
mockResp: []*http.Response{
makeResponse(500, `{"error":"server error"}`),
makeResponse(200, `{"choices":[{"message":{"content":"fallback response"}}]}`),
},
wantResponse: "fallback response",
wantModel: "gpt-3.5",
checkStderr: func(t *testing.T, stderr string) {
assert.Contains(t, stderr, "Model gpt-4 failed")
assert.Contains(t, stderr, "trying gpt-3.5")
},
},
{
name: "all models fail",
cfg: config.ConfigData{
Protocol: config.ProtocolOpenAI,
URL: "https://api.example.com",
APIKey: "sk-test",
Model: "gpt-4",
FallbackModels: []string{"gpt-3.5"},
},
query: "test",
mockResp: []*http.Response{
makeResponse(500, `{"error":"error1"}`),
makeResponse(500, `{"error":"error2"}`),
},
wantErr: true,
errContains: "all models failed",
checkStderr: func(t *testing.T, stderr string) {
assert.Contains(t, stderr, "Model gpt-4 failed")
assert.Contains(t, stderr, "Model gpt-3.5 failed")
},
},
{
name: "quiet mode suppresses progress",
cfg: config.ConfigData{
Protocol: config.ProtocolOpenAI,
URL: "https://api.example.com",
APIKey: "sk-test",
Model: "gpt-4",
FallbackModels: []string{"gpt-3.5"},
Quiet: true,
},
query: "test",
mockResp: []*http.Response{
makeResponse(500, `{"error":"error1"}`),
makeResponse(200, `{"choices":[{"message":{"content":"response"}}]}`),
},
wantResponse: "response",
wantModel: "gpt-3.5",
checkStderr: func(t *testing.T, stderr string) {
assert.Empty(t, stderr)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
transport := &sequenceTransport{responses: tt.mockResp}
oldClient := httpClient
httpClient = &http.Client{
Timeout: 5 * time.Minute,
Transport: transport,
}
defer func() { httpClient = oldClient }()
var stderr string
var response string
var model string
var duration time.Duration
var err error
if tt.checkStderr != nil {
stderr = captureStderr(func() {
response, model, duration, err = SendChatRequest(tt.cfg, tt.query)
})
} else {
response, model, duration, err = SendChatRequest(tt.cfg, tt.query)
}
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
if tt.checkStderr != nil {
tt.checkStderr(t, stderr)
}
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantResponse, response)
assert.Equal(t, tt.wantModel, model)
assert.Greater(t, duration, time.Duration(0))
if tt.checkStderr != nil {
tt.checkStderr(t, stderr)
}
})
}
}

47
api/http.go Normal file
View File

@@ -0,0 +1,47 @@
package api
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"git.wisehodl.dev/jay/aicli/config"
)
var httpClient = &http.Client{Timeout: 5 * time.Minute}
// executeHTTP sends the payload to the API endpoint and returns the response body.
func executeHTTP(cfg config.ConfigData, payload map[string]interface{}) ([]byte, error) {
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal payload: %w", err)
}
req, err := http.NewRequest("POST", cfg.URL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", cfg.APIKey))
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("execute request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
return respBody, nil
}

193
api/http_test.go Normal file
View File

@@ -0,0 +1,193 @@
package api
import (
"bytes"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
"git.wisehodl.dev/jay/aicli/config"
"github.com/stretchr/testify/assert"
)
type mockRoundTripper struct {
response *http.Response
err error
request *http.Request
}
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
m.request = req
return m.response, m.err
}
func makeResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
}
}
func TestExecuteHTTP(t *testing.T) {
tests := []struct {
name string
cfg config.ConfigData
payload map[string]interface{}
mockResp *http.Response
mockErr error
wantBody string
wantErr bool
errContains string
}{
{
name: "successful request",
cfg: config.ConfigData{
URL: "https://api.example.com/chat",
APIKey: "sk-test123",
},
payload: map[string]interface{}{
"model": "gpt-4",
},
mockResp: makeResponse(200, `{"choices":[{"message":{"content":"response"}}]}`),
wantBody: `{"choices":[{"message":{"content":"response"}}]}`,
wantErr: false,
},
{
name: "HTTP 400 error",
cfg: config.ConfigData{
URL: "https://api.example.com/chat",
APIKey: "sk-test123",
},
payload: map[string]interface{}{"model": "gpt-4"},
mockResp: makeResponse(400, `{"error":"bad request"}`),
wantErr: true,
errContains: "HTTP 400",
},
{
name: "HTTP 401 unauthorized",
cfg: config.ConfigData{
URL: "https://api.example.com/chat",
APIKey: "invalid-key",
},
payload: map[string]interface{}{"model": "gpt-4"},
mockResp: makeResponse(401, `{"error":"unauthorized"}`),
wantErr: true,
errContains: "HTTP 401",
},
{
name: "HTTP 429 rate limit",
cfg: config.ConfigData{
URL: "https://api.example.com/chat",
APIKey: "sk-test123",
},
payload: map[string]interface{}{"model": "gpt-4"},
mockResp: makeResponse(429, `{"error":"rate limit exceeded"}`),
wantErr: true,
errContains: "HTTP 429",
},
{
name: "HTTP 500 server error",
cfg: config.ConfigData{
URL: "https://api.example.com/chat",
APIKey: "sk-test123",
},
payload: map[string]interface{}{"model": "gpt-4"},
mockResp: makeResponse(500, `{"error":"internal server error"}`),
wantErr: true,
errContains: "HTTP 500",
},
{
name: "network error",
cfg: config.ConfigData{
URL: "https://api.example.com/chat",
APIKey: "sk-test123",
},
payload: map[string]interface{}{"model": "gpt-4"},
mockErr: http.ErrHandlerTimeout,
wantErr: true,
errContains: "execute request",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
transport := &mockRoundTripper{
response: tt.mockResp,
err: tt.mockErr,
}
oldClient := httpClient
httpClient = &http.Client{
Timeout: 5 * time.Minute,
Transport: transport,
}
defer func() { httpClient = oldClient }()
got, err := executeHTTP(tt.cfg, tt.payload)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantBody, string(got))
})
}
}
func TestExecuteHTTPHeaders(t *testing.T) {
cfg := config.ConfigData{
URL: "https://api.example.com/chat",
APIKey: "sk-test-key",
}
payload := map[string]interface{}{"model": "gpt-4"}
transport := &mockRoundTripper{
response: makeResponse(200, `{"result":"ok"}`),
}
oldClient := httpClient
httpClient = &http.Client{
Timeout: 5 * time.Minute,
Transport: transport,
}
defer func() { httpClient = oldClient }()
_, err := executeHTTP(cfg, payload)
assert.NoError(t, err)
assert.Equal(t, "application/json", transport.request.Header.Get("Content-Type"))
assert.Equal(t, "Bearer sk-test-key", transport.request.Header.Get("Authorization"))
}
func TestExecuteHTTPTimeout(t *testing.T) {
cfg := config.ConfigData{
URL: "https://api.example.com/chat",
APIKey: "sk-test123",
}
payload := map[string]interface{}{"model": "gpt-4"}
transport := &mockRoundTripper{
response: makeResponse(200, `{"ok":true}`),
}
client := &http.Client{
Timeout: 5 * time.Minute,
Transport: transport,
}
body, _ := json.Marshal(payload)
req, _ := http.NewRequest("POST", cfg.URL, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
resp, err := client.Do(req)
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
}

51
api/parse.go Normal file
View File

@@ -0,0 +1,51 @@
package api
import (
"encoding/json"
"fmt"
"git.wisehodl.dev/jay/aicli/config"
)
// parseResponse extracts the response content from the API response body.
func parseResponse(body []byte, protocol config.APIProtocol) (string, error) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("parse response: %w", err)
}
if protocol == config.ProtocolOllama {
response, ok := result["response"].(string)
if !ok {
return "", fmt.Errorf("no response field in ollama response")
}
return response, nil
}
// OpenAI protocol
choices, ok := result["choices"].([]interface{})
if !ok {
return "", fmt.Errorf("no choices in response")
}
if len(choices) == 0 {
return "", fmt.Errorf("empty choices array")
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid choice format")
}
message, ok := firstChoice["message"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("no message in choice")
}
content, ok := message["content"].(string)
if !ok {
return "", fmt.Errorf("no content in message")
}
return content, nil
}

159
api/parse_test.go Normal file
View File

@@ -0,0 +1,159 @@
package api
import (
"os"
"testing"
"git.wisehodl.dev/jay/aicli/config"
"github.com/stretchr/testify/assert"
)
func TestParseResponse(t *testing.T) {
tests := []struct {
name string
body string
protocol config.APIProtocol
want string
wantErr bool
errContains string
}{
{
name: "openai success",
body: `{"choices":[{"message":{"content":"This is the response text."}}]}`,
protocol: config.ProtocolOpenAI,
want: "This is the response text.",
},
{
name: "openai empty choices",
body: `{"choices":[]}`,
protocol: config.ProtocolOpenAI,
wantErr: true,
errContains: "empty choices array",
},
{
name: "openai no choices field",
body: `{"result":"ok"}`,
protocol: config.ProtocolOpenAI,
wantErr: true,
errContains: "no choices in response",
},
{
name: "openai invalid choice format",
body: `{"choices":["invalid"]}`,
protocol: config.ProtocolOpenAI,
wantErr: true,
errContains: "invalid choice format",
},
{
name: "openai no message field",
body: `{"choices":[{"text":"wrong structure"}]}`,
protocol: config.ProtocolOpenAI,
wantErr: true,
errContains: "no message in choice",
},
{
name: "openai no content field",
body: `{"choices":[{"message":{"role":"assistant"}}]}`,
protocol: config.ProtocolOpenAI,
wantErr: true,
errContains: "no content in message",
},
{
name: "ollama success",
body: `{"response":"This is the Ollama response."}`,
protocol: config.ProtocolOllama,
want: "This is the Ollama response.",
},
{
name: "ollama no response field",
body: `{"model":"llama3"}`,
protocol: config.ProtocolOllama,
wantErr: true,
errContains: "no response field in ollama response",
},
{
name: "malformed json",
body: `{invalid json`,
protocol: config.ProtocolOpenAI,
wantErr: true,
errContains: "parse response",
},
{
name: "openai multiline content",
body: `{"choices":[{"message":{"content":"Line 1\nLine 2\nLine 3"}}]}`,
protocol: config.ProtocolOpenAI,
want: "Line 1\nLine 2\nLine 3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseResponse([]byte(tt.body), tt.protocol)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestParseResponseWithTestdata(t *testing.T) {
tests := []struct {
name string
file string
protocol config.APIProtocol
want string
wantErr bool
errContains string
}{
{
name: "openai success from file",
file: "testdata/openai_success.json",
protocol: config.ProtocolOpenAI,
want: "This is the response text.",
},
{
name: "openai empty choices from file",
file: "testdata/openai_empty_choices.json",
protocol: config.ProtocolOpenAI,
wantErr: true,
errContains: "empty choices array",
},
{
name: "ollama success from file",
file: "testdata/ollama_success.json",
protocol: config.ProtocolOllama,
want: "This is the Ollama response.",
},
{
name: "ollama no response from file",
file: "testdata/ollama_no_response.json",
protocol: config.ProtocolOllama,
wantErr: true,
errContains: "no response field",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body, err := os.ReadFile(tt.file)
assert.NoError(t, err, "failed to read test file")
got, err := parseResponse(body, tt.protocol)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}

36
api/payload.go Normal file
View File

@@ -0,0 +1,36 @@
package api
import "git.wisehodl.dev/jay/aicli/config"
// buildPayload constructs the JSON payload for the API request based on protocol.
func buildPayload(cfg config.ConfigData, model string, query string) map[string]interface{} {
if cfg.Protocol == config.ProtocolOllama {
payload := map[string]interface{}{
"model": model,
"prompt": query,
"stream": false,
}
if cfg.SystemPrompt != "" {
payload["system"] = cfg.SystemPrompt
}
return payload
}
// OpenAI protocol
messages := []map[string]string{}
if cfg.SystemPrompt != "" {
messages = append(messages, map[string]string{
"role": "system",
"content": cfg.SystemPrompt,
})
}
messages = append(messages, map[string]string{
"role": "user",
"content": query,
})
return map[string]interface{}{
"model": model,
"messages": messages,
}
}

126
api/payload_test.go Normal file
View File

@@ -0,0 +1,126 @@
package api
import (
"testing"
"git.wisehodl.dev/jay/aicli/config"
"github.com/stretchr/testify/assert"
)
func TestBuildPayload(t *testing.T) {
tests := []struct {
name string
cfg config.ConfigData
model string
query string
want map[string]interface{}
}{
{
name: "openai without system prompt",
cfg: config.ConfigData{
Protocol: config.ProtocolOpenAI,
},
model: "gpt-4",
query: "analyze this",
want: map[string]interface{}{
"model": "gpt-4",
"messages": []map[string]string{
{"role": "user", "content": "analyze this"},
},
},
},
{
name: "openai with system prompt",
cfg: config.ConfigData{
Protocol: config.ProtocolOpenAI,
SystemPrompt: "You are helpful",
},
model: "gpt-4",
query: "analyze this",
want: map[string]interface{}{
"model": "gpt-4",
"messages": []map[string]string{
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "analyze this"},
},
},
},
{
name: "ollama without system prompt",
cfg: config.ConfigData{
Protocol: config.ProtocolOllama,
},
model: "llama3",
query: "analyze this",
want: map[string]interface{}{
"model": "llama3",
"prompt": "analyze this",
"stream": false,
},
},
{
name: "ollama with system prompt",
cfg: config.ConfigData{
Protocol: config.ProtocolOllama,
SystemPrompt: "You are helpful",
},
model: "llama3",
query: "analyze this",
want: map[string]interface{}{
"model": "llama3",
"prompt": "analyze this",
"system": "You are helpful",
"stream": false,
},
},
{
name: "empty query",
cfg: config.ConfigData{
Protocol: config.ProtocolOpenAI,
},
model: "gpt-4",
query: "",
want: map[string]interface{}{
"model": "gpt-4",
"messages": []map[string]string{
{"role": "user", "content": ""},
},
},
},
{
name: "multiline query",
cfg: config.ConfigData{
Protocol: config.ProtocolOpenAI,
},
model: "gpt-4",
query: "line1\nline2\nline3",
want: map[string]interface{}{
"model": "gpt-4",
"messages": []map[string]string{
{"role": "user", "content": "line1\nline2\nline3"},
},
},
},
{
name: "model name injection",
cfg: config.ConfigData{
Protocol: config.ProtocolOpenAI,
},
model: "custom-model-name",
query: "test",
want: map[string]interface{}{
"model": "custom-model-name",
"messages": []map[string]string{
{"role": "user", "content": "test"},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildPayload(tt.cfg, tt.model, tt.query)
assert.Equal(t, tt.want, got)
})
}
}

3
api/testdata/ollama_no_response.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"model": "llama3"
}

3
api/testdata/ollama_success.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"response": "This is the Ollama response."
}

View File

@@ -0,0 +1,3 @@
{
"choices": []
}

9
api/testdata/openai_success.json vendored Normal file
View File

@@ -0,0 +1,9 @@
{
"choices": [
{
"message": {
"content": "This is the response text."
}
}
]
}