Commit 759db6e

mo khan <mo@mokhan.ca>
2022-09-24 01:56:38
feat: ad method to extract value from context
1 parent d8047ae
Changed files (3)
pkg/context/key.go
@@ -1,9 +1,20 @@
 package context
 
-import "context"
+import (
+	"context"
+
+	"github.com/xlgmokha/x/pkg/reflect"
+)
 
 type Key[T any] string
 
 func (self Key[T]) With(ctx context.Context, value T) context.Context {
 	return context.WithValue(ctx, self, value)
 }
+
+func (self Key[T]) From(ctx context.Context) T {
+	if value := ctx.Value(self); value != nil {
+		return value.(T)
+	}
+	return reflect.ZeroValue[T]()
+}
pkg/context/key_test.go
@@ -3,22 +3,37 @@ package context
 import (
 	"context"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/assert"
 )
 
 func TestWith(t *testing.T) {
-	t.Run("injects the value into context", func(t *testing.T) {
-		key := Key[int]("ticket")
+	t.Run("With", func(t *testing.T) {
+		t.Run("injects the value into context", func(t *testing.T) {
+			key := Key[int]("ticket")
 
-		value := 42
-		ctx := key.With(context.Background(), value)
+			value := 42
+			ctx := key.With(context.Background(), value)
 
-		assert.Equal(t, value, ctx.Value(key))
+			assert.Equal(t, value, ctx.Value(key))
+		})
 	})
 
-	t.Run("works like this", func(t *testing.T) {
-		ctx := context.WithValue(context.Background(), "ticket", 42)
-		assert.Equal(t, 42, ctx.Value("ticket"))
+	t.Run("From", func(t *testing.T) {
+		t.Run("returns the value for the key", func(t *testing.T) {
+			key := Key[time.Time]("secret")
+			now := time.Now()
+
+			ctx := key.With(context.Background(), now)
+
+			assert.Equal(t, now, key.From(ctx))
+		})
+
+		t.Run("returns the zero value", func(t *testing.T) {
+			key := Key[int]("not-found")
+
+			assert.Equal(t, 0, key.From(context.Background()))
+		})
 	})
 }
pkg/reflect/zero.go
@@ -0,0 +1,6 @@
+package reflect
+
+func ZeroValue[T any]() T {
+	var item T
+	return item
+}