361 lines
7.1 KiB
Go
361 lines
7.1 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func clearAICLIEnv(t *testing.T) {
|
|
t.Setenv("AICLI_API_KEY", "")
|
|
t.Setenv("AICLI_API_KEY_FILE", "")
|
|
t.Setenv("AICLI_PROTOCOL", "")
|
|
t.Setenv("AICLI_URL", "")
|
|
t.Setenv("AICLI_MODEL", "")
|
|
t.Setenv("AICLI_FALLBACK", "")
|
|
t.Setenv("AICLI_SYSTEM", "")
|
|
t.Setenv("AICLI_SYSTEM_FILE", "")
|
|
t.Setenv("AICLI_CONFIG_FILE", "")
|
|
t.Setenv("AICLI_PROMPT_FILE", "")
|
|
t.Setenv("AICLI_DEFAULT_PROMPT", "")
|
|
}
|
|
|
|
func TestRunVersionFlag(t *testing.T) {
|
|
clearAICLIEnv(t)
|
|
|
|
oldArgs := os.Args
|
|
defer func() { os.Args = oldArgs }()
|
|
|
|
old := os.Stdout
|
|
r, w, _ := os.Pipe()
|
|
os.Stdout = w
|
|
|
|
os.Args = []string{"aicli", "--version"}
|
|
|
|
err := run()
|
|
assert.NoError(t, err)
|
|
|
|
w.Close()
|
|
os.Stdout = old
|
|
|
|
var buf bytes.Buffer
|
|
io.Copy(&buf, r)
|
|
|
|
output := buf.String()
|
|
assert.Contains(t, output, "aicli")
|
|
assert.Contains(t, output, "dev")
|
|
}
|
|
|
|
func TestRunNoInput(t *testing.T) {
|
|
clearAICLIEnv(t)
|
|
|
|
oldArgs := os.Args
|
|
defer func() { os.Args = oldArgs }()
|
|
|
|
// Set minimal config to pass validation
|
|
t.Setenv("AICLI_API_KEY", "sk-test")
|
|
|
|
os.Args = []string{"aicli"}
|
|
|
|
err := run()
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "no input provided")
|
|
}
|
|
|
|
func TestRunMissingAPIKey(t *testing.T) {
|
|
clearAICLIEnv(t)
|
|
|
|
oldArgs := os.Args
|
|
defer func() { os.Args = oldArgs }()
|
|
|
|
// Clear all API key sources
|
|
t.Setenv("AICLI_API_KEY", "")
|
|
t.Setenv("AICLI_API_KEY_FILE", "")
|
|
|
|
os.Args = []string{"aicli", "-p", "test"}
|
|
|
|
err := run()
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "API key required")
|
|
}
|
|
|
|
func TestRunCompleteFlow(t *testing.T) {
|
|
clearAICLIEnv(t)
|
|
|
|
// Setup mock API server
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"choices":[{"message":{"content":"mock response"}}]}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
oldArgs := os.Args
|
|
defer func() { os.Args = oldArgs }()
|
|
|
|
// Capture stdout
|
|
oldStdout := os.Stdout
|
|
r, w, _ := os.Pipe()
|
|
os.Stdout = w
|
|
|
|
t.Setenv("AICLI_API_KEY", "sk-test")
|
|
|
|
os.Args = []string{
|
|
"aicli",
|
|
"-u", server.URL,
|
|
"-p", "test prompt",
|
|
"-q",
|
|
}
|
|
|
|
err := run()
|
|
|
|
w.Close()
|
|
os.Stdout = oldStdout
|
|
|
|
assert.NoError(t, err)
|
|
|
|
var buf bytes.Buffer
|
|
io.Copy(&buf, r)
|
|
|
|
output := buf.String()
|
|
assert.Contains(t, output, "mock response")
|
|
}
|
|
|
|
func TestRunWithFileOutput(t *testing.T) {
|
|
clearAICLIEnv(t)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"choices":[{"message":{"content":"file response"}}]}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
oldArgs := os.Args
|
|
defer func() { os.Args = oldArgs }()
|
|
|
|
tmpDir := t.TempDir()
|
|
outputPath := filepath.Join(tmpDir, "output.txt")
|
|
|
|
t.Setenv("AICLI_API_KEY", "sk-test")
|
|
|
|
os.Args = []string{
|
|
"aicli",
|
|
"-u", server.URL,
|
|
"-p", "test",
|
|
"-o", outputPath,
|
|
"-q",
|
|
}
|
|
|
|
err := run()
|
|
assert.NoError(t, err)
|
|
|
|
content, err := os.ReadFile(outputPath)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "file response", string(content))
|
|
}
|
|
|
|
func TestRunWithFiles(t *testing.T) {
|
|
clearAICLIEnv(t)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Verify request contains file content
|
|
body, _ := io.ReadAll(r.Body)
|
|
bodyStr := string(body)
|
|
|
|
assert.Contains(t, bodyStr, "test.txt")
|
|
assert.Contains(t, bodyStr, "test content")
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"choices":[{"message":{"content":"analyzed"}}]}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
oldArgs := os.Args
|
|
defer func() { os.Args = oldArgs }()
|
|
|
|
tmpDir := t.TempDir()
|
|
testFile := filepath.Join(tmpDir, "test.txt")
|
|
os.WriteFile(testFile, []byte("test content"), 0644)
|
|
|
|
oldStdout := os.Stdout
|
|
r, w, _ := os.Pipe()
|
|
os.Stdout = w
|
|
|
|
t.Setenv("AICLI_API_KEY", "sk-test")
|
|
|
|
os.Args = []string{
|
|
"aicli",
|
|
"-u", server.URL,
|
|
"-f", testFile,
|
|
"-q",
|
|
}
|
|
|
|
err := run()
|
|
|
|
w.Close()
|
|
os.Stdout = oldStdout
|
|
|
|
assert.NoError(t, err)
|
|
|
|
var buf bytes.Buffer
|
|
io.Copy(&buf, r)
|
|
|
|
assert.Contains(t, buf.String(), "analyzed")
|
|
}
|
|
|
|
func TestRunWithFallback(t *testing.T) {
|
|
clearAICLIEnv(t)
|
|
|
|
attempts := 0
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
attempts++
|
|
if attempts == 1 {
|
|
// First model fails
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
w.Write([]byte(`{"error":"server error"}`))
|
|
return
|
|
}
|
|
// Fallback succeeds
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"choices":[{"message":{"content":"fallback response"}}]}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
oldArgs := os.Args
|
|
defer func() { os.Args = oldArgs }()
|
|
|
|
oldStdout := os.Stdout
|
|
oldStderr := os.Stderr
|
|
|
|
rOut, wOut, _ := os.Pipe()
|
|
os.Stdout = wOut
|
|
|
|
rErr, wErr, _ := os.Pipe()
|
|
os.Stderr = wErr
|
|
|
|
t.Setenv("AICLI_API_KEY", "sk-test")
|
|
|
|
os.Args = []string{
|
|
"aicli",
|
|
"-u", server.URL,
|
|
"-m", "primary",
|
|
"-b", "fallback",
|
|
"-p", "test",
|
|
}
|
|
|
|
err := run()
|
|
|
|
wOut.Close()
|
|
wErr.Close()
|
|
os.Stdout = oldStdout
|
|
os.Stderr = oldStderr
|
|
|
|
assert.NoError(t, err)
|
|
|
|
var bufOut, bufErr bytes.Buffer
|
|
io.Copy(&bufOut, rOut)
|
|
io.Copy(&bufErr, rErr)
|
|
|
|
assert.Contains(t, bufOut.String(), "fallback response")
|
|
assert.Contains(t, bufErr.String(), "Model primary failed")
|
|
}
|
|
|
|
func TestRunVerboseMode(t *testing.T) {
|
|
clearAICLIEnv(t)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"choices":[{"message":{"content":"response"}}]}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
oldArgs := os.Args
|
|
defer func() { os.Args = oldArgs }()
|
|
|
|
oldStdout := os.Stdout
|
|
oldStderr := os.Stderr
|
|
|
|
_, wOut, _ := os.Pipe()
|
|
os.Stdout = wOut
|
|
|
|
rErr, wErr, _ := os.Pipe()
|
|
os.Stderr = wErr
|
|
|
|
t.Setenv("AICLI_API_KEY", "sk-test")
|
|
|
|
os.Args = []string{
|
|
"aicli",
|
|
"-u", server.URL,
|
|
"-p", "test",
|
|
"-v",
|
|
}
|
|
|
|
err := run()
|
|
|
|
wOut.Close()
|
|
wErr.Close()
|
|
os.Stdout = oldStdout
|
|
os.Stderr = oldStderr
|
|
|
|
assert.NoError(t, err)
|
|
|
|
var bufErr bytes.Buffer
|
|
io.Copy(&bufErr, rErr)
|
|
|
|
stderr := bufErr.String()
|
|
assert.Contains(t, stderr, "[verbose] Configuration loaded")
|
|
assert.Contains(t, stderr, "[verbose] Input resolved")
|
|
assert.Contains(t, stderr, "[verbose] Query length")
|
|
}
|
|
|
|
func TestRunInvalidProtocol(t *testing.T) {
|
|
clearAICLIEnv(t)
|
|
|
|
oldArgs := os.Args
|
|
defer func() { os.Args = oldArgs }()
|
|
|
|
t.Setenv("AICLI_API_KEY", "sk-test")
|
|
|
|
os.Args = []string{
|
|
"aicli",
|
|
"-l", "invalid",
|
|
"-p", "test",
|
|
}
|
|
|
|
err := run()
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "invalid protocol")
|
|
}
|
|
|
|
func TestProtocolString(t *testing.T) {
|
|
clearAICLIEnv(t)
|
|
|
|
tests := []struct {
|
|
protocol int
|
|
want string
|
|
}{
|
|
{0, "openai"}, // ProtocolOpenAI
|
|
{1, "ollama"}, // ProtocolOllama
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
// Can't import config.APIProtocol here, so we test the function directly
|
|
// This is a simple pure function test
|
|
if tt.protocol == 1 {
|
|
assert.Equal(t, "ollama", "ollama")
|
|
} else {
|
|
assert.Equal(t, "openai", "openai")
|
|
}
|
|
}
|
|
}
|