Commit 3211970

mo khan <mo@mokhan.ca>
2026-01-20 16:17:04
refactor: merge in net/llm code
1 parent 8959a47
lib/net/llm/anthropic.rb
@@ -0,0 +1,23 @@
+# frozen_string_literal: true
+
+module Net
+  module Llm
+    class Anthropic
+      attr_reader :api_key, :model
+
+      def initialize(api_key: ENV.fetch("ANTHROPIC_API_KEY"), model: "claude-sonnet-4-20250514", http: Net::Llm.http)
+        @api_key = api_key
+        @model = model
+        @claude = Claude.new(
+          endpoint: "https://api.anthropic.com/v1/messages",
+          headers: { "x-api-key" => api_key, "anthropic-version" => "2023-06-01" },
+          model: model,
+          http: http
+        )
+      end
+
+      def messages(...) = @claude.messages(...)
+      def fetch(...) = @claude.fetch(...)
+    end
+  end
+end
lib/net/llm/claude.rb
@@ -0,0 +1,266 @@
+# frozen_string_literal: true
+
+module Net
+  module Llm
+    class Claude
+      attr_reader :endpoint, :headers, :model, :http, :anthropic_version
+
+      def initialize(endpoint:, headers:, http:, model: nil, anthropic_version: nil)
+        @endpoint = endpoint
+        @headers_source = headers
+        @model = model
+        @http = http
+        @anthropic_version = anthropic_version
+      end
+
+      def headers
+        @headers_source.respond_to?(:call) ? @headers_source.call : @headers_source
+      end
+
+      def messages(messages, system: nil, max_tokens: 64000, tools: nil, &block)
+        payload = build_payload(messages, system, max_tokens, tools, block_given?)
+
+        if block_given?
+          stream_request(payload, &block)
+        else
+          post_request(payload)
+        end
+      end
+
+      def fetch(messages, tools = [], &block)
+        system_message, user_messages = extract_system_message(messages)
+        anthropic_tools = tools.empty? ? nil : tools.map { |t| normalize_tool_for_anthropic(t) }
+
+        if block_given?
+          fetch_streaming(user_messages, anthropic_tools, system: system_message, &block)
+        else
+          fetch_non_streaming(user_messages, anthropic_tools, system: system_message)
+        end
+      end
+
+      private
+
+      def build_payload(messages, system, max_tokens, tools, stream)
+        payload = { max_tokens: max_tokens, messages: messages, stream: stream }
+        payload[:model] = model if model
+        payload[:anthropic_version] = anthropic_version if anthropic_version
+        payload[:system] = system if system
+        payload[:tools] = tools if tools
+        payload
+      end
+
+      def post_request(payload)
+        handle_response(http.post(endpoint, headers: headers, body: payload))
+      end
+
+      def handle_response(response)
+        if response.is_a?(Net::HTTPSuccess)
+          JSON.parse(response.body)
+        else
+          { "code" => response.code, "body" => response.body }
+        end
+      end
+
+      def stream_request(payload, &block)
+        http.post(endpoint, headers: headers, body: payload) do |response|
+          raise "HTTP #{response.code}: #{response.body}" unless response.is_a?(Net::HTTPSuccess)
+
+          buffer = ""
+          response.read_body do |chunk|
+            buffer += chunk
+
+            while (event = extract_sse_event(buffer))
+              next if event[:data].nil? || event[:data].empty?
+              next if event[:data] == "[DONE]"
+
+              json = JSON.parse(event[:data])
+              block.call(json)
+
+              break if json["type"] == "message_stop"
+            end
+          end
+        end
+      end
+
+      def extract_sse_event(buffer)
+        event_end = buffer.index("\n\n")
+        return nil unless event_end
+
+        event_data = buffer[0...event_end]
+        buffer.replace(buffer[(event_end + 2)..] || "")
+
+        event = {}
+        event_data.split("\n").each do |line|
+          if line.start_with?("event: ")
+            event[:event] = line[7..]
+          elsif line.start_with?("data: ")
+            event[:data] = line[6..]
+          elsif line == "data:"
+            event[:data] = ""
+          end
+        end
+
+        event
+      end
+
+      def extract_system_message(messages)
+        system_msg = messages.find { |m| m[:role] == "system" || m["role"] == "system" }
+        system_content = system_msg ? (system_msg[:content] || system_msg["content"]) : nil
+        other_messages = messages.reject { |m| m[:role] == "system" || m["role"] == "system" }
+        normalized_messages = normalize_messages_for_claude(other_messages)
+        [system_content, normalized_messages]
+      end
+
+      def normalize_messages_for_claude(messages)
+        messages.map do |msg|
+          role = msg[:role] || msg["role"]
+          tool_calls = msg[:tool_calls] || msg["tool_calls"]
+
+          if role == "tool"
+            {
+              role: "user",
+              content: [{
+                type: "tool_result",
+                tool_use_id: msg[:tool_call_id] || msg["tool_call_id"],
+                content: msg[:content] || msg["content"]
+              }]
+            }
+          elsif role == "assistant" && tool_calls&.any?
+            content = []
+            text = msg[:content] || msg["content"]
+            content << { type: "text", text: text } if text && !text.empty?
+            tool_calls.each do |tc|
+              func = tc[:function] || tc["function"] || {}
+              args = func[:arguments] || func["arguments"]
+              input = args.is_a?(String) ? (JSON.parse(args) rescue {}) : (args || {})
+              content << {
+                type: "tool_use",
+                id: tc[:id] || tc["id"],
+                name: func[:name] || func["name"] || tc[:name] || tc["name"],
+                input: input
+              }
+            end
+            { role: "assistant", content: content }
+          else
+            msg
+          end
+        end
+      end
+
+      def fetch_non_streaming(messages, tools, system: nil)
+        result = self.messages(messages, system: system, tools: tools)
+        return result if result["code"]
+
+        {
+          type: :complete,
+          content: extract_text_content(result["content"]),
+          thinking: extract_thinking_content(result["content"]),
+          tool_calls: extract_tool_calls(result["content"]),
+          stop_reason: map_stop_reason(result["stop_reason"])
+        }
+      end
+
+      def fetch_streaming(messages, tools, system: nil, &block)
+        content = ""
+        thinking = ""
+        tool_calls = []
+        stop_reason = :end_turn
+
+        self.messages(messages, system: system, tools: tools) do |event|
+          case event["type"]
+          when "content_block_start"
+            if event.dig("content_block", "type") == "tool_use"
+              tool_calls << {
+                id: event.dig("content_block", "id"),
+                name: event.dig("content_block", "name"),
+                arguments: {}
+              }
+            end
+          when "content_block_delta"
+            delta = event["delta"]
+            case delta["type"]
+            when "text_delta"
+              text = delta["text"]
+              content += text
+              block.call({ type: :delta, content: text, thinking: nil, tool_calls: nil })
+            when "thinking_delta"
+              text = delta["thinking"]
+              thinking += text if text
+              block.call({ type: :delta, content: nil, thinking: text, tool_calls: nil })
+            when "input_json_delta"
+              if tool_calls.any?
+                tool_calls.last[:arguments_json] ||= ""
+                tool_calls.last[:arguments_json] += delta["partial_json"] || ""
+              end
+            end
+          when "message_delta"
+            stop_reason = map_stop_reason(event.dig("delta", "stop_reason"))
+          when "message_stop"
+            tool_calls.each do |tc|
+              if tc[:arguments_json]
+                tc[:arguments] = begin
+                  JSON.parse(tc[:arguments_json])
+                rescue
+                  {}
+                end
+                tc.delete(:arguments_json)
+              end
+            end
+            block.call({
+              type: :complete,
+              content: content,
+              thinking: thinking.empty? ? nil : thinking,
+              tool_calls: tool_calls,
+              stop_reason: stop_reason
+            })
+          end
+        end
+      end
+
+      def extract_text_content(content_blocks)
+        return nil unless content_blocks
+
+        content_blocks
+          .select { |b| b["type"] == "text" }
+          .map { |b| b["text"] }
+          .join
+      end
+
+      def extract_thinking_content(content_blocks)
+        return nil unless content_blocks
+
+        thinking = content_blocks
+          .select { |b| b["type"] == "thinking" }
+          .map { |b| b["thinking"] }
+          .join
+
+        thinking.empty? ? nil : thinking
+      end
+
+      def extract_tool_calls(content_blocks)
+        return [] unless content_blocks
+
+        content_blocks
+          .select { |b| b["type"] == "tool_use" }
+          .map { |b| { id: b["id"], name: b["name"], arguments: b["input"] || {} } }
+      end
+
+      def normalize_tool_for_anthropic(tool)
+        if tool[:function]
+          { name: tool[:function][:name], description: tool[:function][:description], input_schema: tool[:function][:parameters] }
+        else
+          tool
+        end
+      end
+
+      def map_stop_reason(reason)
+        case reason
+        when "end_turn" then :end_turn
+        when "tool_use" then :tool_use
+        when "max_tokens" then :max_tokens
+        else :end_turn
+        end
+      end
+    end
+  end
+end
lib/net/llm/ollama.rb
@@ -0,0 +1,171 @@
+# frozen_string_literal: true
+
+module Net
+  module Llm
+    class Ollama
+      attr_reader :host, :model, :http
+
+      def initialize(host: ENV.fetch("OLLAMA_HOST", "localhost:11434"), model: "gpt-oss", http: Net::Llm.http)
+        @host = host
+        @model = model
+        @http = http
+      end
+
+      def chat(messages, tools = [], &block)
+        payload = { model: model, messages: messages, stream: block_given? }
+        payload[:tools] = tools unless tools.empty?
+
+        execute(build_url("/api/chat"), payload, &block)
+      end
+
+      def fetch(messages, tools = [], &block)
+        content = ""
+        thinking = ""
+        tool_calls = []
+
+        if block_given?
+          chat(messages, tools) do |chunk|
+            msg = chunk["message"] || {}
+            delta_content = msg["content"]
+            delta_thinking = msg["thinking"]
+
+            content += delta_content if delta_content
+            thinking += delta_thinking if delta_thinking
+            tool_calls += normalize_tool_calls(msg["tool_calls"]) if msg["tool_calls"]
+
+            if chunk["done"]
+              block.call({
+                type: :complete,
+                content: content,
+                thinking: thinking.empty? ? nil : thinking,
+                tool_calls: tool_calls,
+                stop_reason: map_stop_reason(chunk["done_reason"])
+              })
+            else
+              block.call({
+                type: :delta,
+                content: delta_content,
+                thinking: delta_thinking,
+                tool_calls: nil
+              })
+            end
+          end
+        else
+          result = chat(messages, tools)
+          msg = result["message"] || {}
+          {
+            type: :complete,
+            content: msg["content"],
+            thinking: msg["thinking"],
+            tool_calls: normalize_tool_calls(msg["tool_calls"]),
+            stop_reason: map_stop_reason(result["done_reason"])
+          }
+        end
+      end
+
+      def generate(prompt, &block)
+        execute(build_url("/api/generate"), {
+          model: model,
+          prompt: prompt,
+          stream: block_given?
+        }, &block)
+      end
+
+      def embeddings(input)
+        post_request(build_url("/api/embed"), { model: model, input: input })
+      end
+
+      def tags
+        get_request(build_url("/api/tags"))
+      end
+
+      def show(name)
+        post_request(build_url("/api/show"), { name: name })
+      end
+
+      private
+
+      def execute(url, payload, &block)
+        if block_given?
+          stream_request(url, payload, &block)
+        else
+          post_request(url, payload)
+        end
+      end
+
+      def build_url(path)
+        base = host.start_with?("http://", "https://") ? host : "http://#{host}"
+        "#{base}#{path}"
+      end
+
+      def get_request(url)
+        handle_response(http.get(url))
+      end
+
+      def post_request(url, payload)
+        handle_response(http.post(url, body: payload))
+      end
+
+      def handle_response(response)
+        if response.is_a?(Net::HTTPSuccess)
+          JSON.parse(response.body)
+        else
+          {
+            "code" => response.code,
+            "body" => response.body
+          }
+        end
+      end
+
+      def stream_request(url, payload, &block)
+        http.post(url, body: payload) do |response|
+          raise "HTTP #{response.code}: #{response.body}" unless response.is_a?(Net::HTTPSuccess)
+
+          buffer = ""
+          response.read_body do |chunk|
+            buffer += chunk
+
+            while (message = extract_message(buffer))
+              next if message.empty?
+
+              json = JSON.parse(message)
+              block.call(json)
+
+              break if json["done"]
+            end
+          end
+        end
+      end
+
+      def extract_message(buffer)
+        message_end = buffer.index("\n")
+        return nil unless message_end
+
+        message = buffer[0...message_end]
+        buffer.replace(buffer[(message_end + 1)..-1] || "")
+        message
+      end
+
+      def normalize_tool_calls(tool_calls)
+        return [] if tool_calls.nil? || tool_calls.empty?
+
+        tool_calls.map do |tc|
+          {
+            id: tc["id"] || tc.dig("function", "id"),
+            name: tc.dig("function", "name"),
+            arguments: tc.dig("function", "arguments") || {}
+          }
+        end
+      end
+
+      def map_stop_reason(reason)
+        case reason
+        when "stop" then :end_turn
+        when "tool_calls", "tool_use" then :tool_use
+        when "length" then :max_tokens
+        else :end_turn
+        end
+      end
+    end
+  end
+end
lib/net/llm/openai.rb
@@ -0,0 +1,172 @@
+# frozen_string_literal: true
+
+module Net
+  module Llm
+    class OpenAI
+      attr_reader :api_key, :base_url, :model, :http
+
+      def initialize(api_key: ENV.fetch("OPENAI_API_KEY"), base_url: ENV.fetch("OPENAI_BASE_URL", "https://api.openai.com/v1"), model: "gpt-4o-mini", http: Net::Llm.http)
+        @api_key = api_key
+        @base_url = base_url
+        @model = model
+        @http = http
+      end
+
+      def chat(messages, tools)
+        handle_response(http.post(
+          "#{base_url}/chat/completions",
+          headers: headers,
+          body: { model: model, messages: messages, tools: tools, tool_choice: "auto" }
+        ))
+      end
+
+      def fetch(messages, tools = [], &block)
+        if block_given?
+          fetch_streaming(messages, tools, &block)
+        else
+          fetch_non_streaming(messages, tools)
+        end
+      end
+
+      def models
+        handle_response(http.get("#{base_url}/models", headers: headers))
+      end
+
+      def embeddings(input, model: "text-embedding-ada-002")
+        handle_response(http.post(
+          "#{base_url}/embeddings",
+          headers: headers,
+          body: { model: model, input: input },
+        ))
+      end
+
+      private
+
+      def headers
+        { "Authorization" => Net::Hippie.bearer_auth(api_key) }
+      end
+
+      def handle_response(response)
+        if response.is_a?(Net::HTTPSuccess)
+          JSON.parse(response.body)
+        else
+          { "code" => response.code, "body" => response.body }
+        end
+      end
+
+      def fetch_non_streaming(messages, tools)
+        body = { model: model, messages: messages }
+        body[:tools] = tools unless tools.empty?
+        body[:tool_choice] = "auto" unless tools.empty?
+
+        result = handle_response(http.post("#{base_url}/chat/completions", headers: headers, body: body))
+        return result if result["code"]
+
+        msg = result.dig("choices", 0, "message") || {}
+        {
+          type: :complete,
+          content: msg["content"],
+          thinking: nil,
+          tool_calls: normalize_tool_calls(msg["tool_calls"]),
+          stop_reason: map_stop_reason(result.dig("choices", 0, "finish_reason"))
+        }
+      end
+
+      def fetch_streaming(messages, tools, &block)
+        body = { model: model, messages: messages, stream: true }
+        body[:tools] = tools unless tools.empty?
+        body[:tool_choice] = "auto" unless tools.empty?
+
+        content = ""
+        tool_calls = {}
+        stop_reason = :end_turn
+
+        http.post("#{base_url}/chat/completions", headers: headers, body: body) do |response|
+          raise "HTTP #{response.code}: #{response.body}" unless response.is_a?(Net::HTTPSuccess)
+
+          buffer = ""
+          response.read_body do |chunk|
+            buffer += chunk
+
+            while (line = extract_line(buffer))
+              next if line.empty? || !line.start_with?("data: ")
+
+              data = line[6..]
+              break if data == "[DONE]"
+
+              json = JSON.parse(data)
+              delta = json.dig("choices", 0, "delta") || {}
+
+              if delta["content"]
+                content += delta["content"]
+                block.call({ type: :delta, content: delta["content"], thinking: nil, tool_calls: nil })
+              end
+
+              if delta["tool_calls"]
+                delta["tool_calls"].each do |tc|
+                  idx = tc["index"]
+                  tool_calls[idx] ||= { id: nil, name: nil, arguments_json: "" }
+                  tool_calls[idx][:id] = tc["id"] if tc["id"]
+                  tool_calls[idx][:name] = tc.dig("function", "name") if tc.dig("function", "name")
+                  tool_calls[idx][:arguments_json] += tc.dig("function", "arguments") || ""
+                end
+              end
+
+              if json.dig("choices", 0, "finish_reason")
+                stop_reason = map_stop_reason(json.dig("choices", 0, "finish_reason"))
+              end
+            end
+          end
+        end
+
+        final_tool_calls = tool_calls.values.map do |tc|
+          args = begin
+            JSON.parse(tc[:arguments_json])
+          rescue
+            {}
+          end
+          { id: tc[:id], name: tc[:name], arguments: args }
+        end
+
+        block.call({
+          type: :complete,
+          content: content,
+          thinking: nil,
+          tool_calls: final_tool_calls,
+          stop_reason: stop_reason
+        })
+      end
+
+      def extract_line(buffer)
+        line_end = buffer.index("\n")
+        return nil unless line_end
+
+        line = buffer[0...line_end]
+        buffer.replace(buffer[(line_end + 1)..] || "")
+        line
+      end
+
+      def normalize_tool_calls(tool_calls)
+        return [] if tool_calls.nil? || tool_calls.empty?
+
+        tool_calls.map do |tc|
+          args = tc.dig("function", "arguments")
+          {
+            id: tc["id"],
+            name: tc.dig("function", "name"),
+            arguments: args.is_a?(String) ? (JSON.parse(args) rescue {}) : (args || {})
+          }
+        end
+      end
+
+      def map_stop_reason(reason)
+        case reason
+        when "stop" then :end_turn
+        when "tool_calls" then :tool_use
+        when "length" then :max_tokens
+        else :end_turn
+        end
+      end
+    end
+  end
+end
lib/net/llm/version.rb
@@ -0,0 +1,7 @@
+# frozen_string_literal: true
+
+module Net
+  module Llm
+    VERSION = "0.6.1"
+  end
+end
lib/net/llm/vertex_ai.rb
@@ -0,0 +1,40 @@
+# frozen_string_literal: true
+
+module Net
+  module Llm
+    class VertexAI
+      attr_reader :project_id, :region, :model
+
+      def initialize(project_id: ENV.fetch("GOOGLE_CLOUD_PROJECT"), region: ENV.fetch("GOOGLE_CLOUD_REGION", "us-east5"), model: "claude-opus-4-5@20251101", http: Net::Llm.http)
+        @project_id = project_id
+        @region = region
+        @model = model
+        @handler = build_handler(http)
+      end
+
+      def messages(...) = @handler.messages(...)
+      def fetch(...) = @handler.fetch(...)
+
+      private
+
+      def build_handler(http)
+        if model.start_with?("claude-")
+          Claude.new(
+            endpoint: "https://#{region}-aiplatform.googleapis.com/v1/projects/#{project_id}/locations/#{region}/publishers/anthropic/models/#{model}:rawPredict",
+            headers: -> { { "Authorization" => "Bearer #{access_token}" } },
+            http: http,
+            anthropic_version: "vertex-2023-10-16"
+          )
+        else
+          raise NotImplementedError, "Model '#{model}' is not yet supported. Only Claude models (claude-*) are currently implemented."
+        end
+      end
+
+      def access_token
+        ENV.fetch("GOOGLE_OAUTH_ACCESS_TOKEN") do
+          `gcloud auth application-default print-access-token`.strip
+        end
+      end
+    end
+  end
+end
lib/net/llm.rb
@@ -0,0 +1,24 @@
+# frozen_string_literal: true
+
+require "net/hippie"
+require "json"
+
+require_relative "llm/version"
+require_relative "llm/openai"
+require_relative "llm/ollama"
+require_relative "llm/claude"
+require_relative "llm/anthropic"
+require_relative "llm/vertex_ai"
+
+module Net
+  module Llm
+    class Error < StandardError; end
+
+    def self.http
+      @http ||= Net::Hippie::Client.new(
+        read_timeout: 3600,
+        open_timeout: 10
+      )
+    end
+  end
+end
elelem.gemspec
@@ -43,10 +43,11 @@ Gem::Specification.new do |spec|
   spec.add_dependency "date", "~> 3.0"
   spec.add_dependency "fileutils", "~> 1.0"
   spec.add_dependency "json", "~> 2.0"
