Skip to content

Commit

Permalink
socket: Make sure errno is always set on error, and always return a m…
Browse files Browse the repository at this point in the history
…eaningful error code
  • Loading branch information
nikias committed Mar 6, 2024
1 parent 6be525c commit fc10c88
Showing 1 changed file with 97 additions and 5 deletions.
102 changes: 97 additions & 5 deletions src/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,51 @@ enum poll_status
poll_status_error
};

#ifdef WIN32
static inline __attribute__((always_inline)) int WSAError_to_errno(int wsaerr)
{
switch (wsaerr) {
case WSAEINVAL:
return EINVAL;
case WSAENOTSOCK:
return ENOTSOCK;
case WSAENOTCONN:
return ENOTCONN;
case WSAESHUTDOWN:
return ENOTCONN;
case WSAECONNRESET:
return ECONNRESET;
case WSAECONNABORTED:
return ECONNABORTED;
case WSAECONNREFUSED:
return ECONNREFUSED;
case WSAENETDOWN:
return ENETDOWN;
case WSAENETRESET:
return ENETRESET;
case WSAEHOSTUNREACH:
return EHOSTUNREACH;
case WSAETIMEDOUT:
return ETIMEDOUT;
case WSAEWOULDBLOCK:
return EWOULDBLOCK;
case WSAEINPROGRESS:
return EINPROGRESS;
case WSAENOBUFS:
return ENOBUFS;
case WSAEINTR:
return EINTR;
case WSAEACCES:
return EACCES;
case WSAEFAULT:
return EFAULT;
default:
break;
}
return wsaerr;
}
#endif

// timeout of -1 means infinity
static inline __attribute__((always_inline)) enum poll_status poll_wrapper(int fd, fd_mode mode, int timeout)
{
Expand Down Expand Up @@ -387,8 +432,17 @@ int socket_connect_unix(const char *filename)
socklen_t len = sizeof(so_error);
getsockopt(sfd, SOL_SOCKET, SO_ERROR, (void*)&so_error, &len);
if (so_error == 0) {
errno = 0;
break;
}
errno = so_error;
} else {
int so_error = 0;
socklen_t len = sizeof(so_error);
getsockopt(sfd, SOL_SOCKET, SO_ERROR, (void*)&so_error, &len);
if (so_error != 0) {
errno = so_error;
}
}
}
socket_close(sfd);
Expand Down Expand Up @@ -1064,7 +1118,20 @@ int socket_connect_addr(struct sockaddr* addr, uint16_t port)
errno = 0;
break;
}
#ifdef WIN32
so_error = WSAError_to_errno(so_error);
#endif
errno = so_error;
} else {
int so_error = 0;
socklen_t len = sizeof(so_error);
getsockopt(sfd, SOL_SOCKET, SO_ERROR, (void*)&so_error, &len);
if (so_error != 0) {
#ifdef WIN32
so_error = WSAError_to_errno(so_error);
#endif
errno = so_error;
}
}
}
socket_close(sfd);
Expand Down Expand Up @@ -1173,8 +1240,23 @@ int socket_connect(const char *addr, uint16_t port)
socklen_t len = sizeof(so_error);
getsockopt(sfd, SOL_SOCKET, SO_ERROR, (void*)&so_error, &len);
if (so_error == 0) {
errno = 0;
break;
}
#ifdef WIN32
so_error = WSAError_to_errno(so_error);
#endif
errno = so_error;
} else {
int so_error = 0;
socklen_t len = sizeof(so_error);
getsockopt(sfd, SOL_SOCKET, SO_ERROR, (void*)&so_error, &len);
if (so_error != 0) {
#ifdef WIN32
so_error = WSAError_to_errno(so_error);
#endif
errno = so_error;
}
}
}
socket_close(sfd);
Expand Down Expand Up @@ -1208,7 +1290,7 @@ int socket_check_fd(int fd, fd_mode fdm, unsigned int timeout)
if (fd < 0) {
if (verbose >= 2)
fprintf(stderr, "ERROR: invalid fd in check_fd %d\n", fd);
return -1;
return -EINVAL;
}

int timeout_ms;
Expand All @@ -1229,10 +1311,10 @@ int socket_check_fd(int fd, fd_mode fdm, unsigned int timeout)
default:
if (verbose >= 2)
fprintf(stderr, "%s: poll_wrapper failed\n", __func__);
return -1;
return -ECONNRESET;
}

return -1;
return -ECONNRESET;
}

int socket_accept(int fd, uint16_t port)
Expand Down Expand Up @@ -1286,13 +1368,16 @@ int socket_receive_timeout(int fd, void *data, size_t length, int flags, unsigne
}
// if we get here, there _is_ data available
result = recv(fd, data, length, flags);
if (res > 0 && result == 0) {
if (result == 0) {
// but this is an error condition
if (verbose >= 3)
fprintf(stderr, "%s: fd=%d recv returned 0\n", __func__, fd);
return -ECONNRESET;
}
if (result < 0) {
#ifdef WIN32
errno = WSAError_to_errno(WSAGetLastError());
#endif
return -errno;
}
return result;
Expand All @@ -1308,7 +1393,14 @@ int socket_send(int fd, void *data, size_t length)
#ifdef MSG_NOSIGNAL
flags |= MSG_NOSIGNAL;
#endif
return send(fd, data, length, flags);
int s = (int)send(fd, data, length, flags);
if (s < 0) {
#ifdef WIN32
errno = WSAError_to_errno(WSAGetLastError());
#endif
return -errno;
}
return s;
}

int socket_get_socket_port(int fd, uint16_t *port)
Expand Down

0 comments on commit fc10c88

Please sign in to comment.