Commit c1c0242

mo khan <mo@mokhan.ca>
2026-01-26 19:28:43
feat: add support for OAuth login to MCP server
1 parent 5af1c0a
lib/elelem/mcp/oauth.rb
@@ -0,0 +1,211 @@
+# frozen_string_literal: true
+
+module Elelem
+  class MCP
+    class OAuth
+      CALLBACK_PORT = 18273
+      REDIRECT_URI = "http://127.0.0.1:#{CALLBACK_PORT}/callback"
+
+      def initialize(resource_url, http: Elelem::Net.http)
+        @resource_url = resource_url
+        @http = http
+        @storage = TokenStorage.new
+      end
+
+      def token
+        stored = @storage.load(@resource_url)
+        return stored[:access_token] if stored && !expired?(stored)
+        return refresh(stored[:refresh_token]) if stored&.dig(:refresh_token)
+
+        authorize
+      end
+
+      private
+
+      def expired?(stored)
+        return false unless stored[:expires_at]
+
+        Time.now.to_i >= stored[:expires_at] - 60
+      end
+
+      def authorize
+        metadata = discover_auth_server
+        client = load_or_register_client(metadata)
+        verifier, challenge = generate_pkce
+        state = SecureRandom.hex(16)
+
+        auth_url = build_auth_url(metadata, client, challenge, state)
+        open_browser(auth_url)
+
+        code = wait_for_callback(state)
+        tokens = exchange_code(metadata, client, code, verifier)
+
+        @storage.save(
+          @resource_url,
+          access_token: tokens["access_token"],
+          refresh_token: tokens["refresh_token"],
+          expires_in: tokens["expires_in"]
+        )
+
+        tokens["access_token"]
+      end
+
+      def refresh(refresh_token)
+        metadata = discover_auth_server
+        client = load_or_register_client(metadata)
+        uri = URI.parse(metadata["token_endpoint"])
+
+        body = {
+          grant_type: "refresh_token",
+          refresh_token: refresh_token,
+          client_id: client[:client_id]
+        }
+
+        response = post_form(uri, body)
+        tokens = JSON.parse(response.body)
+
+        @storage.save(
+          @resource_url,
+          access_token: tokens["access_token"],
+          refresh_token: tokens["refresh_token"] || refresh_token,
+          expires_in: tokens["expires_in"]
+        )
+
+        tokens["access_token"]
+      rescue StandardError
+        authorize
+      end
+
+      def discover_auth_server
+        resource_uri = URI.parse(@resource_url)
+        metadata_url = "#{resource_uri.scheme}://#{resource_uri.host}/.well-known/oauth-protected-resource"
+
+        resource_metadata = fetch_json(metadata_url)
+        auth_server_url = resource_metadata["authorization_servers"]&.first
+        raise "No authorization server found" unless auth_server_url
+
+        auth_metadata_url = "#{auth_server_url}/.well-known/oauth-authorization-server"
+        fetch_json(auth_metadata_url)
+      end
+
+      def load_or_register_client(metadata)
+        stored = @storage.load_client(@resource_url)
+        return stored if stored
+
+        client = register_client(metadata)
+        @storage.save_client(@resource_url, client)
+        @storage.load_client(@resource_url)
+      end
+
+      def register_client(metadata)
+        endpoint = metadata["registration_endpoint"]
+        raise "Dynamic registration not supported" unless endpoint
+
+        body = {
+          client_name: "elelem",
+          redirect_uris: [REDIRECT_URI],
+          grant_types: %w[authorization_code refresh_token],
+          response_types: ["code"],
+          token_endpoint_auth_method: "none"
+        }
+
+        response = post_json(endpoint, body)
+        JSON.parse(response.body)
+      end
+
+      def generate_pkce
+        verifier = SecureRandom.urlsafe_base64(32)
+        challenge = Base64.urlsafe_encode64(
+          Digest::SHA256.digest(verifier),
+          padding: false
+        )
+        [verifier, challenge]
+      end
+
+      def build_auth_url(metadata, client, challenge, state)
+        params = {
+          response_type: "code",
+          client_id: client[:client_id],
+          redirect_uri: REDIRECT_URI,
+          scope: metadata["scopes_supported"]&.join(" ") || "openid",
+          state: state,
+          code_challenge: challenge,
+          code_challenge_method: "S256"
+        }
+
+        "#{metadata["authorization_endpoint"]}?#{URI.encode_www_form(params)}"
+      end
+
+      def open_browser(url)
+        commands = ["xdg-open", "open", "start"]
+        commands.each do |cmd|
+          return if system(cmd, url, out: File::NULL, err: File::NULL)
+        end
+        warn "Open this URL in your browser: #{url}"
+      end
+
+      def wait_for_callback(expected_state)
+        code = nil
+        server = WEBrick::HTTPServer.new(
+          Port: CALLBACK_PORT,
+          Logger: WEBrick::Log.new(File::NULL),
+          AccessLog: []
+        )
+
+        server.mount_proc("/callback") do |req, res|
+          state = req.query["state"]
+          raise "State mismatch" unless state == expected_state
+
+          code = req.query["code"]
+          res.content_type = "text/html"
+          res.body = "<html><body><h1>Authorization complete</h1><p>You can close this window.</p></body></html>"
+          server.shutdown
+        end
+
+        server.start
+        code
+      end
+
+      def exchange_code(metadata, client, code, verifier)
+        uri = URI.parse(metadata["token_endpoint"])
+
+        body = {
+          grant_type: "authorization_code",
+          code: code,
+          redirect_uri: REDIRECT_URI,
+          client_id: client[:client_id],
+          code_verifier: verifier
+        }
+
+        response = post_form(uri, body)
+        JSON.parse(response.body)
+      end
+
+      def fetch_json(url)
+        response = nil
+        @http.get(url) { |r| response = r }
+        JSON.parse(response.body)
+      end
+
+      def post_json(url, body)
+        response = nil
+        @http.post(
+          url,
+          headers: { "Content-Type" => "application/json" },
+          body: body.to_json
+        ) { |r| response = r }
+        response
+      end
+
+      def post_form(uri, body)
+        response = nil
+        @http.post(
+          uri.to_s,
+          headers: { "Content-Type" => "application/x-www-form-urlencoded" },
+          body: URI.encode_www_form(body)
+        ) { |r| response = r }
+        response
+      end
+    end
+  end
+end
lib/elelem/mcp/token_storage.rb
@@ -0,0 +1,60 @@
+# frozen_string_literal: true
+
+module Elelem
+  class MCP
+    class TokenStorage
+      STORAGE_DIR = File.expand_path("~/.config/elelem/tokens")
+
+      def initialize
+        FileUtils.mkdir_p(STORAGE_DIR, mode: 0o700)
+      end
+
+      def save(resource_url, access_token:, refresh_token: nil, expires_in: nil)
+        data = {
+          access_token: access_token,
+          refresh_token: refresh_token,
+          expires_at: expires_in ? Time.now.to_i + expires_in : nil
+        }
+        path = token_path(resource_url)
+        File.write(path, data.to_json)
+        File.chmod(0o600, path)
+      end
+
+      def load(resource_url)
+        path = token_path(resource_url)
+        return nil unless File.exist?(path)
+
+        JSON.parse(File.read(path), symbolize_names: true)
+      rescue JSON::ParserError
+        nil
+      end
+
+      def save_client(resource_url, client_data)
+        path = client_path(resource_url)
+        File.write(path, client_data.to_json)
+        File.chmod(0o600, path)
+      end
+
+      def load_client(resource_url)
+        path = client_path(resource_url)
+        return nil unless File.exist?(path)
+
+        JSON.parse(File.read(path), symbolize_names: true)
+      rescue JSON::ParserError
+        nil
+      end
+
+      private
+
+      def token_path(resource_url)
+        hash = Digest::SHA256.hexdigest(resource_url)[0, 16]
+        File.join(STORAGE_DIR, "#{hash}.json")
+      end
+
+      def client_path(resource_url)
+        hash = Digest::SHA256.hexdigest(resource_url)[0, 16]
+        File.join(STORAGE_DIR, "#{hash}_client.json")
+      end
+    end
+  end
+end
lib/elelem/mcp.rb
@@ -1,5 +1,8 @@
 # frozen_string_literal: true
 
