Commit 7884cd4

mo khan <mo@mokhan.ca>
2026-01-20 22:41:08
refactor: cleanup net code
1 parent 5904cd6
lib/elelem/net/claude.rb
@@ -16,93 +16,172 @@ module Elelem
         new(
           endpoint: "https://#{region}-aiplatform.googleapis.com/v1/projects/#{project}/locations/#{region}/publishers/anthropic/models/#{model}:rawPredict",
           headers: -> { { "Authorization" => "Bearer #{`gcloud auth application-default print-access-token`.strip}" } },
+          model:,
           version: "vertex-2023-10-16",
           http:
         )
       end
 
       def initialize(endpoint:, headers:, model:, version: nil, http: Elelem::Net.http)
-        @endpoint, @headers_src, @model, @version, @http = endpoint, headers, model, version, http
+        @endpoint = endpoint
+        @headers_source = headers
+        @model = model
+        @version = version
+        @http = http
       end
 
       def fetch(messages, tools = [], &block)
-        system, msgs = extract_system(messages)
+        system_prompt, normalized_messages = extract_system(messages)
         tool_calls = []
 
-        stream(msgs, system, tools) do |event|
-          case event["type"]
-          when "content_block_start"
-            if event.dig("content_block", "type") == "tool_use"
-              tool_calls << { id: event.dig("content_block", "id"), name: event.dig("content_block", "name"), args: "" }
-            end
-          when "content_block_delta"
-            case event.dig("delta", "type")
-            when "text_delta"
-              block.call(content: event.dig("delta", "text"), thinking: nil)
-            when "thinking_delta"
-              block.call(content: nil, thinking: event.dig("delta", "thinking"))
-            when "input_json_delta"
-              tool_calls.last[:args] += event.dig("delta", "partial_json").to_s if tool_calls.any?
-            end
-          when "message_stop"
-            tool_calls.each { |tool_call| tool_call[:arguments] = begin; JSON.parse(tool_call.delete(:args)); rescue; {}; end }
-          end
+        stream(normalized_messages, system_prompt, tools) do |event|
+          handle_event(event, tool_calls, &block)
         end
-        tool_calls
+
+        finalize_tool_calls(tool_calls)
       end
 
       private
 
-      def headers = @headers_src.respond_to?(:call) ? @headers_src.call : @headers_src
+      def headers
+        @headers_source.respond_to?(:call) ? @headers_source.call : @headers_source
+      end
+
+      def handle_event(event, tool_calls, &block)
+        case event["type"]
+        when "content_block_start"
+          handle_content_block_start(event, tool_calls)
+        when "content_block_delta"
+          handle_content_block_delta(event, tool_calls, &block)
+        end
+      end
+
+      def handle_content_block_start(event, tool_calls)
+        content_block = event["content_block"]
+        return unless content_block["type"] == "tool_use"
+
+        tool_calls << {
+          id: content_block["id"],
+          name: content_block["name"],
+          args: String.new
+        }
+      end
+
+      def handle_content_block_delta(event, tool_calls, &block)
+        delta = event["delta"]
+
+        case delta["type"]
+        when "text_delta"
+          block.call(content: delta["text"], thinking: nil)
+        when "thinking_delta"
+          block.call(content: nil, thinking: delta["thinking"])
+        when "input_json_delta"
+          tool_calls.last[:args] << delta["partial_json"].to_s if tool_calls.any?
+        end
+      end
+
+      def finalize_tool_calls(tool_calls)
+        tool_calls.each do |tool_call|
+          tool_call[:arguments] = JSON.parse(tool_call.delete(:args))
+        end
+      end
+
+      def stream(messages, system_prompt, tools)
+        body = build_request_body(messages, system_prompt, tools)
+
+        @http.post(@endpoint, headers:, body:) do |response|
+          raise "HTTP #{response.code}: #{response.body}" unless response.is_a?(::Net::HTTPSuccess)
+
+          read_sse_stream(response) { |event| yield event }
+        end
+      end
 
