af_unix: improve SCM_RIGHTS file descriptor retrieval

parent da63ea6b
...@@ -112,22 +112,18 @@ int lxc_abstract_unix_connect(const char *path) ...@@ -112,22 +112,18 @@ int lxc_abstract_unix_connect(const char *path)
return move_fd(fd); return move_fd(fd);
} }
int lxc_abstract_unix_send_fds_iov(int fd, int *sendfds, int num_sendfds, int lxc_abstract_unix_send_fds_iov(int fd, const int *sendfds, int num_sendfds,
struct iovec *iov, size_t iovlen) struct iovec *iov, size_t iovlen)
{ {
__do_free char *cmsgbuf = NULL; __do_free char *cmsgbuf = NULL;
int ret; int ret;
struct msghdr msg; struct msghdr msg = {};
struct cmsghdr *cmsg = NULL; struct cmsghdr *cmsg = NULL;
size_t cmsgbufsize = CMSG_SPACE(num_sendfds * sizeof(int)); size_t cmsgbufsize = CMSG_SPACE(num_sendfds * sizeof(int));
memset(&msg, 0, sizeof(msg));
cmsgbuf = malloc(cmsgbufsize); cmsgbuf = malloc(cmsgbufsize);
if (!cmsgbuf) { if (!cmsgbuf)
errno = ENOMEM; return ret_errno(-ENOMEM);
return -1;
}
msg.msg_control = cmsgbuf; msg.msg_control = cmsgbuf;
msg.msg_controllen = cmsgbufsize; msg.msg_controllen = cmsgbufsize;
...@@ -151,13 +147,13 @@ int lxc_abstract_unix_send_fds_iov(int fd, int *sendfds, int num_sendfds, ...@@ -151,13 +147,13 @@ int lxc_abstract_unix_send_fds_iov(int fd, int *sendfds, int num_sendfds,
return ret; return ret;
} }
int lxc_abstract_unix_send_fds(int fd, int *sendfds, int num_sendfds, int lxc_abstract_unix_send_fds(int fd, const int *sendfds, int num_sendfds,
void *data, size_t size) void *data, size_t size)
{ {
char buf[1] = {0}; char buf[1] = {};
struct iovec iov = { struct iovec iov = {
.iov_base = data ? data : buf, .iov_base = data ? data : buf,
.iov_len = data ? size : sizeof(buf), .iov_len = data ? size : sizeof(buf),
}; };
return lxc_abstract_unix_send_fds_iov(fd, sendfds, num_sendfds, &iov, 1); return lxc_abstract_unix_send_fds_iov(fd, sendfds, num_sendfds, &iov, 1);
} }
...@@ -168,60 +164,168 @@ int lxc_unix_send_fds(int fd, int *sendfds, int num_sendfds, void *data, ...@@ -168,60 +164,168 @@ int lxc_unix_send_fds(int fd, int *sendfds, int num_sendfds, void *data,
return lxc_abstract_unix_send_fds(fd, sendfds, num_sendfds, data, size); return lxc_abstract_unix_send_fds(fd, sendfds, num_sendfds, data, size);
} }
static int lxc_abstract_unix_recv_fds_iov(int fd, int *recvfds, int num_recvfds, static ssize_t lxc_abstract_unix_recv_fds_iov(int fd,
struct iovec *iov, size_t iovlen) struct unix_fds *ret_fds,
struct iovec *ret_iov,
size_t size_ret_iov)
{ {
__do_free char *cmsgbuf = NULL; __do_free char *cmsgbuf = NULL;
int ret; ssize_t ret;
struct msghdr msg; struct msghdr msg = {};
struct cmsghdr *cmsg = NULL;
size_t cmsgbufsize = CMSG_SPACE(sizeof(struct ucred)) + size_t cmsgbufsize = CMSG_SPACE(sizeof(struct ucred)) +
CMSG_SPACE(num_recvfds * sizeof(int)); CMSG_SPACE(ret_fds->fd_count_max * sizeof(int));
memset(&msg, 0, sizeof(msg)); cmsgbuf = zalloc(cmsgbufsize);
cmsgbuf = malloc(cmsgbufsize);
if (!cmsgbuf) if (!cmsgbuf)
return ret_errno(ENOMEM); return ret_errno(ENOMEM);
msg.msg_control = cmsgbuf; msg.msg_control = cmsgbuf;
msg.msg_controllen = cmsgbufsize; msg.msg_controllen = cmsgbufsize;
msg.msg_iov = iov; msg.msg_iov = ret_iov;
msg.msg_iovlen = iovlen; msg.msg_iovlen = size_ret_iov;
do { again:
ret = recvmsg(fd, &msg, MSG_CMSG_CLOEXEC); ret = recvmsg(fd, &msg, MSG_CMSG_CLOEXEC);
} while (ret < 0 && errno == EINTR); if (ret < 0) {
if (ret < 0 || ret == 0) if (errno == EINTR)
return ret; goto again;
/* return syserrno(-errno, "Failed to receive response");
* If SO_PASSCRED is set we will always get a ucred message. }
*/ if (ret == 0)
for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) { return 0;
if (cmsg->cmsg_type != SCM_RIGHTS)
continue; /* If SO_PASSCRED is set we will always get a ucred message. */
for (cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
memset(recvfds, -1, num_recvfds * sizeof(int)); if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
if (cmsg && __u32 idx;
cmsg->cmsg_len == CMSG_LEN(num_recvfds * sizeof(int)) && #pragma GCC diagnostic push
cmsg->cmsg_level == SOL_SOCKET) #pragma GCC diagnostic ignored "-Wcast-align"
memcpy(recvfds, CMSG_DATA(cmsg), num_recvfds * sizeof(int)); int *fds_raw = (int *)CMSG_DATA(cmsg);
break; #pragma GCC diagnostic pop
__u32 num_raw = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
/*
* We received an insane amount of file descriptors
* which exceeds the kernel limit we know about so
* close them and return an error.
*/
if (num_raw > KERNEL_SCM_MAX_FD) {
for (idx = 0; idx < num_raw; idx++)
close(fds_raw[idx]);
return syserrno_set(-EFBIG, "Received excessive number of file descriptors");
}
if (ret_fds->fd_count_max > num_raw) {
/*
* Make sure any excess entries in the fd array
* are set to -EBADF so our cleanup functions
* can safely be called.
*/
for (idx = num_raw; idx < ret_fds->fd_count_max; idx++)
ret_fds->fd[idx] = -EBADF;
WARN("Received fewer file descriptors than we expected %u != %u", ret_fds->fd_count_max, num_raw);
} else if (ret_fds->fd_count_max < num_raw) {
/* Make sure we close any excess fds we received. */
for (idx = ret_fds->fd_count_max; idx < num_raw; idx++)
close(fds_raw[idx]);
WARN("Received more file descriptors than we expected %u != %u", ret_fds->fd_count_max, num_raw);
/* Cap the number of received file descriptors. */
num_raw = ret_fds->fd_count_max;
}
memcpy(ret_fds->fd, CMSG_DATA(cmsg), num_raw * sizeof(int));
ret_fds->fd_count_ret = num_raw;
break;
}
} }
return ret; return ret;
} }
int lxc_abstract_unix_recv_fds(int fd, int *recvfds, int num_recvfds, ssize_t lxc_abstract_unix_recv_fds(int fd, struct unix_fds *ret_fds,
void *data, size_t size) void *ret_data, size_t size_ret_data)
{ {
char buf[1] = {0}; char buf[1] = {};
struct iovec iov = {
.iov_base = ret_data ? ret_data : buf,
.iov_len = ret_data ? size_ret_data : sizeof(buf),
};
ssize_t ret;
ret = lxc_abstract_unix_recv_fds_iov(fd, ret_fds, &iov, 1);
if (ret < 0)
return ret;
return ret;
}
ssize_t lxc_abstract_unix_recv_one_fd(int fd, int *ret_fd, void *ret_data,
size_t size_ret_data)
{
call_cleaner(put_unix_fds) struct unix_fds *fds = NULL;
char buf[1] = {};
struct iovec iov = { struct iovec iov = {
.iov_base = data ? data : buf, .iov_base = ret_data ? ret_data : buf,
.iov_len = data ? size : sizeof(buf), .iov_len = ret_data ? size_ret_data : sizeof(buf),
};
ssize_t ret;
fds = &(struct unix_fds){
.fd_count_max = 1,
}; };
return lxc_abstract_unix_recv_fds_iov(fd, recvfds, num_recvfds, &iov, 1);
ret = lxc_abstract_unix_recv_fds_iov(fd, fds, &iov, 1);
if (ret < 0)
return ret;
if (ret == 0)
return ret_errno(ENODATA);
if (fds->fd_count_ret != fds->fd_count_max)
*ret_fd = -EBADF;
else
*ret_fd = move_fd(fds->fd[0]);
return ret;
}
ssize_t lxc_abstract_unix_recv_two_fds(int fd, int *ret_fd)
{
call_cleaner(put_unix_fds) struct unix_fds *fds = NULL;
char buf[1] = {};
struct iovec iov = {
.iov_base = buf,
.iov_len = sizeof(buf),
};
ssize_t ret;
fds = &(struct unix_fds){
.fd_count_max = 2,
};
ret = lxc_abstract_unix_recv_fds_iov(fd, fds, &iov, 1);
if (ret < 0)
return ret;
if (ret == 0)
return ret_errno(ENODATA);
if (fds->fd_count_ret != fds->fd_count_max) {
ret_fd[0] = -EBADF;
ret_fd[1] = -EBADF;
} else {
ret_fd[0] = move_fd(fds->fd[0]);
ret_fd[1] = move_fd(fds->fd[1]);
}
return 0;
} }
int lxc_abstract_unix_send_credential(int fd, void *data, size_t size) int lxc_abstract_unix_send_credential(int fd, void *data, size_t size)
......
...@@ -5,9 +5,24 @@ ...@@ -5,9 +5,24 @@
#include <stdio.h> #include <stdio.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <stddef.h>
#include <sys/un.h> #include <sys/un.h>
#include "compiler.h" #include "compiler.h"
#include "macro.h"
#include "memory_utils.h"
/*
* Technically 253 is the kernel limit but we want to the struct to be a
* multiple of 8.
*/
#define KERNEL_SCM_MAX_FD 252
struct unix_fds {
__u32 fd_count_max;
__u32 fd_count_ret;
__s32 fd[KERNEL_SCM_MAX_FD];
} __attribute__((aligned(8)));
/* does not enforce \0-termination */ /* does not enforce \0-termination */
__hidden extern int lxc_abstract_unix_open(const char *path, int type, int flags); __hidden extern int lxc_abstract_unix_open(const char *path, int type, int flags);
...@@ -15,14 +30,29 @@ __hidden extern void lxc_abstract_unix_close(int fd); ...@@ -15,14 +30,29 @@ __hidden extern void lxc_abstract_unix_close(int fd);
/* does not enforce \0-termination */ /* does not enforce \0-termination */
__hidden extern int lxc_abstract_unix_connect(const char *path); __hidden extern int lxc_abstract_unix_connect(const char *path);
__hidden extern int lxc_abstract_unix_send_fds(int fd, int *sendfds, int num_sendfds, void *data, __hidden extern int lxc_abstract_unix_send_fds(int fd, const int *sendfds,
size_t size) __access_r(2, 3) __access_r(4, 5); int num_sendfds, void *data,
size_t size) __access_r(2, 3)
__access_r(4, 5);
__hidden extern int lxc_abstract_unix_send_fds_iov(int fd, const int *sendfds,
int num_sendfds,
struct iovec *iov,
size_t iovlen)
__access_r(2, 3);
__hidden extern ssize_t lxc_abstract_unix_recv_fds(int fd,
struct unix_fds *ret_fds,
void *ret_data,
size_t size_ret_data)
__access_r(3, 4);
__hidden extern int lxc_abstract_unix_send_fds_iov(int fd, int *sendfds, int num_sendfds, __hidden extern ssize_t lxc_abstract_unix_recv_one_fd(int fd, int *ret_fd,
struct iovec *iov, size_t iovlen) __access_r(2, 3); void *ret_data,
size_t size_ret_data)
__access_r(3, 4);
__hidden extern int lxc_abstract_unix_recv_fds(int fd, int *recvfds, int num_recvfds, void *data, __hidden extern ssize_t lxc_abstract_unix_recv_two_fds(int fd, int *ret_fd);
size_t size) __access_r(2, 3) __access_r(4, 5);
__hidden extern int lxc_unix_send_fds(int fd, int *sendfds, int num_sendfds, void *data, size_t size); __hidden extern int lxc_unix_send_fds(int fd, int *sendfds, int num_sendfds, void *data, size_t size);
...@@ -37,4 +67,13 @@ __hidden extern int lxc_unix_connect(struct sockaddr_un *addr); ...@@ -37,4 +67,13 @@ __hidden extern int lxc_unix_connect(struct sockaddr_un *addr);
__hidden extern int lxc_unix_connect_type(struct sockaddr_un *addr, int type); __hidden extern int lxc_unix_connect_type(struct sockaddr_un *addr, int type);
__hidden extern int lxc_socket_set_timeout(int fd, int rcv_timeout, int snd_timeout); __hidden extern int lxc_socket_set_timeout(int fd, int rcv_timeout, int snd_timeout);
static inline void put_unix_fds(struct unix_fds *fds)
{
if (!IS_ERR_OR_NULL(fds)) {
for (size_t idx = 0; idx < fds->fd_count_ret; idx++)
close_prot_errno_disarm(fds->fd[idx]);
}
}
define_cleanup_function(struct unix_fds *, put_unix_fds);
#endif /* __LXC_AF_UNIX_H */ #endif /* __LXC_AF_UNIX_H */
...@@ -164,7 +164,7 @@ static inline bool sync_wake_fd(int fd, int fd_send) ...@@ -164,7 +164,7 @@ static inline bool sync_wake_fd(int fd, int fd_send)
static inline bool sync_wait_fd(int fd, int *fd_recv) static inline bool sync_wait_fd(int fd, int *fd_recv)
{ {
return lxc_abstract_unix_recv_fds(fd, fd_recv, 1, NULL, 0) > 0; return lxc_abstract_unix_recv_one_fd(fd, fd_recv, NULL, 0) > 0;
} }
static bool attach_lsm(lxc_attach_options_t *options) static bool attach_lsm(lxc_attach_options_t *options)
......
...@@ -2183,8 +2183,8 @@ static int cgroup_attach_move_into_leaf(const struct lxc_conf *conf, ...@@ -2183,8 +2183,8 @@ static int cgroup_attach_move_into_leaf(const struct lxc_conf *conf,
size_t pidstr_len; size_t pidstr_len;
ssize_t ret; ssize_t ret;
ret = lxc_abstract_unix_recv_fds(sk, target_fds, 2, NULL, 0); ret = lxc_abstract_unix_recv_two_fds(sk, target_fds);
if (ret <= 0) if (ret < 0)
return log_error_errno(-1, errno, "Failed to receive target cgroup fd"); return log_error_errno(-1, errno, "Failed to receive target cgroup fd");
target_fd0 = target_fds[0]; target_fd0 = target_fds[0];
target_fd1 = target_fds[1]; target_fd1 = target_fds[1];
......
...@@ -115,11 +115,15 @@ static const char *lxc_cmd_str(lxc_cmd_t cmd) ...@@ -115,11 +115,15 @@ static const char *lxc_cmd_str(lxc_cmd_t cmd)
*/ */
static int lxc_cmd_rsp_recv(int sock, struct lxc_cmd_rr *cmd) static int lxc_cmd_rsp_recv(int sock, struct lxc_cmd_rr *cmd)
{ {
__do_close int fd_rsp = -EBADF; call_cleaner(put_unix_fds) struct unix_fds *fds = NULL;
int ret; int ret;
struct lxc_cmd_rsp *rsp = &cmd->rsp; struct lxc_cmd_rsp *rsp = &cmd->rsp;
ret = lxc_abstract_unix_recv_fds(sock, &fd_rsp, 1, rsp, sizeof(*rsp)); fds = &(struct unix_fds){
.fd_count_max = 1,
};
ret = lxc_abstract_unix_recv_fds(sock, fds, rsp, sizeof(*rsp));
if (ret < 0) if (ret < 0)
return log_warn_errno(-1, return log_warn_errno(-1,
errno, "Failed to receive response for command \"%s\"", errno, "Failed to receive response for command \"%s\"",
...@@ -141,30 +145,29 @@ static int lxc_cmd_rsp_recv(int sock, struct lxc_cmd_rr *cmd) ...@@ -141,30 +145,29 @@ static int lxc_cmd_rsp_recv(int sock, struct lxc_cmd_rr *cmd)
ENOMEM, "Failed to receive response for command \"%s\"", ENOMEM, "Failed to receive response for command \"%s\"",
lxc_cmd_str(cmd->req.cmd)); lxc_cmd_str(cmd->req.cmd));
rspdata->ptxfd = move_fd(fd_rsp); rspdata->ptxfd = move_fd(fds->fd[0]);
rspdata->ttynum = PTR_TO_INT(rsp->data); rspdata->ttynum = PTR_TO_INT(rsp->data);
rsp->data = rspdata; rsp->data = rspdata;
} }
if (cmd->req.cmd == LXC_CMD_GET_CGROUP2_FD || if (cmd->req.cmd == LXC_CMD_GET_CGROUP2_FD ||
cmd->req.cmd == LXC_CMD_GET_LIMITING_CGROUP2_FD) cmd->req.cmd == LXC_CMD_GET_LIMITING_CGROUP2_FD) {
{ int cgroup2_fd = move_fd(fds->fd[0]);
int cgroup2_fd = move_fd(fd_rsp);
rsp->data = INT_TO_PTR(cgroup2_fd); rsp->data = INT_TO_PTR(cgroup2_fd);
} }
if (cmd->req.cmd == LXC_CMD_GET_INIT_PIDFD) { if (cmd->req.cmd == LXC_CMD_GET_INIT_PIDFD) {
int init_pidfd = move_fd(fd_rsp); int init_pidfd = move_fd(fds->fd[0]);
rsp->data = INT_TO_PTR(init_pidfd); rsp->data = INT_TO_PTR(init_pidfd);
} }
if (cmd->req.cmd == LXC_CMD_GET_DEVPTS_FD) { if (cmd->req.cmd == LXC_CMD_GET_DEVPTS_FD) {
int devpts_fd = move_fd(fd_rsp); int devpts_fd = move_fd(fds->fd[0]);
rsp->data = INT_TO_PTR(devpts_fd); rsp->data = INT_TO_PTR(devpts_fd);
} }
if (cmd->req.cmd == LXC_CMD_GET_SECCOMP_NOTIFY_FD) { if (cmd->req.cmd == LXC_CMD_GET_SECCOMP_NOTIFY_FD) {
int seccomp_notify_fd = move_fd(fd_rsp); int seccomp_notify_fd = move_fd(fds->fd[0]);
rsp->data = INT_TO_PTR(seccomp_notify_fd); rsp->data = INT_TO_PTR(seccomp_notify_fd);
} }
...@@ -1371,7 +1374,7 @@ static int lxc_cmd_seccomp_notify_add_listener_callback(int fd, ...@@ -1371,7 +1374,7 @@ static int lxc_cmd_seccomp_notify_add_listener_callback(int fd,
int ret; int ret;
__do_close int recv_fd = -EBADF; __do_close int recv_fd = -EBADF;
ret = lxc_abstract_unix_recv_fds(fd, &recv_fd, 1, NULL, 0); ret = lxc_abstract_unix_recv_one_fd(fd, &recv_fd, NULL, 0);
if (ret <= 0) { if (ret <= 0) {
rsp.ret = -errno; rsp.ret = -errno;
goto out; goto out;
......
...@@ -1509,8 +1509,10 @@ int lxc_setup_devpts_parent(struct lxc_handler *handler) ...@@ -1509,8 +1509,10 @@ int lxc_setup_devpts_parent(struct lxc_handler *handler)
if (handler->conf->pty_max <= 0) if (handler->conf->pty_max <= 0)
return 0; return 0;
ret = lxc_abstract_unix_recv_fds(handler->data_sock[1], &handler->conf->devpts_fd, 1, ret = lxc_abstract_unix_recv_one_fd(handler->data_sock[1],
&handler->conf->devpts_fd, sizeof(handler->conf->devpts_fd)); &handler->conf->devpts_fd,
&handler->conf->devpts_fd,
sizeof(handler->conf->devpts_fd));
if (ret < 0) if (ret < 0)
return log_error_errno(-1, errno, "Failed to receive devpts fd from child"); return log_error_errno(-1, errno, "Failed to receive devpts fd from child");
......
...@@ -1637,9 +1637,9 @@ int lxc_seccomp_recv_notifier_fd(struct lxc_seccomp *seccomp, int socket_fd) ...@@ -1637,9 +1637,9 @@ int lxc_seccomp_recv_notifier_fd(struct lxc_seccomp *seccomp, int socket_fd)
if (seccomp->notifier.wants_supervision) { if (seccomp->notifier.wants_supervision) {
int ret; int ret;
ret = lxc_abstract_unix_recv_fds(socket_fd, ret = lxc_abstract_unix_recv_one_fd(socket_fd,
&seccomp->notifier.notify_fd, &seccomp->notifier.notify_fd,
1, NULL, 0); NULL, 0);
if (ret < 0) if (ret < 0)
return -1; return -1;
} }
......
...@@ -1041,7 +1041,7 @@ static int do_start(void *data) ...@@ -1041,7 +1041,7 @@ static int do_start(void *data)
lxc_sync_fini_parent(handler); lxc_sync_fini_parent(handler);
if (lxc_abstract_unix_recv_fds(data_sock1, &status_fd, 1, NULL, 0) < 0) { if (lxc_abstract_unix_recv_one_fd(data_sock1, &status_fd, NULL, 0) < 0) {
ERROR("Failed to receive status file descriptor to child process"); ERROR("Failed to receive status file descriptor to child process");
goto out_warn_father; goto out_warn_father;
} }
...@@ -1460,7 +1460,7 @@ static int lxc_recv_ttys_from_child(struct lxc_handler *handler) ...@@ -1460,7 +1460,7 @@ static int lxc_recv_ttys_from_child(struct lxc_handler *handler)
for (i = 0; i < conf->ttys.max; i++) { for (i = 0; i < conf->ttys.max; i++) {
int ttyfds[2]; int ttyfds[2];
ret = lxc_abstract_unix_recv_fds(sock, ttyfds, 2, NULL, 0); ret = lxc_abstract_unix_recv_two_fds(sock, ttyfds);
if (ret < 0) if (ret < 0)
break; break;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment