Split read_all and write_all into separate functions

This commit is contained in:
Nick Mathewson 2018-06-27 10:47:09 -04:00
parent 05040a9e84
commit 67135ca8e0
5 changed files with 67 additions and 25 deletions

View File

@ -1076,22 +1076,17 @@ format_time_interval(char *out, size_t out_len, long interval)
* File helpers * File helpers
* ===== */ * ===== */
/** Write <b>count</b> bytes from <b>buf</b> to <b>fd</b>. <b>isSocket</b> /** Write <b>count</b> bytes from <b>buf</b> to <b>fd</b>. Return the number
* must be 1 if fd was returned by socket() or accept(), and 0 if fd * of bytes written, or -1 on error. Only use if fd is a blocking fd. */
* was returned by open(). Return the number of bytes written, or -1
* on error. Only use if fd is a blocking fd. */
ssize_t ssize_t
write_all(tor_socket_t fd, const char *buf, size_t count, int isSocket) write_all_to_fd(int fd, const char *buf, size_t count)
{ {
size_t written = 0; size_t written = 0;
ssize_t result; ssize_t result;
raw_assert(count < SSIZE_MAX); raw_assert(count < SSIZE_MAX);
while (written != count) { while (written != count) {
if (isSocket) result = write(fd, buf+written, count-written);
result = tor_socket_send(fd, buf+written, count-written, 0);
else
result = write((int)fd, buf+written, count-written);
if (result<0) if (result<0)
return -1; return -1;
written += result; written += result;
@ -1099,13 +1094,29 @@ write_all(tor_socket_t fd, const char *buf, size_t count, int isSocket)
return (ssize_t)count; return (ssize_t)count;
} }
/** Read from <b>fd</b> to <b>buf</b>, until we get <b>count</b> bytes /** Write <b>count</b> bytes from <b>buf</b> to <b>sock</b>. Return the number
* or reach the end of the file. <b>isSocket</b> must be 1 if fd * of bytes written, or -1 on error. Only use if fd is a blocking fd. */
* was returned by socket() or accept(), and 0 if fd was returned by
* open(). Return the number of bytes read, or -1 on error. Only use
* if fd is a blocking fd. */
ssize_t ssize_t
read_all(tor_socket_t fd, char *buf, size_t count, int isSocket) write_all_to_socket(tor_socket_t fd, const char *buf, size_t count)
{
size_t written = 0;
ssize_t result;
raw_assert(count < SSIZE_MAX);
while (written != count) {
result = tor_socket_send(fd, buf+written, count-written, 0);
if (result<0)
return -1;
written += result;
}
return (ssize_t)count;
}
/** Read from <b>fd</b> to <b>buf</b>, until we get <b>count</b> bytes or
* reach the end of the file. Return the number of bytes read, or -1 on
* error. Only use if fd is a blocking fd. */
ssize_t
read_all_from_fd(int fd, char *buf, size_t count)
{ {
size_t numread = 0; size_t numread = 0;
ssize_t result; ssize_t result;
@ -1116,10 +1127,32 @@ read_all(tor_socket_t fd, char *buf, size_t count, int isSocket)
} }
while (numread < count) { while (numread < count) {
if (isSocket) result = read(fd, buf+numread, count-numread);
result = tor_socket_recv(fd, buf+numread, count-numread, 0); if (result<0)
else return -1;
result = read((int)fd, buf+numread, count-numread); else if (result == 0)
break;
numread += result;
}
return (ssize_t)numread;
}
/** Read from <b>sock</b> to <b>buf</b>, until we get <b>count</b> bytes or
* reach the end of the file. Return the number of bytes read, or -1 on
* error. Only use if fd is a blocking fd. */
ssize_t
read_all_from_socket(tor_socket_t sock, char *buf, size_t count)
{
size_t numread = 0;
ssize_t result;
if (count > SIZE_T_CEILING || count > SSIZE_MAX) {
errno = EINVAL;
return -1;
}
while (numread < count) {
result = tor_socket_recv(sock, buf+numread, count-numread, 0);
if (result<0) if (result<0)
return -1; return -1;
else if (result == 0) else if (result == 0)

View File

@ -117,8 +117,17 @@ int parse_http_time(const char *buf, struct tm *tm);
int format_time_interval(char *out, size_t out_len, long interval); int format_time_interval(char *out, size_t out_len, long interval);
/* File helpers */ /* File helpers */
ssize_t write_all(tor_socket_t fd, const char *buf, size_t count,int isSocket); ssize_t write_all_to_fd(int fd, const char *buf, size_t count);
ssize_t read_all(tor_socket_t fd, char *buf, size_t count, int isSocket); ssize_t write_all_to_socket(tor_socket_t fd, const char *buf, size_t count);
ssize_t read_all_from_fd(int fd, char *buf, size_t count);
ssize_t read_all_from_socket(tor_socket_t fd, char *buf, size_t count);
#define write_all(fd, buf, count, isSock) \
((isSock) ? write_all_to_socket((fd), (buf), (count)) \
: write_all_to_fd((int)(fd), (buf), (count)))
#define read_all(fd, buf, count, isSock) \
((isSock) ? read_all_from_socket((fd), (buf), (count)) \
: read_all_from_fd((int)(fd), (buf), (count)))
/** Status of an I/O stream. */ /** Status of an I/O stream. */
enum stream_status { enum stream_status {

View File

@ -94,7 +94,7 @@ tor_ftruncate(int fd)
/** Minimal version of write_all, for use by logging. */ /** Minimal version of write_all, for use by logging. */
int int
write_all_to_fd(int fd, const char *buf, size_t count) write_all_to_fd_minimal(int fd, const char *buf, size_t count)
{ {
size_t written = 0; size_t written = 0;
raw_assert(count < SSIZE_MAX); raw_assert(count < SSIZE_MAX);

View File

@ -12,6 +12,6 @@ off_t tor_fd_getpos(int fd);
int tor_fd_setpos(int fd, off_t pos); int tor_fd_setpos(int fd, off_t pos);
int tor_fd_seekend(int fd); int tor_fd_seekend(int fd);
int tor_ftruncate(int fd); int tor_ftruncate(int fd);
int write_all_to_fd(int fd, const char *buf, size_t count); int write_all_to_fd_minimal(int fd, const char *buf, size_t count);
#endif /* !defined(TOR_FDIO_H) */ #endif /* !defined(TOR_FDIO_H) */

View File

@ -346,7 +346,7 @@ log_tor_version(logfile_t *lf, int reset)
tor_snprintf(buf+n, sizeof(buf)-n, tor_snprintf(buf+n, sizeof(buf)-n,
"Tor %s opening %slog file.\n", VERSION, is_new?"new ":""); "Tor %s opening %slog file.\n", VERSION, is_new?"new ":"");
} }
if (write_all_to_fd(lf->fd, buf, strlen(buf)) < 0) /* error */ if (write_all_to_fd_minimal(lf->fd, buf, strlen(buf)) < 0) /* error */
return -1; /* failed */ return -1; /* failed */
return 0; return 0;
} }
@ -560,7 +560,7 @@ logfile_deliver(logfile_t *lf, const char *buf, size_t msg_len,
lf->callback(severity, domain, msg_after_prefix); lf->callback(severity, domain, msg_after_prefix);
} }
} else { } else {
if (write_all_to_fd(lf->fd, buf, msg_len) < 0) { /* error */ if (write_all_to_fd_minimal(lf->fd, buf, msg_len) < 0) { /* error */
/* don't log the error! mark this log entry to be blown away, and /* don't log the error! mark this log entry to be blown away, and
* continue. */ * continue. */
lf->seems_dead = 1; lf->seems_dead = 1;