-      def stream(messages, system, tools, &block)
+      def build_request_body(messages, system_prompt, tools)
         body = { max_tokens: 64000, messages:, stream: true }
-        body[:model] = @model if @model
+        body[:model] = @model unless @version
         body[:anthropic_version] = @version if @version
-        body[:system] = system if system
+        body[:system] = system_prompt if system_prompt
         body[:tools] = unwrap_tools(tools) unless tools.empty?
+        body
+      end
+
+      def read_sse_stream(response)
+        buffer = String.new
+
+        response.read_body do |chunk|
+          buffer << chunk
 
-        @http.post(@endpoint, headers:, body:) do |res|
-          raise "HTTP #{res.code}: #{res.body}" unless res.is_a?(::Net::HTTPSuccess)
-          buf = ""
-          res.read_body do |chunk|
-            buf += chunk
-            while (i = buf.index("\n\n"))
-              parse_sse(buf.slice!(0, i + 2))&.then { |data| block.call(data) }
-            end
+          while (index = buffer.index("\n\n"))
+            raw_event = buffer.slice!(0, index + 2)
+            event = parse_sse(raw_event)
+            yield event if event
           end
         end
       end
 
       def parse_sse(raw)
-        data = raw.lines.find { |l| l.start_with?("data: ") }&.then { |l| l[6..] }
-        data && data != "[DONE]" ? JSON.parse(data) : nil
+        line = raw.lines.find { |l| l.start_with?("data: ") }
+        return nil unless line
+
+        data = line.delete_prefix("data: ").strip
+        return nil if data == "[DONE]"
+
+        JSON.parse(data)
       end
 
       def extract_system(messages)
-        sys = messages.find { |m| m[:role] == "system" || m["role"] == "system" }
-        [sys && (sys[:content] || sys["content"]), normalize(messages.reject { |m| m[:role] == "system" || m["role"] == "system" })]
+        system_messages, other_messages = messages.partition { |message| message[:role] == "system" }
+        system_content = system_messages.first&.dig(:content)
+        [system_content, normalize(other_messages)]
       end
 
       def normalize(messages)
-        messages.map do |m|
-          role, tool_calls = m[:role] || m["role"], m[:tool_calls] || m["tool_calls"]
-
-          if role == "tool"
-            { role: "user", content: [{ type: "tool_result", tool_use_id: m[:tool_call_id] || m["tool_call_id"], content: m[:content] || m["content"] }] }
-          elsif role == "assistant" && tool_calls&.any?
-            content = []
-            text = m[:content] || m["content"]
-            content << { type: "text", text: } if text && !text.empty?
-            tool_calls.each do |tool_call|
-              fn = tool_call[:function] || tool_call["function"] || {}
-              args = fn[:arguments] || fn["arguments"]
-              content << { type: "tool_use", id: tool_call[:id] || tool_call["id"], name: fn[:name] || fn["name"] || tool_call[:name] || tool_call["name"], input: args.is_a?(String) ? (JSON.parse(args) rescue {}) : (args || {}) }
-            end
-            { role: "assistant", content: }
-          else
-            m
-          end
+        messages.map { |message| normalize_message(message) }
+      end
+
+      def normalize_message(message)
+        case message[:role]
+        when "tool"
+          tool_result_message(message)
+        when "assistant"
+          message[:tool_calls]&.any? ? assistant_with_tools_message(message) : message
+        else
+          message
+        end
+      end
+
+      def tool_result_message(message)
+        {
+          role: "user",
+          content: [{
+            type: "tool_result",
+            tool_use_id: message[:tool_call_id],
+            content: message[:content]
+          }]
+        }
+      end
+
+      def assistant_with_tools_message(message)
+        text_content = build_text_content(message[:content])
+        tool_content = build_tool_content(message[:tool_calls])
+
+        { role: "assistant", content: text_content + tool_content }
+      end
+
+      def build_text_content(content)
+        return [] if content.to_s.empty?
+
+        [{ type: "text", text: content }]
+      end
+
+      def build_tool_content(tool_calls)
+        tool_calls.map do |tool_call|
+          {
+            type: "tool_use",
+            id: tool_call[:id],
+            name: tool_call[:name],
+            input: tool_call[:arguments]
+          }
         end
       end
 
