Commit c9f394f
Changed files (8)
bin
cmd
gtwy
pkg
test
bin/api
@@ -27,20 +27,10 @@ $scheme = ENV.fetch("SCHEME", "http")
$port = ENV.fetch("PORT", 8284).to_i
$host = ENV.fetch("HOST", "localhost:#{$port}")
-class Organization
- def initialize(attributes = {})
- @attributes = attributes
- end
-
- def id
- @attributes[:id]
- end
-end
-
-class Project
+class Entity
class << self
def all
- @projects ||= []
+ @items ||= []
end
def create!(attributes)
@@ -54,47 +44,32 @@ class Project
@attributes = attributes
end
- def to_h
- @attributes
+ def id
+ self[:id]
end
-end
-class API
- attr_reader :rpc
+ def [](attribute)
+ @attributes.fetch(attribute)
+ end
- def initialize
- @rpc = ::Authx::Rpc::AbilityClient.new("http://idp.example.com:8080/twirp")
+ def to_h
+ @attributes
end
+end
- def call(env)
- request = Rack::Request.new(env)
- path = env['PATH_INFO']
- case env['REQUEST_METHOD']
- when 'GET'
- case path
- when '/projects.json'
- return json_ok(Project.all.map(&:to_h))
- else
- return json_not_found
- end
- when 'POST'
- case path
- when "/projects"
- if authorized?(request, :create_project)
- return json_created(Project.create!(JSON.parse(request.body.read, symbolize_names: true)))
- else
- return json_unauthorized(:create_project)
- end
- else
- return json_not_found
- end
+class Organization < Entity
+ class << self
+ def default
+ @default ||= create!(id: SecureRandom.uuid)
end
- json_not_found
end
+end
- private
+class Project < Entity
+end
- def authorized?(request, permission, resource = Organization.new(id: 1))
+module HTTPHelpers
+ def authorized?(request, permission, resource)
authorization = Rack::Auth::AbstractRequest.new(request.env)
return false unless authorization.provided?
@@ -136,6 +111,41 @@ class API
end
end
+class API
+ include HTTPHelpers
+
+ attr_reader :rpc
+
+ def initialize
+ @rpc = ::Authx::Rpc::AbilityClient.new("http://idp.example.com:8080/twirp")
+ end
+
+ def call(env)
+ request = Rack::Request.new(env)
+ case request.request_method
+ when Rack::GET
+ case request.path
+ when "/organizations", "/organizations.json"
+ return json_ok(Organization.all.map(&:to_h))
+ when "/projects", "/projects.json"
+ return json_ok(Project.all.map(&:to_h))
+ end
+ when Rack::POST
+ case request.path
+ when "/projects", "/projects.json"
+ if authorized?(request, :create_project, Organization.default)
+ return json_created(Project.create!(JSON.parse(request.body.read, symbolize_names: true)))
+ else
+ return json_unauthorized(:create_project)
+ end
+ end
+ end
+ json_not_found
+ end
+
+ private
+end
+
if __FILE__ == $0
app = Rack::Builder.new do
use Rack::CommonLogger
cmd/gtwy/main.go
@@ -1,23 +1,49 @@
package main
import (
+ "fmt"
"log"
+ "net"
"net/http"
+ "github.com/casbin/casbin/v2"
"github.com/xlgmokha/x/pkg/env"
+ "github.com/xlgmokha/x/pkg/x"
+ "gitlab.com/mokhax/spike/pkg/authz"
"gitlab.com/mokhax/spike/pkg/cfg"
"gitlab.com/mokhax/spike/pkg/prxy"
"gitlab.com/mokhax/spike/pkg/srv"
)
+func WithCasbin() authz.Authorizer {
+ enforcer := x.Must(casbin.NewEnforcer("model.conf", "policy.csv"))
+
+ return authz.AuthorizerFunc(func(r *http.Request) bool {
+ host, _, err := net.SplitHostPort(r.Host)
+ if err != nil {
+ return false
+ }
+
+ subject := "71cbc18e-bd41-4229-9ad2-749546a2a4a7" // TODO:: unpack sub claim in JWT
+ ok, err := enforcer.Enforce(subject, host, r.Method, r.URL.Path)
+ if err != nil {
+ fmt.Printf("%v\n", err)
+ return false
+ }
+
+ fmt.Printf("%v: %v %v %v\n", ok, r.Method, host, r.URL.Path)
+ return ok
+ })
+}
+
func WithRoutes() cfg.Option {
return func(c *cfg.Config) {
mux := http.NewServeMux()
- mux.Handle("/", prxy.New(map[string]string{
- "idp.example.com": "localhost:8282",
- "ui.example.com": "localhost:8283",
- "api.example.com": "localhost:8284",
- }))
+ mux.Handle("/", authz.HTTP(WithCasbin(), prxy.New(map[string]string{
+ "idp.example.com": "http://localhost:8282",
+ "ui.example.com": "http://localhost:8283",
+ "api.example.com": "http://localhost:8284",
+ })))
cfg.WithMux(mux)(c)
}
pkg/authz/authz.go
@@ -0,0 +1,23 @@
+package authz
+
+import "net/http"
+
+type Authorizer interface {
+ Authorize(*http.Request) bool
+}
+
+type AuthorizerFunc func(*http.Request) bool
+
+func (f AuthorizerFunc) Authorize(r *http.Request) bool {
+ return f(r)
+}
+
+func HTTP(authorizer Authorizer, h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if authorizer.Authorize(r) {
+ h.ServeHTTP(w, r)
+ } else {
+ w.WriteHeader(http.StatusForbidden)
+ }
+ })
+}
pkg/prxy/prxy.go
@@ -3,33 +3,32 @@ package prxy
import (
"fmt"
"log"
+ "net"
"net/http"
"net/http/httputil"
- "strings"
+ "net/url"
- "github.com/casbin/casbin/v2"
"github.com/xlgmokha/x/pkg/x"
)
func New(routes map[string]string) http.Handler {
- authz := x.Must(casbin.NewEnforcer("model.conf", "policy.csv"))
+ mapped := map[string]*url.URL{}
+ for source, destination := range routes {
+ mapped[source] = x.Must(url.Parse(destination))
+ }
return &httputil.ReverseProxy{
Director: func(r *http.Request) {
- segments := strings.SplitN(r.Host, ":", 2)
- host := segments[0]
- destinationHost := routes[host]
-
- log.Printf("%v (from: %v to: %v)\n", r.URL, host, destinationHost)
-
- subject := "71cbc18e-bd41-4229-9ad2-749546a2a4a7" // TODO:: unpack sub claim in JWT
- if x.Must(authz.Enforce(subject, host, r.Method, r.URL.Path)) {
- r.URL.Scheme = "http" // TODO:: use TLS
- r.Host = destinationHost
- r.URL.Host = destinationHost
- } else {
- log.Println("UNAUTHORIZED") // TODO:: Return forbidden, unauthorized or not found status code
+ host, _, err := net.SplitHostPort(r.Host)
+ if err != nil {
+ fmt.Printf("%v\n", err)
+ return
}
+
+ destination := mapped[host]
+ r.URL.Scheme = destination.Scheme
+ r.Host = destination.Host
+ r.URL.Host = destination.Host
},
Transport: http.DefaultTransport,
FlushInterval: -1,
pkg/prxy/prxy_test.go
@@ -0,0 +1,49 @@
+package prxy
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/xlgmokha/x/pkg/x"
+ "gitlab.com/mokhax/spike/pkg/test"
+)
+
+func TestProxy(t *testing.T) {
+ t.Run("http://idp.test", func(t *testing.T) {
+ var lastIdPRequest *http.Request
+ var lastUiRequest *http.Request
+
+ idp := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ lastIdPRequest = r
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer idp.Close()
+
+ ui := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ lastUiRequest = r
+ w.WriteHeader(http.StatusTeapot)
+ }))
+ defer ui.Close()
+
+ subject := New(map[string]string{
+ "idp.test": idp.URL,
+ "ui.test": ui.URL,
+ })
+
+ r, w := test.RequestResponse("GET", "http://idp.test:8080/saml/new")
+
+ subject.ServeHTTP(w, r)
+
+ url := x.Must(url.Parse(idp.URL))
+
+ assert.Nil(t, lastUiRequest)
+ assert.Equal(t, http.StatusOK, w.Code)
+
+ require.NotNil(t, lastIdPRequest)
+ assert.Equal(t, url.Host, lastIdPRequest.Host)
+ })
+}
pkg/test/test.go
@@ -0,0 +1,49 @@
+package test
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+)
+
+type RequestOption func(*http.Request) *http.Request
+
+func Request(method, target string, options ...RequestOption) *http.Request {
+ request := httptest.NewRequest(method, target, nil)
+ for _, option := range options {
+ request = option(request)
+ }
+ return request
+}
+
+func RequestResponse(method, target string, options ...RequestOption) (*http.Request, *httptest.ResponseRecorder) {
+ return Request(method, target, options...), httptest.NewRecorder()
+}
+
+func WithRequestHeader(key, value string) RequestOption {
+ return func(r *http.Request) *http.Request {
+ r.Header.Set(key, value)
+ return r
+ }
+}
+
+func WithRequestBody(body io.ReadCloser) RequestOption {
+ return func(r *http.Request) *http.Request {
+ r.Body = body
+ return r
+ }
+}
+
+func WithContext(ctx context.Context) RequestOption {
+ return func(r *http.Request) *http.Request {
+ return r.WithContext(ctx)
+ }
+}
+
+func WithCookie(cookie *http.Cookie) RequestOption {
+ return func(r *http.Request) *http.Request {
+ r.AddCookie(cookie)
+ return r
+ }
+}
test/e2e_test.go
@@ -80,19 +80,41 @@ func TestAuthx(t *testing.T) {
assert.Equal(t, "Bearer", item.TokenType)
assert.NotEmpty(t, item.RefreshToken)
- response := x.Must(http.Get("http://api.example.com:8080/projects.json"))
- assert.Equal(t, http.StatusOK, response.StatusCode)
- projects := x.Must(serde.FromJSON[[]map[string]string](response.Body))
- assert.NotNil(t, projects)
-
- io := bytes.NewBuffer(nil)
- assert.NoError(t, serde.ToJSON(io, map[string]string{"name": "example"}))
- request := x.Must(http.NewRequestWithContext(t.Context(), "POST", "http://api.example.com:8080/projects", io))
- request.Header.Add("Authorization", "Bearer "+item.AccessToken)
- response = x.Must(client.Do(request))
- assert.Equal(t, http.StatusCreated, response.StatusCode)
- project := x.Must(serde.FromJSON[map[string]string](response.Body))
- assert.Equal(t, "example", project["name"])
+ t.Run("lists all the organizations", func(t *testing.T) {
+ response := x.Must(http.Get("http://api.example.com:8080/organizations.json"))
+ assert.Equal(t, http.StatusOK, response.StatusCode)
+ organizations := x.Must(serde.FromJSON[[]map[string]string](response.Body))
+ assert.NotNil(t, organizations)
+ })
+
+ t.Run("lists all the projects", func(t *testing.T) {
+ response := x.Must(http.Get("http://api.example.com:8080/projects.json"))
+ assert.Equal(t, http.StatusOK, response.StatusCode)
+ projects := x.Must(serde.FromJSON[[]map[string]string](response.Body))
+ assert.NotNil(t, projects)
+ })
+
+ t.Run("creates a new project", func(t *testing.T) {
+ io := bytes.NewBuffer(nil)
+ assert.NoError(t, serde.ToJSON(io, map[string]string{"name": "example"}))
+ request := x.Must(http.NewRequestWithContext(t.Context(), "POST", "http://api.example.com:8080/projects", io))
+ request.Header.Add("Authorization", "Bearer "+item.AccessToken)
+ response := x.Must(client.Do(request))
+ assert.Equal(t, http.StatusCreated, response.StatusCode)
+ project := x.Must(serde.FromJSON[map[string]string](response.Body))
+ assert.Equal(t, "example", project["name"])
+ })
+
+ t.Run("creates another projects", func(t *testing.T) {
+ io := bytes.NewBuffer(nil)
+ assert.NoError(t, serde.ToJSON(io, map[string]string{"name": "example2"}))
+ request := x.Must(http.NewRequestWithContext(t.Context(), "POST", "http://api.example.com:8080/projects.json", io))
+ request.Header.Add("Authorization", "Bearer "+item.AccessToken)
+ response := x.Must(client.Do(request))
+ assert.Equal(t, http.StatusCreated, response.StatusCode)
+ project := x.Must(serde.FromJSON[map[string]string](response.Body))
+ assert.Equal(t, "example2", project["name"])
+ })
})
})
magefile.go
@@ -6,7 +6,6 @@ package main
import (
"context"
"path/filepath"
- "runtime"
"github.com/magefile/mage/mg"
"github.com/magefile/mage/sh"
@@ -56,16 +55,6 @@ func Api() error {
return sh.RunWithV(env, "ruby", "./bin/api")
}
-// Open a web browser to the login page
-func Browser() error {
- url := "http://localhost:8080/ui/sessions/new"
- if runtime.GOOS == "linux" {
- return sh.RunV("xdg-open", url)
- } else {
- return sh.RunV("open", url)
- }
-}
-
// Generate gRPC from protocal buffers
func Protos() error {
outDir := "lib/authx/rpc"
@@ -94,5 +83,5 @@ func Test(ctx context.Context) error {
mg.CtxDeps(ctx, func() error {
return sh.RunV("go", "clean", "-testcache")
})
- return sh.RunV("go", "test", "-v", "./...")
+ return sh.RunV("go", "test", "-shuffle=on", "-v", "./...")
}