main
  1# frozen_string_literal: true
  2
  3module Net
  4  module Llm
  5    class Claude
  6      attr_reader :endpoint, :headers, :model, :http, :anthropic_version
  7
  8      def initialize(endpoint:, headers:, http:, model: nil, anthropic_version: nil)
  9        @endpoint = endpoint
 10        @headers_source = headers
 11        @model = model
 12        @http = http
 13        @anthropic_version = anthropic_version
 14      end
 15
 16      def headers
 17        @headers_source.respond_to?(:call) ? @headers_source.call : @headers_source
 18      end
 19
 20      def messages(messages, system: nil, max_tokens: 64000, tools: nil, &block)
 21        payload = build_payload(messages, system, max_tokens, tools, block_given?)
 22
 23        if block_given?
 24          stream_request(payload, &block)
 25        else
 26          post_request(payload)
 27        end
 28      end
 29
 30      def fetch(messages, tools = [], &block)
 31        system_message, user_messages = extract_system_message(messages)
 32        anthropic_tools = tools.empty? ? nil : tools.map { |t| normalize_tool_for_anthropic(t) }
 33
 34        if block_given?
 35          fetch_streaming(user_messages, anthropic_tools, system: system_message, &block)
 36        else
 37          fetch_non_streaming(user_messages, anthropic_tools, system: system_message)
 38        end
 39      end
 40
 41      private
 42
 43      def build_payload(messages, system, max_tokens, tools, stream)
 44        payload = { max_tokens: max_tokens, messages: messages, stream: stream }
 45        payload[:model] = model if model
 46        payload[:anthropic_version] = anthropic_version if anthropic_version
 47        payload[:system] = system if system
 48        payload[:tools] = tools if tools
 49        payload
 50      end
 51
 52      def post_request(payload)
 53        handle_response(http.post(endpoint, headers: headers, body: payload))
 54      end
 55
 56      def handle_response(response)
 57        if response.is_a?(Net::HTTPSuccess)
 58          JSON.parse(response.body)
 59        else
 60          { "code" => response.code, "body" => response.body }
 61        end
 62      end
 63
 64      def stream_request(payload, &block)
 65        http.post(endpoint, headers: headers, body: payload) do |response|
 66          raise "HTTP #{response.code}: #{response.body}" unless response.is_a?(Net::HTTPSuccess)
 67
 68          buffer = ""
 69          response.read_body do |chunk|
 70            buffer += chunk
 71
 72            while (event = extract_sse_event(buffer))
 73              next if event[:data].nil? || event[:data].empty?
 74              next if event[:data] == "[DONE]"
 75
 76              json = JSON.parse(event[:data])
 77              block.call(json)
 78
 79              break if json["type"] == "message_stop"
 80            end
 81          end
 82        end
 83      end
 84
 85      def extract_sse_event(buffer)
 86        event_end = buffer.index("\n\n")
 87        return nil unless event_end
 88
 89        event_data = buffer[0...event_end]
 90        buffer.replace(buffer[(event_end + 2)..] || "")
 91
 92        event = {}
 93        event_data.split("\n").each do |line|
 94          if line.start_with?("event: ")
 95            event[:event] = line[7..]
 96          elsif line.start_with?("data: ")
 97            event[:data] = line[6..]
 98          elsif line == "data:"
 99            event[:data] = ""
