diff --git a/src/lib/tls/include.am b/src/lib/tls/include.am index b25e2e16bf..a664b29fb2 100644 --- a/src/lib/tls/include.am +++ b/src/lib/tls/include.am @@ -12,6 +12,7 @@ src_lib_libtor_tls_a_SOURCES = \ if USE_NSS src_lib_libtor_tls_a_SOURCES += \ + src/lib/tls/nss_countbytes.c \ src/lib/tls/tortls_nss.c \ src/lib/tls/x509_nss.c else @@ -31,6 +32,7 @@ src_lib_libtor_tls_testing_a_CFLAGS = \ noinst_HEADERS += \ src/lib/tls/ciphers.inc \ src/lib/tls/buffers_tls.h \ + src/lib/tls/nss_countbytes.h \ src/lib/tls/tortls.h \ src/lib/tls/tortls_internal.h \ src/lib/tls/tortls_st.h \ diff --git a/src/lib/tls/nss_countbytes.c b/src/lib/tls/nss_countbytes.c new file mode 100644 index 0000000000..c727684529 --- /dev/null +++ b/src/lib/tls/nss_countbytes.c @@ -0,0 +1,244 @@ +/* Copyright 2018, The Tor Project Inc. */ +/* See LICENSE for licensing information */ + +/** + * \file nss_countbytes.c + * \brief A PRFileDesc layer to let us count the number of bytes + * bytes actually written on a PRFileDesc. + **/ + +#include "orconfig.h" + +#include "lib/log/util_bug.h" +#include "lib/malloc/malloc.h" +#include "lib/tls/nss_countbytes.h" + +#include +#include + +#include + +/** Boolean: have we initialized this module */ +static bool countbytes_initialized = false; + +/** Integer to identity this layer. */ +static PRDescIdentity countbytes_layer_id = PR_INVALID_IO_LAYER; + +/** Table of methods for this layer.*/ +static PRIOMethods countbytes_methods; + +/** Default close function provided by NSPR. We use this to help + * implement our own close function.*/ +static PRStatus(*default_close_fn)(PRFileDesc *fd); + +static PRStatus countbytes_close_fn(PRFileDesc *fd); +static PRInt32 countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount); +static PRInt32 countbytes_write_fn(PRFileDesc *fd, const void *buf, + PRInt32 amount); +static PRInt32 countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov, + PRInt32 size, PRIntervalTime timeout); +static PRInt32 countbytes_send_fn(PRFileDesc *fd, const void *buf, + PRInt32 amount, PRIntn flags, + PRIntervalTime timeout); +static PRInt32 countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount, + PRIntn flags, PRIntervalTime timeout); + +/** Private fields for the byte-counter layer. We cast this to and from + * PRFilePrivate*, which is supposed to be allowed. */ +typedef struct tor_nss_bytecounts_t { + uint64_t n_read; + uint64_t n_written; +} tor_nss_bytecounts_t; + +/** + * Initialize this module, if it is not already initialized. + **/ +void +tor_nss_countbytes_init(void) +{ + if (countbytes_initialized) + return; + + countbytes_layer_id = PR_GetUniqueIdentity("Tor byte-counting layer"); + tor_assert(countbytes_layer_id != PR_INVALID_IO_LAYER); + + memcpy(&countbytes_methods, PR_GetDefaultIOMethods(), sizeof(PRIOMethods)); + + default_close_fn = countbytes_methods.close; + countbytes_methods.close = countbytes_close_fn; + countbytes_methods.read = countbytes_read_fn; + countbytes_methods.write = countbytes_write_fn; + countbytes_methods.writev = countbytes_writev_fn; + countbytes_methods.send = countbytes_send_fn; + countbytes_methods.recv = countbytes_recv_fn; + /* NOTE: We aren't wrapping recvfrom, sendto, or sendfile, since I think + * NSS won't be using them for TLS connections. */ + + countbytes_initialized = true; +} + +/** + * Return the tor_nss_bytecounts_t object for a given IO layer. Asserts that + * the IO layer is in fact a layer created by this module. + */ +static tor_nss_bytecounts_t * +get_counts(PRFileDesc *fd) +{ + tor_assert(fd->identity == countbytes_layer_id); + return (tor_nss_bytecounts_t*) fd->secret; +} + +/** Helper: increment the read-count of an fd by n. */ +#define INC_READ(fd, n) STMT_BEGIN \ + get_counts(fd)->n_read += (n); \ + STMT_END + +/** Helper: increment the write-count of an fd by n. */ +#define INC_WRITTEN(fd, n) STMT_BEGIN \ + get_counts(fd)->n_written += (n); \ + STMT_END + +/** Implementation for PR_Close: frees the 'secret' field, then passes control + * to the default close function */ +static PRStatus +countbytes_close_fn(PRFileDesc *fd) +{ + tor_assert(fd); + + tor_nss_bytecounts_t *counts = (tor_nss_bytecounts_t *)fd->secret; + tor_free(counts); + fd->secret = NULL; + + return default_close_fn(fd); +} + +/** Implementation for PR_Read: Calls the lower-level read function, + * and records what it said. */ +static PRInt32 +countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount) +{ + tor_assert(fd); + tor_assert(fd->lower); + + PRInt32 result = (fd->lower->methods->read)(fd->lower, buf, amount); + if (result > 0) + INC_READ(fd, result); + return result; +} +/** Implementation for PR_Write: Calls the lower-level write function, + * and records what it said. */ +static PRInt32 +countbytes_write_fn(PRFileDesc *fd, const void *buf, PRInt32 amount) +{ + tor_assert(fd); + tor_assert(fd->lower); + + PRInt32 result = (fd->lower->methods->write)(fd->lower, buf, amount); + if (result > 0) + INC_WRITTEN(fd, result); + return result; +} +/** Implementation for PR_Writev: Calls the lower-level writev function, + * and records what it said. */ +static PRInt32 +countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov, + PRInt32 size, PRIntervalTime timeout) +{ + tor_assert(fd); + tor_assert(fd->lower); + + PRInt32 result = (fd->lower->methods->writev)(fd->lower, iov, size, timeout); + if (result > 0) + INC_WRITTEN(fd, result); + return result; +} +/** Implementation for PR_Send: Calls the lower-level send function, + * and records what it said. */ +static PRInt32 +countbytes_send_fn(PRFileDesc *fd, const void *buf, + PRInt32 amount, PRIntn flags, PRIntervalTime timeout) +{ + tor_assert(fd); + tor_assert(fd->lower); + + PRInt32 result = (fd->lower->methods->send)(fd->lower, buf, amount, flags, + timeout); + if (result > 0) + INC_WRITTEN(fd, result); + return result; +} +/** Implementation for PR_Recv: Calls the lower-level recv function, + * and records what it said. */ +static PRInt32 +countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount, + PRIntn flags, PRIntervalTime timeout) +{ + tor_assert(fd); + tor_assert(fd->lower); + + PRInt32 result = (fd->lower->methods->recv)(fd->lower, buf, amount, flags, + timeout); + if (result > 0) + INC_READ(fd, result); + return result; +} + +/** + * Wrap a PRFileDesc from NSPR with a new PRFileDesc that will count the + * total number of bytes read and written. Return the new PRFileDesc. + * + * This function takes ownership of its input. + */ +PRFileDesc * +tor_wrap_prfiledesc_with_byte_counter(PRFileDesc *stack) +{ + if (BUG(! countbytes_initialized)) { + tor_nss_countbytes_init(); + } + + tor_nss_bytecounts_t *bytecounts = tor_malloc_zero(sizeof(*bytecounts)); + + PRFileDesc *newfd = PR_CreateIOLayerStub(countbytes_layer_id, + &countbytes_methods); + tor_assert(newfd); + newfd->secret = (PRFilePrivate *)bytecounts; + + /* This does some complicated messing around with the headers of these + objects; see the NSPR documentation for more. The upshot is that + after PushIOLayer, "stack" will be the head of the stack. + */ + PRStatus status = PR_PushIOLayer(stack, PR_TOP_IO_LAYER, newfd); + tor_assert(status == PR_SUCCESS); + + return stack; +} + +/** + * Given a PRFileDesc returned by tor_wrap_prfiledesc_with_byte_counter(), + * or another PRFileDesc wrapping that PRFileDesc, set the provided + * pointers to the number of bytes read and written on the descriptor since + * it was created. + * + * Return 0 on success, -1 on failure. + */ +int +tor_get_prfiledesc_byte_counts(PRFileDesc *fd, + uint64_t *n_read_out, + uint64_t *n_written_out) +{ + if (BUG(! countbytes_initialized)) { + tor_nss_countbytes_init(); + } + + tor_assert(fd); + PRFileDesc *bclayer = PR_GetIdentitiesLayer(fd, countbytes_layer_id); + if (BUG(bclayer == NULL)) + return -1; + + tor_nss_bytecounts_t *counts = get_counts(bclayer); + + *n_read_out = counts->n_read; + *n_written_out = counts->n_written; + + return 0; +} diff --git a/src/lib/tls/nss_countbytes.h b/src/lib/tls/nss_countbytes.h new file mode 100644 index 0000000000..f26280edf2 --- /dev/null +++ b/src/lib/tls/nss_countbytes.h @@ -0,0 +1,25 @@ +/* Copyright 2018, The Tor Project, Inc. */ +/* See LICENSE for licensing information */ + +/** + * \file nss_countbytes.h + * \brief Header for nss_countbytes.c, which lets us count the number of + * bytes actually written on a PRFileDesc. + **/ + +#ifndef TOR_NSS_COUNTBYTES_H +#define TOR_NSS_COUNTBYTES_H + +#include "lib/cc/torint.h" + +void tor_nss_countbytes_init(void); + +struct PRFileDesc; +struct PRFileDesc *tor_wrap_prfiledesc_with_byte_counter( + struct PRFileDesc *stack); + +int tor_get_prfiledesc_byte_counts(struct PRFileDesc *fd, + uint64_t *n_read_out, + uint64_t *n_written_out); + +#endif diff --git a/src/lib/tls/tortls_nss.c b/src/lib/tls/tortls_nss.c index 53adfedf32..0944c57a34 100644 --- a/src/lib/tls/tortls_nss.c +++ b/src/lib/tls/tortls_nss.c @@ -31,11 +31,12 @@ #include "lib/tls/tortls.h" #include "lib/tls/tortls_st.h" #include "lib/tls/tortls_internal.h" +#include "lib/tls/nss_countbytes.h" #include "lib/log/util_bug.h" DISABLE_GCC_WARNING(strict-prototypes) #include -// For access to raw sockets. +// For access to rar sockets. #include #include #include @@ -158,6 +159,8 @@ tor_tls_context_new(crypto_pk_t *identity, SECStatus s; tor_assert(identity); + tor_tls_init(); + tor_tls_context_t *ctx = tor_malloc_zero(sizeof(tor_tls_context_t)); ctx->refcnt = 1; @@ -320,7 +323,7 @@ tor_tls_get_state_description(tor_tls_t *tls, char *buf, size_t sz) void tor_tls_init(void) { - /* We don't have any global setup to do yet, but that will change */ + tor_nss_countbytes_init(); } void @@ -373,7 +376,11 @@ tor_tls_new(tor_socket_t sock, int is_server) if (!tcp) return NULL; - PRFileDesc *ssl = SSL_ImportFD(ctx->ctx, tcp); + PRFileDesc *count = tor_wrap_prfiledesc_with_byte_counter(tcp); + if (! count) + return NULL; + + PRFileDesc *ssl = SSL_ImportFD(ctx->ctx, count); if (!ssl) { PR_Close(tcp); return NULL; @@ -465,7 +472,6 @@ tor_tls_read, (tor_tls_t *tls, char *cp, size_t len)) PRInt32 rv = PR_Read(tls->ssl, cp, (int)len); // log_debug(LD_NET, "PR_Read(%zu) returned %d", n, (int)rv); if (rv > 0) { - tls->n_read_since_last_check += rv; return rv; } if (rv == 0) @@ -489,7 +495,6 @@ tor_tls_write(tor_tls_t *tls, const char *cp, size_t n) PRInt32 rv = PR_Write(tls->ssl, cp, (int)n); // log_debug(LD_NET, "PR_Write(%zu) returned %d", n, (int)rv); if (rv > 0) { - tls->n_written_since_last_check += rv; return rv; } if (rv == 0) @@ -579,13 +584,17 @@ tor_tls_get_n_raw_bytes(tor_tls_t *tls, tor_assert(tls); tor_assert(n_read); tor_assert(n_written); - /* XXXX We don't curently have a way to measure this information correctly - * in NSS; we could do that with a PRIO layer, but it'll take a little - * coding. For now, we just track the number of bytes sent _in_ the TLS - * stream. Doing this will make our rate-limiting slightly inaccurate. */ - *n_read = tls->n_read_since_last_check; - *n_written = tls->n_written_since_last_check; - tls->n_read_since_last_check = tls->n_written_since_last_check = 0; + uint64_t r, w; + if (tor_get_prfiledesc_byte_counts(tls->ssl, &r, &w) < 0) { + *n_read = *n_written = 0; + return; + } + + *n_read = (size_t)(r - tls->last_read_count); + *n_written = (size_t)(w - tls->last_write_count); + + tls->last_read_count = r; + tls->last_write_count = w; } int diff --git a/src/lib/tls/tortls_st.h b/src/lib/tls/tortls_st.h index a1b59a37af..549443a4e7 100644 --- a/src/lib/tls/tortls_st.h +++ b/src/lib/tls/tortls_st.h @@ -66,8 +66,9 @@ struct tor_tls_t { void *callback_arg; #endif #ifdef ENABLE_NSS - size_t n_read_since_last_check; - size_t n_written_since_last_check; + /** Last values retried from tor_get_prfiledesc_byte_counts(). */ + uint64_t last_write_count; + uint64_t last_read_count; #endif };