main
  1package policies
  2
  3import (
  4	"fmt"
  5	"testing"
  6
  7	"github.com/cedar-policy/cedar-go"
  8	"github.com/stretchr/testify/assert"
  9	"gitlab.com/mokhax/spike/pkg/gid"
 10)
 11
 12func build(f func(*cedar.Request)) *cedar.Request {
 13	request := &cedar.Request{
 14		Principal: gid.NewEntityUID("gid://example/User/1"),
 15		Action:    cedar.NewEntityUID("HttpMethod", "GET"),
 16		Resource:  cedar.NewEntityUID("HttpPath", "/"),
 17		Context: cedar.NewRecord(cedar.RecordMap{
 18			"host": cedar.String("idp.example.com"),
 19		}),
 20	}
 21	f(request)
 22	return request
 23}
 24
 25func TestAllowed(t *testing.T) {
 26	allowed := []*cedar.Request{
 27		build(func(r *cedar.Request) {}),
 28		build(func(r *cedar.Request) {
 29			r.Principal = gid.NewEntityUID("gid://example/User/1")
 30			r.Action = cedar.NewEntityUID("HttpMethod", "POST")
 31		}),
 32		build(func(r *cedar.Request) {
 33			r.Principal = gid.NewEntityUID("gid://example/User/1")
 34			r.Action = cedar.NewEntityUID("HttpMethod", "PUT")
 35		}),
 36		build(func(r *cedar.Request) {
 37			r.Principal = gid.NewEntityUID("gid://example/User/1")
 38			r.Action = cedar.NewEntityUID("HttpMethod", "PATCH")
 39		}),
 40		build(func(r *cedar.Request) {
 41			r.Principal = gid.NewEntityUID("gid://example/User/1")
 42			r.Action = cedar.NewEntityUID("HttpMethod", "DELETE")
 43		}),
 44		build(func(r *cedar.Request) {
 45			r.Principal = gid.NewEntityUID("gid://example/User/1")
 46			r.Action = cedar.NewEntityUID("HttpMethod", "HEAD")
 47		}),
 48		build(func(r *cedar.Request) {
 49			r.Principal = gid.NewEntityUID("gid://example/User/1")
 50			r.Resource = cedar.NewEntityUID("HttpPath", "/organizations.json")
 51			r.Context = cedar.NewRecord(cedar.RecordMap{
 52				"host": cedar.String("api.example.com"),
 53			})
 54		}),
 55		build(func(r *cedar.Request) {
 56			r.Principal = gid.NewEntityUID("gid://example/User/1")
 57			r.Resource = cedar.NewEntityUID("HttpPath", "/groups.json")
 58			r.Context = cedar.NewRecord(cedar.RecordMap{
 59				"host": cedar.String("api.example.com"),
 60			})
 61		}),
 62		build(func(r *cedar.Request) {
 63			r.Principal = gid.NewEntityUID("gid://example/User/1")
 64			r.Resource = cedar.NewEntityUID("HttpPath", "/.well-known/openid-configuration")
 65			r.Context = cedar.NewRecord(cedar.RecordMap{
 66				"host": cedar.String("idp.example.com"),
 67			})
 68		}),
 69		build(func(r *cedar.Request) {
 70			r.Principal = gid.NewEntityUID("gid://example/User/1")
 71			r.Resource = cedar.NewEntityUID("HttpPath", "/.well-known/oauth-authorization-server")
 72			r.Context = cedar.NewRecord(cedar.RecordMap{
 73				"host": cedar.String("idp.example.com"),
 74			})
 75		}),
 76		build(func(r *cedar.Request) {
 77			r.Principal = gid.NewEntityUID("gid://example/User/*")
 78			r.Resource = cedar.NewEntityUID("HttpPath", "/.well-known/openid-configuration")
 79			r.Context = cedar.NewRecord(cedar.RecordMap{
 80				"host": cedar.String("idp.example.com"),
 81			})
 82		}),
 83		build(func(r *cedar.Request) {
 84			r.Principal = gid.NewEntityUID("gid://example/User/*")
 85			r.Resource = cedar.NewEntityUID("HttpPath", "/.well-known/oauth-authorization-server")
 86			r.Context = cedar.NewRecord(cedar.RecordMap{
 87				"host": cedar.String("idp.example.com"),
 88			})
 89		}),
 90		build(func(r *cedar.Request) {
 91			r.Principal = gid.NewEntityUID("gid://example/User/1")
 92			r.Action = cedar.NewEntityUID("HttpMethod", "POST")
 93			r.Resource = cedar.NewEntityUID("HttpPath", "/twirp/authx.rpc.Ability/Allowed")
 94			r.Context = cedar.NewRecord(cedar.RecordMap{
 95				"host": cedar.String("idp.example.com"),
 96			})
 97		}),
 98		build(func(r *cedar.Request) {
 99			r.Principal = gid.NewEntityUID("gid://example/User/1")
100			r.Action = cedar.NewEntityUID("HttpMethod", "GET")
101			r.Resource = cedar.NewEntityUID("HttpPath", "/index.html")
102			r.Context = cedar.NewRecord(cedar.RecordMap{
103				"host": cedar.String("ui.example.com"),
104			})
105		}),
106	}
107
108	for _, tt := range allowed {
109		t.Run(fmt.Sprintf("allows: %v/%v %v %v%v", tt.Principal.Type, tt.Principal.ID, tt.Action.ID, tt.Context.Map()["host"], tt.Resource.ID), func(t *testing.T) {
110			assert.True(t, Allowed(*tt))
111		})
112	}
113
114	denied := []*cedar.Request{
115		build(func(r *cedar.Request) {
116			r.Principal = gid.ZeroEntityUID()
117			r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("POST"))
118		}),
119		build(func(r *cedar.Request) {
120			r.Principal = gid.ZeroEntityUID()
121			r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("PUT"))
122		}),
123		build(func(r *cedar.Request) {
124			r.Principal = gid.ZeroEntityUID()
125			r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("PATCH"))
126		}),
127		build(func(r *cedar.Request) {
128			r.Principal = gid.ZeroEntityUID()
129			r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("DELETE"))
130		}),
131		build(func(r *cedar.Request) {
132			r.Principal = gid.ZeroEntityUID()
133			r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("HEAD"))
134		}),
135		build(func(r *cedar.Request) {
136			r.Principal = gid.ZeroEntityUID()
137			r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("TRACE"))
138		}),
139	}
140
141	for _, tt := range denied {
142		t.Run(fmt.Sprintf("denies: %v/%v %v %v%v", tt.Principal.Type, tt.Principal.ID, tt.Action.ID, tt.Context.Map()["host"], tt.Resource.ID), func(t *testing.T) {
143			assert.False(t, Allowed(*tt))
144		})
145	}
146}