ustcomm: use connections; don't reconnect at every message
[ust.git] / libustcomm / ustcomm.c
1 #define _GNU_SOURCE
2 #include <sys/types.h>
3 #include <signal.h>
4 #include <errno.h>
5 #include <sys/socket.h>
6 #include <sys/un.h>
7 #include <unistd.h>
8 #include <poll.h>
9
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <string.h>
13 #include <execinfo.h>
14
15 #include "ustcomm.h"
16 #include "localerr.h"
17
18 #define UNIX_PATH_MAX 108
19 #define SOCK_DIR "/tmp/socks"
20 #define UST_SIGNAL SIGIO
21
22 #define MSG_MAX 1000
23
24 /* FIXME: ustcomm blocks on message sending, which might be problematic in
25 * some cases. Fix the poll() usage so sends are buffered until they don't
26 * block.
27 */
28
29 //static void bt(void)
30 //{
31 // void *buffer[100];
32 // int result;
33 //
34 // result = backtrace(&buffer, 100);
35 // backtrace_symbols_fd(buffer, result, STDERR_FILENO);
36 //}
37
38 char *strdup_malloc(const char *s)
39 {
40 char *retval;
41
42 if(s == NULL)
43 return NULL;
44
45 retval = (char *) malloc(strlen(s)+1);
46
47 strcpy(retval, s);
48
49 return retval;
50 }
51
52 static void signal_process(pid_t pid)
53 {
54 int result;
55
56 result = kill(pid, UST_SIGNAL);
57 if(result == -1) {
58 PERROR("kill");
59 return;
60 }
61
62 sleep(1);
63 }
64
65 static int send_message_fd(int fd, const char *msg)
66 {
67 int result;
68
69 result = send(fd, msg, strlen(msg), 0);
70 if(result == -1) {
71 PERROR("send");
72 return -1;
73 }
74 else if(result == 0) {
75 return 0;
76 }
77
78 return 1;
79
80 // *reply = (char *) malloc(MSG_MAX+1);
81 // result = recv(fd, *reply, MSG_MAX, 0);
82 // if(result == -1) {
83 // PERROR("recv");
84 // return -1;
85 // }
86 // else if(result == 0) {
87 // return 0;
88 // }
89 //
90 // (*reply)[result] = '\0';
91 //
92 // return 1;
93 }
94
95 static int send_message_path(const char *path, const char *msg, int signalpid)
96 {
97 int fd;
98 int result;
99 struct sockaddr_un addr;
100
101 result = fd = socket(PF_UNIX, SOCK_STREAM, 0);
102 if(result == -1) {
103 PERROR("socket");
104 return -1;
105 }
106
107 addr.sun_family = AF_UNIX;
108
109 result = snprintf(addr.sun_path, UNIX_PATH_MAX, "%s", path);
110 if(result >= UNIX_PATH_MAX) {
111 ERR("string overflow allocating socket name");
112 return -1;
113 }
114
115 if(signalpid >= 0)
116 signal_process(signalpid);
117
118 result = connect(fd, (struct sockaddr *)&addr, sizeof(addr));
119 if(result == -1) {
120 PERROR("connect");
121 return -1;
122 }
123
124 return send_message_fd(fd, msg);
125 }
126
127 ///* pid: the pid of the trace process that must receive the msg
128 // msg: pointer to a null-terminated message to send
129 // reply: location where to put the null-terminated string of the reply;
130 // it must be free'd after usage
131 // */
132 //
133 //int send_message_pid(pid_t pid, const char *msg, char **reply)
134 //{
135 // int result;
136 // char path[UNIX_PATH_MAX];
137 //
138 // result = snprintf(path, UNIX_PATH_MAX, "%s/%d", SOCK_DIR, pid);
139 // if(result >= UNIX_PATH_MAX) {
140 // fprintf(stderr, "string overflow allocating socket name");
141 // return -1;
142 // }
143 //
144 // send_message_path(path, msg, reply, pid);
145 //
146 // return 0;
147 //}
148
149 /* Called by an app to ask the consumer daemon to connect to it. */
150
151 int ustcomm_request_consumer(pid_t pid, const char *channel)
152 {
153 char path[UNIX_PATH_MAX];
154 int result;
155 char *msg;
156
157 result = snprintf(path, UNIX_PATH_MAX, "%s/ustd", SOCK_DIR);
158 if(result >= UNIX_PATH_MAX) {
159 fprintf(stderr, "string overflow allocating socket name");
160 return -1;
161 }
162
163 asprintf(&msg, "collect %d %s", pid, channel);
164
165 send_message_path(path, msg, -1);
166 free(msg);
167
168 return 0;
169 }
170
171 /* returns 1 to indicate a message was received
172 * returns 0 to indicate no message was received (cannot happen)
173 * returns -1 to indicate an error
174 */
175
176 static int recv_message_fd(int fd, char **msg, struct ustcomm_source *src)
177 {
178 int result;
179
180 *msg = (char *) malloc(MSG_MAX+1);
181
182 result = recv(fd, *msg, MSG_MAX, 0);
183 if(result == -1) {
184 PERROR("recv");
185 return -1;
186 }
187
188 (*msg)[result] = '\0';
189
190 DBG("ustcomm_app_recv_message: result is %d, message is %s", result, (*msg));
191
192 if(src)
193 src->fd = fd;
194
195 return 1;
196 }
197
198 int ustcomm_send_reply(struct ustcomm_server *server, char *msg, struct ustcomm_source *src)
199 {
200 int result;
201
202 result = send_message_fd(src->fd, msg);
203 if(result < 0) {
204 ERR("error in send_message_fd");
205 return -1;
206 }
207
208 return 0;
209 }
210
211 /* @timeout: max blocking time in milliseconds, -1 means infinity
212 *
213 * returns 1 to indicate a message was received
214 * returns 0 to indicate no message was received
215 * returns -1 to indicate an error
216 */
217
218 int ustcomm_recv_message(struct ustcomm_server *server, char **msg, struct ustcomm_source *src, int timeout)
219 {
220 struct pollfd *fds;
221 struct ustcomm_connection *conn;
222 int result;
223 int retval;
224
225 for(;;) {
226 int idx = 0;
227 int n_fds = 1;
228
229 list_for_each_entry(conn, &server->connections, list) {
230 n_fds++;
231 }
232
233 fds = (struct pollfd *) malloc(n_fds * sizeof(struct pollfd));
234 if(fds == NULL) {
235 ERR("malloc returned NULL");
236 return -1;
237 }
238
239 /* special idx 0 is for listening socket */
240 fds[idx].fd = server->listen_fd;
241 fds[idx].events = POLLIN;
242 idx++;
243
244 list_for_each_entry(conn, &server->connections, list) {
245 fds[idx].fd = conn->fd;
246 fds[idx].events = POLLIN;
247 idx++;
248 }
249
250 result = poll(fds, n_fds, timeout);
251 if(result == -1) {
252 PERROR("poll");
253 return -1;
254 }
255
256 if(result == 0)
257 return 0;
258
259 if(fds[0].revents) {
260 struct ustcomm_connection *newconn;
261 int newfd;
262
263 result = newfd = accept(server->listen_fd, NULL, NULL);
264 if(result == -1) {
265 PERROR("accept");
266 return -1;
267 }
268
269 newconn = (struct ustcomm_connection *) malloc(sizeof(struct ustcomm_connection));
270 if(newconn == NULL) {
271 ERR("malloc returned NULL");
272 return -1;
273 }
274
275 newconn->fd = newfd;
276
277 list_add(&newconn->list, &server->connections);
278 }
279
280 for(idx=1; idx<n_fds; idx++) {
281 if(fds[idx].revents) {
282 retval = recv_message_fd(fds[idx].fd, msg, src);
283 if(**msg == 0) {
284 /* connection finished */
285 close(fds[idx].fd);
286
287 list_for_each_entry(conn, &server->connections, list) {
288 if(conn->fd == fds[idx].fd) {
289 list_del(&conn->list);
290 break;
291 }
292 }
293 }
294 else {
295 goto free_fds_return;
296 }
297 }
298 }
299
300 free(fds);
301 }
302
303 free_fds_return:
304 free(fds);
305 return retval;
306 }
307
308 int ustcomm_ustd_recv_message(struct ustcomm_ustd *ustd, char **msg, struct ustcomm_source *src, int timeout)
309 {
310 return ustcomm_recv_message(&ustd->server, msg, src, timeout);
311 }
312
313 int ustcomm_app_recv_message(struct ustcomm_app *app, char **msg, struct ustcomm_source *src, int timeout)
314 {
315 return ustcomm_recv_message(&app->server, msg, src, timeout);
316 }
317
318 /* This removes src from the list of active connections of app.
319 */
320
321 int ustcomm_app_detach_client(struct ustcomm_app *app, struct ustcomm_source *src)
322 {
323 struct ustcomm_server *server = (struct ustcomm_server *)app;
324 struct ustcomm_connection *conn;
325
326 list_for_each_entry(conn, &server->connections, list) {
327 if(conn->fd == src->fd) {
328 list_del(&conn->list);
329 goto found;
330 }
331 }
332
333 return -1;
334 found:
335 return src->fd;
336 }
337
338 static int init_named_socket(char *name, char **path_out)
339 {
340 int result;
341 int fd;
342
343 struct sockaddr_un addr;
344
345 result = fd = socket(PF_UNIX, SOCK_STREAM, 0);
346 if(result == -1) {
347 PERROR("socket");
348 return -1;
349 }
350
351 addr.sun_family = AF_UNIX;
352
353 strncpy(addr.sun_path, name, UNIX_PATH_MAX);
354 addr.sun_path[UNIX_PATH_MAX-1] = '\0';
355
356 result = access(name, F_OK);
357 if(result == 0) {
358 /* file exists */
359 result = unlink(name);
360 if(result == -1) {
361 PERROR("unlink of socket file");
362 goto close_sock;
363 }
364 WARN("socket already exists; overwriting");
365 }
366
367 result = bind(fd, (struct sockaddr *)&addr, sizeof(addr));
368 if(result == -1) {
369 PERROR("bind");
370 goto close_sock;
371 }
372
373 result = listen(fd, 1);
374 if(result == -1) {
375 PERROR("listen");
376 goto close_sock;
377 }
378
379 if(path_out) {
380 *path_out = "";
381 *path_out = strdupa(addr.sun_path);
382 }
383
384 return fd;
385
386 close_sock:
387 close(fd);
388
389 return -1;
390 }
391
392 int ustcomm_send_request(struct ustcomm_connection *conn, char *req, char **reply)
393 {
394 int result;
395
396 result = send(conn->fd, req, strlen(req), 0);
397 if(result == -1) {
398 PERROR("send");
399 return -1;
400 }
401 else if(result == 0) {
402 return 0;
403 }
404
405 if(!reply)
406 return 1;
407
408 *reply = (char *) malloc(MSG_MAX+1);
409 result = recv(conn->fd, *reply, MSG_MAX, 0);
410 if(result == -1) {
411 PERROR("recv");
412 return -1;
413 }
414 else if(result == 0) {
415 return 0;
416 }
417
418 (*reply)[result] = '\0';
419
420 return 1;
421 }
422
423 int ustcomm_connect_path(char *path, struct ustcomm_connection *conn, pid_t signalpid)
424 {
425 int fd;
426 int result;
427 struct sockaddr_un addr;
428
429 result = fd = socket(PF_UNIX, SOCK_STREAM, 0);
430 if(result == -1) {
431 PERROR("socket");
432 return -1;
433 }
434
435 addr.sun_family = AF_UNIX;
436
437 result = snprintf(addr.sun_path, UNIX_PATH_MAX, "%s", path);
438 if(result >= UNIX_PATH_MAX) {
439 ERR("string overflow allocating socket name");
440 return -1;
441 }
442
443 if(signalpid >= 0)
444 signal_process(signalpid);
445
446 result = connect(fd, (struct sockaddr *)&addr, sizeof(addr));
447 if(result == -1) {
448 PERROR("connect");
449 return -1;
450 }
451
452 conn->fd = fd;
453
454 return 0;
455 }
456
457 int ustcomm_disconnect(struct ustcomm_connection *conn)
458 {
459 return close(conn->fd);
460 }
461
462 int ustcomm_connect_app(pid_t pid, struct ustcomm_connection *conn)
463 {
464 int result;
465 char path[UNIX_PATH_MAX];
466
467
468 result = snprintf(path, UNIX_PATH_MAX, "%s/%d", SOCK_DIR, pid);
469 if(result >= UNIX_PATH_MAX) {
470 fprintf(stderr, "string overflow allocating socket name");
471 return -1;
472 }
473
474 return ustcomm_connect_path(path, conn, pid);
475 }
476
477 int ustcomm_disconnect_app(struct ustcomm_connection *conn)
478 {
479 close(conn->fd);
480 return 0;
481 }
482
483 int ustcomm_init_app(pid_t pid, struct ustcomm_app *handle)
484 {
485 int result;
486 char *name;
487
488 result = asprintf(&name, "%s/%d", SOCK_DIR, (int)pid);
489 if(result >= UNIX_PATH_MAX) {
490 ERR("string overflow allocating socket name");
491 return -1;
492 }
493
494 handle->server.listen_fd = init_named_socket(name, &(handle->server.socketpath));
495 if(handle->server.listen_fd < 0) {
496 ERR("error initializing named socket");
497 goto free_name;
498 }
499 free(name);
500
501 INIT_LIST_HEAD(&handle->server.connections);
502
503 return 0;
504
505 free_name:
506 free(name);
507 return -1;
508 }
509
510 int ustcomm_init_ustd(struct ustcomm_ustd *handle)
511 {
512 int result;
513 char *name;
514
515 result = asprintf(&name, "%s/%s", SOCK_DIR, "ustd");
516 if(result >= UNIX_PATH_MAX) {
517 ERR("string overflow allocating socket name");
518 return -1;
519 }
520
521 handle->server.listen_fd = init_named_socket(name, &handle->server.socketpath);
522 if(handle->server.listen_fd < 0) {
523 ERR("error initializing named socket");
524 goto free_name;
525 }
526 free(name);
527
528 INIT_LIST_HEAD(&handle->server.connections);
529
530 return 0;
531
532 free_name:
533 free(name);
534 return -1;
535 }
536
537 static char *find_tok(char *str)
538 {
539 while(*str == ' ') {
540 str++;
541
542 if(*str == 0)
543 return NULL;
544 }
545
546 return str;
547 }
548
549 static char *find_sep(char *str)
550 {
551 while(*str != ' ') {
552 str++;
553
554 if(*str == 0)
555 break;
556 }
557
558 return str;
559 }
560
561 int nth_token_is(char *str, char *token, int tok_no)
562 {
563 int i;
564 char *start;
565 char *end;
566
567 for(i=0; i<=tok_no; i++) {
568 str = find_tok(str);
569 if(str == NULL)
570 return -1;
571
572 start = str;
573
574 str = find_sep(str);
575 if(str == NULL)
576 return -1;
577
578 end = str;
579 }
580
581 if(end-start != strlen(token))
582 return 0;
583
584 if(strncmp(start, token, end-start))
585 return 0;
586
587 return 1;
588 }
589
590 char *nth_token(char *str, int tok_no)
591 {
592 static char *retval = NULL;
593 int i;
594 char *start;
595 char *end;
596
597 for(i=0; i<=tok_no; i++) {
598 str = find_tok(str);
599 if(str == NULL)
600 return NULL;
601
602 start = str;
603
604 str = find_sep(str);
605 if(str == NULL)
606 return NULL;
607
608 end = str;
609 }
610
611 if(retval) {
612 free(retval);
613 retval = NULL;
614 }
615
616 asprintf(&retval, "%.*s", (int)(end-start), start);
617
618 return retval;
619 }
620
This page took 0.041837 seconds and 4 git commands to generate.