Refactored, added comprehensive testing.
All checks were successful
Release / release (push) Successful in 3m17s
All checks were successful
Release / release (push) Successful in 3m17s
This commit is contained in:
62
api/api.go
Normal file
62
api/api.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.wisehodl.dev/jay/aicli/config"
|
||||
)
|
||||
|
||||
// tryModel attempts a single model request through the complete pipeline:
|
||||
// payload construction, HTTP execution, and response parsing.
|
||||
func tryModel(cfg config.ConfigData, model string, query string) (string, error) {
|
||||
payload := buildPayload(cfg, model, query)
|
||||
|
||||
if cfg.Verbose {
|
||||
payloadJSON, _ := json.Marshal(payload)
|
||||
fmt.Fprintf(os.Stderr, "[verbose] Request payload: %s\n", string(payloadJSON))
|
||||
}
|
||||
|
||||
body, err := executeHTTP(cfg, payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if cfg.Verbose {
|
||||
fmt.Fprintf(os.Stderr, "[verbose] Response: %s\n", string(body))
|
||||
}
|
||||
|
||||
response, err := parseResponse(body, cfg.Protocol)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// SendChatRequest sends a query to the configured model with automatic fallback.
|
||||
// Returns the response content, the model name that succeeded, total duration, and any error.
|
||||
// On failure, attempts each fallback model in sequence until one succeeds or all fail.
|
||||
func SendChatRequest(cfg config.ConfigData, query string) (string, string, time.Duration, error) {
|
||||
models := append([]string{cfg.Model}, cfg.FallbackModels...)
|
||||
start := time.Now()
|
||||
|
||||
for i, model := range models {
|
||||
if !cfg.Quiet && i > 0 {
|
||||
fmt.Fprintf(os.Stderr, "Model %s failed, trying %s...\n", models[i-1], model)
|
||||
}
|
||||
|
||||
response, err := tryModel(cfg, model, query)
|
||||
if err == nil {
|
||||
return response, model, time.Since(start), nil
|
||||
}
|
||||
|
||||
if !cfg.Quiet {
|
||||
fmt.Fprintf(os.Stderr, "Model %s failed: %v\n", model, err)
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", time.Since(start), fmt.Errorf("all models failed")
|
||||
}
|
||||
263
api/api_test.go
Normal file
263
api/api_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.wisehodl.dev/jay/aicli/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type sequenceTransport struct {
|
||||
responses []*http.Response
|
||||
index int
|
||||
}
|
||||
|
||||
func (t *sequenceTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
if t.index >= len(t.responses) {
|
||||
return nil, fmt.Errorf("no more responses in sequence")
|
||||
}
|
||||
resp := t.responses[t.index]
|
||||
t.index++
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func TestTryModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.ConfigData
|
||||
model string
|
||||
query string
|
||||
mockResp *http.Response
|
||||
mockErr error
|
||||
want string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "successful request",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOpenAI,
|
||||
URL: "https://api.example.com",
|
||||
APIKey: "sk-test",
|
||||
},
|
||||
model: "gpt-4",
|
||||
query: "test query",
|
||||
mockResp: makeResponse(200, `{"choices":[{"message":{"content":"response text"}}]}`),
|
||||
want: "response text",
|
||||
},
|
||||
{
|
||||
name: "http error",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOpenAI,
|
||||
URL: "https://api.example.com",
|
||||
APIKey: "sk-test",
|
||||
},
|
||||
model: "gpt-4",
|
||||
query: "test query",
|
||||
mockResp: makeResponse(500, `{"error":"server error"}`),
|
||||
wantErr: true,
|
||||
errContains: "HTTP 500",
|
||||
},
|
||||
{
|
||||
name: "parse error",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOpenAI,
|
||||
URL: "https://api.example.com",
|
||||
APIKey: "sk-test",
|
||||
},
|
||||
model: "gpt-4",
|
||||
query: "test query",
|
||||
mockResp: makeResponse(200, `{"choices":[]}`),
|
||||
wantErr: true,
|
||||
errContains: "empty choices array",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
transport := &mockRoundTripper{
|
||||
response: tt.mockResp,
|
||||
err: tt.mockErr,
|
||||
}
|
||||
oldClient := httpClient
|
||||
httpClient = &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: transport,
|
||||
}
|
||||
defer func() { httpClient = oldClient }()
|
||||
|
||||
got, err := tryModel(tt.cfg, tt.model, tt.query)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func captureStderr(f func()) string {
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
f()
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func TestSendChatRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.ConfigData
|
||||
query string
|
||||
mockResp []*http.Response
|
||||
wantResponse string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
errContains string
|
||||
checkStderr func(*testing.T, string)
|
||||
}{
|
||||
{
|
||||
name: "primary model succeeds",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOpenAI,
|
||||
URL: "https://api.example.com",
|
||||
APIKey: "sk-test",
|
||||
Model: "gpt-4",
|
||||
FallbackModels: []string{"gpt-3.5"},
|
||||
},
|
||||
query: "test",
|
||||
mockResp: []*http.Response{
|
||||
makeResponse(200, `{"choices":[{"message":{"content":"primary response"}}]}`),
|
||||
},
|
||||
wantResponse: "primary response",
|
||||
wantModel: "gpt-4",
|
||||
},
|
||||
{
|
||||
name: "primary fails, fallback succeeds",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOpenAI,
|
||||
URL: "https://api.example.com",
|
||||
APIKey: "sk-test",
|
||||
Model: "gpt-4",
|
||||
FallbackModels: []string{"gpt-3.5"},
|
||||
},
|
||||
query: "test",
|
||||
mockResp: []*http.Response{
|
||||
makeResponse(500, `{"error":"server error"}`),
|
||||
makeResponse(200, `{"choices":[{"message":{"content":"fallback response"}}]}`),
|
||||
},
|
||||
wantResponse: "fallback response",
|
||||
wantModel: "gpt-3.5",
|
||||
checkStderr: func(t *testing.T, stderr string) {
|
||||
assert.Contains(t, stderr, "Model gpt-4 failed")
|
||||
assert.Contains(t, stderr, "trying gpt-3.5")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all models fail",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOpenAI,
|
||||
URL: "https://api.example.com",
|
||||
APIKey: "sk-test",
|
||||
Model: "gpt-4",
|
||||
FallbackModels: []string{"gpt-3.5"},
|
||||
},
|
||||
query: "test",
|
||||
mockResp: []*http.Response{
|
||||
makeResponse(500, `{"error":"error1"}`),
|
||||
makeResponse(500, `{"error":"error2"}`),
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "all models failed",
|
||||
checkStderr: func(t *testing.T, stderr string) {
|
||||
assert.Contains(t, stderr, "Model gpt-4 failed")
|
||||
assert.Contains(t, stderr, "Model gpt-3.5 failed")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "quiet mode suppresses progress",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOpenAI,
|
||||
URL: "https://api.example.com",
|
||||
APIKey: "sk-test",
|
||||
Model: "gpt-4",
|
||||
FallbackModels: []string{"gpt-3.5"},
|
||||
Quiet: true,
|
||||
},
|
||||
query: "test",
|
||||
mockResp: []*http.Response{
|
||||
makeResponse(500, `{"error":"error1"}`),
|
||||
makeResponse(200, `{"choices":[{"message":{"content":"response"}}]}`),
|
||||
},
|
||||
wantResponse: "response",
|
||||
wantModel: "gpt-3.5",
|
||||
checkStderr: func(t *testing.T, stderr string) {
|
||||
assert.Empty(t, stderr)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
transport := &sequenceTransport{responses: tt.mockResp}
|
||||
|
||||
oldClient := httpClient
|
||||
httpClient = &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: transport,
|
||||
}
|
||||
defer func() { httpClient = oldClient }()
|
||||
|
||||
var stderr string
|
||||
var response string
|
||||
var model string
|
||||
var duration time.Duration
|
||||
var err error
|
||||
|
||||
if tt.checkStderr != nil {
|
||||
stderr = captureStderr(func() {
|
||||
response, model, duration, err = SendChatRequest(tt.cfg, tt.query)
|
||||
})
|
||||
} else {
|
||||
response, model, duration, err = SendChatRequest(tt.cfg, tt.query)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
if tt.checkStderr != nil {
|
||||
tt.checkStderr(t, stderr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantResponse, response)
|
||||
assert.Equal(t, tt.wantModel, model)
|
||||
assert.Greater(t, duration, time.Duration(0))
|
||||
|
||||
if tt.checkStderr != nil {
|
||||
tt.checkStderr(t, stderr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
47
api/http.go
Normal file
47
api/http.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.wisehodl.dev/jay/aicli/config"
|
||||
)
|
||||
|
||||
var httpClient = &http.Client{Timeout: 5 * time.Minute}
|
||||
|
||||
// executeHTTP sends the payload to the API endpoint and returns the response body.
|
||||
func executeHTTP(cfg config.ConfigData, payload map[string]interface{}) ([]byte, error) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", cfg.URL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", cfg.APIKey))
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return respBody, nil
|
||||
}
|
||||
193
api/http_test.go
Normal file
193
api/http_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.wisehodl.dev/jay/aicli/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockRoundTripper struct {
|
||||
response *http.Response
|
||||
err error
|
||||
request *http.Request
|
||||
}
|
||||
|
||||
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
m.request = req
|
||||
return m.response, m.err
|
||||
}
|
||||
|
||||
func makeResponse(statusCode int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.ConfigData
|
||||
payload map[string]interface{}
|
||||
mockResp *http.Response
|
||||
mockErr error
|
||||
wantBody string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "successful request",
|
||||
cfg: config.ConfigData{
|
||||
URL: "https://api.example.com/chat",
|
||||
APIKey: "sk-test123",
|
||||
},
|
||||
payload: map[string]interface{}{
|
||||
"model": "gpt-4",
|
||||
},
|
||||
mockResp: makeResponse(200, `{"choices":[{"message":{"content":"response"}}]}`),
|
||||
wantBody: `{"choices":[{"message":{"content":"response"}}]}`,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "HTTP 400 error",
|
||||
cfg: config.ConfigData{
|
||||
URL: "https://api.example.com/chat",
|
||||
APIKey: "sk-test123",
|
||||
},
|
||||
payload: map[string]interface{}{"model": "gpt-4"},
|
||||
mockResp: makeResponse(400, `{"error":"bad request"}`),
|
||||
wantErr: true,
|
||||
errContains: "HTTP 400",
|
||||
},
|
||||
{
|
||||
name: "HTTP 401 unauthorized",
|
||||
cfg: config.ConfigData{
|
||||
URL: "https://api.example.com/chat",
|
||||
APIKey: "invalid-key",
|
||||
},
|
||||
payload: map[string]interface{}{"model": "gpt-4"},
|
||||
mockResp: makeResponse(401, `{"error":"unauthorized"}`),
|
||||
wantErr: true,
|
||||
errContains: "HTTP 401",
|
||||
},
|
||||
{
|
||||
name: "HTTP 429 rate limit",
|
||||
cfg: config.ConfigData{
|
||||
URL: "https://api.example.com/chat",
|
||||
APIKey: "sk-test123",
|
||||
},
|
||||
payload: map[string]interface{}{"model": "gpt-4"},
|
||||
mockResp: makeResponse(429, `{"error":"rate limit exceeded"}`),
|
||||
wantErr: true,
|
||||
errContains: "HTTP 429",
|
||||
},
|
||||
{
|
||||
name: "HTTP 500 server error",
|
||||
cfg: config.ConfigData{
|
||||
URL: "https://api.example.com/chat",
|
||||
APIKey: "sk-test123",
|
||||
},
|
||||
payload: map[string]interface{}{"model": "gpt-4"},
|
||||
mockResp: makeResponse(500, `{"error":"internal server error"}`),
|
||||
wantErr: true,
|
||||
errContains: "HTTP 500",
|
||||
},
|
||||
{
|
||||
name: "network error",
|
||||
cfg: config.ConfigData{
|
||||
URL: "https://api.example.com/chat",
|
||||
APIKey: "sk-test123",
|
||||
},
|
||||
payload: map[string]interface{}{"model": "gpt-4"},
|
||||
mockErr: http.ErrHandlerTimeout,
|
||||
wantErr: true,
|
||||
errContains: "execute request",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
transport := &mockRoundTripper{
|
||||
response: tt.mockResp,
|
||||
err: tt.mockErr,
|
||||
}
|
||||
oldClient := httpClient
|
||||
httpClient = &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: transport,
|
||||
}
|
||||
defer func() { httpClient = oldClient }()
|
||||
|
||||
got, err := executeHTTP(tt.cfg, tt.payload)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantBody, string(got))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteHTTPHeaders(t *testing.T) {
|
||||
cfg := config.ConfigData{
|
||||
URL: "https://api.example.com/chat",
|
||||
APIKey: "sk-test-key",
|
||||
}
|
||||
payload := map[string]interface{}{"model": "gpt-4"}
|
||||
|
||||
transport := &mockRoundTripper{
|
||||
response: makeResponse(200, `{"result":"ok"}`),
|
||||
}
|
||||
oldClient := httpClient
|
||||
httpClient = &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: transport,
|
||||
}
|
||||
defer func() { httpClient = oldClient }()
|
||||
|
||||
_, err := executeHTTP(cfg, payload)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "application/json", transport.request.Header.Get("Content-Type"))
|
||||
assert.Equal(t, "Bearer sk-test-key", transport.request.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
func TestExecuteHTTPTimeout(t *testing.T) {
|
||||
cfg := config.ConfigData{
|
||||
URL: "https://api.example.com/chat",
|
||||
APIKey: "sk-test123",
|
||||
}
|
||||
payload := map[string]interface{}{"model": "gpt-4"}
|
||||
|
||||
transport := &mockRoundTripper{
|
||||
response: makeResponse(200, `{"ok":true}`),
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(payload)
|
||||
req, _ := http.NewRequest("POST", cfg.URL, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
51
api/parse.go
Normal file
51
api/parse.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"git.wisehodl.dev/jay/aicli/config"
|
||||
)
|
||||
|
||||
// parseResponse extracts the response content from the API response body.
|
||||
func parseResponse(body []byte, protocol config.APIProtocol) (string, error) {
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if protocol == config.ProtocolOllama {
|
||||
response, ok := result["response"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no response field in ollama response")
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// OpenAI protocol
|
||||
choices, ok := result["choices"].([]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no choices in response")
|
||||
}
|
||||
|
||||
if len(choices) == 0 {
|
||||
return "", fmt.Errorf("empty choices array")
|
||||
}
|
||||
|
||||
firstChoice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid choice format")
|
||||
}
|
||||
|
||||
message, ok := firstChoice["message"].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no message in choice")
|
||||
}
|
||||
|
||||
content, ok := message["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no content in message")
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
159
api/parse_test.go
Normal file
159
api/parse_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"git.wisehodl.dev/jay/aicli/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
protocol config.APIProtocol
|
||||
want string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "openai success",
|
||||
body: `{"choices":[{"message":{"content":"This is the response text."}}]}`,
|
||||
protocol: config.ProtocolOpenAI,
|
||||
want: "This is the response text.",
|
||||
},
|
||||
{
|
||||
name: "openai empty choices",
|
||||
body: `{"choices":[]}`,
|
||||
protocol: config.ProtocolOpenAI,
|
||||
wantErr: true,
|
||||
errContains: "empty choices array",
|
||||
},
|
||||
{
|
||||
name: "openai no choices field",
|
||||
body: `{"result":"ok"}`,
|
||||
protocol: config.ProtocolOpenAI,
|
||||
wantErr: true,
|
||||
errContains: "no choices in response",
|
||||
},
|
||||
{
|
||||
name: "openai invalid choice format",
|
||||
body: `{"choices":["invalid"]}`,
|
||||
protocol: config.ProtocolOpenAI,
|
||||
wantErr: true,
|
||||
errContains: "invalid choice format",
|
||||
},
|
||||
{
|
||||
name: "openai no message field",
|
||||
body: `{"choices":[{"text":"wrong structure"}]}`,
|
||||
protocol: config.ProtocolOpenAI,
|
||||
wantErr: true,
|
||||
errContains: "no message in choice",
|
||||
},
|
||||
{
|
||||
name: "openai no content field",
|
||||
body: `{"choices":[{"message":{"role":"assistant"}}]}`,
|
||||
protocol: config.ProtocolOpenAI,
|
||||
wantErr: true,
|
||||
errContains: "no content in message",
|
||||
},
|
||||
{
|
||||
name: "ollama success",
|
||||
body: `{"response":"This is the Ollama response."}`,
|
||||
protocol: config.ProtocolOllama,
|
||||
want: "This is the Ollama response.",
|
||||
},
|
||||
{
|
||||
name: "ollama no response field",
|
||||
body: `{"model":"llama3"}`,
|
||||
protocol: config.ProtocolOllama,
|
||||
wantErr: true,
|
||||
errContains: "no response field in ollama response",
|
||||
},
|
||||
{
|
||||
name: "malformed json",
|
||||
body: `{invalid json`,
|
||||
protocol: config.ProtocolOpenAI,
|
||||
wantErr: true,
|
||||
errContains: "parse response",
|
||||
},
|
||||
{
|
||||
name: "openai multiline content",
|
||||
body: `{"choices":[{"message":{"content":"Line 1\nLine 2\nLine 3"}}]}`,
|
||||
protocol: config.ProtocolOpenAI,
|
||||
want: "Line 1\nLine 2\nLine 3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parseResponse([]byte(tt.body), tt.protocol)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponseWithTestdata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
file string
|
||||
protocol config.APIProtocol
|
||||
want string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "openai success from file",
|
||||
file: "testdata/openai_success.json",
|
||||
protocol: config.ProtocolOpenAI,
|
||||
want: "This is the response text.",
|
||||
},
|
||||
{
|
||||
name: "openai empty choices from file",
|
||||
file: "testdata/openai_empty_choices.json",
|
||||
protocol: config.ProtocolOpenAI,
|
||||
wantErr: true,
|
||||
errContains: "empty choices array",
|
||||
},
|
||||
{
|
||||
name: "ollama success from file",
|
||||
file: "testdata/ollama_success.json",
|
||||
protocol: config.ProtocolOllama,
|
||||
want: "This is the Ollama response.",
|
||||
},
|
||||
{
|
||||
name: "ollama no response from file",
|
||||
file: "testdata/ollama_no_response.json",
|
||||
protocol: config.ProtocolOllama,
|
||||
wantErr: true,
|
||||
errContains: "no response field",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body, err := os.ReadFile(tt.file)
|
||||
assert.NoError(t, err, "failed to read test file")
|
||||
|
||||
got, err := parseResponse(body, tt.protocol)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
36
api/payload.go
Normal file
36
api/payload.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package api
|
||||
|
||||
import "git.wisehodl.dev/jay/aicli/config"
|
||||
|
||||
// buildPayload constructs the JSON payload for the API request based on protocol.
|
||||
func buildPayload(cfg config.ConfigData, model string, query string) map[string]interface{} {
|
||||
if cfg.Protocol == config.ProtocolOllama {
|
||||
payload := map[string]interface{}{
|
||||
"model": model,
|
||||
"prompt": query,
|
||||
"stream": false,
|
||||
}
|
||||
if cfg.SystemPrompt != "" {
|
||||
payload["system"] = cfg.SystemPrompt
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// OpenAI protocol
|
||||
messages := []map[string]string{}
|
||||
if cfg.SystemPrompt != "" {
|
||||
messages = append(messages, map[string]string{
|
||||
"role": "system",
|
||||
"content": cfg.SystemPrompt,
|
||||
})
|
||||
}
|
||||
messages = append(messages, map[string]string{
|
||||
"role": "user",
|
||||
"content": query,
|
||||
})
|
||||
|
||||
return map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
}
|
||||
126
api/payload_test.go
Normal file
126
api/payload_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.wisehodl.dev/jay/aicli/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBuildPayload(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.ConfigData
|
||||
model string
|
||||
query string
|
||||
want map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "openai without system prompt",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOpenAI,
|
||||
},
|
||||
model: "gpt-4",
|
||||
query: "analyze this",
|
||||
want: map[string]interface{}{
|
||||
"model": "gpt-4",
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "analyze this"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "openai with system prompt",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOpenAI,
|
||||
SystemPrompt: "You are helpful",
|
||||
},
|
||||
model: "gpt-4",
|
||||
query: "analyze this",
|
||||
want: map[string]interface{}{
|
||||
"model": "gpt-4",
|
||||
"messages": []map[string]string{
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "analyze this"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ollama without system prompt",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOllama,
|
||||
},
|
||||
model: "llama3",
|
||||
query: "analyze this",
|
||||
want: map[string]interface{}{
|
||||
"model": "llama3",
|
||||
"prompt": "analyze this",
|
||||
"stream": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ollama with system prompt",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOllama,
|
||||
SystemPrompt: "You are helpful",
|
||||
},
|
||||
model: "llama3",
|
||||
query: "analyze this",
|
||||
want: map[string]interface{}{
|
||||
"model": "llama3",
|
||||
"prompt": "analyze this",
|
||||
"system": "You are helpful",
|
||||
"stream": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty query",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOpenAI,
|
||||
},
|
||||
model: "gpt-4",
|
||||
query: "",
|
||||
want: map[string]interface{}{
|
||||
"model": "gpt-4",
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": ""},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiline query",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOpenAI,
|
||||
},
|
||||
model: "gpt-4",
|
||||
query: "line1\nline2\nline3",
|
||||
want: map[string]interface{}{
|
||||
"model": "gpt-4",
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "line1\nline2\nline3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "model name injection",
|
||||
cfg: config.ConfigData{
|
||||
Protocol: config.ProtocolOpenAI,
|
||||
},
|
||||
model: "custom-model-name",
|
||||
query: "test",
|
||||
want: map[string]interface{}{
|
||||
"model": "custom-model-name",
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := buildPayload(tt.cfg, tt.model, tt.query)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
3
api/testdata/ollama_no_response.json
vendored
Normal file
3
api/testdata/ollama_no_response.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"model": "llama3"
|
||||
}
|
||||
3
api/testdata/ollama_success.json
vendored
Normal file
3
api/testdata/ollama_success.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"response": "This is the Ollama response."
|
||||
}
|
||||
3
api/testdata/openai_empty_choices.json
vendored
Normal file
3
api/testdata/openai_empty_choices.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"choices": []
|
||||
}
|
||||
9
api/testdata/openai_success.json
vendored
Normal file
9
api/testdata/openai_success.json
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "This is the response text."
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user