Use a mutex to protect the count of open sockets.

This matters because a cpuworker can close its socket when it
finishes.  Cpuworker typically runs in another thread, so without a
lock here, we can have a race condition and get confused about how
many sockets are open.  Possible fix for bug 939.
This commit is contained in:
Nick Mathewson 2009-05-12 16:17:32 -04:00
parent a28215a150
commit c36efb0c45
2 changed files with 45 additions and 10 deletions

View File

@ -5,6 +5,9 @@ Changes in version 0.2.1.15??? - ????-??-??
Bugfix on 0.2.0.9-alpha. Bugfix on 0.2.0.9-alpha.
- Provide a more useful log message if bug 977 (related to buffer - Provide a more useful log message if bug 977 (related to buffer
freelists) ever reappears, and do not crash right away. freelists) ever reappears, and do not crash right away.
- Protect the count of open sockets with a mutex, so we can't
corrupt it when two threads are closing or opening sockets at once.
Fix for bug 939. Bugfix on 0.2.0.1-alpha.
Changes in version 0.2.1.14-rc - 2009-04-12 Changes in version 0.2.1.14-rc - 2009-04-12

View File

@ -676,6 +676,23 @@ static int max_socket = -1;
* eventdns and libevent.) */ * eventdns and libevent.) */
static int n_sockets_open = 0; static int n_sockets_open = 0;
/** Mutex to protect open_sockets, max_socket, and n_sockets_open. */
static tor_mutex_t *socket_accounting_mutex = NULL;
static INLINE void
socket_accounting_lock(void)
{
if (PREDICT_UNLIKELY(!socket_accounting_mutex))
socket_accounting_mutex = tor_mutex_new();
tor_mutex_acquire(socket_accounting_mutex);
}
static INLINE void
socket_accounting_unlock(void)
{
tor_mutex_release(socket_accounting_mutex);
}
/** As close(), but guaranteed to work for sockets across platforms (including /** As close(), but guaranteed to work for sockets across platforms (including
* Windows, where close()ing a socket doesn't work. Returns 0 on success, -1 * Windows, where close()ing a socket doesn't work. Returns 0 on success, -1
* on failure. */ * on failure. */
@ -683,15 +700,7 @@ int
tor_close_socket(int s) tor_close_socket(int s)
{ {
int r = 0; int r = 0;
#ifdef DEBUG_SOCKET_COUNTING
if (s > max_socket || ! bitarray_is_set(open_sockets, s)) {
log_warn(LD_BUG, "Closing a socket (%d) that wasn't returned by tor_open_"
"socket(), or that was already closed or something.", s);
} else {
tor_assert(open_sockets && s <= max_socket);
bitarray_clear(open_sockets, s);
}
#endif
/* On Windows, you have to call close() on fds returned by open(), /* On Windows, you have to call close() on fds returned by open(),
* and closesocket() on fds returned by socket(). On Unix, everything * and closesocket() on fds returned by socket(). On Unix, everything
* gets close()'d. We abstract this difference by always using * gets close()'d. We abstract this difference by always using
@ -703,6 +712,17 @@ tor_close_socket(int s)
#else #else
r = close(s); r = close(s);
#endif #endif
socket_accounting_lock();
#ifdef DEBUG_SOCKET_COUNTING
if (s > max_socket || ! bitarray_is_set(open_sockets, s)) {
log_warn(LD_BUG, "Closing a socket (%d) that wasn't returned by tor_open_"
"socket(), or that was already closed or something.", s);
} else {
tor_assert(open_sockets && s <= max_socket);
bitarray_clear(open_sockets, s);
}
#endif
if (r == 0) { if (r == 0) {
--n_sockets_open; --n_sockets_open;
} else { } else {
@ -717,9 +737,11 @@ tor_close_socket(int s)
#endif #endif
r = -1; r = -1;
} }
if (n_sockets_open < 0) if (n_sockets_open < 0)
log_warn(LD_BUG, "Our socket count is below zero: %d. Please submit a " log_warn(LD_BUG, "Our socket count is below zero: %d. Please submit a "
"bug report.", n_sockets_open); "bug report.", n_sockets_open);
socket_accounting_unlock();
return r; return r;
} }
@ -754,8 +776,10 @@ tor_open_socket(int domain, int type, int protocol)
{ {
int s = socket(domain, type, protocol); int s = socket(domain, type, protocol);
if (s >= 0) { if (s >= 0) {
socket_accounting_lock();
++n_sockets_open; ++n_sockets_open;
mark_socket_open(s); mark_socket_open(s);
socket_accounting_unlock();
} }
return s; return s;
} }
@ -766,8 +790,10 @@ tor_accept_socket(int sockfd, struct sockaddr *addr, socklen_t *len)
{ {
int s = accept(sockfd, addr, len); int s = accept(sockfd, addr, len);
if (s >= 0) { if (s >= 0) {
socket_accounting_lock();
++n_sockets_open; ++n_sockets_open;
mark_socket_open(s); mark_socket_open(s);
socket_accounting_unlock();
} }
return s; return s;
} }
@ -776,7 +802,11 @@ tor_accept_socket(int sockfd, struct sockaddr *addr, socklen_t *len)
int int
get_n_open_sockets(void) get_n_open_sockets(void)
{ {
return n_sockets_open; int n;
socket_accounting_lock();
n = n_sockets_open;
socket_accounting_unlock();
return n;
} }
/** Turn <b>socket</b> into a nonblocking socket. /** Turn <b>socket</b> into a nonblocking socket.
@ -817,6 +847,7 @@ tor_socketpair(int family, int type, int protocol, int fd[2])
int r; int r;
r = socketpair(family, type, protocol, fd); r = socketpair(family, type, protocol, fd);
if (r == 0) { if (r == 0) {
socket_accounting_lock();
if (fd[0] >= 0) { if (fd[0] >= 0) {
++n_sockets_open; ++n_sockets_open;
mark_socket_open(fd[0]); mark_socket_open(fd[0]);
@ -825,6 +856,7 @@ tor_socketpair(int family, int type, int protocol, int fd[2])
++n_sockets_open; ++n_sockets_open;
mark_socket_open(fd[1]); mark_socket_open(fd[1]);
} }
socket_accounting_unlock();
} }
return r < 0 ? -errno : r; return r < 0 ? -errno : r;
#else #else