Commit 894e270

mo khan <mo@mokhan.ca>
2025-04-02 19:32:22
refactor: combine cedar policies and add tests
1 parent 2a37de4
pkg/authz/cedar.go
@@ -1,25 +1,33 @@
 package authz
 
 import (
+	"net"
 	"net/http"
 
 	cedar "github.com/cedar-policy/cedar-go"
+	"gitlab.com/mokhax/spike/pkg/gid"
+	xlog "gitlab.com/mokhax/spike/pkg/log"
 	"gitlab.com/mokhax/spike/pkg/policies"
 )
 
 func WithCedar() Authorizer {
 	return AuthorizerFunc(func(r *http.Request) bool {
+		host, _, err := net.SplitHostPort(r.Host)
+		if err != nil {
+			xlog.WithFields(r, xlog.Fields{"error": err})
+			return false
+		}
 		subject, found := TokenFrom(r).Subject()
 		if !found {
-			subject = "*"
+			subject = "gid://User/*"
 		}
 
 		return policies.Allowed(cedar.Request{
-			Principal: cedar.NewEntityUID("Subject", cedar.String(subject)),
-			Action:    cedar.NewEntityUID("Action", cedar.String(r.Method)),
-			Resource:  cedar.NewEntityUID("Path", cedar.String(r.URL.Path)),
+			Principal: gid.NewEntityUID(subject),
+			Action:    cedar.NewEntityUID("HttpMethod", cedar.String(r.Method)),
+			Resource:  cedar.NewEntityUID("HttpPath", cedar.String(r.URL.Path)),
 			Context: cedar.NewRecord(cedar.RecordMap{
-				"Host": cedar.String(r.Host),
+				"host": cedar.String(host),
 			}),
 		})
 	})
pkg/gid/gid.go
@@ -0,0 +1,20 @@
+package gid
+
+import (
+	"net/url"
+	"strings"
+
+	"github.com/cedar-policy/cedar-go"
+)
+
+func NewEntityUID(globalID string) cedar.EntityUID {
+	url, err := url.Parse(globalID)
+	if err != nil {
+		return cedar.NewEntityUID("User", cedar.String(globalID))
+	}
+
+	return cedar.NewEntityUID(
+		cedar.EntityType(url.Hostname()),
+		cedar.String(strings.TrimPrefix(url.Path, "/")),
+	)
+}
pkg/policies/album.cedar
@@ -1,5 +1,5 @@
 permit (
-	principal == User::"alice",
-	action == Action::"view",
-	resource in Album::"jane_vacation"
+  principal == User::"alice",
+  action == Permission::"view",
+  resource in Album::"jane_vacation"
 );
pkg/policies/entities.json
@@ -26,7 +26,8 @@
     "uid": {
       "type": "User",
       "id": "1"
-    }
+    },
+    "parents": []
   },
   {
     "uid": {
@@ -301,5 +302,11 @@
         "id": "4"
       }
     ]
+  },
+  {
+    "uid": {
+      "type": "HttpPath",
+      "id": "/projects.json"
+    }
   }
 ]
pkg/policies/init.go
@@ -5,7 +5,6 @@ import (
 	_ "embed"
 	"fmt"
 	"io/fs"
-	"log"
 	"strings"
 
 	"github.com/cedar-policy/cedar-go"
@@ -57,7 +56,7 @@ func init() {
 	})
 
 	if err != nil {
-		log.Fatal(err)
+		xlog.Default.Printf("error: %v\n", err)
 	}
 }
 
pkg/policies/organization.cedar
@@ -1,5 +1,5 @@
 permit (
-	principal == User::"1",
-	action == Action::"read",
-	resource in Organization::"1"
+  principal == User::"1",
+  action == Permission::"read",
+  resource in Organization::"2"
 );
