main
  1# frozen_string_literal: true
  2
  3module Net
  4  module Llm
  5    class Ollama
  6      attr_reader :host, :model, :http
  7
  8      def initialize(host: ENV.fetch("OLLAMA_HOST", "localhost:11434"), model: "gpt-oss", http: Net::Llm.http)
  9        @host = host
 10        @model = model
 11        @http = http
 12      end
 13
 14      def chat(messages, tools = [], &block)
 15        payload = { model: model, messages: messages, stream: block_given? }
 16        payload[:tools] = tools unless tools.empty?
 17
 18        execute(build_url("/api/chat"), payload, &block)
 19      end
 20
 21      def fetch(messages, tools = [], &block)
 22        content = ""
 23        thinking = ""
 24        tool_calls = []
 25
 26        if block_given?
 27          chat(messages, tools) do |chunk|
 28            msg = chunk["message"] || {}
 29            delta_content = msg["content"]
 30            delta_thinking = msg["thinking"]
 31
 32            content += delta_content if delta_content
 33            thinking += delta_thinking if delta_thinking
 34            tool_calls += normalize_tool_calls(msg["tool_calls"]) if msg["tool_calls"]
 35
 36            if chunk["done"]
 37              block.call({
 38                type: :complete,
 39                content: content,
 40                thinking: thinking.empty? ? nil : thinking,
 41                tool_calls: tool_calls,
 42                stop_reason: map_stop_reason(chunk["done_reason"])
 43              })
 44            else
 45              block.call({
 46                type: :delta,
 47                content: delta_content,
 48                thinking: delta_thinking,
 49                tool_calls: nil
 50              })
 51            end
 52          end
 53        else
 54          result = chat(messages, tools)
 55          msg = result["message"] || {}
 56          {
 57            type: :complete,
 58            content: msg["content"],
 59            thinking: msg["thinking"],
 60            tool_calls: normalize_tool_calls(msg["tool_calls"]),
 61            stop_reason: map_stop_reason(result["done_reason"])
 62          }
 63        end
 64      end
 65
 66      def generate(prompt, &block)
 67        execute(build_url("/api/generate"), {
 68          model: model,
 69          prompt: prompt,
 70          stream: block_given?
 71        }, &block)
 72      end
 73
 74      def embeddings(input)
 75        post_request(build_url("/api/embed"), { model: model, input: input })
 76      end
 77
 78      def tags
 79        get_request(build_url("/api/tags"))
 80      end
 81
 82      def show(name)
 83        post_request(build_url("/api/show"), { name: name })
 84      end
 85
 86      private
 87
 88      def execute(url, payload, &block)
 89        if block_given?
 90          stream_request(url, payload, &block)
 91        else
 92          post_request(url, payload)
 93        end
 94      end
 95
 96      def build_url(path)
 97        base = host.start_with?("http://", "https://") ? host : "http://#{host}"
 98        "#{base}#{path}"
 99      end
100
101      def get_request(url)
102        handle_response(http.get(url))
103      end
104
105      def post_request(url, payload)
106        handle_response(http.post(url, body: payload))
107      end
108
109      def handle_response(response)
110        if response.is_a?(Net::HTTPSuccess)
111          JSON.parse(response.body)
112        else
113          {
114            "code" => response.code,
115            "body" => response.body
116          }
117        end
118      end
119
120      def stream_request(url, payload, &block)
121        http.post(url, body: payload) do |response|
122          raise "HTTP #{response.code}: #{response.body}" unless response.is_a?(Net::HTTPSuccess)
123
124          buffer = ""
125          response.read_body do |chunk|
126            buffer += chunk
127
128            while (message = extract_message(buffer))
129              next if message.empty?
130
131              json = JSON.parse(message)
132              block.call(json)
133
134              break if json["done"]
135            end
136          end
137        end
138      end
139
140      def extract_message(buffer)
141        message_end = buffer.index("\n")
142        return nil unless message_end
143
144        message = buffer[0...message_end]
145        buffer.replace(buffer[(message_end + 1)..-1] || "")
146        message
147      end
148
149      def normalize_tool_calls(tool_calls)
150        return [] if tool_calls.nil? || tool_calls.empty?
151
152        tool_calls.map do |tc|
153          {
154            id: tc["id"] || tc.dig("function", "id"),
155            name: tc.dig("function", "name"),
156            arguments: tc.dig("function", "arguments") || {}
157          }
158        end
159      end
160
161      def map_stop_reason(reason)
162        case reason
163        when "stop" then :end_turn
164        when "tool_calls", "tool_use" then :tool_use
165        when "length" then :max_tokens
166        else :end_turn
167        end
168      end
169    end
170  end
171end