Commit 9309624

mo khan <mo@mokhan.ca>
2025-03-11 22:45:12
refactor: move more saml code into authn namespace
1 parent ae6a630
Changed files (1)
bin
bin/idp
@@ -63,10 +63,26 @@ module Authn
     end
   end
 
+  class OnDemandRegistry < Saml::Kit::DefaultRegistry
+    def metadata_for(entity_id)
+      found = super(entity_id)
+      return found if found
+
+      register_url(entity_id, verify_ssl: false)
+      super(entity_id)
+    end
+  end
+
   class SAMLController
     include ::HTTPHelpers
 
     def initialize(scheme, host)
+      Saml::Kit.configure do |x|
+        x.entity_id = "#{$scheme}://#{$host}/saml/metadata.xml"
+        x.registry = OnDemandRegistry.new
+        x.logger = Logger.new("/dev/stderr")
+      end
+
       @saml_metadata = Saml::Kit::Metadata.build do |builder|
         builder.contact_email = 'hi@example.com'
         builder.organization_name = "Acme, Inc"
@@ -168,43 +184,27 @@ module Authn
   end
 end
 
-class OnDemandRegistry < Saml::Kit::DefaultRegistry
-  def metadata_for(entity_id)
-    found = super(entity_id)
-    return found if found
-
-    register_url(entity_id, verify_ssl: false)
-    super(entity_id)
+class Organization
+  class << self
+    def find(id)
+      new
+    end
   end
 end
 
-Saml::Kit.configure do |x|
-  x.entity_id = "#{$scheme}://#{$host}/saml/metadata.xml"
-  x.registry = OnDemandRegistry.new
-  x.logger = Logger.new("/dev/stderr")
-end
-
-class OrganizationPolicy < DeclarativePolicy::Base
-  condition(:owner) { true }
-
-  rule { owner }.enable :create_project
-end
-
 DeclarativePolicy.configure do
   name_transformation do |name|
-    "#{name}Policy"
+    "::Authz::#{name}Policy"
   end
 end
 
-class Organization
-  class << self
-    def find(id)
-      new
-    end
+module Authz
+  class OrganizationPolicy < DeclarativePolicy::Base
+    condition(:owner) { true }
+
+    rule { owner }.enable :create_project
   end
-end
 
-module Authz
   class JWT
     attr_reader :claims
 
@@ -222,7 +222,7 @@ module Authz
   end
 
   module Rpc
