Commit addde2a

mo khan <mo@mokhan.ca>
2025-08-16 14:53:18
refactor: extract a Tree type
1 parent 4afe9fc
Changed files (2)
pkg
pkg/filesystem/server.go
@@ -2,7 +2,6 @@ package filesystem
 
 import (
 	"fmt"
-	"mime"
 	"os"
 	"path/filepath"
 	"strings"
@@ -10,29 +9,8 @@ import (
 	"github.com/xlgmokha/mcp/pkg/mcp"
 )
 
-type Server struct {
-	*mcp.Server
-	allowedDirectories []string
-}
-
-func normalizeDirectories(dirs []string) []string {
-	normalizedDirs := make([]string, len(dirs))
-	for i, dir := range dirs {
-		absPath, err := filepath.Abs(expandHome(dir))
-		if err != nil {
-			panic(fmt.Sprintf("Invalid directory: %s", dir))
-		}
-		normalizedDirs[i] = filepath.Clean(absPath)
-	}
-	return normalizedDirs
-}
-
-func New(allowedDirs []string) *Server {
-	allowedDirectories := normalizeDirectories(allowedDirs)
-
-	fsServer := &Server{
-		allowedDirectories: allowedDirectories,
-	}
+func New(allowedDirs []string) *mcp.Server {
+	tree := NewTree(allowedDirs)
 
 	tools := []mcp.Tool{
 		mcp.NewTool("read_file", "Read the contents of a file", map[string]interface{}{
@@ -49,7 +27,7 @@ func New(allowedDirs []string) *Server {
 				return mcp.NewToolError("path is required"), nil
 			}
 
-			validPath, err := fsServer.validatePath(pathStr)
+			validPath, err := tree.validatePath(pathStr)
 			if err != nil {
 				return mcp.NewToolError(err.Error()), nil
 			}
@@ -84,7 +62,7 @@ func New(allowedDirs []string) *Server {
 				return mcp.NewToolError("content is required"), nil
 			}
 
-			validPath, err := fsServer.validatePath(pathStr)
+			validPath, err := tree.validatePath(pathStr)
 			if err != nil {
 				return mcp.NewToolError(err.Error()), nil
 			}
@@ -109,7 +87,7 @@ func New(allowedDirs []string) *Server {
 				}
 
 				filePath := req.URI[7:]
-				validPath, err := fsServer.validatePath(filePath)
+				validPath, err := tree.validatePath(filePath)
 				if err != nil {
 					return mcp.ReadResourceResult{}, fmt.Errorf("access denied: %v", err)
 				}
@@ -139,7 +117,7 @@ func New(allowedDirs []string) *Server {
 		),
 	}
 
-	for _, dir := range allowedDirectories {
+	for _, dir := range tree.Directories {
 		fileURI := "file://" + dir
 		dirName := filepath.Base(dir)
 		if dirName == "." || dirName == "/" {
@@ -158,14 +136,14 @@ func New(allowedDirs []string) *Server {
 				}, nil
 			},
 		))
-		
+
 		// Discover files in this directory at construction time
-		fileResources := fsServer.discoverFiles(dir)
+		fileResources := tree.discoverFiles(dir)
 		resources = append(resources, fileResources...)
 	}
 
 	var roots []mcp.Root
-	for _, dir := range allowedDirectories {
+	for _, dir := range tree.Directories {
 		fileURI := "file://" + dir
 		dirName := filepath.Base(dir)
 		if dirName == "." || dirName == "/" {
@@ -174,157 +152,5 @@ func New(allowedDirs []string) *Server {
 		roots = append(roots, mcp.NewRoot(fileURI, fmt.Sprintf("Directory: %s", dirName)))
 	}
 
-	fsServer.Server = mcp.NewServer("filesystem", "0.2.0", tools, resources, roots)
-
-	return fsServer
-}
-
-func (fs *Server) discoverFiles(dirPath string) []mcp.Resource {
-	var resources []mcp.Resource
-
-	programmingMimeTypes := map[string]string{
-		".go":   "text/x-go",
-		".rs":   "text/x-rust",
-		".py":   "text/x-python",
-		".java": "text/x-java",
-		".c":    "text/x-c",
-		".cpp":  "text/x-c++",
-		".h":    "text/x-c",
-		".hpp":  "text/x-c++",
-		".sh":   "text/x-shellscript",
-	}
-
-	specialFiles := map[string]string{
-		"Makefile":   "text/x-makefile",
-		"README":     "text/plain",
-		"LICENSE":    "text/plain",
-		"go.mod":     "text/x-go-mod",
-		"Cargo.toml": "application/toml",
-	}
-
-	const maxFiles = 500
-	fileCount := 0
-
-	err := filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error {
-		if err != nil || fileCount >= maxFiles {
-			return filepath.SkipDir
-		}
-
-		if strings.HasPrefix(info.Name(), ".") {
-			if info.IsDir() {
-				return filepath.SkipDir
-			}
-			return nil
-		}
-
-		if info.IsDir() {
-			dirName := strings.ToLower(info.Name())
-			skipDirs := []string{"node_modules", "vendor", "target", "build", "dist", ".git"}
-			for _, skip := range skipDirs {
-				if dirName == skip {
-					return filepath.SkipDir
-				}
-			}
-			return nil
-		}
-
-		fileName := info.Name()
-		ext := strings.ToLower(filepath.Ext(fileName))
-
-		var mimeType string
-
-		if progType, exists := programmingMimeTypes[ext]; exists {
-			mimeType = progType
-		} else if specialType, exists := specialFiles[fileName]; exists {
-			mimeType = specialType
-		} else {
-			mimeType = mime.TypeByExtension(ext)
-			if mimeType == "" {
-				return nil
-			}
-		}
-
-		fileURI := "file://" + path
-		relPath, _ := filepath.Rel(dirPath, path)
-
-		resource := mcp.Resource{
-			URI:      fileURI,
-			Name:     relPath,
-			MimeType: mimeType,
-		}
-
-		resources = append(resources, resource)
-		fileCount++
-
-		return nil
-	})
-
-	if err != nil {
-		fmt.Printf("Warning: Error discovering files in %s: %v\n", dirPath, err)
-	}
-
-	return resources
-}
-
-func isBinaryContent(content []byte) bool {
-	checkBytes := content
-	if len(content) > 512 {
-		checkBytes = content[:512]
-	}
-
-	for _, b := range checkBytes {
-		if b == 0 {
-			return true
-		}
-	}
-
-	return false
-}
-
-func (fs *Server) validatePath(requestedPath string) (string, error) {
-	expandedPath := expandHome(requestedPath)
-	var absolute string
-
-	if filepath.IsAbs(expandedPath) {
-		absolute = filepath.Clean(expandedPath)
-	} else {
-		wd, _ := os.Getwd()
-		absolute = filepath.Clean(filepath.Join(wd, expandedPath))
-	}
-
-	allowed := false
-	for _, dir := range fs.allowedDirectories {
-		if strings.HasPrefix(absolute, dir) {
-			allowed = true
-			break
-		}
-	}
-
-	if !allowed {
-		return "", fmt.Errorf("access denied: %s", absolute)
-	}
-
-	realPath, err := filepath.EvalSymlinks(absolute)
-	if err != nil {
-		return absolute, nil
-	}
-
-	for _, dir := range fs.allowedDirectories {
-		if strings.HasPrefix(realPath, dir) {
-			return realPath, nil
-		}
-	}
-
-	return "", fmt.Errorf("access denied")
-}
-
-func expandHome(filePath string) string {
-	if strings.HasPrefix(filePath, "~/") || filePath == "~" {
-		homeDir, _ := os.UserHomeDir()
-		if filePath == "~" {
-			return homeDir
-		}
-		return filepath.Join(homeDir, filePath[2:])
-	}
-	return filePath
+	return mcp.NewServer("filesystem", "0.2.0", tools, resources, roots)
 }
pkg/filesystem/tree.go
@@ -0,0 +1,183 @@
+package filesystem
+
+import (
+	"fmt"
+	"mime"
+	"os"
+	"path/filepath"
+	"strings"
+
+	"github.com/xlgmokha/mcp/pkg/mcp"
+)
+
+type Tree struct {
+	Directories []string
+}
+
+func NewTree(allowedDirectories []string) *Tree {
+	return &Tree{
+		Directories: normalizeDirectories(allowedDirectories),
+	}
+}
+
+func (fs *Tree) discoverFiles(dirPath string) []mcp.Resource {
+	var resources []mcp.Resource
+
+	programmingMimeTypes := map[string]string{
+		".go":   "text/x-go",
+		".rs":   "text/x-rust",
+		".py":   "text/x-python",
+		".java": "text/x-java",
+		".c":    "text/x-c",
+		".cpp":  "text/x-c++",
+		".h":    "text/x-c",
+		".hpp":  "text/x-c++",
+		".sh":   "text/x-shellscript",
+	}
+
+	specialFiles := map[string]string{
+		"Makefile":   "text/x-makefile",
+		"README":     "text/plain",
+		"LICENSE":    "text/plain",
+		"go.mod":     "text/x-go-mod",
+		"Cargo.toml": "application/toml",
+	}
+
+	const maxFiles = 500
+	fileCount := 0
+
+	err := filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error {
+		if err != nil || fileCount >= maxFiles {
+			return filepath.SkipDir
+		}
+
+		if strings.HasPrefix(info.Name(), ".") {
+			if info.IsDir() {
+				return filepath.SkipDir
+			}
+			return nil
+		}
+
+		if info.IsDir() {
+			dirName := strings.ToLower(info.Name())
+			skipDirs := []string{"node_modules", "vendor", "target", "build", "dist", ".git"}
+			for _, skip := range skipDirs {
+				if dirName == skip {
+					return filepath.SkipDir
+				}
+			}
+			return nil
+		}
+
+		fileName := info.Name()
+		ext := strings.ToLower(filepath.Ext(fileName))
+
+		var mimeType string
+
+		if progType, exists := programmingMimeTypes[ext]; exists {
+			mimeType = progType
+		} else if specialType, exists := specialFiles[fileName]; exists {
+			mimeType = specialType
+		} else {
+			mimeType = mime.TypeByExtension(ext)
+			if mimeType == "" {
+				return nil
+			}
+		}
+
+		fileURI := "file://" + path
+		relPath, _ := filepath.Rel(dirPath, path)
+
+		resource := mcp.Resource{
+			URI:      fileURI,
+			Name:     relPath,
+			MimeType: mimeType,
+		}
+
+		resources = append(resources, resource)
+		fileCount++
+
+		return nil
+	})
+
+	if err != nil {
+		fmt.Printf("Warning: Error discovering files in %s: %v\n", dirPath, err)
+	}
+
+	return resources
+}
+
+func isBinaryContent(content []byte) bool {
+	checkBytes := content
+	if len(content) > 512 {
+		checkBytes = content[:512]
+	}
+
+	for _, b := range checkBytes {
+		if b == 0 {
+			return true
+		}
+	}
+
+	return false
+}
+
+func (fs *Tree) validatePath(requestedPath string) (string, error) {
+	expandedPath := expandHome(requestedPath)
+	var absolute string
+
+	if filepath.IsAbs(expandedPath) {
+		absolute = filepath.Clean(expandedPath)
+	} else {
+		wd, _ := os.Getwd()
+		absolute = filepath.Clean(filepath.Join(wd, expandedPath))
+	}
+
+	allowed := false
+	for _, dir := range fs.Directories {
+		if strings.HasPrefix(absolute, dir) {
+			allowed = true
+			break
+		}
+	}
+
+	if !allowed {
+		return "", fmt.Errorf("access denied: %s", absolute)
+	}
+
+	realPath, err := filepath.EvalSymlinks(absolute)
+	if err != nil {
+		return absolute, nil
+	}
+
+	for _, dir := range fs.Directories {
+		if strings.HasPrefix(realPath, dir) {
+			return realPath, nil
+		}
+	}
+
+	return "", fmt.Errorf("access denied")
+}
+
+func expandHome(filePath string) string {
+	if strings.HasPrefix(filePath, "~/") || filePath == "~" {
+		homeDir, _ := os.UserHomeDir()
+		if filePath == "~" {
+			return homeDir
+		}
+		return filepath.Join(homeDir, filePath[2:])
+	}
+	return filePath
+}
+
+func normalizeDirectories(dirs []string) []string {
+	normalizedDirs := make([]string, len(dirs))
+	for i, dir := range dirs {
+		absPath, err := filepath.Abs(expandHome(dir))
+		if err != nil {
+			panic(fmt.Sprintf("Invalid directory: %s", dir))
+		}
+		normalizedDirs[i] = filepath.Clean(absPath)
+	}
+	return normalizedDirs
+}