mirror of
https://gitlab.torproject.org/tpo/core/tor.git
synced 2024-11-11 05:33:47 +01:00
Add support for lower-level byte counting with NSS
This is harder than with OpenSSL, since OpenSSL counts the bytes on its own and NSS doesn't. To fix this, we need to define a new PRFileDesc layer that has its own byte-counting support. Closes ticket 27289.
This commit is contained in:
parent
4b646e30d8
commit
126819c947
@ -12,6 +12,7 @@ src_lib_libtor_tls_a_SOURCES = \
|
||||
|
||||
if USE_NSS
|
||||
src_lib_libtor_tls_a_SOURCES += \
|
||||
src/lib/tls/nss_countbytes.c \
|
||||
src/lib/tls/tortls_nss.c \
|
||||
src/lib/tls/x509_nss.c
|
||||
else
|
||||
@ -31,6 +32,7 @@ src_lib_libtor_tls_testing_a_CFLAGS = \
|
||||
noinst_HEADERS += \
|
||||
src/lib/tls/ciphers.inc \
|
||||
src/lib/tls/buffers_tls.h \
|
||||
src/lib/tls/nss_countbytes.h \
|
||||
src/lib/tls/tortls.h \
|
||||
src/lib/tls/tortls_internal.h \
|
||||
src/lib/tls/tortls_st.h \
|
||||
|
244
src/lib/tls/nss_countbytes.c
Normal file
244
src/lib/tls/nss_countbytes.c
Normal file
@ -0,0 +1,244 @@
|
||||
/* Copyright 2018, The Tor Project Inc. */
|
||||
/* See LICENSE for licensing information */
|
||||
|
||||
/**
|
||||
* \file nss_countbytes.c
|
||||
* \brief A PRFileDesc layer to let us count the number of bytes
|
||||
* bytes actually written on a PRFileDesc.
|
||||
**/
|
||||
|
||||
#include "orconfig.h"
|
||||
|
||||
#include "lib/log/util_bug.h"
|
||||
#include "lib/malloc/malloc.h"
|
||||
#include "lib/tls/nss_countbytes.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <prio.h>
|
||||
|
||||
/** Boolean: have we initialized this module */
|
||||
static bool countbytes_initialized = false;
|
||||
|
||||
/** Integer to identity this layer. */
|
||||
static PRDescIdentity countbytes_layer_id = PR_INVALID_IO_LAYER;
|
||||
|
||||
/** Table of methods for this layer.*/
|
||||
static PRIOMethods countbytes_methods;
|
||||
|
||||
/** Default close function provided by NSPR. We use this to help
|
||||
* implement our own close function.*/
|
||||
static PRStatus(*default_close_fn)(PRFileDesc *fd);
|
||||
|
||||
static PRStatus countbytes_close_fn(PRFileDesc *fd);
|
||||
static PRInt32 countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount);
|
||||
static PRInt32 countbytes_write_fn(PRFileDesc *fd, const void *buf,
|
||||
PRInt32 amount);
|
||||
static PRInt32 countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov,
|
||||
PRInt32 size, PRIntervalTime timeout);
|
||||
static PRInt32 countbytes_send_fn(PRFileDesc *fd, const void *buf,
|
||||
PRInt32 amount, PRIntn flags,
|
||||
PRIntervalTime timeout);
|
||||
static PRInt32 countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount,
|
||||
PRIntn flags, PRIntervalTime timeout);
|
||||
|
||||
/** Private fields for the byte-counter layer. We cast this to and from
|
||||
* PRFilePrivate*, which is supposed to be allowed. */
|
||||
typedef struct tor_nss_bytecounts_t {
|
||||
uint64_t n_read;
|
||||
uint64_t n_written;
|
||||
} tor_nss_bytecounts_t;
|
||||
|
||||
/**
|
||||
* Initialize this module, if it is not already initialized.
|
||||
**/
|
||||
void
|
||||
tor_nss_countbytes_init(void)
|
||||
{
|
||||
if (countbytes_initialized)
|
||||
return;
|
||||
|
||||
countbytes_layer_id = PR_GetUniqueIdentity("Tor byte-counting layer");
|
||||
tor_assert(countbytes_layer_id != PR_INVALID_IO_LAYER);
|
||||
|
||||
memcpy(&countbytes_methods, PR_GetDefaultIOMethods(), sizeof(PRIOMethods));
|
||||
|
||||
default_close_fn = countbytes_methods.close;
|
||||
countbytes_methods.close = countbytes_close_fn;
|
||||
countbytes_methods.read = countbytes_read_fn;
|
||||
countbytes_methods.write = countbytes_write_fn;
|
||||
countbytes_methods.writev = countbytes_writev_fn;
|
||||
countbytes_methods.send = countbytes_send_fn;
|
||||
countbytes_methods.recv = countbytes_recv_fn;
|
||||
/* NOTE: We aren't wrapping recvfrom, sendto, or sendfile, since I think
|
||||
* NSS won't be using them for TLS connections. */
|
||||
|
||||
countbytes_initialized = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the tor_nss_bytecounts_t object for a given IO layer. Asserts that
|
||||
* the IO layer is in fact a layer created by this module.
|
||||
*/
|
||||
static tor_nss_bytecounts_t *
|
||||
get_counts(PRFileDesc *fd)
|
||||
{
|
||||
tor_assert(fd->identity == countbytes_layer_id);
|
||||
return (tor_nss_bytecounts_t*) fd->secret;
|
||||
}
|
||||
|
||||
/** Helper: increment the read-count of an fd by n. */
|
||||
#define INC_READ(fd, n) STMT_BEGIN \
|
||||
get_counts(fd)->n_read += (n); \
|
||||
STMT_END
|
||||
|
||||
/** Helper: increment the write-count of an fd by n. */
|
||||
#define INC_WRITTEN(fd, n) STMT_BEGIN \
|
||||
get_counts(fd)->n_written += (n); \
|
||||
STMT_END
|
||||
|
||||
/** Implementation for PR_Close: frees the 'secret' field, then passes control
|
||||
* to the default close function */
|
||||
static PRStatus
|
||||
countbytes_close_fn(PRFileDesc *fd)
|
||||
{
|
||||
tor_assert(fd);
|
||||
|
||||
tor_nss_bytecounts_t *counts = (tor_nss_bytecounts_t *)fd->secret;
|
||||
tor_free(counts);
|
||||
fd->secret = NULL;
|
||||
|
||||
return default_close_fn(fd);
|
||||
}
|
||||
|
||||
/** Implementation for PR_Read: Calls the lower-level read function,
|
||||
* and records what it said. */
|
||||
static PRInt32
|
||||
countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount)
|
||||
{
|
||||
tor_assert(fd);
|
||||
tor_assert(fd->lower);
|
||||
|
||||
PRInt32 result = (fd->lower->methods->read)(fd->lower, buf, amount);
|
||||
if (result > 0)
|
||||
INC_READ(fd, result);
|
||||
return result;
|
||||
}
|
||||
/** Implementation for PR_Write: Calls the lower-level write function,
|
||||
* and records what it said. */
|
||||
static PRInt32
|
||||
countbytes_write_fn(PRFileDesc *fd, const void *buf, PRInt32 amount)
|
||||
{
|
||||
tor_assert(fd);
|
||||
tor_assert(fd->lower);
|
||||
|
||||
PRInt32 result = (fd->lower->methods->write)(fd->lower, buf, amount);
|
||||
if (result > 0)
|
||||
INC_WRITTEN(fd, result);
|
||||
return result;
|
||||
}
|
||||
/** Implementation for PR_Writev: Calls the lower-level writev function,
|
||||
* and records what it said. */
|
||||
static PRInt32
|
||||
countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov,
|
||||
PRInt32 size, PRIntervalTime timeout)
|
||||
{
|
||||
tor_assert(fd);
|
||||
tor_assert(fd->lower);
|
||||
|
||||
PRInt32 result = (fd->lower->methods->writev)(fd->lower, iov, size, timeout);
|
||||
if (result > 0)
|
||||
INC_WRITTEN(fd, result);
|
||||
return result;
|
||||
}
|
||||
/** Implementation for PR_Send: Calls the lower-level send function,
|
||||
* and records what it said. */
|
||||
static PRInt32
|
||||
countbytes_send_fn(PRFileDesc *fd, const void *buf,
|
||||
PRInt32 amount, PRIntn flags, PRIntervalTime timeout)
|
||||
{
|
||||
tor_assert(fd);
|
||||
tor_assert(fd->lower);
|
||||
|
||||
PRInt32 result = (fd->lower->methods->send)(fd->lower, buf, amount, flags,
|
||||
timeout);
|
||||
if (result > 0)
|
||||
INC_WRITTEN(fd, result);
|
||||
return result;
|
||||
}
|
||||
/** Implementation for PR_Recv: Calls the lower-level recv function,
|
||||
* and records what it said. */
|
||||
static PRInt32
|
||||
countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount,
|
||||
PRIntn flags, PRIntervalTime timeout)
|
||||
{
|
||||
tor_assert(fd);
|
||||
tor_assert(fd->lower);
|
||||
|
||||
PRInt32 result = (fd->lower->methods->recv)(fd->lower, buf, amount, flags,
|
||||
timeout);
|
||||
if (result > 0)
|
||||
INC_READ(fd, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap a PRFileDesc from NSPR with a new PRFileDesc that will count the
|
||||
* total number of bytes read and written. Return the new PRFileDesc.
|
||||
*
|
||||
* This function takes ownership of its input.
|
||||
*/
|
||||
PRFileDesc *
|
||||
tor_wrap_prfiledesc_with_byte_counter(PRFileDesc *stack)
|
||||
{
|
||||
if (BUG(! countbytes_initialized)) {
|
||||
tor_nss_countbytes_init();
|
||||
}
|
||||
|
||||
tor_nss_bytecounts_t *bytecounts = tor_malloc_zero(sizeof(*bytecounts));
|
||||
|
||||
PRFileDesc *newfd = PR_CreateIOLayerStub(countbytes_layer_id,
|
||||
&countbytes_methods);
|
||||
tor_assert(newfd);
|
||||
newfd->secret = (PRFilePrivate *)bytecounts;
|
||||
|
||||
/* This does some complicated messing around with the headers of these
|
||||
objects; see the NSPR documentation for more. The upshot is that
|
||||
after PushIOLayer, "stack" will be the head of the stack.
|
||||
*/
|
||||
PRStatus status = PR_PushIOLayer(stack, PR_TOP_IO_LAYER, newfd);
|
||||
tor_assert(status == PR_SUCCESS);
|
||||
|
||||
return stack;
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a PRFileDesc returned by tor_wrap_prfiledesc_with_byte_counter(),
|
||||
* or another PRFileDesc wrapping that PRFileDesc, set the provided
|
||||
* pointers to the number of bytes read and written on the descriptor since
|
||||
* it was created.
|
||||
*
|
||||
* Return 0 on success, -1 on failure.
|
||||
*/
|
||||
int
|
||||
tor_get_prfiledesc_byte_counts(PRFileDesc *fd,
|
||||
uint64_t *n_read_out,
|
||||
uint64_t *n_written_out)
|
||||
{
|
||||
if (BUG(! countbytes_initialized)) {
|
||||
tor_nss_countbytes_init();
|
||||
}
|
||||
|
||||
tor_assert(fd);
|
||||
PRFileDesc *bclayer = PR_GetIdentitiesLayer(fd, countbytes_layer_id);
|
||||
if (BUG(bclayer == NULL))
|
||||
return -1;
|
||||
|
||||
tor_nss_bytecounts_t *counts = get_counts(bclayer);
|
||||
|
||||
*n_read_out = counts->n_read;
|
||||
*n_written_out = counts->n_written;
|
||||
|
||||
return 0;
|
||||
}
|
25
src/lib/tls/nss_countbytes.h
Normal file
25
src/lib/tls/nss_countbytes.h
Normal file
@ -0,0 +1,25 @@
|
||||
/* Copyright 2018, The Tor Project, Inc. */
|
||||
/* See LICENSE for licensing information */
|
||||
|
||||
/**
|
||||
* \file nss_countbytes.h
|
||||
* \brief Header for nss_countbytes.c, which lets us count the number of
|
||||
* bytes actually written on a PRFileDesc.
|
||||
**/
|
||||
|
||||
#ifndef TOR_NSS_COUNTBYTES_H
|
||||
#define TOR_NSS_COUNTBYTES_H
|
||||
|
||||
#include "lib/cc/torint.h"
|
||||
|
||||
void tor_nss_countbytes_init(void);
|
||||
|
||||
struct PRFileDesc;
|
||||
struct PRFileDesc *tor_wrap_prfiledesc_with_byte_counter(
|
||||
struct PRFileDesc *stack);
|
||||
|
||||
int tor_get_prfiledesc_byte_counts(struct PRFileDesc *fd,
|
||||
uint64_t *n_read_out,
|
||||
uint64_t *n_written_out);
|
||||
|
||||
#endif
|
@ -31,11 +31,12 @@
|
||||
#include "lib/tls/tortls.h"
|
||||
#include "lib/tls/tortls_st.h"
|
||||
#include "lib/tls/tortls_internal.h"
|
||||
#include "lib/tls/nss_countbytes.h"
|
||||
#include "lib/log/util_bug.h"
|
||||
|
||||
DISABLE_GCC_WARNING(strict-prototypes)
|
||||
#include <prio.h>
|
||||
// For access to raw sockets.
|
||||
// For access to rar sockets.
|
||||
#include <private/pprio.h>
|
||||
#include <ssl.h>
|
||||
#include <sslt.h>
|
||||
@ -158,6 +159,8 @@ tor_tls_context_new(crypto_pk_t *identity,
|
||||
SECStatus s;
|
||||
tor_assert(identity);
|
||||
|
||||
tor_tls_init();
|
||||
|
||||
tor_tls_context_t *ctx = tor_malloc_zero(sizeof(tor_tls_context_t));
|
||||
ctx->refcnt = 1;
|
||||
|
||||
@ -320,7 +323,7 @@ tor_tls_get_state_description(tor_tls_t *tls, char *buf, size_t sz)
|
||||
void
|
||||
tor_tls_init(void)
|
||||
{
|
||||
/* We don't have any global setup to do yet, but that will change */
|
||||
tor_nss_countbytes_init();
|
||||
}
|
||||
|
||||
void
|
||||
@ -373,7 +376,11 @@ tor_tls_new(tor_socket_t sock, int is_server)
|
||||
if (!tcp)
|
||||
return NULL;
|
||||
|
||||
PRFileDesc *ssl = SSL_ImportFD(ctx->ctx, tcp);
|
||||
PRFileDesc *count = tor_wrap_prfiledesc_with_byte_counter(tcp);
|
||||
if (! count)
|
||||
return NULL;
|
||||
|
||||
PRFileDesc *ssl = SSL_ImportFD(ctx->ctx, count);
|
||||
if (!ssl) {
|
||||
PR_Close(tcp);
|
||||
return NULL;
|
||||
@ -465,7 +472,6 @@ tor_tls_read, (tor_tls_t *tls, char *cp, size_t len))
|
||||
PRInt32 rv = PR_Read(tls->ssl, cp, (int)len);
|
||||
// log_debug(LD_NET, "PR_Read(%zu) returned %d", n, (int)rv);
|
||||
if (rv > 0) {
|
||||
tls->n_read_since_last_check += rv;
|
||||
return rv;
|
||||
}
|
||||
if (rv == 0)
|
||||
@ -489,7 +495,6 @@ tor_tls_write(tor_tls_t *tls, const char *cp, size_t n)
|
||||
PRInt32 rv = PR_Write(tls->ssl, cp, (int)n);
|
||||
// log_debug(LD_NET, "PR_Write(%zu) returned %d", n, (int)rv);
|
||||
if (rv > 0) {
|
||||
tls->n_written_since_last_check += rv;
|
||||
return rv;
|
||||
}
|
||||
if (rv == 0)
|
||||
@ -579,13 +584,17 @@ tor_tls_get_n_raw_bytes(tor_tls_t *tls,
|
||||
tor_assert(tls);
|
||||
tor_assert(n_read);
|
||||
tor_assert(n_written);
|
||||
/* XXXX We don't curently have a way to measure this information correctly
|
||||
* in NSS; we could do that with a PRIO layer, but it'll take a little
|
||||
* coding. For now, we just track the number of bytes sent _in_ the TLS
|
||||
* stream. Doing this will make our rate-limiting slightly inaccurate. */
|
||||
*n_read = tls->n_read_since_last_check;
|
||||
*n_written = tls->n_written_since_last_check;
|
||||
tls->n_read_since_last_check = tls->n_written_since_last_check = 0;
|
||||
uint64_t r, w;
|
||||
if (tor_get_prfiledesc_byte_counts(tls->ssl, &r, &w) < 0) {
|
||||
*n_read = *n_written = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
*n_read = (size_t)(r - tls->last_read_count);
|
||||
*n_written = (size_t)(w - tls->last_write_count);
|
||||
|
||||
tls->last_read_count = r;
|
||||
tls->last_write_count = w;
|
||||
}
|
||||
|
||||
int
|
||||
|
@ -66,8 +66,9 @@ struct tor_tls_t {
|
||||
void *callback_arg;
|
||||
#endif
|
||||
#ifdef ENABLE_NSS
|
||||
size_t n_read_since_last_check;
|
||||
size_t n_written_since_last_check;
|
||||
/** Last values retried from tor_get_prfiledesc_byte_counts(). */
|
||||
uint64_t last_write_count;
|
||||
uint64_t last_read_count;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user