diff --git a/INSTALL.md b/INSTALL.md index 0b90a8c..dd4a935 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -38,7 +38,7 @@ sudo mv aicli-linux-amd64 /usr/local/bin/aicli # Verify installation aicli --version -``` +```` ### Windows Installation @@ -46,9 +46,10 @@ aicli --version 2. Rename the executable to `aicli.exe` if desired 3. Add the directory to your PATH or move the executable to a directory in your PATH 4. Open Command Prompt or PowerShell and verify the installation: - ``` - aicli --version - ``` + +``` +aicli --version +``` ## Configuration @@ -69,14 +70,17 @@ export AICLI_API_KEY_FILE=~/.aicli_key ### Configuration File -Create a configuration file at `~/.aicli.yaml` or use the sample config provided in the release: +Create a configuration file at `~/.aicli.yaml` or use the sample config: ```bash -# Download the sample config -curl -LO https://git.wisehodl.dev/jay/aicli/raw/branch/main/sample-config.yml - -# Copy to your home directory -cp sample-config.yml ~/.aicli.yaml +# Create config file +cat > ~/.aicli.yaml << 'EOF' +protocol: openai +url: https://api.ppq.ai/chat/completions +key_file: ~/.aicli_key +model: gpt-4o-mini +fallback: gpt-4.1-mini,o3 +EOF # Edit with your preferred editor nano ~/.aicli.yaml @@ -114,4 +118,4 @@ sudo mv aicli /usr/local/bin/ ## Next Steps -See the [README.md](README.md) for usage instructions and examples. +See the [README.md](https://claude.ai/chat/README.md) for usage instructions and examples. diff --git a/README.md b/README.md index 179dd5c..7f31bc9 100644 --- a/README.md +++ b/README.md @@ -22,18 +22,20 @@ AICLI provides a streamlined way to interact with language models from your term ### Pre-built Binaries -Download the latest binary for your platform from the [Releases](https://git.wisehodl.dev/jay/aicli/releases) page: +Download the latest binary for your platform from the [Releases](https://git.wisehodl.dev/jay/aicli/releases) page. Make the file executable (Linux/macOS): ```bash chmod +x aicli-linux-amd64 mv aicli-linux-amd64 /usr/local/bin/aicli # or any directory in your PATH -``` +```` + +See [INSTALL.md](https://claude.ai/chat/INSTALL.md) for detailed installation instructions. ### Building from Source -Requires Go 1.16+: +Requires Go 1.23+: ```bash git clone https://git.wisehodl.dev/jay/aicli.git @@ -73,7 +75,7 @@ export AICLI_API_KEY_FILE=~/.aicli_key export AICLI_API_KEY="your-api-key" export AICLI_API_KEY_FILE="~/.aicli_key" export AICLI_PROTOCOL="openai" # or "ollama" -export AICLI_URL="https://api.ppq.ai/chat/completions" # custom endpoint +export AICLI_URL="https://api.ppq.ai/chat/completions" # Model Selection export AICLI_MODEL="gpt-4o-mini" @@ -242,38 +244,36 @@ grep ERROR /var/log/app.log | aicli -p "Identify patterns in these error logs" ## Full Command Reference ``` -Usage: aicli [OPTION]... [FILE]... -Send files and prompts to LLM chat endpoints. - -With no FILE, or when FILE is -, read standard input. +Usage: aicli [OPTION]... +Send prompts and files to LLM chat endpoints. Global: - --version display version information and exit + --version display version and exit Input: -f, --file PATH input file (repeatable) - -F, --stdin-file treat stdin as file contents - -p, --prompt TEXT prompt text (repeatable, can be combined with --prompt-file) - -pf, --prompt-file PATH prompt from file (combined with any --prompt flags) + -F, --stdin-file treat stdin as file content + -p, --prompt TEXT prompt text (repeatable) + -pf, --prompt-file PATH read prompt from file System: -s, --system TEXT system prompt text - -sf, --system-file PATH system prompt from file + -sf, --system-file PATH read system prompt from file API: - -l, --protocol PROTO API protocol: openai, ollama (default: openai) - -u, --url URL API endpoint (default: https://api.ppq.ai/chat/completions) - -k, --key KEY API key (if present, --key-file is ignored) - -kf, --key-file PATH API key from file (used only if --key is not provided) + -l, --protocol PROTO openai or ollama (default: openai) + -u, --url URL endpoint (default: https://api.ppq.ai/chat/completions) + -k, --key KEY API key + -kf, --key-file PATH read API key from file Models: -m, --model NAME primary model (default: gpt-4o-mini) - -b, --fallback NAMES comma-separated fallback models (default: gpt-4.1-mini) + -b, --fallback NAMES comma-separated fallback list (default: gpt-4.1-mini) Output: -o, --output PATH write to file instead of stdout - -q, --quiet suppress progress output - -v, --verbose enable debug logging + -q, --quiet suppress progress messages + -v, --verbose log debug information to stderr Config: -c, --config PATH YAML config file diff --git a/api/api.go b/api/api.go new file mode 100644 index 0000000..97a1900 --- /dev/null +++ b/api/api.go @@ -0,0 +1,62 @@ +package api + +import ( + "encoding/json" + "fmt" + "os" + "time" + + "git.wisehodl.dev/jay/aicli/config" +) + +// tryModel attempts a single model request through the complete pipeline: +// payload construction, HTTP execution, and response parsing. +func tryModel(cfg config.ConfigData, model string, query string) (string, error) { + payload := buildPayload(cfg, model, query) + + if cfg.Verbose { + payloadJSON, _ := json.Marshal(payload) + fmt.Fprintf(os.Stderr, "[verbose] Request payload: %s\n", string(payloadJSON)) + } + + body, err := executeHTTP(cfg, payload) + if err != nil { + return "", err + } + + if cfg.Verbose { + fmt.Fprintf(os.Stderr, "[verbose] Response: %s\n", string(body)) + } + + response, err := parseResponse(body, cfg.Protocol) + if err != nil { + return "", err + } + + return response, nil +} + +// SendChatRequest sends a query to the configured model with automatic fallback. +// Returns the response content, the model name that succeeded, total duration, and any error. +// On failure, attempts each fallback model in sequence until one succeeds or all fail. +func SendChatRequest(cfg config.ConfigData, query string) (string, string, time.Duration, error) { + models := append([]string{cfg.Model}, cfg.FallbackModels...) + start := time.Now() + + for i, model := range models { + if !cfg.Quiet && i > 0 { + fmt.Fprintf(os.Stderr, "Model %s failed, trying %s...\n", models[i-1], model) + } + + response, err := tryModel(cfg, model, query) + if err == nil { + return response, model, time.Since(start), nil + } + + if !cfg.Quiet { + fmt.Fprintf(os.Stderr, "Model %s failed: %v\n", model, err) + } + } + + return "", "", time.Since(start), fmt.Errorf("all models failed") +} diff --git a/api/api_test.go b/api/api_test.go new file mode 100644 index 0000000..c0a6e2b --- /dev/null +++ b/api/api_test.go @@ -0,0 +1,263 @@ +package api + +import ( + "bytes" + "fmt" + "io" + "net/http" + "os" + "testing" + "time" + + "git.wisehodl.dev/jay/aicli/config" + "github.com/stretchr/testify/assert" +) + +type sequenceTransport struct { + responses []*http.Response + index int +} + +func (t *sequenceTransport) RoundTrip(*http.Request) (*http.Response, error) { + if t.index >= len(t.responses) { + return nil, fmt.Errorf("no more responses in sequence") + } + resp := t.responses[t.index] + t.index++ + return resp, nil +} + +func TestTryModel(t *testing.T) { + tests := []struct { + name string + cfg config.ConfigData + model string + query string + mockResp *http.Response + mockErr error + want string + wantErr bool + errContains string + }{ + { + name: "successful request", + cfg: config.ConfigData{ + Protocol: config.ProtocolOpenAI, + URL: "https://api.example.com", + APIKey: "sk-test", + }, + model: "gpt-4", + query: "test query", + mockResp: makeResponse(200, `{"choices":[{"message":{"content":"response text"}}]}`), + want: "response text", + }, + { + name: "http error", + cfg: config.ConfigData{ + Protocol: config.ProtocolOpenAI, + URL: "https://api.example.com", + APIKey: "sk-test", + }, + model: "gpt-4", + query: "test query", + mockResp: makeResponse(500, `{"error":"server error"}`), + wantErr: true, + errContains: "HTTP 500", + }, + { + name: "parse error", + cfg: config.ConfigData{ + Protocol: config.ProtocolOpenAI, + URL: "https://api.example.com", + APIKey: "sk-test", + }, + model: "gpt-4", + query: "test query", + mockResp: makeResponse(200, `{"choices":[]}`), + wantErr: true, + errContains: "empty choices array", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := &mockRoundTripper{ + response: tt.mockResp, + err: tt.mockErr, + } + oldClient := httpClient + httpClient = &http.Client{ + Timeout: 5 * time.Minute, + Transport: transport, + } + defer func() { httpClient = oldClient }() + + got, err := tryModel(tt.cfg, tt.model, tt.query) + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func captureStderr(f func()) string { + old := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + f() + + w.Close() + os.Stderr = old + + var buf bytes.Buffer + io.Copy(&buf, r) + return buf.String() +} + +func TestSendChatRequest(t *testing.T) { + tests := []struct { + name string + cfg config.ConfigData + query string + mockResp []*http.Response + wantResponse string + wantModel string + wantErr bool + errContains string + checkStderr func(*testing.T, string) + }{ + { + name: "primary model succeeds", + cfg: config.ConfigData{ + Protocol: config.ProtocolOpenAI, + URL: "https://api.example.com", + APIKey: "sk-test", + Model: "gpt-4", + FallbackModels: []string{"gpt-3.5"}, + }, + query: "test", + mockResp: []*http.Response{ + makeResponse(200, `{"choices":[{"message":{"content":"primary response"}}]}`), + }, + wantResponse: "primary response", + wantModel: "gpt-4", + }, + { + name: "primary fails, fallback succeeds", + cfg: config.ConfigData{ + Protocol: config.ProtocolOpenAI, + URL: "https://api.example.com", + APIKey: "sk-test", + Model: "gpt-4", + FallbackModels: []string{"gpt-3.5"}, + }, + query: "test", + mockResp: []*http.Response{ + makeResponse(500, `{"error":"server error"}`), + makeResponse(200, `{"choices":[{"message":{"content":"fallback response"}}]}`), + }, + wantResponse: "fallback response", + wantModel: "gpt-3.5", + checkStderr: func(t *testing.T, stderr string) { + assert.Contains(t, stderr, "Model gpt-4 failed") + assert.Contains(t, stderr, "trying gpt-3.5") + }, + }, + { + name: "all models fail", + cfg: config.ConfigData{ + Protocol: config.ProtocolOpenAI, + URL: "https://api.example.com", + APIKey: "sk-test", + Model: "gpt-4", + FallbackModels: []string{"gpt-3.5"}, + }, + query: "test", + mockResp: []*http.Response{ + makeResponse(500, `{"error":"error1"}`), + makeResponse(500, `{"error":"error2"}`), + }, + wantErr: true, + errContains: "all models failed", + checkStderr: func(t *testing.T, stderr string) { + assert.Contains(t, stderr, "Model gpt-4 failed") + assert.Contains(t, stderr, "Model gpt-3.5 failed") + }, + }, + { + name: "quiet mode suppresses progress", + cfg: config.ConfigData{ + Protocol: config.ProtocolOpenAI, + URL: "https://api.example.com", + APIKey: "sk-test", + Model: "gpt-4", + FallbackModels: []string{"gpt-3.5"}, + Quiet: true, + }, + query: "test", + mockResp: []*http.Response{ + makeResponse(500, `{"error":"error1"}`), + makeResponse(200, `{"choices":[{"message":{"content":"response"}}]}`), + }, + wantResponse: "response", + wantModel: "gpt-3.5", + checkStderr: func(t *testing.T, stderr string) { + assert.Empty(t, stderr) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := &sequenceTransport{responses: tt.mockResp} + + oldClient := httpClient + httpClient = &http.Client{ + Timeout: 5 * time.Minute, + Transport: transport, + } + defer func() { httpClient = oldClient }() + + var stderr string + var response string + var model string + var duration time.Duration + var err error + + if tt.checkStderr != nil { + stderr = captureStderr(func() { + response, model, duration, err = SendChatRequest(tt.cfg, tt.query) + }) + } else { + response, model, duration, err = SendChatRequest(tt.cfg, tt.query) + } + + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + if tt.checkStderr != nil { + tt.checkStderr(t, stderr) + } + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.wantResponse, response) + assert.Equal(t, tt.wantModel, model) + assert.Greater(t, duration, time.Duration(0)) + + if tt.checkStderr != nil { + tt.checkStderr(t, stderr) + } + }) + } +} diff --git a/api/http.go b/api/http.go new file mode 100644 index 0000000..4089e12 --- /dev/null +++ b/api/http.go @@ -0,0 +1,47 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "git.wisehodl.dev/jay/aicli/config" +) + +var httpClient = &http.Client{Timeout: 5 * time.Minute} + +// executeHTTP sends the payload to the API endpoint and returns the response body. +func executeHTTP(cfg config.ConfigData, payload map[string]interface{}) ([]byte, error) { + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("marshal payload: %w", err) + } + + req, err := http.NewRequest("POST", cfg.URL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", cfg.APIKey)) + + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + return respBody, nil +} diff --git a/api/http_test.go b/api/http_test.go new file mode 100644 index 0000000..08537d6 --- /dev/null +++ b/api/http_test.go @@ -0,0 +1,193 @@ +package api + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + "time" + + "git.wisehodl.dev/jay/aicli/config" + "github.com/stretchr/testify/assert" +) + +type mockRoundTripper struct { + response *http.Response + err error + request *http.Request +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + m.request = req + return m.response, m.err +} + +func makeResponse(statusCode int, body string) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + } +} + +func TestExecuteHTTP(t *testing.T) { + tests := []struct { + name string + cfg config.ConfigData + payload map[string]interface{} + mockResp *http.Response + mockErr error + wantBody string + wantErr bool + errContains string + }{ + { + name: "successful request", + cfg: config.ConfigData{ + URL: "https://api.example.com/chat", + APIKey: "sk-test123", + }, + payload: map[string]interface{}{ + "model": "gpt-4", + }, + mockResp: makeResponse(200, `{"choices":[{"message":{"content":"response"}}]}`), + wantBody: `{"choices":[{"message":{"content":"response"}}]}`, + wantErr: false, + }, + { + name: "HTTP 400 error", + cfg: config.ConfigData{ + URL: "https://api.example.com/chat", + APIKey: "sk-test123", + }, + payload: map[string]interface{}{"model": "gpt-4"}, + mockResp: makeResponse(400, `{"error":"bad request"}`), + wantErr: true, + errContains: "HTTP 400", + }, + { + name: "HTTP 401 unauthorized", + cfg: config.ConfigData{ + URL: "https://api.example.com/chat", + APIKey: "invalid-key", + }, + payload: map[string]interface{}{"model": "gpt-4"}, + mockResp: makeResponse(401, `{"error":"unauthorized"}`), + wantErr: true, + errContains: "HTTP 401", + }, + { + name: "HTTP 429 rate limit", + cfg: config.ConfigData{ + URL: "https://api.example.com/chat", + APIKey: "sk-test123", + }, + payload: map[string]interface{}{"model": "gpt-4"}, + mockResp: makeResponse(429, `{"error":"rate limit exceeded"}`), + wantErr: true, + errContains: "HTTP 429", + }, + { + name: "HTTP 500 server error", + cfg: config.ConfigData{ + URL: "https://api.example.com/chat", + APIKey: "sk-test123", + }, + payload: map[string]interface{}{"model": "gpt-4"}, + mockResp: makeResponse(500, `{"error":"internal server error"}`), + wantErr: true, + errContains: "HTTP 500", + }, + { + name: "network error", + cfg: config.ConfigData{ + URL: "https://api.example.com/chat", + APIKey: "sk-test123", + }, + payload: map[string]interface{}{"model": "gpt-4"}, + mockErr: http.ErrHandlerTimeout, + wantErr: true, + errContains: "execute request", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := &mockRoundTripper{ + response: tt.mockResp, + err: tt.mockErr, + } + oldClient := httpClient + httpClient = &http.Client{ + Timeout: 5 * time.Minute, + Transport: transport, + } + defer func() { httpClient = oldClient }() + + got, err := executeHTTP(tt.cfg, tt.payload) + + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.wantBody, string(got)) + }) + } +} + +func TestExecuteHTTPHeaders(t *testing.T) { + cfg := config.ConfigData{ + URL: "https://api.example.com/chat", + APIKey: "sk-test-key", + } + payload := map[string]interface{}{"model": "gpt-4"} + + transport := &mockRoundTripper{ + response: makeResponse(200, `{"result":"ok"}`), + } + oldClient := httpClient + httpClient = &http.Client{ + Timeout: 5 * time.Minute, + Transport: transport, + } + defer func() { httpClient = oldClient }() + + _, err := executeHTTP(cfg, payload) + assert.NoError(t, err) + + assert.Equal(t, "application/json", transport.request.Header.Get("Content-Type")) + assert.Equal(t, "Bearer sk-test-key", transport.request.Header.Get("Authorization")) +} + +func TestExecuteHTTPTimeout(t *testing.T) { + cfg := config.ConfigData{ + URL: "https://api.example.com/chat", + APIKey: "sk-test123", + } + payload := map[string]interface{}{"model": "gpt-4"} + + transport := &mockRoundTripper{ + response: makeResponse(200, `{"ok":true}`), + } + + client := &http.Client{ + Timeout: 5 * time.Minute, + Transport: transport, + } + + body, _ := json.Marshal(payload) + req, _ := http.NewRequest("POST", cfg.URL, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+cfg.APIKey) + + resp, err := client.Do(req) + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) +} diff --git a/api/parse.go b/api/parse.go new file mode 100644 index 0000000..5cd8671 --- /dev/null +++ b/api/parse.go @@ -0,0 +1,51 @@ +package api + +import ( + "encoding/json" + "fmt" + + "git.wisehodl.dev/jay/aicli/config" +) + +// parseResponse extracts the response content from the API response body. +func parseResponse(body []byte, protocol config.APIProtocol) (string, error) { + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return "", fmt.Errorf("parse response: %w", err) + } + + if protocol == config.ProtocolOllama { + response, ok := result["response"].(string) + if !ok { + return "", fmt.Errorf("no response field in ollama response") + } + return response, nil + } + + // OpenAI protocol + choices, ok := result["choices"].([]interface{}) + if !ok { + return "", fmt.Errorf("no choices in response") + } + + if len(choices) == 0 { + return "", fmt.Errorf("empty choices array") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid choice format") + } + + message, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return "", fmt.Errorf("no message in choice") + } + + content, ok := message["content"].(string) + if !ok { + return "", fmt.Errorf("no content in message") + } + + return content, nil +} diff --git a/api/parse_test.go b/api/parse_test.go new file mode 100644 index 0000000..9078bea --- /dev/null +++ b/api/parse_test.go @@ -0,0 +1,159 @@ +package api + +import ( + "os" + "testing" + + "git.wisehodl.dev/jay/aicli/config" + "github.com/stretchr/testify/assert" +) + +func TestParseResponse(t *testing.T) { + tests := []struct { + name string + body string + protocol config.APIProtocol + want string + wantErr bool + errContains string + }{ + { + name: "openai success", + body: `{"choices":[{"message":{"content":"This is the response text."}}]}`, + protocol: config.ProtocolOpenAI, + want: "This is the response text.", + }, + { + name: "openai empty choices", + body: `{"choices":[]}`, + protocol: config.ProtocolOpenAI, + wantErr: true, + errContains: "empty choices array", + }, + { + name: "openai no choices field", + body: `{"result":"ok"}`, + protocol: config.ProtocolOpenAI, + wantErr: true, + errContains: "no choices in response", + }, + { + name: "openai invalid choice format", + body: `{"choices":["invalid"]}`, + protocol: config.ProtocolOpenAI, + wantErr: true, + errContains: "invalid choice format", + }, + { + name: "openai no message field", + body: `{"choices":[{"text":"wrong structure"}]}`, + protocol: config.ProtocolOpenAI, + wantErr: true, + errContains: "no message in choice", + }, + { + name: "openai no content field", + body: `{"choices":[{"message":{"role":"assistant"}}]}`, + protocol: config.ProtocolOpenAI, + wantErr: true, + errContains: "no content in message", + }, + { + name: "ollama success", + body: `{"response":"This is the Ollama response."}`, + protocol: config.ProtocolOllama, + want: "This is the Ollama response.", + }, + { + name: "ollama no response field", + body: `{"model":"llama3"}`, + protocol: config.ProtocolOllama, + wantErr: true, + errContains: "no response field in ollama response", + }, + { + name: "malformed json", + body: `{invalid json`, + protocol: config.ProtocolOpenAI, + wantErr: true, + errContains: "parse response", + }, + { + name: "openai multiline content", + body: `{"choices":[{"message":{"content":"Line 1\nLine 2\nLine 3"}}]}`, + protocol: config.ProtocolOpenAI, + want: "Line 1\nLine 2\nLine 3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseResponse([]byte(tt.body), tt.protocol) + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestParseResponseWithTestdata(t *testing.T) { + tests := []struct { + name string + file string + protocol config.APIProtocol + want string + wantErr bool + errContains string + }{ + { + name: "openai success from file", + file: "testdata/openai_success.json", + protocol: config.ProtocolOpenAI, + want: "This is the response text.", + }, + { + name: "openai empty choices from file", + file: "testdata/openai_empty_choices.json", + protocol: config.ProtocolOpenAI, + wantErr: true, + errContains: "empty choices array", + }, + { + name: "ollama success from file", + file: "testdata/ollama_success.json", + protocol: config.ProtocolOllama, + want: "This is the Ollama response.", + }, + { + name: "ollama no response from file", + file: "testdata/ollama_no_response.json", + protocol: config.ProtocolOllama, + wantErr: true, + errContains: "no response field", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, err := os.ReadFile(tt.file) + assert.NoError(t, err, "failed to read test file") + + got, err := parseResponse(body, tt.protocol) + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/api/payload.go b/api/payload.go new file mode 100644 index 0000000..361ed66 --- /dev/null +++ b/api/payload.go @@ -0,0 +1,36 @@ +package api + +import "git.wisehodl.dev/jay/aicli/config" + +// buildPayload constructs the JSON payload for the API request based on protocol. +func buildPayload(cfg config.ConfigData, model string, query string) map[string]interface{} { + if cfg.Protocol == config.ProtocolOllama { + payload := map[string]interface{}{ + "model": model, + "prompt": query, + "stream": false, + } + if cfg.SystemPrompt != "" { + payload["system"] = cfg.SystemPrompt + } + return payload + } + + // OpenAI protocol + messages := []map[string]string{} + if cfg.SystemPrompt != "" { + messages = append(messages, map[string]string{ + "role": "system", + "content": cfg.SystemPrompt, + }) + } + messages = append(messages, map[string]string{ + "role": "user", + "content": query, + }) + + return map[string]interface{}{ + "model": model, + "messages": messages, + } +} diff --git a/api/payload_test.go b/api/payload_test.go new file mode 100644 index 0000000..b561a80 --- /dev/null +++ b/api/payload_test.go @@ -0,0 +1,126 @@ +package api + +import ( + "testing" + + "git.wisehodl.dev/jay/aicli/config" + "github.com/stretchr/testify/assert" +) + +func TestBuildPayload(t *testing.T) { + tests := []struct { + name string + cfg config.ConfigData + model string + query string + want map[string]interface{} + }{ + { + name: "openai without system prompt", + cfg: config.ConfigData{ + Protocol: config.ProtocolOpenAI, + }, + model: "gpt-4", + query: "analyze this", + want: map[string]interface{}{ + "model": "gpt-4", + "messages": []map[string]string{ + {"role": "user", "content": "analyze this"}, + }, + }, + }, + { + name: "openai with system prompt", + cfg: config.ConfigData{ + Protocol: config.ProtocolOpenAI, + SystemPrompt: "You are helpful", + }, + model: "gpt-4", + query: "analyze this", + want: map[string]interface{}{ + "model": "gpt-4", + "messages": []map[string]string{ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "analyze this"}, + }, + }, + }, + { + name: "ollama without system prompt", + cfg: config.ConfigData{ + Protocol: config.ProtocolOllama, + }, + model: "llama3", + query: "analyze this", + want: map[string]interface{}{ + "model": "llama3", + "prompt": "analyze this", + "stream": false, + }, + }, + { + name: "ollama with system prompt", + cfg: config.ConfigData{ + Protocol: config.ProtocolOllama, + SystemPrompt: "You are helpful", + }, + model: "llama3", + query: "analyze this", + want: map[string]interface{}{ + "model": "llama3", + "prompt": "analyze this", + "system": "You are helpful", + "stream": false, + }, + }, + { + name: "empty query", + cfg: config.ConfigData{ + Protocol: config.ProtocolOpenAI, + }, + model: "gpt-4", + query: "", + want: map[string]interface{}{ + "model": "gpt-4", + "messages": []map[string]string{ + {"role": "user", "content": ""}, + }, + }, + }, + { + name: "multiline query", + cfg: config.ConfigData{ + Protocol: config.ProtocolOpenAI, + }, + model: "gpt-4", + query: "line1\nline2\nline3", + want: map[string]interface{}{ + "model": "gpt-4", + "messages": []map[string]string{ + {"role": "user", "content": "line1\nline2\nline3"}, + }, + }, + }, + { + name: "model name injection", + cfg: config.ConfigData{ + Protocol: config.ProtocolOpenAI, + }, + model: "custom-model-name", + query: "test", + want: map[string]interface{}{ + "model": "custom-model-name", + "messages": []map[string]string{ + {"role": "user", "content": "test"}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildPayload(tt.cfg, tt.model, tt.query) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/api/testdata/ollama_no_response.json b/api/testdata/ollama_no_response.json new file mode 100644 index 0000000..a97b6d8 --- /dev/null +++ b/api/testdata/ollama_no_response.json @@ -0,0 +1,3 @@ +{ + "model": "llama3" +} diff --git a/api/testdata/ollama_success.json b/api/testdata/ollama_success.json new file mode 100644 index 0000000..e3417a3 --- /dev/null +++ b/api/testdata/ollama_success.json @@ -0,0 +1,3 @@ +{ + "response": "This is the Ollama response." +} diff --git a/api/testdata/openai_empty_choices.json b/api/testdata/openai_empty_choices.json new file mode 100644 index 0000000..b6c2984 --- /dev/null +++ b/api/testdata/openai_empty_choices.json @@ -0,0 +1,3 @@ +{ + "choices": [] +} diff --git a/api/testdata/openai_success.json b/api/testdata/openai_success.json new file mode 100644 index 0000000..7729acf --- /dev/null +++ b/api/testdata/openai_success.json @@ -0,0 +1,9 @@ +{ + "choices": [ + { + "message": { + "content": "This is the response text." + } + } + ] +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..94721ef --- /dev/null +++ b/config/config.go @@ -0,0 +1,141 @@ +package config + +import ( + "fmt" + "os" +) + +const UsageText = `Usage: aicli [OPTION]... +Send prompts and files to LLM chat endpoints. + +Global: + --version display version and exit + +Input: + -f, --file PATH input file (repeatable) + -F, --stdin-file treat stdin as file content + -p, --prompt TEXT prompt text (repeatable) + -pf, --prompt-file PATH read prompt from file + +System: + -s, --system TEXT system prompt text + -sf, --system-file PATH read system prompt from file + (error if both -s and -sf provided) + +API: + -l, --protocol PROTO openai or ollama (default: openai) + -u, --url URL endpoint (default: https://api.ppq.ai/chat/completions) + -k, --key KEY API key + -kf, --key-file PATH read API key from file + +Models: + -m, --model NAME primary model (default: gpt-4o-mini) + -b, --fallback NAMES comma-separated fallback list (default: gpt-4.1-mini) + +Output: + -o, --output PATH write to file (mode 0644) instead of stdout + -q, --quiet suppress progress messages + -v, --verbose log debug information to stderr + +Config: + -c, --config PATH YAML config file + +Environment Variables: + AICLI_API_KEY API key + AICLI_API_KEY_FILE path to API key file + AICLI_PROTOCOL API protocol + AICLI_URL endpoint URL + AICLI_MODEL primary model name + AICLI_FALLBACK comma-separated fallback models + AICLI_SYSTEM system prompt text + AICLI_SYSTEM_FILE path to system prompt file + AICLI_CONFIG_FILE path to config file + AICLI_PROMPT_FILE path to prompt file + AICLI_DEFAULT_PROMPT override default prompt + +Precedence Rules: + API key: --key > --key-file > AICLI_API_KEY > AICLI_API_KEY_FILE > config key_file + System: --system > --system-file > AICLI_SYSTEM > AICLI_SYSTEM_FILE > config system_file + Config file: --config > AICLI_CONFIG_FILE + All others: flags > environment > config file > defaults + +Stdin Behavior: + No flags: stdin becomes the prompt + With -p/-pf: stdin appends after explicit prompts + With -F: stdin becomes first file (path: "input") + +Examples: + echo "What is Rust?" | aicli + cat log.txt | aicli -F -p "Find errors in this log" + aicli -f main.go -p "Review this code" + aicli -c ~/.aicli.yaml -f src/main.go -f src/util.go -o analysis.md + aicli -p "Context:" -pf template.txt -p "Apply to finance sector" +` + +func printUsage() { + fmt.Fprint(os.Stderr, UsageText) +} + +// BuildConfig resolves configuration from all sources with precedence: +// flags > env > file > defaults +func BuildConfig(args []string) (ConfigData, error) { + flags, err := parseFlags(args) + if err != nil { + return ConfigData{}, fmt.Errorf("parse flags: %w", err) + } + + // Validate protocol strings before merge + if flags.protocol != "" && flags.protocol != "openai" && flags.protocol != "ollama" { + return ConfigData{}, fmt.Errorf("invalid protocol: must be openai or ollama, got: %s", flags.protocol) + } + + configPath := flags.config + if configPath == "" { + configPath = os.Getenv("AICLI_CONFIG_FILE") + } + + env := loadEnvironment() + + // Validate env protocol + if env.protocol != "" && env.protocol != "openai" && env.protocol != "ollama" { + return ConfigData{}, fmt.Errorf("invalid protocol: must be openai or ollama, got: %s", env.protocol) + } + + file, err := loadConfigFile(configPath) + if err != nil { + return ConfigData{}, fmt.Errorf("load config file: %w", err) + } + + // Validate file protocol + if file.protocol != "" && file.protocol != "openai" && file.protocol != "ollama" { + return ConfigData{}, fmt.Errorf("invalid protocol: must be openai or ollama, got: %s", file.protocol) + } + + cfg := mergeSources(flags, env, file) + + if err := validateConfig(cfg); err != nil { + return ConfigData{}, err + } + + return cfg, nil +} + +// IsVersionRequest checks if --version flag was passed +func IsVersionRequest(args []string) bool { + for _, arg := range args { + if arg == "--version" { + return true + } + } + return false +} + +// IsHelpRequest checks if -h or --help flag was passed +func IsHelpRequest(args []string) bool { + for _, arg := range args { + if arg == "-h" || arg == "--help" { + return true + } + } + return false +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 0000000..c95475b --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,77 @@ +package config + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestBuildConfig(t *testing.T) { + tests := []struct { + name string + args []string + env map[string]string + wantErr bool + check func(*testing.T, ConfigData) + }{ + { + name: "valid full config", + args: []string{"-k", "sk-test", "-m", "gpt-4"}, + check: func(t *testing.T, cfg ConfigData) { + assert.Equal(t, "sk-test", cfg.APIKey) + assert.Equal(t, "gpt-4", cfg.Model) + }, + }, + { + name: "config file from env", + args: []string{"-k", "sk-test"}, + env: map[string]string{"AICLI_CONFIG_FILE": "testdata/partial.yaml"}, + check: func(t *testing.T, cfg ConfigData) { + assert.Equal(t, "gpt-4", cfg.Model) + }, + }, + { + name: "missing api key", + args: []string{}, + wantErr: true, + }, + { + name: "invalid config file", + args: []string{"-c", "testdata/invalid.yaml", "-k", "test"}, + wantErr: true, + }, + { + name: "invalid protocol in flags", + args: []string{"-k", "sk-test", "-l", "invalid"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Clear all AICLI_* env vars + t.Setenv("AICLI_API_KEY", "") + t.Setenv("AICLI_API_KEY_FILE", "") + t.Setenv("AICLI_PROTOCOL", "") + t.Setenv("AICLI_URL", "") + t.Setenv("AICLI_MODEL", "") + t.Setenv("AICLI_FALLBACK", "") + t.Setenv("AICLI_SYSTEM", "") + t.Setenv("AICLI_CONFIG_FILE", "") + + // Apply test-specific env + for k, v := range tt.env { + t.Setenv(k, v) + } + + cfg, err := BuildConfig(tt.args) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + if tt.check != nil { + tt.check(t, cfg) + } + }) + } +} diff --git a/config/defaults.go b/config/defaults.go new file mode 100644 index 0000000..3b2408e --- /dev/null +++ b/config/defaults.go @@ -0,0 +1,11 @@ +package config + +var defaultConfig = ConfigData{ + StdinAsFile: false, + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + Quiet: false, + Verbose: false, +} diff --git a/config/env.go b/config/env.go new file mode 100644 index 0000000..5dfc33a --- /dev/null +++ b/config/env.go @@ -0,0 +1,34 @@ +package config + +import "os" +import "strings" + +func loadEnvironment() envValues { + ev := envValues{} + + if val := os.Getenv("AICLI_PROTOCOL"); val != "" { + ev.protocol = val + } + if val := os.Getenv("AICLI_URL"); val != "" { + ev.url = val + } + if val := os.Getenv("AICLI_API_KEY"); val != "" { + ev.key = val + } else if val := os.Getenv("AICLI_API_KEY_FILE"); val != "" { + content, err := os.ReadFile(val) + if err == nil { + ev.key = strings.TrimSpace(string(content)) + } + } + if val := os.Getenv("AICLI_MODEL"); val != "" { + ev.model = val + } + if val := os.Getenv("AICLI_FALLBACK"); val != "" { + ev.fallback = val + } + if val := os.Getenv("AICLI_SYSTEM"); val != "" { + ev.system = val + } + + return ev +} diff --git a/config/env_test.go b/config/env_test.go new file mode 100644 index 0000000..517cce2 --- /dev/null +++ b/config/env_test.go @@ -0,0 +1,134 @@ +package config + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestLoadEnvironment(t *testing.T) { + tests := []struct { + name string + env map[string]string + want envValues + }{ + { + name: "empty environment", + env: map[string]string{}, + want: envValues{}, + }, + { + name: "protocol only", + env: map[string]string{"AICLI_PROTOCOL": "ollama"}, + want: envValues{protocol: "ollama"}, + }, + { + name: "url only", + env: map[string]string{"AICLI_URL": "http://localhost:11434"}, + want: envValues{url: "http://localhost:11434"}, + }, + { + name: "api key direct", + env: map[string]string{"AICLI_API_KEY": "sk-test123"}, + want: envValues{key: "sk-test123"}, + }, + { + name: "model only", + env: map[string]string{"AICLI_MODEL": "llama3"}, + want: envValues{model: "llama3"}, + }, + { + name: "fallback only", + env: map[string]string{"AICLI_FALLBACK": "gpt-3.5,gpt-4"}, + want: envValues{fallback: "gpt-3.5,gpt-4"}, + }, + { + name: "system only", + env: map[string]string{"AICLI_SYSTEM": "You are helpful"}, + want: envValues{system: "You are helpful"}, + }, + { + name: "all variables set", + env: map[string]string{ + "AICLI_PROTOCOL": "openai", + "AICLI_URL": "https://api.openai.com/v1/chat/completions", + "AICLI_API_KEY": "sk-abc", + "AICLI_MODEL": "gpt-4", + "AICLI_FALLBACK": "gpt-3.5", + "AICLI_SYSTEM": "system prompt", + }, + want: envValues{ + protocol: "openai", + url: "https://api.openai.com/v1/chat/completions", + key: "sk-abc", + model: "gpt-4", + fallback: "gpt-3.5", + system: "system prompt", + }, + }, + { + name: "empty string values preserved", + env: map[string]string{"AICLI_SYSTEM": ""}, + want: envValues{system: ""}, + }, + { + name: "whitespace preserved", + env: map[string]string{"AICLI_SYSTEM": " spaces "}, + want: envValues{system: " spaces "}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set environment + for k, v := range tt.env { + t.Setenv(k, v) + } + + got := loadEnvironment() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestLoadEnvironmentKeyFile(t *testing.T) { + tests := []struct { + name string + env map[string]string + want envValues + }{ + { + name: "key file when no direct key", + env: map[string]string{"AICLI_API_KEY_FILE": "testdata/api.key"}, + want: envValues{key: "sk-test-key-123"}, + }, + { + name: "direct key overrides key file", + env: map[string]string{ + "AICLI_API_KEY": "sk-direct", + "AICLI_API_KEY_FILE": "testdata/api.key", + }, + want: envValues{key: "sk-direct"}, + }, + { + name: "key file not found", + env: map[string]string{"AICLI_API_KEY_FILE": "/nonexistent/key.txt"}, + want: envValues{}, + }, + { + name: "key file with whitespace trimmed", + env: map[string]string{"AICLI_API_KEY_FILE": "testdata/api_whitespace.key"}, + want: envValues{key: "sk-whitespace-key"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.env { + t.Setenv(k, v) + } + + got := loadEnvironment() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/config/file.go b/config/file.go new file mode 100644 index 0000000..0aa13d1 --- /dev/null +++ b/config/file.go @@ -0,0 +1,44 @@ +package config + +import ( + "gopkg.in/yaml.v3" + "os" +) + +func loadConfigFile(path string) (fileValues, error) { + if path == "" { + return fileValues{}, nil + } + + data, err := os.ReadFile(path) + if err != nil { + return fileValues{}, err + } + + var raw map[string]interface{} + if err := yaml.Unmarshal(data, &raw); err != nil { + return fileValues{}, err + } + + fv := fileValues{} + if v, ok := raw["protocol"].(string); ok { + fv.protocol = v + } + if v, ok := raw["url"].(string); ok { + fv.url = v + } + if v, ok := raw["key_file"].(string); ok { + fv.keyFile = v + } + if v, ok := raw["model"].(string); ok { + fv.model = v + } + if v, ok := raw["fallback"].(string); ok { + fv.fallback = v + } + if v, ok := raw["system_file"].(string); ok { + fv.systemFile = v + } + + return fv, nil +} diff --git a/config/file_test.go b/config/file_test.go new file mode 100644 index 0000000..ec52dc6 --- /dev/null +++ b/config/file_test.go @@ -0,0 +1,76 @@ +package config + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestLoadConfigFile(t *testing.T) { + tests := []struct { + name string + path string + want fileValues + wantErr bool + }{ + { + name: "empty path returns nil", + path: "", + want: fileValues{}, + }, + { + name: "valid config", + path: "testdata/valid.yaml", + want: fileValues{ + protocol: "ollama", + url: "http://localhost:11434/api/chat", + keyFile: "~/.aicli_key", + model: "llama3", + fallback: "llama2,mistral", + systemFile: "~/system.txt", + }, + }, + { + name: "partial config", + path: "testdata/partial.yaml", + want: fileValues{ + model: "gpt-4", + fallback: "gpt-3.5-turbo", + }, + }, + { + name: "empty file", + path: "testdata/empty.yaml", + want: fileValues{}, + }, + { + name: "file not found", + path: "testdata/nonexistent.yaml", + wantErr: true, + }, + { + name: "invalid yaml syntax", + path: "testdata/invalid.yaml", + wantErr: true, + }, + { + name: "unknown keys ignored", + path: "testdata/unknown_keys.yaml", + want: fileValues{ + protocol: "openai", + model: "gpt-4", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := loadConfigFile(tt.path) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/config/flags.go b/config/flags.go new file mode 100644 index 0000000..5a8617b --- /dev/null +++ b/config/flags.go @@ -0,0 +1,81 @@ +package config + +import "flag" + +type stringSlice []string + +func (s *stringSlice) String() string { + if s == nil { + return "" + } + return "" +} + +func (s *stringSlice) Set(value string) error { + *s = append(*s, value) + return nil +} + +func parseFlags(args []string) (flagValues, error) { + fv := flagValues{} + + fs := flag.NewFlagSet("aicli", flag.ContinueOnError) + fs.Usage = printUsage + + var files stringSlice + var prompts stringSlice + + // Input flags + fs.Var(&files, "f", "") + fs.Var(&files, "file", "") + fs.Var(&prompts, "p", "") + fs.Var(&prompts, "prompt", "") + fs.StringVar(&fv.promptFile, "pf", "", "") + fs.StringVar(&fv.promptFile, "prompt-file", "", "") + + // System flags + fs.StringVar(&fv.system, "s", "", "") + fs.StringVar(&fv.system, "system", "", "") + fs.StringVar(&fv.systemFile, "sf", "", "") + fs.StringVar(&fv.systemFile, "system-file", "", "") + + // API flags + fs.StringVar(&fv.key, "k", "", "") + fs.StringVar(&fv.key, "key", "", "") + fs.StringVar(&fv.keyFile, "kf", "", "") + fs.StringVar(&fv.keyFile, "key-file", "", "") + fs.StringVar(&fv.protocol, "l", "", "") + fs.StringVar(&fv.protocol, "protocol", "", "") + fs.StringVar(&fv.url, "u", "", "") + fs.StringVar(&fv.url, "url", "", "") + + // Model flags + fs.StringVar(&fv.model, "m", "", "") + fs.StringVar(&fv.model, "model", "", "") + fs.StringVar(&fv.fallback, "b", "", "") + fs.StringVar(&fv.fallback, "fallback", "", "") + + // Output flags + fs.StringVar(&fv.output, "o", "", "") + fs.StringVar(&fv.output, "output", "", "") + fs.StringVar(&fv.config, "c", "", "") + fs.StringVar(&fv.config, "config", "", "") + + // Boolean flags + fs.BoolVar(&fv.stdinFile, "F", false, "") + fs.BoolVar(&fv.stdinFile, "stdin-file", false, "") + fs.BoolVar(&fv.quiet, "q", false, "") + fs.BoolVar(&fv.quiet, "quiet", false, "") + fs.BoolVar(&fv.verbose, "v", false, "") + fs.BoolVar(&fv.verbose, "verbose", false, "") + fs.BoolVar(&fv.version, "version", false, "") + + if err := fs.Parse(args); err != nil { + return flagValues{}, err + } + + fv.files = files + fv.prompts = prompts + + return fv, nil +} diff --git a/config/flags_test.go b/config/flags_test.go new file mode 100644 index 0000000..583785c --- /dev/null +++ b/config/flags_test.go @@ -0,0 +1,268 @@ +package config + +import ( + "flag" + "github.com/stretchr/testify/assert" + "os" + "testing" +) + +func resetFlags() { + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError) +} + +func TestParseFlags(t *testing.T) { + tests := []struct { + name string + args []string + want flagValues + }{ + { + name: "empty args", + args: []string{}, + want: flagValues{}, + }, + { + name: "single file short flag", + args: []string{"-f", "main.go"}, + want: flagValues{files: []string{"main.go"}}, + }, + { + name: "single file long flag", + args: []string{"--file", "main.go"}, + want: flagValues{files: []string{"main.go"}}, + }, + { + name: "multiple files", + args: []string{"-f", "a.go", "-f", "b.go", "--file", "c.go"}, + want: flagValues{files: []string{"a.go", "b.go", "c.go"}}, + }, + { + name: "single prompt short flag", + args: []string{"-p", "analyze this"}, + want: flagValues{prompts: []string{"analyze this"}}, + }, + { + name: "single prompt long flag", + args: []string{"--prompt", "analyze this"}, + want: flagValues{prompts: []string{"analyze this"}}, + }, + { + name: "multiple prompts", + args: []string{"-p", "first", "-p", "second", "--prompt", "third"}, + want: flagValues{prompts: []string{"first", "second", "third"}}, + }, + { + name: "prompt file", + args: []string{"-pf", "prompt.txt"}, + want: flagValues{promptFile: "prompt.txt"}, + }, + { + name: "prompt file long", + args: []string{"--prompt-file", "prompt.txt"}, + want: flagValues{promptFile: "prompt.txt"}, + }, + { + name: "system short", + args: []string{"-s", "You are helpful"}, + want: flagValues{system: "You are helpful"}, + }, + { + name: "system long", + args: []string{"--system", "You are helpful"}, + want: flagValues{system: "You are helpful"}, + }, + { + name: "system file short", + args: []string{"-sf", "system.txt"}, + want: flagValues{systemFile: "system.txt"}, + }, + { + name: "system file long", + args: []string{"--system-file", "system.txt"}, + want: flagValues{systemFile: "system.txt"}, + }, + { + name: "key short", + args: []string{"-k", "sk-abc123"}, + want: flagValues{key: "sk-abc123"}, + }, + { + name: "key long", + args: []string{"--key", "sk-abc123"}, + want: flagValues{key: "sk-abc123"}, + }, + { + name: "key file short", + args: []string{"-kf", "api.key"}, + want: flagValues{keyFile: "api.key"}, + }, + { + name: "key file long", + args: []string{"--key-file", "api.key"}, + want: flagValues{keyFile: "api.key"}, + }, + { + name: "protocol short", + args: []string{"-l", "ollama"}, + want: flagValues{protocol: "ollama"}, + }, + { + name: "protocol long", + args: []string{"--protocol", "ollama"}, + want: flagValues{protocol: "ollama"}, + }, + { + name: "url short", + args: []string{"-u", "http://localhost:11434"}, + want: flagValues{url: "http://localhost:11434"}, + }, + { + name: "url long", + args: []string{"--url", "http://localhost:11434"}, + want: flagValues{url: "http://localhost:11434"}, + }, + { + name: "model short", + args: []string{"-m", "gpt-4"}, + want: flagValues{model: "gpt-4"}, + }, + { + name: "model long", + args: []string{"--model", "gpt-4"}, + want: flagValues{model: "gpt-4"}, + }, + { + name: "fallback short", + args: []string{"-b", "gpt-3.5-turbo"}, + want: flagValues{fallback: "gpt-3.5-turbo"}, + }, + { + name: "fallback long", + args: []string{"--fallback", "gpt-3.5-turbo"}, + want: flagValues{fallback: "gpt-3.5-turbo"}, + }, + { + name: "output short", + args: []string{"-o", "result.txt"}, + want: flagValues{output: "result.txt"}, + }, + { + name: "output long", + args: []string{"--output", "result.txt"}, + want: flagValues{output: "result.txt"}, + }, + { + name: "config short", + args: []string{"-c", "config.yaml"}, + want: flagValues{config: "config.yaml"}, + }, + { + name: "config long", + args: []string{"--config", "config.yaml"}, + want: flagValues{config: "config.yaml"}, + }, + { + name: "stdin file short", + args: []string{"-F"}, + want: flagValues{stdinFile: true}, + }, + { + name: "stdin file long", + args: []string{"--stdin-file"}, + want: flagValues{stdinFile: true}, + }, + { + name: "quiet short", + args: []string{"-q"}, + want: flagValues{quiet: true}, + }, + { + name: "quiet long", + args: []string{"--quiet"}, + want: flagValues{quiet: true}, + }, + { + name: "verbose short", + args: []string{"-v"}, + want: flagValues{verbose: true}, + }, + { + name: "verbose long", + args: []string{"--verbose"}, + want: flagValues{verbose: true}, + }, + { + name: "version flag", + args: []string{"--version"}, + want: flagValues{version: true}, + }, + { + name: "complex combination", + args: []string{ + "-f", "a.go", + "-f", "b.go", + "-p", "first prompt", + "-pf", "prompt.txt", + "-s", "system prompt", + "-k", "key123", + "-m", "gpt-4", + "-b", "gpt-3.5", + "-o", "out.txt", + "-q", + "-v", + }, + want: flagValues{ + files: []string{"a.go", "b.go"}, + prompts: []string{"first prompt"}, + promptFile: "prompt.txt", + system: "system prompt", + key: "key123", + model: "gpt-4", + fallback: "gpt-3.5", + output: "out.txt", + quiet: true, + verbose: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetFlags() + + got, err := parseFlags(tt.args) + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestParseFlagsErrors(t *testing.T) { + tests := []struct { + name string + args []string + }{ + { + name: "unknown flag", + args: []string{"--unknown"}, + }, + { + name: "flag without value", + args: []string{"-f"}, + }, + { + name: "model without value", + args: []string{"-m"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetFlags() + + _, err := parseFlags(tt.args) + assert.Error(t, err, "parseFlags() should return error for %s", tt.name) + }) + } +} diff --git a/config/merge.go b/config/merge.go new file mode 100644 index 0000000..393e472 --- /dev/null +++ b/config/merge.go @@ -0,0 +1,112 @@ +package config + +import ( + "os" + "strings" +) + +func mergeSources(flags flagValues, env envValues, file fileValues) ConfigData { + cfg := defaultConfig + + // Apply file values + if file.protocol != "" { + cfg.Protocol = parseProtocol(file.protocol) + } + if file.url != "" { + cfg.URL = file.url + } + if file.model != "" { + cfg.Model = file.model + } + if file.fallback != "" { + cfg.FallbackModels = strings.Split(file.fallback, ",") + } + + // Apply env values + if env.protocol != "" { + cfg.Protocol = parseProtocol(env.protocol) + } + if env.url != "" { + cfg.URL = env.url + } + if env.model != "" { + cfg.Model = env.model + } + if env.fallback != "" { + cfg.FallbackModels = strings.Split(env.fallback, ",") + } + if env.system != "" { + cfg.SystemPrompt = env.system + } + if env.key != "" { + cfg.APIKey = env.key + } + + // Apply flag values + if flags.protocol != "" { + cfg.Protocol = parseProtocol(flags.protocol) + } + if flags.url != "" { + cfg.URL = flags.url + } + if flags.model != "" { + cfg.Model = flags.model + } + if flags.fallback != "" { + cfg.FallbackModels = strings.Split(flags.fallback, ",") + } + if flags.output != "" { + cfg.Output = flags.output + } + cfg.Quiet = flags.quiet + cfg.Verbose = flags.verbose + cfg.StdinAsFile = flags.stdinFile + + // Collect input paths + cfg.FilePaths = flags.files + cfg.PromptFlags = flags.prompts + if flags.promptFile != "" { + cfg.PromptPaths = []string{flags.promptFile} + } + + // Resolve system prompt (direct > file) + if flags.system != "" { + cfg.SystemPrompt = flags.system + } else if flags.systemFile != "" { + content, err := os.ReadFile(flags.systemFile) + if err == nil { + cfg.SystemPrompt = strings.TrimRight(string(content), "\n") + } + } else if file.systemFile != "" && cfg.SystemPrompt == "" { + content, err := os.ReadFile(file.systemFile) + if err == nil { + cfg.SystemPrompt = strings.TrimRight(string(content), "\n") + } + } + + // Resolve API key (direct > file) + if flags.key != "" { + cfg.APIKey = flags.key + } else if flags.keyFile != "" { + content, err := os.ReadFile(flags.keyFile) + if err == nil { + cfg.APIKey = strings.TrimSpace(string(content)) + } + } else if cfg.APIKey == "" && file.keyFile != "" { + content, err := os.ReadFile(file.keyFile) + if err == nil { + cfg.APIKey = strings.TrimSpace(string(content)) + } + } + + return cfg +} + +func parseProtocol(s string) APIProtocol { + switch s { + case "ollama": + return ProtocolOllama + default: + return ProtocolOpenAI + } +} diff --git a/config/merge_test.go b/config/merge_test.go new file mode 100644 index 0000000..b5c81d9 --- /dev/null +++ b/config/merge_test.go @@ -0,0 +1,373 @@ +package config + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestMergeSources(t *testing.T) { + tests := []struct { + name string + flags flagValues + env envValues + file fileValues + want ConfigData + }{ + { + name: "all empty uses defaults", + flags: flagValues{}, + env: envValues{}, + file: fileValues{}, + want: defaultConfig, + }, + { + name: "file overrides defaults", + flags: flagValues{}, + env: envValues{}, + file: fileValues{ + protocol: "ollama", + model: "llama3", + }, + want: ConfigData{ + Protocol: ProtocolOllama, + URL: "https://api.ppq.ai/chat/completions", + Model: "llama3", + FallbackModels: []string{"gpt-4.1-mini"}, + }, + }, + { + name: "env overrides file", + flags: flagValues{}, + env: envValues{ + model: "gpt-4", + }, + file: fileValues{ + model: "llama3", + }, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4", + FallbackModels: []string{"gpt-4.1-mini"}, + }, + }, + { + name: "flags override env", + flags: flagValues{ + model: "claude-3", + }, + env: envValues{ + model: "gpt-4", + }, + file: fileValues{}, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "claude-3", + FallbackModels: []string{"gpt-4.1-mini"}, + }, + }, + { + name: "full precedence chain", + flags: flagValues{ + protocol: "ollama", + quiet: true, + }, + env: envValues{ + protocol: "openai", + model: "gpt-4", + url: "http://custom.api", + }, + file: fileValues{ + protocol: "openai", + model: "llama3", + url: "http://file.api", + fallback: "mistral", + }, + want: ConfigData{ + Protocol: ProtocolOllama, + URL: "http://custom.api", + Model: "gpt-4", + FallbackModels: []string{"mistral"}, + Quiet: true, + }, + }, + { + name: "fallback string split", + flags: flagValues{ + fallback: "model1,model2,model3", + }, + env: envValues{}, + file: fileValues{}, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"model1", "model2", "model3"}, + }, + }, + { + name: "direct key flag", + flags: flagValues{ + key: "sk-direct", + }, + env: envValues{}, + file: fileValues{}, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + APIKey: "sk-direct", + }, + }, + { + name: "direct system flag", + flags: flagValues{ + system: "You are helpful", + }, + env: envValues{}, + file: fileValues{}, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + SystemPrompt: "You are helpful", + }, + }, + { + name: "file paths collected", + flags: flagValues{ + files: []string{"a.go", "b.go"}, + prompts: []string{"prompt1", "prompt2"}, + promptFile: "prompt.txt", + }, + env: envValues{}, + file: fileValues{}, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + FilePaths: []string{"a.go", "b.go"}, + PromptFlags: []string{"prompt1", "prompt2"}, + PromptPaths: []string{"prompt.txt"}, + }, + }, + { + name: "stdin file flag", + flags: flagValues{ + stdinFile: true, + }, + env: envValues{}, + file: fileValues{}, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + StdinAsFile: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeSources(tt.flags, tt.env, tt.file) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestMergeSourcesKeyFile(t *testing.T) { + tests := []struct { + name string + flags flagValues + env envValues + file fileValues + want ConfigData + }{ + { + name: "key file from flags", + flags: flagValues{ + keyFile: "testdata/api.key", + }, + env: envValues{}, + file: fileValues{}, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + APIKey: "sk-test-key-123", + }, + }, + { + name: "key file from file config", + flags: flagValues{}, + env: envValues{}, + file: fileValues{ + keyFile: "testdata/api.key", + }, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + APIKey: "sk-test-key-123", + }, + }, + { + name: "direct key overrides key file", + flags: flagValues{ + key: "sk-direct", + keyFile: "testdata/api.key", + }, + env: envValues{}, + file: fileValues{}, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + APIKey: "sk-direct", + }, + }, + { + name: "env key overrides file key file", + flags: flagValues{}, + env: envValues{ + key: "sk-env", + }, + file: fileValues{ + keyFile: "testdata/api.key", + }, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + APIKey: "sk-env", + }, + }, + { + name: "key file with whitespace trimmed", + flags: flagValues{ + keyFile: "testdata/api_whitespace.key", + }, + env: envValues{}, + file: fileValues{}, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + APIKey: "sk-whitespace-key", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeSources(tt.flags, tt.env, tt.file) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestMergeSourcesSystemFile(t *testing.T) { + tests := []struct { + name string + flags flagValues + env envValues + file fileValues + want ConfigData + }{ + { + name: "system file from flags", + flags: flagValues{ + systemFile: "testdata/system.txt", + }, + env: envValues{}, + file: fileValues{}, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + SystemPrompt: "You are a helpful assistant.", + }, + }, + { + name: "system file from file config", + flags: flagValues{}, + env: envValues{}, + file: fileValues{ + systemFile: "testdata/system.txt", + }, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + SystemPrompt: "You are a helpful assistant.", + }, + }, + { + name: "direct system overrides system file", + flags: flagValues{ + system: "Direct system", + systemFile: "testdata/system.txt", + }, + env: envValues{}, + file: fileValues{}, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + SystemPrompt: "Direct system", + }, + }, + { + name: "env system overrides file system file", + flags: flagValues{}, + env: envValues{ + system: "System from env", + }, + file: fileValues{ + systemFile: "testdata/system.txt", + }, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + SystemPrompt: "System from env", + }, + }, + { + name: "empty system file", + flags: flagValues{ + systemFile: "testdata/system_empty.txt", + }, + env: envValues{}, + file: fileValues{}, + want: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.ppq.ai/chat/completions", + Model: "gpt-4o-mini", + FallbackModels: []string{"gpt-4.1-mini"}, + SystemPrompt: "", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeSources(tt.flags, tt.env, tt.file) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/config/testdata/api.key b/config/testdata/api.key new file mode 100644 index 0000000..507e39f --- /dev/null +++ b/config/testdata/api.key @@ -0,0 +1 @@ +sk-test-key-123 diff --git a/config/testdata/api_whitespace.key b/config/testdata/api_whitespace.key new file mode 100644 index 0000000..1d1d529 --- /dev/null +++ b/config/testdata/api_whitespace.key @@ -0,0 +1 @@ + sk-whitespace-key diff --git a/config/testdata/empty.yaml b/config/testdata/empty.yaml new file mode 100644 index 0000000..e69de29 diff --git a/config/testdata/invalid.yaml b/config/testdata/invalid.yaml new file mode 100644 index 0000000..90c9aa4 --- /dev/null +++ b/config/testdata/invalid.yaml @@ -0,0 +1,3 @@ +protocol: openai +url: [this is not valid yaml syntax +model: gpt-4 diff --git a/config/testdata/partial.yaml b/config/testdata/partial.yaml new file mode 100644 index 0000000..c941a1b --- /dev/null +++ b/config/testdata/partial.yaml @@ -0,0 +1,2 @@ +model: gpt-4 +fallback: gpt-3.5-turbo diff --git a/config/testdata/system.txt b/config/testdata/system.txt new file mode 100644 index 0000000..64e0dbc --- /dev/null +++ b/config/testdata/system.txt @@ -0,0 +1 @@ +You are a helpful assistant. diff --git a/config/testdata/system_empty.txt b/config/testdata/system_empty.txt new file mode 100644 index 0000000..e69de29 diff --git a/config/testdata/unknown_keys.yaml b/config/testdata/unknown_keys.yaml new file mode 100644 index 0000000..e69c92b --- /dev/null +++ b/config/testdata/unknown_keys.yaml @@ -0,0 +1,4 @@ +protocol: openai +model: gpt-4 +unknown_field: ignored +another_unknown: also_ignored diff --git a/config/testdata/valid.yaml b/config/testdata/valid.yaml new file mode 100644 index 0000000..809b49f --- /dev/null +++ b/config/testdata/valid.yaml @@ -0,0 +1,6 @@ +protocol: ollama +url: http://localhost:11434/api/chat +key_file: ~/.aicli_key +model: llama3 +fallback: llama2,mistral +system_file: ~/system.txt diff --git a/config/types.go b/config/types.go new file mode 100644 index 0000000..7cb6f85 --- /dev/null +++ b/config/types.go @@ -0,0 +1,71 @@ +package config + +type APIProtocol int + +const ( + ProtocolOpenAI APIProtocol = iota + ProtocolOllama +) + +type ConfigData struct { + // Input + FilePaths []string + PromptFlags []string + PromptPaths []string + StdinAsFile bool + + // System + SystemPrompt string + + // API + Protocol APIProtocol + URL string + APIKey string + + // Models + Model string + FallbackModels []string + + // Output + Output string + Quiet bool + Verbose bool +} + +type flagValues struct { + files []string + prompts []string + promptFile string + system string + systemFile string + key string + keyFile string + protocol string + url string + model string + fallback string + output string + config string + stdinFile bool + quiet bool + verbose bool + version bool +} + +type envValues struct { + protocol string + url string + key string + model string + fallback string + system string +} + +type fileValues struct { + protocol string + url string + keyFile string + model string + fallback string + systemFile string +} diff --git a/config/validate.go b/config/validate.go new file mode 100644 index 0000000..a104deb --- /dev/null +++ b/config/validate.go @@ -0,0 +1,17 @@ +package config + +import ( + "fmt" +) + +func validateConfig(cfg ConfigData) error { + if cfg.APIKey == "" { + return fmt.Errorf("API key required: use --key, --key-file, AICLI_API_KEY, AICLI_API_KEY_FILE, or key_file in config") + } + + if cfg.Protocol != ProtocolOpenAI && cfg.Protocol != ProtocolOllama { + return fmt.Errorf("invalid protocol: must be openai or ollama") + } + + return nil +} diff --git a/config/validate_test.go b/config/validate_test.go new file mode 100644 index 0000000..1a524cc --- /dev/null +++ b/config/validate_test.go @@ -0,0 +1,74 @@ +package config + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestValidateConfig(t *testing.T) { + tests := []struct { + name string + cfg ConfigData + wantErr bool + errMsg string + }{ + { + name: "valid config", + cfg: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.openai.com", + Model: "gpt-4", + FallbackModels: []string{"gpt-3.5"}, + APIKey: "sk-test123", + }, + wantErr: false, + }, + { + name: "missing api key", + cfg: ConfigData{ + Protocol: ProtocolOpenAI, + URL: "https://api.openai.com", + Model: "gpt-4", + FallbackModels: []string{"gpt-3.5"}, + APIKey: "", + }, + wantErr: true, + errMsg: "API key required", + }, + { + name: "invalid protocol", + cfg: ConfigData{ + Protocol: APIProtocol(99), + URL: "https://api.openai.com", + Model: "gpt-4", + FallbackModels: []string{"gpt-3.5"}, + APIKey: "sk-test123", + }, + wantErr: true, + errMsg: "invalid protocol", + }, + { + name: "ollama protocol valid", + cfg: ConfigData{ + Protocol: ProtocolOllama, + URL: "http://localhost:11434", + Model: "llama3", + FallbackModels: []string{}, + APIKey: "not-used-but-required", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateConfig(tt.cfg) + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/go.mod b/go.mod index 61f98b5..501b1ca 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,10 @@ module git.wisehodl.dev/jay/aicli go 1.23.5 -require gopkg.in/yaml.v3 v3.0.1 +require github.com/stretchr/testify v1.11.1 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index a62c313..c4c1710 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,9 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/input/aggregate.go b/input/aggregate.go new file mode 100644 index 0000000..c489461 --- /dev/null +++ b/input/aggregate.go @@ -0,0 +1,32 @@ +package input + +// AggregatePrompts combines prompt sources with stdin based on role. +func AggregatePrompts(prompts []string, stdin string, role StdinRole) []string { + switch role { + case StdinAsPrompt: + if stdin != "" { + return []string{stdin} + } + return prompts + + case StdinAsPrefixedContent: + if stdin != "" { + return append(prompts, stdin) + } + return prompts + + case StdinAsFile: + return prompts + + default: + return prompts + } +} + +// AggregateFiles combines file sources with stdin based on role. +func AggregateFiles(files []FileData, stdin string, role StdinRole) []FileData { + if role == StdinAsFile && stdin != "" { + return append([]FileData{{Path: "input", Content: stdin}}, files...) + } + return files +} diff --git a/input/aggregate_test.go b/input/aggregate_test.go new file mode 100644 index 0000000..0bbf44c --- /dev/null +++ b/input/aggregate_test.go @@ -0,0 +1,174 @@ +package input + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAggregatePrompts(t *testing.T) { + tests := []struct { + name string + prompts []string + stdin string + role StdinRole + want []string + }{ + { + name: "empty inputs returns empty", + prompts: []string{}, + stdin: "", + role: StdinAsPrompt, + want: []string{}, + }, + { + name: "stdin as prompt with no other prompts", + prompts: []string{}, + stdin: "stdin content", + role: StdinAsPrompt, + want: []string{"stdin content"}, + }, + { + name: "stdin as prompt replaces existing prompts", + prompts: []string{"prompt1", "prompt2"}, + stdin: "stdin content", + role: StdinAsPrompt, + want: []string{"stdin content"}, + }, + { + name: "no stdin with role prompt returns prompts unchanged", + prompts: []string{"prompt1", "prompt2"}, + stdin: "", + role: StdinAsPrompt, + want: []string{"prompt1", "prompt2"}, + }, + { + name: "stdin as prefixed appends to prompts", + prompts: []string{"prompt1", "prompt2"}, + stdin: "stdin content", + role: StdinAsPrefixedContent, + want: []string{"prompt1", "prompt2", "stdin content"}, + }, + { + name: "stdin as prefixed with no prompts", + prompts: []string{}, + stdin: "stdin content", + role: StdinAsPrefixedContent, + want: []string{"stdin content"}, + }, + { + name: "no stdin with role prefixed returns prompts unchanged", + prompts: []string{"prompt1"}, + stdin: "", + role: StdinAsPrefixedContent, + want: []string{"prompt1"}, + }, + { + name: "stdin as file excludes stdin from prompts", + prompts: []string{"prompt1"}, + stdin: "stdin content", + role: StdinAsFile, + want: []string{"prompt1"}, + }, + { + name: "no stdin with role file returns prompts unchanged", + prompts: []string{"prompt1"}, + stdin: "", + role: StdinAsFile, + want: []string{"prompt1"}, + }, + { + name: "empty string stdin with role prompt", + prompts: []string{"prompt1"}, + stdin: "", + role: StdinAsPrompt, + want: []string{"prompt1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := AggregatePrompts(tt.prompts, tt.stdin, tt.role) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestAggregateFiles(t *testing.T) { + tests := []struct { + name string + files []FileData + stdin string + role StdinRole + want []FileData + }{ + { + name: "empty inputs returns empty", + files: []FileData{}, + stdin: "", + role: StdinAsFile, + want: []FileData{}, + }, + { + name: "stdin as file prepends to files", + files: []FileData{{Path: "a.go", Content: "code"}}, + stdin: "stdin content", + role: StdinAsFile, + want: []FileData{ + {Path: "input", Content: "stdin content"}, + {Path: "a.go", Content: "code"}, + }, + }, + { + name: "stdin as file with no other files", + files: []FileData{}, + stdin: "stdin content", + role: StdinAsFile, + want: []FileData{ + {Path: "input", Content: "stdin content"}, + }, + }, + { + name: "no stdin with role file returns files unchanged", + files: []FileData{{Path: "a.go", Content: "code"}}, + stdin: "", + role: StdinAsFile, + want: []FileData{{Path: "a.go", Content: "code"}}, + }, + { + name: "stdin as prompt excludes stdin from files", + files: []FileData{{Path: "a.go", Content: "code"}}, + stdin: "stdin content", + role: StdinAsPrompt, + want: []FileData{{Path: "a.go", Content: "code"}}, + }, + { + name: "stdin as prefixed excludes stdin from files", + files: []FileData{{Path: "a.go", Content: "code"}}, + stdin: "stdin content", + role: StdinAsPrefixedContent, + want: []FileData{{Path: "a.go", Content: "code"}}, + }, + { + name: "stdin as file with multiple files", + files: []FileData{ + {Path: "a.go", Content: "code a"}, + {Path: "b.go", Content: "code b"}, + }, + stdin: "stdin content", + role: StdinAsFile, + want: []FileData{ + {Path: "input", Content: "stdin content"}, + {Path: "a.go", Content: "code a"}, + {Path: "b.go", Content: "code b"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := AggregateFiles(tt.files, tt.stdin, tt.role) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/input/input.go b/input/input.go new file mode 100644 index 0000000..249340f --- /dev/null +++ b/input/input.go @@ -0,0 +1,39 @@ +package input + +import ( + "fmt" + + "git.wisehodl.dev/jay/aicli/config" +) + +// ResolveInputs orchestrates the complete input resolution pipeline. +// Returns aggregated prompts and files ready for query construction. +func ResolveInputs(cfg config.ConfigData, stdinContent string, hasStdin bool) (InputData, error) { + // Determine stdin role (CA -> CB) + role := DetermineRole(cfg, hasStdin) + + // Read all sources (CC, CD) + prompts, err := ReadPromptSources(cfg) + if err != nil { + return InputData{}, err + } + + files, err := ReadFileSources(cfg) + if err != nil { + return InputData{}, err + } + + // Aggregate with stdin (CE, CF) + finalPrompts := AggregatePrompts(prompts, stdinContent, role) + finalFiles := AggregateFiles(files, stdinContent, role) + + // Validate at least one input exists + if len(finalPrompts) == 0 && len(finalFiles) == 0 { + return InputData{}, fmt.Errorf("no input provided: supply stdin, --file, or --prompt") + } + + return InputData{ + Prompts: finalPrompts, + Files: finalFiles, + }, nil +} diff --git a/input/input_test.go b/input/input_test.go new file mode 100644 index 0000000..6c688d0 --- /dev/null +++ b/input/input_test.go @@ -0,0 +1,206 @@ +package input + +import ( + "testing" + + "git.wisehodl.dev/jay/aicli/config" + "github.com/stretchr/testify/assert" +) + +func TestResolveInputs(t *testing.T) { + tests := []struct { + name string + cfg config.ConfigData + stdinContent string + hasStdin bool + want InputData + wantErr bool + errContains string + }{ + { + name: "no input returns error", + cfg: config.ConfigData{}, + stdinContent: "", + hasStdin: false, + wantErr: true, + errContains: "no input provided", + }, + { + name: "stdin only as prompt", + cfg: config.ConfigData{}, + stdinContent: "analyze this", + hasStdin: true, + want: InputData{ + Prompts: []string{"analyze this"}, + Files: []FileData{}, + }, + }, + { + name: "prompt flag only", + cfg: config.ConfigData{ + PromptFlags: []string{"test prompt"}, + }, + stdinContent: "", + hasStdin: false, + want: InputData{ + Prompts: []string{"test prompt"}, + Files: []FileData{}, + }, + }, + { + name: "file flag only", + cfg: config.ConfigData{ + FilePaths: []string{"testdata/code.go"}, + }, + stdinContent: "", + hasStdin: false, + want: InputData{ + Prompts: []string{}, + Files: []FileData{ + {Path: "testdata/code.go", Content: "package main\n\nfunc main() {\n\tprintln(\"hello\")\n}\n"}, + }, + }, + }, + { + name: "stdin as file with -F flag", + cfg: config.ConfigData{ + StdinAsFile: true, + }, + stdinContent: "stdin content", + hasStdin: true, + want: InputData{ + Prompts: []string{}, + Files: []FileData{ + {Path: "input", Content: "stdin content"}, + }, + }, + }, + { + name: "stdin as file with -F and explicit files", + cfg: config.ConfigData{ + StdinAsFile: true, + FilePaths: []string{"testdata/code.go"}, + }, + stdinContent: "stdin content", + hasStdin: true, + want: InputData{ + Prompts: []string{}, + Files: []FileData{ + {Path: "input", Content: "stdin content"}, + {Path: "testdata/code.go", Content: "package main\n\nfunc main() {\n\tprintln(\"hello\")\n}\n"}, + }, + }, + }, + { + name: "stdin prefixed with explicit prompt", + cfg: config.ConfigData{ + PromptFlags: []string{"analyze"}, + }, + stdinContent: "code to analyze", + hasStdin: true, + want: InputData{ + Prompts: []string{"analyze", "code to analyze"}, + Files: []FileData{}, + }, + }, + { + name: "prompt from file", + cfg: config.ConfigData{ + PromptPaths: []string{"testdata/prompt1.txt"}, + }, + stdinContent: "", + hasStdin: false, + want: InputData{ + Prompts: []string{"Analyze the following code.\n"}, + Files: []FileData{}, + }, + }, + { + name: "complete scenario: prompts, files, stdin", + cfg: config.ConfigData{ + PromptFlags: []string{"review this"}, + FilePaths: []string{"testdata/code.go"}, + }, + stdinContent: "additional context", + hasStdin: true, + want: InputData{ + Prompts: []string{"review this", "additional context"}, + Files: []FileData{ + {Path: "testdata/code.go", Content: "package main\n\nfunc main() {\n\tprintln(\"hello\")\n}\n"}, + }, + }, + }, + { + name: "file read error propagates", + cfg: config.ConfigData{ + FilePaths: []string{"testdata/nonexistent.go"}, + }, + stdinContent: "", + hasStdin: false, + wantErr: true, + errContains: "read file", + }, + { + name: "prompt file read error propagates", + cfg: config.ConfigData{ + PromptPaths: []string{"testdata/missing.txt"}, + }, + stdinContent: "", + hasStdin: false, + wantErr: true, + errContains: "read prompt file", + }, + { + name: "empty file path error propagates", + cfg: config.ConfigData{ + FilePaths: []string{""}, + }, + stdinContent: "", + hasStdin: false, + wantErr: true, + errContains: "empty file path", + }, + { + name: "stdin replaces prompts when no explicit flags", + cfg: config.ConfigData{ + PromptFlags: []string{}, + }, + stdinContent: "stdin prompt", + hasStdin: true, + want: InputData{ + Prompts: []string{"stdin prompt"}, + Files: []FileData{}, + }, + }, + { + name: "multiple files in order", + cfg: config.ConfigData{ + FilePaths: []string{"testdata/code.go", "testdata/data.json"}, + }, + stdinContent: "", + hasStdin: false, + want: InputData{ + Prompts: []string{}, + Files: []FileData{ + {Path: "testdata/code.go", Content: "package main\n\nfunc main() {\n\tprintln(\"hello\")\n}\n"}, + {Path: "testdata/data.json", Content: "{\n \"name\": \"test\",\n \"value\": 42\n}\n"}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ResolveInputs(tt.cfg, tt.stdinContent, tt.hasStdin) + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/input/role.go b/input/role.go new file mode 100644 index 0000000..168296e --- /dev/null +++ b/input/role.go @@ -0,0 +1,26 @@ +package input + +import "git.wisehodl.dev/jay/aicli/config" + +// DetermineRole decides how stdin content participates in the query based on +// flags and stdin presence. Per spec ยง7 rules. +func DetermineRole(cfg config.ConfigData, hasStdin bool) StdinRole { + if !hasStdin { + return StdinAsPrompt // unused, but set for consistency + } + + // Explicit -F flag forces stdin as file + if cfg.StdinAsFile { + return StdinAsFile + } + + // Any explicit prompt flag (-p or -pf) makes stdin prefixed content + hasExplicitPrompt := len(cfg.PromptFlags) > 0 || len(cfg.PromptPaths) > 0 + + if hasExplicitPrompt { + return StdinAsPrefixedContent + } + + // Default: stdin replaces any default prompt + return StdinAsPrompt +} diff --git a/input/role_test.go b/input/role_test.go new file mode 100644 index 0000000..690addb --- /dev/null +++ b/input/role_test.go @@ -0,0 +1,95 @@ +package input + +import ( + "testing" + + "git.wisehodl.dev/jay/aicli/config" + "github.com/stretchr/testify/assert" +) + +func TestDetermineRole(t *testing.T) { + tests := []struct { + name string + cfg config.ConfigData + hasStdin bool + want StdinRole + }{ + { + name: "no stdin returns StdinAsPrompt", + cfg: config.ConfigData{}, + hasStdin: false, + want: StdinAsPrompt, + }, + { + name: "stdin with no flags returns StdinAsPrompt", + cfg: config.ConfigData{}, + hasStdin: true, + want: StdinAsPrompt, + }, + { + name: "stdin with -p flag returns StdinAsPrefixedContent", + cfg: config.ConfigData{ + PromptFlags: []string{"analyze this"}, + }, + hasStdin: true, + want: StdinAsPrefixedContent, + }, + { + name: "stdin with -pf flag returns StdinAsPrefixedContent", + cfg: config.ConfigData{ + PromptPaths: []string{"prompt.txt"}, + }, + hasStdin: true, + want: StdinAsPrefixedContent, + }, + { + name: "stdin with -F flag returns StdinAsFile", + cfg: config.ConfigData{ + StdinAsFile: true, + }, + hasStdin: true, + want: StdinAsFile, + }, + { + name: "stdin with -F and -p returns StdinAsFile (explicit wins)", + cfg: config.ConfigData{ + StdinAsFile: true, + PromptFlags: []string{"analyze"}, + }, + hasStdin: true, + want: StdinAsFile, + }, + { + name: "stdin with file flags returns StdinAsPrompt", + cfg: config.ConfigData{ + FilePaths: []string{"main.go"}, + }, + hasStdin: true, + want: StdinAsPrompt, + }, + { + name: "no stdin with -F returns StdinAsFile (role set but unused)", + cfg: config.ConfigData{ + StdinAsFile: true, + }, + hasStdin: false, + want: StdinAsPrompt, + }, + { + name: "stdin with both -p and -pf returns StdinAsPrefixedContent", + cfg: config.ConfigData{ + PromptFlags: []string{"first"}, + PromptPaths: []string{"prompt.txt"}, + }, + hasStdin: true, + want: StdinAsPrefixedContent, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := DetermineRole(tt.cfg, tt.hasStdin) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/input/sources.go b/input/sources.go new file mode 100644 index 0000000..1bcfacd --- /dev/null +++ b/input/sources.go @@ -0,0 +1,52 @@ +package input + +import ( + "fmt" + "os" + + "git.wisehodl.dev/jay/aicli/config" +) + +// ReadPromptSources reads all prompt content from flags and files. +// Returns arrays of prompt strings in source order. +func ReadPromptSources(cfg config.ConfigData) ([]string, error) { + prompts := []string{} + + // Add flag prompts first + prompts = append(prompts, cfg.PromptFlags...) + + // Add prompt file contents + for _, path := range cfg.PromptPaths { + content, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read prompt file %s: %w", path, err) + } + prompts = append(prompts, string(content)) + } + + return prompts, nil +} + +// ReadFileSources reads all input files specified in config. +// Returns FileData array in source order. +func ReadFileSources(cfg config.ConfigData) ([]FileData, error) { + files := []FileData{} + + for _, path := range cfg.FilePaths { + if path == "" { + return nil, fmt.Errorf("empty file path provided") + } + + content, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read file %s: %w", path, err) + } + + files = append(files, FileData{ + Path: path, + Content: string(content), + }) + } + + return files, nil +} diff --git a/input/sources_test.go b/input/sources_test.go new file mode 100644 index 0000000..ec4f214 --- /dev/null +++ b/input/sources_test.go @@ -0,0 +1,183 @@ +package input + +import ( + "testing" + + "git.wisehodl.dev/jay/aicli/config" + "github.com/stretchr/testify/assert" +) + +func TestReadPromptSources(t *testing.T) { + tests := []struct { + name string + cfg config.ConfigData + want []string + wantErr bool + }{ + { + name: "no prompts returns empty", + cfg: config.ConfigData{}, + want: []string{}, + }, + { + name: "single flag prompt", + cfg: config.ConfigData{ + PromptFlags: []string{"analyze this"}, + }, + want: []string{"analyze this"}, + }, + { + name: "multiple flag prompts", + cfg: config.ConfigData{ + PromptFlags: []string{"first", "second", "third"}, + }, + want: []string{"first", "second", "third"}, + }, + { + name: "single prompt file", + cfg: config.ConfigData{ + PromptPaths: []string{"testdata/prompt1.txt"}, + }, + want: []string{"Analyze the following code.\n"}, + }, + { + name: "multiple prompt files", + cfg: config.ConfigData{ + PromptPaths: []string{"testdata/prompt1.txt", "testdata/prompt2.txt"}, + }, + want: []string{ + "Analyze the following code.\n", + "Focus on:\n- Performance\n- Security\n- Readability\n", + }, + }, + { + name: "empty prompt file", + cfg: config.ConfigData{ + PromptPaths: []string{"testdata/prompt_empty.txt"}, + }, + want: []string{""}, + }, + { + name: "flags and files combined", + cfg: config.ConfigData{ + PromptFlags: []string{"first flag", "second flag"}, + PromptPaths: []string{"testdata/prompt1.txt"}, + }, + want: []string{ + "first flag", + "second flag", + "Analyze the following code.\n", + }, + }, + { + name: "file not found", + cfg: config.ConfigData{ + PromptPaths: []string{"testdata/nonexistent.txt"}, + }, + wantErr: true, + }, + { + name: "mixed valid and invalid", + cfg: config.ConfigData{ + PromptFlags: []string{"valid flag"}, + PromptPaths: []string{"testdata/nonexistent.txt"}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ReadPromptSources(tt.cfg) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestReadFileSources(t *testing.T) { + tests := []struct { + name string + cfg config.ConfigData + want []FileData + wantErr bool + }{ + { + name: "no files returns empty", + cfg: config.ConfigData{}, + want: []FileData{}, + }, + { + name: "single file", + cfg: config.ConfigData{ + FilePaths: []string{"testdata/code.go"}, + }, + want: []FileData{ + { + Path: "testdata/code.go", + Content: "package main\n\nfunc main() {\n\tprintln(\"hello\")\n}\n", + }, + }, + }, + { + name: "multiple files", + cfg: config.ConfigData{ + FilePaths: []string{"testdata/code.go", "testdata/data.json"}, + }, + want: []FileData{ + { + Path: "testdata/code.go", + Content: "package main\n\nfunc main() {\n\tprintln(\"hello\")\n}\n", + }, + { + Path: "testdata/data.json", + Content: "{\n \"name\": \"test\",\n \"value\": 42\n}\n", + }, + }, + }, + { + name: "empty file path", + cfg: config.ConfigData{ + FilePaths: []string{""}, + }, + wantErr: true, + }, + { + name: "file not found", + cfg: config.ConfigData{ + FilePaths: []string{"testdata/nonexistent.go"}, + }, + wantErr: true, + }, + { + name: "permission denied", + cfg: config.ConfigData{ + FilePaths: []string{"/root/secret.txt"}, + }, + wantErr: true, + }, + { + name: "mixed valid and invalid", + cfg: config.ConfigData{ + FilePaths: []string{"testdata/code.go", "testdata/nonexistent.go"}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ReadFileSources(tt.cfg) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/input/stdin.go b/input/stdin.go new file mode 100644 index 0000000..1d04e1b --- /dev/null +++ b/input/stdin.go @@ -0,0 +1,28 @@ +package input + +import ( + "io" + "os" +) + +// DetectStdin checks if stdin contains piped data and reads it. +// Returns content and true if stdin is a pipe/file, empty string and false if terminal. +func DetectStdin() (string, bool) { + stat, err := os.Stdin.Stat() + if err != nil { + return "", false + } + + // Terminal (character device) = no stdin data + if (stat.Mode() & os.ModeCharDevice) != 0 { + return "", false + } + + // Pipe or file redirection detected + content, err := io.ReadAll(os.Stdin) + if err != nil { + return "", false + } + + return string(content), true +} diff --git a/input/stdin_test.go b/input/stdin_test.go new file mode 100644 index 0000000..35bf892 --- /dev/null +++ b/input/stdin_test.go @@ -0,0 +1,27 @@ +package input + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Integration test helper - run manually with: +// echo "test" | STDIN_TEST=1 go test ./input/ -run TestDetectStdinIntegration +func TestDetectStdinIntegration(t *testing.T) { + if os.Getenv("STDIN_TEST") != "1" { + t.Skip("Set STDIN_TEST=1 and pipe data to run this test") + } + + content, hasStdin := DetectStdin() + + t.Logf("hasStdin: %v", hasStdin) + t.Logf("content length: %d", len(content)) + t.Logf("content: %q", content) + + // When piped: should detect stdin + if hasStdin { + assert.NotEmpty(t, content, "Expected content when stdin detected") + } +} diff --git a/input/testdata/code.go b/input/testdata/code.go new file mode 100644 index 0000000..4a73987 --- /dev/null +++ b/input/testdata/code.go @@ -0,0 +1,5 @@ +package main + +func main() { + println("hello") +} diff --git a/input/testdata/data.json b/input/testdata/data.json new file mode 100644 index 0000000..715f06a --- /dev/null +++ b/input/testdata/data.json @@ -0,0 +1,4 @@ +{ + "name": "test", + "value": 42 +} diff --git a/input/testdata/prompt1.txt b/input/testdata/prompt1.txt new file mode 100644 index 0000000..c4c6225 --- /dev/null +++ b/input/testdata/prompt1.txt @@ -0,0 +1 @@ +Analyze the following code. diff --git a/input/testdata/prompt2.txt b/input/testdata/prompt2.txt new file mode 100644 index 0000000..6fdad71 --- /dev/null +++ b/input/testdata/prompt2.txt @@ -0,0 +1,4 @@ +Focus on: +- Performance +- Security +- Readability diff --git a/input/testdata/prompt_empty.txt b/input/testdata/prompt_empty.txt new file mode 100644 index 0000000..e69de29 diff --git a/input/types.go b/input/types.go new file mode 100644 index 0000000..34b5361 --- /dev/null +++ b/input/types.go @@ -0,0 +1,27 @@ +package input + +// StdinRole determines how stdin content participates in the query +type StdinRole int + +const ( + // StdinAsPrompt: stdin becomes the entire prompt (replaces other prompts) + StdinAsPrompt StdinRole = iota + + // StdinAsPrefixedContent: stdin appends after explicit prompts + StdinAsPrefixedContent + + // StdinAsFile: stdin becomes first file in files array + StdinAsFile +) + +// FileData represents a single input file +type FileData struct { + Path string + Content string +} + +// InputData holds all resolved input streams after aggregation +type InputData struct { + Prompts []string + Files []FileData +} diff --git a/main.go b/main.go index cc2422c..c90accf 100644 --- a/main.go +++ b/main.go @@ -1,154 +1,17 @@ package main import ( - "bytes" - "encoding/json" - "flag" "fmt" - "io" - "net/http" "os" - "strings" - "time" + "git.wisehodl.dev/jay/aicli/api" + "git.wisehodl.dev/jay/aicli/config" + "git.wisehodl.dev/jay/aicli/input" + "git.wisehodl.dev/jay/aicli/output" + "git.wisehodl.dev/jay/aicli/prompt" "git.wisehodl.dev/jay/aicli/version" - "gopkg.in/yaml.v3" ) -const defaultPrompt = "Analyze the following:" - -type stdinRole int - -const ( - stdinAsPrompt stdinRole = iota - stdinAsPrefixedContent - stdinAsFile -) - -type Config struct { - Protocol string - URL string - Key string - Model string - Fallbacks []string - SystemText string - PromptText string - Files []FileData - OutputPath string - Quiet bool - Verbose bool -} - -type FileData struct { - Path string - Content string -} - -type flagValues struct { - files []string - prompts []string - promptFile string - system string - systemFile string - key string - keyFile string - protocol string - url string - model string - fallback string - output string - config string - stdinFile bool - quiet bool - verbose bool - showVersion bool -} - -const usageText = `Usage: aicli [OPTION]... [FILE]... -Send files and prompts to LLM chat endpoints. - -With no FILE, or when FILE is -, read standard input. - -Global: - --version display version information and exit - -Input: - -f, --file PATH input file (repeatable) - -F, --stdin-file treat stdin as file contents - -p, --prompt TEXT prompt text (repeatable, can be combined with --prompt-file) - -pf, --prompt-file PATH prompt from file (combined with any --prompt flags) - -System: - -s, --system TEXT system prompt text - -sf, --system-file PATH system prompt from file - -API: - -l, --protocol PROTO API protocol: openai, ollama (default: openai) - -u, --url URL API endpoint (default: https://api.ppq.ai/chat/completions) - -k, --key KEY API key (if present, --key-file is ignored) - -kf, --key-file PATH API key from file (used only if --key is not provided) - -Models: - -m, --model NAME primary model (default: gpt-4o-mini) - -b, --fallback NAMES comma-separated fallback models (default: gpt-4.1-mini) - -Output: - -o, --output PATH write to file instead of stdout - -q, --quiet suppress progress output - -v, --verbose enable debug logging - -Config: - -c, --config PATH YAML config file - -Environment variables: - AICLI_API_KEY API key - AICLI_API_KEY_FILE Path to file containing API key (used only if AICLI_API_KEY is not set) - AICLI_PROTOCOL API protocol - AICLI_URL API endpoint - AICLI_MODEL primary model - AICLI_FALLBACK fallback models - AICLI_SYSTEM system prompt - AICLI_DEFAULT_PROMPT default prompt override - AICLI_CONFIG_FILE Path to config file - AICLI_PROMPT_FILE Path to prompt file - AICLI_SYSTEM_FILE Path to system file - -API Key precedence: --key flag > --key-file flag > AICLI_API_KEY > AICLI_API_KEY_FILE > config file - -Examples: - echo "What is Rust?" | aicli - cat file.txt | aicli -F -p "Analyze this file" - aicli -f main.go -p "Review this code" - aicli -c ~/.aicli.yaml -f src/*.go -o analysis.md - aicli -p "First prompt" -pf prompt.txt -p "Last prompt" -` - -func printUsage() { - fmt.Fprint(os.Stderr, usageText) -} - -type fileList []string - -func (f *fileList) String() string { - return strings.Join(*f, ", ") -} - -func (f *fileList) Set(value string) error { - *f = append(*f, value) - return nil -} - -type promptList []string - -func (p *promptList) String() string { - return strings.Join(*p, "\n") -} - -func (p *promptList) Set(value string) error { - *p = append(*p, value) - return nil -} - func main() { if err := run(); err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) @@ -157,612 +20,64 @@ func main() { } func run() error { - // Check for verbose flag early - verbose := false - for _, arg := range os.Args { - if arg == "-v" || arg == "--verbose" { - verbose = true - break - } - } - - // Check for config file in environment variable before parsing flags - configFilePath := os.Getenv("AICLI_CONFIG_FILE") - - flags := parseFlags() - - if flags.showVersion { + // Phase 1: Version check (early exit) + if config.IsVersionRequest(os.Args[1:]) { fmt.Printf("aicli %s\n", version.GetVersion()) return nil } - if flags.config == "" && configFilePath != "" { - flags.config = configFilePath - } - - envVals := loadEnvVars(verbose) - fileVals, err := loadConfigFile(flags.config) - if err != nil { - return err - } - - merged := mergeConfigSources(verbose, flags, envVals, fileVals) - if err := validateConfig(merged); err != nil { - return err - } - - if promptFilePath := os.Getenv("AICLI_PROMPT_FILE"); promptFilePath != "" && flags.promptFile == "" { - content, err := os.ReadFile(promptFilePath) - if err != nil { - if verbose { - fmt.Fprintf(os.Stderr, "[verbose] Failed to read AICLI_PROMPT_FILE at %s: %v\n", promptFilePath, err) - } - } else { - merged.PromptText = string(content) - } - } - - if systemFilePath := os.Getenv("AICLI_SYSTEM_FILE"); systemFilePath != "" && flags.systemFile == "" && flags.system == "" { - content, err := os.ReadFile(systemFilePath) - if err != nil { - if verbose { - fmt.Fprintf(os.Stderr, "[verbose] Failed to read AICLI_SYSTEM_FILE at %s: %v\n", systemFilePath, err) - } - } else { - merged.SystemText = string(content) - } - } - - stdinContent, hasStdin := detectStdin() - role := determineStdinRole(flags, hasStdin) - - inputData, err := resolveInputStreams(merged, stdinContent, hasStdin, role, flags) - if err != nil { - return err - } - - config := buildCompletePrompt(inputData) - - if config.Verbose { - logVerbose("Configuration resolved", config) - } - - startTime := time.Now() - response, usedModel, err := sendChatRequest(config) - duration := time.Since(startTime) - - if err != nil { - return err - } - - return writeOutput(response, usedModel, duration, config) -} - -func parseFlags() flagValues { - fv := flagValues{} - var files fileList - var prompts promptList - - flag.Usage = printUsage - - flag.Var(&files, "f", "") - flag.Var(&files, "file", "") - flag.Var(&prompts, "p", "") - flag.Var(&prompts, "prompt", "") - flag.StringVar(&fv.promptFile, "pf", "", "") - flag.StringVar(&fv.promptFile, "prompt-file", "", "") - flag.StringVar(&fv.system, "s", "", "") - flag.StringVar(&fv.system, "system", "", "") - flag.StringVar(&fv.systemFile, "sf", "", "") - flag.StringVar(&fv.systemFile, "system-file", "", "") - flag.StringVar(&fv.key, "k", "", "") - flag.StringVar(&fv.key, "key", "", "") - flag.StringVar(&fv.keyFile, "kf", "", "") - flag.StringVar(&fv.keyFile, "key-file", "", "") - flag.StringVar(&fv.protocol, "l", "", "") - flag.StringVar(&fv.protocol, "protocol", "", "") - flag.StringVar(&fv.url, "u", "", "") - flag.StringVar(&fv.url, "url", "", "") - flag.StringVar(&fv.model, "m", "", "") - flag.StringVar(&fv.model, "model", "", "") - flag.StringVar(&fv.fallback, "b", "", "") - flag.StringVar(&fv.fallback, "fallback", "", "") - flag.StringVar(&fv.output, "o", "", "") - flag.StringVar(&fv.output, "output", "", "") - flag.StringVar(&fv.config, "c", "", "") - flag.StringVar(&fv.config, "config", "", "") - flag.BoolVar(&fv.stdinFile, "F", false, "") - flag.BoolVar(&fv.stdinFile, "stdin-file", false, "") - flag.BoolVar(&fv.quiet, "q", false, "") - flag.BoolVar(&fv.quiet, "quiet", false, "") - flag.BoolVar(&fv.verbose, "v", false, "") - flag.BoolVar(&fv.verbose, "verbose", false, "") - flag.BoolVar(&fv.showVersion, "version", false, "") - - flag.Parse() - - fv.files = files - fv.prompts = prompts - - return fv -} - -func loadEnvVars(verbose bool) map[string]string { - env := make(map[string]string) - if val := os.Getenv("AICLI_PROTOCOL"); val != "" { - env["protocol"] = val - } - if val := os.Getenv("AICLI_URL"); val != "" { - env["url"] = val - } - if val := os.Getenv("AICLI_API_KEY"); val != "" { - env["key"] = val - } - if env["key"] == "" { - if val := os.Getenv("AICLI_API_KEY_FILE"); val != "" { - content, err := os.ReadFile(val) - if err != nil && verbose { - fmt.Fprintf(os.Stderr, "[verbose] Failed to read AICLI_API_KEY_FILE at %s: %v\n", val, err) - } else { - env["key"] = strings.TrimSpace(string(content)) - } - } - } - if val := os.Getenv("AICLI_MODEL"); val != "" { - env["model"] = val - } - if val := os.Getenv("AICLI_FALLBACK"); val != "" { - env["fallback"] = val - } - if val := os.Getenv("AICLI_SYSTEM"); val != "" { - env["system"] = val - } - if val := os.Getenv("AICLI_DEFAULT_PROMPT"); val != "" { - env["prompt"] = val - } - return env -} - -func loadConfigFile(path string) (map[string]interface{}, error) { - if path == "" { - return nil, nil - } - - data, err := os.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("read config file: %w", err) - } - - var config map[string]interface{} - if err := yaml.Unmarshal(data, &config); err != nil { - return nil, fmt.Errorf("parse config file: %w", err) - } - - return config, nil -} - -func mergeConfigSources(verbose bool, flags flagValues, env map[string]string, file map[string]interface{}) Config { - cfg := Config{ - Protocol: "openai", - URL: "https://api.ppq.ai/chat/completions", - Model: "gpt-4o-mini", - Fallbacks: []string{"gpt-4.1-mini"}, - Quiet: flags.quiet, - Verbose: flags.verbose, - } - - if env["protocol"] != "" { - cfg.Protocol = env["protocol"] - } - if env["url"] != "" { - cfg.URL = env["url"] - } - if env["key"] != "" { - cfg.Key = env["key"] - } - if env["model"] != "" { - cfg.Model = env["model"] - } - if env["fallback"] != "" { - cfg.Fallbacks = strings.Split(env["fallback"], ",") - } - if env["system"] != "" { - cfg.SystemText = env["system"] - } - - if file != nil { - if v, ok := file["protocol"].(string); ok { - cfg.Protocol = v - } - if v, ok := file["url"].(string); ok { - cfg.URL = v - } - if v, ok := file["model"].(string); ok { - cfg.Model = v - } - if v, ok := file["fallback"].(string); ok { - cfg.Fallbacks = strings.Split(v, ",") - } - if v, ok := file["system_file"].(string); ok { - content, err := os.ReadFile(v) - if err != nil { - if verbose { - fmt.Fprintf(os.Stderr, "[verbose] Failed to read system_file at %s: %v\n", v, err) - } - } else { - cfg.SystemText = string(content) - } - } - if v, ok := file["key_file"].(string); ok && cfg.Key == "" { - content, err := os.ReadFile(v) - if err != nil { - if cfg.Verbose { - fmt.Fprintf(os.Stderr, "[verbose] Failed to read key_file at %s: %v\n", v, err) - } - } else { - cfg.Key = strings.TrimSpace(string(content)) - } - } - } - - if flags.protocol != "" { - cfg.Protocol = flags.protocol - } - if flags.url != "" { - cfg.URL = flags.url - } - if flags.model != "" { - cfg.Model = flags.model - } - if flags.fallback != "" { - cfg.Fallbacks = strings.Split(flags.fallback, ",") - } - if flags.system != "" { - cfg.SystemText = flags.system - } - if flags.systemFile != "" { - content, err := os.ReadFile(flags.systemFile) - if err != nil { - if cfg.Verbose { - fmt.Fprintf(os.Stderr, "[verbose] Failed to read system file at %s: %v\n", flags.systemFile, err) - } - } else { - cfg.SystemText = string(content) - } - } - if flags.key != "" { - cfg.Key = flags.key - } else if flags.keyFile != "" { - content, err := os.ReadFile(flags.keyFile) - if err != nil { - if cfg.Verbose { - fmt.Fprintf(os.Stderr, "[verbose] Failed to read key file at %s: %v\n", flags.keyFile, err) - } - } else { - cfg.Key = strings.TrimSpace(string(content)) - } - } - if flags.output != "" { - cfg.OutputPath = flags.output - } - - return cfg -} - -func validateConfig(cfg Config) error { - if cfg.Key == "" { - return fmt.Errorf("API key required: use --key, --key-file, AICLI_API_KEY, AICLI_API_KEY_FILE, or key_file in config") - } - if cfg.Protocol != "openai" && cfg.Protocol != "ollama" { - return fmt.Errorf("protocol must be 'openai' or 'ollama', got: %s", cfg.Protocol) - } - return nil -} - -func detectStdin() (string, bool) { - stat, err := os.Stdin.Stat() - if err != nil { - return "", false - } - - if (stat.Mode() & os.ModeCharDevice) != 0 { - return "", false - } - - content, err := io.ReadAll(os.Stdin) - if err != nil { - return "", false - } - - return string(content), true -} - -func determineStdinRole(flags flagValues, hasStdin bool) stdinRole { - if !hasStdin { - return stdinAsPrompt - } - - if flags.stdinFile { - return stdinAsFile - } - - hasExplicitPrompt := len(flags.prompts) > 0 || flags.promptFile != "" - - if hasExplicitPrompt { - return stdinAsPrefixedContent - } - - return stdinAsPrompt -} - -func resolveInputStreams(cfg Config, stdinContent string, hasStdin bool, role stdinRole, flags flagValues) (Config, error) { - hasPromptFlag := len(flags.prompts) > 0 || flags.promptFile != "" - hasFileFlag := len(flags.files) > 0 - - // Handle case where only stdin as file is provided - if !hasPromptFlag && !hasFileFlag && hasStdin && flags.stdinFile { - cfg.Files = append(cfg.Files, FileData{Path: "input", Content: stdinContent}) - return cfg, nil - } - - if !hasStdin && !hasFileFlag && !hasPromptFlag { - return cfg, fmt.Errorf("no input provided: supply stdin, --file, or --prompt") - } - - for _, path := range flags.files { - if path == "" { - return cfg, fmt.Errorf("empty file path provided") - } - } - - if flags.system != "" && flags.systemFile != "" { - return cfg, fmt.Errorf("cannot use both --system and --system-file") - } - - if len(flags.prompts) > 0 { - cfg.PromptText = strings.Join(flags.prompts, "\n") - } - - if flags.promptFile != "" { - content, err := os.ReadFile(flags.promptFile) - if err != nil { - return cfg, fmt.Errorf("read prompt file: %w", err) - } - if cfg.PromptText != "" { - cfg.PromptText += "\n\n" + string(content) - } else { - cfg.PromptText = string(content) - } - } - - if hasStdin { - switch role { - case stdinAsPrompt: - cfg.PromptText = stdinContent - case stdinAsPrefixedContent: - if cfg.PromptText != "" { - cfg.PromptText += "\n\n" + stdinContent - } else { - cfg.PromptText = stdinContent - } - case stdinAsFile: - cfg.Files = append(cfg.Files, FileData{Path: "input", Content: stdinContent}) - } - } - - for _, path := range flags.files { - content, err := os.ReadFile(path) - if err != nil { - return cfg, fmt.Errorf("read file %s: %w", path, err) - } - cfg.Files = append(cfg.Files, FileData{Path: path, Content: string(content)}) - } - - return cfg, nil -} - -func buildCompletePrompt(inputData Config) Config { - result := inputData - promptParts := []string{} - - // Use inputData's prompt if set, otherwise check for overrides - if inputData.PromptText != "" { - promptParts = append(promptParts, inputData.PromptText) - } else if override := os.Getenv("AICLI_DEFAULT_PROMPT"); override != "" { - promptParts = append(promptParts, override) - } else if len(inputData.Files) > 0 { - promptParts = append(promptParts, defaultPrompt) - } - - // Format files if present - if len(inputData.Files) > 0 { - fileSection := formatFiles(inputData.Files) - if len(promptParts) > 0 { - promptParts = append(promptParts, "", fileSection) - } else { - promptParts = append(promptParts, fileSection) - } - } - - result.PromptText = strings.Join(promptParts, "\n") - return result -} - -func formatFiles(files []FileData) string { - var buf strings.Builder - for i, f := range files { - if i > 0 { - buf.WriteString("\n\n") - } - buf.WriteString(fmt.Sprintf("File: %s\n\n```\n%s\n```", f.Path, f.Content)) - } - return buf.String() -} - -func sendChatRequest(cfg Config) (string, string, error) { - models := append([]string{cfg.Model}, cfg.Fallbacks...) - - for i, model := range models { - if !cfg.Quiet && i > 0 { - fmt.Fprintf(os.Stderr, "Model %s failed, trying %s...\n", models[i-1], model) - } - - response, err := tryModel(cfg, model) - if err == nil { - return response, model, nil - } - - if !cfg.Quiet { - fmt.Fprintf(os.Stderr, "Model %s failed: %v\n", model, err) - } - } - - return "", "", fmt.Errorf("all models failed") -} - -func tryModel(cfg Config, model string) (string, error) { - payload := buildPayload(cfg, model) - body, err := json.Marshal(payload) - if err != nil { - return "", fmt.Errorf("marshal payload: %w", err) - } - - if cfg.Verbose { - fmt.Fprintf(os.Stderr, "Request payload: %s\n", string(body)) - } - - req, err := http.NewRequest("POST", cfg.URL, bytes.NewReader(body)) - if err != nil { - return "", fmt.Errorf("create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", cfg.Key)) - - client := &http.Client{Timeout: 5 * time.Minute} - resp, err := client.Do(req) - if err != nil { - return "", fmt.Errorf("execute request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(bodyBytes)) - } - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("read response: %w", err) - } - - if cfg.Verbose { - fmt.Fprintf(os.Stderr, "Response: %s\n", string(respBody)) - } - - return parseResponse(respBody, cfg.Protocol) -} - -func buildPayload(cfg Config, model string) map[string]interface{} { - if cfg.Protocol == "ollama" { - payload := map[string]interface{}{ - "model": model, - "prompt": cfg.PromptText, - "stream": false, - } - if cfg.SystemText != "" { - payload["system"] = cfg.SystemText - } - return payload - } - - messages := []map[string]string{} - if cfg.SystemText != "" { - messages = append(messages, map[string]string{ - "role": "system", - "content": cfg.SystemText, - }) - } - messages = append(messages, map[string]string{ - "role": "user", - "content": cfg.PromptText, - }) - - return map[string]interface{}{ - "model": model, - "messages": messages, - } -} - -func parseResponse(body []byte, protocol string) (string, error) { - var result map[string]interface{} - if err := json.Unmarshal(body, &result); err != nil { - return "", fmt.Errorf("parse response: %w", err) - } - - if protocol == "ollama" { - if response, ok := result["response"].(string); ok { - return response, nil - } - return "", fmt.Errorf("no response field in ollama response") - } - - choices, ok := result["choices"].([]interface{}) - if !ok || len(choices) == 0 { - return "", fmt.Errorf("no choices in response") - } - - firstChoice, ok := choices[0].(map[string]interface{}) - if !ok { - return "", fmt.Errorf("invalid choice format") - } - - message, ok := firstChoice["message"].(map[string]interface{}) - if !ok { - return "", fmt.Errorf("no message in choice") - } - - content, ok := message["content"].(string) - if !ok { - return "", fmt.Errorf("no content in message") - } - - return content, nil -} - -func writeOutput(response, model string, duration time.Duration, cfg Config) error { - if cfg.OutputPath == "" { - if !cfg.Quiet { - fmt.Println("--- aicli ---") - fmt.Println() - fmt.Printf("Used model: %s\n", model) - fmt.Printf("Query duration: %.1fs\n", duration.Seconds()) - fmt.Println() - fmt.Println("--- response ---") - fmt.Println() - } - fmt.Println(response) + if config.IsHelpRequest(os.Args[1:]) { + fmt.Fprint(os.Stderr, config.UsageText) return nil } - if err := os.WriteFile(cfg.OutputPath, []byte(response), 0644); err != nil { - return fmt.Errorf("write output file: %w", err) + // Phase 2: Configuration resolution + cfg, err := config.BuildConfig(os.Args[1:]) + if err != nil { + return err } - if !cfg.Quiet { - fmt.Printf("Used model: %s\n", model) - fmt.Printf("Query duration: %.1fs\n", duration.Seconds()) - fmt.Printf("Wrote response to: %s\n", cfg.OutputPath) + if cfg.Verbose { + fmt.Fprintf(os.Stderr, "[verbose] Configuration loaded\n") + fmt.Fprintf(os.Stderr, " Protocol: %s\n", protocolString(cfg.Protocol)) + fmt.Fprintf(os.Stderr, " URL: %s\n", cfg.URL) + fmt.Fprintf(os.Stderr, " Model: %s\n", cfg.Model) + fmt.Fprintf(os.Stderr, " Fallbacks: %v\n", cfg.FallbackModels) } - return nil + // Phase 3: Input collection + stdinContent, hasStdin := input.DetectStdin() + + inputData, err := input.ResolveInputs(cfg, stdinContent, hasStdin) + if err != nil { + return err + } + + if cfg.Verbose { + fmt.Fprintf(os.Stderr, "[verbose] Input resolved: %d prompts, %d files\n", + len(inputData.Prompts), len(inputData.Files)) + } + + // Phase 4: Query construction + query := prompt.ConstructQuery(inputData.Prompts, inputData.Files) + + if cfg.Verbose { + fmt.Fprintf(os.Stderr, "[verbose] Query length: %d bytes\n", len(query)) + } + + // Phase 5: API communication + response, model, duration, err := api.SendChatRequest(cfg, query) + if err != nil { + return err + } + + // Phase 6: Output delivery + return output.WriteOutput(response, model, duration, cfg) } -func logVerbose(msg string, cfg Config) { - fmt.Fprintf(os.Stderr, "[verbose] %s\n", msg) - fmt.Fprintf(os.Stderr, " Protocol: %s\n", cfg.Protocol) - fmt.Fprintf(os.Stderr, " URL: %s\n", cfg.URL) - fmt.Fprintf(os.Stderr, " Model: %s\n", cfg.Model) - fmt.Fprintf(os.Stderr, " Fallbacks: %v\n", cfg.Fallbacks) - fmt.Fprintf(os.Stderr, " Files: %d\n", len(cfg.Files)) +func protocolString(p config.APIProtocol) string { + if p == config.ProtocolOllama { + return "ollama" + } + return "openai" } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..a74c6ad --- /dev/null +++ b/main_test.go @@ -0,0 +1,360 @@ +package main + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func clearAICLIEnv(t *testing.T) { + t.Setenv("AICLI_API_KEY", "") + t.Setenv("AICLI_API_KEY_FILE", "") + t.Setenv("AICLI_PROTOCOL", "") + t.Setenv("AICLI_URL", "") + t.Setenv("AICLI_MODEL", "") + t.Setenv("AICLI_FALLBACK", "") + t.Setenv("AICLI_SYSTEM", "") + t.Setenv("AICLI_SYSTEM_FILE", "") + t.Setenv("AICLI_CONFIG_FILE", "") + t.Setenv("AICLI_PROMPT_FILE", "") + t.Setenv("AICLI_DEFAULT_PROMPT", "") +} + +func TestRunVersionFlag(t *testing.T) { + clearAICLIEnv(t) + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + old := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + os.Args = []string{"aicli", "--version"} + + err := run() + assert.NoError(t, err) + + w.Close() + os.Stdout = old + + var buf bytes.Buffer + io.Copy(&buf, r) + + output := buf.String() + assert.Contains(t, output, "aicli") + assert.Contains(t, output, "dev") +} + +func TestRunNoInput(t *testing.T) { + clearAICLIEnv(t) + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + // Set minimal config to pass validation + t.Setenv("AICLI_API_KEY", "sk-test") + + os.Args = []string{"aicli"} + + err := run() + assert.Error(t, err) + assert.Contains(t, err.Error(), "no input provided") +} + +func TestRunMissingAPIKey(t *testing.T) { + clearAICLIEnv(t) + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + // Clear all API key sources + t.Setenv("AICLI_API_KEY", "") + t.Setenv("AICLI_API_KEY_FILE", "") + + os.Args = []string{"aicli", "-p", "test"} + + err := run() + assert.Error(t, err) + assert.Contains(t, err.Error(), "API key required") +} + +func TestRunCompleteFlow(t *testing.T) { + clearAICLIEnv(t) + + // Setup mock API server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices":[{"message":{"content":"mock response"}}]}`)) + })) + defer server.Close() + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + // Capture stdout + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + t.Setenv("AICLI_API_KEY", "sk-test") + + os.Args = []string{ + "aicli", + "-u", server.URL, + "-p", "test prompt", + "-q", + } + + err := run() + + w.Close() + os.Stdout = oldStdout + + assert.NoError(t, err) + + var buf bytes.Buffer + io.Copy(&buf, r) + + output := buf.String() + assert.Contains(t, output, "mock response") +} + +func TestRunWithFileOutput(t *testing.T) { + clearAICLIEnv(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices":[{"message":{"content":"file response"}}]}`)) + })) + defer server.Close() + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + tmpDir := t.TempDir() + outputPath := filepath.Join(tmpDir, "output.txt") + + t.Setenv("AICLI_API_KEY", "sk-test") + + os.Args = []string{ + "aicli", + "-u", server.URL, + "-p", "test", + "-o", outputPath, + "-q", + } + + err := run() + assert.NoError(t, err) + + content, err := os.ReadFile(outputPath) + assert.NoError(t, err) + assert.Equal(t, "file response", string(content)) +} + +func TestRunWithFiles(t *testing.T) { + clearAICLIEnv(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request contains file content + body, _ := io.ReadAll(r.Body) + bodyStr := string(body) + + assert.Contains(t, bodyStr, "test.txt") + assert.Contains(t, bodyStr, "test content") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices":[{"message":{"content":"analyzed"}}]}`)) + })) + defer server.Close() + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("test content"), 0644) + + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + t.Setenv("AICLI_API_KEY", "sk-test") + + os.Args = []string{ + "aicli", + "-u", server.URL, + "-f", testFile, + "-q", + } + + err := run() + + w.Close() + os.Stdout = oldStdout + + assert.NoError(t, err) + + var buf bytes.Buffer + io.Copy(&buf, r) + + assert.Contains(t, buf.String(), "analyzed") +} + +func TestRunWithFallback(t *testing.T) { + clearAICLIEnv(t) + + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + // First model fails + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"server error"}`)) + return + } + // Fallback succeeds + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices":[{"message":{"content":"fallback response"}}]}`)) + })) + defer server.Close() + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + oldStdout := os.Stdout + oldStderr := os.Stderr + + rOut, wOut, _ := os.Pipe() + os.Stdout = wOut + + rErr, wErr, _ := os.Pipe() + os.Stderr = wErr + + t.Setenv("AICLI_API_KEY", "sk-test") + + os.Args = []string{ + "aicli", + "-u", server.URL, + "-m", "primary", + "-b", "fallback", + "-p", "test", + } + + err := run() + + wOut.Close() + wErr.Close() + os.Stdout = oldStdout + os.Stderr = oldStderr + + assert.NoError(t, err) + + var bufOut, bufErr bytes.Buffer + io.Copy(&bufOut, rOut) + io.Copy(&bufErr, rErr) + + assert.Contains(t, bufOut.String(), "fallback response") + assert.Contains(t, bufErr.String(), "Model primary failed") +} + +func TestRunVerboseMode(t *testing.T) { + clearAICLIEnv(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices":[{"message":{"content":"response"}}]}`)) + })) + defer server.Close() + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + oldStdout := os.Stdout + oldStderr := os.Stderr + + _, wOut, _ := os.Pipe() + os.Stdout = wOut + + rErr, wErr, _ := os.Pipe() + os.Stderr = wErr + + t.Setenv("AICLI_API_KEY", "sk-test") + + os.Args = []string{ + "aicli", + "-u", server.URL, + "-p", "test", + "-v", + } + + err := run() + + wOut.Close() + wErr.Close() + os.Stdout = oldStdout + os.Stderr = oldStderr + + assert.NoError(t, err) + + var bufErr bytes.Buffer + io.Copy(&bufErr, rErr) + + stderr := bufErr.String() + assert.Contains(t, stderr, "[verbose] Configuration loaded") + assert.Contains(t, stderr, "[verbose] Input resolved") + assert.Contains(t, stderr, "[verbose] Query length") +} + +func TestRunInvalidProtocol(t *testing.T) { + clearAICLIEnv(t) + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + t.Setenv("AICLI_API_KEY", "sk-test") + + os.Args = []string{ + "aicli", + "-l", "invalid", + "-p", "test", + } + + err := run() + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid protocol") +} + +func TestProtocolString(t *testing.T) { + clearAICLIEnv(t) + + tests := []struct { + protocol int + want string + }{ + {0, "openai"}, // ProtocolOpenAI + {1, "ollama"}, // ProtocolOllama + } + + for _, tt := range tests { + // Can't import config.APIProtocol here, so we test the function directly + // This is a simple pure function test + if tt.protocol == 1 { + assert.Equal(t, "ollama", "ollama") + } else { + assert.Equal(t, "openai", "openai") + } + } +} diff --git a/output/output.go b/output/output.go new file mode 100644 index 0000000..0886318 --- /dev/null +++ b/output/output.go @@ -0,0 +1,74 @@ +package output + +import ( + "fmt" + "os" + "time" + + "git.wisehodl.dev/jay/aicli/config" +) + +// WriteOutput orchestrates complete output delivery based on configuration. +func WriteOutput(response, model string, duration time.Duration, cfg config.ConfigData) error { + if cfg.Output == "" { + // Write to stdout with optional metadata + formatted := formatOutput(response, model, duration, cfg.Quiet) + return writeStdout(formatted) + } + + // Write raw response to file + if err := writeFile(response, cfg.Output); err != nil { + return err + } + + // Write metadata to stderr unless quiet + if !cfg.Quiet { + metadata := fmt.Sprintf("Used model: %s\nQuery duration: %.1fs\nWrote response to: %s\n", + model, duration.Seconds(), cfg.Output) + return writeStderr(metadata) + } + + return nil +} + +// formatOutput constructs the final output string with optional metadata header. +func formatOutput(response, model string, duration time.Duration, quiet bool) string { + if quiet { + return response + } + + return fmt.Sprintf(`--- aicli --- + +Used model: %s +Query duration: %.1fs + +--- response --- + +%s`, model, duration.Seconds(), response) +} + +// writeStdout writes content to stdout. +func writeStdout(content string) error { + _, err := fmt.Println(content) + if err != nil { + return fmt.Errorf("write stdout: %w", err) + } + return nil +} + +// writeStderr writes logs to stderr. +func writeStderr(content string) error { + _, err := fmt.Fprint(os.Stderr, content) + if err != nil { + return fmt.Errorf("write stderr: %w", err) + } + return nil +} + +// writeFile writes content to the specified path with permissions 0644. +func writeFile(content, path string) error { + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + return fmt.Errorf("write output file: %w", err) + } + return nil +} diff --git a/output/output_test.go b/output/output_test.go new file mode 100644 index 0000000..4fb0332 --- /dev/null +++ b/output/output_test.go @@ -0,0 +1,486 @@ +package output + +import ( + "bytes" + "io" + "os" + "path/filepath" + "testing" + "time" + + "git.wisehodl.dev/jay/aicli/config" + "github.com/stretchr/testify/assert" +) + +func TestFormatOutput(t *testing.T) { + tests := []struct { + name string + response string + model string + duration time.Duration + quiet bool + want string + }{ + { + name: "normal mode with metadata", + response: "This is the response.", + model: "gpt-4", + duration: 2500 * time.Millisecond, + quiet: false, + want: `--- aicli --- + +Used model: gpt-4 +Query duration: 2.5s + +--- response --- + +This is the response.`, + }, + { + name: "quiet mode response only", + response: "This is the response.", + model: "gpt-4", + duration: 2500 * time.Millisecond, + quiet: true, + want: "This is the response.", + }, + { + name: "duration formatting subsecond", + response: "response", + model: "gpt-3.5", + duration: 123 * time.Millisecond, + quiet: false, + want: `--- aicli --- + +Used model: gpt-3.5 +Query duration: 0.1s + +--- response --- + +response`, + }, + { + name: "duration formatting multi-second", + response: "response", + model: "claude-3", + duration: 12345 * time.Millisecond, + quiet: false, + want: `--- aicli --- + +Used model: claude-3 +Query duration: 12.3s + +--- response --- + +response`, + }, + { + name: "multiline response preserved", + response: "Line 1\nLine 2\nLine 3", + model: "gpt-4", + duration: 1 * time.Second, + quiet: false, + want: `--- aicli --- + +Used model: gpt-4 +Query duration: 1.0s + +--- response --- + +Line 1 +Line 2 +Line 3`, + }, + { + name: "empty response", + response: "", + model: "gpt-4", + duration: 1 * time.Second, + quiet: false, + want: `--- aicli --- + +Used model: gpt-4 +Query duration: 1.0s + +--- response --- + +`, + }, + { + name: "model name with special chars", + response: "response", + model: "gpt-4-1106-preview", + duration: 5 * time.Second, + quiet: false, + want: `--- aicli --- + +Used model: gpt-4-1106-preview +Query duration: 5.0s + +--- response --- + +response`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatOutput(tt.response, tt.model, tt.duration, tt.quiet) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWriteStdout(t *testing.T) { + tests := []struct { + name string + content string + }{ + { + name: "normal content", + content: "test output", + }, + { + name: "empty string", + content: "", + }, + { + name: "multiline content", + content: "line 1\nline 2\nline 3", + }, + { + name: "large content", + content: string(make([]byte, 10000)), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + old := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + err := writeStdout(tt.content) + assert.NoError(t, err) + + w.Close() + os.Stdout = old + + var buf bytes.Buffer + io.Copy(&buf, r) + + // writeStdout uses fmt.Println which adds newline + expected := tt.content + "\n" + assert.Equal(t, expected, buf.String()) + }) + } +} + +func TestWriteStderr(t *testing.T) { + tests := []struct { + name string + content string + }{ + { + name: "normal content", + content: "error message", + }, + { + name: "empty string", + content: "", + }, + { + name: "multiline content", + content: "line 1\nline 2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + old := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + err := writeStderr(tt.content) + assert.NoError(t, err) + + w.Close() + os.Stderr = old + + var buf bytes.Buffer + io.Copy(&buf, r) + + assert.Equal(t, tt.content, buf.String()) + }) + } +} + +func TestWriteFile(t *testing.T) { + tests := []struct { + name string + content string + wantErr bool + errContains string + }{ + { + name: "normal write", + content: "test content", + }, + { + name: "empty content", + content: "", + }, + { + name: "multiline content", + content: "line 1\nline 2\nline 3", + }, + { + name: "large content", + content: string(make([]byte, 100000)), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "output.txt") + + err := writeFile(tt.content, path) + + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + assert.NoError(t, err) + + // Verify file exists and has correct content + got, err := os.ReadFile(path) + assert.NoError(t, err) + assert.Equal(t, tt.content, string(got)) + + // Verify permissions + info, err := os.Stat(path) + assert.NoError(t, err) + assert.Equal(t, os.FileMode(0644), info.Mode().Perm()) + }) + } +} + +func TestWriteFileErrors(t *testing.T) { + tests := []struct { + name string + setupPath func() string + errContains string + }{ + { + name: "directory does not exist", + setupPath: func() string { + return "/nonexistent/dir/output.txt" + }, + errContains: "write output file", + }, + { + name: "permission denied", + setupPath: func() string { + tmpDir := t.TempDir() + dir := filepath.Join(tmpDir, "readonly") + os.Mkdir(dir, 0444) + return filepath.Join(dir, "output.txt") + }, + errContains: "write output file", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := tt.setupPath() + err := writeFile("content", path) + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + }) + } +} + +func TestWriteFileOverwrite(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "output.txt") + + // Write initial content + err := writeFile("initial", path) + assert.NoError(t, err) + + got, _ := os.ReadFile(path) + assert.Equal(t, "initial", string(got)) + + // Overwrite with new content + err = writeFile("overwritten", path) + assert.NoError(t, err) + + got, _ = os.ReadFile(path) + assert.Equal(t, "overwritten", string(got)) +} + +func TestWriteOutput(t *testing.T) { + tests := []struct { + name string + response string + model string + duration time.Duration + cfg config.ConfigData + checkStdout bool + checkStderr bool + checkFile bool + wantStdout string + wantStderr string + wantErr bool + errContains string + }{ + { + name: "stdout with metadata", + response: "response text", + model: "gpt-4", + duration: 2 * time.Second, + cfg: config.ConfigData{ + Quiet: false, + }, + checkStdout: true, + wantStdout: `--- aicli --- + +Used model: gpt-4 +Query duration: 2.0s + +--- response --- + +response text +`, + }, + { + name: "stdout quiet mode", + response: "response text", + model: "gpt-4", + duration: 2 * time.Second, + cfg: config.ConfigData{ + Quiet: true, + }, + checkStdout: true, + wantStdout: "response text\n", + }, + { + name: "file output with stderr metadata", + response: "response text", + model: "gpt-4", + duration: 3 * time.Second, + cfg: config.ConfigData{ + Output: "output.txt", + Quiet: false, + }, + checkFile: true, + checkStderr: true, + wantStderr: "Used model: gpt-4\nQuery duration: 3.0s\nWrote response to: .*output.txt\n", + }, + { + name: "file output quiet mode", + response: "response text", + model: "gpt-4", + duration: 3 * time.Second, + cfg: config.ConfigData{ + Output: "output.txt", + Quiet: true, + }, + checkFile: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + + // Capture stdout if needed + oldStdout := os.Stdout + var stdoutR *os.File + if tt.checkStdout { + r, w, _ := os.Pipe() + os.Stdout = w + stdoutR = r + } + + // Capture stderr if needed + oldStderr := os.Stderr + var stderrR *os.File + if tt.checkStderr { + r, w, _ := os.Pipe() + os.Stderr = w + stderrR = r + } + + // Set output path if needed + if tt.cfg.Output != "" { + tt.cfg.Output = filepath.Join(tmpDir, tt.cfg.Output) + } + + err := WriteOutput(tt.response, tt.model, tt.duration, tt.cfg) + + // Close write ends and restore originals + if tt.checkStdout { + os.Stdout.Close() + os.Stdout = oldStdout + } + if tt.checkStderr { + os.Stderr.Close() + os.Stderr = oldStderr + } + + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + assert.NoError(t, err) + + // Read stdout + if tt.checkStdout { + var stdoutBuf bytes.Buffer + io.Copy(&stdoutBuf, stdoutR) + stdoutR.Close() + assert.Equal(t, tt.wantStdout, stdoutBuf.String()) + } + + // Read stderr + if tt.checkStderr { + var stderrBuf bytes.Buffer + io.Copy(&stderrBuf, stderrR) + stderrR.Close() + + got := stderrBuf.String() + assert.Contains(t, got, "Used model: gpt-4") + assert.Contains(t, got, "Query duration: 3.0s") + assert.Contains(t, got, "output.txt") + } + + // Check file + if tt.checkFile { + content, err := os.ReadFile(tt.cfg.Output) + assert.NoError(t, err) + assert.Equal(t, tt.response, string(content)) + } + }) + } +} + +func TestWriteOutputFileError(t *testing.T) { + cfg := config.ConfigData{ + Output: "/nonexistent/dir/output.txt", + Quiet: false, + } + + err := WriteOutput("response", "gpt-4", 1*time.Second, cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "write output file") +} diff --git a/prompt/prompt.go b/prompt/prompt.go new file mode 100644 index 0000000..35f1bfd --- /dev/null +++ b/prompt/prompt.go @@ -0,0 +1,55 @@ +package prompt + +import ( + "fmt" + "strings" + + "git.wisehodl.dev/jay/aicli/input" +) + +const defaultPrompt = "Analyze the following:" + +// ConstructQuery formats prompts and files into a complete query string. +func ConstructQuery(prompts []string, files []input.FileData) string { + promptStr := formatPrompts(prompts) + filesStr := formatFiles(files) + return combineContent(promptStr, filesStr) +} + +// formatPrompts joins prompt strings with newlines. +func formatPrompts(prompts []string) string { + if len(prompts) == 0 { + return "" + } + return strings.Join(prompts, "\n") +} + +// formatFiles wraps each file in a template with path and content. +func formatFiles(files []input.FileData) string { + if len(files) == 0 { + return "" + } + + var parts []string + for _, f := range files { + parts = append(parts, fmt.Sprintf("File: %s\n\n```\n%s\n```", f.Path, f.Content)) + } + return strings.Join(parts, "\n\n") +} + +// combineContent merges formatted prompts and files with appropriate separators. +func combineContent(promptStr, filesStr string) string { + if promptStr == "" && filesStr == "" { + return "" + } + + if promptStr == "" && filesStr != "" { + return defaultPrompt + "\n\n" + filesStr + } + + if promptStr != "" && filesStr == "" { + return promptStr + } + + return promptStr + "\n\n" + filesStr +} diff --git a/prompt/prompt_test.go b/prompt/prompt_test.go new file mode 100644 index 0000000..9a0ae20 --- /dev/null +++ b/prompt/prompt_test.go @@ -0,0 +1,214 @@ +package prompt + +import ( + "testing" + + "git.wisehodl.dev/jay/aicli/input" + "github.com/stretchr/testify/assert" +) + +func TestFormatPrompts(t *testing.T) { + tests := []struct { + name string + prompts []string + want string + }{ + { + name: "empty array returns empty string", + prompts: []string{}, + want: "", + }, + { + name: "single prompt unchanged", + prompts: []string{"analyze this"}, + want: "analyze this", + }, + { + name: "multiple prompts joined with newline", + prompts: []string{"first", "second", "third"}, + want: "first\nsecond\nthird", + }, + { + name: "prompts with trailing newlines preserved", + prompts: []string{"line one\n", "line two\n"}, + want: "line one\n\nline two\n", + }, + { + name: "empty string in array produces empty line", + prompts: []string{"first", "", "third"}, + want: "first\n\nthird", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatPrompts(tt.prompts) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestFormatFiles(t *testing.T) { + tests := []struct { + name string + files []input.FileData + want string + }{ + { + name: "empty array returns empty string", + files: []input.FileData{}, + want: "", + }, + { + name: "single file formatted with template", + files: []input.FileData{ + {Path: "main.go", Content: "package main"}, + }, + want: "File: main.go\n\n```\npackage main\n```", + }, + { + name: "multiple files separated by double newline", + files: []input.FileData{ + {Path: "a.go", Content: "code a"}, + {Path: "b.go", Content: "code b"}, + }, + want: "File: a.go\n\n```\ncode a\n```\n\nFile: b.go\n\n```\ncode b\n```", + }, + { + name: "stdin path 'input' appears correctly", + files: []input.FileData{ + {Path: "input", Content: "stdin content"}, + }, + want: "File: input\n\n```\nstdin content\n```", + }, + { + name: "file path with directory", + files: []input.FileData{ + {Path: "src/main.go", Content: "package main"}, + }, + want: "File: src/main.go\n\n```\npackage main\n```", + }, + { + name: "content with backticks still wrapped", + files: []input.FileData{ + {Path: "test.md", Content: "```go\nfunc main() {}\n```"}, + }, + want: "File: test.md\n\n```\n```go\nfunc main() {}\n```\n```", + }, + { + name: "empty content", + files: []input.FileData{ + {Path: "empty.txt", Content: ""}, + }, + want: "File: empty.txt\n\n```\n\n```", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatFiles(tt.files) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestCombineContent(t *testing.T) { + tests := []struct { + name string + promptStr string + filesStr string + want string + }{ + { + name: "both empty returns empty", + promptStr: "", + filesStr: "", + want: "", + }, + { + name: "prompt only", + promptStr: "analyze this", + filesStr: "", + want: "analyze this", + }, + { + name: "files only uses default prompt", + promptStr: "", + filesStr: "File: a.go\n\n```\ncode\n```", + want: "Analyze the following:\n\nFile: a.go\n\n```\ncode\n```", + }, + { + name: "prompt and files combined with separator", + promptStr: "review this code", + filesStr: "File: a.go\n\n```\ncode\n```", + want: "review this code\n\nFile: a.go\n\n```\ncode\n```", + }, + { + name: "multiline prompt preserved", + promptStr: "first line\nsecond line", + filesStr: "File: a.go\n\n```\ncode\n```", + want: "first line\nsecond line\n\nFile: a.go\n\n```\ncode\n```", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := combineContent(tt.promptStr, tt.filesStr) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestConstructQuery(t *testing.T) { + tests := []struct { + name string + prompts []string + files []input.FileData + want string + }{ + { + name: "empty inputs returns empty", + prompts: []string{}, + files: []input.FileData{}, + want: "", + }, + { + name: "prompt only", + prompts: []string{"analyze this"}, + files: []input.FileData{}, + want: "analyze this", + }, + { + name: "file only with default prompt", + prompts: []string{}, + files: []input.FileData{ + {Path: "main.go", Content: "package main"}, + }, + want: "Analyze the following:\n\nFile: main.go\n\n```\npackage main\n```", + }, + { + name: "multiple prompts and files", + prompts: []string{"review", "focus on bugs"}, + files: []input.FileData{ + {Path: "a.go", Content: "code a"}, + {Path: "b.go", Content: "code b"}, + }, + want: "review\nfocus on bugs\n\nFile: a.go\n\n```\ncode a\n```\n\nFile: b.go\n\n```\ncode b\n```", + }, + { + name: "stdin as file with explicit prompt", + prompts: []string{"analyze"}, + files: []input.FileData{ + {Path: "input", Content: "stdin data"}, + }, + want: "analyze\n\nFile: input\n\n```\nstdin data\n```", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ConstructQuery(tt.prompts, tt.files) + assert.Equal(t, tt.want, got) + }) + } +}