From aca1ad90ac68780598f7bcce7b475670f2a48cb0 Mon Sep 17 00:00:00 2001 From: Pierre-Marc Fournier Date: Tue, 24 Feb 2009 19:43:56 -0500 Subject: [PATCH] ust: change communication socket for STREAM --- libtracectl/localerr.h | 1 + libustcomm/ustcomm.c | 163 ++++++++++++++++++++++++++++++++++------- libustcomm/ustcomm.h | 15 +++- share/usterr.c | 18 +++++ ust/Makefile | 6 ++ ust/localerr.h | 11 +++ 6 files changed, 185 insertions(+), 29 deletions(-) create mode 100644 libtracectl/localerr.h create mode 100644 share/usterr.c create mode 100644 ust/Makefile create mode 100644 ust/localerr.h diff --git a/libtracectl/localerr.h b/libtracectl/localerr.h new file mode 100644 index 0000000..eef0d4f --- /dev/null +++ b/libtracectl/localerr.h @@ -0,0 +1 @@ +#include "usterr.h" diff --git a/libustcomm/ustcomm.c b/libustcomm/ustcomm.c index c944214..2773b1a 100644 --- a/libustcomm/ustcomm.c +++ b/libustcomm/ustcomm.c @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -20,6 +21,11 @@ #define MSG_MAX 1000 +/* FIXME: ustcomm blocks on message sending, which might be problematic in + * some cases. Fix the poll() usage so sends are buffered until they don't + * block. + */ + //static void bt(void) //{ // void *buffer[100]; @@ -48,7 +54,7 @@ int send_message_path(const char *path, const char *msg, char **reply, int signa int result; struct sockaddr_un addr; - result = fd = socket(PF_UNIX, SOCK_DGRAM, 0); + result = fd = socket(PF_UNIX, SOCK_STREAM, 0); if(result == -1) { PERROR("socket"); return -1; @@ -65,9 +71,15 @@ int send_message_path(const char *path, const char *msg, char **reply, int signa if(signalpid >= 0) signal_process(signalpid); - result = sendto(fd, msg, strlen(msg), 0, (struct sockaddr *)&addr, sizeof(addr)); + result = connect(fd, (struct sockaddr *)&addr, sizeof(addr)); if(result == -1) { - PERROR("sendto"); + PERROR("connect"); + return -1; + } + + result = send(fd, msg, strlen(msg), 0); + if(result == -1) { + PERROR("send"); return -1; } @@ -133,23 +145,10 @@ int ustcomm_request_consumer(pid_t pid, const char *channel) static int recv_message_fd(int fd, char **msg, struct ustcomm_source *src) { int result; - size_t initial_addrlen,addrlen; *msg = (char *) malloc(MSG_MAX+1); - if(src) { - initial_addrlen = addrlen = sizeof(src->addr); - - result = recvfrom(fd, *msg, MSG_MAX, 0, &src->addr, &addrlen); - if(initial_addrlen != addrlen) { - ERR("recvfrom: unexpected address length"); - return -1; - } - } - else { - result = recvfrom(fd, *msg, MSG_MAX, 0, NULL, NULL); - } - + result = recv(fd, *msg, MSG_MAX, 0); if(result == -1) { PERROR("recvfrom"); return -1; @@ -164,12 +163,94 @@ static int recv_message_fd(int fd, char **msg, struct ustcomm_source *src) int ustcomm_ustd_recv_message(struct ustcomm_ustd *ustd, char **msg, struct ustcomm_source *src) { - return recv_message_fd(ustd->fd, msg, src); + struct pollfd *fds; + struct ustcomm_connection *conn; + int result; + int retval; + + for(;;) { + int idx = 0; + int n_fds = 1; + + list_for_each_entry(conn, &ustd->connections, list) { + n_fds++; + } + + fds = (struct pollfd *) malloc(n_fds * sizeof(struct pollfd)); + if(fds == NULL) { + ERR("malloc returned NULL"); + return -1; + } + + /* special idx 0 is for listening socket */ + fds[idx].fd = ustd->listen_fd; + fds[idx].events = POLLIN; + idx++; + + list_for_each_entry(conn, &ustd->connections, list) { + fds[idx].fd = conn->fd; + fds[idx].events = POLLIN; + idx++; + } + + result = poll(fds, n_fds, -1); + if(result == -1) { + PERROR("poll"); + return -1; + } + + if(fds[0].revents) { + struct ustcomm_connection *newconn; + int newfd; + + result = newfd = accept(ustd->listen_fd, NULL, NULL); + if(result == -1) { + PERROR("accept"); + return -1; + } + + newconn = (struct ustcomm_connection *) malloc(sizeof(struct ustcomm_connection)); + if(newconn == NULL) { + ERR("malloc returned NULL"); + return -1; + } + + newconn->fd = newfd; + + list_add(&newconn->list, &ustd->connections); + } + + for(idx=1; idxconnections, list) { + if(conn->fd == fds[idx].fd) { + list_del(&conn->list); + break; + } + } + } + else { + goto free_fds_return; + } + } + } + + free(fds); + } + +free_fds_return: + free(fds); + return retval; } int ustcomm_app_recv_message(struct ustcomm_app *app, char **msg, struct ustcomm_source *src) { - return recv_message_fd(app->fd, msg, src); + return ustcomm_ustd_recv_message((struct ustcomm_ustd *)app, msg, src); } static int init_named_socket(char *name, char **path_out) @@ -179,7 +260,7 @@ static int init_named_socket(char *name, char **path_out) struct sockaddr_un addr; - result = fd = socket(PF_UNIX, SOCK_DGRAM, 0); + result = fd = socket(PF_UNIX, SOCK_STREAM, 0); if(result == -1) { PERROR("socket"); return -1; @@ -190,12 +271,29 @@ static int init_named_socket(char *name, char **path_out) strncpy(addr.sun_path, name, UNIX_PATH_MAX); addr.sun_path[UNIX_PATH_MAX-1] = '\0'; + result = access(name, F_OK); + if(result == 0) { + /* file exists */ + result = unlink(name); + if(result == -1) { + PERROR("unlink of socket file"); + goto close_sock; + } + WARN("socket already exists; overwriting"); + } + result = bind(fd, (struct sockaddr *)&addr, sizeof(addr)); if(result == -1) { PERROR("bind"); goto close_sock; } + result = listen(fd, 1); + if(result == -1) { + PERROR("listen"); + goto close_sock; + } + if(path_out) { *path_out = ""; *path_out = strdupa(addr.sun_path); @@ -220,12 +318,15 @@ int ustcomm_init_app(pid_t pid, struct ustcomm_app *handle) return -1; } - handle->fd = init_named_socket(name, &(handle->socketpath)); - if(handle->fd < 0) { + handle->listen_fd = init_named_socket(name, &(handle->socketpath)); + if(handle->listen_fd < 0) { + ERR("error initializing named socket"); goto free_name; } free(name); + INIT_LIST_HEAD(&handle->connections); + return 0; free_name: @@ -244,15 +345,23 @@ int ustcomm_init_ustd(struct ustcomm_ustd *handle) return -1; } - handle->fd = init_named_socket(name, &handle->socketpath); - if(handle->fd < 0) - return handle->fd; + handle->listen_fd = init_named_socket(name, &handle->socketpath); + if(handle->listen_fd < 0) { + ERR("error initializing named socket"); + goto free_name; + } free(name); + INIT_LIST_HEAD(&handle->connections); + return 0; + +free_name: + free(name); + return -1; } -char *find_tok(const char *str) +static char *find_tok(char *str) { while(*str == ' ') { str++; @@ -331,7 +440,7 @@ char *nth_token(char *str, int tok_no) retval = NULL; } - retval = strndupa(start, end-start); + asprintf(&retval, "%.*s", (int)(end-start), start); return retval; } diff --git a/libustcomm/ustcomm.h b/libustcomm/ustcomm.h index 53ced51..7d84592 100644 --- a/libustcomm/ustcomm.h +++ b/libustcomm/ustcomm.h @@ -4,16 +4,27 @@ #include #include +#include "kcompat.h" + +struct ustcomm_connection { + struct list_head list; + int fd; +}; + struct ustcomm_app { /* the "server" socket for serving the external requests */ - int fd; + int listen_fd; char *socketpath; + + struct list_head connections; }; struct ustcomm_ustd { /* the "server" socket for serving the external requests */ - int fd; + int listen_fd; char *socketpath; + + struct list_head connections; }; struct ustcomm_source { diff --git a/share/usterr.c b/share/usterr.c new file mode 100644 index 0000000..ee7aeb0 --- /dev/null +++ b/share/usterr.c @@ -0,0 +1,18 @@ +#include +#include + +int safe_printf(const char *fmt, ...) +{ + static char buf[500]; + va_list ap; + int n; + + va_start(ap, fmt); + + n = vsnprintf(buf, sizeof(buf), fmt, ap); + + write(STDOUT_FILENO, buf, n); + + va_end(ap); +} + diff --git a/ust/Makefile b/ust/Makefile new file mode 100644 index 0000000..3aeec04 --- /dev/null +++ b/ust/Makefile @@ -0,0 +1,6 @@ +all: ust + +ust: ust.c + gcc -g -Wall -I ../libustcomm -I. -I ../../../../libkcompat -o ust ust.c ../libustcomm/ustcomm.c + +.PHONY: ust diff --git a/ust/localerr.h b/ust/localerr.h new file mode 100644 index 0000000..7fe99e2 --- /dev/null +++ b/ust/localerr.h @@ -0,0 +1,11 @@ +#ifndef LOCALERR_H +#define LOCALERR_H + +#include + +#define DBG(fmt, args...) fprintf(stderr, "ustd: " fmt "\n", ## args); fflush(stderr) +#define WARN(fmt, args...) fprintf(stderr, "ustd: WARNING: " fmt "\n", ## args); fflush(stderr) +#define ERR(fmt, args...) fprintf(stderr, "ustd: ERROR: " fmt "\n", ## args); fflush(stderr) +#define PERROR(a) perror(a) + +#endif /* LOCALERR_H */ -- 2.34.1