libustcomm: fix segfault caused by incorrect initialization of buffer size
[ust.git] / libustcomm / ustcomm.c
index 6044c271fa7101d585261d44b1c24f401dfd2ef4..0d2ab339485cb666ba033b24d9896840a7d1c311 100644 (file)
 #include <execinfo.h>
 
 #include "ustcomm.h"
-#include "localerr.h"
+#include "usterr.h"
+#include "share.h"
 
 #define UNIX_PATH_MAX 108
 
-#define MSG_MAX 1000
+#define MSG_MAX 10000
 
 /* FIXME: ustcomm blocks on message sending, which might be problematic in
  * some cases. Fix the poll() usage so sends are buffered until they don't
@@ -67,37 +68,40 @@ char *strdup_malloc(const char *s)
 
 static int signal_process(pid_t pid)
 {
-       int result;
-
-       result = kill(pid, UST_SIGNAL);
-       if(result == -1) {
-               PERROR("kill");
-               return -1;
-       }
-
-       /* FIXME: should wait in a better way */
-       //sleep(1);
-
        return 0;
 }
 
 int pid_is_online(pid_t pid) {
-       return kill(pid, UST_SIGNAL) != -1;
+       return 1;
 }
 
+/* Send a message
+ *
+ * @fd: file descriptor to send to
+ * @msg: a null-terminated string containing the message to send
+ *
+ * Return value:
+ * -1: error
+ * 0: connection closed
+ * 1: success
+ */
+
 static int send_message_fd(int fd, const char *msg)
 {
        int result;
 
-       result = send(fd, msg, strlen(msg), MSG_NOSIGNAL);
+       /* Send including the final \0 */
+       result = patient_send(fd, msg, strlen(msg)+1, MSG_NOSIGNAL);
        if(result == -1) {
-               PERROR("send");
+               if(errno != EPIPE)
+                       PERROR("send");
                return -1;
        }
        else if(result == 0) {
                return 0;
        }
 
+       DBG("sent message \"%s\"", msg);
        return 1;
 }
 
@@ -153,30 +157,67 @@ int ustcomm_request_consumer(pid_t pid, const char *channel)
 }
 
 /* returns 1 to indicate a message was received
- * returns 0 to indicate no message was received (cannot happen)
+ * returns 0 to indicate no message was received (end of stream)
  * returns -1 to indicate an error
  */
 
-static int recv_message_fd(int fd, char **msg, struct ustcomm_source *src)
+#define RECV_INCREMENT 1
+#define RECV_INITIAL_BUF_SIZE 10
+
+static int recv_message_fd(int fd, char **msg)
 {
        int result;
+       int buf_alloc_size = 0;
+       char *buf = NULL;
+       int buf_used_size = 0;
 
-       *msg = (char *) malloc(MSG_MAX+1);
+       buf = malloc(RECV_INITIAL_BUF_SIZE);
+       buf_alloc_size = RECV_INITIAL_BUF_SIZE;
 
-       result = recv(fd, *msg, MSG_MAX, 0);
-       if(result == -1) {
-               PERROR("recv");
-               return -1;
-       }
+       for(;;) {
+               if(buf_used_size + RECV_INCREMENT > buf_alloc_size) {
+                       char *new_buf;
+                       buf_alloc_size *= 2;
+                       new_buf = (char *) realloc(buf, buf_alloc_size);
+                       if(new_buf == NULL) {
+                               ERR("realloc returned NULL");
+                               free(buf);
+                               return -1;
+                       }
+                       buf = new_buf;
+               }
 
-       (*msg)[result] = '\0';
-       
-       DBG("ustcomm_app_recv_message: result is %d, message is %s", result, (*msg));
+               /* FIXME: this is really inefficient; but with count>1 we would
+                * need a buffering mechanism */
+               result = recv(fd, buf+buf_used_size, RECV_INCREMENT, 0);
+               if(result == -1) {
+                       free(buf);
+                       if(errno != ECONNRESET)
+                               PERROR("recv");
+                       return -1;
+               }
+               if(result == 0) {
+                       if(buf_used_size)
+                               goto ret;
+                       else {
+                               free(buf);
+                               return 0;
+                       }
+               }
+
+               buf_used_size += result;
 
-       if(src)
-               src->fd = fd;
+               if(buf[buf_used_size-1] == 0) {
+                       goto ret;
+               }
+       }
+
+ret:
+       *msg = buf;
+       DBG("received message \"%s\"", buf);
 
        return 1;
+
 }
 
 int ustcomm_send_reply(struct ustcomm_server *server, char *msg, struct ustcomm_source *src)
