Commit c9f394f

mo khan <mo@mokhan.ca>
2025-03-13 22:43:47
refactor: extract authz interface to test out different PaC libraries
1 parent b55a661
bin/api
@@ -27,20 +27,10 @@ $scheme = ENV.fetch("SCHEME", "http")
 $port = ENV.fetch("PORT", 8284).to_i
 $host = ENV.fetch("HOST", "localhost:#{$port}")
 
-class Organization
-  def initialize(attributes = {})
-    @attributes = attributes
-  end
-
-  def id
-    @attributes[:id]
-  end
-end
-
-class Project
+class Entity
   class << self
     def all
-      @projects ||= []
+      @items ||= []
     end
 
     def create!(attributes)
@@ -54,47 +44,32 @@ class Project
     @attributes = attributes
   end
 
-  def to_h
-    @attributes
+  def id
+    self[:id]
   end
-end
 
-class API
-  attr_reader :rpc
+  def [](attribute)
+    @attributes.fetch(attribute)
+  end
 
-  def initialize
-    @rpc = ::Authx::Rpc::AbilityClient.new("http://idp.example.com:8080/twirp")
+  def to_h
+    @attributes
   end
+end
 
-  def call(env)
-    request = Rack::Request.new(env)
-    path = env['PATH_INFO']
-    case env['REQUEST_METHOD']
-    when 'GET'
-      case path
-      when '/projects.json'
-        return json_ok(Project.all.map(&:to_h))
-      else
-        return json_not_found
-      end
-    when 'POST'
-      case path
-      when "/projects"
-        if authorized?(request, :create_project)
-          return json_created(Project.create!(JSON.parse(request.body.read, symbolize_names: true)))
-        else
-          return json_unauthorized(:create_project)
-        end
-      else
-        return json_not_found
-      end
+class Organization < Entity
+  class << self
+    def default
+      @default ||= create!(id: SecureRandom.uuid)
     end
-    json_not_found
   end
+end
 
-  private
+class Project < Entity
+end
 
-  def authorized?(request, permission, resource = Organization.new(id: 1))
+module HTTPHelpers
+  def authorized?(request, permission, resource)
     authorization = Rack::Auth::AbstractRequest.new(request.env)
     return false unless authorization.provided?
 
@@ -136,6 +111,41 @@ class API
   end
 end
 
+class API
+  include HTTPHelpers
+
+  attr_reader :rpc
+
+  def initialize
+    @rpc = ::Authx::Rpc::AbilityClient.new("http://idp.example.com:8080/twirp")
+  end
+
+  def call(env)
+    request = Rack::Request.new(env)
+    case request.request_method
+    when Rack::GET
+      case request.path
+      when "/organizations", "/organizations.json"
+        return json_ok(Organization.all.map(&:to_h))
+      when "/projects", "/projects.json"
+        return json_ok(Project.all.map(&:to_h))
+      end
+    when Rack::POST
+      case request.path
+      when "/projects", "/projects.json"
+        if authorized?(request, :create_project, Organization.default)
+          return json_created(Project.create!(JSON.parse(request.body.read, symbolize_names: true)))
+        else
+          return json_unauthorized(:create_project)
+        end
+      end
+    end
+    json_not_found
+  end
+
+  private
+end
+
 if __FILE__ == $0
   app = Rack::Builder.new do
     use Rack::CommonLogger
cmd/gtwy/main.go
@@ -1,23 +1,49 @@
 package main
 
 import (
+	"fmt"
 	"log"
+	"net"
 	"net/http"
 
+	"github.com/casbin/casbin/v2"
 	"github.com/xlgmokha/x/pkg/env"
+	"github.com/xlgmokha/x/pkg/x"
+	"gitlab.com/mokhax/spike/pkg/authz"
 	"gitlab.com/mokhax/spike/pkg/cfg"
 	"gitlab.com/mokhax/spike/pkg/prxy"
 	"gitlab.com/mokhax/spike/pkg/srv"
 )
 