lib/elelem/net/ollama.rb
@@ -4,17 +4,17 @@ module Elelem
   module Net
     class Ollama
       def initialize(model:, host: "localhost:11434", http: Elelem::Net.http)
-        @url = "#{host.start_with?('http') ? host : "http://#{host}"}/api/chat"
-        @model, @http = model, http
+        @url = normalize_url(host)
+        @model = model
+        @http = http
       end
 
       def fetch(messages, tools = [], &block)
         tool_calls = []
+        body = build_request_body(messages, tools)
 
-        stream({ model: @model, messages:, tools:, stream: true }) do |json|
-          msg = json["message"] || {}
-          block.call(content: msg["content"], thinking: msg["thinking"]) unless json["done"]
-          tool_calls.concat(parse_tools(msg["tool_calls"])) if msg["tool_calls"]
+        stream(body) do |event|
+          handle_event(event, tool_calls, &block)
         end
 
         tool_calls
@@ -22,20 +22,49 @@ module Elelem
 
       private
 
-      def stream(body, &block)
-        @http.post(@url, body:) do |res|
-          raise "HTTP #{res.code}: #{res.body}" unless res.is_a?(::Net::HTTPSuccess)
-          buf = ""
-          res.read_body do |chunk|
-            buf += chunk
-            while (i = buf.index("\n"))
-              block.call(JSON.parse(buf.slice!(0, i + 1)))
-            end
+      def normalize_url(host)
+        base = host.start_with?("http") ? host : "http://#{host}"
+        "#{base}/api/chat"
+      end
+
+      def build_request_body(messages, tools)
+        { model: @model, messages:, tools:, stream: true }
+      end
+
+      def handle_event(event, tool_calls, &block)
+        message = event["message"] || {}
+
+        unless event["done"]
+          block.call(content: message["content"], thinking: message["thinking"])
+        end
+
+        if message["tool_calls"]
+          tool_calls.concat(parse_tool_calls(message["tool_calls"]))
+        end
+      end
+
+      def stream(body)
+        @http.post(@url, body:) do |response|
+          raise "HTTP #{response.code}: #{response.body}" unless response.is_a?(::Net::HTTPSuccess)
+
+          read_ndjson_stream(response) { |event| yield event }
+        end
+      end
+
+      def read_ndjson_stream(response)
+        buffer = String.new
+
+        response.read_body do |chunk|
+          buffer << chunk
+
+          while (index = buffer.index("\n"))
+            line = buffer.slice!(0, index + 1)
+            yield JSON.parse(line)
           end
         end
       end
 
