af_unix: improve SCM_RIGHTS file descriptor retrieval

parent da63ea6b
......@@ -112,22 +112,18 @@ int lxc_abstract_unix_connect(const char *path)
return move_fd(fd);
}
int lxc_abstract_unix_send_fds_iov(int fd, int *sendfds, int num_sendfds,
int lxc_abstract_unix_send_fds_iov(int fd, const int *sendfds, int num_sendfds,
struct iovec *iov, size_t iovlen)
{
__do_free char *cmsgbuf = NULL;
int ret;
struct msghdr msg;
struct msghdr msg = {};
struct cmsghdr *cmsg = NULL;
size_t cmsgbufsize = CMSG_SPACE(num_sendfds * sizeof(int));
memset(&msg, 0, sizeof(msg));
cmsgbuf = malloc(cmsgbufsize);
if (!cmsgbuf) {
errno = ENOMEM;
return -1;
}
if (!cmsgbuf)
return ret_errno(-ENOMEM);
msg.msg_control = cmsgbuf;
msg.msg_controllen = cmsgbufsize;
......@@ -151,13 +147,13 @@ int lxc_abstract_unix_send_fds_iov(int fd, int *sendfds, int num_sendfds,
return ret;
}
int lxc_abstract_unix_send_fds(int fd, int *sendfds, int num_sendfds,
int lxc_abstract_unix_send_fds(int fd, const int *sendfds, int num_sendfds,
void *data, size_t size)
{
char buf[1] = {0};
char buf[1] = {};
struct iovec iov = {
.iov_base = data ? data : buf,
.iov_len = data ? size : sizeof(buf),
.iov_base = data ? data : buf,
.iov_len = data ? size : sizeof(buf),
};
return lxc_abstract_unix_send_fds_iov(fd, sendfds, num_sendfds, &iov, 1);
}
......@@ -168,60 +164,168 @@ int lxc_unix_send_fds(int fd, int *sendfds, int num_sendfds, void *data,
return lxc_abstract_unix_send_fds(fd, sendfds, num_sendfds, data, size);
}
static int lxc_abstract_unix_recv_fds_iov(int fd, int *recvfds, int num_recvfds,
struct iovec *iov, size_t iovlen)
static ssize_t lxc_abstract_unix_recv_fds_iov(int fd,
struct unix_fds *ret_fds,
struct iovec *ret_iov,
size_t size_ret_iov)
{
__do_free char *cmsgbuf = NULL;
int ret;
struct msghdr msg;
ssize_t ret;
struct msghdr msg = {};
struct cmsghdr *cmsg = NULL;
size_t cmsgbufsize = CMSG_SPACE(sizeof(struct ucred)) +
CMSG_SPACE(num_recvfds * sizeof(int));
CMSG_SPACE(ret_fds->fd_count_max * sizeof(int));
memset(&msg, 0, sizeof(msg));
cmsgbuf = malloc(cmsgbufsize);
cmsgbuf = zalloc(cmsgbufsize);
if (!cmsgbuf)
return ret_errno(ENOMEM);
msg.msg_control = cmsgbuf;
msg.msg_controllen = cmsgbufsize;
msg.msg_control = cmsgbuf;
msg.msg_controllen = cmsgbufsize;
msg.msg_iov = iov;
msg.msg_iovlen = iovlen;
msg.msg_iov = ret_iov;
msg.msg_iovlen = size_ret_iov;
do {
ret = recvmsg(fd, &msg, MSG_CMSG_CLOEXEC);
} while (ret < 0 && errno == EINTR);
if (ret < 0 || ret == 0)
return ret;
again:
ret = recvmsg(fd, &msg, MSG_CMSG_CLOEXEC);
if (ret < 0) {
if (errno == EINTR)
goto again;
/*
* If SO_PASSCRED is set we will always get a ucred message.
*/
for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_type != SCM_RIGHTS)
continue;
memset(recvfds, -1, num_recvfds * sizeof(int));
if (cmsg &&
cmsg->cmsg_len == CMSG_LEN(num_recvfds * sizeof(int)) &&
cmsg->cmsg_level == SOL_SOCKET)
memcpy(recvfds, CMSG_DATA(cmsg), num_recvfds * sizeof(int));
break;
return syserrno(-errno, "Failed to receive response");
}
if (ret == 0)
return 0;
/* If SO_PASSCRED is set we will always get a ucred message. */
for (cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
__u32 idx;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wcast-align"
int *fds_raw = (int *)CMSG_DATA(cmsg);
#pragma GCC diagnostic pop
__u32 num_raw = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
/*
* We received an insane amount of file descriptors
* which exceeds the kernel limit we know about so
* close them and return an error.
*/
if (num_raw > KERNEL_SCM_MAX_FD) {
for (idx = 0; idx < num_raw; idx++)
close(fds_raw[idx]);
return syserrno_set(-EFBIG, "Received excessive number of file descriptors");
}
if (ret_fds->fd_count_max > num_raw) {
/*
* Make sure any excess entries in the fd array
* are set to -EBADF so our cleanup functions
* can safely be called.
*/
for (idx = num_raw; idx < ret_fds->fd_count_max; idx++)
ret_fds->fd[idx] = -EBADF;
WARN("Received fewer file descriptors than we expected %u != %u", ret_fds->fd_count_max, num_raw);
} else if (ret_fds->fd_count_max < num_raw) {
/* Make sure we close any excess fds we received. */
for (idx = ret_fds->fd_count_max; idx < num_raw; idx++)
close(fds_raw[idx]);
WARN("Received more file descriptors than we expected %u != %u", ret_fds->fd_count_max, num_raw);
/* Cap the number of received file descriptors. */
num_raw = ret_fds->fd_count_max;
}
memcpy(ret_fds->fd, CMSG_DATA(cmsg), num_raw * sizeof(int));
ret_fds->fd_count_ret = num_raw;
break;
}
}
return ret;
}
int lxc_abstract_unix_recv_fds(int fd, int *recvfds, int num_recvfds,
void *data, size_t size)
ssize_t lxc_abstract_unix_recv_fds(int fd, struct unix_fds *ret_fds,
void *ret_data, size_t size_ret_data)
{
char buf[1] = {0};
char buf[1] = {};
struct iovec iov = {
.iov_base = ret_data ? ret_data : buf,
.iov_len = ret_data ? size_ret_data : sizeof(buf),
};
ssize_t ret;
ret = lxc_abstract_unix_recv_fds_iov(fd, ret_fds, &iov, 1);
if (ret < 0)
return ret;
return ret;
}
ssize_t lxc_abstract_unix_recv_one_fd(int fd, int *ret_fd, void *ret_data,
size_t size_ret_data)
{
call_cleaner(put_unix_fds) struct unix_fds *fds = NULL;
char buf[1] = {};
struct iovec iov = {
.iov_base = data ? data : buf,
.iov_len = data ? size : sizeof(buf),
.iov_base = ret_data ? ret_data : buf,
.iov_len = ret_data ? size_ret_data : sizeof(buf),
};
ssize_t ret;
fds = &(struct unix_fds){
.fd_count_max = 1,
};
return lxc_abstract_unix_recv_fds_iov(fd, recvfds, num_recvfds, &iov, 1);
ret = lxc_abstract_unix_recv_fds_iov(fd, fds, &iov, 1);
if (ret < 0)
return ret;
if (ret == 0)
return ret_errno(ENODATA);
if (fds->fd_count_ret != fds->fd_count_max)
*ret_fd = -EBADF;
else
*ret_fd = move_fd(fds->fd[0]);
return ret;
}
ssize_t lxc_abstract_unix_recv_two_fds(int fd, int *ret_fd)
{
call_cleaner(put_unix_fds) struct unix_fds *fds = NULL;
char buf[1] = {};
struct iovec iov = {
.iov_base = buf,
.iov_len = sizeof(buf),
};
ssize_t ret;
fds = &(struct unix_fds){
.fd_count_max = 2,
};
ret = lxc_abstract_unix_recv_fds_iov(fd, fds, &iov, 1);
if (ret < 0)
return ret;
if (ret == 0)
return ret_errno(ENODATA);
if (fds->fd_count_ret != fds->fd_count_max) {
ret_fd[0] = -EBADF;
ret_fd[1] = -EBADF;
} else {
ret_fd[0] = move_fd(fds->fd[0]);
ret_fd[1] = move_fd(fds->fd[1]);
}
return 0;
}
int lxc_abstract_unix_send_credential(int fd, void *data, size_t size)
......
......@@ -5,9 +5,24 @@
#include <stdio.h>
#include <sys/socket.h>
#include <stddef.h>
#include <sys/un.h>
#include "compiler.h"
#include "macro.h"
#include "memory_utils.h"
/*
* Technically 253 is the kernel limit but we want to the struct to be a
* multiple of 8.
*/
#define KERNEL_SCM_MAX_FD 252
struct unix_fds {
__u32 fd_count_max;
__u32 fd_count_ret;
__s32 fd[KERNEL_SCM_MAX_FD];
} __attribute__((aligned(8)));
/* does not enforce \0-termination */
__hidden extern int lxc_abstract_unix_open(const char *path, int type, int flags);
......@@ -15,14 +30,29 @@ __hidden extern void lxc_abstract_unix_close(int fd);
/* does not enforce \0-termination */
__hidden extern int lxc_abstract_unix_connect(const char *path);
__hidden extern int lxc_abstract_unix_send_fds(int fd, int *sendfds, int num_sendfds, void *data,
size_t size) __access_r(2, 3) __access_r(4, 5);
__hidden extern int lxc_abstract_unix_send_fds(int fd, const int *sendfds,
int num_sendfds, void *data,
size_t size) __access_r(2, 3)
__access_r(4, 5);
__hidden extern int lxc_abstract_unix_send_fds_iov(int fd, const int *sendfds,
int num_sendfds,
struct iovec *iov,
size_t iovlen)
__access_r(2, 3);
__hidden extern ssize_t lxc_abstract_unix_recv_fds(int fd,
struct unix_fds *ret_fds,
void *ret_data,
size_t size_ret_data)
__access_r(3, 4);
__hidden extern int lxc_abstract_unix_send_fds_iov(int fd, int *sendfds, int num_sendfds,
struct iovec *iov, size_t iovlen) __access_r(2, 3);
__hidden extern ssize_t lxc_abstract_unix_recv_one_fd(int fd, int *ret_fd,
void *ret_data,
size_t size_ret_data)
__access_r(3, 4);
__hidden extern int lxc_abstract_unix_recv_fds(int fd, int *recvfds, int num_recvfds, void *data,
size_t size) __access_r(2, 3) __access_r(4, 5);
__hidden extern ssize_t lxc_abstract_unix_recv_two_fds(int fd, int *ret_fd);
__hidden extern int lxc_unix_send_fds(int fd, int *sendfds, int num_sendfds, void *data, size_t size);
......@@ -37,4 +67,13 @@ __hidden extern int lxc_unix_connect(struct sockaddr_un *addr);
__hidden extern int lxc_unix_connect_type(struct sockaddr_un *addr, int type);
__hidden extern int lxc_socket_set_timeout(int fd, int rcv_timeout, int snd_timeout);
static inline void put_unix_fds(struct unix_fds *fds)
{
if (!IS_ERR_OR_NULL(fds)) {
for (size_t idx = 0; idx < fds->fd_count_ret; idx++)
close_prot_errno_disarm(fds->fd[idx]);
}
}
define_cleanup_function(struct unix_fds *, put_unix_fds);
#endif /* __LXC_AF_UNIX_H */
......@@ -164,7 +164,7 @@ static inline bool sync_wake_fd(int fd, int fd_send)
static inline bool sync_wait_fd(int fd, int *fd_recv)
{
return lxc_abstract_unix_recv_fds(fd, fd_recv, 1, NULL, 0) > 0;
return lxc_abstract_unix_recv_one_fd(fd, fd_recv, NULL, 0) > 0;
}
static bool attach_lsm(lxc_attach_options_t *options)
......
......@@ -2183,8 +2183,8 @@ static int cgroup_attach_move_into_leaf(const struct lxc_conf *conf,
size_t pidstr_len;
ssize_t ret;
ret = lxc_abstract_unix_recv_fds(sk, target_fds, 2, NULL, 0);
if (ret <= 0)
ret = lxc_abstract_unix_recv_two_fds(sk, target_fds);
if (ret < 0)
return log_error_errno(-1, errno, "Failed to receive target cgroup fd");
target_fd0 = target_fds[0];
target_fd1 = target_fds[1];
......
......@@ -115,11 +115,15 @@ static const char *lxc_cmd_str(lxc_cmd_t cmd)
*/
static int lxc_cmd_rsp_recv(int sock, struct lxc_cmd_rr *cmd)
{
__do_close int fd_rsp = -EBADF;
call_cleaner(put_unix_fds) struct unix_fds *fds = NULL;
int ret;
struct lxc_cmd_rsp *rsp = &cmd->rsp;
ret = lxc_abstract_unix_recv_fds(sock, &fd_rsp, 1, rsp, sizeof(*rsp));
fds = &(struct unix_fds){
.fd_count_max = 1,
};
ret = lxc_abstract_unix_recv_fds(sock, fds, rsp, sizeof(*rsp));
if (ret < 0)
return log_warn_errno(-1,
errno, "Failed to receive response for command \"%s\"",
......@@ -141,30 +145,29 @@ static int lxc_cmd_rsp_recv(int sock, struct lxc_cmd_rr *cmd)
ENOMEM, "Failed to receive response for command \"%s\"",
lxc_cmd_str(cmd->req.cmd));
rspdata->ptxfd = move_fd(fd_rsp);
rspdata->ptxfd = move_fd(fds->fd[0]);
rspdata->ttynum = PTR_TO_INT(rsp->data);
rsp->data = rspdata;
}
if (cmd->req.cmd == LXC_CMD_GET_CGROUP2_FD ||
cmd->req.cmd == LXC_CMD_GET_LIMITING_CGROUP2_FD)
{
int cgroup2_fd = move_fd(fd_rsp);
cmd->req.cmd == LXC_CMD_GET_LIMITING_CGROUP2_FD) {
int cgroup2_fd = move_fd(fds->fd[0]);
rsp->data = INT_TO_PTR(cgroup2_fd);
}
if (cmd->req.cmd == LXC_CMD_GET_INIT_PIDFD) {
int init_pidfd = move_fd(fd_rsp);
int init_pidfd = move_fd(fds->fd[0]);
rsp->data = INT_TO_PTR(init_pidfd);
}
if (cmd->req.cmd == LXC_CMD_GET_DEVPTS_FD) {
int devpts_fd = move_fd(fd_rsp);
int devpts_fd = move_fd(fds->fd[0]);
rsp->data = INT_TO_PTR(devpts_fd);
}
if (cmd->req.cmd == LXC_CMD_GET_SECCOMP_NOTIFY_FD) {
int seccomp_notify_fd = move_fd(fd_rsp);
int seccomp_notify_fd = move_fd(fds->fd[0]);
rsp->data = INT_TO_PTR(seccomp_notify_fd);
}
......@@ -1371,7 +1374,7 @@ static int lxc_cmd_seccomp_notify_add_listener_callback(int fd,
int ret;
__do_close int recv_fd = -EBADF;
ret = lxc_abstract_unix_recv_fds(fd, &recv_fd, 1, NULL, 0);
ret = lxc_abstract_unix_recv_one_fd(fd, &recv_fd, NULL, 0);
if (ret <= 0) {
rsp.ret = -errno;
goto out;
......
......@@ -1509,8 +1509,10 @@ int lxc_setup_devpts_parent(struct lxc_handler *handler)
if (handler->conf->pty_max <= 0)
return 0;
ret = lxc_abstract_unix_recv_fds(handler->data_sock[1], &handler->conf->devpts_fd, 1,
&handler->conf->devpts_fd, sizeof(handler->conf->devpts_fd));
ret = lxc_abstract_unix_recv_one_fd(handler->data_sock[1],
&handler->conf->devpts_fd,
&handler->conf->devpts_fd,
sizeof(handler->conf->devpts_fd));
if (ret < 0)
return log_error_errno(-1, errno, "Failed to receive devpts fd from child");
......
......@@ -1637,9 +1637,9 @@ int lxc_seccomp_recv_notifier_fd(struct lxc_seccomp *seccomp, int socket_fd)
if (seccomp->notifier.wants_supervision) {
int ret;
ret = lxc_abstract_unix_recv_fds(socket_fd,
&seccomp->notifier.notify_fd,
1, NULL, 0);
ret = lxc_abstract_unix_recv_one_fd(socket_fd,
&seccomp->notifier.notify_fd,
NULL, 0);
if (ret < 0)
return -1;
}
......
......@@ -1041,7 +1041,7 @@ static int do_start(void *data)
lxc_sync_fini_parent(handler);
if (lxc_abstract_unix_recv_fds(data_sock1, &status_fd, 1, NULL, 0) < 0) {
if (lxc_abstract_unix_recv_one_fd(data_sock1, &status_fd, NULL, 0) < 0) {
ERROR("Failed to receive status file descriptor to child process");
goto out_warn_father;
}
......@@ -1460,7 +1460,7 @@ static int lxc_recv_ttys_from_child(struct lxc_handler *handler)
for (i = 0; i < conf->ttys.max; i++) {
int ttyfds[2];
ret = lxc_abstract_unix_recv_fds(sock, ttyfds, 2, NULL, 0);
ret = lxc_abstract_unix_recv_two_fds(sock, ttyfds);
if (ret < 0)
break;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment