Commit 6148f00
Changed files (3)
pkg/mcp/prompts_test.go
@@ -0,0 +1,149 @@
+package mcp
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestPrompt_Creation(t *testing.T) {
+ prompt := &Prompt{
+ Name: "test-prompt",
+ Description: "A test prompt",
+ Arguments: []PromptArgument{
+ {
+ Name: "input",
+ Description: "Input parameter",
+ Required: true,
+ },
+ {
+ Name: "optional",
+ Description: "Optional parameter",
+ Required: false,
+ },
+ },
+ }
+
+ if prompt.Name != "test-prompt" {
+ t.Errorf("Expected name 'test-prompt', got %s", prompt.Name)
+ }
+ if len(prompt.Arguments) != 2 {
+ t.Errorf("Expected 2 arguments, got %d", len(prompt.Arguments))
+ }
+ if !prompt.Arguments[0].Required {
+ t.Error("Expected first argument to be required")
+ }
+ if prompt.Arguments[1].Required {
+ t.Error("Expected second argument to be optional")
+ }
+}
+
+func TestListPromptsResult_Empty(t *testing.T) {
+ result := ListPromptsResult{
+ Prompts: []Prompt{},
+ }
+
+ if len(result.Prompts) != 0 {
+ t.Errorf("Expected 0 prompts, got %d", len(result.Prompts))
+ }
+}
+
+func TestListPromptsResult_WithPrompts(t *testing.T) {
+ prompts := []Prompt{
+ {Name: "prompt1", Description: "First prompt"},
+ {Name: "prompt2", Description: "Second prompt"},
+ }
+
+ result := ListPromptsResult{
+ Prompts: prompts,
+ }
+
+ if len(result.Prompts) != 2 {
+ t.Errorf("Expected 2 prompts, got %d", len(result.Prompts))
+ }
+ if result.Prompts[0].Name != "prompt1" {
+ t.Errorf("Expected first prompt name 'prompt1', got %s", result.Prompts[0].Name)
+ }
+}
+
+func TestGetPromptRequest_Creation(t *testing.T) {
+ req := GetPromptRequest{
+ Name: "test-prompt",
+ Arguments: map[string]interface{}{
+ "input": "test value",
+ "count": 42,
+ },
+ }
+
+ if req.Name != "test-prompt" {
+ t.Errorf("Expected name 'test-prompt', got %s", req.Name)
+ }
+ if req.Arguments["input"] != "test value" {
+ t.Errorf("Expected input 'test value', got %v", req.Arguments["input"])
+ }
+ if req.Arguments["count"] != 42 {
+ t.Errorf("Expected count 42, got %v", req.Arguments["count"])
+ }
+}
+
+func TestGetPromptResult_WithMessages(t *testing.T) {
+ messages := []PromptMessage{
+ {Role: "user", Content: NewTextContent("Hello")},
+ {Role: "assistant", Content: NewTextContent("Hi there!")},
+ }
+
+ result := GetPromptResult{
+ Description: "Test conversation",
+ Messages: messages,
+ }
+
+ if result.Description != "Test conversation" {
+ t.Errorf("Expected description 'Test conversation', got %s", result.Description)
+ }
+ if len(result.Messages) != 2 {
+ t.Errorf("Expected 2 messages, got %d", len(result.Messages))
+ }
+ if result.Messages[0].Role != "user" {
+ t.Errorf("Expected first message role 'user', got %s", result.Messages[0].Role)
+ }
+}
+
+func TestPromptMessage_Types(t *testing.T) {
+ tests := []struct {
+ name string
+ role string
+ content Content
+ expected PromptMessage
+ }{
+ {
+ name: "user message",
+ role: "user",
+ content: NewTextContent("What is the weather?"),
+ expected: PromptMessage{
+ Role: "user",
+ Content: NewTextContent("What is the weather?"),
+ },
+ },
+ {
+ name: "assistant message",
+ role: "assistant",
+ content: NewTextContent("The weather is sunny."),
+ expected: PromptMessage{
+ Role: "assistant",
+ Content: NewTextContent("The weather is sunny."),
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ msg := PromptMessage{
+ Role: tt.role,
+ Content: tt.content,
+ }
+
+ if !reflect.DeepEqual(msg, tt.expected) {
+ t.Errorf("Expected %+v, got %+v", tt.expected, msg)
+ }
+ })
+ }
+}
\ No newline at end of file
pkg/mcp/server.go
@@ -19,6 +19,7 @@ type Server struct {
// Handler functions
toolHandlers map[string]ToolHandler
promptHandlers map[string]PromptHandler
+ promptDefinitions map[string]Prompt
resourceHandlers map[string]ResourceHandler
// Lifecycle handlers
@@ -36,11 +37,12 @@ type ResourceHandler func(ReadResourceRequest) (ReadResourceResult, error)
// NewServer creates a new MCP server
func NewServer(name, version string) *Server {
return &Server{
- name: name,
- version: version,
- toolHandlers: make(map[string]ToolHandler),
- promptHandlers: make(map[string]PromptHandler),
- resourceHandlers: make(map[string]ResourceHandler),
+ name: name,
+ version: version,
+ toolHandlers: make(map[string]ToolHandler),
+ promptHandlers: make(map[string]PromptHandler),
+ promptDefinitions: make(map[string]Prompt),
+ resourceHandlers: make(map[string]ResourceHandler),
capabilities: ServerCapabilities{
Tools: &ToolsCapability{},
Prompts: &PromptsCapability{},
@@ -57,11 +59,12 @@ func (s *Server) RegisterTool(name string, handler ToolHandler) {
s.toolHandlers[name] = handler
}
-// RegisterPrompt registers a prompt handler
-func (s *Server) RegisterPrompt(name string, handler PromptHandler) {
+// RegisterPrompt registers a prompt with its definition and handler
+func (s *Server) RegisterPrompt(prompt Prompt, handler PromptHandler) {
s.mu.Lock()
defer s.mu.Unlock()
- s.promptHandlers[name] = handler
+ s.promptHandlers[prompt.Name] = handler
+ s.promptDefinitions[prompt.Name] = prompt
}
// RegisterResource registers a resource handler
@@ -103,6 +106,19 @@ func (s *Server) ListTools() []Tool {
return tools
}
+// ListPrompts returns all registered prompts
+func (s *Server) ListPrompts() []Prompt {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ prompts := make([]Prompt, 0, len(s.promptDefinitions))
+ for _, prompt := range s.promptDefinitions {
+ prompts = append(prompts, prompt)
+ }
+
+ return prompts
+}
+
// Run starts the server and handles JSON-RPC over stdio
func (s *Server) Run(ctx context.Context) error {
scanner := bufio.NewScanner(os.Stdin)
@@ -230,8 +246,8 @@ func (s *Server) handleCallTool(req JSONRPCRequest) JSONRPCResponse {
}
func (s *Server) handleListPrompts(req JSONRPCRequest) JSONRPCResponse {
- // Return empty prompts list for now
- result := ListPromptsResult{Prompts: []Prompt{}}
+ prompts := s.ListPrompts()
+ result := ListPromptsResult{Prompts: prompts}
return s.createSuccessResponse(req.ID, result)
}
pkg/mcp/server_prompts_test.go
@@ -0,0 +1,113 @@
+package mcp
+
+import (
+ "testing"
+)
+
+func TestServer_RegisterPrompt(t *testing.T) {
+ server := NewServer("test-server", "1.0.0")
+
+ prompt := Prompt{
+ Name: "test-prompt",
+ Description: "A test prompt for verification",
+ Arguments: []PromptArgument{
+ {
+ Name: "input",
+ Description: "Input parameter",
+ Required: true,
+ },
+ },
+ }
+
+ handler := func(req GetPromptRequest) (GetPromptResult, error) {
+ return GetPromptResult{
+ Description: "Test response",
+ Messages: []PromptMessage{
+ {
+ Role: "user",
+ Content: NewTextContent("Hello"),
+ },
+ },
+ }, nil
+ }
+
+ server.RegisterPrompt(prompt, handler)
+
+ // Test that prompt is registered
+ prompts := server.ListPrompts()
+ if len(prompts) != 1 {
+ t.Errorf("Expected 1 prompt, got %d", len(prompts))
+ }
+
+ if prompts[0].Name != "test-prompt" {
+ t.Errorf("Expected prompt name 'test-prompt', got %s", prompts[0].Name)
+ }
+
+ if prompts[0].Description != "A test prompt for verification" {
+ t.Errorf("Expected description 'A test prompt for verification', got %s", prompts[0].Description)
+ }
+
+ if len(prompts[0].Arguments) != 1 {
+ t.Errorf("Expected 1 argument, got %d", len(prompts[0].Arguments))
+ }
+
+ if prompts[0].Arguments[0].Name != "input" {
+ t.Errorf("Expected argument name 'input', got %s", prompts[0].Arguments[0].Name)
+ }
+
+ if !prompts[0].Arguments[0].Required {
+ t.Error("Expected argument to be required")
+ }
+}
+
+func TestServer_ListPrompts_Empty(t *testing.T) {
+ server := NewServer("test-server", "1.0.0")
+
+ prompts := server.ListPrompts()
+ if len(prompts) != 0 {
+ t.Errorf("Expected 0 prompts, got %d", len(prompts))
+ }
+}
+
+func TestServer_MultiplePrompts(t *testing.T) {
+ server := NewServer("test-server", "1.0.0")
+
+ prompt1 := Prompt{
+ Name: "prompt1",
+ Description: "First prompt",
+ }
+
+ prompt2 := Prompt{
+ Name: "prompt2",
+ Description: "Second prompt",
+ Arguments: []PromptArgument{
+ {Name: "arg1", Required: true},
+ {Name: "arg2", Required: false},
+ },
+ }
+
+ handler := func(req GetPromptRequest) (GetPromptResult, error) {
+ return GetPromptResult{}, nil
+ }
+
+ server.RegisterPrompt(prompt1, handler)
+ server.RegisterPrompt(prompt2, handler)
+
+ prompts := server.ListPrompts()
+ if len(prompts) != 2 {
+ t.Errorf("Expected 2 prompts, got %d", len(prompts))
+ }
+
+ // Check that both prompts are present (order may vary due to map iteration)
+ names := make(map[string]bool)
+ for _, prompt := range prompts {
+ names[prompt.Name] = true
+ }
+
+ if !names["prompt1"] {
+ t.Error("prompt1 not found in list")
+ }
+ if !names["prompt2"] {
+ t.Error("prompt2 not found in list")
+ }
+}
\ No newline at end of file