Skip to content

Commit

Permalink
Add asynq/x/rate package
Browse files Browse the repository at this point in the history
- Added a directory /x for external, experimental packeges
- Added a `rate` package to enable rate limiting across multiple asynq worker servers
  • Loading branch information
ajatprabha authored Nov 3, 2021
1 parent 0d2c0f6 commit 23c522d
Show file tree
Hide file tree
Showing 8 changed files with 661 additions and 56 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
.asynq.*

# Ignore editor config files
.vscode
.vscode
.idea
55 changes: 5 additions & 50 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,82 +6,37 @@ package asynq

import (
"context"
"time"

"github.com/hibiken/asynq/internal/base"
asynqcontext "github.com/hibiken/asynq/internal/context"
)

// A taskMetadata holds task scoped data to put in context.
type taskMetadata struct {
id string
maxRetry int
retryCount int
qname string
}

// ctxKey type is unexported to prevent collisions with context keys defined in
// other packages.
type ctxKey int

// metadataCtxKey is the context key for the task metadata.
// Its value of zero is arbitrary.
const metadataCtxKey ctxKey = 0

// createContext returns a context and cancel function for a given task message.
func createContext(msg *base.TaskMessage, deadline time.Time) (context.Context, context.CancelFunc) {
metadata := taskMetadata{
id: msg.ID.String(),
maxRetry: msg.Retry,
retryCount: msg.Retried,
qname: msg.Queue,
}
ctx := context.WithValue(context.Background(), metadataCtxKey, metadata)
return context.WithDeadline(ctx, deadline)
}

// GetTaskID extracts a task ID from a context, if any.
//
// ID of a task is guaranteed to be unique.
// ID of a task doesn't change if the task is being retried.
func GetTaskID(ctx context.Context) (id string, ok bool) {
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
if !ok {
return "", false
}
return metadata.id, true
return asynqcontext.GetTaskID(ctx)
}

// GetRetryCount extracts retry count from a context, if any.
//
// Return value n indicates the number of times associated task has been
// retried so far.
func GetRetryCount(ctx context.Context) (n int, ok bool) {
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
if !ok {
return 0, false
}
return metadata.retryCount, true
return asynqcontext.GetRetryCount(ctx)
}

// GetMaxRetry extracts maximum retry from a context, if any.
//
// Return value n indicates the maximum number of times the assoicated task
// can be retried if ProcessTask returns a non-nil error.
func GetMaxRetry(ctx context.Context) (n int, ok bool) {
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
if !ok {
return 0, false
}
return metadata.maxRetry, true
return asynqcontext.GetMaxRetry(ctx)
}

// GetQueueName extracts queue name from a context, if any.
//
// Return value qname indicates which queue the task was pulled from.
func GetQueueName(ctx context.Context) (qname string, ok bool) {
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
if !ok {
return "", false
}
return metadata.qname, true
return asynqcontext.GetQueueName(ctx)
}
87 changes: 87 additions & 0 deletions internal/context/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright 2020 Kentaro Hibino. All rights reserved.
// Use of this source code is governed by a MIT license
// that can be found in the LICENSE file.

package context

import (
"context"
"time"

"github.com/hibiken/asynq/internal/base"
)

// A taskMetadata holds task scoped data to put in context.
type taskMetadata struct {
id string
maxRetry int
retryCount int
qname string
}

// ctxKey type is unexported to prevent collisions with context keys defined in
// other packages.
type ctxKey int

// metadataCtxKey is the context key for the task metadata.
// Its value of zero is arbitrary.
const metadataCtxKey ctxKey = 0

// New returns a context and cancel function for a given task message.
func New(msg *base.TaskMessage, deadline time.Time) (context.Context, context.CancelFunc) {
metadata := taskMetadata{
id: msg.ID.String(),
maxRetry: msg.Retry,
retryCount: msg.Retried,
qname: msg.Queue,
}
ctx := context.WithValue(context.Background(), metadataCtxKey, metadata)
return context.WithDeadline(ctx, deadline)
}

// GetTaskID extracts a task ID from a context, if any.
//
// ID of a task is guaranteed to be unique.
// ID of a task doesn't change if the task is being retried.
func GetTaskID(ctx context.Context) (id string, ok bool) {
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
if !ok {
return "", false
}
return metadata.id, true
}

// GetRetryCount extracts retry count from a context, if any.
//
// Return value n indicates the number of times associated task has been
// retried so far.
func GetRetryCount(ctx context.Context) (n int, ok bool) {
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
if !ok {
return 0, false
}
return metadata.retryCount, true
}

// GetMaxRetry extracts maximum retry from a context, if any.
//
// Return value n indicates the maximum number of times the assoicated task
// can be retried if ProcessTask returns a non-nil error.
func GetMaxRetry(ctx context.Context) (n int, ok bool) {
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
if !ok {
return 0, false
}
return metadata.maxRetry, true
}

// GetQueueName extracts queue name from a context, if any.
//
// Return value qname indicates which queue the task was pulled from.
func GetQueueName(ctx context.Context) (qname string, ok bool) {
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
if !ok {
return "", false
}
return metadata.qname, true
}
8 changes: 4 additions & 4 deletions context_test.go → internal/context/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a MIT license
// that can be found in the LICENSE file.

package asynq
package context