-    class AbilityHandler
+    class Ability
       def allowed(request, env)
         {
           result: can?(request)
@@ -255,84 +255,84 @@ module Authz
       end
     end
   end
-end
 
-class OAuthController
-  include ::HTTPHelpers
+  class OAuthController
+    include ::HTTPHelpers
 
-  def call(env)
-    path = env['PATH_INFO']
-    case env['REQUEST_METHOD']
-    when 'GET'
-      case path
-      when "/authorize" # RFC-6749
-        return get_authorize(Rack::Request.new(env))
-      else
-        return http_not_found
-      end
-    when 'POST'
-      case path
-      when "/authorize" # RFC-6749
-        return post_authorize(Rack::Request.new(env))
-      when "/token" # RFC-6749
-        return [200, { 'Content-Type' => "application/json" }, [JSON.pretty_generate({
-          access_token: ::Authz::JWT.new(sub: SecureRandom.uuid, iat: Time.now.to_i).to_jwt,
-          token_type: "Bearer",
-          issued_token_type: "urn:ietf:params:oauth:token-type:access_token",
-          expires_in: 3600,
-          refresh_token: SecureRandom.hex(32)
-        })]]
-      when "/oauth/revoke" # RFC-7009
-        return http_not_found
-      else
-        return http_not_found
+    def call(env)
+      path = env['PATH_INFO']
+      case env['REQUEST_METHOD']
+      when 'GET'
+        case path
+        when "/authorize" # RFC-6749
+          return get_authorize(Rack::Request.new(env))
+        else
+          return http_not_found
+        end
+      when 'POST'
+        case path
+        when "/authorize" # RFC-6749
+          return post_authorize(Rack::Request.new(env))
+        when "/token" # RFC-6749
+          return [200, { 'Content-Type' => "application/json" }, [JSON.pretty_generate({
+            access_token: ::Authz::JWT.new(sub: SecureRandom.uuid, iat: Time.now.to_i).to_jwt,
+            token_type: "Bearer",
+            issued_token_type: "urn:ietf:params:oauth:token-type:access_token",
+            expires_in: 3600,
+            refresh_token: SecureRandom.hex(32)
+          })]]
+        when "/oauth/revoke" # RFC-7009
+          return http_not_found
+        else
+          return http_not_found
+        end
       end
+      http_not_found
     end
-    http_not_found
-  end
 
-  def get_authorize(request)
-    template = <<~ERB
-      <!doctype html>
-      <html>
-        <head><title></title></head>
-        <body>
-          <h2>Authorize?</h2>
-          <form id="authorize-form" action="/oauth/authorize" method="post">
-            <input type="hidden" name="client_id" value="<%= request.params['client_id'] %>" />
-            <input type="hidden" name="scope" value="<%= request.params['scope'] %>" />
-            <input type="hidden" name="redirect_uri" value="<%= request.params['redirect_uri'] %>" />
-            <input type="hidden" name="response_mode" value="<%= request.params['response_mode'] %>" />
-            <input type="hidden" name="response_type" value="<%= request.params['response_type'] %>" />
-            <input type="hidden" name="state" value="<%= request.params['state'] %>" />
-            <input type="hidden" name="code_challenge_method" value="<%= request.params['code_challenge_method'] %>" />
-            <input type="hidden" name="code_challenge" value="<%= request.params['code_challenge'] %>" />
-            <input id="submit-button" type="submit" value="Submit" />
-          </form>
-        </body>
-      </html>
-    ERB
-    html = ERB.new(template, trim_mode: '-').result(binding)
-    [200, { 'Content-Type' => "text/html" }, [html]]
-  end
+    def get_authorize(request)
+      template = <<~ERB
+        <!doctype html>
+        <html>
+          <head><title></title></head>
+          <body>
+            <h2>Authorize?</h2>
+            <form id="authorize-form" action="/oauth/authorize" method="post">
+              <input type="hidden" name="client_id" value="<%= request.params['client_id'] %>" />
+              <input type="hidden" name="scope" value="<%= request.params['scope'] %>" />
+              <input type="hidden" name="redirect_uri" value="<%= request.params['redirect_uri'] %>" />
+              <input type="hidden" name="response_mode" value="<%= request.params['response_mode'] %>" />
+              <input type="hidden" name="response_type" value="<%= request.params['response_type'] %>" />
+              <input type="hidden" name="state" value="<%= request.params['state'] %>" />
+              <input type="hidden" name="code_challenge_method" value="<%= request.params['code_challenge_method'] %>" />
+              <input type="hidden" name="code_challenge" value="<%= request.params['code_challenge'] %>" />
+              <input id="submit-button" type="submit" value="Submit" />
+            </form>
+          </body>
+        </html>
+      ERB
+      html = ERB.new(template, trim_mode: '-').result(binding)
+      [200, { 'Content-Type' => "text/html" }, [html]]
+    end
+
+    def post_authorize(request)
+      params = request.params.slice('client_id', 'redirect_uri', 'response_type', 'response_mode', 'state', 'code_challenge_method', 'code_challenge', 'scope')
+      case params['response_type']
+      when 'code'
+        case params['response_mode']
+        when 'fragment'
+          return [302, { 'Location' => "#{params['redirect_uri']}#code=#{SecureRandom.uuid}&state=#{params['state']}" }, []]
+        when 'query'
+          return [302, { 'Location' => "#{params['redirect_uri']}?code=#{SecureRandom.uuid}&state=#{params['state']}" }, []]
+        else
+          # TODO:: form post
+        end
 
-  def post_authorize(request)
-    params = request.params.slice('client_id', 'redirect_uri', 'response_type', 'response_mode', 'state', 'code_challenge_method', 'code_challenge', 'scope')
-    case params['response_type']
-    when 'code'
-      case params['response_mode']
-      when 'fragment'
-        return [302, { 'Location' => "#{params['redirect_uri']}#code=#{SecureRandom.uuid}&state=#{params['state']}" }, []]
-      when 'query'
-        return [302, { 'Location' => "#{params['redirect_uri']}?code=#{SecureRandom.uuid}&state=#{params['state']}" }, []]
+      when 'token'
+        return http_not_found
       else
-        # TODO:: form post
+        return http_not_found
       end
-
-    when 'token'
-      return http_not_found
-    else
-      return http_not_found
     end
   end
 end
@@ -451,10 +451,10 @@ if __FILE__ == $0
     use Rack::Reloader
     map "/twirp" do
       # https://github.com/arthurnn/twirp-ruby/wiki/Service-Handlers
-      run ::Authx::Rpc::AbilityService.new(::Authz::Rpc::AbilityHandler.new)
+      run ::Authx::Rpc::AbilityService.new(::Authz::Rpc::Ability.new)
     end
     map "/oauth" do
-      run OAuthController.new
+      run ::Authz::OAuthController.new
     end
 
     map "/saml" do