Commit 1fc7c28

mo khan <mo@mokhan.ca>
2025-08-18 21:04:35
refactor: simplify the fetch server
1 parent dee7660
Changed files (2)
pkg
fetch
htmlprocessor
pkg/fetch/server.go
@@ -1,6 +1,7 @@
 package fetch
 
 import (
+	"encoding/base64"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -9,35 +10,27 @@ import (
 	"strings"
 	"time"
 
-	"github.com/xlgmokha/mcp/pkg/htmlprocessor"
 	"github.com/xlgmokha/mcp/pkg/mcp"
 )
 
-// FetchResult represents the result of a fetch operation
 type FetchResult struct {
 	URL         string `json:"url"`
 	Content     string `json:"content"`
 	ContentType string `json:"content_type"`
-	Length      int    `json:"length"`
-	Truncated   bool   `json:"truncated,omitempty"`
-	NextIndex   int    `json:"next_index,omitempty"`
+	IsBinary    bool   `json:"is_binary"`
 }
 
-// FetchOperations provides HTTP client operations for fetching content
 type FetchOperations struct {
-	httpClient    *http.Client
-	userAgent     string
-	htmlProcessor *htmlprocessor.ContentExtractor
+	httpClient *http.Client
+	userAgent  string
 }
 
-// NewFetchOperations creates a new FetchOperations helper
 func NewFetchOperations() *FetchOperations {
 	return &FetchOperations{
 		httpClient: &http.Client{
 			Timeout: 30 * time.Second,
 		},
-		userAgent:     "ModelContextProtocol/1.0 (Fetch; +https://github.com/xlgmokha/mcp)",
-		htmlProcessor: htmlprocessor.NewContentExtractor(),
+		userAgent: "ModelContextProtocol/1.0 (Fetch; +https://github.com/xlgmokha/mcp)",
 	}
 }
 
@@ -46,8 +39,7 @@ func New() *mcp.Server {
 	fetch := NewFetchOperations()
 	builder := mcp.NewServerBuilder("mcp-fetch", "1.0.0")
 
-	// Add fetch tool
-	builder.AddTool(mcp.NewTool("fetch", "Fetches a URL from the internet and extracts its contents as markdown. Always returns successful response with content or error details.", map[string]interface{}{
+	builder.AddTool(mcp.NewTool("fetch", "Fetches a URL and returns the content. Text content is returned as-is, binary content is base64 encoded.", map[string]interface{}{
 		"type": "object",
 		"properties": map[string]interface{}{
 			"url": map[string]interface{}{
@@ -55,24 +47,6 @@ func New() *mcp.Server {
 				"description": "URL to fetch",
 				"format":      "uri",
 			},
-			"max_length": map[string]interface{}{
-				"type":        "integer",
-				"description": "Maximum number of characters to return. Defaults to 5000",
-				"minimum":     1,
-				"maximum":     999999,
-				"default":     5000,
-			},
-			"start_index": map[string]interface{}{
-				"type":        "integer",
-				"description": "Start reading content from this character index. Defaults to 0",
-				"minimum":     0,
-				"default":     0,
-			},
-			"raw": map[string]interface{}{
-				"type":        "boolean",
-				"description": "Get raw HTML content without markdown conversion. Defaults to false",
-				"default":     false,
-			},
 		},
 		"required": []string{"url"},
 	}, func(req mcp.CallToolRequest) (mcp.CallToolResult, error) {
@@ -81,57 +55,16 @@ func New() *mcp.Server {
 			return mcp.NewToolError("url is required"), nil
 		}
 
-		// Parse and validate URL
 		parsedURL, err := url.Parse(urlStr)
 		if err != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
 			return mcp.NewToolError("Invalid URL format"), nil
 		}
 
-		// Get optional parameters
-		maxLength := 5000
-		if ml, ok := req.Arguments["max_length"]; ok {
-			switch v := ml.(type) {
-			case float64:
-				maxLength = int(v)
-			case int:
-				maxLength = v
-			default:
-				return mcp.NewToolError("max_length must be a number"), nil
-			}
-			if maxLength < 1 || maxLength > 999999 {
-				return mcp.NewToolError("max_length must be between 1 and 999999"), nil
-			}
-		}
-
-		startIndex := 0
-		if si, ok := req.Arguments["start_index"]; ok {
-			switch v := si.(type) {
-			case float64:
-				startIndex = int(v)
-			case int:
-				startIndex = v
-			default:
-				return mcp.NewToolError("start_index must be a number"), nil
-			}
-			if startIndex < 0 {
-				return mcp.NewToolError("start_index must be >= 0"), nil
-			}
-		}
-
-		raw := false
-		if r, ok := req.Arguments["raw"]; ok {
-			if rBool, ok := r.(bool); ok {
-				raw = rBool
-			}
-		}
-
-		// Fetch the content
-		result, err := fetch.fetchContent(parsedURL.String(), maxLength, startIndex, raw)
+		result, err := fetch.fetchContent(parsedURL.String())
 		if err != nil {
 			return mcp.NewToolError(err.Error()), nil
 		}
 
-		// Format result as JSON
 		jsonResult, err := json.MarshalIndent(result, "", "  ")
 		if err != nil {
 			return mcp.NewToolError(fmt.Sprintf("Failed to marshal result: %v", err)), nil
@@ -140,155 +73,70 @@ func New() *mcp.Server {
 		return mcp.NewToolResult(mcp.NewTextContent(string(jsonResult))), nil
 	}))
 
-	// Add fetch prompt
-	builder.AddPrompt(mcp.NewPrompt("fetch", "Prompt for manually entering a URL to fetch content from", []mcp.PromptArgument{
-		{
-			Name:        "url",
-			Description: "The URL to fetch content from",
-			Required:    true,
-		},
-		{
-			Name:        "reason",
-			Description: "Why you want to fetch this URL (optional context)",
-			Required:    false,
-		},
-	}, func(req mcp.GetPromptRequest) (mcp.GetPromptResult, error) {
-		url, hasURL := req.Arguments["url"].(string)
-		reason, hasReason := req.Arguments["reason"].(string)
-
-		if !hasURL || url == "" {
-			return mcp.GetPromptResult{}, fmt.Errorf("url argument is required")
-		}
-
-		// Create the prompt messages
-		var messages []mcp.PromptMessage
-
-		// User message with the URL and optional reason
-		userContent := fmt.Sprintf("Please fetch the content from this URL: %s", url)
-		if hasReason && reason != "" {
-			userContent += fmt.Sprintf("\n\nReason: %s", reason)
-		}
-
-		messages = append(messages, mcp.PromptMessage{
-			Role:    "user",
-			Content: mcp.NewTextContent(userContent),
-		})
-
-		// Assistant message suggesting the fetch tool usage
-		assistantContent := fmt.Sprintf(`I'll fetch the content from %s for you.
-
-Let me use the fetch tool to retrieve and process the content:`, url)
-
-		messages = append(messages, mcp.PromptMessage{
-			Role:    "assistant",
-			Content: mcp.NewTextContent(assistantContent),
-		})
-
-		description := "Manual URL fetch prompt"
-		if hasReason && reason != "" {
-			description = fmt.Sprintf("Manual URL fetch: %s", reason)
-		}
-
-		return mcp.GetPromptResult{
-			Description: description,
-			Messages:    messages,
-		}, nil
-	}))
-
 	return builder.Build()
 }
 
-// Helper methods for FetchOperations
-
-func (fetch *FetchOperations) fetchContent(urlStr string, maxLength, startIndex int, raw bool) (*FetchResult, error) {
-	// Create HTTP request
+func (fetch *FetchOperations) fetchContent(urlStr string) (*FetchResult, error) {
 	req, err := http.NewRequest("GET", urlStr, nil)
 	if err != nil {
 		return nil, fmt.Errorf("Failed to create request: %v", err)
 	}
 
 	req.Header.Set("User-Agent", fetch.userAgent)
-	req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8")
 
-	// Perform HTTP request
 	resp, err := fetch.httpClient.Do(req)
 	if err != nil {
 		return nil, fmt.Errorf("Failed to fetch URL: %v", err)
 	}
 	defer resp.Body.Close()
 
-	// Check for HTTP errors
 	if resp.StatusCode >= 400 {
 		return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status)
 	}
 
-	// Read response body
 	body, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return nil, fmt.Errorf("Failed to read response body: %v", err)
 	}
 
-	// Get content type
 	contentType := resp.Header.Get("Content-Type")
-
-	// Process content
+	isBinary := isBinaryContent(contentType)
+	
 	var content string
-	if raw || !isHTMLContent(string(body), contentType) {
-		content = string(body)
+	if isBinary {
+		content = base64.StdEncoding.EncodeToString(body)
 	} else {
-		// Convert HTML to markdown using improved processor
-		var err error
-		content, err = fetch.htmlProcessor.ToMarkdown(string(body))
-		if err != nil {
-			// Fallback to raw content if markdown conversion fails
-			content = string(body)
-		}
-	}
-
-	// Apply start index first
-	originalLength := len(content)
-	if startIndex > 0 {
-		if startIndex >= originalLength {
-			return nil, fmt.Errorf("start_index (%d) is beyond content length (%d)", startIndex, originalLength)
-		}
-		content = content[startIndex:]
-	}
-
-	// Apply max length and check for truncation
-	truncated := false
-	nextIndex := 0
-	if len(content) > maxLength {
-		content = content[:maxLength]
-		truncated = true
-		nextIndex = startIndex + maxLength
+		content = string(body)
 	}
 
-	result := &FetchResult{
+	return &FetchResult{
 		URL:         urlStr,
 		Content:     content,
 		ContentType: contentType,
-		Length:      len(content),
-	}
-
-	if truncated {
-		result.Truncated = true
-		result.NextIndex = nextIndex
-	}
-
-	return result, nil
+		IsBinary:    isBinary,
+	}, nil
 }
 
-func isHTMLContent(content, contentType string) bool {
-	// Check content type header
-	if strings.Contains(strings.ToLower(contentType), "text/html") {
-		return true
+func isBinaryContent(contentType string) bool {
+	if contentType == "" {
+		return false
 	}
-
-	// Check if content starts with HTML tags (first 100 chars)
-	prefix := content
-	if len(prefix) > 100 {
-		prefix = prefix[:100]
+	
+	contentType = strings.ToLower(strings.Split(contentType, ";")[0])
+	
+	textTypes := []string{
+		"text/",
+		"application/json",
+		"application/xml",
+		"application/javascript",
+		"application/x-javascript",
 	}
-
-	return strings.Contains(strings.ToLower(prefix), "<html")
+	
+	for _, textType := range textTypes {
+		if strings.HasPrefix(contentType, textType) {
+			return false
+		}
+	}
+	
+	return true
 }
pkg/htmlprocessor/processor.go
@@ -1,98 +0,0 @@
-package htmlprocessor
-
-import (
-	"strings"
-
-	"github.com/JohannesKaufmann/html-to-markdown"
-	"github.com/PuerkitoBio/goquery"
-)
-
-// ContentExtractor handles HTML content extraction and conversion
-type ContentExtractor struct {
-	converter *md.Converter
-}
-
-// NewContentExtractor creates a new ContentExtractor with default settings
-func NewContentExtractor() *ContentExtractor {
-	converter := md.NewConverter("", true, nil)
-
-	// Add custom rules to remove unwanted elements
-	converter.AddRules(
-		md.Rule{
-			Filter: []string{"script", "style", "nav", "header", "footer", "aside"},
-			Replacement: func(content string, selec *goquery.Selection, opt *md.Options) *string {
-				// Remove these elements entirely
-				empty := ""
-				return &empty
-			},
-		},
-	)
-
-	return &ContentExtractor{
-		converter: converter,
-	}
-}
-
-// ExtractReadableContent extracts the main readable content from HTML
-// It removes navigation, ads, scripts, styles, and other non-content elements
-func (e *ContentExtractor) ExtractReadableContent(html string) (string, error) {
-	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
-	if err != nil {
-		return "", err
-	}
-
-	// Remove unwanted elements
-	doc.Find("script, style, nav, header, footer, aside, .sidebar, .ads, .advertisement").Remove()
-
-	// Try to find main content areas in order of preference
-	var contentSelection *goquery.Selection
-
-	// Look for semantic HTML5 elements first
-	if main := doc.Find("main"); main.Length() > 0 {
-		contentSelection = main.First()
-	} else if article := doc.Find("article"); article.Length() > 0 {
-		contentSelection = article.First()
-	} else if content := doc.Find(".content, .main-content, #content, #main"); content.Length() > 0 {
-		contentSelection = content.First()
-	} else {
-		// Fallback to body
-		contentSelection = doc.Find("body")
-	}
-
-	// Extract text content
-	var textParts []string
-	contentSelection.Find("h1, h2, h3, h4, h5, h6, p, li").Each(func(i int, s *goquery.Selection) {
-		text := strings.TrimSpace(s.Text())
-		if text != "" {
-			textParts = append(textParts, text)
-		}
-	})
-
-	return strings.Join(textParts, "\n"), nil
-}
-
-// ToMarkdown converts HTML to markdown format
-func (e *ContentExtractor) ToMarkdown(html string) (string, error) {
-	markdown, err := e.converter.ConvertString(html)
-	if err != nil {
-		return "", err
-	}
-
-	// Clean up extra whitespace
-	lines := strings.Split(markdown, "\n")
-	var cleanLines []string
-
-	for _, line := range lines {
-		trimmed := strings.TrimSpace(line)
-		if trimmed != "" || (len(cleanLines) > 0 && cleanLines[len(cleanLines)-1] != "") {
-			cleanLines = append(cleanLines, trimmed)
-		}
-	}
-
-	// Remove trailing empty lines
-	for len(cleanLines) > 0 && cleanLines[len(cleanLines)-1] == "" {
-		cleanLines = cleanLines[:len(cleanLines)-1]
-	}
-
-	return strings.Join(cleanLines, "\n"), nil
-}