-      def parse_tools(tool_calls)
+      def parse_tool_calls(tool_calls)
         tool_calls.map do |tool_call|
           {
             id: tool_call["id"],
lib/elelem/net/openai.rb
@@ -5,48 +5,74 @@ module Elelem
     class OpenAI
       def initialize(model:, api_key:, base_url: "https://api.openai.com/v1", http: Elelem::Net.http)
         @url = "#{base_url}/chat/completions"
-        @model, @api_key, @http = model, api_key, http
+        @model = model
+        @api_key = api_key
+        @http = http
       end
 
       def fetch(messages, tools = [], &block)
         tool_calls = {}
-        body = { model: @model, messages:, stream: true, tools:, tool_choice: "auto" }
-
-        stream(body) do |json|
-          delta = json.dig("choices", 0, "delta") || {}
-          block.call(content: delta["content"], thinking: nil) if delta["content"]
-
-          delta["tool_calls"]&.each do |tool_call|
-            idx = tool_call["index"]
-            tool_calls[idx] ||= { id: nil, name: nil, args: "" }
-            tool_calls[idx][:id] ||= tool_call["id"]
-            tool_calls[idx][:name] ||= tool_call.dig("function", "name")
-            tool_calls[idx][:args] += tool_call.dig("function", "arguments").to_s
-          end
+        body = build_request_body(messages, tools)
+
+        stream(body) do |event|
+          handle_event(event, tool_calls, &block)
         end
 
-        finalize_tools(tool_calls)
+        finalize_tool_calls(tool_calls)
       end
 
       private
 
-      def stream(body, &block)
-        @http.post(@url, headers: { "Authorization" => "Bearer #{@api_key}" }, body:) do |res|
-          raise "HTTP #{res.code}: #{res.body}" unless res.is_a?(::Net::HTTPSuccess)
-
-          buf = ""
-          res.read_body do |chunk|
-            buf += chunk
-            while (i = buf.index("\n"))
-              line = buf.slice!(0, i + 1).strip
-              next unless line.start_with?("data: ") && line != "data: [DONE]"
-              block.call(JSON.parse(line[6..]))
-            end
+      def build_request_body(messages, tools)
+        { model: @model, messages:, stream: true, tools:, tool_choice: "auto" }
+      end
+
+      def handle_event(event, tool_calls, &block)
+        delta = event.dig("choices", 0, "delta") || {}
+
+        block.call(content: delta["content"], thinking: nil) if delta["content"]
+
+        accumulate_tool_calls(delta["tool_calls"], tool_calls) if delta["tool_calls"]
+      end
+
+      def accumulate_tool_calls(incoming_tool_calls, tool_calls)
+        incoming_tool_calls.each do |tool_call|
+          index = tool_call["index"]
+          tool_calls[index] ||= { id: nil, name: nil, args: String.new }
+          tool_calls[index][:id] ||= tool_call["id"]
+          tool_calls[index][:name] ||= tool_call.dig("function", "name")
+          tool_calls[index][:args] << tool_call.dig("function", "arguments").to_s
+        end
+      end
+
+      def stream(body)
+        @http.post(@url, headers: headers, body:) do |response|
+          raise "HTTP #{response.code}: #{response.body}" unless response.is_a?(::Net::HTTPSuccess)
+
+          read_sse_stream(response) { |event| yield event }
+        end
+      end
+
+      def headers
+        { "Authorization" => "Bearer #{@api_key}" }
+      end
+
+      def read_sse_stream(response)
+        buffer = String.new
+
+        response.read_body do |chunk|
+          buffer << chunk
+
+          while (index = buffer.index("\n"))
+            line = buffer.slice!(0, index + 1).strip
+            next unless line.start_with?("data: ") && line != "data: [DONE]"
+
+            yield JSON.parse(line.delete_prefix("data: "))
           end
         end
       end
 
-      def finalize_tools(tool_calls)
+      def finalize_tool_calls(tool_calls)
         tool_calls.values.map do |tool_call|
           {
             id: tool_call[:id],
lib/elelem/agent.rb
@@ -84,7 +84,7 @@ module Elelem
 
     def task_tool
       {
-        desc: "Delegate subtask to focused agent (complex searches, multi-file analysis)",
+        description: "Delegate subtask to focused agent (complex searches, multi-file analysis)",
         params: { prompt: { type: "string" } },
         required: ["prompt"],
         fn: ->(a) {
lib/elelem/mcp.rb
@@ -13,7 +13,7 @@ module Elelem
           [
             "#{name}_#{tool["name"]}",
             {
-              desc: tool["description"],
+              description: tool["description"],
               params: tool.dig("inputSchema", "properties") || {},
               required: tool.dig("inputSchema", "required") || [],
               fn: ->(a) { server(name).call(tool["name"], a) }