@@ -281,7 +322,10 @@ int ustcomm_recv_message(struct ustcomm_server *server, char **msg, struct ustco
 
                for(idx=1; idx<n_fds; idx++) {
                        if(fds[idx].revents) {
-                               retval = recv_message_fd(fds[idx].fd, msg, src);
+                               retval = recv_message_fd(fds[idx].fd, msg);
+                               if(src)
+                                       src->fd = fds[idx].fd;
+
                                if(**msg == 0) {
                                        /* connection finished */
                                        close(fds[idx].fd);
@@ -404,29 +448,22 @@ int ustcomm_send_request(struct ustcomm_connection *conn, const char *req, char
 {
        int result;
 
-       result = send(conn->fd, req, strlen(req), MSG_NOSIGNAL);
-       if(result == -1) {
-               if(errno != EPIPE)
-                       PERROR("send");
-               return -1;
-       }
+       /* Send including the final \0 */
+       result = send_message_fd(conn->fd, req);
+       if(result != 1)
+               return result;
 
        if(!reply)
                return 1;
 
-       *reply = (char *) malloc(MSG_MAX+1);
-       result = recv(conn->fd, *reply, MSG_MAX, 0);
+       result = recv_message_fd(conn->fd, reply);
        if(result == -1) {
-               if(errno != ECONNRESET)
-                       PERROR("recv");
                return -1;
        }
        else if(result == 0) {
                return 0;
        }
        
-       (*reply)[result] = '\0';
-
        return 1;
 }
 
@@ -460,7 +497,7 @@ int ustcomm_connect_path(const char *path, struct ustcomm_connection *conn, pid_
 
        result = connect(fd, (struct sockaddr *)&addr, sizeof(addr));
        if(result == -1) {
-               PERROR("connect");
+               PERROR("connect (path=%s)", path);
                return -1;
        }
 
@@ -489,6 +526,38 @@ int ustcomm_connect_app(pid_t pid, struct ustcomm_connection *conn)
        return ustcomm_connect_path(path, conn, pid);
 }
 
+static int ensure_dir_exists(const char *dir)
+{
+       struct stat st;
+       int result;
+
+       if(!strcmp(dir, ""))
+               return -1;
+
+       result = stat(dir, &st);
+       if(result == -1 && errno != ENOENT) {
+               return -1;
+       }
+       else if(result == -1) {
+               /* ENOENT */
+               char buf[200];
+               int result;
+
+               result = snprintf(buf, sizeof(buf), "mkdir -p \"%s\"", dir);
+               if(result >= sizeof(buf)) {
+                       ERR("snprintf buffer overflow");
+                       return -1;
+               }
+               result = system(buf);
+               if(result != 0) {
+                       ERR("executing command %s", buf);
+                       return -1;
+               }
+       }
+
+       return 0;
+}
+
 /* Called by an application to initialize its server so daemons can
  * connect to it.
  */
@@ -504,6 +573,12 @@ int ustcomm_init_app(pid_t pid, struct ustcomm_app *handle)
                return -1;
        }
 
+       result = ensure_dir_exists(SOCK_DIR);
+       if(result == -1) {
+               ERR("Unable to create socket directory %s", SOCK_DIR);
+               return -1;
+       }
+
        handle->server.listen_fd = init_named_socket(name, &(handle->server.socketpath));
        if(handle->server.listen_fd < 0) {
                ERR("Error initializing named socket (%s). Check that directory exists and that it is writable.", name);
@@ -533,6 +608,15 @@ int ustcomm_init_ustd(struct ustcomm_ustd *handle, const char *sock_path)
                asprintf(&name, "%s", sock_path);
        }
        else {
+               int result;
+
+               /* Only check if socket dir exists if we are using the default directory */
+               result = ensure_dir_exists(SOCK_DIR);
+               if(result == -1) {
+                       ERR("Unable to create socket directory %s", SOCK_DIR);
+                       return -1;
+               }
+
                asprintf(&name, "%s/%s", SOCK_DIR, "ustd");
        }
 
@@ -576,7 +660,7 @@ void ustcomm_fini_app(struct ustcomm_app *handle)
        }
 }
 
-static char *find_tok(char *str)
+static const char *find_tok(const char *str)
 {
        while(*str == ' ') {
                str++;
@@ -588,7 +672,7 @@ static char *find_tok(char *str)
        return str;
 }
 
-static char *find_sep(char *str)
+static const char *find_sep(const char *str)
 {
        while(*str != ' ') {
                str++;
@@ -600,11 +684,11 @@ static char *find_sep(char *str)
        return str;
 }
 
-int nth_token_is(char *str, char *token, int tok_no)
+int nth_token_is(const char *str, const char *token, int tok_no)
 {
        int i;
-       char *start;
-       char *end;
+       const char *start;
+       const char *end;
 
        for(i=0; i<=tok_no; i++) {
                str = find_tok(str);
@@ -629,12 +713,12 @@ int nth_token_is(char *str, char *token, int tok_no)
        return 1;
 }
 
-char *nth_token(char *str, int tok_no)
+char *nth_token(const char *str, int tok_no)
 {
        static char *retval = NULL;
        int i;
-       char *start;
-       char *end;
+       const char *start;
+       const char *end;
 
        for(i=0; i<=tok_no; i++) {
                str = find_tok(str);
This page took 0.027152 seconds and 4 git commands to generate.