Commit 56c7693

mo khan <mo@mokhan.ca>
2025-08-15 23:34:57
refactor: collapse more code into the constructor
1 parent d67e930
Changed files (1)
pkg
filesystem
pkg/filesystem/server.go
@@ -1,371 +1,361 @@
 package filesystem
 
 import (
-  "encoding/json"
-  "fmt"
-  "mime"
-  "os"
-  "path/filepath"
-  "strings"
-
-  "github.com/xlgmokha/mcp/pkg/mcp"
+	"encoding/json"
+	"fmt"
+	"mime"
+	"os"
+	"path/filepath"
+	"strings"
+
+	"github.com/xlgmokha/mcp/pkg/mcp"
 )
 
 type Server struct {
-  *mcp.Server
-  allowedDirectories []string
+	*mcp.Server
+	allowedDirectories []string
 }
 
-func New(allowedDirs []string) *Server {
-  normalizedDirs := make([]string, len(allowedDirs))
-  for i, dir := range allowedDirs {
-    absPath, err := filepath.Abs(expandHome(dir))
-    if err != nil {
-      panic(fmt.Sprintf("Invalid directory: %s", dir))
-    }
-    normalizedDirs[i] = filepath.Clean(absPath)
-  }
-
-  fsServer := &Server{
-    allowedDirectories: normalizedDirs,
-  }
-
-  tools := []mcp.Tool{
-    mcp.NewTool("read_file", "Read the contents of a file", map[string]interface{}{
-      "type": "object",
-      "properties": map[string]interface{}{
-        "path": map[string]interface{}{
-          "type": "string",
-        },
-      },
-      "required": []string{"path"},
-    }, func(req mcp.CallToolRequest) (mcp.CallToolResult, error) {
-      pathStr, ok := req.Arguments["path"].(string)
-      if !ok {
-        return mcp.NewToolError("path is required"), nil
-      }
-
-      validPath, err := fsServer.validatePath(pathStr)
-      if err != nil {
-        return mcp.NewToolError(err.Error()), nil
-      }
-
-      content, err := os.ReadFile(validPath)
-      if err != nil {
-        return mcp.NewToolError(fmt.Sprintf("Failed to read file: %v", err)), nil
-      }
-
-      return mcp.NewToolResult(mcp.NewTextContent(string(content))), nil
-    }),
-
-    mcp.NewTool("write_file", "Write content to a file", map[string]interface{}{
-      "type": "object",
-      "properties": map[string]interface{}{
-        "path": map[string]interface{}{
-          "type": "string",
-        },
-        "content": map[string]interface{}{
-          "type": "string",
-        },
-      },
-      "required": []string{"path", "content"},
-    }, func(req mcp.CallToolRequest) (mcp.CallToolResult, error) {
-      pathStr, ok := req.Arguments["path"].(string)
-      if !ok {
-        return mcp.NewToolError("path is required"), nil
-      }
-
-      content, ok := req.Arguments["content"].(string)
-      if !ok {
-        return mcp.NewToolError("content is required"), nil
-      }
-
-      validPath, err := fsServer.validatePath(pathStr)
-      if err != nil {
-        return mcp.NewToolError(err.Error()), nil
-      }
-
-      err = os.WriteFile(validPath, []byte(content), 0644)
-      if err != nil {
-        return mcp.NewToolError(fmt.Sprintf("Failed to write file: %v", err)), nil
-      }
-
-      return mcp.NewToolResult(mcp.NewTextContent(fmt.Sprintf("Successfully wrote to %s", pathStr))), nil
-    }),
-  }
-
-  resources := []mcp.Resource{
-    mcp.NewResource(
-      "file://",
-      "File System",
-      "",
-      func(req mcp.ReadResourceRequest) (mcp.ReadResourceResult, error) {
-        if !strings.HasPrefix(req.URI, "file://") {
-          return mcp.ReadResourceResult{}, fmt.Errorf("invalid file URI: %s", req.URI)
-        }
-
-        filePath := req.URI[7:]
-        validPath, err := fsServer.validatePath(filePath)
-        if err != nil {
-          return mcp.ReadResourceResult{}, fmt.Errorf("access denied: %v", err)
-        }
-
-        content, err := os.ReadFile(validPath)
-        if err != nil {
-          return mcp.ReadResourceResult{}, fmt.Errorf("failed to read file: %v", err)
-        }
-
-        if isBinaryContent(content) {
-          return mcp.ReadResourceResult{
-            Contents: []mcp.Content{
-              mcp.TextContent{
-                Type: "text",
-                Text: fmt.Sprintf("Binary file (size: %d bytes)", len(content)),
-              },
-            },
-          }, nil
-        }
-
-        return mcp.ReadResourceResult{
-          Contents: []mcp.Content{
-            mcp.NewTextContent(string(content)),
-          },
-        }, nil
-      },
-    ),
-  }
-
-  for _, dir := range normalizedDirs {
-    fileURI := "file://" + dir
-    dirName := filepath.Base(dir)
-    if dirName == "." || dirName == "/" {
-      dirName = dir
-    }
-
-    resources = append(resources, mcp.NewResource(
-      fileURI,
-      fmt.Sprintf("Directory: %s", dirName),
-      "inode/directory",
-      func(req mcp.ReadResourceRequest) (mcp.ReadResourceResult, error) {
-        return mcp.ReadResourceResult{
-          Contents: []mcp.Content{
-            mcp.NewTextContent(fmt.Sprintf("Directory: %s", dir)),
-          },
-        }, nil
-      },
-    ))
-  }
-
-  server := mcp.NewServer("filesystem", "0.2.0", tools, resources)
-  fsServer.Server = server
-
-  fsServer.registerRoots()
-  fsServer.setupResourceHandling()
-
-  return fsServer
+func normalizeDirectories(dirs []string) []string {
+	normalizedDirs := make([]string, len(dirs))
+	for i, dir := range dirs {
+		absPath, err := filepath.Abs(expandHome(dir))
+		if err != nil {
+			panic(fmt.Sprintf("Invalid directory: %s", dir))
+		}
+		normalizedDirs[i] = filepath.Clean(absPath)
+	}
+	return normalizedDirs
 }
 
+func New(allowedDirs []string) *Server {
+	allowedDirectories := normalizeDirectories(allowedDirs)
+
+	fsServer := &Server{
+		allowedDirectories: allowedDirectories,
+	}
+
+	tools := []mcp.Tool{
+		mcp.NewTool("read_file", "Read the contents of a file", map[string]interface{}{
+			"type": "object",
+			"properties": map[string]interface{}{
+				"path": map[string]interface{}{
+					"type": "string",
+				},
+			},
+			"required": []string{"path"},
+		}, func(req mcp.CallToolRequest) (mcp.CallToolResult, error) {
+			pathStr, ok := req.Arguments["path"].(string)
+			if !ok {
+				return mcp.NewToolError("path is required"), nil
+			}
+
+			validPath, err := fsServer.validatePath(pathStr)
+			if err != nil {
+				return mcp.NewToolError(err.Error()), nil
+			}
+
+			content, err := os.ReadFile(validPath)
+			if err != nil {
+				return mcp.NewToolError(fmt.Sprintf("Failed to read file: %v", err)), nil
+			}
+
+			return mcp.NewToolResult(mcp.NewTextContent(string(content))), nil
+		}),
+
+		mcp.NewTool("write_file", "Write content to a file", map[string]interface{}{
+			"type": "object",
+			"properties": map[string]interface{}{
+				"path": map[string]interface{}{
+					"type": "string",
+				},
+				"content": map[string]interface{}{
+					"type": "string",
+				},
+			},
+			"required": []string{"path", "content"},
+		}, func(req mcp.CallToolRequest) (mcp.CallToolResult, error) {
+			pathStr, ok := req.Arguments["path"].(string)
+			if !ok {
+				return mcp.NewToolError("path is required"), nil
+			}
+
+			content, ok := req.Arguments["content"].(string)
+			if !ok {
+				return mcp.NewToolError("content is required"), nil
+			}
+
+			validPath, err := fsServer.validatePath(pathStr)
+			if err != nil {
+				return mcp.NewToolError(err.Error()), nil
+			}
+
+			err = os.WriteFile(validPath, []byte(content), 0644)
+			if err != nil {
+				return mcp.NewToolError(fmt.Sprintf("Failed to write file: %v", err)), nil
+			}
+
+			return mcp.NewToolResult(mcp.NewTextContent(fmt.Sprintf("Successfully wrote to %s", pathStr))), nil
+		}),
+	}
+
+	resources := []mcp.Resource{
+		mcp.NewResource(
+			"file://",
+			"File System",
+			"",
+			func(req mcp.ReadResourceRequest) (mcp.ReadResourceResult, error) {
+				if !strings.HasPrefix(req.URI, "file://") {
+					return mcp.ReadResourceResult{}, fmt.Errorf("invalid file URI: %s", req.URI)
+				}
+
+				filePath := req.URI[7:]
+				validPath, err := fsServer.validatePath(filePath)
+				if err != nil {
+					return mcp.ReadResourceResult{}, fmt.Errorf("access denied: %v", err)
+				}
+
+				content, err := os.ReadFile(validPath)
+				if err != nil {
+					return mcp.ReadResourceResult{}, fmt.Errorf("failed to read file: %v", err)
+				}
+
+				if isBinaryContent(content) {
+					return mcp.ReadResourceResult{
+						Contents: []mcp.Content{
+							mcp.TextContent{
+								Type: "text",
+								Text: fmt.Sprintf("Binary file (size: %d bytes)", len(content)),
+							},
+						},
+					}, nil
+				}
+
+				return mcp.ReadResourceResult{
+					Contents: []mcp.Content{
+						mcp.NewTextContent(string(content)),
+					},
+				}, nil
+			},
+		),
+	}
+
+	for _, dir := range allowedDirectories {
+		fileURI := "file://" + dir
+		dirName := filepath.Base(dir)
+		if dirName == "." || dirName == "/" {
+			dirName = dir
+		}
+
+		resources = append(resources, mcp.NewResource(
+			fileURI,
+			fmt.Sprintf("Directory: %s", dirName),
+			"inode/directory",
+			func(req mcp.ReadResourceRequest) (mcp.ReadResourceResult, error) {
+				return mcp.ReadResourceResult{
+					Contents: []mcp.Content{
+						mcp.NewTextContent(fmt.Sprintf("Directory: %s", dir)),
+					},
+				}, nil
+			},
+		))
+	}
+
+	fsServer.Server = mcp.NewServer("filesystem", "0.2.0", tools, resources)
+
+	for _, dir := range allowedDirectories {
+		fileURI := "file://" + dir
+		dirName := filepath.Base(dir)
+		if dirName == "." || dirName == "/" {
+			dirName = dir
+		}
+
+		root := mcp.NewRoot(fileURI, fmt.Sprintf("Directory: %s", dirName))
+		fsServer.RegisterRoot(root)
+	}
+
+	handlers := map[string]func(mcp.JSONRPCRequest) mcp.JSONRPCResponse{
+		"resources/list": func(req mcp.JSONRPCRequest) mcp.JSONRPCResponse {
+			resources := fsServer.ListResources()
+			result := mcp.ListResourcesResult{Resources: resources}
+
+			id := req.ID
+			bytes, _ := json.Marshal(result)
+			rawMsg := json.RawMessage(bytes)
+			resultBytes := &rawMsg
+
+			return mcp.JSONRPCResponse{
+				JSONRPC: "2.0",
+				ID:      id,
+				Result:  resultBytes,
+			}
+		},
+	}
+	fsServer.SetCustomRequestHandler(handlers)
+
+	return fsServer
+}
 
+func (fs *Server) ListResources() []mcp.Resource {
+	var resources []mcp.Resource
 
+	parentResources := fs.Server.ListResources()
+	resources = append(resources, parentResources...)
 
-func (fs *Server) registerRoots() {
-  for _, dir := range fs.allowedDirectories {
-    fileURI := "file://" + dir
-    dirName := filepath.Base(dir)
-    if dirName == "." || dirName == "/" {
-      dirName = dir
-    }
-
-    root := mcp.NewRoot(fileURI, fmt.Sprintf("Directory: %s", dirName))
-    fs.RegisterRoot(root)
-  }
-}
+	for _, dir := range fs.allowedDirectories {
+		fileResources := fs.discoverFiles(dir)
+		resources = append(resources, fileResources...)
+	}
 
-func (fs *Server) setupResourceHandling() {
-  customListResourcesHandler := func(req mcp.JSONRPCRequest) mcp.JSONRPCResponse {
-    resources := fs.ListResources()
-    result := mcp.ListResourcesResult{Resources: resources}
-    
-    id := req.ID
-    bytes, _ := json.Marshal(result)
-    rawMsg := json.RawMessage(bytes)
-    resultBytes := &rawMsg
-    
-    return mcp.JSONRPCResponse{
-      JSONRPC: "2.0",
-      ID:      id,
-      Result:  resultBytes,
-    }
-  }
-  
-  handlers := map[string]func(mcp.JSONRPCRequest) mcp.JSONRPCResponse{
-    "resources/list": customListResourcesHandler,
-  }
-  fs.SetCustomRequestHandler(handlers)
-}
-
-func (fs *Server) ListResources() []mcp.Resource {
-  var resources []mcp.Resource
-  
-  parentResources := fs.Server.ListResources()
-  resources = append(resources, parentResources...)
-  
-  for _, dir := range fs.allowedDirectories {
-    fileResources := fs.discoverFiles(dir)
-    resources = append(resources, fileResources...)
-  }
-  
-  return resources
+	return resources
 }
 
 func (fs *Server) discoverFiles(dirPath string) []mcp.Resource {
-  var resources []mcp.Resource
-  
-  programmingMimeTypes := map[string]string{
-    ".go": "text/x-go", 
-    ".rs": "text/x-rust", 
-    ".py": "text/x-python",
-    ".java": "text/x-java",
-    ".c": "text/x-c", 
-    ".cpp": "text/x-c++", 
-    ".h": "text/x-c", 
-    ".hpp": "text/x-c++",
-    ".sh": "text/x-shellscript",
-  }
-  
-  specialFiles := map[string]string{
-    "Makefile": "text/x-makefile", 
-    "README": "text/plain", 
-    "LICENSE": "text/plain",
-    "go.mod": "text/x-go-mod", 
-    "Cargo.toml": "application/toml",
-  }
-  
-  const maxFiles = 500
-  fileCount := 0
-  
-  err := filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error {
-    if err != nil || fileCount >= maxFiles {
-      return filepath.SkipDir
-    }
-    
-    if strings.HasPrefix(info.Name(), ".") {
-      if info.IsDir() {
-        return filepath.SkipDir
-      }
-      return nil
-    }
-    
-    if info.IsDir() {
-      dirName := strings.ToLower(info.Name())
-      skipDirs := []string{"node_modules", "vendor", "target", "build", "dist", ".git"}
-      for _, skip := range skipDirs {
-        if dirName == skip {
-          return filepath.SkipDir
-        }
-      }
-      return nil
-    }
-    
-    fileName := info.Name()
-    ext := strings.ToLower(filepath.Ext(fileName))
-    
-    var mimeType string
-    
-    if progType, exists := programmingMimeTypes[ext]; exists {
-      mimeType = progType
-    } else if specialType, exists := specialFiles[fileName]; exists {
-      mimeType = specialType
-    } else {
-      mimeType = mime.TypeByExtension(ext)
-      if mimeType == "" {
-        return nil
-      }
-    }
-    
-    fileURI := "file://" + path
-    relPath, _ := filepath.Rel(dirPath, path)
-    
-    resource := mcp.Resource{
-      URI:      fileURI,
-      Name:     relPath,
-      MimeType: mimeType,
-    }
-    
-    resources = append(resources, resource)
-    fileCount++
-    
-    return nil
-  })
-  
-  if err != nil {
-    fmt.Printf("Warning: Error discovering files in %s: %v\n", dirPath, err)
-  }
-  
-  return resources
+	var resources []mcp.Resource
+
+	programmingMimeTypes := map[string]string{
+		".go":   "text/x-go",
+		".rs":   "text/x-rust",
+		".py":   "text/x-python",
+		".java": "text/x-java",
+		".c":    "text/x-c",
+		".cpp":  "text/x-c++",
+		".h":    "text/x-c",
+		".hpp":  "text/x-c++",
+		".sh":   "text/x-shellscript",
+	}
+
+	specialFiles := map[string]string{
+		"Makefile":   "text/x-makefile",
+		"README":     "text/plain",
+		"LICENSE":    "text/plain",
+		"go.mod":     "text/x-go-mod",
+		"Cargo.toml": "application/toml",
+	}
+
+	const maxFiles = 500
+	fileCount := 0
+
+	err := filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error {
+		if err != nil || fileCount >= maxFiles {
+			return filepath.SkipDir
+		}
+
+		if strings.HasPrefix(info.Name(), ".") {
+			if info.IsDir() {
+				return filepath.SkipDir
+			}
+			return nil
+		}
+
+		if info.IsDir() {
+			dirName := strings.ToLower(info.Name())
+			skipDirs := []string{"node_modules", "vendor", "target", "build", "dist", ".git"}
+			for _, skip := range skipDirs {
+				if dirName == skip {
+					return filepath.SkipDir
+				}
+			}
+			return nil
+		}
+
+		fileName := info.Name()
+		ext := strings.ToLower(filepath.Ext(fileName))
+
+		var mimeType string
+
+		if progType, exists := programmingMimeTypes[ext]; exists {
+			mimeType = progType
+		} else if specialType, exists := specialFiles[fileName]; exists {
+			mimeType = specialType
+		} else {
+			mimeType = mime.TypeByExtension(ext)
+			if mimeType == "" {
+				return nil
+			}
+		}
+
+		fileURI := "file://" + path
+		relPath, _ := filepath.Rel(dirPath, path)
+
+		resource := mcp.Resource{
+			URI:      fileURI,
+			Name:     relPath,
+			MimeType: mimeType,
+		}
+
+		resources = append(resources, resource)
+		fileCount++
+
+		return nil
+	})
+
+	if err != nil {
+		fmt.Printf("Warning: Error discovering files in %s: %v\n", dirPath, err)
+	}
+
+	return resources
 }
 