+require_relative "mcp/token_storage"
+require_relative "mcp/oauth"
+
 module Elelem
   class MCP
     def initialize(config_path = ".mcp.json")
@@ -30,12 +33,22 @@ module Elelem
     private
 
     def server(name)
-      @servers[name] ||= Server.new(**@config.dig("mcpServers", name).transform_keys(&:to_sym))
+      @servers[name] ||= build_server(@config.dig("mcpServers", name))
+    end
+
+    def build_server(config)
+      if config["type"] == "http"
+        HttpServer.new(url: config["url"], headers: config["headers"] || {})
+      else
+        Server.new(**config.transform_keys(&:to_sym))
+      end
     end
 
     class Server
       def initialize(command:, args: [], env: {})
-        resolved_env = env.transform_values { |v| v.gsub(/\$\{(\w+)\}/) { ENV[$1] } }
+        resolved_env = env.transform_values do |v|
+          v.gsub(/\$\{(\w+)\}/) { ENV[$1] || raise("Missing environment variable: #{$1}") }
+        end
         @stdin, @stdout, @stderr, @wait = Open3.popen3(resolved_env, command, *args)
         @id = 0
         initialize!
@@ -61,7 +74,7 @@ module Elelem
 
       def initialize!
         request("initialize", {
-          protocolVersion: "2024-11-05",
+          protocolVersion: "2025-06-18",
           capabilities: {},
           clientInfo: { name: "elelem", version: VERSION }
         })
