Skip to content

Commit

Permalink
Handle protocol_status_error, rename conn to socket (#3359)
Browse files Browse the repository at this point in the history
* Handle protocol_status_error, rename conn to socket

* Fix typo

* Fix typo again
  • Loading branch information
twose authored Jun 5, 2020
1 parent 916478b commit bc5d853
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 43 deletions.
6 changes: 3 additions & 3 deletions include/swoole.h
Original file line number Diff line number Diff line change
Expand Up @@ -2401,9 +2401,9 @@ int swThreadPool_run(swThreadPool *pool);
int swThreadPool_free(swThreadPool *pool);

//--------------------------------protocol------------------------------
ssize_t swProtocol_get_package_length(swProtocol *protocol, swSocket *conn, const char *data, uint32_t size);
int swProtocol_recv_check_length(swProtocol *protocol, swSocket *conn, swString *buffer);
int swProtocol_recv_check_eof(swProtocol *protocol, swSocket *conn, swString *buffer);
ssize_t swProtocol_get_package_length(swProtocol *protocol, swSocket *socket, const char *data, uint32_t size);
int swProtocol_recv_check_length(swProtocol *protocol, swSocket *socket, swString *buffer);
int swProtocol_recv_check_eof(swProtocol *protocol, swSocket *socket, swString *buffer);

//--------------------------------timer------------------------------
#define SW_TIMER_MIN_MS 1
Expand Down
83 changes: 50 additions & 33 deletions src/protocol/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@ using namespace swoole;
/**
* return the package total length
*/
ssize_t swProtocol_get_package_length(swProtocol *protocol, swSocket *conn, const char *data, uint32_t size)
ssize_t swProtocol_get_package_length(swProtocol *protocol, swSocket *socket, const char *data, uint32_t size)
{
uint16_t length_offset = protocol->package_length_offset;
uint8_t package_length_size = protocol->get_package_length_size ? protocol->get_package_length_size(conn) : protocol->package_length_size;
uint8_t package_length_size = protocol->get_package_length_size ? protocol->get_package_length_size(socket) : protocol->package_length_size;
int32_t body_length;

if (package_length_size == 0)
{
// protocol error
return SW_ERR;
}
/**
* no have length field, wait more data
*/
Expand All @@ -41,17 +47,22 @@ ssize_t swProtocol_get_package_length(swProtocol *protocol, swSocket *conn, cons
//Protocol length is not legitimate, out of bounds or exceed the allocated length
if (body_length < 0)
{
swWarn("invalid package, remote_addr=%s:%d, length=%d, size=%d",
swSocket_get_ip(conn->socket_type, &conn->info),
swSocket_get_port(conn->socket_type, &conn->info), body_length, size);
swConnection *conn = (swConnection *) socket->object;
swWarn(
"invalid package (size=%d) from session#%u<%s:%d>",
size, conn->session_id,
swSocket_get_ip(socket->socket_type, &socket->info),
swSocket_get_port(socket->socket_type, &socket->info)
);
return SW_ERR;
}
swDebug("length=%d", protocol->package_body_offset + body_length);

//total package length
return protocol->package_body_offset + body_length;
}

static sw_inline int swProtocol_split_package_by_eof(swProtocol *protocol, swSocket *conn, swString *buffer)
static sw_inline int swProtocol_split_package_by_eof(swProtocol *protocol, swSocket *socket, swString *buffer)
{
if (buffer->length < protocol->package_eof_len)
{
Expand All @@ -61,19 +72,19 @@ static sw_inline int swProtocol_split_package_by_eof(swProtocol *protocol, swSoc
int retval;

size_t n = string_split(buffer, protocol->package_eof, protocol->package_eof_len, [&](char *data, size_t length) -> int {
if (protocol->onPackage(protocol, conn, data, length) < 0)
if (protocol->onPackage(protocol, socket, data, length) < 0)
{
retval = SW_CLOSE;
return false;
}
if (conn->removed)
if (socket->removed)
{
return false;
}
return true;
});

if (conn->removed)
if (socket->removed)
{
return SW_CLOSE;
}
Expand All @@ -99,7 +110,7 @@ static sw_inline int swProtocol_split_package_by_eof(swProtocol *protocol, swSoc
}

#ifdef SW_USE_OPENSSL
if (conn->ssl)
if (socket->ssl)
{
return SW_CONTINUE;
}
Expand All @@ -112,21 +123,27 @@ static sw_inline int swProtocol_split_package_by_eof(swProtocol *protocol, swSoc
* @return SW_ERR: close the connection
* @return SW_OK: continue
*/
int swProtocol_recv_check_length(swProtocol *protocol, swSocket *conn, swString *buffer)
int swProtocol_recv_check_length(swProtocol *protocol, swSocket *socket, swString *buffer)
{
ssize_t package_length;
uint8_t package_length_size = protocol->get_package_length_size ? protocol->get_package_length_size(conn) : protocol->package_length_size;
uint8_t package_length_size = protocol->get_package_length_size ? protocol->get_package_length_size(socket) : protocol->package_length_size;
uint32_t recv_size;
ssize_t recv_n = 0;

if (conn->skip_recv)
if (package_length_size == 0)
{
// protocol error
return SW_ERR;
}

if (socket->skip_recv)
{
conn->skip_recv = 0;
socket->skip_recv = 0;
goto _do_get_length;
}

_do_recv:
if (conn->removed)
if (socket->removed)
{
return SW_OK;
}
Expand All @@ -139,13 +156,13 @@ int swProtocol_recv_check_length(swProtocol *protocol, swSocket *conn, swString
recv_size = protocol->package_length_offset + package_length_size;
}

recv_n = swSocket_recv(conn, buffer->str + buffer->length, recv_size, 0);
recv_n = swSocket_recv(socket, buffer->str + buffer->length, recv_size, 0);
if (recv_n < 0)
{
switch (swSocket_error(errno))
{
case SW_ERROR:
swSysWarn("recv(%d, %d) failed", conn->fd, recv_size);
swSysWarn("recv(%d, %d) failed", socket->fd, recv_size);
return SW_OK;
case SW_CLOSE:
return SW_ERR;
Expand All @@ -161,20 +178,20 @@ int swProtocol_recv_check_length(swProtocol *protocol, swSocket *conn, swString
{
buffer->length += recv_n;

if (conn->recv_wait)
if (socket->recv_wait)
{
if (buffer->length >= (size_t) buffer->offset)
{
_do_dispatch:
if (protocol->onPackage(protocol, conn, buffer->str, buffer->offset) < 0)
if (protocol->onPackage(protocol, socket, buffer->str, buffer->offset) < 0)
{
return SW_ERR;
}
if (conn->removed)
if (socket->removed)
{
return SW_OK;
}
conn->recv_wait = 0;
socket->recv_wait = 0;

if (buffer->length > (size_t) buffer->offset)
{
Expand All @@ -187,7 +204,7 @@ int swProtocol_recv_check_length(swProtocol *protocol, swSocket *conn, swString
}
}
#ifdef SW_USE_OPENSSL
if (conn->ssl)
if (socket->ssl)
{
goto _do_recv;
}
Expand All @@ -197,7 +214,7 @@ int swProtocol_recv_check_length(swProtocol *protocol, swSocket *conn, swString
else
{
_do_get_length:
package_length = protocol->get_package_length(protocol, conn, buffer->str, buffer->length);
package_length = protocol->get_package_length(protocol, socket, buffer->str, buffer->length);
//invalid package, close connection.
if (package_length < 0)
{
Expand All @@ -220,8 +237,8 @@ int swProtocol_recv_check_length(swProtocol *protocol, swSocket *conn, swString
{
swoole_error_log(SW_LOG_WARNING, SW_ERROR_PACKAGE_LENGTH_TOO_LARGE,
"package is too big, remote_addr=%s:%d, length=%zu",
swSocket_get_ip(conn->socket_type, &conn->info),
swSocket_get_port(conn->socket_type, &conn->info), package_length);
swSocket_get_ip(socket->socket_type, &socket->info),
swSocket_get_port(socket->socket_type, &socket->info), package_length);
return SW_ERR;
}
//get length success
Expand All @@ -234,7 +251,7 @@ int swProtocol_recv_check_length(swProtocol *protocol, swSocket *conn, swString
return SW_ERR;
}
}
conn->recv_wait = 1;
socket->recv_wait = 1;
buffer->offset = package_length;

if (buffer->length >= (size_t) package_length)
Expand All @@ -255,7 +272,7 @@ int swProtocol_recv_check_length(swProtocol *protocol, swSocket *conn, swString
* @return SW_ERR: close the connection
* @return SW_OK: continue
*/
int swProtocol_recv_check_eof(swProtocol *protocol, swSocket *conn, swString *buffer)
int swProtocol_recv_check_eof(swProtocol *protocol, swSocket *socket, swString *buffer)
{
int recv_again = SW_FALSE;
int buf_size;
Expand All @@ -269,13 +286,13 @@ int swProtocol_recv_check_eof(swProtocol *protocol, swSocket *conn, swString *bu
buf_size = SW_BUFFER_SIZE_STD;
}

int n = swSocket_recv(conn, buf_ptr, buf_size, 0);
int n = swSocket_recv(socket, buf_ptr, buf_size, 0);
if (n < 0)
{
switch (swSocket_error(errno))
{
case SW_ERROR:
swSysWarn("recv from socket#%d failed", conn->fd);
swSysWarn("recv from socket#%d failed", socket->fd);
return SW_OK;
case SW_CLOSE:
return SW_ERR;
Expand All @@ -298,7 +315,7 @@ int swProtocol_recv_check_eof(swProtocol *protocol, swSocket *conn, swString *bu

if (protocol->split_by_eof)
{
int retval = swProtocol_split_package_by_eof(protocol, conn, buffer);
int retval = swProtocol_split_package_by_eof(protocol, socket, buffer);
if (retval == SW_CONTINUE)
{
recv_again = SW_TRUE;
Expand All @@ -314,17 +331,17 @@ int swProtocol_recv_check_eof(swProtocol *protocol, swSocket *conn, swString *bu
}
else if (memcmp(buffer->str + buffer->length - protocol->package_eof_len, protocol->package_eof, protocol->package_eof_len) == 0)
{
if (protocol->onPackage(protocol, conn, buffer->str, buffer->length) < 0)
if (protocol->onPackage(protocol, socket, buffer->str, buffer->length) < 0)
{
return SW_ERR;
}
if (conn->removed)
if (socket->removed)
{
return SW_OK;
}
swString_clear(buffer);
#ifdef SW_USE_OPENSSL
if (conn->ssl && SSL_pending(conn->ssl) > 0)
if (socket->ssl && SSL_pending(socket->ssl) > 0)
{
goto _recv_data;
}
Expand Down
22 changes: 15 additions & 7 deletions src/protocol/http.cc
Original file line number Diff line number Diff line change
Expand Up @@ -825,12 +825,20 @@ string swHttpRequest_get_date_if_modified_since(swHttpRequest *request)
return string("");
}


#ifdef SW_USE_HTTP2
static void protocol_status_error(swSocket *socket, swConnection *conn)
{
swoole_error_log(
SW_LOG_WARNING, SW_ERROR_PROTOCOL_ERROR,
"unexpected protocol status of session#%u<%s:%d>",
conn->session_id, swSocket_get_ip(conn->socket_type, &conn->info), swSocket_get_port(conn->socket_type, &conn->info)
);
}

ssize_t swHttpMix_get_package_length(swProtocol *protocol, swSocket *socket, const char *data, uint32_t length)
{
swConnection *conn = (swConnection *) socket->object;
if (conn->websocket_status == WEBSOCKET_STATUS_ACTIVE)
if (conn->websocket_status >= WEBSOCKET_STATUS_HANDSHAKE)
{
return swWebSocket_get_package_length(protocol, socket, data, length);
}
Expand All @@ -840,15 +848,15 @@ ssize_t swHttpMix_get_package_length(swProtocol *protocol, swSocket *socket, con
}
else
{
abort();
protocol_status_error(socket, conn);
return SW_ERR;
}
}

uint8_t swHttpMix_get_package_length_size(swSocket *socket)
{
swConnection *conn = (swConnection *) socket->object;
if (conn->websocket_status == WEBSOCKET_STATUS_ACTIVE)
if (conn->websocket_status >= WEBSOCKET_STATUS_HANDSHAKE)
{
return SW_WEBSOCKET_HEADER_LEN + SW_WEBSOCKET_MASK_LEN + sizeof(uint64_t);
}
Expand All @@ -858,15 +866,15 @@ uint8_t swHttpMix_get_package_length_size(swSocket *socket)
}
else
{
abort();
protocol_status_error(socket, conn);
return 0;
}
}

int swHttpMix_dispatch_frame(swProtocol *proto, swSocket *socket, const char *data, uint32_t length)
{
swConnection *conn = (swConnection *) socket->object;
if (conn->websocket_status == WEBSOCKET_STATUS_ACTIVE)
if (conn->websocket_status >= WEBSOCKET_STATUS_HANDSHAKE)
{
return swWebSocket_dispatch_frame(proto, socket, data, length);
}
Expand All @@ -876,7 +884,7 @@ int swHttpMix_dispatch_frame(swProtocol *proto, swSocket *socket, const char *da
}
else
{
abort();
protocol_status_error(socket, conn);
return SW_ERR;
}
}
Expand Down

0 comments on commit bc5d853

Please sign in to comment.