-
 func isBinaryContent(content []byte) bool {
-  checkBytes := content
-  if len(content) > 512 {
-    checkBytes = content[:512]
-  }
-
-  for _, b := range checkBytes {
-    if b == 0 {
-      return true
-    }
-  }
-
-  return false
+	checkBytes := content
+	if len(content) > 512 {
+		checkBytes = content[:512]
+	}
+
+	for _, b := range checkBytes {
+		if b == 0 {
+			return true
+		}
+	}
+
+	return false
 }
 
-
 func (fs *Server) validatePath(requestedPath string) (string, error) {
-  expandedPath := expandHome(requestedPath)
-  var absolute string
-
-  if filepath.IsAbs(expandedPath) {
-    absolute = filepath.Clean(expandedPath)
-  } else {
-    wd, _ := os.Getwd()
-    absolute = filepath.Clean(filepath.Join(wd, expandedPath))
-  }
-
-  allowed := false
-  for _, dir := range fs.allowedDirectories {
-    if strings.HasPrefix(absolute, dir) {
-      allowed = true
-      break
-    }
-  }
-
-  if !allowed {
-    return "", fmt.Errorf("access denied: %s", absolute)
-  }
-
-  realPath, err := filepath.EvalSymlinks(absolute)
-  if err != nil {
-    return absolute, nil
-  }
-
-  for _, dir := range fs.allowedDirectories {
-    if strings.HasPrefix(realPath, dir) {
-      return realPath, nil
-    }
-  }
-
-  return "", fmt.Errorf("access denied")
+	expandedPath := expandHome(requestedPath)
+	var absolute string
+
+	if filepath.IsAbs(expandedPath) {
+		absolute = filepath.Clean(expandedPath)
+	} else {
+		wd, _ := os.Getwd()
+		absolute = filepath.Clean(filepath.Join(wd, expandedPath))
+	}
+
+	allowed := false
+	for _, dir := range fs.allowedDirectories {
+		if strings.HasPrefix(absolute, dir) {
+			allowed = true
+			break
+		}
+	}
+
+	if !allowed {
+		return "", fmt.Errorf("access denied: %s", absolute)
+	}
+
+	realPath, err := filepath.EvalSymlinks(absolute)
+	if err != nil {
+		return absolute, nil
+	}
+
+	for _, dir := range fs.allowedDirectories {
+		if strings.HasPrefix(realPath, dir) {
+			return realPath, nil
+		}
+	}
+
+	return "", fmt.Errorf("access denied")
 }
 
 func expandHome(filePath string) string {
-  if strings.HasPrefix(filePath, "~/") || filePath == "~" {
-    homeDir, _ := os.UserHomeDir()
-    if filePath == "~" {
-      return homeDir
-    }
-    return filepath.Join(homeDir, filePath[2:])
-  }
-  return filePath
+	if strings.HasPrefix(filePath, "~/") || filePath == "~" {
+		homeDir, _ := os.UserHomeDir()
+		if filePath == "~" {
+			return homeDir
+		}
+		return filepath.Join(homeDir, filePath[2:])
+	}
+	return filePath
 }