import (
"context"
Expand All @@ -28,7 +28,7 @@ func TestCreateContextWithFutureDeadline(t *testing.T) {
Payload: nil,
}

ctx, cancel := createContext(msg, tc.deadline)
ctx, cancel := New(msg, tc.deadline)

select {
case x := <-ctx.Done():
Expand Down Expand Up @@ -68,7 +68,7 @@ func TestCreateContextWithPastDeadline(t *testing.T) {
Payload: nil,
}

ctx, cancel := createContext(msg, tc.deadline)
ctx, cancel := New(msg, tc.deadline)
defer cancel()

select {
Expand Down Expand Up @@ -98,7 +98,7 @@ func TestGetTaskMetadataFromContext(t *testing.T) {
}

for _, tc := range tests {
ctx, cancel := createContext(tc.msg, time.Now().Add(30*time.Minute))
ctx, cancel := New(tc.msg, time.Now().Add(30*time.Minute))
defer cancel()

id, ok := GetTaskID(ctx)
Expand Down
3 changes: 2 additions & 1 deletion processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"

"github.com/hibiken/asynq/internal/base"
asynqcontext "github.com/hibiken/asynq/internal/context"
"github.com/hibiken/asynq/internal/errors"
"github.com/hibiken/asynq/internal/log"
"golang.org/x/time/rate"
Expand Down Expand Up @@ -189,7 +190,7 @@ func (p *processor) exec() {
<-p.sema // release token
}()

ctx, cancel := createContext(msg, deadline)
ctx, cancel := asynqcontext.New(msg, deadline)
p.cancelations.Add(msg.ID.String(), cancel)
defer func() {
cancel()
Expand Down
40 changes: 40 additions & 0 deletions x/rate/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package rate_test

import (
"context"
"fmt"
"time"

"github.com/hibiken/asynq"
"github.com/hibiken/asynq/x/rate"
)

type RateLimitError struct {
RetryIn time.Duration
}

func (e *RateLimitError) Error() string {
return fmt.Sprintf("rate limited (retry in %v)", e.RetryIn)
}

func ExampleNewSemaphore() {
redisConnOpt := asynq.RedisClientOpt{Addr: ":6379"}
sema := rate.NewSemaphore(redisConnOpt, "my_queue", 10)
// call sema.Close() when appropriate

_ = asynq.HandlerFunc(func(ctx context.Context, task *asynq.Task) error {
ok, err := sema.Acquire(ctx)
if err != nil {
return err
}
if !ok {
return &RateLimitError{RetryIn: 30 * time.Second}
}

// Make sure to release the token once we're done.
defer sema.Release(ctx)

// Process task
return nil
})
}
114 changes: 114 additions & 0 deletions x/rate/semaphore.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Package rate contains rate limiting strategies for asynq.Handler(s).
package rate

import (
"context"
"fmt"
"strings"
"time"

"github.com/go-redis/redis/v8"
"github.com/hibiken/asynq"
asynqcontext "github.com/hibiken/asynq/internal/context"
)

// NewSemaphore creates a counting Semaphore for the given scope with the given number of tokens.
func NewSemaphore(rco asynq.RedisConnOpt, scope string, maxTokens int) *Semaphore {
rc, ok := rco.MakeRedisClient().(redis.UniversalClient)
if !ok {
panic(fmt.Sprintf("rate.NewSemaphore: unsupported RedisConnOpt type %T", rco))
}

if maxTokens < 1 {
panic("rate.NewSemaphore: maxTokens cannot be less than 1")
}

if len(strings.TrimSpace(scope)) == 0 {
panic("rate.NewSemaphore: scope should not be empty")
}

return &Semaphore{
rc: rc,
scope: scope,
maxTokens: maxTokens,
}
}

// Semaphore is a distributed counting semaphore which can be used to set maxTokens across multiple asynq servers.
type Semaphore struct {
rc redis.UniversalClient
maxTokens int
scope string
}

// KEYS[1] -> asynq:sema:<scope>
// ARGV[1] -> max concurrency
// ARGV[2] -> current time in unix time
// ARGV[3] -> deadline in unix time
// ARGV[4] -> task ID
var acquireCmd = redis.NewScript(`
redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", tonumber(ARGV[2])-1)
local count = redis.call("ZCARD", KEYS[1])
if (count < tonumber(ARGV[1])) then
redis.call("ZADD", KEYS[1], ARGV[3], ARGV[4])
return 'true'
else
return 'false'
end
`)

// Acquire attempts to acquire a token from the semaphore.
// - Returns (true, nil), iff semaphore key exists and current value is less than maxTokens
// - Returns (false, nil) when token cannot be acquired
// - Returns (false, error) otherwise
//
// The context.Context passed to Acquire must have a deadline set,
// this ensures that token is released if the job goroutine crashes and does not call Release.
func (s *Semaphore) Acquire(ctx context.Context) (bool, error) {
d, ok := ctx.Deadline()
if !ok {
return false, fmt.Errorf("provided context must have a deadline")
}

taskID, ok := asynqcontext.GetTaskID(ctx)
if !ok {
return false, fmt.Errorf("provided context is missing task ID value")
}

return acquireCmd.Run(ctx, s.rc,
[]string{semaphoreKey(s.scope)},
s.maxTokens,
time.Now().Unix(),
d.Unix(),
taskID,
).Bool()
}

// Release will release the token on the counting semaphore.
func (s *Semaphore) Release(ctx context.Context) error {
taskID, ok := asynqcontext.GetTaskID(ctx)
if !ok {
return fmt.Errorf("provided context is missing task ID value")
}

n, err := s.rc.ZRem(ctx, semaphoreKey(s.scope), taskID).Result()
if err != nil {
return fmt.Errorf("redis command failed: %v", err)
}

if n == 0 {
return fmt.Errorf("no token found for task %q", taskID)
}

return nil
}

// Close closes the connection to redis.
func (s *Semaphore) Close() error {
return s.rc.Close()
}

func semaphoreKey(scope string) string {
return fmt.Sprintf("asynq:sema:%s", scope)
}
Loading

0 comments on commit 23c522d

Please sign in to comment.