Commit 94ec3d6

mo khan <mo@mokhan.ca>
2025-08-14 17:12:08
feat: validate tool call function arguments
1 parent b3677eb
lib/elelem/configuration.rb
@@ -51,6 +51,8 @@ module Elelem
     end
 
     def mcp_tools(clients = [serena_client])
+      return [] if ENV["SMALL"]
+
       @mcp_tools ||= clients.map { |client| client.tools.map { |tool| MCPTool.new(client, tui, tool) } }.flatten
     end
 
lib/elelem/conversation.rb
@@ -16,7 +16,7 @@ module Elelem
     def add(role: :user, content: "")
       role = role.to_sym
       raise "unknown role: #{role}" unless ROLES.include?(role)
-      return if content&.empty?
+      return if content.nil? || content.empty?
 
       if @items.last && @items.last[:role] == role
         @items.last[:content] += content
lib/elelem/state.rb
@@ -80,15 +80,7 @@ module Elelem
             agent.show_progress(tool_name, "[>]", colour: :magenta)
             agent.say("\n\n", newline: false)
 
-            result = agent.execute(tool_call)
-
-            if result.is_a?(Hash) && result[:success] == false
-              agent.say("\n", newline: false)
-              agent.complete_progress("#{tool_name} failed", colour: :red)
-              return Error.new(agent, result[:error])
-            end
-
-            output = result.is_a?(Hash) ? result[:output] : result
+            output = agent.execute(tool_call)
             agent.conversation.add(role: :tool, content: output)
 
             agent.say("\n", newline: false)
lib/elelem/tool.rb
@@ -14,6 +14,10 @@ module Elelem
       [name, description].join(": ")
     end
 
+    def valid?(args)
+      JSON::Validator.validate(parameters, args, insert_defaults: true)
+    end
+
     def to_h
       {
         type: "function",
@@ -29,9 +33,9 @@ module Elelem
   class BashTool < Tool
     attr_reader :tui
 
-    def initialize(tui)
-      @tui = tui
-      super("bash(command)", "Execute a shell command.", {
+    def initialize(configuration)
+      @tui = configuration.tui
+      super("bash", "Execute a shell command.", {
         parameters: {
           type: "object",
           properties: {
@@ -91,28 +95,24 @@ module Elelem
 
     def call(args)
       unless client.connected?
-        error_msg = "MCP connection lost"
-        tui.say(error_msg, colour: :red)
-        return { success: false, output: "", error: error_msg }
+        tui.say("MCP connection lost", colour: :red)
+        return ""
       end
 
       result = client.call(name, args)
+      tui.say(result)
 
       if result.nil? || result.empty?
-        error_msg = "Tool call failed: no response from MCP server"
-        tui.say(error_msg, colour: :red)
-        return { success: false, output: "", error: error_msg }
+        tui.say("Tool call failed: no response from MCP server", colour: :red)
+        return result
       end
 
       if result["error"]
-        error_msg = "Tool error: #{result["error"]}"
-        tui.say(error_msg, colour: :red)
-        return { success: false, output: "", error: error_msg }
+        tui.say(result["error"], colour: :red)
+        return result
       end
 
-      output = result.dig("content", 0, "text") || result.to_s
-      tui.say(output)
-      { success: true, output: output, error: nil }
+      result.dig("content", 0, "text") || result.to_s
     end
   end
 end
lib/elelem/tools.rb
@@ -15,7 +15,11 @@ module Elelem
       name = tool_call.dig("function", "name")
       args = tool_call.dig("function", "arguments")
 
-      tools.find { |tool| tool.name == name }&.call(args)
+      tool = tools.find { |tool| tool.name == name }
+      return "Invalid function name: #{name}" if tool.nil?
+      return "Invalid function arguments: #{args}" unless tool.valid?(args)
+
+      tool.call(args)
     end
 
     def to_h
lib/elelem.rb
@@ -2,6 +2,7 @@
 
 require "erb"
 require "json"
+require "json-schema"
 require "logger"
 require "net/http"
 require "open3"
elelem.gemspec
@@ -33,6 +33,7 @@ Gem::Specification.new do |spec|
 
   spec.add_dependency "erb"
   spec.add_dependency "json"
+  spec.add_dependency "json-schema"
   spec.add_dependency "logger"
   spec.add_dependency "net-http"
   spec.add_dependency "open3"
Gemfile.lock
@@ -4,6 +4,7 @@ PATH
     elelem (0.1.1)
       erb
       json
+      json-schema
       logger
       net-http
       open3
@@ -14,7 +15,10 @@ PATH
 GEM
   remote: https://rubygems.org/
   specs:
+    addressable (2.8.7)
+      public_suffix (>= 2.0.2, < 7.0)
     ast (2.4.3)
+    bigdecimal (3.2.2)
     date (3.4.1)
     diff-lcs (1.6.2)
     erb (5.0.2)
@@ -24,6 +28,9 @@ GEM
       rdoc (>= 4.0.0)
       reline (>= 0.4.2)
     json (2.13.2)
+    json-schema (6.0.0)
+      addressable (~> 2.8)
+      bigdecimal (~> 3.1)
     language_server-protocol (3.17.0.5)
     lint_roller (1.1.0)
     logger (1.7.0)
@@ -41,6 +48,7 @@ GEM
     psych (5.2.6)
       date
       stringio
+    public_suffix (6.0.2)
     racc (1.8.1)
     rainbow (3.1.1)
     rake (13.3.0)