Commit 9b267c4

mo khan <mo@mokhan.ca>
2025-03-12 22:15:20
feat: require a login before authorizing an auth grant
1 parent f62507b
Changed files (3)
bin/idp
@@ -89,8 +89,11 @@ module Authn
       end
     end
 
+    attr_reader :id
+
     def initialize(attributes)
       @attributes = attributes
+      @id = self[:id]
     end
 
     def [](attribute)
@@ -106,7 +109,7 @@ module Authn
     end
 
     def create_access_token
-      ::Authz::JWT.new(sub: self[:id], iat: Time.now.to_i)
+      ::Authz::JWT.new(sub: to_global_id.to_s, iat: Time.now.to_i)
     end
 
     def assertion_attributes_for(request)
@@ -118,6 +121,10 @@ module Authn
     def valid_password?(entered_password)
       ::BCrypt::Password.new(self[:password_digest]) == entered_password
     end
+
+    def to_global_id
+      ::GlobalID.create(self, app: "example").to_s
+    end
   end
 
   class OnDemandRegistry < Saml::Kit::DefaultRegistry
@@ -139,6 +146,7 @@ module Authn
       when Rack::GET
         case request.path
         when '/sessions/new'
+          request.session.delete(:user_id)
           return get_login(request)
         end
       when Rack::POST
@@ -376,29 +384,38 @@ module Authz
 
       case request.request_method
       when Rack::GET
-        case request.path_info
-        when "/authorize" # RFC-6749
+        case request.path
+        when "/oauth/authorize/continue"
+          if current_user?(request)
+            return get_authorize(request.session[:oauth_params])
+          end
+        when "/oauth/authorize" # RFC-6749
+          oauth_params = request.params.slice('client_id', 'scope', 'redirect_uri', 'response_mode', 'response_type', 'state', 'code_challenge_method', 'code_challenge')
           if current_user?(request)
-            return get_authorize(request)
+            return get_authorize(oauth_params)
           else
-            http_redirect_to("/saml/")
+            request.session[:oauth_params] = oauth_params
+            return http_redirect_to("/sessions/new?redirect_back=/oauth/authorize/continue")
           end
         else
           return http_not_found
         end
       when Rack::POST
-        case request.path_info
-        when "/authorize" # RFC-6749
+        case request.path
+        when "/oauth/authorize" # RFC-6749
           return post_authorize(request)
-        when "/token" # RFC-6749
+        when "/oauth/token" # RFC-6749
+          # TODO:: Look up authorization grant by (code, saml_assertion)
+          user = Authn::User.new(id: SecureRandom.uuid)
           return [200, { 'Content-Type' => "application/json" }, [JSON.pretty_generate({
-            access_token: ::Authz::JWT.new(sub: SecureRandom.uuid, iat: Time.now.to_i).to_jwt,
+            access_token: user.create_access_token.to_jwt,
             token_type: "Bearer",
             issued_token_type: "urn:ietf:params:oauth:token-type:access_token",
             expires_in: 3600,
             refresh_token: SecureRandom.hex(32)
           })]]
         when "/oauth/revoke" # RFC-7009
+          # TODO:: Revoke the JWT token and make it ineligible for usage
           return http_not_found
         else
           return http_not_found
@@ -407,7 +424,7 @@ module Authz
       http_not_found
     end
 
-    def get_authorize(request)
+    def get_authorize(oauth_params)
       template = <<~ERB
         <!DOCTYPE html>
         <html>
@@ -415,14 +432,14 @@ module Authz
           <body>
             <h2>Authorize?</h2>
             <form id="authorize-form" action="/oauth/authorize" method="post">
-              <input type="hidden" name="client_id" value="<%= request.params['client_id'] %>" />
-              <input type="hidden" name="scope" value="<%= request.params['scope'] %>" />
-              <input type="hidden" name="redirect_uri" value="<%= request.params['redirect_uri'] %>" />
-              <input type="hidden" name="response_mode" value="<%= request.params['response_mode'] %>" />
-              <input type="hidden" name="response_type" value="<%= request.params['response_type'] %>" />
-              <input type="hidden" name="state" value="<%= request.params['state'] %>" />
-              <input type="hidden" name="code_challenge_method" value="<%= request.params['code_challenge_method'] %>" />
-              <input type="hidden" name="code_challenge" value="<%= request.params['code_challenge'] %>" />
+              <input type="hidden" name="client_id" value="<%= oauth_params['client_id'] %>" />
+              <input type="hidden" name="scope" value="<%= oauth_params['scope'] %>" />
+              <input type="hidden" name="redirect_uri" value="<%= oauth_params['redirect_uri'] %>" />
+              <input type="hidden" name="response_mode" value="<%= oauth_params['response_mode'] %>" />
+              <input type="hidden" name="response_type" value="<%= oauth_params['response_type'] %>" />
+              <input type="hidden" name="state" value="<%= oauth_params['state'] %>" />
+              <input type="hidden" name="code_challenge_method" value="<%= oauth_params['code_challenge_method'] %>" />
+              <input type="hidden" name="code_challenge" value="<%= oauth_params['code_challenge'] %>" />
               <input id="submit-button" type="submit" value="Submit" />
             </form>
           </body>
bin/ui
@@ -20,6 +20,8 @@ $port = ENV.fetch("PORT", 8283).to_i
 $host = ENV.fetch("HOST", "localhost:#{$port}")
 $idp_host = ENV.fetch("IDP_HOST", "localhost:8282")
 
+Net::Hippie.logger = Logger.new($stdout, level: :debug)
+
 class OnDemandRegistry < Saml::Kit::DefaultRegistry
   def metadata_for(entity_id)
     found = super(entity_id)
@@ -90,16 +92,25 @@ class UI
   end
 
   def oauth_callback(request)
-    response = Net::Hippie.default_client.post(
-      "http://#{$idp_host}/oauth/token",
-      headers: { 'Authorization' => Net::Hippie.basic_auth('client_id', 'secret') },
-      body: {
-        grant_type: "authorization_code",
-        code: request.params['code'],
-        code_verifier: "not_implemented"
-      }
-    )
-    [200, { "Content-Type" => "application/json" }, [JSON.pretty_generate(request.params.merge(JSON.parse(response.body)))]]
+    client = Net::Hippie::Client.new
+    response = client.with_retry do |x|
+      client.post(
+        "http://#{$idp_host}/oauth/token",
+        headers: { 'Authorization' => Net::Hippie.basic_auth('client_id', 'secret') },
+        body: {
+          grant_type: "authorization_code",
+          code: request.params['code'],
+          code_verifier: "not_implemented"
+        }
+      )
+    end
+    if response.code.to_i == 200
+      [200, { "Content-Type" => "application/json" }, [JSON.pretty_generate(
+        request.params.merge(JSON.parse(response.body))
+      )]]
+    else
+      [response.code, response.header, [response.body]]
+    end
   end
 
   def saml_post_to_idp(request)
test/e2e_test.go
@@ -43,6 +43,7 @@ func TestAuthx(t *testing.T) {
 		}
 
 		t.Run("GET http://ui.example.com:8080/saml/new", func(t *testing.T) {
+			assert.NoError(t, page.Context().ClearCookies())
 			x.Must(page.Goto("http://ui.example.com:8080/saml/new"))
 			action := x.Must(page.Locator("#idp-form").GetAttribute("action"))
 			assert.Equal(t, "http://idp.example.com:8080/saml/new", action)
@@ -61,8 +62,15 @@ func TestAuthx(t *testing.T) {
 
 	t.Run("OIDC", func(t *testing.T) {
 		t.Run("GET http://ui.example.com:8080/oidc/new", func(t *testing.T) {
+			assert.NoError(t, page.Context().ClearCookies())
 			x.Must(page.Goto("http://ui.example.com:8080/oidc/new"))
-			assert.Contains(t, page.URL(), "http://idp.example.com:8080/oauth/authorize")
+
+			assert.Contains(t, page.URL(), "http://idp.example.com:8080/sessions/new")
+			page.Locator("#username").Fill("username1")
+			page.Locator("#password").Fill("password1")
+			assert.NoError(t, page.Locator("#login-button").Click())
+
+			assert.Contains(t, page.URL(), "http://idp.example.com:8080/oauth/authorize/continue")
 			assert.NoError(t, page.Locator("#submit-button").Click())
 
 			assert.Contains(t, page.URL(), "http://ui.example.com:8080/oauth/callback")
@@ -177,7 +185,13 @@ func TestAuthx(t *testing.T) {
 				oauth2.SetAuthURLParam("response_type", "code"),
 				oauth2.SetAuthURLParam("response_mode", "fragment"),
 			)
+			assert.NoError(t, page.Context().ClearCookies())
 			x.Must(page.Goto(authURL))
+
+			page.Locator("#username").Fill("username1")
+			page.Locator("#password").Fill("password1")
+			assert.NoError(t, page.Locator("#login-button").Click())
+
 			assert.NoError(t, page.Locator("#submit-button").Click())
 
 			uri := x.Must(url.Parse(page.URL()))