Skip to content

Commit

Permalink
Refactor: move slave pkg inside of cluster
Browse files Browse the repository at this point in the history
Test: middleware for node communication
  • Loading branch information
HFO4 committed Nov 8, 2021
1 parent eaa0f6b commit e41ec9d
Show file tree
Hide file tree
Showing 16 changed files with 135 additions and 43 deletions.
3 changes: 1 addition & 2 deletions bootstrap/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/crontab"
"github.com/cloudreve/Cloudreve/v3/pkg/email"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/gin-gonic/gin"
)
Expand Down Expand Up @@ -78,7 +77,7 @@ func Init(path string) {
{
"slave",
func() {
slave.Init()
cluster.InitController()
},
},
{
Expand Down
1 change: 1 addition & 0 deletions middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
c.Abort()
return
}

c.Next()
}
}
Expand Down
17 changes: 14 additions & 3 deletions middleware/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,27 @@ func TestSignRequired(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
SignRequiredFunc := SignRequired(auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))})
authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
SignRequiredFunc := SignRequired(authInstance)

// 鉴权失败
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())

c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())

// Sign verify success
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
c.Request = auth.SignRequest(authInstance, c.Request, 0)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.False(c.IsAborted())
}

func TestWebDAVAuth(t *testing.T) {
Expand Down Expand Up @@ -780,14 +792,13 @@ func TestS3CallbackAuth(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[702]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackUpyun"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
AuthFunc(c)
asserts.False(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
}
}
9 changes: 4 additions & 5 deletions middleware/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package middleware
import (
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/gin-gonic/gin"
"strconv"
)
Expand All @@ -19,11 +18,11 @@ func MasterMetadata() gin.HandlerFunc {
}

// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例
func UseSlaveAria2Instance() gin.HandlerFunc {
func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc {
return func(c *gin.Context) {
if siteID, exist := c.Get("MasterSiteID"); exist {
// 获取对应主机节点的从机Aria2实例
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
caller, err := clusterController.GetAria2Instance(siteID.(string))
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err))
c.Abort()
Expand All @@ -40,7 +39,7 @@ func UseSlaveAria2Instance() gin.HandlerFunc {
}
}

func SlaveRPCSignRequired() gin.HandlerFunc {
func SlaveRPCSignRequired(nodePool cluster.Pool) gin.HandlerFunc {
return func(c *gin.Context) {
nodeID, err := strconv.ParseUint(c.GetHeader("X-Node-Id"), 10, 64)
if err != nil {
Expand All @@ -49,7 +48,7 @@ func SlaveRPCSignRequired() gin.HandlerFunc {
return
}

slaveNode := cluster.Default.GetNodeByID(uint(nodeID))
slaveNode := nodePool.GetNodeByID(uint(nodeID))
if slaveNode == nil {
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
c.Abort()
Expand Down
80 changes: 80 additions & 0 deletions middleware/cluster_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package middleware

import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)

func TestMasterMetadata(t *testing.T) {
a := assert.New(t)
masterMetaDataFunc := MasterMetadata()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)

c.Request.Header = map[string][]string{
"X-Site-Id": {"expectedSiteID"},
"X-Site-Url": {"expectedSiteURL"},
"X-Cloudreve-Version": {"expectedMasterVersion"},
}
masterMetaDataFunc(c)
siteID, _ := c.Get("MasterSiteID")
siteURL, _ := c.Get("MasterSiteURL")
siteVersion, _ := c.Get("MasterVersion")

a.Equal("expectedSiteID", siteID.(string))
a.Equal("expectedSiteURL", siteURL.(string))
a.Equal("expectedMasterVersion", siteVersion.(string))
}

func TestSlaveRPCSignRequired(t *testing.T) {
a := assert.New(t)
np := &cluster.NodePool{}
np.Init()
slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np)
rec := httptest.NewRecorder()

// id parse failed
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header.Set("X-Node-Id", "unknown")
slaveRPCSignRequiredFunc(c)
a.True(c.IsAborted())
}

// node id not exist
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header.Set("X-Node-Id", "38")
slaveRPCSignRequiredFunc(c)
a.True(c.IsAborted())
}

// success
{
authInstance := auth.HMACAuth{SecretKey: []byte("")}
np.Add(&model.Node{Model: gorm.Model{
ID: 38,
}})

c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("POST", "/", nil)
c.Request.Header.Set("X-Node-Id", "38")
c.Request = auth.SignRequest(authInstance, c.Request, 0)
slaveRPCSignRequiredFunc(c)
a.False(c.IsAborted())
}
}

func TestUseSlaveAria2Instance(t *testing.T) {
a := assert.New(t)

}
2 changes: 1 addition & 1 deletion models/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type User struct {
Storage uint64
TwoFactor string
Avatar string
Options string `json:"-",gorm:"type:text"`
Options string `json:"-" gorm:"type:text"`
Authn string `gorm:"type:text"`

// 关联模型
Expand Down
9 changes: 4 additions & 5 deletions pkg/slave/slave.go → pkg/cluster/controller.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package slave
package cluster

import (
"bytes"
Expand All @@ -8,7 +8,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
Expand Down Expand Up @@ -51,13 +50,13 @@ type MasterInfo struct {
TTL int
URL *url.URL
// used to invoke aria2 rpc calls
Instance cluster.Node
Instance Node
Client request.Client

jobTracker map[string]bool
}

func Init() {
func InitController() {
DefaultController = &slaveController{
masters: make(map[string]MasterInfo),
}
Expand Down Expand Up @@ -95,7 +94,7 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
}, int64(req.CredentialTTL)),
),
jobTracker: make(map[string]bool),
Instance: cluster.NewNodeFromDBModel(&model.Node{
Instance: NewNodeFromDBModel(&model.Node{
Model: gorm.Model{ID: req.Node.ID},
MasterKey: req.Node.MasterKey,
Type: model.MasterNodeType,
Expand Down
6 changes: 5 additions & 1 deletion pkg/cluster/errors.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package cluster

import "errors"
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)

var (
ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed")
ErrIlegalPath = errors.New("path out of boundary of setting temp folder")
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil)
)
16 changes: 11 additions & 5 deletions pkg/cluster/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,22 @@ type NodePool struct {

// Init 初始化从机节点池
func Init() {
Default = &NodePool{
featureMap: make(map[string][]Node),
}
Default = &NodePool{}
Default.Init()
if err := Default.initFromDB(); err != nil {
util.Log().Warning("节点池初始化失败, %s", err)
}
}

func (pool *NodePool) Init() {
pool.lock.Lock()
defer pool.lock.Unlock()

pool.featureMap = make(map[string][]Node)
pool.active = make(map[uint]Node)
pool.inactive = make(map[uint]Node)
}

func (pool *NodePool) buildIndexMap() {
pool.lock.Lock()
for _, feature := range featureGroup {
Expand Down Expand Up @@ -98,8 +106,6 @@ func (pool *NodePool) initFromDB() error {
}

pool.lock.Lock()
pool.active = make(map[uint]Node)
pool.inactive = make(map[uint]Node)
for i := 0; i < len(nodes); i++ {
pool.add(&nodes[i])
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/filesystem/driver/onedrive/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package onedrive
import (
"context"
"encoding/json"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"io/ioutil"
"net/http"
"net/url"
Expand All @@ -12,7 +13,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)

Expand Down Expand Up @@ -179,7 +179,7 @@ func (client *Client) UpdateCredential(ctx context.Context) error {

// UpdateCredential 更新凭证,并检查有效期
func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
res, err := slave.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID)
res, err := cluster.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID)
if err != nil {
return err
}
Expand Down
7 changes: 0 additions & 7 deletions pkg/slave/errors.go

This file was deleted.

8 changes: 4 additions & 4 deletions pkg/task/slavetask/transfer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package slavetask
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"os"
Expand Down Expand Up @@ -68,7 +68,7 @@ func (job *TransferTask) SetErrorMsg(msg string, err error) {
},
}

if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil {
if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil {
util.Log().Warning("无法发送转存失败通知到从机, ", err)
}
}
Expand All @@ -94,7 +94,7 @@ func (job *TransferTask) Do() {
return
}

master, err := slave.DefaultController.GetMasterInfo(job.MasterID)
master, err := cluster.DefaultController.GetMasterInfo(job.MasterID)
if err != nil {
job.SetErrorMsg("找不到主机节点", err)
return
Expand Down Expand Up @@ -131,7 +131,7 @@ func (job *TransferTask) Do() {
Content: serializer.SlaveTransferResult{},
}

if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil {
if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil {
util.Log().Warning("无法发送转存成功通知到从机, ", err)
}
}
Expand Down
5 changes: 3 additions & 2 deletions routers/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package routers
import (
"github.com/cloudreve/Cloudreve/v3/middleware"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
Expand Down Expand Up @@ -59,7 +60,7 @@ func InitSlaveRouter() *gin.Engine {

// 离线下载
aria2 := v3.Group("aria2")
aria2.Use(middleware.UseSlaveAria2Instance())
aria2.Use(middleware.UseSlaveAria2Instance(cluster.DefaultController))
{
// 创建离线下载任务
aria2.POST("task", controllers.SlaveAria2Create)
Expand Down Expand Up @@ -205,7 +206,7 @@ func InitMasterRouter() *gin.Engine {

// 从机的 RPC 通信
slave := v3.Group("slave")
slave.Use(middleware.SlaveRPCSignRequired())
slave.Use(middleware.SlaveRPCSignRequired(cluster.Default))
{
// 事件通知
slave.PUT("notification/:subject", controllers.SlaveNotificationPush)
Expand Down
Loading

0 comments on commit e41ec9d

Please sign in to comment.