@@ -92,5 +105,120 @@ module Elelem
         end
       end
     end
+
+    class HttpServer
+      def initialize(url:, headers: {}, http: Elelem::Net.http)
+        @url = url
+        @headers = resolve_headers(headers)
+        @http = http
+        @id = 0
+        @session_id = nil
+        @access_token = nil
+        initialize!
+      end
+
+      def tools
+        request("tools/list")["tools"]
+      end
+
+      def call(name, args)
+        result = request("tools/call", { name: name, arguments: args })
+        { content: result["content"]&.map { |c| c["text"] }&.join("\n") }
+      end
+
+      def close
+      end
+
+      private
+
+      def resolve_headers(headers)
+        headers.transform_values do |v|
+          v.gsub(/\$\{(\w+)\}/) do
+            ENV[$1] || raise("Missing environment variable: #{$1}")
+          end
+        end
+      end
+
+      def initialize!
+        request("initialize", {
+          protocolVersion: "2025-06-18",
+          capabilities: {},
+          clientInfo: { name: "elelem", version: VERSION }
+        })
+        notify("notifications/initialized")
+      end
+
+      def request(method, params = {})
+        msg = { jsonrpc: "2.0", id: @id += 1, method: method, params: params }
+        response = post(msg)
+        raise response["error"]["message"] if response["error"]
+        response["result"]
+      end
+
+      def notify(method, params = {})
+        msg = { jsonrpc: "2.0", method: method, params: params }
+        post(msg)
+      end
+
+      def post(msg, retry_auth: true)
+        result = nil
+        needs_auth = false
+        error = nil
+
+        @http.post(@url, headers: request_headers, body: msg) do |response|
+          case response
+          when ::Net::HTTPSuccess
+            @session_id ||= response["Mcp-Session-Id"]
+            result = parse_response(response)
+          when ::Net::HTTPUnauthorized
+            needs_auth = true
+          else
+            error = "HTTP #{response.code}: #{response.body}"
+          end
+        end
+
+        raise error if error
+        if needs_auth
+          raise "Authorization failed" unless retry_auth
+
+          @access_token = OAuth.new(@url, http: @http).token
+          return post(msg, retry_auth: false)
+        end
+        result
+      end
+
+      def request_headers
+        base = { "Accept" => "application/json, text/event-stream" }
+        base["Mcp-Session-Id"] = @session_id if @session_id
+        base["Authorization"] = "Bearer #{@access_token}" if @access_token
+        @headers.merge(base)
+      end
+
+      def parse_response(response)
+        if response.content_type&.include?("text/event-stream")
+          parse_sse(response)
+        elsif response.body && !response.body.empty?
+          JSON.parse(response.body)
+        end
+      end
+
+      def parse_sse(response)
+        buffer = String.new
+        result = nil
+
+        response.read_body do |chunk|
+          buffer << chunk
+
+          while (index = buffer.index("\n"))
+            line = buffer.slice!(0, index + 1).strip
+            next unless line.start_with?("data: ")
+
+            result = JSON.parse(line.delete_prefix("data: "))
+          end
+        end
+
+        result
+      end
+    end
   end
 end