pkg/policies/policies_test.go
@@ -0,0 +1,59 @@
+package policies
+
+import (
+	"fmt"
+	"testing"
+
+	"github.com/cedar-policy/cedar-go"
+	"github.com/stretchr/testify/assert"
+	"gitlab.com/mokhax/spike/pkg/gid"
+)
+
+func build(f func(*cedar.Request)) *cedar.Request {
+	request := &cedar.Request{
+		Principal: gid.NewEntityUID("gid://User/1"),
+		Action:    cedar.NewEntityUID("HttpMethod", cedar.String("GET")),
+		Resource:  cedar.NewEntityUID("HttpPath", cedar.String("/projects.json")),
+		Context:   cedar.NewRecord(cedar.RecordMap{"host": cedar.String("api.example.com")}),
+	}
+	if f != nil {
+		f(request)
+	}
+	return request
+}
+
+func TestAllowed(t *testing.T) {
+	allowed := []*cedar.Request{
+		build(func(r *cedar.Request) {}),
+		build(func(r *cedar.Request) { r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("POST")) }),
+		build(func(r *cedar.Request) { r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("PUT")) }),
+		build(func(r *cedar.Request) { r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("PATCH")) }),
+		build(func(r *cedar.Request) { r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("DELETE")) }),
+		build(func(r *cedar.Request) { r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("HEAD")) }),
+	}
+
+	for _, tt := range allowed {
+		t.Run(fmt.Sprintf("allows: %v %v %v %v", tt.Principal, tt.Action, tt.Resource, tt.Context), func(t *testing.T) {
+			assert.True(t, Allowed(*tt))
+		})
+	}
+
+	denied := []*cedar.Request{
+		build(func(r *cedar.Request) {
+			r.Principal = gid.NewEntityUID("gid://User/*")
+			r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("POST"))
+		}),
+		build(func(r *cedar.Request) {
+			r.Context = cedar.NewRecord(cedar.RecordMap{"host": cedar.String("unknown.example.com")})
+		}),
+		build(func(r *cedar.Request) {
+			r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("TRACE"))
+		}),
+	}
+
+	for _, tt := range denied {
+		t.Run(fmt.Sprintf("denies: %v %v %v %v", tt.Principal, tt.Action, tt.Resource, tt.Context), func(t *testing.T) {
+			assert.False(t, Allowed(*tt))
+		})
+	}
+}
pkg/policies/rest.cedar
@@ -1,41 +1,12 @@
 permit (
-    principal == Subject::"*",
-    action == Action::"GET",
-    resource in Path::"/projects.json"
-);
-
-permit (
-    principal == Subject::"gid://User/1",
-    action == Action::"GET",
-    resource in Path::"/*.json"
-);
-
-permit (
-    principal == Subject::"gid://User/1",
-    action == Action::"POST",
-    resource in Path::"/*.json"
-);
-
-permit (
-    principal == Subject::"gid://User/1",
-    action == Action::"PUT",
-    resource in Path::"/*.json"
-);
-
-permit (
-    principal == Subject::"gid://User/1",
-    action == Action::"PATCH",
-    resource in Path::"/*.json"
-);
-
-permit (
-    principal == Subject::"gid://User/1",
-    action == Action::"DELETE",
-    resource in Path::"/*.json"
-);
-
-permit (
-    principal == Subject::"gid://User/1",
-    action == Action::"HEAD",
-    resource in Path::"/*.json"
-);
+  principal == User::"1",
+  action in [
+    HttpMethod::"GET",
+    HttpMethod::"POST",
+    HttpMethod::"PUT",
+    HttpMethod::"PATCH",
+    HttpMethod::"DELETE",
+    HttpMethod::"HEAD"
+  ],
+  resource
+) when { context.host == "api.example.com" };
pkg/rpc/ability_service.go
@@ -4,6 +4,7 @@ import (
 	context "context"
 
 	"github.com/cedar-policy/cedar-go"
+	"gitlab.com/mokhax/spike/pkg/gid"
 	"gitlab.com/mokhax/spike/pkg/policies"
 )
 
@@ -17,9 +18,9 @@ func NewAbilityService() *AbilityService {
 
 func (h *AbilityService) Allowed(ctx context.Context, req *AllowRequest) (*AllowReply, error) {
 	ok := policies.Allowed(cedar.Request{
-		Principal: cedar.NewEntityUID("User", cedar.String(req.Subject)),
-		Action:    cedar.NewEntityUID("Action", cedar.String(req.Permission)),
-		Resource:  cedar.NewEntityUID("Album", cedar.String(req.Resource)),
+		Principal: gid.NewEntityUID(req.Subject),
+		Action:    cedar.NewEntityUID("Permission", cedar.String(req.Permission)),
+		Resource:  gid.NewEntityUID(req.Resource),
 		Context:   cedar.NewRecord(cedar.RecordMap{}),
 	})
 	return &AllowReply{Result: ok}, nil
pkg/rpc/server_test.go
@@ -31,7 +31,7 @@ func TestServer(t *testing.T) {
 	defer connection.Close()
 	client := NewAbilityClient(connection)
 
-	t.Run("returns false", func(t *testing.T) {
+	t.Run("forbids", func(t *testing.T) {
 		reply, err := client.Allowed(t.Context(), &AllowRequest{
 			Subject:    "",
 			Permission: "",
@@ -41,17 +41,17 @@ func TestServer(t *testing.T) {
 		assert.False(t, reply.Result)
 	})
 
-	t.Run("returns true for alice:view:jane_vacation", func(t *testing.T) {
+	t.Run("allows alice:view:jane_vacation", func(t *testing.T) {
 		reply, err := client.Allowed(t.Context(), &AllowRequest{
-			Subject:    "alice",
+			Subject:    "gid://User/alice",
 			Permission: "view",
-			Resource:   "jane_vacation",
+			Resource:   "gid://Album/jane_vacation",
 		})
 		require.NoError(t, err)
 		assert.True(t, reply.Result)
 	})
 
-	t.Run("returns gid://User/1:read:gid://Organization/2", func(t *testing.T) {
+	t.Run("allows gid://User/1 read gid://Organization/2", func(t *testing.T) {
 		reply, err := client.Allowed(t.Context(), &AllowRequest{
 			Subject:    "gid://User/1",
 			Permission: "read",