Skip to content

Commit

Permalink
Add RDB.RunAllAggregatingTasks
Browse files Browse the repository at this point in the history
  • Loading branch information
hibiken committed Apr 11, 2022
1 parent 725105c commit 74db013
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 0 deletions.
51 changes: 51 additions & 0 deletions internal/rdb/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,57 @@ func (r *RDB) RunAllArchivedTasks(qname string) (int64, error) {
return n, nil
}

// runAllAggregatingCmd schedules all tasks in the group to run individually.
//
// Input:
// KEYS[1] -> asynq:{<qname>}:g:<gname>
// KEYS[2] -> asynq:{<qname>}:pending
// KEYS[3] -> asynq:{<qname>}:groups
// -------
// ARGV[1] -> task key prefix
// ARGV[2] -> group name
//
// Output:
// integer: number of tasks scheduled to run
var runAllAggregatingCmd = redis.NewScript(`
local ids = redis.call("ZRANGE", KEYS[1], 0, -1)
for _, id in ipairs(ids) do
redis.call("LPUSH", KEYS[2], id)
redis.call("HSET", ARGV[1] .. id, "state", "pending")
end
redis.call("DEL", KEYS[1])
redis.call("SREM", KEYS[3], ARGV[2])
return table.getn(ids)
`)

// RunAllAggregatingTasks schedules all tasks from the given queue to run
// and returns the number of tasks scheduled to run.
// If a queue with the given name doesn't exist, it returns QueueNotFoundError.
func (r *RDB) RunAllAggregatingTasks(qname, gname string) (int64, error) {
var op errors.Op = "rdb.RunAllAggregatingTasks"
if err := r.checkQueueExists(qname); err != nil {
return 0, errors.E(op, errors.CanonicalCode(err), err)
}
keys := []string{
base.GroupKey(qname, gname),
base.PendingKey(qname),
base.AllGroups(qname),
}
argv := []interface{}{
base.TaskKeyPrefix(qname),
gname,
}
res, err := runAllAggregatingCmd.Run(context.Background(), r.client, keys, argv...).Result()
if err != nil {
return 0, errors.E(op, errors.Internal, err)
}
n, ok := res.(int64)
if !ok {
return 0, errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from script %v", res))
}
return n, nil
}

// runTaskCmd is a Lua script that updates the given task to pending state.
//
// Input:
Expand Down
112 changes: 112 additions & 0 deletions internal/rdb/inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2566,6 +2566,118 @@ func TestRunAllTasksError(t *testing.T) {
if _, got := r.RunAllArchivedTasks(tc.qname); !tc.match(got) {
t.Errorf("%s: RunAllArchivedTasks returned %v", tc.desc, got)
}
if _, got := r.RunAllAggregatingTasks(tc.qname, "mygroup"); !tc.match(got) {
t.Errorf("%s: RunAllAggregatingTasks returned %v", tc.desc, got)
}
}
}

func TestRunAllAggregatingTasks(t *testing.T) {
r := setup(t)
defer r.Close()
now := time.Now()
r.SetClock(timeutil.NewSimulatedClock(now))

m1 := h.NewTaskMessageBuilder().SetQueue("default").SetType("task1").SetGroup("group1").Build()
m2 := h.NewTaskMessageBuilder().SetQueue("default").SetType("task2").SetGroup("group1").Build()
m3 := h.NewTaskMessageBuilder().SetQueue("custom").SetType("task3").SetGroup("group2").Build()

fxt := struct {
tasks []*h.TaskSeedData
allQueues []string
allGroups map[string][]string
groups map[string][]*redis.Z
}{
tasks: []*h.TaskSeedData{
{Msg: m1, State: base.TaskStateAggregating},
{Msg: m2, State: base.TaskStateAggregating},
{Msg: m3, State: base.TaskStateAggregating},
},
allQueues: []string{"default", "custom"},
allGroups: map[string][]string{
base.AllGroups("default"): {"group1"},
base.AllGroups("custom"): {"group2"},
},
groups: map[string][]*redis.Z{
base.GroupKey("default", "group1"): {
{Member: m1.ID, Score: float64(now.Add(-20 * time.Second).Unix())},
{Member: m2.ID, Score: float64(now.Add(-25 * time.Second).Unix())},
},
base.GroupKey("custom", "group2"): {
{Member: m3.ID, Score: float64(now.Add(-20 * time.Second).Unix())},
},
},
}

tests := []struct {
desc string
qname string
gname string
want int64
wantPending map[string][]string
wantGroups map[string][]redis.Z
wantAllGroups map[string][]string
}{
{
desc: "schedules tasks in a group with multiple tasks",
qname: "default",
gname: "group1",
want: 2,
wantPending: map[string][]string{
base.PendingKey("default"): {m1.ID, m2.ID},
},
wantGroups: map[string][]redis.Z{
base.GroupKey("default", "group1"): {},
base.GroupKey("custom", "group2"): {
{Member: m3.ID, Score: float64(now.Add(-20 * time.Second).Unix())},
},
},
wantAllGroups: map[string][]string{
base.AllGroups("default"): {},
base.AllGroups("custom"): {"group2"},
},
},
{
desc: "schedules tasks in a group with a single task",
qname: "custom",
gname: "group2",
want: 1,
wantPending: map[string][]string{
base.PendingKey("custom"): {m3.ID},
},
wantGroups: map[string][]redis.Z{
base.GroupKey("default", "group1"): {
{Member: m1.ID, Score: float64(now.Add(-20 * time.Second).Unix())},
{Member: m2.ID, Score: float64(now.Add(-25 * time.Second).Unix())},
},
base.GroupKey("custom", "group2"): {},
},
wantAllGroups: map[string][]string{
base.AllGroups("default"): {"group1"},
base.AllGroups("custom"): {},
},
},
}

for _, tc := range tests {
h.FlushDB(t, r.client)
h.SeedTasks(t, r.client, fxt.tasks)
h.SeedRedisSet(t, r.client, base.AllQueues, fxt.allQueues)
h.SeedRedisSets(t, r.client, fxt.allGroups)
h.SeedRedisZSets(t, r.client, fxt.groups)

t.Run(tc.desc, func(t *testing.T) {
got, err := r.RunAllAggregatingTasks(tc.qname, tc.gname)
if err != nil {
t.Fatalf("RunAllAggregatingTasks returned error: %v", err)
}
if got != tc.want {
t.Errorf("RunAllAggregatingTasks = %d, want %d", got, tc.want)
}
h.AssertRedisLists(t, r.client, tc.wantPending)
h.AssertRedisZSets(t, r.client, tc.wantGroups)
h.AssertRedisSets(t, r.client, tc.wantAllGroups)
})
}
}

Expand Down
12 changes: 12 additions & 0 deletions internal/testutil/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,18 @@ func SeedRedisLists(tb testing.TB, r redis.UniversalClient, lists map[string][]s
}
}

func AssertRedisLists(t *testing.T, r redis.UniversalClient, wantLists map[string][]string) {
for key, want := range wantLists {
got, err := r.LRange(context.Background(), key, 0, -1).Result()
if err != nil {
t.Fatalf("Failed to read list (key=%q): %v", key, err)
}
if diff := cmp.Diff(want, got, SortStringSliceOpt); diff != "" {
t.Errorf("mismatch found in list (key=%q): (-want,+got)\n%s", key, diff)
}
}
}

func AssertRedisSets(t *testing.T, r redis.UniversalClient, wantSets map[string][]string) {
for key, want := range wantSets {
got, err := r.SMembers(context.Background(), key).Result()
Expand Down

0 comments on commit 74db013

Please sign in to comment.