100          end
101        end
102
103        event
104      end
105
106      def extract_system_message(messages)
107        system_msg = messages.find { |m| m[:role] == "system" || m["role"] == "system" }
108        system_content = system_msg ? (system_msg[:content] || system_msg["content"]) : nil
109        other_messages = messages.reject { |m| m[:role] == "system" || m["role"] == "system" }
110        normalized_messages = normalize_messages_for_claude(other_messages)
111        [system_content, normalized_messages]
112      end
113
114      def normalize_messages_for_claude(messages)
115        messages.map do |msg|
116          role = msg[:role] || msg["role"]
117          tool_calls = msg[:tool_calls] || msg["tool_calls"]
118
119          if role == "tool"
120            {
121              role: "user",
122              content: [{
123                type: "tool_result",
124                tool_use_id: msg[:tool_call_id] || msg["tool_call_id"],
125                content: msg[:content] || msg["content"]
126              }]
127            }
128          elsif role == "assistant" && tool_calls&.any?
129            content = []
130            text = msg[:content] || msg["content"]
131            content << { type: "text", text: text } if text && !text.empty?
132            tool_calls.each do |tc|
133              func = tc[:function] || tc["function"] || {}
134              args = func[:arguments] || func["arguments"]
135              input = args.is_a?(String) ? (JSON.parse(args) rescue {}) : (args || {})
136              content << {
137                type: "tool_use",
138                id: tc[:id] || tc["id"],
139                name: func[:name] || func["name"] || tc[:name] || tc["name"],
140                input: input
141              }
142            end
143            { role: "assistant", content: content }
144          else
145            msg
146          end
147        end
148      end
149
150      def fetch_non_streaming(messages, tools, system: nil)
151        result = self.messages(messages, system: system, tools: tools)
152        return result if result["code"]
153
154        {
155          type: :complete,
156          content: extract_text_content(result["content"]),
157          thinking: extract_thinking_content(result["content"]),
158          tool_calls: extract_tool_calls(result["content"]),
159          stop_reason: map_stop_reason(result["stop_reason"])
160        }
161      end
162
163      def fetch_streaming(messages, tools, system: nil, &block)
164        content = ""
165        thinking = ""
166        tool_calls = []
167        stop_reason = :end_turn
168
169        self.messages(messages, system: system, tools: tools) do |event|
170          case event["type"]
171          when "content_block_start"
172            if event.dig("content_block", "type") == "tool_use"
173              tool_calls << {
174                id: event.dig("content_block", "id"),
175                name: event.dig("content_block", "name"),
176                arguments: {}
177              }
178            end
179          when "content_block_delta"
180            delta = event["delta"]
181            case delta["type"]
182            when "text_delta"
183              text = delta["text"]
184              content += text
185              block.call({ type: :delta, content: text, thinking: nil, tool_calls: nil })
186            when "thinking_delta"
187              text = delta["thinking"]
188              thinking += text if text
189              block.call({ type: :delta, content: nil, thinking: text, tool_calls: nil })
190            when "input_json_delta"
191              if tool_calls.any?
192                tool_calls.last[:arguments_json] ||= ""
193                tool_calls.last[:arguments_json] += delta["partial_json"] || ""
194              end
195            end
196          when "message_delta"
197            stop_reason = map_stop_reason(event.dig("delta", "stop_reason"))
198          when "message_stop"
199            tool_calls.each do |tc|
200              if tc[:arguments_json]
201                tc[:arguments] = begin
202                  JSON.parse(tc[:arguments_json])
203                rescue
204                  {}
205                end
206                tc.delete(:arguments_json)
207              end
208            end
209            block.call({
210              type: :complete,
211              content: content,
212              thinking: thinking.empty? ? nil : thinking,
213              tool_calls: tool_calls,
214              stop_reason: stop_reason
215            })
216          end
217        end
218      end
219
220      def extract_text_content(content_blocks)
221        return nil unless content_blocks
222
223        content_blocks
224          .select { |b| b["type"] == "text" }
225          .map { |b| b["text"] }
226          .join
227      end
228
229      def extract_thinking_content(content_blocks)
230        return nil unless content_blocks
231
232        thinking = content_blocks
233          .select { |b| b["type"] == "thinking" }
234          .map { |b| b["thinking"] }
235          .join
236
237        thinking.empty? ? nil : thinking
238      end
239
240      def extract_tool_calls(content_blocks)
241        return [] unless content_blocks
242
243        content_blocks
244          .select { |b| b["type"] == "tool_use" }
245          .map { |b| { id: b["id"], name: b["name"], arguments: b["input"] || {} } }
246      end
247
248      def normalize_tool_for_anthropic(tool)
249        if tool[:function]
250          { name: tool[:function][:name], description: tool[:function][:description], input_schema: tool[:function][:parameters] }
251        else
252          tool
253        end
254      end
255
256      def map_stop_reason(reason)
257        case reason
258        when "end_turn" then :end_turn
259        when "tool_use" then :tool_use
260        when "max_tokens" then :max_tokens
261        else :end_turn
262        end
263      end
264    end
265  end
266end