Fix for openai (#95)

This commit is contained in:
Ed Zynda
2025-06-28 13:38:58 +03:00
committed by GitHub
parent ddd7856f9b
commit 053e6c32b0
2 changed files with 165 additions and 8 deletions
+156
View File
@@ -0,0 +1,156 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
)
// CustomChatModel wraps the eino-ext OpenAI model with custom tool schema handling
type CustomChatModel struct {
wrapped *einoopenai.ChatModel
}
// CustomRoundTripper intercepts HTTP requests to fix OpenAI function schemas
type CustomRoundTripper struct {
wrapped http.RoundTripper
}
// NewCustomChatModel creates a new custom OpenAI chat model
func NewCustomChatModel(ctx context.Context, config *einoopenai.ChatModelConfig) (*CustomChatModel, error) {
// Create a custom HTTP client that intercepts requests
if config.HTTPClient == nil {
config.HTTPClient = &http.Client{}
}
// Wrap the transport to intercept requests
if config.HTTPClient.Transport == nil {
config.HTTPClient.Transport = http.DefaultTransport
}
config.HTTPClient.Transport = &CustomRoundTripper{
wrapped: config.HTTPClient.Transport,
}
wrapped, err := einoopenai.NewChatModel(ctx, config)
if err != nil {
return nil, err
}
return &CustomChatModel{
wrapped: wrapped,
}, nil
}
// RoundTrip implements http.RoundTripper to intercept and fix OpenAI requests
func (c *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// Only intercept OpenAI chat completions requests
if !strings.Contains(req.URL.Path, "/chat/completions") {
return c.wrapped.RoundTrip(req)
}
// Read the request body
if req.Body == nil {
return c.wrapped.RoundTrip(req)
}
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return c.wrapped.RoundTrip(req)
}
req.Body.Close()
// Parse the JSON request
var requestData map[string]interface{}
if err := json.Unmarshal(bodyBytes, &requestData); err != nil {
// If we can't parse it, just pass it through
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return c.wrapped.RoundTrip(req)
}
// Fix function schemas if present
if tools, ok := requestData["tools"].([]interface{}); ok {
for _, tool := range tools {
if toolMap, ok := tool.(map[string]interface{}); ok {
if function, ok := toolMap["function"].(map[string]interface{}); ok {
if parameters, ok := function["parameters"].(map[string]interface{}); ok {
if typeVal, ok := parameters["type"].(string); ok && typeVal == "object" {
// Check if properties is missing or empty
if properties, exists := parameters["properties"]; !exists || properties == nil {
parameters["properties"] = map[string]interface{}{}
} else if propMap, ok := properties.(map[string]interface{}); ok && len(propMap) == 0 {
parameters["properties"] = map[string]interface{}{}
}
}
}
}
}
}
}
// Marshal the fixed request back to JSON
fixedBodyBytes, err := json.Marshal(requestData)
if err != nil {
// If we can't marshal it, use the original
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return c.wrapped.RoundTrip(req)
}
// Create new request body with fixed data
req.Body = io.NopCloser(bytes.NewReader(fixedBodyBytes))
req.ContentLength = int64(len(fixedBodyBytes))
return c.wrapped.RoundTrip(req)
}
// Generate implements model.ChatModel
func (c *CustomChatModel) Generate(ctx context.Context, in []*schema.Message, opts ...model.Option) (*schema.Message, error) {
return c.wrapped.Generate(ctx, in, opts...)
}
// Stream implements model.ChatModel
func (c *CustomChatModel) Stream(ctx context.Context, in []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
return c.wrapped.Stream(ctx, in, opts...)
}
// WithTools implements model.ToolCallingChatModel
func (c *CustomChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
wrappedWithTools, err := c.wrapped.WithTools(tools)
if err != nil {
return nil, err
}
// Type assert back to *einoopenai.ChatModel
wrappedChatModel, ok := wrappedWithTools.(*einoopenai.ChatModel)
if !ok {
return nil, fmt.Errorf("unexpected type returned from WithTools")
}
return &CustomChatModel{wrapped: wrappedChatModel}, nil
}
// BindTools implements model.ToolCallingChatModel
func (c *CustomChatModel) BindTools(tools []*schema.ToolInfo) error {
return c.wrapped.BindTools(tools)
}
// BindForcedTools implements model.ToolCallingChatModel
func (c *CustomChatModel) BindForcedTools(tools []*schema.ToolInfo) error {
return c.wrapped.BindForcedTools(tools)
}
// GetType implements model.ChatModel
func (c *CustomChatModel) GetType() string {
return "CustomOpenAI"
}
// IsCallbacksEnabled implements model.ChatModel
func (c *CustomChatModel) IsCallbacksEnabled() bool {
return c.wrapped.IsCallbacksEnabled()
}
+9 -8
View File
@@ -11,10 +11,11 @@ import (
"strings"
"time"
"github.com/cloudwego/eino-ext/components/model/claude"
einoclaude "github.com/cloudwego/eino-ext/components/model/claude"
"github.com/cloudwego/eino-ext/components/model/ollama"
"github.com/cloudwego/eino-ext/components/model/openai"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/mark3labs/mcphost/internal/models/openai"
"github.com/mark3labs/mcphost/internal/ui/progress"
"github.com/ollama/ollama/api"
"google.golang.org/genai"
@@ -182,7 +183,7 @@ func createAzureOpenAIProvider(ctx context.Context, config *ProviderConfig, mode
return nil, fmt.Errorf("Azure OpenAI API key not provided. Use --provider-api-key flag or AZURE_OPENAI_API_KEY environment variable")
}
azureConfig := &openai.ChatModelConfig{
azureConfig := &einoopenai.ChatModelConfig{
APIKey: apiKey,
Model: modelName,
ByAzure: true, // Indicate this is an Azure OpenAI model
@@ -214,7 +215,7 @@ func createAzureOpenAIProvider(ctx context.Context, config *ProviderConfig, mode
azureConfig.Stop = config.StopSequences
}
return openai.NewChatModel(ctx, azureConfig)
return openai.NewCustomChatModel(ctx, azureConfig)
}
func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {
@@ -235,7 +236,7 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN
maxTokens = 4096 // Default value
}
claudeConfig := &claude.Config{
claudeConfig := &einoclaude.Config{
Model: modelName,
MaxTokens: maxTokens,
}
@@ -272,7 +273,7 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN
claudeConfig.StopSequences = config.StopSequences
}
return claude.NewChatModel(ctx, claudeConfig)
return einoclaude.NewChatModel(ctx, claudeConfig)
}
func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {
@@ -284,7 +285,7 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
return nil, fmt.Errorf("OpenAI API key not provided. Use --provider-api-key flag or OPENAI_API_KEY environment variable")
}
openaiConfig := &openai.ChatModelConfig{
openaiConfig := &einoopenai.ChatModelConfig{
APIKey: apiKey,
Model: modelName,
}
@@ -330,7 +331,7 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
openaiConfig.Stop = config.StopSequences
}
return openai.NewChatModel(ctx, openaiConfig)
return openai.NewCustomChatModel(ctx, openaiConfig)
}
func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {