Commit ead9021

mo khan <mo@mokhan.ca>
2025-03-26 17:26:06
feat: exchange a saml response for a token
1 parent 3135a4d
Changed files (2)
bin/ui
@@ -69,18 +69,32 @@ module OAuth
       ].join("?")
     end
 
-    def exchange(grant_type:, code:, code_verifier: "not_implemented")
+    def exchange(grant_type, params = {})
       with_http do |client|
-        client.post(self[:token_endpoint], body: {
-          grant_type: grant_type,
-          code: code,
-          code_verifier: code_verifier,
-        })
+        client.post(self[:token_endpoint], body: body_for(grant_type, params))
       end
     end
 
     private
 
+    def body_for(grant_type, params)
+      case grant_type
+      when "authorization_code"
+        {
+          grant_type: grant_type,
+          code: params.fetch(:code),
+          code_verifier: params.fetch(:code_verifier, "not_implemented"),
+        }
+      when "urn:ietf:params:oauth:grant-type:saml2-bearer"
+        {
+          grant_type: grant_type,
+          assertion: params.fetch(:assertion),
+        }
+      else
+        raise NotImplementedError.new(grant_type)
+      end
+    end
+
     def to_query(params = {})
       params.map do |(key, value)|
         [key, value].join("=")
@@ -221,7 +235,7 @@ class UI
   end
 
   def oauth_callback(request)
-    response = oauth_client.exchange(grant_type: "authorization_code", code: request.params['code'])
+    response = oauth_client.exchange("authorization_code", code: request.params['code'])
     if response.code == "200"
       tokens = JSON.parse(response.body, symbolize_names: true)
       request.session[:access_token] = tokens[:access_token]
@@ -379,21 +393,38 @@ class UI
     saml_response = saml_binding.deserialize(request.params)
     raise saml_response.errors unless saml_response.valid?
 
-    template = <<~ERB
-      <!doctype html>
-      <html>
-        <head>
-          <title></title>
-        </head>
-        <body style="background-color: pink;">
-          <h2>Received SAML Response</h2>
-          <textarea readonly="readonly" disabled="disabled" cols=220 rows=40><%=- saml_response.to_xml(pretty: true) -%></textarea>
-          <pre id="saml-response"><%= request.params["SAMLResponse"] %></pre>
-        </body>
-      </html>
-    ERB
-    html = ERB.new(template, trim_mode: '-').result(binding)
-    [200, { 'Content-Type' => "text/html" }, [html]]
+    response = oauth_client.exchange(
+      "urn:ietf:params:oauth:grant-type:saml2-bearer",
+      assertion: request.params["SAMLResponse"],
+    )
+    if response.code == "200"
+      tokens = JSON.parse(response.body, symbolize_names: true)
+      request.session[:access_token] = tokens[:access_token]
+      request.session[:id_token] = tokens[:id_token]
+      request.session[:refresh_token] = tokens[:access_token]
+
+      template = <<~ERB
+        <!doctype html>
+        <html>
+          <head>
+            <title></title>
+          </head>
+          <body style="background-color: pink;">
+            <h2>Received SAML Response</h2>
+            <textarea readonly="readonly" disabled="disabled" cols=220 rows=40><%=- saml_response.to_xml(pretty: true) -%></textarea>
+            <pre id="raw-saml-response"><%= request.params["SAMLResponse"] %></pre>
+            <pre id="access-token"><%= JSON.pretty_generate(request.session[:access_token]) %></pre>
+
+            <a href="/index.html">Home</a>
+            <a href="/groups.html">Groups</a>
+          </body>
+        </html>
+      ERB
+      html = ERB.new(template, trim_mode: '-').result(binding)
+      [200, { 'Content-Type' => "text/html" }, [html]]
+    else
+      [response.code, response.header, [response.body]]
+    end
   end
 end
 
test/e2e_test.go
@@ -60,8 +60,23 @@ func TestAuthx(t *testing.T) {
 			assert.NoError(t, page.Locator("#submit-button").Click())
 			assert.Contains(t, x.Must(page.Content()), "Received SAML Response")
 
+			t.Run("generates a usable access token", func(t *testing.T) {
+				rawToken := x.Must(page.Locator("#access-token").TextContent())
+				accessToken := strings.Replace(rawToken, "\"", "", -1)
+				assert.NotEmpty(t, accessToken)
+
+				t.Run("GET http://api.example.com:8080/projects.json", func(t *testing.T) {
+					request := x.Must(http.NewRequestWithContext(t.Context(), "GET", "http://api.example.com:8080/projects.json", nil))
+					request.Header.Add("Authorization", "Bearer "+accessToken)
+					response := x.Must(client.Do(request))
+					require.Equal(t, http.StatusOK, response.StatusCode)
+					projects := x.Must(serde.FromJSON[[]map[string]string](response.Body))
+					assert.NotNil(t, projects)
+				})
+			})
+
 			t.Run("exchange SAML assertion for access token", func(t *testing.T) {
-				samlAssertion := x.Must(page.Locator("#saml-response").TextContent())
+				samlAssertion := x.Must(page.Locator("#raw-saml-response").TextContent())
 				io := bytes.NewBuffer(nil)
 				assert.NoError(t, serde.ToJSON(io, map[string]string{
 					"assertion":  samlAssertion,