diff --git a/cmd/acp.go b/cmd/acp.go index e3147017..a4adbd78 100644 --- a/cmd/acp.go +++ b/cmd/acp.go @@ -1,7 +1,11 @@ package cmd import ( + "bufio" + "bytes" + "encoding/json" "fmt" + "io" "log/slog" "os" "os/signal" @@ -37,8 +41,10 @@ func runACP(cmd *cobra.Command, _ []string) error { defer agent.Close() // Create the stdio connection. The SDK reads JSON-RPC from stdin and - // writes responses to stdout. - conn := acp.NewAgentSideConnection(agent, os.Stdout, os.Stdin) + // writes responses to stdout. We wrap stdin with a normalizer that + // fills in optional fields the SDK's generated validation requires + // (e.g. mcpServers) so clients that omit them still work. + conn := acp.NewAgentSideConnection(agent, os.Stdout, newACPNormalizer(os.Stdin)) // Wire the connection back to the agent so it can send session updates. agent.SetAgentConnection(conn) @@ -63,3 +69,91 @@ func runACP(cmd *cobra.Command, _ []string) error { return nil } + +// acpNormalizer wraps an io.Reader carrying newline-delimited JSON-RPC and +// patches incoming messages so that fields the SDK validates as required — +// but that some clients (e.g. Zed) omit — are defaulted. This avoids +// InvalidParams errors without forking the SDK. +type acpNormalizer struct { + scanner *bufio.Scanner + buf bytes.Buffer // leftover bytes from the last normalized line +} + +func newACPNormalizer(r io.Reader) *acpNormalizer { + const maxMsg = 10 * 1024 * 1024 // 10 MB, matches SDK buffer + s := bufio.NewScanner(r) + s.Buffer(make([]byte, 0, 1024*1024), maxMsg) + return &acpNormalizer{scanner: s} +} + +// Read satisfies io.Reader. It feeds one normalized JSON line (plus newline) +// per underlying scan, buffering across short caller reads. +func (n *acpNormalizer) Read(p []byte) (int, error) { + // Drain any leftover bytes from the previous line first. + if n.buf.Len() > 0 { + return n.buf.Read(p) + } + + if !n.scanner.Scan() { + if err := n.scanner.Err(); err != nil { + return 0, err + } + return 0, io.EOF + } + + line := n.scanner.Bytes() + normalized := normalizeACPLine(line) + n.buf.Write(normalized) + n.buf.WriteByte('\n') + return n.buf.Read(p) +} + +// normalizeACPLine ensures session/new and session/load params contain an +// mcpServers array. Returns the original line unchanged for all other methods. +func normalizeACPLine(line []byte) []byte { + // Quick check: if it already contains mcpServers, nothing to do. + if bytes.Contains(line, []byte(`"mcpServers"`)) { + return line + } + + // Only bother parsing if the method could be session/new or session/load. + if !bytes.Contains(line, []byte(`"session/new"`)) && + !bytes.Contains(line, []byte(`"session/load"`)) { + return line + } + + var msg struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` + } + if err := json.Unmarshal(line, &msg); err != nil { + return line + } + if msg.Method != "session/new" && msg.Method != "session/load" { + return line + } + + // Patch params to include mcpServers: []. + var params map[string]json.RawMessage + if err := json.Unmarshal(msg.Params, ¶ms); err != nil { + return line + } + if _, ok := params["mcpServers"]; ok { + return line + } + params["mcpServers"] = json.RawMessage(`[]`) + + patched, err := json.Marshal(params) + if err != nil { + return line + } + msg.Params = patched + + out, err := json.Marshal(msg) + if err != nil { + return line + } + return out +}