+func WithCasbin() authz.Authorizer {
+	enforcer := x.Must(casbin.NewEnforcer("model.conf", "policy.csv"))
+
+	return authz.AuthorizerFunc(func(r *http.Request) bool {
+		host, _, err := net.SplitHostPort(r.Host)
+		if err != nil {
+			return false
+		}
+
+		subject := "71cbc18e-bd41-4229-9ad2-749546a2a4a7" // TODO:: unpack sub claim in JWT
+		ok, err := enforcer.Enforce(subject, host, r.Method, r.URL.Path)
+		if err != nil {
+			fmt.Printf("%v\n", err)
+			return false
+		}
+
+		fmt.Printf("%v: %v %v %v\n", ok, r.Method, host, r.URL.Path)
+		return ok
+	})
+}
+
 func WithRoutes() cfg.Option {
 	return func(c *cfg.Config) {
 		mux := http.NewServeMux()
-		mux.Handle("/", prxy.New(map[string]string{
-			"idp.example.com": "localhost:8282",
-			"ui.example.com":  "localhost:8283",
-			"api.example.com": "localhost:8284",
-		}))
+		mux.Handle("/", authz.HTTP(WithCasbin(), prxy.New(map[string]string{
+			"idp.example.com": "http://localhost:8282",
+			"ui.example.com":  "http://localhost:8283",
+			"api.example.com": "http://localhost:8284",
+		})))
 
 		cfg.WithMux(mux)(c)
 	}
pkg/authz/authz.go
@@ -0,0 +1,23 @@
+package authz
+
+import "net/http"
+
+type Authorizer interface {
+	Authorize(*http.Request) bool
+}
+
+type AuthorizerFunc func(*http.Request) bool
+
+func (f AuthorizerFunc) Authorize(r *http.Request) bool {
+	return f(r)
+}
+
+func HTTP(authorizer Authorizer, h http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		if authorizer.Authorize(r) {
+			h.ServeHTTP(w, r)
+		} else {
+			w.WriteHeader(http.StatusForbidden)
+		}
+	})
+}
pkg/prxy/prxy.go
@@ -3,33 +3,32 @@ package prxy
 import (
 	"fmt"
 	"log"
+	"net"
 	"net/http"
 	"net/http/httputil"
-	"strings"
+	"net/url"
 
-	"github.com/casbin/casbin/v2"
 	"github.com/xlgmokha/x/pkg/x"
 )
 
 func New(routes map[string]string) http.Handler {
-	authz := x.Must(casbin.NewEnforcer("model.conf", "policy.csv"))
+	mapped := map[string]*url.URL{}
+	for source, destination := range routes {
+		mapped[source] = x.Must(url.Parse(destination))
+	}
 
 	return &httputil.ReverseProxy{
 		Director: func(r *http.Request) {
-			segments := strings.SplitN(r.Host, ":", 2)
-			host := segments[0]
-			destinationHost := routes[host]
-
-			log.Printf("%v (from: %v to: %v)\n", r.URL, host, destinationHost)
-
-			subject := "71cbc18e-bd41-4229-9ad2-749546a2a4a7" // TODO:: unpack sub claim in JWT
-			if x.Must(authz.Enforce(subject, host, r.Method, r.URL.Path)) {
-				r.URL.Scheme = "http" // TODO:: use TLS
-				r.Host = destinationHost
-				r.URL.Host = destinationHost
-			} else {
-				log.Println("UNAUTHORIZED") // TODO:: Return forbidden, unauthorized or not found status code
+			host, _, err := net.SplitHostPort(r.Host)
+			if err != nil {
+				fmt.Printf("%v\n", err)
+				return
 			}
+
+			destination := mapped[host]
+			r.URL.Scheme = destination.Scheme
+			r.Host = destination.Host
+			r.URL.Host = destination.Host
 		},
 		Transport:     http.DefaultTransport,
 		FlushInterval: -1,
pkg/prxy/prxy_test.go
@@ -0,0 +1,49 @@
+package prxy
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"net/url"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"github.com/xlgmokha/x/pkg/x"
+	"gitlab.com/mokhax/spike/pkg/test"
+)
+
+func TestProxy(t *testing.T) {
+	t.Run("http://idp.test", func(t *testing.T) {
+		var lastIdPRequest *http.Request
+		var lastUiRequest *http.Request
+
+		idp := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			lastIdPRequest = r
+			w.WriteHeader(http.StatusOK)
+		}))
+		defer idp.Close()
+
+		ui := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			lastUiRequest = r
+			w.WriteHeader(http.StatusTeapot)
+		}))
+		defer ui.Close()
+
+		subject := New(map[string]string{
+			"idp.test": idp.URL,
+			"ui.test":  ui.URL,
+		})
+
+		r, w := test.RequestResponse("GET", "http://idp.test:8080/saml/new")
+
+		subject.ServeHTTP(w, r)
+
+		url := x.Must(url.Parse(idp.URL))
+
+		assert.Nil(t, lastUiRequest)
+		assert.Equal(t, http.StatusOK, w.Code)
+
+		require.NotNil(t, lastIdPRequest)
+		assert.Equal(t, url.Host, lastIdPRequest.Host)
+	})
+}
pkg/test/test.go
@@ -0,0 +1,49 @@
+package test
+
+import (
+	"context"
+	"io"
+	"net/http"
+	"net/http/httptest"
+)
+
+type RequestOption func(*http.Request) *http.Request
+
+func Request(method, target string, options ...RequestOption) *http.Request {
+	request := httptest.NewRequest(method, target, nil)
+	for _, option := range options {
+		request = option(request)
+	}
+	return request
+}
+
+func RequestResponse(method, target string, options ...RequestOption) (*http.Request, *httptest.ResponseRecorder) {
+	return Request(method, target, options...), httptest.NewRecorder()
+}
+
+func WithRequestHeader(key, value string) RequestOption {
+	return func(r *http.Request) *http.Request {
+		r.Header.Set(key, value)
+		return r
+	}
+}
+
+func WithRequestBody(body io.ReadCloser) RequestOption {
+	return func(r *http.Request) *http.Request {
+		r.Body = body
+		return r
+	}
+}
+
+func WithContext(ctx context.Context) RequestOption {
+	return func(r *http.Request) *http.Request {
+		return r.WithContext(ctx)
+	}
+}
+
+func WithCookie(cookie *http.Cookie) RequestOption {
+	return func(r *http.Request) *http.Request {
+		r.AddCookie(cookie)
+		return r
+	}
+}
test/e2e_test.go
@@ -80,19 +80,41 @@ func TestAuthx(t *testing.T) {
 			assert.Equal(t, "Bearer", item.TokenType)
 			assert.NotEmpty(t, item.RefreshToken)
 
-			response := x.Must(http.Get("http://api.example.com:8080/projects.json"))
-			assert.Equal(t, http.StatusOK, response.StatusCode)
-			projects := x.Must(serde.FromJSON[[]map[string]string](response.Body))
-			assert.NotNil(t, projects)
-
-			io := bytes.NewBuffer(nil)
-			assert.NoError(t, serde.ToJSON(io, map[string]string{"name": "example"}))
-			request := x.Must(http.NewRequestWithContext(t.Context(), "POST", "http://api.example.com:8080/projects", io))
-			request.Header.Add("Authorization", "Bearer "+item.AccessToken)
-			response = x.Must(client.Do(request))
-			assert.Equal(t, http.StatusCreated, response.StatusCode)
-			project := x.Must(serde.FromJSON[map[string]string](response.Body))
-			assert.Equal(t, "example", project["name"])
+			t.Run("lists all the organizations", func(t *testing.T) {
+				response := x.Must(http.Get("http://api.example.com:8080/organizations.json"))
+				assert.Equal(t, http.StatusOK, response.StatusCode)
+				organizations := x.Must(serde.FromJSON[[]map[string]string](response.Body))
+				assert.NotNil(t, organizations)
+			})
+
+			t.Run("lists all the projects", func(t *testing.T) {
+				response := x.Must(http.Get("http://api.example.com:8080/projects.json"))
+				assert.Equal(t, http.StatusOK, response.StatusCode)
+				projects := x.Must(serde.FromJSON[[]map[string]string](response.Body))
+				assert.NotNil(t, projects)
+			})
+
+			t.Run("creates a new project", func(t *testing.T) {
+				io := bytes.NewBuffer(nil)
+				assert.NoError(t, serde.ToJSON(io, map[string]string{"name": "example"}))
+				request := x.Must(http.NewRequestWithContext(t.Context(), "POST", "http://api.example.com:8080/projects", io))
+				request.Header.Add("Authorization", "Bearer "+item.AccessToken)
+				response := x.Must(client.Do(request))
+				assert.Equal(t, http.StatusCreated, response.StatusCode)
+				project := x.Must(serde.FromJSON[map[string]string](response.Body))
+				assert.Equal(t, "example", project["name"])
+			})
+
+			t.Run("creates another projects", func(t *testing.T) {
+				io := bytes.NewBuffer(nil)
+				assert.NoError(t, serde.ToJSON(io, map[string]string{"name": "example2"}))
+				request := x.Must(http.NewRequestWithContext(t.Context(), "POST", "http://api.example.com:8080/projects.json", io))
+				request.Header.Add("Authorization", "Bearer "+item.AccessToken)
+				response := x.Must(client.Do(request))
+				assert.Equal(t, http.StatusCreated, response.StatusCode)
+				project := x.Must(serde.FromJSON[map[string]string](response.Body))
+				assert.Equal(t, "example2", project["name"])
+			})
 		})
 	})
 
magefile.go
@@ -6,7 +6,6 @@ package main
 import (
 	"context"
 	"path/filepath"
-	"runtime"
 
 	"github.com/magefile/mage/mg"
 	"github.com/magefile/mage/sh"
@@ -56,16 +55,6 @@ func Api() error {
 	return sh.RunWithV(env, "ruby", "./bin/api")
 }
 
-// Open a web browser to the login page
-func Browser() error {
-	url := "http://localhost:8080/ui/sessions/new"
-	if runtime.GOOS == "linux" {
-		return sh.RunV("xdg-open", url)
-	} else {
-		return sh.RunV("open", url)
-	}
-}
-
 // Generate gRPC from protocal buffers
 func Protos() error {
 	outDir := "lib/authx/rpc"
@@ -94,5 +83,5 @@ func Test(ctx context.Context) error {
 	mg.CtxDeps(ctx, func() error {
 		return sh.RunV("go", "clean", "-testcache")
 	})
-	return sh.RunV("go", "test", "-v", "./...")
+	return sh.RunV("go", "test", "-shuffle=on", "-v", "./...")
 }