diff options
Diffstat (limited to 'src/lib')
43 files changed, 1319 insertions, 97 deletions
diff --git a/src/lib/container/smartlist.c b/src/lib/container/smartlist.c index 4b29d834d9..64cabfcc6f 100644 --- a/src/lib/container/smartlist.c +++ b/src/lib/container/smartlist.c @@ -408,7 +408,7 @@ smartlist_uniq(smartlist_t *sl, * less than member, and greater than 0 if key is greater then member. */ void * -smartlist_bsearch(smartlist_t *sl, const void *key, +smartlist_bsearch(const smartlist_t *sl, const void *key, int (*compare)(const void *key, const void **member)) { int found, idx; diff --git a/src/lib/container/smartlist.h b/src/lib/container/smartlist.h index 9705396ac9..0f5af3a923 100644 --- a/src/lib/container/smartlist.h +++ b/src/lib/container/smartlist.h @@ -64,7 +64,7 @@ const uint8_t *smartlist_get_most_frequent_digest256(smartlist_t *sl); void smartlist_uniq_strings(smartlist_t *sl); void smartlist_uniq_digests(smartlist_t *sl); void smartlist_uniq_digests256(smartlist_t *sl); -void *smartlist_bsearch(smartlist_t *sl, const void *key, +void *smartlist_bsearch(const smartlist_t *sl, const void *key, int (*compare)(const void *key, const void **member)); int smartlist_bsearch_idx(const smartlist_t *sl, const void *key, int (*compare)(const void *key, const void **member), diff --git a/src/lib/crypt_ops/aes_openssl.c b/src/lib/crypt_ops/aes_openssl.c index 387f5d3df0..f2990fc06d 100644 --- a/src/lib/crypt_ops/aes_openssl.c +++ b/src/lib/crypt_ops/aes_openssl.c @@ -11,7 +11,9 @@ #include "orconfig.h" #include "lib/crypt_ops/aes.h" +#include "lib/crypt_ops/crypto_util.h" #include "lib/log/util_bug.h" +#include "lib/arch/bytes.h" #ifdef _WIN32 /*wrkard for dtls1.h >= 0.9.8m of "#include <winsock.h>"*/ #include <winsock2.h> @@ -396,10 +398,10 @@ static void aes_set_iv(aes_cnt_cipher_t *cipher, const uint8_t *iv) { #ifdef USING_COUNTER_VARS - cipher->counter3 = ntohl(get_uint32(iv)); - cipher->counter2 = ntohl(get_uint32(iv+4)); - cipher->counter1 = ntohl(get_uint32(iv+8)); - cipher->counter0 = ntohl(get_uint32(iv+12)); + cipher->counter3 = tor_ntohl(get_uint32(iv)); + cipher->counter2 = tor_ntohl(get_uint32(iv+4)); + cipher->counter1 = tor_ntohl(get_uint32(iv+8)); + cipher->counter0 = tor_ntohl(get_uint32(iv+12)); #endif /* defined(USING_COUNTER_VARS) */ cipher->pos = 0; memcpy(cipher->ctr_buf.buf, iv, 16); diff --git a/src/lib/crypt_ops/crypto_dh_nss.c b/src/lib/crypt_ops/crypto_dh_nss.c index 9a14b809b4..e2d9040f5e 100644 --- a/src/lib/crypt_ops/crypto_dh_nss.c +++ b/src/lib/crypt_ops/crypto_dh_nss.c @@ -53,6 +53,8 @@ crypto_dh_init_nss(void) circuit_dh_param.prime.len = DH1024_KEY_LEN; circuit_dh_param.base.data = dh_generator_data; circuit_dh_param.base.len = 1; + + dh_initialized = 1; } void diff --git a/src/lib/crypt_ops/crypto_init.c b/src/lib/crypt_ops/crypto_init.c index c731662d49..9d6e2da0d0 100644 --- a/src/lib/crypt_ops/crypto_init.c +++ b/src/lib/crypt_ops/crypto_init.c @@ -191,3 +191,14 @@ crypto_get_header_version_string(void) return crypto_nss_get_header_version_str(); #endif } + +/** Return true iff Tor is using the NSS library. */ +int +tor_is_using_nss(void) +{ +#ifdef ENABLE_NSS + return 1; +#else + return 0; +#endif +} diff --git a/src/lib/crypt_ops/crypto_init.h b/src/lib/crypt_ops/crypto_init.h index 5b6d65d48c..b71f144276 100644 --- a/src/lib/crypt_ops/crypto_init.h +++ b/src/lib/crypt_ops/crypto_init.h @@ -31,4 +31,6 @@ const char *crypto_get_library_name(void); const char *crypto_get_library_version_string(void); const char *crypto_get_header_version_string(void); +int tor_is_using_nss(void); + #endif /* !defined(TOR_CRYPTO_H) */ diff --git a/src/lib/crypt_ops/crypto_ope.c b/src/lib/crypt_ops/crypto_ope.c index fd5d5f3770..789517eba2 100644 --- a/src/lib/crypt_ops/crypto_ope.c +++ b/src/lib/crypt_ops/crypto_ope.c @@ -48,17 +48,17 @@ struct crypto_ope_t { /** The type to add up in order to produce our OPE ciphertexts */ typedef uint16_t ope_val_t; -#ifdef WORDS_BIG_ENDIAN -/** Convert an OPE value to little-endian */ +#ifdef WORDS_BIGENDIAN +/** Convert an OPE value from little-endian. */ static inline ope_val_t -ope_val_to_le(ope_val_t x) +ope_val_from_le(ope_val_t x) { return ((x) >> 8) | (((x)&0xff) << 8); } #else -#define ope_val_to_le(x) (x) +#define ope_val_from_le(x) (x) #endif /** @@ -104,7 +104,7 @@ sum_values_from_cipher(crypto_cipher_t *c, size_t n) crypto_cipher_crypt_inplace(c, (char*)buf, BUFSZ*sizeof(ope_val_t)); for (i = 0; i < BUFSZ; ++i) { - total += ope_val_to_le(buf[i]); + total += ope_val_from_le(buf[i]); total += 1; } n -= BUFSZ; @@ -113,7 +113,7 @@ sum_values_from_cipher(crypto_cipher_t *c, size_t n) memset(buf, 0, n*sizeof(ope_val_t)); crypto_cipher_crypt_inplace(c, (char*)buf, n*sizeof(ope_val_t)); for (i = 0; i < n; ++i) { - total += ope_val_to_le(buf[i]); + total += ope_val_from_le(buf[i]); total += 1; } diff --git a/src/lib/crypt_ops/crypto_pwbox.c b/src/lib/crypt_ops/crypto_pwbox.c index 2377f216a0..91536e891b 100644 --- a/src/lib/crypt_ops/crypto_pwbox.c +++ b/src/lib/crypt_ops/crypto_pwbox.c @@ -61,6 +61,7 @@ crypto_pwbox(uint8_t **out, size_t *outlen_out, int rv; enc = pwbox_encoded_new(); + tor_assert(enc); pwbox_encoded_setlen_skey_header(enc, S2K_MAXLEN); diff --git a/src/lib/crypt_ops/crypto_rand.c b/src/lib/crypt_ops/crypto_rand.c index 313d829a57..cffd0610f3 100644 --- a/src/lib/crypt_ops/crypto_rand.c +++ b/src/lib/crypt_ops/crypto_rand.c @@ -335,8 +335,18 @@ crypto_strongest_rand_raw(uint8_t *out, size_t out_len) * Try to get <b>out_len</b> bytes of the strongest entropy we can generate, * storing it into <b>out</b>. **/ +void +crypto_strongest_rand(uint8_t *out, size_t out_len) +{ + crypto_strongest_rand_(out, out_len); +} + +/** + * Try to get <b>out_len</b> bytes of the strongest entropy we can generate, + * storing it into <b>out</b>. (Mockable version.) + **/ MOCK_IMPL(void, -crypto_strongest_rand,(uint8_t *out, size_t out_len)) +crypto_strongest_rand_,(uint8_t *out, size_t out_len)) { #define DLEN DIGEST512_LEN diff --git a/src/lib/crypt_ops/crypto_rand.h b/src/lib/crypt_ops/crypto_rand.h index 25bcfa1f1c..0c538d81ac 100644 --- a/src/lib/crypt_ops/crypto_rand.h +++ b/src/lib/crypt_ops/crypto_rand.h @@ -21,7 +21,8 @@ int crypto_seed_rng(void) ATTR_WUR; MOCK_DECL(void,crypto_rand,(char *to, size_t n)); void crypto_rand_unmocked(char *to, size_t n); -MOCK_DECL(void,crypto_strongest_rand,(uint8_t *out, size_t out_len)); +void crypto_strongest_rand(uint8_t *out, size_t out_len); +MOCK_DECL(void,crypto_strongest_rand_,(uint8_t *out, size_t out_len)); int crypto_rand_int(unsigned int max); int crypto_rand_int_range(unsigned int min, unsigned int max); uint64_t crypto_rand_uint64_range(uint64_t min, uint64_t max); diff --git a/src/lib/crypt_ops/crypto_rsa.c b/src/lib/crypt_ops/crypto_rsa.c index 6a9e2948f1..a510e12964 100644 --- a/src/lib/crypt_ops/crypto_rsa.c +++ b/src/lib/crypt_ops/crypto_rsa.c @@ -540,6 +540,9 @@ crypto_pk_read_private_key_from_string(crypto_pk_t *env, return crypto_pk_read_from_string_generic(env, src, len, true); } +/** If a file is longer than this, we won't try to decode its private key */ +#define MAX_PRIVKEY_FILE_LEN (16*1024*1024) + /** Read a PEM-encoded private key from the file named by * <b>keyfile</b> into <b>env</b>. Return 0 on success, -1 on failure. */ @@ -551,9 +554,14 @@ crypto_pk_read_private_key_from_filename(crypto_pk_t *env, char *buf = read_file_to_str(keyfile, 0, &st); if (!buf) return -1; + if (st.st_size > MAX_PRIVKEY_FILE_LEN) { + tor_free(buf); + return -1; + } - int rv = crypto_pk_read_private_key_from_string(env, buf, st.st_size); - memwipe(buf, 0, st.st_size); + int rv = crypto_pk_read_private_key_from_string(env, buf, + (ssize_t)st.st_size); + memwipe(buf, 0, (size_t)st.st_size); tor_free(buf); return rv; } diff --git a/src/lib/encoding/confline.c b/src/lib/encoding/confline.c index dd5193d3a7..71ce5b8424 100644 --- a/src/lib/encoding/confline.c +++ b/src/lib/encoding/confline.c @@ -148,6 +148,9 @@ config_get_lines_aux(const char *string, config_line_t **result, int extended, tor_free(v); return -1; } + log_notice(LD_CONFIG, "Included configuration file or " + "directory at recursion level %d: \"%s\".", + recursion_level, v); *next = include_list; if (list_last) next = &list_last->next; diff --git a/src/lib/evloop/procmon.c b/src/lib/evloop/procmon.c index e0c26caab2..02e167377f 100644 --- a/src/lib/evloop/procmon.c +++ b/src/lib/evloop/procmon.c @@ -20,6 +20,9 @@ #ifdef HAVE_ERRNO_H #include <errno.h> #endif +#ifdef HAVE_SYS_TIME_H +#include <sys/time.h> +#endif #ifdef _WIN32 #include <winsock2.h> diff --git a/src/lib/evloop/workqueue.c b/src/lib/evloop/workqueue.c index 931f65e710..5471f87b04 100644 --- a/src/lib/evloop/workqueue.c +++ b/src/lib/evloop/workqueue.c @@ -15,7 +15,7 @@ * * The main thread informs the worker threads of pending work by using a * condition variable. The workers inform the main process of completed work - * by using an alert_sockets_t object, as implemented in compat_threads.c. + * by using an alert_sockets_t object, as implemented in net/alertsock.c. * * The main thread can also queue an "update" that will be handled by all the * workers. This is useful for updating state that all the workers share. @@ -622,8 +622,8 @@ reply_event_cb(evutil_socket_t sock, short events, void *arg) tp->reply_cb(tp); } -/** Register the threadpool <b>tp</b>'s reply queue with the libevent - * mainloop of <b>base</b>. If <b>tp</b> is provided, it is run after +/** Register the threadpool <b>tp</b>'s reply queue with Tor's global + * libevent mainloop. If <b>cb</b> is provided, it is run after * each time there is work to process from the reply queue. Return 0 on * success, -1 on failure. */ diff --git a/src/lib/evloop/workqueue.h b/src/lib/evloop/workqueue.h index da292d1f05..10d5d47464 100644 --- a/src/lib/evloop/workqueue.h +++ b/src/lib/evloop/workqueue.h @@ -63,7 +63,6 @@ replyqueue_t *threadpool_get_replyqueue(threadpool_t *tp); replyqueue_t *replyqueue_new(uint32_t alertsocks_flags); void replyqueue_process(replyqueue_t *queue); -struct event_base; int threadpool_register_reply_event(threadpool_t *tp, void (*cb)(threadpool_t *tp)); diff --git a/src/lib/geoip/.may_include b/src/lib/geoip/.may_include new file mode 100644 index 0000000000..b1ee2dcfe9 --- /dev/null +++ b/src/lib/geoip/.may_include @@ -0,0 +1,13 @@ +orconfig.h +lib/cc/*.h +lib/container/*.h +lib/crypt_ops/*.h +lib/ctime/*.h +lib/encoding/*.h +lib/fs/*.h +lib/geoip/*.h +lib/log/*.h +lib/malloc/*.h +lib/net/*.h +lib/string/*.h +lib/testsupport/*.h diff --git a/src/lib/geoip/country.h b/src/lib/geoip/country.h new file mode 100644 index 0000000000..080c156023 --- /dev/null +++ b/src/lib/geoip/country.h @@ -0,0 +1,16 @@ +/* Copyright (c) 2001 Matej Pfajfar. + * Copyright (c) 2001-2004, Roger Dingledine. + * Copyright (c) 2004-2006, Roger Dingledine, Nick Mathewson. + * Copyright (c) 2007-2018, The Tor Project, Inc. */ +/* See LICENSE for licensing information */ + +#ifndef TOR_COUNTRY_H +#define TOR_COUNTRY_H + +#include "lib/cc/torint.h" +/** A signed integer representing a country code. */ +typedef int16_t country_t; + +#define COUNTRY_MAX INT16_MAX + +#endif diff --git a/src/lib/geoip/geoip.c b/src/lib/geoip/geoip.c new file mode 100644 index 0000000000..b1c0973d03 --- /dev/null +++ b/src/lib/geoip/geoip.c @@ -0,0 +1,510 @@ +/* Copyright (c) 2007-2018, The Tor Project, Inc. */ +/* See LICENSE for licensing information */ + +/** + * \file geoip.c + * \brief Functions related to maintaining an IP-to-country database; + * to summarizing client connections by country to entry guards, bridges, + * and directory servers; and for statistics on answering network status + * requests. + * + * There are two main kinds of functions in this module: geoip functions, + * which map groups of IPv4 and IPv6 addresses to country codes, and + * statistical functions, which collect statistics about different kinds of + * per-country usage. + * + * The geoip lookup tables are implemented as sorted lists of disjoint address + * ranges, each mapping to a singleton geoip_country_t. These country objects + * are also indexed by their names in a hashtable. + * + * The tables are populated from disk at startup by the geoip_load_file() + * function. For more information on the file format they read, see that + * function. See the scripts and the README file in src/config for more + * information about how those files are generated. + * + * Tor uses GeoIP information in order to implement user requests (such as + * ExcludeNodes {cc}), and to keep track of how much usage relays are getting + * for each country. + */ + +#define GEOIP_PRIVATE +#include "lib/geoip/geoip.h" +#include "lib/container/map.h" +#include "lib/container/order.h" +#include "lib/container/smartlist.h" +#include "lib/crypt_ops/crypto_digest.h" +#include "lib/ctime/di_ops.h" +#include "lib/encoding/binascii.h" +#include "lib/fs/files.h" +#include "lib/log/escape.h" +#include "lib/malloc/malloc.h" +#include "lib/net/address.h" //???? +#include "lib/net/inaddr.h" +#include "lib/string/compat_ctype.h" +#include "lib/string/compat_string.h" +#include "lib/string/scanf.h" +#include "lib/string/util_string.h" + +#include <stdio.h> +#include <string.h> + +static void init_geoip_countries(void); + +/** An entry from the GeoIP IPv4 file: maps an IPv4 range to a country. */ +typedef struct geoip_ipv4_entry_t { + uint32_t ip_low; /**< The lowest IP in the range, in host order */ + uint32_t ip_high; /**< The highest IP in the range, in host order */ + intptr_t country; /**< An index into geoip_countries */ +} geoip_ipv4_entry_t; + +/** An entry from the GeoIP IPv6 file: maps an IPv6 range to a country. */ +typedef struct geoip_ipv6_entry_t { + struct in6_addr ip_low; /**< The lowest IP in the range, in host order */ + struct in6_addr ip_high; /**< The highest IP in the range, in host order */ + intptr_t country; /**< An index into geoip_countries */ +} geoip_ipv6_entry_t; + +/** A list of geoip_country_t */ +static smartlist_t *geoip_countries = NULL; +/** A map from lowercased country codes to their position in geoip_countries. + * The index is encoded in the pointer, and 1 is added so that NULL can mean + * not found. */ +static strmap_t *country_idxplus1_by_lc_code = NULL; +/** Lists of all known geoip_ipv4_entry_t and geoip_ipv6_entry_t, sorted + * by their respective ip_low. */ +static smartlist_t *geoip_ipv4_entries = NULL, *geoip_ipv6_entries = NULL; + +/** SHA1 digest of the GeoIP files to include in extra-info descriptors. */ +static char geoip_digest[DIGEST_LEN]; +static char geoip6_digest[DIGEST_LEN]; + +/** Return a list of geoip_country_t for all known countries. */ +const smartlist_t * +geoip_get_countries(void) +{ + if (geoip_countries == NULL) { + init_geoip_countries(); + } + return geoip_countries; +} + +/** Return the index of the <b>country</b>'s entry in the GeoIP + * country list if it is a valid 2-letter country code, otherwise + * return -1. */ +MOCK_IMPL(country_t, +geoip_get_country,(const char *country)) +{ + void *idxplus1_; + intptr_t idx; + + idxplus1_ = strmap_get_lc(country_idxplus1_by_lc_code, country); + if (!idxplus1_) + return -1; + + idx = ((uintptr_t)idxplus1_)-1; + return (country_t)idx; +} + +/** Add an entry to a GeoIP table, mapping all IP addresses between <b>low</b> + * and <b>high</b>, inclusive, to the 2-letter country code <b>country</b>. */ +static void +geoip_add_entry(const tor_addr_t *low, const tor_addr_t *high, + const char *country) +{ + intptr_t idx; + void *idxplus1_; + + IF_BUG_ONCE(tor_addr_family(low) != tor_addr_family(high)) + return; + IF_BUG_ONCE(tor_addr_compare(high, low, CMP_EXACT) < 0) + return; + + idxplus1_ = strmap_get_lc(country_idxplus1_by_lc_code, country); + + if (!idxplus1_) { + geoip_country_t *c = tor_malloc_zero(sizeof(geoip_country_t)); + strlcpy(c->countrycode, country, sizeof(c->countrycode)); + tor_strlower(c->countrycode); + smartlist_add(geoip_countries, c); + idx = smartlist_len(geoip_countries) - 1; + strmap_set_lc(country_idxplus1_by_lc_code, country, (void*)(idx+1)); + } else { + idx = ((uintptr_t)idxplus1_)-1; + } + { + geoip_country_t *c = smartlist_get(geoip_countries, (int)idx); + tor_assert(!strcasecmp(c->countrycode, country)); + } + + if (tor_addr_family(low) == AF_INET) { + geoip_ipv4_entry_t *ent = tor_malloc_zero(sizeof(geoip_ipv4_entry_t)); + ent->ip_low = tor_addr_to_ipv4h(low); + ent->ip_high = tor_addr_to_ipv4h(high); + ent->country = idx; + smartlist_add(geoip_ipv4_entries, ent); + } else if (tor_addr_family(low) == AF_INET6) { + geoip_ipv6_entry_t *ent = tor_malloc_zero(sizeof(geoip_ipv6_entry_t)); + ent->ip_low = *tor_addr_to_in6_assert(low); + ent->ip_high = *tor_addr_to_in6_assert(high); + ent->country = idx; + smartlist_add(geoip_ipv6_entries, ent); + } +} + +/** Add an entry to the GeoIP table indicated by <b>family</b>, + * parsing it from <b>line</b>. The format is as for geoip_load_file(). */ +STATIC int +geoip_parse_entry(const char *line, sa_family_t family) +{ + tor_addr_t low_addr, high_addr; + char c[3]; + char *country = NULL; + + if (!geoip_countries) + init_geoip_countries(); + if (family == AF_INET) { + if (!geoip_ipv4_entries) + geoip_ipv4_entries = smartlist_new(); + } else if (family == AF_INET6) { + if (!geoip_ipv6_entries) + geoip_ipv6_entries = smartlist_new(); + } else { + log_warn(LD_GENERAL, "Unsupported family: %d", family); + return -1; + } + + while (TOR_ISSPACE(*line)) + ++line; + if (*line == '#') + return 0; + + char buf[512]; + if (family == AF_INET) { + unsigned int low, high; + if (tor_sscanf(line,"%u,%u,%2s", &low, &high, c) == 3 || + tor_sscanf(line,"\"%u\",\"%u\",\"%2s\",", &low, &high, c) == 3) { + tor_addr_from_ipv4h(&low_addr, low); + tor_addr_from_ipv4h(&high_addr, high); + } else + goto fail; + country = c; + } else { /* AF_INET6 */ + char *low_str, *high_str; + struct in6_addr low, high; + char *strtok_state; + strlcpy(buf, line, sizeof(buf)); + low_str = tor_strtok_r(buf, ",", &strtok_state); + if (!low_str) + goto fail; + high_str = tor_strtok_r(NULL, ",", &strtok_state); + if (!high_str) + goto fail; + country = tor_strtok_r(NULL, "\n", &strtok_state); + if (!country) + goto fail; + if (strlen(country) != 2) + goto fail; + if (tor_inet_pton(AF_INET6, low_str, &low) <= 0) + goto fail; + tor_addr_from_in6(&low_addr, &low); + if (tor_inet_pton(AF_INET6, high_str, &high) <= 0) + goto fail; + tor_addr_from_in6(&high_addr, &high); + } + geoip_add_entry(&low_addr, &high_addr, country); + return 0; + + fail: + log_warn(LD_GENERAL, "Unable to parse line from GEOIP %s file: %s", + family == AF_INET ? "IPv4" : "IPv6", escaped(line)); + return -1; +} + +/** Sorting helper: return -1, 1, or 0 based on comparison of two + * geoip_ipv4_entry_t */ +static int +geoip_ipv4_compare_entries_(const void **_a, const void **_b) +{ + const geoip_ipv4_entry_t *a = *_a, *b = *_b; + if (a->ip_low < b->ip_low) + return -1; + else if (a->ip_low > b->ip_low) + return 1; + else + return 0; +} + +/** bsearch helper: return -1, 1, or 0 based on comparison of an IP (a pointer + * to a uint32_t in host order) to a geoip_ipv4_entry_t */ +static int +geoip_ipv4_compare_key_to_entry_(const void *_key, const void **_member) +{ + /* No alignment issue here, since _key really is a pointer to uint32_t */ + const uint32_t addr = *(uint32_t *)_key; + const geoip_ipv4_entry_t *entry = *_member; + if (addr < entry->ip_low) + return -1; + else if (addr > entry->ip_high) + return 1; + else + return 0; +} + +/** Sorting helper: return -1, 1, or 0 based on comparison of two + * geoip_ipv6_entry_t */ +static int +geoip_ipv6_compare_entries_(const void **_a, const void **_b) +{ + const geoip_ipv6_entry_t *a = *_a, *b = *_b; + return fast_memcmp(a->ip_low.s6_addr, b->ip_low.s6_addr, + sizeof(struct in6_addr)); +} + +/** bsearch helper: return -1, 1, or 0 based on comparison of an IPv6 + * (a pointer to a in6_addr) to a geoip_ipv6_entry_t */ +static int +geoip_ipv6_compare_key_to_entry_(const void *_key, const void **_member) +{ + const struct in6_addr *addr = (struct in6_addr *)_key; + const geoip_ipv6_entry_t *entry = *_member; + + if (fast_memcmp(addr->s6_addr, entry->ip_low.s6_addr, + sizeof(struct in6_addr)) < 0) + return -1; + else if (fast_memcmp(addr->s6_addr, entry->ip_high.s6_addr, + sizeof(struct in6_addr)) > 0) + return 1; + else + return 0; +} + +/** Set up a new list of geoip countries with no countries (yet) set in it, + * except for the unknown country. + */ +static void +init_geoip_countries(void) +{ + geoip_country_t *geoip_unresolved; + geoip_countries = smartlist_new(); + /* Add a geoip_country_t for requests that could not be resolved to a + * country as first element (index 0) to geoip_countries. */ + geoip_unresolved = tor_malloc_zero(sizeof(geoip_country_t)); + strlcpy(geoip_unresolved->countrycode, "??", + sizeof(geoip_unresolved->countrycode)); + smartlist_add(geoip_countries, geoip_unresolved); + country_idxplus1_by_lc_code = strmap_new(); + strmap_set_lc(country_idxplus1_by_lc_code, "??", (void*)(1)); +} + +/** Clear appropriate GeoIP database, based on <b>family</b>, and + * reload it from the file <b>filename</b>. Return 0 on success, -1 on + * failure. + * + * Recognized line formats for IPv4 are: + * INTIPLOW,INTIPHIGH,CC + * and + * "INTIPLOW","INTIPHIGH","CC","CC3","COUNTRY NAME" + * where INTIPLOW and INTIPHIGH are IPv4 addresses encoded as 4-byte unsigned + * integers, and CC is a country code. + * + * Recognized line format for IPv6 is: + * IPV6LOW,IPV6HIGH,CC + * where IPV6LOW and IPV6HIGH are IPv6 addresses and CC is a country code. + * + * It also recognizes, and skips over, blank lines and lines that start + * with '#' (comments). + */ +int +geoip_load_file(sa_family_t family, const char *filename, int severity) +{ + FILE *f; + crypto_digest_t *geoip_digest_env = NULL; + + tor_assert(family == AF_INET || family == AF_INET6); + + if (!(f = tor_fopen_cloexec(filename, "r"))) { + log_fn(severity, LD_GENERAL, "Failed to open GEOIP file %s.", + filename); + return -1; + } + if (!geoip_countries) + init_geoip_countries(); + + if (family == AF_INET) { + if (geoip_ipv4_entries) { + SMARTLIST_FOREACH(geoip_ipv4_entries, geoip_ipv4_entry_t *, e, + tor_free(e)); + smartlist_free(geoip_ipv4_entries); + } + geoip_ipv4_entries = smartlist_new(); + } else { /* AF_INET6 */ + if (geoip_ipv6_entries) { + SMARTLIST_FOREACH(geoip_ipv6_entries, geoip_ipv6_entry_t *, e, + tor_free(e)); + smartlist_free(geoip_ipv6_entries); + } + geoip_ipv6_entries = smartlist_new(); + } + geoip_digest_env = crypto_digest_new(); + + log_notice(LD_GENERAL, "Parsing GEOIP %s file %s.", + (family == AF_INET) ? "IPv4" : "IPv6", filename); + while (!feof(f)) { + char buf[512]; + if (fgets(buf, (int)sizeof(buf), f) == NULL) + break; + crypto_digest_add_bytes(geoip_digest_env, buf, strlen(buf)); + /* FFFF track full country name. */ + geoip_parse_entry(buf, family); + } + /*XXXX abort and return -1 if no entries/illformed?*/ + fclose(f); + + /* Sort list and remember file digests so that we can include it in + * our extra-info descriptors. */ + if (family == AF_INET) { + smartlist_sort(geoip_ipv4_entries, geoip_ipv4_compare_entries_); + crypto_digest_get_digest(geoip_digest_env, geoip_digest, DIGEST_LEN); + } else { + /* AF_INET6 */ + smartlist_sort(geoip_ipv6_entries, geoip_ipv6_compare_entries_); + crypto_digest_get_digest(geoip_digest_env, geoip6_digest, DIGEST_LEN); + } + crypto_digest_free(geoip_digest_env); + + return 0; +} + +/** Given an IP address in host order, return a number representing the + * country to which that address belongs, -1 for "No geoip information + * available", or 0 for the 'unknown country'. The return value will always + * be less than geoip_get_n_countries(). To decode it, call + * geoip_get_country_name(). + */ +int +geoip_get_country_by_ipv4(uint32_t ipaddr) +{ + geoip_ipv4_entry_t *ent; + if (!geoip_ipv4_entries) + return -1; + ent = smartlist_bsearch(geoip_ipv4_entries, &ipaddr, + geoip_ipv4_compare_key_to_entry_); + return ent ? (int)ent->country : 0; +} + +/** Given an IPv6 address, return a number representing the country to + * which that address belongs, -1 for "No geoip information available", or + * 0 for the 'unknown country'. The return value will always be less than + * geoip_get_n_countries(). To decode it, call geoip_get_country_name(). + */ +int +geoip_get_country_by_ipv6(const struct in6_addr *addr) +{ + geoip_ipv6_entry_t *ent; + + if (!geoip_ipv6_entries) + return -1; + ent = smartlist_bsearch(geoip_ipv6_entries, addr, + geoip_ipv6_compare_key_to_entry_); + return ent ? (int)ent->country : 0; +} + +/** Given an IP address, return a number representing the country to which + * that address belongs, -1 for "No geoip information available", or 0 for + * the 'unknown country'. The return value will always be less than + * geoip_get_n_countries(). To decode it, call geoip_get_country_name(). + */ +MOCK_IMPL(int, +geoip_get_country_by_addr,(const tor_addr_t *addr)) +{ + if (tor_addr_family(addr) == AF_INET) { + return geoip_get_country_by_ipv4(tor_addr_to_ipv4h(addr)); + } else if (tor_addr_family(addr) == AF_INET6) { + return geoip_get_country_by_ipv6(tor_addr_to_in6(addr)); + } else { + return -1; + } +} + +/** Return the number of countries recognized by the GeoIP country list. */ +MOCK_IMPL(int, +geoip_get_n_countries,(void)) +{ + if (!geoip_countries) + init_geoip_countries(); + return (int) smartlist_len(geoip_countries); +} + +/** Return the two-letter country code associated with the number <b>num</b>, + * or "??" for an unknown value. */ +const char * +geoip_get_country_name(country_t num) +{ + if (geoip_countries && num >= 0 && num < smartlist_len(geoip_countries)) { + geoip_country_t *c = smartlist_get(geoip_countries, num); + return c->countrycode; + } else + return "??"; +} + +/** Return true iff we have loaded a GeoIP database.*/ +MOCK_IMPL(int, +geoip_is_loaded,(sa_family_t family)) +{ + tor_assert(family == AF_INET || family == AF_INET6); + if (geoip_countries == NULL) + return 0; + if (family == AF_INET) + return geoip_ipv4_entries != NULL; + else /* AF_INET6 */ + return geoip_ipv6_entries != NULL; +} + +/** Return the hex-encoded SHA1 digest of the loaded GeoIP file. The + * result does not need to be deallocated, but will be overwritten by the + * next call of hex_str(). */ +const char * +geoip_db_digest(sa_family_t family) +{ + tor_assert(family == AF_INET || family == AF_INET6); + if (family == AF_INET) + return hex_str(geoip_digest, DIGEST_LEN); + else /* AF_INET6 */ + return hex_str(geoip6_digest, DIGEST_LEN); +} + +/** Release all storage held by the GeoIP databases and country list. */ +STATIC void +clear_geoip_db(void) +{ + if (geoip_countries) { + SMARTLIST_FOREACH(geoip_countries, geoip_country_t *, c, tor_free(c)); + smartlist_free(geoip_countries); + } + + strmap_free(country_idxplus1_by_lc_code, NULL); + if (geoip_ipv4_entries) { + SMARTLIST_FOREACH(geoip_ipv4_entries, geoip_ipv4_entry_t *, ent, + tor_free(ent)); + smartlist_free(geoip_ipv4_entries); + } + if (geoip_ipv6_entries) { + SMARTLIST_FOREACH(geoip_ipv6_entries, geoip_ipv6_entry_t *, ent, + tor_free(ent)); + smartlist_free(geoip_ipv6_entries); + } + geoip_countries = NULL; + country_idxplus1_by_lc_code = NULL; + geoip_ipv4_entries = NULL; + geoip_ipv6_entries = NULL; +} + +/** Release all storage held in this file. */ +void +geoip_free_all(void) +{ + clear_geoip_db(); + + memset(geoip_digest, 0, sizeof(geoip_digest)); + memset(geoip6_digest, 0, sizeof(geoip6_digest)); +} diff --git a/src/lib/geoip/geoip.h b/src/lib/geoip/geoip.h new file mode 100644 index 0000000000..6ef27d66d0 --- /dev/null +++ b/src/lib/geoip/geoip.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2001 Matej Pfajfar. + * Copyright (c) 2001-2004, Roger Dingledine. + * Copyright (c) 2004-2006, Roger Dingledine, Nick Mathewson. + * Copyright (c) 2007-2018, The Tor Project, Inc. */ +/* See LICENSE for licensing information */ + +/** + * \file geoip.h + * \brief Header file for geoip.c. + **/ + +#ifndef TOR_GEOIP_H +#define TOR_GEOIP_H + +#include "orconfig.h" +#include "lib/net/nettypes.h" +#include "lib/testsupport/testsupport.h" +#include "lib/net/inaddr_st.h" +#include "lib/geoip/country.h" + +#ifdef GEOIP_PRIVATE +STATIC int geoip_parse_entry(const char *line, sa_family_t family); +STATIC void clear_geoip_db(void); +#endif /* defined(GEOIP_PRIVATE) */ + +struct in6_addr; +struct tor_addr_t; + +int geoip_get_country_by_ipv4(uint32_t ipaddr); +int geoip_get_country_by_ipv6(const struct in6_addr *addr); + +/** A per-country GeoIP record. */ +typedef struct geoip_country_t { + char countrycode[3]; +} geoip_country_t; + +struct smartlist_t; +const struct smartlist_t *geoip_get_countries(void); + +int geoip_load_file(sa_family_t family, const char *filename, int severity); +MOCK_DECL(int, geoip_get_country_by_addr, (const struct tor_addr_t *addr)); +MOCK_DECL(int, geoip_get_n_countries, (void)); +const char *geoip_get_country_name(country_t num); +MOCK_DECL(int, geoip_is_loaded, (sa_family_t family)); +const char *geoip_db_digest(sa_family_t family); +MOCK_DECL(country_t, geoip_get_country, (const char *countrycode)); + +void geoip_free_all(void); + +#endif /* !defined(TOR_GEOIP_H) */ diff --git a/src/lib/geoip/include.am b/src/lib/geoip/include.am new file mode 100644 index 0000000000..9710d75ac7 --- /dev/null +++ b/src/lib/geoip/include.am @@ -0,0 +1,17 @@ +noinst_LIBRARIES += src/lib/libtor-geoip.a + +if UNITTESTS_ENABLED +noinst_LIBRARIES += src/lib/libtor-geoip-testing.a +endif + +src_lib_libtor_geoip_a_SOURCES = \ + src/lib/geoip/geoip.c + +src_lib_libtor_geoip_testing_a_SOURCES = \ + $(src_lib_libtor_geoip_a_SOURCES) +src_lib_libtor_geoip_testing_a_CPPFLAGS = $(AM_CPPFLAGS) $(TEST_CPPFLAGS) +src_lib_libtor_geoip_testing_a_CFLAGS = $(AM_CFLAGS) $(TEST_CFLAGS) + +noinst_HEADERS += \ + src/lib/geoip/geoip.h \ + src/lib/geoip/country.h diff --git a/src/lib/log/util_bug.h b/src/lib/log/util_bug.h index 44a4f8381c..557d932ac3 100644 --- a/src/lib/log/util_bug.h +++ b/src/lib/log/util_bug.h @@ -56,6 +56,35 @@ #error "Sorry; we don't support building with NDEBUG." #endif /* defined(NDEBUG) */ +#if defined(TOR_UNIT_TESTS) && defined(__GNUC__) +/* We define this GCC macro as a replacement for PREDICT_UNLIKELY() in this + * header, so that in our unit test builds, we'll get compiler warnings about + * stuff like tor_assert(n = 5). + * + * The key here is that (e) is wrapped in exactly one layer of parentheses, + * and then passed right to a conditional. If you do anything else to the + * expression here, or introduce any more parentheses, the compiler won't + * help you. + * + * We only do this for the unit-test build case because it interferes with + * the likely-branch labeling. Note below that in the other case, we define + * these macros to just be synonyms for PREDICT_(UN)LIKELY. + */ +#define ASSERT_PREDICT_UNLIKELY_(e) \ + ( { \ + int tor__assert_tmp_value__; \ + if (e) \ + tor__assert_tmp_value__ = 1; \ + else \ + tor__assert_tmp_value__ = 0; \ + tor__assert_tmp_value__; \ + } ) +#define ASSERT_PREDICT_LIKELY_(e) ASSERT_PREDICT_UNLIKELY_(e) +#else +#define ASSERT_PREDICT_UNLIKELY_(e) PREDICT_UNLIKELY(e) +#define ASSERT_PREDICT_LIKELY_(e) PREDICT_LIKELY(e) +#endif + /* Sometimes we don't want to use assertions during branch coverage tests; it * leads to tons of unreached branches which in reality are only assertions we * didn't hit. */ @@ -67,13 +96,19 @@ /** Like assert(3), but send assertion failures to the log as well as to * stderr. */ #define tor_assert(expr) STMT_BEGIN \ - if (PREDICT_UNLIKELY(!(expr))) { \ + if (ASSERT_PREDICT_LIKELY_(expr)) { \ + } else { \ tor_assertion_failed_(SHORT_FILE__, __LINE__, __func__, #expr); \ abort(); \ } STMT_END #endif /* defined(TOR_UNIT_TESTS) && defined(DISABLE_ASSERTS_IN_UNIT_TESTS) */ -#define tor_assert_unreached() tor_assert(0) +#define tor_assert_unreached() \ + STMT_BEGIN { \ + tor_assertion_failed_(SHORT_FILE__, __LINE__, __func__, \ + "line should be unreached"); \ + abort(); \ + } STMT_END /* Non-fatal bug assertions. The "unreached" variants mean "this line should * never be reached." The "once" variants mean "Don't log a warning more than @@ -104,7 +139,7 @@ #define tor_assert_nonfatal_unreached_once() tor_assert(0) #define tor_assert_nonfatal_once(cond) tor_assert((cond)) #define BUG(cond) \ - (PREDICT_UNLIKELY(cond) ? \ + (ASSERT_PREDICT_UNLIKELY_(cond) ? \ (tor_assertion_failed_(SHORT_FILE__,__LINE__,__func__,"!("#cond")"), \ abort(), 1) \ : 0) @@ -113,14 +148,15 @@ #define tor_assert_nonfatal(cond) ((void)(cond)) #define tor_assert_nonfatal_unreached_once() STMT_NIL #define tor_assert_nonfatal_once(cond) ((void)(cond)) -#define BUG(cond) (PREDICT_UNLIKELY(cond) ? 1 : 0) +#define BUG(cond) (ASSERT_PREDICT_UNLIKELY_(cond) ? 1 : 0) #else /* Normal case, !ALL_BUGS_ARE_FATAL, !DISABLE_ASSERTS_IN_UNIT_TESTS */ #define tor_assert_nonfatal_unreached() STMT_BEGIN \ tor_bug_occurred_(SHORT_FILE__, __LINE__, __func__, NULL, 0); \ STMT_END #define tor_assert_nonfatal(cond) STMT_BEGIN \ - if (PREDICT_UNLIKELY(!(cond))) { \ - tor_bug_occurred_(SHORT_FILE__, __LINE__, __func__, #cond, 0); \ + if (ASSERT_PREDICT_LIKELY_(cond)) { \ + } else { \ + tor_bug_occurred_(SHORT_FILE__, __LINE__, __func__, #cond, 0); \ } \ STMT_END #define tor_assert_nonfatal_unreached_once() STMT_BEGIN \ @@ -132,13 +168,14 @@ STMT_END #define tor_assert_nonfatal_once(cond) STMT_BEGIN \ static int warning_logged__ = 0; \ - if (!warning_logged__ && PREDICT_UNLIKELY(!(cond))) { \ + if (ASSERT_PREDICT_LIKELY_(cond)) { \ + } else if (!warning_logged__) { \ warning_logged__ = 1; \ tor_bug_occurred_(SHORT_FILE__, __LINE__, __func__, #cond, 1); \ } \ STMT_END #define BUG(cond) \ - (PREDICT_UNLIKELY(cond) ? \ + (ASSERT_PREDICT_UNLIKELY_(cond) ? \ (tor_bug_occurred_(SHORT_FILE__,__LINE__,__func__,"!("#cond")",0), 1) \ : 0) #endif /* defined(ALL_BUGS_ARE_FATAL) || ... */ @@ -147,17 +184,17 @@ #define IF_BUG_ONCE__(cond,var) \ if (( { \ static int var = 0; \ - int bool_result = (cond); \ - if (PREDICT_UNLIKELY(bool_result) && !var) { \ + int bool_result = !!(cond); \ + if (bool_result && !var) { \ var = 1; \ tor_bug_occurred_(SHORT_FILE__, __LINE__, __func__, \ "!("#cond")", 1); \ } \ - PREDICT_UNLIKELY(bool_result); } )) + bool_result; } )) #else /* !(defined(__GNUC__)) */ #define IF_BUG_ONCE__(cond,var) \ static int var = 0; \ - if (PREDICT_UNLIKELY(cond) ? \ + if ((cond) ? \ (var ? 1 : \ (var=1, \ tor_bug_occurred_(SHORT_FILE__, __LINE__, __func__, \ @@ -175,7 +212,7 @@ */ #define IF_BUG_ONCE(cond) \ - IF_BUG_ONCE__((cond), \ + IF_BUG_ONCE__(ASSERT_PREDICT_UNLIKELY_(cond), \ IF_BUG_ONCE_VARNAME__(__LINE__)) /** Define this if you want Tor to crash when any problem comes up, diff --git a/src/lib/net/address.c b/src/lib/net/address.c index 03767e2950..c97a17037a 100644 --- a/src/lib/net/address.c +++ b/src/lib/net/address.c @@ -1187,14 +1187,22 @@ tor_addr_parse(tor_addr_t *addr, const char *src) int result; struct in_addr in_tmp; struct in6_addr in6_tmp; + int brackets_detected = 0; + tor_assert(addr && src); - if (src[0] == '[' && src[1]) + + size_t len = strlen(src); + + if (len && src[0] == '[' && src[len - 1] == ']') { + brackets_detected = 1; src = tmp = tor_strndup(src+1, strlen(src)-2); + } if (tor_inet_pton(AF_INET6, src, &in6_tmp) > 0) { result = AF_INET6; tor_addr_from_in6(addr, &in6_tmp); - } else if (tor_inet_pton(AF_INET, src, &in_tmp) > 0) { + } else if (!brackets_detected && + tor_inet_pton(AF_INET, src, &in_tmp) > 0) { result = AF_INET; tor_addr_from_in(addr, &in_tmp); } else { diff --git a/src/lib/net/inaddr.c b/src/lib/net/inaddr.c index dcd8fcdd65..0960d323c5 100644 --- a/src/lib/net/inaddr.c +++ b/src/lib/net/inaddr.c @@ -168,6 +168,13 @@ tor_inet_pton(int af, const char *src, void *dst) if (af == AF_INET) { return tor_inet_aton(src, dst); } else if (af == AF_INET6) { + ssize_t len = strlen(src); + + /* Reject if src has needless trailing ':'. */ + if (len > 2 && src[len - 1] == ':' && src[len - 2] != ':') { + return 0; + } + struct in6_addr *out = dst; uint16_t words[8]; int gapPos = -1, i, setWords=0; @@ -207,7 +214,6 @@ tor_inet_pton(int af, const char *src, void *dst) return 0; if (TOR_ISXDIGIT(*src)) { char *next; - ssize_t len; long r = strtol(src, &next, 16); if (next == NULL || next == src) { /* The 'next == src' error case can happen on versions of openbsd diff --git a/src/lib/net/socket.c b/src/lib/net/socket.c index 06421b080d..cd7c9685cd 100644 --- a/src/lib/net/socket.c +++ b/src/lib/net/socket.c @@ -142,41 +142,6 @@ tor_close_socket_simple(tor_socket_t s) return r; } -/** As tor_close_socket_simple(), but keeps track of the number - * of open sockets. Returns 0 on success, -1 on failure. */ -MOCK_IMPL(int, -tor_close_socket,(tor_socket_t s)) -{ - int r = tor_close_socket_simple(s); - - socket_accounting_lock(); -#ifdef DEBUG_SOCKET_COUNTING - if (s > max_socket || ! bitarray_is_set(open_sockets, s)) { - log_warn(LD_BUG, "Closing a socket (%d) that wasn't returned by tor_open_" - "socket(), or that was already closed or something.", s); - } else { - tor_assert(open_sockets && s <= max_socket); - bitarray_clear(open_sockets, s); - } -#endif /* defined(DEBUG_SOCKET_COUNTING) */ - if (r == 0) { - --n_sockets_open; - } else { -#ifdef _WIN32 - if (r != WSAENOTSOCK) - --n_sockets_open; -#else - if (r != EBADF) - --n_sockets_open; // LCOV_EXCL_LINE -- EIO and EINTR too hard to force. -#endif /* defined(_WIN32) */ - r = -1; - } - - tor_assert_nonfatal(n_sockets_open >= 0); - socket_accounting_unlock(); - return r; -} - /** @{ */ #ifdef DEBUG_SOCKET_COUNTING /** Helper: if DEBUG_SOCKET_COUNTING is enabled, remember that <b>s</b> is @@ -201,11 +166,50 @@ mark_socket_open(tor_socket_t s) } bitarray_set(open_sockets, s); } +static inline void +mark_socket_closed(tor_socket_t s) +{ + if (s > max_socket || ! bitarray_is_set(open_sockets, s)) { + log_warn(LD_BUG, "Closing a socket (%d) that wasn't returned by tor_open_" + "socket(), or that was already closed or something.", s); + } else { + tor_assert(open_sockets && s <= max_socket); + bitarray_clear(open_sockets, s); + } +} #else /* !(defined(DEBUG_SOCKET_COUNTING)) */ #define mark_socket_open(s) ((void) (s)) +#define mark_socket_closed(s) ((void) (s)) #endif /* defined(DEBUG_SOCKET_COUNTING) */ /** @} */ +/** As tor_close_socket_simple(), but keeps track of the number + * of open sockets. Returns 0 on success, -1 on failure. */ +MOCK_IMPL(int, +tor_close_socket,(tor_socket_t s)) +{ + int r = tor_close_socket_simple(s); + + socket_accounting_lock(); + mark_socket_closed(s); + if (r == 0) { + --n_sockets_open; + } else { +#ifdef _WIN32 + if (r != WSAENOTSOCK) + --n_sockets_open; +#else + if (r != EBADF) + --n_sockets_open; // LCOV_EXCL_LINE -- EIO and EINTR too hard to force. +#endif /* defined(_WIN32) */ + r = -1; + } + + tor_assert_nonfatal(n_sockets_open >= 0); + socket_accounting_unlock(); + return r; +} + /** As socket(), but counts the number of open sockets. */ MOCK_IMPL(tor_socket_t, tor_open_socket,(int domain, int type, int protocol)) @@ -307,6 +311,20 @@ tor_take_socket_ownership(tor_socket_t s) socket_accounting_unlock(); } +/** + * For socket accounting: declare that we are no longer the owner of the + * socket <b>s</b>. This will prevent us from overallocating sockets, and + * prevent us from asserting later when we close the socket <b>s</b>. + */ +void +tor_release_socket_ownership(tor_socket_t s) +{ + socket_accounting_lock(); + --n_sockets_open; + mark_socket_closed(s); + socket_accounting_unlock(); +} + /** As accept(), but counts the number of open sockets. */ tor_socket_t tor_accept_socket(tor_socket_t sockfd, struct sockaddr *addr, socklen_t *len) diff --git a/src/lib/net/socket.h b/src/lib/net/socket.h index 5b7d6dbbc6..2b87441fc6 100644 --- a/src/lib/net/socket.h +++ b/src/lib/net/socket.h @@ -23,6 +23,7 @@ struct sockaddr; int tor_close_socket_simple(tor_socket_t s); MOCK_DECL(int, tor_close_socket, (tor_socket_t s)); void tor_take_socket_ownership(tor_socket_t s); +void tor_release_socket_ownership(tor_socket_t s); tor_socket_t tor_open_socket_with_extensions( int domain, int type, int protocol, int cloexec, int nonblock); diff --git a/src/lib/net/socketpair.c b/src/lib/net/socketpair.c index 8dd04d0486..380338f15c 100644 --- a/src/lib/net/socketpair.c +++ b/src/lib/net/socketpair.c @@ -2,6 +2,7 @@ * Copyright (c) 2004-2006, Roger Dingledine, Nick Mathewson. * Copyright (c) 2007-2018, The Tor Project, Inc. */ +#include "lib/cc/torint.h" #include "lib/net/socketpair.h" #include "lib/net/inaddr_st.h" #include "lib/arch/bytes.h" diff --git a/src/lib/process/daemon.c b/src/lib/process/daemon.c index c64affd8b9..ab3ac73ad5 100644 --- a/src/lib/process/daemon.c +++ b/src/lib/process/daemon.c @@ -38,6 +38,16 @@ static int finish_daemon_called = 0; /** Socketpair used to communicate between parent and child process while * daemonizing. */ static int daemon_filedes[2]; + +/** + * Return true iff we've called start_daemon() at least once. + */ +bool +start_daemon_has_been_called(void) +{ + return start_daemon_called != 0; +} + /** Start putting the process into daemon mode: fork and drop all resources * except standard fds. The parent process never returns, but stays around * until finish_daemon is called. (Note: it's safe to call this more @@ -168,4 +178,10 @@ finish_daemon(const char *cp) (void)cp; return 0; } +bool +start_daemon_has_been_called(void) +{ + return false; +} + #endif /* !defined(_WIN32) */ diff --git a/src/lib/process/daemon.h b/src/lib/process/daemon.h index c3b78029af..e33bd56701 100644 --- a/src/lib/process/daemon.h +++ b/src/lib/process/daemon.h @@ -11,7 +11,11 @@ #ifndef TOR_DAEMON_H #define TOR_DAEMON_H +#include <stdbool.h> + int start_daemon(void); int finish_daemon(const char *desired_cwd); +bool start_daemon_has_been_called(void); + #endif diff --git a/src/lib/string/util_string.c b/src/lib/string/util_string.c index a6b0a3d68a..b2b85d151d 100644 --- a/src/lib/string/util_string.c +++ b/src/lib/string/util_string.c @@ -451,3 +451,93 @@ string_is_C_identifier(const char *string) return 1; } + +/** A byte with the top <b>x</b> bits set. */ +#define TOP_BITS(x) ((uint8_t)(0xFF << (8 - (x)))) +/** A byte with the lowest <b>x</b> bits set. */ +#define LOW_BITS(x) ((uint8_t)(0xFF >> (8 - (x)))) + +/** Given the leading byte <b>b</b>, return the total number of bytes in the + * UTF-8 character. Returns 0 if it's an invalid leading byte. + */ +static uint8_t +bytes_in_char(uint8_t b) +{ + if ((TOP_BITS(1) & b) == 0x00) + return 1; // a 1-byte UTF-8 char, aka ASCII + if ((TOP_BITS(3) & b) == TOP_BITS(2)) + return 2; // a 2-byte UTF-8 char + if ((TOP_BITS(4) & b) == TOP_BITS(3)) + return 3; // a 3-byte UTF-8 char + if ((TOP_BITS(5) & b) == TOP_BITS(4)) + return 4; // a 4-byte UTF-8 char + + // Invalid: either the top 2 bits are 10, or the top 5 bits are 11111. + return 0; +} + +/** Returns true iff <b>b</b> is a UTF-8 continuation byte. */ +static bool +is_continuation_byte(uint8_t b) +{ + uint8_t top2bits = b & TOP_BITS(2); + return top2bits == TOP_BITS(1); +} + +/** Returns true iff the <b>len</b> bytes in <b>c</b> are a valid UTF-8 + * character. + */ +static bool +validate_char(const uint8_t *c, uint8_t len) +{ + if (len == 1) + return true; // already validated this is an ASCII char earlier. + + uint8_t mask = LOW_BITS(7 - len); // bitmask for the leading byte. + uint32_t codepoint = c[0] & mask; + + mask = LOW_BITS(6); // bitmask for continuation bytes. + for (uint8_t i = 1; i < len; i++) { + if (!is_continuation_byte(c[i])) + return false; + codepoint <<= 6; + codepoint |= (c[i] & mask); + } + + if (len == 2 && codepoint <= 0x7f) + return false; // Invalid, overly long encoding, should have fit in 1 byte. + + if (len == 3 && codepoint <= 0x7ff) + return false; // Invalid, overly long encoding, should have fit in 2 bytes. + + if (len == 4 && codepoint <= 0xffff) + return false; // Invalid, overly long encoding, should have fit in 3 bytes. + + if (codepoint >= 0xd800 && codepoint <= 0xdfff) + return false; // Invalid, reserved for UTF-16 surrogate pairs. + + return codepoint <= 0x10ffff; // Check if within maximum. +} + +/** Returns true iff the first <b>len</b> bytes in <b>str</b> are a + valid UTF-8 string. */ +int +string_is_utf8(const char *str, size_t len) +{ + for (size_t i = 0; i < len;) { + uint8_t num_bytes = bytes_in_char(str[i]); + if (num_bytes == 0) // Invalid leading byte found. + return false; + + size_t next_char = i + num_bytes; + if (next_char > len) + return false; + + // Validate the continuation bytes in this multi-byte character, + // and advance to the next character in the string. + if (!validate_char((const uint8_t*)&str[i], num_bytes)) + return false; + i = next_char; + } + return true; +} diff --git a/src/lib/string/util_string.h b/src/lib/string/util_string.h index 471613462a..746ece0d33 100644 --- a/src/lib/string/util_string.h +++ b/src/lib/string/util_string.h @@ -52,4 +52,6 @@ const char *find_str_at_start_of_line(const char *haystack, int string_is_C_identifier(const char *string); +int string_is_utf8(const char *str, size_t len); + #endif /* !defined(TOR_UTIL_STRING_H) */ diff --git a/src/lib/time/compat_time.c b/src/lib/time/compat_time.c index d26cb6880d..f1ddb4fdc4 100644 --- a/src/lib/time/compat_time.c +++ b/src/lib/time/compat_time.c @@ -237,6 +237,7 @@ monotime_reset_ratchets_for_testing(void) */ static struct mach_timebase_info mach_time_info; static struct mach_timebase_info mach_time_info_msec_cvt; +static int32_t mach_time_msec_cvt_threshold; static int monotime_shift = 0; static void @@ -256,11 +257,15 @@ monotime_init_internal(void) } { // For converting ticks to milliseconds in a 32-bit-friendly way, we - // will first right-shift by 20, and then multiply by 20/19, since - // (1<<20) * 19/20 is about 1e6. We precompute a new numerate and + // will first right-shift by 20, and then multiply by 2048/1953, since + // (1<<20) * 1953/2048 is about 1e6. We precompute a new numerator and // denominator here to avoid multiple multiplies. - mach_time_info_msec_cvt.numer = mach_time_info.numer * 20; - mach_time_info_msec_cvt.denom = mach_time_info.denom * 19; + mach_time_info_msec_cvt.numer = mach_time_info.numer * 2048; + mach_time_info_msec_cvt.denom = mach_time_info.denom * 1953; + // For any value above this amount, we should divide before multiplying, + // to avoid overflow. For a value below this, we should multiply + // before dividing, to improve accuracy. + mach_time_msec_cvt_threshold = INT32_MAX / mach_time_info_msec_cvt.numer; } } @@ -323,8 +328,13 @@ monotime_coarse_diff_msec32_(const monotime_coarse_t *start, /* We already require in di_ops.c that right-shift performs a sign-extend. */ const int32_t diff_microticks = (int32_t)(diff_ticks >> 20); - return (diff_microticks * mach_time_info_msec_cvt.numer) / - mach_time_info_msec_cvt.denom; + if (diff_microticks >= mach_time_msec_cvt_threshold) { + return (diff_microticks / mach_time_info_msec_cvt.denom) * + mach_time_info_msec_cvt.numer; + } else { + return (diff_microticks * mach_time_info_msec_cvt.numer) / + mach_time_info_msec_cvt.denom; + } } uint32_t diff --git a/src/lib/time/compat_time.h b/src/lib/time/compat_time.h index 4427ce8f92..44fab62de5 100644 --- a/src/lib/time/compat_time.h +++ b/src/lib/time/compat_time.h @@ -200,6 +200,7 @@ monotime_coarse_diff_msec32(const monotime_coarse_t *start, // on a 64-bit platform, let's assume 64/64 division is cheap. return (int32_t) monotime_coarse_diff_msec(start, end); #else +#define USING_32BIT_MSEC_HACK return monotime_coarse_diff_msec32_(start, end); #endif } diff --git a/src/lib/time/tvdiff.c b/src/lib/time/tvdiff.c index 8617110e52..bc8a1166e7 100644 --- a/src/lib/time/tvdiff.c +++ b/src/lib/time/tvdiff.c @@ -165,3 +165,25 @@ tv_to_msec(const struct timeval *tv) conv += ((int64_t)tv->tv_usec+500)/1000L; return conv; } + +/** + * Return duration in seconds between time_t values + * <b>t1</b> and <b>t2</b> iff <b>t1</b> is numerically + * less or equal than <b>t2</b>. Otherwise, return TIME_MAX. + * + * This provides a safe way to compute difference between + * two UNIX timestamps (<b>t2</b> can be assumed by calling + * code to be later than <b>t1</b>) or two durations measured + * in seconds (<b>t2</b> can be assumed to be longer than + * <b>t1</b>). Calling code is expected to check for TIME_MAX + * return value and interpret that as error condition. + */ +time_t +time_diff(const time_t t1, const time_t t2) +{ + if (t1 <= t2) + return t2 - t1; + + return TIME_MAX; +} + diff --git a/src/lib/time/tvdiff.h b/src/lib/time/tvdiff.h index d78330d7d8..a15ce52ad6 100644 --- a/src/lib/time/tvdiff.h +++ b/src/lib/time/tvdiff.h @@ -18,4 +18,6 @@ long tv_udiff(const struct timeval *start, const struct timeval *end); long tv_mdiff(const struct timeval *start, const struct timeval *end); int64_t tv_to_msec(const struct timeval *tv); +time_t time_diff(const time_t from, const time_t to); + #endif diff --git a/src/lib/tls/.may_include b/src/lib/tls/.may_include index ca7cb455e4..2840e590b8 100644 --- a/src/lib/tls/.may_include +++ b/src/lib/tls/.may_include @@ -8,6 +8,7 @@ lib/ctime/*.h lib/encoding/*.h lib/intmath/*.h lib/log/*.h +lib/malloc/*.h lib/net/*.h lib/string/*.h lib/testsupport/testsupport.h diff --git a/src/lib/tls/include.am b/src/lib/tls/include.am index b25e2e16bf..a664b29fb2 100644 --- a/src/lib/tls/include.am +++ b/src/lib/tls/include.am @@ -12,6 +12,7 @@ src_lib_libtor_tls_a_SOURCES = \ if USE_NSS src_lib_libtor_tls_a_SOURCES += \ + src/lib/tls/nss_countbytes.c \ src/lib/tls/tortls_nss.c \ src/lib/tls/x509_nss.c else @@ -31,6 +32,7 @@ src_lib_libtor_tls_testing_a_CFLAGS = \ noinst_HEADERS += \ src/lib/tls/ciphers.inc \ src/lib/tls/buffers_tls.h \ + src/lib/tls/nss_countbytes.h \ src/lib/tls/tortls.h \ src/lib/tls/tortls_internal.h \ src/lib/tls/tortls_st.h \ diff --git a/src/lib/tls/nss_countbytes.c b/src/lib/tls/nss_countbytes.c new file mode 100644 index 0000000000..c727684529 --- /dev/null +++ b/src/lib/tls/nss_countbytes.c @@ -0,0 +1,244 @@ +/* Copyright 2018, The Tor Project Inc. */ +/* See LICENSE for licensing information */ + +/** + * \file nss_countbytes.c + * \brief A PRFileDesc layer to let us count the number of bytes + * bytes actually written on a PRFileDesc. + **/ + +#include "orconfig.h" + +#include "lib/log/util_bug.h" +#include "lib/malloc/malloc.h" +#include "lib/tls/nss_countbytes.h" + +#include <stdlib.h> +#include <string.h> + +#include <prio.h> + +/** Boolean: have we initialized this module */ +static bool countbytes_initialized = false; + +/** Integer to identity this layer. */ +static PRDescIdentity countbytes_layer_id = PR_INVALID_IO_LAYER; + +/** Table of methods for this layer.*/ +static PRIOMethods countbytes_methods; + +/** Default close function provided by NSPR. We use this to help + * implement our own close function.*/ +static PRStatus(*default_close_fn)(PRFileDesc *fd); + +static PRStatus countbytes_close_fn(PRFileDesc *fd); +static PRInt32 countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount); +static PRInt32 countbytes_write_fn(PRFileDesc *fd, const void *buf, + PRInt32 amount); +static PRInt32 countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov, + PRInt32 size, PRIntervalTime timeout); +static PRInt32 countbytes_send_fn(PRFileDesc *fd, const void *buf, + PRInt32 amount, PRIntn flags, + PRIntervalTime timeout); +static PRInt32 countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount, + PRIntn flags, PRIntervalTime timeout); + +/** Private fields for the byte-counter layer. We cast this to and from + * PRFilePrivate*, which is supposed to be allowed. */ +typedef struct tor_nss_bytecounts_t { + uint64_t n_read; + uint64_t n_written; +} tor_nss_bytecounts_t; + +/** + * Initialize this module, if it is not already initialized. + **/ +void +tor_nss_countbytes_init(void) +{ + if (countbytes_initialized) + return; + + countbytes_layer_id = PR_GetUniqueIdentity("Tor byte-counting layer"); + tor_assert(countbytes_layer_id != PR_INVALID_IO_LAYER); + + memcpy(&countbytes_methods, PR_GetDefaultIOMethods(), sizeof(PRIOMethods)); + + default_close_fn = countbytes_methods.close; + countbytes_methods.close = countbytes_close_fn; + countbytes_methods.read = countbytes_read_fn; + countbytes_methods.write = countbytes_write_fn; + countbytes_methods.writev = countbytes_writev_fn; + countbytes_methods.send = countbytes_send_fn; + countbytes_methods.recv = countbytes_recv_fn; + /* NOTE: We aren't wrapping recvfrom, sendto, or sendfile, since I think + * NSS won't be using them for TLS connections. */ + + countbytes_initialized = true; +} + +/** + * Return the tor_nss_bytecounts_t object for a given IO layer. Asserts that + * the IO layer is in fact a layer created by this module. + */ +static tor_nss_bytecounts_t * +get_counts(PRFileDesc *fd) +{ + tor_assert(fd->identity == countbytes_layer_id); + return (tor_nss_bytecounts_t*) fd->secret; +} + +/** Helper: increment the read-count of an fd by n. */ +#define INC_READ(fd, n) STMT_BEGIN \ + get_counts(fd)->n_read += (n); \ + STMT_END + +/** Helper: increment the write-count of an fd by n. */ +#define INC_WRITTEN(fd, n) STMT_BEGIN \ + get_counts(fd)->n_written += (n); \ + STMT_END + +/** Implementation for PR_Close: frees the 'secret' field, then passes control + * to the default close function */ +static PRStatus +countbytes_close_fn(PRFileDesc *fd) +{ + tor_assert(fd); + + tor_nss_bytecounts_t *counts = (tor_nss_bytecounts_t *)fd->secret; + tor_free(counts); + fd->secret = NULL; + + return default_close_fn(fd); +} + +/** Implementation for PR_Read: Calls the lower-level read function, + * and records what it said. */ +static PRInt32 +countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount) +{ + tor_assert(fd); + tor_assert(fd->lower); + + PRInt32 result = (fd->lower->methods->read)(fd->lower, buf, amount); + if (result > 0) + INC_READ(fd, result); + return result; +} +/** Implementation for PR_Write: Calls the lower-level write function, + * and records what it said. */ +static PRInt32 +countbytes_write_fn(PRFileDesc *fd, const void *buf, PRInt32 amount) +{ + tor_assert(fd); + tor_assert(fd->lower); + + PRInt32 result = (fd->lower->methods->write)(fd->lower, buf, amount); + if (result > 0) + INC_WRITTEN(fd, result); + return result; +} +/** Implementation for PR_Writev: Calls the lower-level writev function, + * and records what it said. */ +static PRInt32 +countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov, + PRInt32 size, PRIntervalTime timeout) +{ + tor_assert(fd); + tor_assert(fd->lower); + + PRInt32 result = (fd->lower->methods->writev)(fd->lower, iov, size, timeout); + if (result > 0) + INC_WRITTEN(fd, result); + return result; +} +/** Implementation for PR_Send: Calls the lower-level send function, + * and records what it said. */ +static PRInt32 +countbytes_send_fn(PRFileDesc *fd, const void *buf, + PRInt32 amount, PRIntn flags, PRIntervalTime timeout) +{ + tor_assert(fd); + tor_assert(fd->lower); + + PRInt32 result = (fd->lower->methods->send)(fd->lower, buf, amount, flags, + timeout); + if (result > 0) + INC_WRITTEN(fd, result); + return result; +} +/** Implementation for PR_Recv: Calls the lower-level recv function, + * and records what it said. */ +static PRInt32 +countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount, + PRIntn flags, PRIntervalTime timeout) +{ + tor_assert(fd); + tor_assert(fd->lower); + + PRInt32 result = (fd->lower->methods->recv)(fd->lower, buf, amount, flags, + timeout); + if (result > 0) + INC_READ(fd, result); + return result; +} + +/** + * Wrap a PRFileDesc from NSPR with a new PRFileDesc that will count the + * total number of bytes read and written. Return the new PRFileDesc. + * + * This function takes ownership of its input. + */ +PRFileDesc * +tor_wrap_prfiledesc_with_byte_counter(PRFileDesc *stack) +{ + if (BUG(! countbytes_initialized)) { + tor_nss_countbytes_init(); + } + + tor_nss_bytecounts_t *bytecounts = tor_malloc_zero(sizeof(*bytecounts)); + + PRFileDesc *newfd = PR_CreateIOLayerStub(countbytes_layer_id, + &countbytes_methods); + tor_assert(newfd); + newfd->secret = (PRFilePrivate *)bytecounts; + + /* This does some complicated messing around with the headers of these + objects; see the NSPR documentation for more. The upshot is that + after PushIOLayer, "stack" will be the head of the stack. + */ + PRStatus status = PR_PushIOLayer(stack, PR_TOP_IO_LAYER, newfd); + tor_assert(status == PR_SUCCESS); + + return stack; +} + +/** + * Given a PRFileDesc returned by tor_wrap_prfiledesc_with_byte_counter(), + * or another PRFileDesc wrapping that PRFileDesc, set the provided + * pointers to the number of bytes read and written on the descriptor since + * it was created. + * + * Return 0 on success, -1 on failure. + */ +int +tor_get_prfiledesc_byte_counts(PRFileDesc *fd, + uint64_t *n_read_out, + uint64_t *n_written_out) +{ + if (BUG(! countbytes_initialized)) { + tor_nss_countbytes_init(); + } + + tor_assert(fd); + PRFileDesc *bclayer = PR_GetIdentitiesLayer(fd, countbytes_layer_id); + if (BUG(bclayer == NULL)) + return -1; + + tor_nss_bytecounts_t *counts = get_counts(bclayer); + + *n_read_out = counts->n_read; + *n_written_out = counts->n_written; + + return 0; +} diff --git a/src/lib/tls/nss_countbytes.h b/src/lib/tls/nss_countbytes.h new file mode 100644 index 0000000000..f26280edf2 --- /dev/null +++ b/src/lib/tls/nss_countbytes.h @@ -0,0 +1,25 @@ +/* Copyright 2018, The Tor Project, Inc. */ +/* See LICENSE for licensing information */ + +/** + * \file nss_countbytes.h + * \brief Header for nss_countbytes.c, which lets us count the number of + * bytes actually written on a PRFileDesc. + **/ + +#ifndef TOR_NSS_COUNTBYTES_H +#define TOR_NSS_COUNTBYTES_H + +#include "lib/cc/torint.h" + +void tor_nss_countbytes_init(void); + +struct PRFileDesc; +struct PRFileDesc *tor_wrap_prfiledesc_with_byte_counter( + struct PRFileDesc *stack); + +int tor_get_prfiledesc_byte_counts(struct PRFileDesc *fd, + uint64_t *n_read_out, + uint64_t *n_written_out); + +#endif diff --git a/src/lib/tls/tortls.c b/src/lib/tls/tortls.c index 3ae3a1a096..56f70bc371 100644 --- a/src/lib/tls/tortls.c +++ b/src/lib/tls/tortls.c @@ -71,13 +71,19 @@ tor_tls_get_my_certs(int server, const tor_x509_cert_t **id_cert_out) { tor_tls_context_t *ctx = tor_tls_context_get(server); - if (! ctx) - return -1; + int rv = -1; + const tor_x509_cert_t *link_cert = NULL; + const tor_x509_cert_t *id_cert = NULL; + if (ctx) { + rv = 0; + link_cert = server ? ctx->my_link_cert : ctx->my_auth_cert; + id_cert = ctx->my_id_cert; + } if (link_cert_out) - *link_cert_out = server ? ctx->my_link_cert : ctx->my_auth_cert; + *link_cert_out = link_cert; if (id_cert_out) - *id_cert_out = ctx->my_id_cert; - return 0; + *id_cert_out = id_cert; + return rv; } /** diff --git a/src/lib/tls/tortls.h b/src/lib/tls/tortls.h index 4591927081..81db5ce5a9 100644 --- a/src/lib/tls/tortls.h +++ b/src/lib/tls/tortls.h @@ -94,6 +94,7 @@ void tor_tls_set_renegotiate_callback(tor_tls_t *tls, void (*cb)(tor_tls_t *, void *arg), void *arg); int tor_tls_is_server(tor_tls_t *tls); +void tor_tls_release_socket(tor_tls_t *tls); void tor_tls_free_(tor_tls_t *tls); #define tor_tls_free(tls) FREE_AND_NULL(tor_tls_t, tor_tls_free_, (tls)) int tor_tls_peer_has_cert(tor_tls_t *tls); @@ -125,6 +126,10 @@ int tor_tls_server_got_renegotiate(tor_tls_t *tls); MOCK_DECL(int,tor_tls_cert_matches_key,(const tor_tls_t *tls, const struct tor_x509_cert_t *cert)); MOCK_DECL(int,tor_tls_get_tlssecrets,(tor_tls_t *tls, uint8_t *secrets_out)); +#ifdef ENABLE_OPENSSL +/* OpenSSL lets us see these master secrets; NSS sensibly does not. */ +#define HAVE_WORKING_TOR_TLS_GET_TLSSECRETS +#endif MOCK_DECL(int,tor_tls_export_key_material,( tor_tls_t *tls, uint8_t *secrets_out, const uint8_t *context, diff --git a/src/lib/tls/tortls_nss.c b/src/lib/tls/tortls_nss.c index 53adfedf32..462cd5b0ff 100644 --- a/src/lib/tls/tortls_nss.c +++ b/src/lib/tls/tortls_nss.c @@ -31,11 +31,12 @@ #include "lib/tls/tortls.h" #include "lib/tls/tortls_st.h" #include "lib/tls/tortls_internal.h" +#include "lib/tls/nss_countbytes.h" #include "lib/log/util_bug.h" DISABLE_GCC_WARNING(strict-prototypes) #include <prio.h> -// For access to raw sockets. +// For access to rar sockets. #include <private/pprio.h> #include <ssl.h> #include <sslt.h> @@ -158,6 +159,8 @@ tor_tls_context_new(crypto_pk_t *identity, SECStatus s; tor_assert(identity); + tor_tls_init(); + tor_tls_context_t *ctx = tor_malloc_zero(sizeof(tor_tls_context_t)); ctx->refcnt = 1; @@ -320,7 +323,7 @@ tor_tls_get_state_description(tor_tls_t *tls, char *buf, size_t sz) void tor_tls_init(void) { - /* We don't have any global setup to do yet, but that will change */ + tor_nss_countbytes_init(); } void @@ -373,7 +376,11 @@ tor_tls_new(tor_socket_t sock, int is_server) if (!tcp) return NULL; - PRFileDesc *ssl = SSL_ImportFD(ctx->ctx, tcp); + PRFileDesc *count = tor_wrap_prfiledesc_with_byte_counter(tcp); + if (! count) + return NULL; + + PRFileDesc *ssl = SSL_ImportFD(ctx->ctx, count); if (!ssl) { PR_Close(tcp); return NULL; @@ -414,6 +421,47 @@ tor_tls_set_renegotiate_callback(tor_tls_t *tls, /* We don't support renegotiation-based TLS with NSS. */ } +/** + * Tell the TLS library that the underlying socket for <b>tls</b> has been + * closed, and the library should not attempt to free that socket itself. + */ +void +tor_tls_release_socket(tor_tls_t *tls) +{ + if (! tls) + return; + + /* NSS doesn't have the equivalent of BIO_NO_CLOSE. If you replace the + * fd with something that's invalid, it causes a memory leak in PR_Close. + * + * If there were a way to put the PRFileDesc into the CLOSED state, that + * would prevent it from closing its fd -- but there doesn't seem to be a + * supported way to do that either. + * + * So instead: we make a new sacrificial socket, and replace the original + * socket with that one. This seems to be the best we can do, until we + * redesign the mainloop code enough to make this function unnecessary. + */ + tor_socket_t sock = + tor_open_socket_nonblocking(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (! SOCKET_OK(sock)) { + log_warn(LD_NET, "Out of sockets when trying to shut down an NSS " + "connection"); + return; + } + + PRFileDesc *tcp = PR_GetIdentitiesLayer(tls->ssl, PR_NSPR_IO_LAYER); + if (BUG(! tcp)) { + tor_close_socket(sock); + return; + } + + PR_ChangeFileDescNativeHandle(tcp, sock); + /* Tell our socket accounting layer that we don't own this socket any more: + * NSS is about to free it for us. */ + tor_release_socket_ownership(sock); +} + void tor_tls_impl_free_(tor_tls_impl_t *tls) { @@ -465,7 +513,6 @@ tor_tls_read, (tor_tls_t *tls, char *cp, size_t len)) PRInt32 rv = PR_Read(tls->ssl, cp, (int)len); // log_debug(LD_NET, "PR_Read(%zu) returned %d", n, (int)rv); if (rv > 0) { - tls->n_read_since_last_check += rv; return rv; } if (rv == 0) @@ -489,7 +536,6 @@ tor_tls_write(tor_tls_t *tls, const char *cp, size_t n) PRInt32 rv = PR_Write(tls->ssl, cp, (int)n); // log_debug(LD_NET, "PR_Write(%zu) returned %d", n, (int)rv); if (rv > 0) { - tls->n_written_since_last_check += rv; return rv; } if (rv == 0) @@ -579,13 +625,17 @@ tor_tls_get_n_raw_bytes(tor_tls_t *tls, tor_assert(tls); tor_assert(n_read); tor_assert(n_written); - /* XXXX We don't curently have a way to measure this information correctly - * in NSS; we could do that with a PRIO layer, but it'll take a little - * coding. For now, we just track the number of bytes sent _in_ the TLS - * stream. Doing this will make our rate-limiting slightly inaccurate. */ - *n_read = tls->n_read_since_last_check; - *n_written = tls->n_written_since_last_check; - tls->n_read_since_last_check = tls->n_written_since_last_check = 0; + uint64_t r, w; + if (tor_get_prfiledesc_byte_counts(tls->ssl, &r, &w) < 0) { + *n_read = *n_written = 0; + return; + } + + *n_read = (size_t)(r - tls->last_read_count); + *n_written = (size_t)(w - tls->last_write_count); + + tls->last_read_count = r; + tls->last_write_count = w; } int diff --git a/src/lib/tls/tortls_openssl.c b/src/lib/tls/tortls_openssl.c index dc6c0bee9c..227225b96e 100644 --- a/src/lib/tls/tortls_openssl.c +++ b/src/lib/tls/tortls_openssl.c @@ -1048,7 +1048,7 @@ tor_tls_new(tor_socket_t sock, int isServer) goto err; } result->socket = sock; - bio = BIO_new_socket(sock, BIO_NOCLOSE); + bio = BIO_new_socket(sock, BIO_CLOSE); if (! bio) { tls_log_errors(NULL, LOG_WARN, LD_NET, "opening BIO"); #ifdef SSL_set_tlsext_host_name @@ -1154,6 +1154,28 @@ tor_tls_assert_renegotiation_unblocked(tor_tls_t *tls) #endif /* defined(SSL_OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION) && ... */ } +/** + * Tell the TLS library that the underlying socket for <b>tls</b> has been + * closed, and the library should not attempt to free that socket itself. + */ +void +tor_tls_release_socket(tor_tls_t *tls) +{ + if (! tls) + return; + + BIO *rbio, *wbio; + rbio = SSL_get_rbio(tls->ssl); + wbio = SSL_get_wbio(tls->ssl); + + if (rbio) { + (void) BIO_set_close(rbio, BIO_NOCLOSE); + } + if (wbio && wbio != rbio) { + (void) BIO_set_close(wbio, BIO_NOCLOSE); + } +} + void tor_tls_impl_free_(tor_tls_impl_t *ssl) { diff --git a/src/lib/tls/tortls_st.h b/src/lib/tls/tortls_st.h index a1b59a37af..549443a4e7 100644 --- a/src/lib/tls/tortls_st.h +++ b/src/lib/tls/tortls_st.h @@ -66,8 +66,9 @@ struct tor_tls_t { void *callback_arg; #endif #ifdef ENABLE_NSS - size_t n_read_since_last_check; - size_t n_written_since_last_check; + /** Last values retried from tor_get_prfiledesc_byte_counts(). */ + uint64_t last_write_count; + uint64_t last_read_count; #endif }; |