diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index aa029443..a5a1c366 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -922,6 +922,57 @@ func (r *RDB) RunAllArchivedTasks(qname string) (int64, error) { return n, nil } +// runAllAggregatingCmd schedules all tasks in the group to run individually. +// +// Input: +// KEYS[1] -> asynq:{}:g: +// KEYS[2] -> asynq:{}:pending +// KEYS[3] -> asynq:{}:groups +// ------- +// ARGV[1] -> task key prefix +// ARGV[2] -> group name +// +// Output: +// integer: number of tasks scheduled to run +var runAllAggregatingCmd = redis.NewScript(` +local ids = redis.call("ZRANGE", KEYS[1], 0, -1) +for _, id in ipairs(ids) do + redis.call("LPUSH", KEYS[2], id) + redis.call("HSET", ARGV[1] .. id, "state", "pending") +end +redis.call("DEL", KEYS[1]) +redis.call("SREM", KEYS[3], ARGV[2]) +return table.getn(ids) +`) + +// RunAllAggregatingTasks schedules all tasks from the given queue to run +// and returns the number of tasks scheduled to run. +// If a queue with the given name doesn't exist, it returns QueueNotFoundError. +func (r *RDB) RunAllAggregatingTasks(qname, gname string) (int64, error) { + var op errors.Op = "rdb.RunAllAggregatingTasks" + if err := r.checkQueueExists(qname); err != nil { + return 0, errors.E(op, errors.CanonicalCode(err), err) + } + keys := []string{ + base.GroupKey(qname, gname), + base.PendingKey(qname), + base.AllGroups(qname), + } + argv := []interface{}{ + base.TaskKeyPrefix(qname), + gname, + } + res, err := runAllAggregatingCmd.Run(context.Background(), r.client, keys, argv...).Result() + if err != nil { + return 0, errors.E(op, errors.Internal, err) + } + n, ok := res.(int64) + if !ok { + return 0, errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from script %v", res)) + } + return n, nil +} + // runTaskCmd is a Lua script that updates the given task to pending state. // // Input: diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 5cb58ea3..5212920a 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -2566,6 +2566,118 @@ func TestRunAllTasksError(t *testing.T) { if _, got := r.RunAllArchivedTasks(tc.qname); !tc.match(got) { t.Errorf("%s: RunAllArchivedTasks returned %v", tc.desc, got) } + if _, got := r.RunAllAggregatingTasks(tc.qname, "mygroup"); !tc.match(got) { + t.Errorf("%s: RunAllAggregatingTasks returned %v", tc.desc, got) + } + } +} + +func TestRunAllAggregatingTasks(t *testing.T) { + r := setup(t) + defer r.Close() + now := time.Now() + r.SetClock(timeutil.NewSimulatedClock(now)) + + m1 := h.NewTaskMessageBuilder().SetQueue("default").SetType("task1").SetGroup("group1").Build() + m2 := h.NewTaskMessageBuilder().SetQueue("default").SetType("task2").SetGroup("group1").Build() + m3 := h.NewTaskMessageBuilder().SetQueue("custom").SetType("task3").SetGroup("group2").Build() + + fxt := struct { + tasks []*h.TaskSeedData + allQueues []string + allGroups map[string][]string + groups map[string][]*redis.Z + }{ + tasks: []*h.TaskSeedData{ + {Msg: m1, State: base.TaskStateAggregating}, + {Msg: m2, State: base.TaskStateAggregating}, + {Msg: m3, State: base.TaskStateAggregating}, + }, + allQueues: []string{"default", "custom"}, + allGroups: map[string][]string{ + base.AllGroups("default"): {"group1"}, + base.AllGroups("custom"): {"group2"}, + }, + groups: map[string][]*redis.Z{ + base.GroupKey("default", "group1"): { + {Member: m1.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-25 * time.Second).Unix())}, + }, + base.GroupKey("custom", "group2"): { + {Member: m3.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + }, + }, + } + + tests := []struct { + desc string + qname string + gname string + want int64 + wantPending map[string][]string + wantGroups map[string][]redis.Z + wantAllGroups map[string][]string + }{ + { + desc: "schedules tasks in a group with multiple tasks", + qname: "default", + gname: "group1", + want: 2, + wantPending: map[string][]string{ + base.PendingKey("default"): {m1.ID, m2.ID}, + }, + wantGroups: map[string][]redis.Z{ + base.GroupKey("default", "group1"): {}, + base.GroupKey("custom", "group2"): { + {Member: m3.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + }, + }, + wantAllGroups: map[string][]string{ + base.AllGroups("default"): {}, + base.AllGroups("custom"): {"group2"}, + }, + }, + { + desc: "schedules tasks in a group with a single task", + qname: "custom", + gname: "group2", + want: 1, + wantPending: map[string][]string{ + base.PendingKey("custom"): {m3.ID}, + }, + wantGroups: map[string][]redis.Z{ + base.GroupKey("default", "group1"): { + {Member: m1.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-25 * time.Second).Unix())}, + }, + base.GroupKey("custom", "group2"): {}, + }, + wantAllGroups: map[string][]string{ + base.AllGroups("default"): {"group1"}, + base.AllGroups("custom"): {}, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + h.SeedTasks(t, r.client, fxt.tasks) + h.SeedRedisSet(t, r.client, base.AllQueues, fxt.allQueues) + h.SeedRedisSets(t, r.client, fxt.allGroups) + h.SeedRedisZSets(t, r.client, fxt.groups) + + t.Run(tc.desc, func(t *testing.T) { + got, err := r.RunAllAggregatingTasks(tc.qname, tc.gname) + if err != nil { + t.Fatalf("RunAllAggregatingTasks returned error: %v", err) + } + if got != tc.want { + t.Errorf("RunAllAggregatingTasks = %d, want %d", got, tc.want) + } + h.AssertRedisLists(t, r.client, tc.wantPending) + h.AssertRedisZSets(t, r.client, tc.wantGroups) + h.AssertRedisSets(t, r.client, tc.wantAllGroups) + }) } } diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index d22904e4..33a87f4a 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -605,6 +605,18 @@ func SeedRedisLists(tb testing.TB, r redis.UniversalClient, lists map[string][]s } } +func AssertRedisLists(t *testing.T, r redis.UniversalClient, wantLists map[string][]string) { + for key, want := range wantLists { + got, err := r.LRange(context.Background(), key, 0, -1).Result() + if err != nil { + t.Fatalf("Failed to read list (key=%q): %v", key, err) + } + if diff := cmp.Diff(want, got, SortStringSliceOpt); diff != "" { + t.Errorf("mismatch found in list (key=%q): (-want,+got)\n%s", key, diff) + } + } +} + func AssertRedisSets(t *testing.T, r redis.UniversalClient, wantSets map[string][]string) { for key, want := range wantSets { got, err := r.SMembers(context.Background(), key).Result()