mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
ee66477498
Remove runAgenticStep, runAgenticLoop, runInteractiveLoop, addMessagesToHistory, replaceMessagesHistory, AgenticLoopConfig, executeStopHook, runNonInteractiveMode, and runInteractiveMode from root.go. Functions still needed by the script command are moved to script.go.
1171 lines
38 KiB
Go
1171 lines
38 KiB
Go
package cmd
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/mark3labs/mcphost/internal/agent"
|
|
"github.com/mark3labs/mcphost/internal/config"
|
|
"github.com/mark3labs/mcphost/internal/hooks"
|
|
"github.com/mark3labs/mcphost/internal/models"
|
|
"github.com/mark3labs/mcphost/internal/session"
|
|
"github.com/mark3labs/mcphost/internal/tools"
|
|
"github.com/mark3labs/mcphost/internal/ui"
|
|
|
|
"github.com/spf13/cobra"
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
// scriptCmd represents the script command for executing MCPHost script files.
|
|
// Script files can contain YAML frontmatter configuration followed by a prompt,
|
|
// allowing for reproducible AI interactions with custom configurations and
|
|
// variable substitution support.
|
|
var scriptCmd = &cobra.Command{
|
|
Use: "script <script-file>",
|
|
Short: "Execute a script file with YAML frontmatter configuration",
|
|
Long: `Execute a script file that contains YAML frontmatter with configuration
|
|
and a prompt. The script file can contain MCP server configurations,
|
|
model settings, and other options.
|
|
|
|
Example script file:
|
|
---
|
|
model: "anthropic/claude-sonnet-4-5-20250929"
|
|
max-steps: 10
|
|
mcpServers:
|
|
filesystem:
|
|
type: "local"
|
|
command: ["npx", "-y", "@modelcontextprotocol/server-filesystem", "${directory:-/tmp}"]
|
|
---
|
|
Hello ${name:-World}! List the files in ${directory:-/tmp} and tell me about them.
|
|
|
|
The script command supports the same flags as the main command,
|
|
which will override any settings in the script file.
|
|
|
|
Variable substitution:
|
|
Variables in the script can be substituted using ${variable} syntax.
|
|
Variables can have default values using ${variable:-default} syntax.
|
|
Pass variables using --args:variable value syntax:
|
|
|
|
mcphost script myscript.sh --args:directory /tmp --args:name "John"
|
|
|
|
This will replace ${directory} with "/tmp" and ${name} with "John" in the script.
|
|
Variables with defaults (${var:-default}) are optional and use the default if not provided.`,
|
|
Args: cobra.ExactArgs(1),
|
|
FParseErrWhitelist: cobra.FParseErrWhitelist{
|
|
UnknownFlags: true, // Allow unknown flags for variable substitution
|
|
},
|
|
PreRun: func(cmd *cobra.Command, args []string) {
|
|
// Override config with frontmatter values from the script file
|
|
scriptFile := args[0]
|
|
variables := parseCustomVariables(cmd)
|
|
overrideConfigWithFrontmatter(scriptFile, variables, cmd)
|
|
},
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
scriptFile := args[0]
|
|
|
|
// Parse custom variables from unknown flags
|
|
variables := parseCustomVariables(cmd)
|
|
|
|
return runScriptCommand(context.Background(), scriptFile, variables, cmd)
|
|
},
|
|
}
|
|
|
|
func init() {
|
|
rootCmd.AddCommand(scriptCmd)
|
|
}
|
|
|
|
// overrideConfigWithFrontmatter parses the script file and overrides viper config with frontmatter values
|
|
// This is the only purpose of this function - to apply frontmatter configuration to viper
|
|
func overrideConfigWithFrontmatter(scriptFile string, variables map[string]string, cmd *cobra.Command) {
|
|
// Parse the script file to get frontmatter configuration
|
|
scriptConfig, err := parseScriptFile(scriptFile, variables)
|
|
if err != nil {
|
|
// If we can't parse the script file, just continue with existing config
|
|
// The error will be handled again in runScriptCommand
|
|
return
|
|
}
|
|
|
|
// Override viper values with frontmatter values (only if flags weren't explicitly set)
|
|
// Check both local flags and persistent flags since script inherits from root
|
|
flagChanged := func(name string) bool {
|
|
return cmd.Flags().Changed(name) || rootCmd.PersistentFlags().Changed(name)
|
|
}
|
|
|
|
if scriptConfig.Model != "" && !flagChanged("model") {
|
|
viper.Set("model", scriptConfig.Model)
|
|
}
|
|
if scriptConfig.MaxSteps != 0 && !flagChanged("max-steps") {
|
|
viper.Set("max-steps", scriptConfig.MaxSteps)
|
|
}
|
|
if scriptConfig.Debug && !flagChanged("debug") {
|
|
viper.Set("debug", scriptConfig.Debug)
|
|
}
|
|
if scriptConfig.Compact && !flagChanged("compact") {
|
|
viper.Set("compact", scriptConfig.Compact)
|
|
}
|
|
if scriptConfig.SystemPrompt != "" && !flagChanged("system-prompt") {
|
|
viper.Set("system-prompt", scriptConfig.SystemPrompt)
|
|
}
|
|
if scriptConfig.ProviderAPIKey != "" && !flagChanged("provider-api-key") {
|
|
viper.Set("provider-api-key", scriptConfig.ProviderAPIKey)
|
|
}
|
|
if scriptConfig.ProviderURL != "" && !flagChanged("provider-url") {
|
|
viper.Set("provider-url", scriptConfig.ProviderURL)
|
|
}
|
|
if scriptConfig.MaxTokens != 0 && !flagChanged("max-tokens") {
|
|
viper.Set("max-tokens", scriptConfig.MaxTokens)
|
|
}
|
|
if scriptConfig.Temperature != nil && !flagChanged("temperature") {
|
|
viper.Set("temperature", *scriptConfig.Temperature)
|
|
}
|
|
if scriptConfig.TopP != nil && !flagChanged("top-p") {
|
|
viper.Set("top-p", *scriptConfig.TopP)
|
|
}
|
|
if scriptConfig.TopK != nil && !flagChanged("top-k") {
|
|
viper.Set("top-k", *scriptConfig.TopK)
|
|
}
|
|
if len(scriptConfig.StopSequences) > 0 && !flagChanged("stop-sequences") {
|
|
viper.Set("stop-sequences", scriptConfig.StopSequences)
|
|
}
|
|
if scriptConfig.NoExit && !flagChanged("no-exit") {
|
|
// Set the global noExitFlag variable if it wasn't explicitly set via command line
|
|
noExitFlag = scriptConfig.NoExit
|
|
}
|
|
if scriptConfig.Stream != nil && !flagChanged("stream") {
|
|
viper.Set("stream", *scriptConfig.Stream)
|
|
}
|
|
if scriptConfig.TLSSkipVerify && !flagChanged("tls-skip-verify") {
|
|
viper.Set("tls-skip-verify", scriptConfig.TLSSkipVerify)
|
|
}
|
|
}
|
|
|
|
// parseCustomVariables extracts custom variables from command line arguments
|
|
func parseCustomVariables(_ *cobra.Command) map[string]string {
|
|
variables := make(map[string]string)
|
|
|
|
// Get all arguments passed to the command
|
|
args := os.Args[1:] // Skip program name
|
|
|
|
// Find the script subcommand position
|
|
scriptPos := -1
|
|
for i, arg := range args {
|
|
if arg == "script" {
|
|
scriptPos = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if scriptPos == -1 {
|
|
return variables
|
|
}
|
|
|
|
// Parse arguments after the script file
|
|
scriptFileFound := false
|
|
|
|
for i := scriptPos + 1; i < len(args); i++ {
|
|
arg := args[i]
|
|
|
|
// Skip the script file argument (first non-flag after "script")
|
|
if !scriptFileFound && !strings.HasPrefix(arg, "-") {
|
|
scriptFileFound = true
|
|
continue
|
|
}
|
|
|
|
// Parse custom variables with --args: prefix
|
|
if after, ok := strings.CutPrefix(arg, "--args:"); ok {
|
|
varName := after
|
|
if varName == "" {
|
|
continue // Skip malformed --args: without name
|
|
}
|
|
|
|
// Check if we have a value
|
|
if i+1 < len(args) {
|
|
varValue := args[i+1]
|
|
|
|
// Make sure the next arg isn't a flag
|
|
if !strings.HasPrefix(varValue, "-") {
|
|
variables[varName] = varValue
|
|
i++ // Skip the value
|
|
} else {
|
|
// No value provided, treat as empty string
|
|
variables[varName] = ""
|
|
}
|
|
} else {
|
|
// No value provided, treat as empty string
|
|
variables[varName] = ""
|
|
}
|
|
}
|
|
}
|
|
|
|
return variables
|
|
}
|
|
|
|
func runScriptCommand(ctx context.Context, scriptFile string, variables map[string]string, _ *cobra.Command) error {
|
|
// Parse the script file to get MCP servers and prompt
|
|
scriptConfig, err := parseScriptFile(scriptFile, variables)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse script file: %v", err)
|
|
}
|
|
|
|
// Get MCP config - use script servers if available, otherwise use global viper config
|
|
var mcpConfig *config.Config
|
|
if len(scriptConfig.MCPServers) > 0 {
|
|
// Load base config and merge with script config
|
|
baseConfig, err := config.LoadAndValidateConfig()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load base config: %v", err)
|
|
}
|
|
mcpConfig = config.MergeConfigs(baseConfig, scriptConfig)
|
|
} else {
|
|
// Use the new config loader
|
|
var err error
|
|
mcpConfig, err = config.LoadAndValidateConfig()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load MCP config: %v", err)
|
|
}
|
|
}
|
|
|
|
// Get final prompt - prioritize command line flag, then script content
|
|
finalPrompt := viper.GetString("prompt")
|
|
if finalPrompt == "" && scriptConfig.Prompt != "" {
|
|
finalPrompt = scriptConfig.Prompt
|
|
}
|
|
|
|
// Get final no-exit setting - prioritize command line flag, then script config
|
|
finalNoExit := noExitFlag || scriptConfig.NoExit
|
|
|
|
// Validate that --no-exit is only used when there's a prompt
|
|
if finalNoExit && finalPrompt == "" {
|
|
return fmt.Errorf("--no-exit flag can only be used when there's a prompt (either from script content or --prompt flag)")
|
|
}
|
|
|
|
// Run the script using the unified agentic loop
|
|
return runScriptMode(ctx, mcpConfig, finalPrompt, finalNoExit)
|
|
}
|
|
|
|
// mergeScriptConfig and setScriptValuesInViper functions removed
|
|
// Configuration override is now handled in overrideConfigWithFrontmatter in the PreRun hook
|
|
|
|
// parseScriptFile parses a script file with YAML frontmatter and returns config
|
|
func parseScriptFile(filename string, variables map[string]string) (*config.Config, error) {
|
|
file, err := os.Open(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = file.Close() }()
|
|
|
|
scanner := bufio.NewScanner(file)
|
|
|
|
// Skip shebang line if present
|
|
if scanner.Scan() {
|
|
line := scanner.Text()
|
|
if !strings.HasPrefix(line, "#!") {
|
|
// If it's not a shebang, we need to process this line
|
|
return parseScriptContent(line+"\n"+readRemainingLines(scanner), variables)
|
|
}
|
|
}
|
|
|
|
// Read the rest of the file
|
|
content := readRemainingLines(scanner)
|
|
return parseScriptContent(content, variables)
|
|
}
|
|
|
|
// readRemainingLines reads all remaining lines from a scanner
|
|
func readRemainingLines(scanner *bufio.Scanner) string {
|
|
var lines []string
|
|
for scanner.Scan() {
|
|
lines = append(lines, scanner.Text())
|
|
}
|
|
return strings.Join(lines, "\n")
|
|
}
|
|
|
|
// parseScriptContent parses the content to extract YAML frontmatter and prompt
|
|
func parseScriptContent(content string, variables map[string]string) (*config.Config, error) {
|
|
// STEP 1: Apply environment variable substitution FIRST
|
|
envSubstituter := &config.EnvSubstituter{}
|
|
processedContent, err := envSubstituter.SubstituteEnvVars(content)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("script env substitution failed: %v", err)
|
|
}
|
|
|
|
// STEP 2: Validate that all declared script variables are provided
|
|
if err := validateVariables(processedContent, variables); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// STEP 3: Apply script args substitution
|
|
argsSubstituter := config.NewArgsSubstituter(variables)
|
|
content, err = argsSubstituter.SubstituteArgs(processedContent)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("script args substitution failed: %v", err)
|
|
}
|
|
|
|
lines := strings.Split(content, "\n")
|
|
|
|
// Find YAML frontmatter between --- delimiters
|
|
var yamlLines []string
|
|
var promptLines []string
|
|
var inFrontmatter bool
|
|
var foundFrontmatter bool
|
|
var frontmatterEnd = -1
|
|
|
|
for i, line := range lines {
|
|
trimmed := strings.TrimSpace(line)
|
|
|
|
// Skip comment lines (lines starting with #)
|
|
if strings.HasPrefix(trimmed, "#") {
|
|
continue
|
|
}
|
|
|
|
// Check for frontmatter start
|
|
if trimmed == "---" && !inFrontmatter {
|
|
// Start of frontmatter
|
|
inFrontmatter = true
|
|
foundFrontmatter = true
|
|
continue
|
|
}
|
|
|
|
// Check for frontmatter end
|
|
if trimmed == "---" && inFrontmatter {
|
|
// End of frontmatter
|
|
inFrontmatter = false
|
|
frontmatterEnd = i + 1
|
|
continue
|
|
}
|
|
|
|
// Collect frontmatter lines
|
|
if inFrontmatter {
|
|
yamlLines = append(yamlLines, line)
|
|
}
|
|
}
|
|
|
|
// Extract prompt (everything after frontmatter)
|
|
if foundFrontmatter && frontmatterEnd != -1 && frontmatterEnd < len(lines) {
|
|
promptLines = lines[frontmatterEnd:]
|
|
} else if !foundFrontmatter {
|
|
// If no frontmatter found, treat entire content as prompt
|
|
promptLines = lines
|
|
yamlLines = []string{} // Empty YAML
|
|
}
|
|
|
|
// Parse YAML frontmatter using Viper for consistency with config file parsing
|
|
var scriptConfig config.Config
|
|
if len(yamlLines) > 0 {
|
|
yamlContent := strings.Join(yamlLines, "\n")
|
|
|
|
// Create temporary viper instance for frontmatter parsing
|
|
frontmatterViper := viper.New()
|
|
frontmatterViper.SetConfigType("yaml")
|
|
|
|
if err := frontmatterViper.ReadConfig(strings.NewReader(yamlContent)); err != nil {
|
|
return nil, fmt.Errorf("failed to parse YAML frontmatter: %v\nYAML content:\n%s", err, yamlContent)
|
|
}
|
|
|
|
if err := frontmatterViper.Unmarshal(&scriptConfig); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal frontmatter config: %v", err)
|
|
}
|
|
|
|
// Manually extract hyphenated keys that Viper might not handle correctly during unmarshal
|
|
if providerURL := frontmatterViper.GetString("provider-url"); providerURL != "" {
|
|
scriptConfig.ProviderURL = providerURL
|
|
}
|
|
if providerAPIKey := frontmatterViper.GetString("provider-api-key"); providerAPIKey != "" {
|
|
scriptConfig.ProviderAPIKey = providerAPIKey
|
|
}
|
|
if systemPrompt := frontmatterViper.GetString("system-prompt"); systemPrompt != "" {
|
|
scriptConfig.SystemPrompt = systemPrompt
|
|
}
|
|
if maxSteps := frontmatterViper.GetInt("max-steps"); maxSteps != 0 {
|
|
scriptConfig.MaxSteps = maxSteps
|
|
}
|
|
if maxTokens := frontmatterViper.GetInt("max-tokens"); maxTokens != 0 {
|
|
scriptConfig.MaxTokens = maxTokens
|
|
}
|
|
if topP := frontmatterViper.GetFloat64("top-p"); topP != 0 {
|
|
topPFloat32 := float32(topP)
|
|
scriptConfig.TopP = &topPFloat32
|
|
}
|
|
if topK := frontmatterViper.GetInt("top-k"); topK != 0 {
|
|
topKInt32 := int32(topK)
|
|
scriptConfig.TopK = &topKInt32
|
|
}
|
|
if stopSequences := frontmatterViper.GetStringSlice("stop-sequences"); len(stopSequences) > 0 {
|
|
scriptConfig.StopSequences = stopSequences
|
|
}
|
|
if noExit := frontmatterViper.GetBool("no-exit"); noExit {
|
|
scriptConfig.NoExit = noExit
|
|
}
|
|
if tlsSkipVerify := frontmatterViper.GetBool("tls-skip-verify"); tlsSkipVerify {
|
|
scriptConfig.TLSSkipVerify = tlsSkipVerify
|
|
}
|
|
}
|
|
|
|
// Set prompt from content after frontmatter
|
|
if len(promptLines) > 0 {
|
|
prompt := strings.Join(promptLines, "\n")
|
|
prompt = strings.TrimSpace(prompt) // Remove leading/trailing whitespace
|
|
if prompt != "" {
|
|
scriptConfig.Prompt = prompt
|
|
}
|
|
}
|
|
|
|
return &scriptConfig, nil
|
|
}
|
|
|
|
// Variable represents a script variable with optional default value.
|
|
// Variables can be declared in scripts using ${variable} syntax for required variables
|
|
// or ${variable:-default} syntax for variables with default values.
|
|
type Variable struct {
|
|
Name string // The name of the variable as it appears in the script
|
|
DefaultValue string // The default value if specified using ${variable:-default} syntax
|
|
HasDefault bool // Whether this variable has a default value
|
|
}
|
|
|
|
// findVariables extracts all unique variable names from ${variable} patterns in content
|
|
// Maintains backward compatibility by returning just variable names
|
|
func findVariables(content string) []string {
|
|
variables := findVariablesWithDefaults(content)
|
|
var names []string
|
|
for _, v := range variables {
|
|
names = append(names, v.Name)
|
|
}
|
|
return names
|
|
}
|
|
|
|
// findVariablesWithDefaults extracts all unique variables with their default values
|
|
// Supports both ${variable} and ${variable:-default} syntax
|
|
func findVariablesWithDefaults(content string) []Variable {
|
|
// Pattern matches:
|
|
// ${varname} - simple variable
|
|
// ${varname:-default} - variable with default value
|
|
re := regexp.MustCompile(`\$\{([^}:]+)(?::-([^}]*))?\}`)
|
|
matches := re.FindAllStringSubmatch(content, -1)
|
|
|
|
seenVars := make(map[string]bool)
|
|
var variables []Variable
|
|
|
|
for _, match := range matches {
|
|
if len(match) >= 2 {
|
|
varName := match[1]
|
|
if !seenVars[varName] {
|
|
seenVars[varName] = true
|
|
|
|
// Check if the original match contains the :- pattern
|
|
hasDefault := strings.Contains(match[0], ":-")
|
|
|
|
variable := Variable{
|
|
Name: varName,
|
|
HasDefault: hasDefault,
|
|
}
|
|
|
|
if hasDefault && len(match) >= 3 {
|
|
variable.DefaultValue = match[2] // Can be empty string
|
|
}
|
|
|
|
variables = append(variables, variable)
|
|
}
|
|
}
|
|
}
|
|
|
|
return variables
|
|
}
|
|
|
|
// validateVariables checks that all declared variables in the content are provided
|
|
// Variables with default values are not required
|
|
func validateVariables(content string, variables map[string]string) error {
|
|
declaredVars := findVariablesWithDefaults(content)
|
|
|
|
var missingVars []string
|
|
for _, variable := range declaredVars {
|
|
if _, exists := variables[variable.Name]; !exists && !variable.HasDefault {
|
|
missingVars = append(missingVars, variable.Name)
|
|
}
|
|
}
|
|
|
|
if len(missingVars) > 0 {
|
|
return fmt.Errorf("missing required variables: %s\nProvide them using --args:variable value syntax", strings.Join(missingVars, ", "))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// substituteVariables replaces ${variable} and ${variable:-default} patterns with their values
|
|
// This function is kept for backward compatibility but now uses the shared ArgsSubstituter
|
|
func substituteVariables(content string, variables map[string]string) string {
|
|
substituter := config.NewArgsSubstituter(variables)
|
|
result, err := substituter.SubstituteArgs(content)
|
|
if err != nil {
|
|
// For backward compatibility, if there's an error, return the original content
|
|
// This maintains the existing behavior where missing variables were left as-is
|
|
return content
|
|
}
|
|
return result
|
|
}
|
|
|
|
// runScriptMode executes the script using the unified agentic loop
|
|
func runScriptMode(ctx context.Context, mcpConfig *config.Config, prompt string, noExit bool) error {
|
|
// Set up logging
|
|
if debugMode || mcpConfig.Debug {
|
|
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
|
}
|
|
|
|
// Get final values from viper and script config
|
|
finalModel := viper.GetString("model")
|
|
if finalModel == "" && mcpConfig.Model != "" {
|
|
finalModel = mcpConfig.Model
|
|
}
|
|
if finalModel == "" {
|
|
finalModel = "anthropic/claude-sonnet-4-5-20250929" // default
|
|
}
|
|
|
|
finalSystemPrompt := viper.GetString("system-prompt")
|
|
if finalSystemPrompt == "" && mcpConfig.SystemPrompt != "" {
|
|
finalSystemPrompt = mcpConfig.SystemPrompt
|
|
}
|
|
|
|
finalDebug := viper.GetBool("debug") || mcpConfig.Debug
|
|
finalCompact := viper.GetBool("compact")
|
|
finalMaxSteps := viper.GetInt("max-steps")
|
|
if finalMaxSteps == 0 && mcpConfig.MaxSteps != 0 {
|
|
finalMaxSteps = mcpConfig.MaxSteps
|
|
}
|
|
|
|
finalProviderURL := viper.GetString("provider-url")
|
|
if finalProviderURL == "" && mcpConfig.ProviderURL != "" {
|
|
finalProviderURL = mcpConfig.ProviderURL
|
|
}
|
|
|
|
finalProviderAPIKey := viper.GetString("provider-api-key")
|
|
if finalProviderAPIKey == "" && mcpConfig.ProviderAPIKey != "" {
|
|
finalProviderAPIKey = mcpConfig.ProviderAPIKey
|
|
}
|
|
|
|
finalMaxTokens := viper.GetInt("max-tokens")
|
|
if finalMaxTokens == 0 && mcpConfig.MaxTokens != 0 {
|
|
finalMaxTokens = mcpConfig.MaxTokens
|
|
}
|
|
if finalMaxTokens == 0 {
|
|
finalMaxTokens = 4096 // default
|
|
}
|
|
|
|
finalTemperature := float32(viper.GetFloat64("temperature"))
|
|
if finalTemperature == 0 && mcpConfig.Temperature != nil {
|
|
finalTemperature = *mcpConfig.Temperature
|
|
}
|
|
if finalTemperature == 0 {
|
|
finalTemperature = 0.7 // default
|
|
}
|
|
|
|
finalTopP := float32(viper.GetFloat64("top-p"))
|
|
if finalTopP == 0 && mcpConfig.TopP != nil {
|
|
finalTopP = *mcpConfig.TopP
|
|
}
|
|
if finalTopP == 0 {
|
|
finalTopP = 0.95 // default
|
|
}
|
|
|
|
finalTopK := int32(viper.GetInt("top-k"))
|
|
if finalTopK == 0 && mcpConfig.TopK != nil {
|
|
finalTopK = *mcpConfig.TopK
|
|
}
|
|
if finalTopK == 0 {
|
|
finalTopK = 40 // default
|
|
}
|
|
|
|
finalStopSequences := viper.GetStringSlice("stop-sequences")
|
|
if len(finalStopSequences) == 0 && len(mcpConfig.StopSequences) > 0 {
|
|
finalStopSequences = mcpConfig.StopSequences
|
|
}
|
|
|
|
// Load system prompt
|
|
systemPrompt, err := config.LoadSystemPrompt(finalSystemPrompt)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load system prompt: %v", err)
|
|
}
|
|
|
|
// Create model configuration
|
|
modelConfig := &models.ProviderConfig{
|
|
ModelString: finalModel,
|
|
SystemPrompt: systemPrompt,
|
|
ProviderAPIKey: finalProviderAPIKey,
|
|
ProviderURL: finalProviderURL,
|
|
MaxTokens: finalMaxTokens,
|
|
Temperature: &finalTemperature,
|
|
TopP: &finalTopP,
|
|
TopK: &finalTopK,
|
|
StopSequences: finalStopSequences,
|
|
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
|
}
|
|
|
|
// Create the agent using the factory (scripts don't need spinners)
|
|
// Use a simple debug logger for scripts
|
|
var debugLogger tools.DebugLogger
|
|
if finalDebug {
|
|
debugLogger = tools.NewSimpleDebugLogger(true)
|
|
}
|
|
|
|
mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
|
|
ModelConfig: modelConfig,
|
|
MCPConfig: mcpConfig,
|
|
SystemPrompt: systemPrompt,
|
|
MaxSteps: finalMaxSteps,
|
|
StreamingEnabled: viper.GetBool("stream"),
|
|
ShowSpinner: false, // Scripts don't need spinners
|
|
Quiet: quietFlag,
|
|
SpinnerFunc: nil, // No spinner function needed
|
|
DebugLogger: debugLogger,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create agent: %v", err)
|
|
}
|
|
defer func() { _ = mcpAgent.Close() }()
|
|
|
|
// Get model name for display
|
|
_, modelName, _ := models.ParseModelString(finalModel)
|
|
if modelName == "" {
|
|
modelName = "Unknown"
|
|
}
|
|
|
|
// Create an adapter for the agent to match the UI interface
|
|
agentAdapter := &agentUIAdapter{agent: mcpAgent}
|
|
|
|
// Create CLI interface using the factory
|
|
cli, err := ui.SetupCLI(&ui.CLISetupOptions{
|
|
Agent: agentAdapter,
|
|
ModelString: finalModel,
|
|
Debug: finalDebug,
|
|
Compact: finalCompact,
|
|
Quiet: quietFlag,
|
|
ShowDebug: false, // Will be handled separately below
|
|
ProviderAPIKey: finalProviderAPIKey,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to setup CLI: %v", err)
|
|
}
|
|
|
|
// Display debug configuration if debug mode is enabled
|
|
if !quietFlag && cli != nil && finalDebug {
|
|
debugConfig := map[string]any{
|
|
"model": finalModel,
|
|
"max-steps": finalMaxSteps,
|
|
"max-tokens": finalMaxTokens,
|
|
"temperature": finalTemperature,
|
|
"top-p": finalTopP,
|
|
"top-k": finalTopK,
|
|
"provider-url": finalProviderURL,
|
|
"system-prompt": finalSystemPrompt,
|
|
}
|
|
|
|
// Only include non-empty stop sequences
|
|
if len(finalStopSequences) > 0 {
|
|
debugConfig["stop-sequences"] = finalStopSequences
|
|
}
|
|
|
|
// Only include API keys if they're set (but don't show the actual values for security)
|
|
if finalProviderAPIKey != "" {
|
|
debugConfig["provider-api-key"] = "[SET]"
|
|
}
|
|
|
|
cli.DisplayDebugConfig(debugConfig)
|
|
}
|
|
|
|
// Initialize hooks
|
|
var hookExecutor *hooks.Executor
|
|
if hooksConfig := viper.Get("hooks"); hooksConfig != nil {
|
|
if hc, ok := hooksConfig.(*hooks.HookConfig); ok {
|
|
// Generate a session ID for this run
|
|
sessionID := fmt.Sprintf("mcphost-%d", time.Now().Unix())
|
|
transcriptPath := "" // We could add transcript logging later
|
|
hookExecutor = hooks.NewExecutor(hc, sessionID, transcriptPath)
|
|
|
|
// Set model and interactive mode
|
|
hookExecutor.SetModel(finalModel)
|
|
hookExecutor.SetInteractive(prompt == "")
|
|
}
|
|
}
|
|
|
|
// Prepare data for slash commands
|
|
var serverNames []string
|
|
for name := range mcpConfig.MCPServers {
|
|
serverNames = append(serverNames, name)
|
|
}
|
|
|
|
tools := mcpAgent.GetTools()
|
|
var toolNames []string
|
|
for _, tool := range tools {
|
|
info := tool.Info()
|
|
toolNames = append(toolNames, info.Name)
|
|
}
|
|
|
|
// Configure and run unified agentic loop
|
|
var messages []fantasy.Message
|
|
config := AgenticLoopConfig{
|
|
IsInteractive: prompt == "", // If no prompt, start in interactive mode
|
|
InitialPrompt: prompt,
|
|
ContinueAfterRun: noExit,
|
|
Quiet: quietFlag,
|
|
ServerNames: serverNames,
|
|
ToolNames: toolNames,
|
|
ModelName: modelName,
|
|
MCPConfig: mcpConfig,
|
|
}
|
|
|
|
return runAgenticLoop(ctx, mcpAgent, cli, messages, config, hookExecutor)
|
|
}
|
|
|
|
// AgenticLoopConfig configures the behavior of the unified agentic loop.
|
|
// This struct controls how the main interaction loop operates, whether in
|
|
// interactive or non-interactive mode, and manages various UI and session options.
|
|
type AgenticLoopConfig struct {
|
|
// Mode configuration
|
|
IsInteractive bool // true for interactive mode, false for non-interactive
|
|
InitialPrompt string // initial prompt for non-interactive mode
|
|
ContinueAfterRun bool // true to continue to interactive mode after initial run (--no-exit)
|
|
ApproveToolRun bool // only used in interactive mode
|
|
|
|
// UI configuration
|
|
Quiet bool // suppress all output except final response
|
|
|
|
// Context data
|
|
ServerNames []string // for slash commands
|
|
ToolNames []string // for slash commands
|
|
ModelName string // for display
|
|
MCPConfig *config.Config // for continuing to interactive mode
|
|
SessionManager *session.Manager // for session persistence
|
|
}
|
|
|
|
// replaceMessagesHistory replaces the conversation history and saves to session if available
|
|
func replaceMessagesHistory(messages *[]fantasy.Message, sessionManager *session.Manager, cli *ui.CLI, newMessages []fantasy.Message) {
|
|
// Replace local history
|
|
*messages = newMessages
|
|
|
|
// Save to session if session manager is available
|
|
if sessionManager != nil {
|
|
// Use ReplaceAllMessages to ensure session matches local history exactly
|
|
if err := sessionManager.ReplaceAllMessages(*messages); err != nil {
|
|
// Log error but don't fail the operation
|
|
if cli != nil {
|
|
cli.DisplayError(fmt.Errorf("failed to save messages to session: %v", err))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// runAgenticLoop handles all execution modes with a single unified loop
|
|
func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []fantasy.Message, config AgenticLoopConfig, hookExecutor *hooks.Executor) error {
|
|
// Handle initial prompt for non-interactive modes
|
|
if !config.IsInteractive && config.InitialPrompt != "" {
|
|
// Execute UserPromptSubmit hooks for non-interactive mode
|
|
if hookExecutor != nil {
|
|
input := &hooks.UserPromptSubmitInput{
|
|
CommonInput: hookExecutor.PopulateCommonFields(hooks.UserPromptSubmit),
|
|
Prompt: config.InitialPrompt,
|
|
}
|
|
|
|
hookOutput, err := hookExecutor.ExecuteHooks(ctx, hooks.UserPromptSubmit, input)
|
|
if err != nil {
|
|
// Log error but don't fail
|
|
if debugMode {
|
|
fmt.Fprintf(os.Stderr, "UserPromptSubmit hook execution error: %v\n", err)
|
|
}
|
|
}
|
|
|
|
// Check if hook blocked the prompt
|
|
if hookOutput != nil && hookOutput.Decision == "block" {
|
|
return fmt.Errorf("prompt blocked by hook: %s", hookOutput.Reason)
|
|
}
|
|
}
|
|
|
|
// Display user message (skip if quiet)
|
|
if !config.Quiet && cli != nil {
|
|
cli.DisplayUserMessage(config.InitialPrompt)
|
|
}
|
|
|
|
// Create temporary messages with user input for processing (don't add to history yet)
|
|
tempMessages := append(messages, fantasy.NewUserMessage(config.InitialPrompt))
|
|
|
|
// Process the initial prompt with tool calls
|
|
_, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config, hookExecutor)
|
|
if err != nil {
|
|
// Check if this was a user cancellation
|
|
if err.Error() == "generation cancelled by user" && cli != nil {
|
|
cli.DisplayCancellation()
|
|
// On cancellation, continue to interactive mode (like --no-exit)
|
|
// Don't add the cancelled message to history
|
|
config.IsInteractive = true
|
|
} else {
|
|
return err
|
|
}
|
|
} else {
|
|
// Only add to history after successful completion
|
|
// conversationMessages already includes the user message, tool calls, and final response
|
|
replaceMessagesHistory(&messages, config.SessionManager, cli, conversationMessages)
|
|
|
|
// If not continuing to interactive mode, exit here
|
|
if !config.ContinueAfterRun {
|
|
return nil
|
|
}
|
|
|
|
// Update config for interactive mode continuation
|
|
config.IsInteractive = true
|
|
}
|
|
}
|
|
|
|
// Interactive loop (or continuation after non-interactive): not supported in script mode.
|
|
return nil
|
|
}
|
|
|
|
// runAgenticStep processes a single step of the agentic loop (handles tool calls)
|
|
func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []fantasy.Message, config AgenticLoopConfig, hookExecutor *hooks.Executor) (*fantasy.Response, []fantasy.Message, error) {
|
|
var currentSpinner *ui.Spinner
|
|
|
|
// Start initial spinner (skip if quiet)
|
|
if !config.Quiet && cli != nil {
|
|
currentSpinner = ui.NewSpinner("")
|
|
currentSpinner.Start()
|
|
}
|
|
|
|
// Create streaming callback for real-time display
|
|
var streamingCallback agent.StreamingResponseHandler
|
|
var responseWasStreamed bool
|
|
var lastDisplayedContent string
|
|
var streamingContent strings.Builder
|
|
var streamingStarted bool
|
|
if cli != nil && !config.Quiet {
|
|
streamingCallback = func(chunk string) {
|
|
// Stop spinner before first chunk if still running
|
|
if currentSpinner != nil {
|
|
currentSpinner.Stop()
|
|
currentSpinner = nil
|
|
}
|
|
// Mark that this response is being streamed
|
|
responseWasStreamed = true
|
|
|
|
// Accumulate content and update message
|
|
if !streamingStarted {
|
|
streamingStarted = true
|
|
streamingContent.Reset() // Reset content for new streaming session
|
|
}
|
|
streamingContent.WriteString(chunk)
|
|
}
|
|
}
|
|
|
|
// Reset streaming state before agent execution
|
|
responseWasStreamed = false
|
|
streamingStarted = false
|
|
streamingContent.Reset()
|
|
|
|
// Variables to store tool information for hooks
|
|
var currentToolName string
|
|
var currentToolArgs string
|
|
var toolIsBlocked bool
|
|
var blockReason string
|
|
|
|
result, err := mcpAgent.GenerateWithLoopAndStreaming(ctx, messages,
|
|
// Tool call handler - called when a tool is about to be executed
|
|
func(toolName, toolArgs string) {
|
|
// Store tool info for use in execution handler
|
|
currentToolName = toolName
|
|
currentToolArgs = toolArgs
|
|
|
|
if !config.Quiet && cli != nil {
|
|
// Stop spinner before displaying tool call
|
|
if currentSpinner != nil {
|
|
currentSpinner.Stop()
|
|
currentSpinner = nil
|
|
}
|
|
cli.DisplayToolCallMessage(toolName, toolArgs)
|
|
}
|
|
},
|
|
// Tool execution handler - called when tool execution starts/ends
|
|
func(toolName string, isStarting bool) {
|
|
if isStarting {
|
|
// Execute PreToolUse hooks
|
|
if hookExecutor != nil {
|
|
input := &hooks.PreToolUseInput{
|
|
CommonInput: hookExecutor.PopulateCommonFields(hooks.PreToolUse),
|
|
ToolName: currentToolName,
|
|
ToolInput: json.RawMessage(currentToolArgs),
|
|
}
|
|
|
|
hookOutput, err := hookExecutor.ExecuteHooks(ctx, hooks.PreToolUse, input)
|
|
if err != nil {
|
|
// Log error but don't fail the tool execution
|
|
if debugMode {
|
|
fmt.Fprintf(os.Stderr, "Hook execution error: %v\n", err)
|
|
}
|
|
}
|
|
|
|
// Check if hook blocked the execution
|
|
if hookOutput != nil && hookOutput.Decision == "block" {
|
|
toolIsBlocked = true
|
|
blockReason = hookOutput.Reason
|
|
if blockReason == "" {
|
|
blockReason = "Tool execution blocked by security policy"
|
|
}
|
|
if !config.Quiet && cli != nil {
|
|
cli.DisplayInfo(fmt.Sprintf("Tool execution blocked by hook: %s", blockReason))
|
|
}
|
|
}
|
|
}
|
|
|
|
if !config.Quiet && cli != nil {
|
|
// Start spinner for tool execution
|
|
currentSpinner = ui.NewSpinner(fmt.Sprintf("Executing %s...", toolName))
|
|
currentSpinner.Start()
|
|
}
|
|
} else {
|
|
// Stop spinner when tool execution completes
|
|
if !config.Quiet && cli != nil && currentSpinner != nil {
|
|
currentSpinner.Stop()
|
|
currentSpinner = nil
|
|
}
|
|
}
|
|
},
|
|
// Tool result handler - called when a tool execution completes
|
|
func(toolName, toolArgs, result string, isError bool) {
|
|
// Check if this tool was blocked
|
|
if toolIsBlocked {
|
|
// Reset the flag for next tool
|
|
toolIsBlocked = false
|
|
|
|
// Display the blocked message
|
|
if !config.Quiet && cli != nil {
|
|
cli.DisplayToolMessage(toolName, toolArgs, fmt.Sprintf("Tool execution blocked: %s", blockReason), true)
|
|
}
|
|
|
|
// Reset block reason
|
|
blockReason = ""
|
|
return
|
|
}
|
|
|
|
// Execute PostToolUse hooks
|
|
var postToolHookOutput *hooks.HookOutput
|
|
if hookExecutor != nil && result != "" {
|
|
input := &hooks.PostToolUseInput{
|
|
CommonInput: hookExecutor.PopulateCommonFields(hooks.PostToolUse),
|
|
ToolName: currentToolName,
|
|
ToolInput: json.RawMessage(currentToolArgs),
|
|
ToolResponse: json.RawMessage(result),
|
|
}
|
|
|
|
hookOutput, err := hookExecutor.ExecuteHooks(ctx, hooks.PostToolUse, input)
|
|
if err != nil {
|
|
// Log error but don't fail
|
|
if debugMode {
|
|
fmt.Fprintf(os.Stderr, "PostToolUse hook execution error: %v\n", err)
|
|
}
|
|
}
|
|
postToolHookOutput = hookOutput
|
|
}
|
|
|
|
// Check if hook wants to suppress output
|
|
if postToolHookOutput != nil && postToolHookOutput.SuppressOutput {
|
|
// Skip displaying tool result to user
|
|
// Note: Result still goes to LLM unless ModifyOutput is used
|
|
return
|
|
}
|
|
|
|
if !config.Quiet && cli != nil {
|
|
// Parse tool result content - it might be JSON-encoded MCP content
|
|
resultContent := result
|
|
|
|
// Try to parse as MCP content structure
|
|
var mcpContent struct {
|
|
Content []struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text"`
|
|
} `json:"content"`
|
|
}
|
|
|
|
// First try to unmarshal as-is
|
|
if err := json.Unmarshal([]byte(result), &mcpContent); err == nil {
|
|
// Extract text from MCP content structure
|
|
if len(mcpContent.Content) > 0 && mcpContent.Content[0].Type == "text" {
|
|
resultContent = mcpContent.Content[0].Text
|
|
}
|
|
} else {
|
|
// If that fails, try unquoting first (in case it's double-encoded)
|
|
var unquoted string
|
|
if err := json.Unmarshal([]byte(result), &unquoted); err == nil {
|
|
if err := json.Unmarshal([]byte(unquoted), &mcpContent); err == nil {
|
|
if len(mcpContent.Content) > 0 && mcpContent.Content[0].Type == "text" {
|
|
resultContent = mcpContent.Content[0].Text
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
cli.DisplayToolMessage(toolName, toolArgs, resultContent, isError)
|
|
// Reset streaming state for next LLM call
|
|
responseWasStreamed = false
|
|
streamingStarted = false
|
|
// Start spinner again for next LLM call
|
|
currentSpinner = ui.NewSpinner("")
|
|
currentSpinner.Start()
|
|
}
|
|
},
|
|
// Response handler - called when the LLM generates a response
|
|
func(content string) {
|
|
if !config.Quiet && cli != nil {
|
|
// Stop spinner when we get the final response
|
|
if currentSpinner != nil {
|
|
currentSpinner.Stop()
|
|
currentSpinner = nil
|
|
}
|
|
}
|
|
},
|
|
// Tool call content handler - called when content accompanies tool calls
|
|
func(content string) {
|
|
if !config.Quiet && cli != nil && !responseWasStreamed {
|
|
// Only display if content wasn't already streamed
|
|
// Stop spinner before displaying content
|
|
if currentSpinner != nil {
|
|
currentSpinner.Stop()
|
|
currentSpinner = nil
|
|
}
|
|
_ = cli.DisplayAssistantMessageWithModel(content, config.ModelName)
|
|
lastDisplayedContent = content
|
|
// Start spinner again for tool calls
|
|
currentSpinner = ui.NewSpinner("")
|
|
currentSpinner.Start()
|
|
} else if responseWasStreamed {
|
|
// Content was already streamed, just track it and manage spinner
|
|
lastDisplayedContent = content
|
|
if currentSpinner != nil {
|
|
currentSpinner.Stop()
|
|
currentSpinner = nil
|
|
}
|
|
// Start spinner again for tool calls
|
|
currentSpinner = ui.NewSpinner("")
|
|
currentSpinner.Start()
|
|
}
|
|
},
|
|
// Add streaming callback handler
|
|
streamingCallback,
|
|
// Tool call approval handler - called before tool execution to get user approval
|
|
func(toolName, toolArgs string) (bool, error) {
|
|
if !config.IsInteractive || !config.ApproveToolRun {
|
|
return true, nil
|
|
}
|
|
if currentSpinner != nil {
|
|
currentSpinner.Stop()
|
|
currentSpinner = nil
|
|
}
|
|
// Tool approval via CLI is no longer supported; always approve in legacy path.
|
|
// Start spinner again for tool calls
|
|
currentSpinner = ui.NewSpinner("")
|
|
currentSpinner.Start()
|
|
|
|
return true, nil
|
|
},
|
|
)
|
|
|
|
// Make sure spinner is stopped if still running
|
|
if !config.Quiet && cli != nil && currentSpinner != nil {
|
|
currentSpinner.Stop()
|
|
}
|
|
|
|
if err != nil {
|
|
if !config.Quiet && cli != nil {
|
|
cli.DisplayError(fmt.Errorf("agent error: %v", err))
|
|
}
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Get the final response and conversation messages
|
|
response := result.FinalResponse
|
|
conversationMessages := result.ConversationMessages
|
|
|
|
// Extract the last user message for usage tracking (do this once)
|
|
lastUserMessage := ""
|
|
if len(messages) > 0 {
|
|
// Find the last user message
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
if messages[i].Role == fantasy.MessageRoleUser {
|
|
// Extract text from message parts
|
|
for _, part := range messages[i].Content {
|
|
if tp, ok := part.(fantasy.TextPart); ok {
|
|
lastUserMessage = tp.Text
|
|
break
|
|
}
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Get text content from response
|
|
responseText := response.Content.Text()
|
|
|
|
// Update usage tracking for ALL responses (streaming and non-streaming)
|
|
if !config.Quiet && cli != nil {
|
|
cli.UpdateUsageFromResponse(response, lastUserMessage)
|
|
}
|
|
|
|
// Display assistant response with model name
|
|
// Skip if: quiet mode, same content already displayed, or if streaming completed the full response
|
|
streamedFullResponse := responseWasStreamed && streamingContent.String() == responseText
|
|
if !config.Quiet && cli != nil && responseText != lastDisplayedContent && responseText != "" && !streamedFullResponse {
|
|
if err := cli.DisplayAssistantMessageWithModel(responseText, config.ModelName); err != nil {
|
|
cli.DisplayError(fmt.Errorf("display error: %v", err))
|
|
return nil, nil, err
|
|
}
|
|
} else if config.Quiet {
|
|
// In quiet mode, only output the final response content to stdout
|
|
fmt.Print(responseText)
|
|
}
|
|
|
|
// Display usage information immediately after the response (for both streaming and non-streaming)
|
|
if !config.Quiet && cli != nil {
|
|
cli.DisplayUsageAfterResponse()
|
|
}
|
|
|
|
// Execute Stop hook after agent has finished responding
|
|
executeStopHook(hookExecutor, response, "completed", config.ModelName)
|
|
|
|
// Return the final response and all conversation messages
|
|
return response, conversationMessages, nil
|
|
}
|
|
|
|
// executeStopHook executes the Stop hook if a hook executor is available
|
|
func executeStopHook(hookExecutor *hooks.Executor, response *fantasy.Response, stopReason string, modelName string) {
|
|
if hookExecutor != nil {
|
|
// Prepare metadata
|
|
var meta json.RawMessage
|
|
if response != nil {
|
|
metaData := map[string]any{
|
|
"model": modelName,
|
|
"has_tool_calls": len(response.Content.ToolCalls()) > 0,
|
|
}
|
|
if metaBytes, err := json.Marshal(metaData); err == nil {
|
|
meta = json.RawMessage(metaBytes)
|
|
}
|
|
}
|
|
|
|
responseContent := ""
|
|
if response != nil {
|
|
responseContent = response.Content.Text()
|
|
}
|
|
|
|
input := &hooks.StopInput{
|
|
CommonInput: hookExecutor.PopulateCommonFields(hooks.Stop),
|
|
StopHookActive: true,
|
|
Response: responseContent,
|
|
StopReason: stopReason,
|
|
Meta: meta,
|
|
}
|
|
|
|
// Execute Stop hook (ignore errors as we're exiting anyway)
|
|
_, _ = hookExecutor.ExecuteHooks(context.Background(), hooks.Stop, input)
|
|
}
|
|
}
|