lib/elelem/net.rb
@@ -1,8 +1,5 @@
 # frozen_string_literal: true
 
-require "net/hippie"
-require "json"
-
 require_relative "net/ollama"
 require_relative "net/openai"
 require_relative "net/claude"
lib/elelem.rb
@@ -1,15 +1,21 @@
 # frozen_string_literal: true
 
+require "base64"
 require "date"
+require "digest"
 require "erb"
 require "fileutils"
 require "json"
 require "json_schemer"
+require "net/hippie"
 require "open3"
 require "pathname"
 require "reline"
+require "securerandom"
 require "stringio"
 require "tempfile"
+require "uri"
+require "webrick"
 
 require_relative "elelem/agent"
 require_relative "elelem/mcp"
elelem.gemspec
@@ -28,6 +28,8 @@ Gem::Specification.new do |spec|
     "lib/elelem.rb",
     "lib/elelem/agent.rb",
     "lib/elelem/mcp.rb",
+    "lib/elelem/mcp/oauth.rb",
+    "lib/elelem/mcp/token_storage.rb",
     "lib/elelem/net.rb",
     "lib/elelem/net/claude.rb",
     "lib/elelem/net/ollama.rb",
@@ -52,7 +54,9 @@ Gem::Specification.new do |spec|
   spec.executables = spec.files.grep(%r{\Aexe/}) { |f| File.basename(f) }
   spec.require_paths = ["lib"]
 
+  spec.add_dependency "base64", "~> 0.1"
   spec.add_dependency "date", "~> 3.0"
+  spec.add_dependency "digest", "~> 3.0"
   spec.add_dependency "erb", "~> 6.0"
   spec.add_dependency "fileutils", "~> 1.0"
   spec.add_dependency "json", "~> 2.0"
@@ -62,7 +66,9 @@ Gem::Specification.new do |spec|
   spec.add_dependency "optparse", "~> 0.1"
   spec.add_dependency "pathname", "~> 0.1"
   spec.add_dependency "reline", "~> 0.6"
+  spec.add_dependency "securerandom", "~> 0.1"
   spec.add_dependency "stringio", "~> 3.0"
   spec.add_dependency "tempfile", "~> 0.3"
   spec.add_dependency "uri", "~> 1.0"
+  spec.add_dependency "webrick", "~> 1.9"
 end
Gemfile.lock
@@ -2,7 +2,9 @@ PATH
   remote: .
   specs:
     elelem (0.9.2)
+      base64 (~> 0.1)
       date (~> 3.0)
+      digest (~> 3.0)
       erb (~> 6.0)
       fileutils (~> 1.0)
       json (~> 2.0)
@@ -12,9 +14,11 @@ PATH
       optparse (~> 0.1)
       pathname (~> 0.1)
       reline (~> 0.6)
+      securerandom (~> 0.1)
       stringio (~> 3.0)
       tempfile (~> 0.3)
       uri (~> 1.0)
+      webrick (~> 1.9)
 
 GEM
   remote: https://rubygems.org/
@@ -23,6 +27,7 @@ GEM
     bigdecimal (4.0.1)
     date (3.5.1)
     diff-lcs (1.6.2)
+    digest (3.2.1)
     erb (6.0.1)
     fileutils (1.8.0)
     hana (1.3.7)
@@ -77,11 +82,13 @@ GEM
       diff-lcs (>= 1.2.0, < 2.0)
       rspec-support (~> 3.13.0)
     rspec-support (3.13.6)
+    securerandom (0.4.1)
     simpleidn (0.2.3)
     stringio (3.2.0)
     tempfile (0.3.1)
     tsort (0.2.0)
     uri (1.1.1)
+    webrick (1.9.2)
 
 PLATFORMS
   ruby