-  spec.add_dependency "net-llm", "~> 0.5", ">= 0.5.0"
+  spec.add_dependency "net-hippie", "~> 1.0"
   spec.add_dependency "open3", "~> 0.1"
   spec.add_dependency "pathname", "~> 0.1"
   spec.add_dependency "reline", "~> 0.6"
   spec.add_dependency "stringio", "~> 3.0"
   spec.add_dependency "tempfile", "~> 0.3"
+  spec.add_dependency "uri", "~> 1.0"
 end
Gemfile.lock
@@ -5,12 +5,13 @@ PATH
       date (~> 3.0)
       fileutils (~> 1.0)
       json (~> 2.0)
-      net-llm (~> 0.5, >= 0.5.0)
+      net-hippie (~> 1.0)
       open3 (~> 0.1)
       pathname (~> 0.1)
       reline (~> 0.6)
       stringio (~> 3.0)
       tempfile (~> 0.3)
+      uri (~> 1.0)
 
 GEM
   remote: https://rubygems.org/
@@ -35,10 +36,6 @@ GEM
       openssl (~> 3.0)
     net-http (0.9.1)
       uri (>= 0.11.1)
-    net-llm (0.6.1)
-      json (~> 2.0)
-      net-hippie (~> 1.0)
-      uri (~> 1.0)
     open3 (0.2.1)
     openssl (3.3.2)
     pathname (0.4.0)