summaryrefslogtreecommitdiff
path: root/src/lib
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib')
-rw-r--r--src/lib/container/smartlist.c2
-rw-r--r--src/lib/container/smartlist.h2
-rw-r--r--src/lib/crypt_ops/aes_openssl.c10
-rw-r--r--src/lib/crypt_ops/crypto_dh_nss.c2
-rw-r--r--src/lib/crypt_ops/crypto_init.c11
-rw-r--r--src/lib/crypt_ops/crypto_init.h2
-rw-r--r--src/lib/crypt_ops/crypto_ope.c12
-rw-r--r--src/lib/crypt_ops/crypto_pwbox.c1
-rw-r--r--src/lib/crypt_ops/crypto_rand.c12
-rw-r--r--src/lib/crypt_ops/crypto_rand.h3
-rw-r--r--src/lib/crypt_ops/crypto_rsa.c12
-rw-r--r--src/lib/encoding/confline.c3
-rw-r--r--src/lib/evloop/procmon.c3
-rw-r--r--src/lib/evloop/workqueue.c6
-rw-r--r--src/lib/evloop/workqueue.h1
-rw-r--r--src/lib/geoip/.may_include13
-rw-r--r--src/lib/geoip/country.h16
-rw-r--r--src/lib/geoip/geoip.c510
-rw-r--r--src/lib/geoip/geoip.h50
-rw-r--r--src/lib/geoip/include.am17
-rw-r--r--src/lib/log/util_bug.h63
-rw-r--r--src/lib/net/address.c12
-rw-r--r--src/lib/net/inaddr.c8
-rw-r--r--src/lib/net/socket.c88
-rw-r--r--src/lib/net/socket.h1
-rw-r--r--src/lib/net/socketpair.c1
-rw-r--r--src/lib/process/daemon.c16
-rw-r--r--src/lib/process/daemon.h4
-rw-r--r--src/lib/string/util_string.c90
-rw-r--r--src/lib/string/util_string.h2
-rw-r--r--src/lib/time/compat_time.c22
-rw-r--r--src/lib/time/compat_time.h1
-rw-r--r--src/lib/time/tvdiff.c22
-rw-r--r--src/lib/time/tvdiff.h2
-rw-r--r--src/lib/tls/.may_include1
-rw-r--r--src/lib/tls/include.am2
-rw-r--r--src/lib/tls/nss_countbytes.c244
-rw-r--r--src/lib/tls/nss_countbytes.h25
-rw-r--r--src/lib/tls/tortls.c16
-rw-r--r--src/lib/tls/tortls.h5
-rw-r--r--src/lib/tls/tortls_nss.c74
-rw-r--r--src/lib/tls/tortls_openssl.c24
-rw-r--r--src/lib/tls/tortls_st.h5
43 files changed, 1319 insertions, 97 deletions
diff --git a/src/lib/container/smartlist.c b/src/lib/container/smartlist.c
index 4b29d834d9..64cabfcc6f 100644
--- a/src/lib/container/smartlist.c
+++ b/src/lib/container/smartlist.c
@@ -408,7 +408,7 @@ smartlist_uniq(smartlist_t *sl,
* less than member, and greater than 0 if key is greater then member.
*/
void *
-smartlist_bsearch(smartlist_t *sl, const void *key,
+smartlist_bsearch(const smartlist_t *sl, const void *key,
int (*compare)(const void *key, const void **member))
{
int found, idx;
diff --git a/src/lib/container/smartlist.h b/src/lib/container/smartlist.h
index 9705396ac9..0f5af3a923 100644
--- a/src/lib/container/smartlist.h
+++ b/src/lib/container/smartlist.h
@@ -64,7 +64,7 @@ const uint8_t *smartlist_get_most_frequent_digest256(smartlist_t *sl);
void smartlist_uniq_strings(smartlist_t *sl);
void smartlist_uniq_digests(smartlist_t *sl);
void smartlist_uniq_digests256(smartlist_t *sl);
-void *smartlist_bsearch(smartlist_t *sl, const void *key,
+void *smartlist_bsearch(const smartlist_t *sl, const void *key,
int (*compare)(const void *key, const void **member));
int smartlist_bsearch_idx(const smartlist_t *sl, const void *key,
int (*compare)(const void *key, const void **member),
diff --git a/src/lib/crypt_ops/aes_openssl.c b/src/lib/crypt_ops/aes_openssl.c
index 387f5d3df0..f2990fc06d 100644
--- a/src/lib/crypt_ops/aes_openssl.c
+++ b/src/lib/crypt_ops/aes_openssl.c
@@ -11,7 +11,9 @@
#include "orconfig.h"
#include "lib/crypt_ops/aes.h"
+#include "lib/crypt_ops/crypto_util.h"
#include "lib/log/util_bug.h"
+#include "lib/arch/bytes.h"
#ifdef _WIN32 /*wrkard for dtls1.h >= 0.9.8m of "#include <winsock.h>"*/
#include <winsock2.h>
@@ -396,10 +398,10 @@ static void
aes_set_iv(aes_cnt_cipher_t *cipher, const uint8_t *iv)
{
#ifdef USING_COUNTER_VARS
- cipher->counter3 = ntohl(get_uint32(iv));
- cipher->counter2 = ntohl(get_uint32(iv+4));
- cipher->counter1 = ntohl(get_uint32(iv+8));
- cipher->counter0 = ntohl(get_uint32(iv+12));
+ cipher->counter3 = tor_ntohl(get_uint32(iv));
+ cipher->counter2 = tor_ntohl(get_uint32(iv+4));
+ cipher->counter1 = tor_ntohl(get_uint32(iv+8));
+ cipher->counter0 = tor_ntohl(get_uint32(iv+12));
#endif /* defined(USING_COUNTER_VARS) */
cipher->pos = 0;
memcpy(cipher->ctr_buf.buf, iv, 16);
diff --git a/src/lib/crypt_ops/crypto_dh_nss.c b/src/lib/crypt_ops/crypto_dh_nss.c
index 9a14b809b4..e2d9040f5e 100644
--- a/src/lib/crypt_ops/crypto_dh_nss.c
+++ b/src/lib/crypt_ops/crypto_dh_nss.c
@@ -53,6 +53,8 @@ crypto_dh_init_nss(void)
circuit_dh_param.prime.len = DH1024_KEY_LEN;
circuit_dh_param.base.data = dh_generator_data;
circuit_dh_param.base.len = 1;
+
+ dh_initialized = 1;
}
void
diff --git a/src/lib/crypt_ops/crypto_init.c b/src/lib/crypt_ops/crypto_init.c
index c731662d49..9d6e2da0d0 100644
--- a/src/lib/crypt_ops/crypto_init.c
+++ b/src/lib/crypt_ops/crypto_init.c
@@ -191,3 +191,14 @@ crypto_get_header_version_string(void)
return crypto_nss_get_header_version_str();
#endif
}
+
+/** Return true iff Tor is using the NSS library. */
+int
+tor_is_using_nss(void)
+{
+#ifdef ENABLE_NSS
+ return 1;
+#else
+ return 0;
+#endif
+}
diff --git a/src/lib/crypt_ops/crypto_init.h b/src/lib/crypt_ops/crypto_init.h
index 5b6d65d48c..b71f144276 100644
--- a/src/lib/crypt_ops/crypto_init.h
+++ b/src/lib/crypt_ops/crypto_init.h
@@ -31,4 +31,6 @@ const char *crypto_get_library_name(void);
const char *crypto_get_library_version_string(void);
const char *crypto_get_header_version_string(void);
+int tor_is_using_nss(void);
+
#endif /* !defined(TOR_CRYPTO_H) */
diff --git a/src/lib/crypt_ops/crypto_ope.c b/src/lib/crypt_ops/crypto_ope.c
index fd5d5f3770..789517eba2 100644
--- a/src/lib/crypt_ops/crypto_ope.c
+++ b/src/lib/crypt_ops/crypto_ope.c
@@ -48,17 +48,17 @@ struct crypto_ope_t {
/** The type to add up in order to produce our OPE ciphertexts */
typedef uint16_t ope_val_t;
-#ifdef WORDS_BIG_ENDIAN
-/** Convert an OPE value to little-endian */
+#ifdef WORDS_BIGENDIAN
+/** Convert an OPE value from little-endian. */
static inline ope_val_t
-ope_val_to_le(ope_val_t x)
+ope_val_from_le(ope_val_t x)
{
return
((x) >> 8) |
(((x)&0xff) << 8);
}
#else
-#define ope_val_to_le(x) (x)
+#define ope_val_from_le(x) (x)
#endif
/**
@@ -104,7 +104,7 @@ sum_values_from_cipher(crypto_cipher_t *c, size_t n)
crypto_cipher_crypt_inplace(c, (char*)buf, BUFSZ*sizeof(ope_val_t));
for (i = 0; i < BUFSZ; ++i) {
- total += ope_val_to_le(buf[i]);
+ total += ope_val_from_le(buf[i]);
total += 1;
}
n -= BUFSZ;
@@ -113,7 +113,7 @@ sum_values_from_cipher(crypto_cipher_t *c, size_t n)
memset(buf, 0, n*sizeof(ope_val_t));
crypto_cipher_crypt_inplace(c, (char*)buf, n*sizeof(ope_val_t));
for (i = 0; i < n; ++i) {
- total += ope_val_to_le(buf[i]);
+ total += ope_val_from_le(buf[i]);
total += 1;
}
diff --git a/src/lib/crypt_ops/crypto_pwbox.c b/src/lib/crypt_ops/crypto_pwbox.c
index 2377f216a0..91536e891b 100644
--- a/src/lib/crypt_ops/crypto_pwbox.c
+++ b/src/lib/crypt_ops/crypto_pwbox.c
@@ -61,6 +61,7 @@ crypto_pwbox(uint8_t **out, size_t *outlen_out,
int rv;
enc = pwbox_encoded_new();
+ tor_assert(enc);
pwbox_encoded_setlen_skey_header(enc, S2K_MAXLEN);
diff --git a/src/lib/crypt_ops/crypto_rand.c b/src/lib/crypt_ops/crypto_rand.c
index 313d829a57..cffd0610f3 100644
--- a/src/lib/crypt_ops/crypto_rand.c
+++ b/src/lib/crypt_ops/crypto_rand.c
@@ -335,8 +335,18 @@ crypto_strongest_rand_raw(uint8_t *out, size_t out_len)
* Try to get <b>out_len</b> bytes of the strongest entropy we can generate,
* storing it into <b>out</b>.
**/
+void
+crypto_strongest_rand(uint8_t *out, size_t out_len)
+{
+ crypto_strongest_rand_(out, out_len);
+}
+
+/**
+ * Try to get <b>out_len</b> bytes of the strongest entropy we can generate,
+ * storing it into <b>out</b>. (Mockable version.)
+ **/
MOCK_IMPL(void,
-crypto_strongest_rand,(uint8_t *out, size_t out_len))
+crypto_strongest_rand_,(uint8_t *out, size_t out_len))
{
#define DLEN DIGEST512_LEN
diff --git a/src/lib/crypt_ops/crypto_rand.h b/src/lib/crypt_ops/crypto_rand.h
index 25bcfa1f1c..0c538d81ac 100644
--- a/src/lib/crypt_ops/crypto_rand.h
+++ b/src/lib/crypt_ops/crypto_rand.h
@@ -21,7 +21,8 @@
int crypto_seed_rng(void) ATTR_WUR;
MOCK_DECL(void,crypto_rand,(char *to, size_t n));
void crypto_rand_unmocked(char *to, size_t n);
-MOCK_DECL(void,crypto_strongest_rand,(uint8_t *out, size_t out_len));
+void crypto_strongest_rand(uint8_t *out, size_t out_len);
+MOCK_DECL(void,crypto_strongest_rand_,(uint8_t *out, size_t out_len));
int crypto_rand_int(unsigned int max);
int crypto_rand_int_range(unsigned int min, unsigned int max);
uint64_t crypto_rand_uint64_range(uint64_t min, uint64_t max);
diff --git a/src/lib/crypt_ops/crypto_rsa.c b/src/lib/crypt_ops/crypto_rsa.c
index 6a9e2948f1..a510e12964 100644
--- a/src/lib/crypt_ops/crypto_rsa.c
+++ b/src/lib/crypt_ops/crypto_rsa.c
@@ -540,6 +540,9 @@ crypto_pk_read_private_key_from_string(crypto_pk_t *env,
return crypto_pk_read_from_string_generic(env, src, len, true);
}
+/** If a file is longer than this, we won't try to decode its private key */
+#define MAX_PRIVKEY_FILE_LEN (16*1024*1024)
+
/** Read a PEM-encoded private key from the file named by
* <b>keyfile</b> into <b>env</b>. Return 0 on success, -1 on failure.
*/
@@ -551,9 +554,14 @@ crypto_pk_read_private_key_from_filename(crypto_pk_t *env,
char *buf = read_file_to_str(keyfile, 0, &st);
if (!buf)
return -1;
+ if (st.st_size > MAX_PRIVKEY_FILE_LEN) {
+ tor_free(buf);
+ return -1;
+ }
- int rv = crypto_pk_read_private_key_from_string(env, buf, st.st_size);
- memwipe(buf, 0, st.st_size);
+ int rv = crypto_pk_read_private_key_from_string(env, buf,
+ (ssize_t)st.st_size);
+ memwipe(buf, 0, (size_t)st.st_size);
tor_free(buf);
return rv;
}
diff --git a/src/lib/encoding/confline.c b/src/lib/encoding/confline.c
index dd5193d3a7..71ce5b8424 100644
--- a/src/lib/encoding/confline.c
+++ b/src/lib/encoding/confline.c
@@ -148,6 +148,9 @@ config_get_lines_aux(const char *string, config_line_t **result, int extended,
tor_free(v);
return -1;
}
+ log_notice(LD_CONFIG, "Included configuration file or "
+ "directory at recursion level %d: \"%s\".",
+ recursion_level, v);
*next = include_list;
if (list_last)
next = &list_last->next;
diff --git a/src/lib/evloop/procmon.c b/src/lib/evloop/procmon.c
index e0c26caab2..02e167377f 100644
--- a/src/lib/evloop/procmon.c
+++ b/src/lib/evloop/procmon.c
@@ -20,6 +20,9 @@
#ifdef HAVE_ERRNO_H
#include <errno.h>
#endif
+#ifdef HAVE_SYS_TIME_H
+#include <sys/time.h>
+#endif
#ifdef _WIN32
#include <winsock2.h>
diff --git a/src/lib/evloop/workqueue.c b/src/lib/evloop/workqueue.c
index 931f65e710..5471f87b04 100644
--- a/src/lib/evloop/workqueue.c
+++ b/src/lib/evloop/workqueue.c
@@ -15,7 +15,7 @@
*
* The main thread informs the worker threads of pending work by using a
* condition variable. The workers inform the main process of completed work
- * by using an alert_sockets_t object, as implemented in compat_threads.c.
+ * by using an alert_sockets_t object, as implemented in net/alertsock.c.
*
* The main thread can also queue an "update" that will be handled by all the
* workers. This is useful for updating state that all the workers share.
@@ -622,8 +622,8 @@ reply_event_cb(evutil_socket_t sock, short events, void *arg)
tp->reply_cb(tp);
}
-/** Register the threadpool <b>tp</b>'s reply queue with the libevent
- * mainloop of <b>base</b>. If <b>tp</b> is provided, it is run after
+/** Register the threadpool <b>tp</b>'s reply queue with Tor's global
+ * libevent mainloop. If <b>cb</b> is provided, it is run after
* each time there is work to process from the reply queue. Return 0 on
* success, -1 on failure.
*/
diff --git a/src/lib/evloop/workqueue.h b/src/lib/evloop/workqueue.h
index da292d1f05..10d5d47464 100644
--- a/src/lib/evloop/workqueue.h
+++ b/src/lib/evloop/workqueue.h
@@ -63,7 +63,6 @@ replyqueue_t *threadpool_get_replyqueue(threadpool_t *tp);
replyqueue_t *replyqueue_new(uint32_t alertsocks_flags);
void replyqueue_process(replyqueue_t *queue);
-struct event_base;
int threadpool_register_reply_event(threadpool_t *tp,
void (*cb)(threadpool_t *tp));
diff --git a/src/lib/geoip/.may_include b/src/lib/geoip/.may_include
new file mode 100644
index 0000000000..b1ee2dcfe9
--- /dev/null
+++ b/src/lib/geoip/.may_include
@@ -0,0 +1,13 @@
+orconfig.h
+lib/cc/*.h
+lib/container/*.h
+lib/crypt_ops/*.h
+lib/ctime/*.h
+lib/encoding/*.h
+lib/fs/*.h
+lib/geoip/*.h
+lib/log/*.h
+lib/malloc/*.h
+lib/net/*.h
+lib/string/*.h
+lib/testsupport/*.h
diff --git a/src/lib/geoip/country.h b/src/lib/geoip/country.h
new file mode 100644
index 0000000000..080c156023
--- /dev/null
+++ b/src/lib/geoip/country.h
@@ -0,0 +1,16 @@
+/* Copyright (c) 2001 Matej Pfajfar.
+ * Copyright (c) 2001-2004, Roger Dingledine.
+ * Copyright (c) 2004-2006, Roger Dingledine, Nick Mathewson.
+ * Copyright (c) 2007-2018, The Tor Project, Inc. */
+/* See LICENSE for licensing information */
+
+#ifndef TOR_COUNTRY_H
+#define TOR_COUNTRY_H
+
+#include "lib/cc/torint.h"
+/** A signed integer representing a country code. */
+typedef int16_t country_t;
+
+#define COUNTRY_MAX INT16_MAX
+
+#endif
diff --git a/src/lib/geoip/geoip.c b/src/lib/geoip/geoip.c
new file mode 100644
index 0000000000..b1c0973d03
--- /dev/null
+++ b/src/lib/geoip/geoip.c
@@ -0,0 +1,510 @@
+/* Copyright (c) 2007-2018, The Tor Project, Inc. */
+/* See LICENSE for licensing information */
+
+/**
+ * \file geoip.c
+ * \brief Functions related to maintaining an IP-to-country database;
+ * to summarizing client connections by country to entry guards, bridges,
+ * and directory servers; and for statistics on answering network status
+ * requests.
+ *
+ * There are two main kinds of functions in this module: geoip functions,
+ * which map groups of IPv4 and IPv6 addresses to country codes, and
+ * statistical functions, which collect statistics about different kinds of
+ * per-country usage.
+ *
+ * The geoip lookup tables are implemented as sorted lists of disjoint address
+ * ranges, each mapping to a singleton geoip_country_t. These country objects
+ * are also indexed by their names in a hashtable.
+ *
+ * The tables are populated from disk at startup by the geoip_load_file()
+ * function. For more information on the file format they read, see that
+ * function. See the scripts and the README file in src/config for more
+ * information about how those files are generated.
+ *
+ * Tor uses GeoIP information in order to implement user requests (such as
+ * ExcludeNodes {cc}), and to keep track of how much usage relays are getting
+ * for each country.
+ */
+
+#define GEOIP_PRIVATE
+#include "lib/geoip/geoip.h"
+#include "lib/container/map.h"
+#include "lib/container/order.h"
+#include "lib/container/smartlist.h"
+#include "lib/crypt_ops/crypto_digest.h"
+#include "lib/ctime/di_ops.h"
+#include "lib/encoding/binascii.h"
+#include "lib/fs/files.h"
+#include "lib/log/escape.h"
+#include "lib/malloc/malloc.h"
+#include "lib/net/address.h" //????
+#include "lib/net/inaddr.h"
+#include "lib/string/compat_ctype.h"
+#include "lib/string/compat_string.h"
+#include "lib/string/scanf.h"
+#include "lib/string/util_string.h"
+
+#include <stdio.h>
+#include <string.h>
+
+static void init_geoip_countries(void);
+
+/** An entry from the GeoIP IPv4 file: maps an IPv4 range to a country. */
+typedef struct geoip_ipv4_entry_t {
+ uint32_t ip_low; /**< The lowest IP in the range, in host order */
+ uint32_t ip_high; /**< The highest IP in the range, in host order */
+ intptr_t country; /**< An index into geoip_countries */
+} geoip_ipv4_entry_t;
+
+/** An entry from the GeoIP IPv6 file: maps an IPv6 range to a country. */
+typedef struct geoip_ipv6_entry_t {
+ struct in6_addr ip_low; /**< The lowest IP in the range, in host order */
+ struct in6_addr ip_high; /**< The highest IP in the range, in host order */
+ intptr_t country; /**< An index into geoip_countries */
+} geoip_ipv6_entry_t;
+
+/** A list of geoip_country_t */
+static smartlist_t *geoip_countries = NULL;
+/** A map from lowercased country codes to their position in geoip_countries.
+ * The index is encoded in the pointer, and 1 is added so that NULL can mean
+ * not found. */
+static strmap_t *country_idxplus1_by_lc_code = NULL;
+/** Lists of all known geoip_ipv4_entry_t and geoip_ipv6_entry_t, sorted
+ * by their respective ip_low. */
+static smartlist_t *geoip_ipv4_entries = NULL, *geoip_ipv6_entries = NULL;
+
+/** SHA1 digest of the GeoIP files to include in extra-info descriptors. */
+static char geoip_digest[DIGEST_LEN];
+static char geoip6_digest[DIGEST_LEN];
+
+/** Return a list of geoip_country_t for all known countries. */
+const smartlist_t *
+geoip_get_countries(void)
+{
+ if (geoip_countries == NULL) {
+ init_geoip_countries();
+ }
+ return geoip_countries;
+}
+
+/** Return the index of the <b>country</b>'s entry in the GeoIP
+ * country list if it is a valid 2-letter country code, otherwise
+ * return -1. */
+MOCK_IMPL(country_t,
+geoip_get_country,(const char *country))
+{
+ void *idxplus1_;
+ intptr_t idx;
+
+ idxplus1_ = strmap_get_lc(country_idxplus1_by_lc_code, country);
+ if (!idxplus1_)
+ return -1;
+
+ idx = ((uintptr_t)idxplus1_)-1;
+ return (country_t)idx;
+}
+
+/** Add an entry to a GeoIP table, mapping all IP addresses between <b>low</b>
+ * and <b>high</b>, inclusive, to the 2-letter country code <b>country</b>. */
+static void
+geoip_add_entry(const tor_addr_t *low, const tor_addr_t *high,
+ const char *country)
+{
+ intptr_t idx;
+ void *idxplus1_;
+
+ IF_BUG_ONCE(tor_addr_family(low) != tor_addr_family(high))
+ return;
+ IF_BUG_ONCE(tor_addr_compare(high, low, CMP_EXACT) < 0)
+ return;
+
+ idxplus1_ = strmap_get_lc(country_idxplus1_by_lc_code, country);
+
+ if (!idxplus1_) {
+ geoip_country_t *c = tor_malloc_zero(sizeof(geoip_country_t));
+ strlcpy(c->countrycode, country, sizeof(c->countrycode));
+ tor_strlower(c->countrycode);
+ smartlist_add(geoip_countries, c);
+ idx = smartlist_len(geoip_countries) - 1;
+ strmap_set_lc(country_idxplus1_by_lc_code, country, (void*)(idx+1));
+ } else {
+ idx = ((uintptr_t)idxplus1_)-1;
+ }
+ {
+ geoip_country_t *c = smartlist_get(geoip_countries, (int)idx);
+ tor_assert(!strcasecmp(c->countrycode, country));
+ }
+
+ if (tor_addr_family(low) == AF_INET) {
+ geoip_ipv4_entry_t *ent = tor_malloc_zero(sizeof(geoip_ipv4_entry_t));
+ ent->ip_low = tor_addr_to_ipv4h(low);
+ ent->ip_high = tor_addr_to_ipv4h(high);
+ ent->country = idx;
+ smartlist_add(geoip_ipv4_entries, ent);
+ } else if (tor_addr_family(low) == AF_INET6) {
+ geoip_ipv6_entry_t *ent = tor_malloc_zero(sizeof(geoip_ipv6_entry_t));
+ ent->ip_low = *tor_addr_to_in6_assert(low);
+ ent->ip_high = *tor_addr_to_in6_assert(high);
+ ent->country = idx;
+ smartlist_add(geoip_ipv6_entries, ent);
+ }
+}
+
+/** Add an entry to the GeoIP table indicated by <b>family</b>,
+ * parsing it from <b>line</b>. The format is as for geoip_load_file(). */
+STATIC int
+geoip_parse_entry(const char *line, sa_family_t family)
+{
+ tor_addr_t low_addr, high_addr;
+ char c[3];
+ char *country = NULL;
+
+ if (!geoip_countries)
+ init_geoip_countries();
+ if (family == AF_INET) {
+ if (!geoip_ipv4_entries)
+ geoip_ipv4_entries = smartlist_new();
+ } else if (family == AF_INET6) {
+ if (!geoip_ipv6_entries)
+ geoip_ipv6_entries = smartlist_new();
+ } else {
+ log_warn(LD_GENERAL, "Unsupported family: %d", family);
+ return -1;
+ }
+
+ while (TOR_ISSPACE(*line))
+ ++line;
+ if (*line == '#')
+ return 0;
+
+ char buf[512];
+ if (family == AF_INET) {
+ unsigned int low, high;
+ if (tor_sscanf(line,"%u,%u,%2s", &low, &high, c) == 3 ||
+ tor_sscanf(line,"\"%u\",\"%u\",\"%2s\",", &low, &high, c) == 3) {
+ tor_addr_from_ipv4h(&low_addr, low);
+ tor_addr_from_ipv4h(&high_addr, high);
+ } else
+ goto fail;
+ country = c;
+ } else { /* AF_INET6 */
+ char *low_str, *high_str;
+ struct in6_addr low, high;
+ char *strtok_state;
+ strlcpy(buf, line, sizeof(buf));
+ low_str = tor_strtok_r(buf, ",", &strtok_state);
+ if (!low_str)
+ goto fail;
+ high_str = tor_strtok_r(NULL, ",", &strtok_state);
+ if (!high_str)
+ goto fail;
+ country = tor_strtok_r(NULL, "\n", &strtok_state);
+ if (!country)
+ goto fail;
+ if (strlen(country) != 2)
+ goto fail;
+ if (tor_inet_pton(AF_INET6, low_str, &low) <= 0)
+ goto fail;
+ tor_addr_from_in6(&low_addr, &low);
+ if (tor_inet_pton(AF_INET6, high_str, &high) <= 0)
+ goto fail;
+ tor_addr_from_in6(&high_addr, &high);
+ }
+ geoip_add_entry(&low_addr, &high_addr, country);
+ return 0;
+
+ fail:
+ log_warn(LD_GENERAL, "Unable to parse line from GEOIP %s file: %s",
+ family == AF_INET ? "IPv4" : "IPv6", escaped(line));
+ return -1;
+}
+
+/** Sorting helper: return -1, 1, or 0 based on comparison of two
+ * geoip_ipv4_entry_t */
+static int
+geoip_ipv4_compare_entries_(const void **_a, const void **_b)
+{
+ const geoip_ipv4_entry_t *a = *_a, *b = *_b;
+ if (a->ip_low < b->ip_low)
+ return -1;
+ else if (a->ip_low > b->ip_low)
+ return 1;
+ else
+ return 0;
+}
+
+/** bsearch helper: return -1, 1, or 0 based on comparison of an IP (a pointer
+ * to a uint32_t in host order) to a geoip_ipv4_entry_t */
+static int
+geoip_ipv4_compare_key_to_entry_(const void *_key, const void **_member)
+{
+ /* No alignment issue here, since _key really is a pointer to uint32_t */
+ const uint32_t addr = *(uint32_t *)_key;
+ const geoip_ipv4_entry_t *entry = *_member;
+ if (addr < entry->ip_low)
+ return -1;
+ else if (addr > entry->ip_high)
+ return 1;
+ else
+ return 0;
+}
+
+/** Sorting helper: return -1, 1, or 0 based on comparison of two
+ * geoip_ipv6_entry_t */
+static int
+geoip_ipv6_compare_entries_(const void **_a, const void **_b)
+{
+ const geoip_ipv6_entry_t *a = *_a, *b = *_b;
+ return fast_memcmp(a->ip_low.s6_addr, b->ip_low.s6_addr,
+ sizeof(struct in6_addr));
+}
+
+/** bsearch helper: return -1, 1, or 0 based on comparison of an IPv6
+ * (a pointer to a in6_addr) to a geoip_ipv6_entry_t */
+static int
+geoip_ipv6_compare_key_to_entry_(const void *_key, const void **_member)
+{
+ const struct in6_addr *addr = (struct in6_addr *)_key;
+ const geoip_ipv6_entry_t *entry = *_member;
+
+ if (fast_memcmp(addr->s6_addr, entry->ip_low.s6_addr,
+ sizeof(struct in6_addr)) < 0)
+ return -1;
+ else if (fast_memcmp(addr->s6_addr, entry->ip_high.s6_addr,
+ sizeof(struct in6_addr)) > 0)
+ return 1;
+ else
+ return 0;
+}
+
+/** Set up a new list of geoip countries with no countries (yet) set in it,
+ * except for the unknown country.
+ */
+static void
+init_geoip_countries(void)
+{
+ geoip_country_t *geoip_unresolved;
+ geoip_countries = smartlist_new();
+ /* Add a geoip_country_t for requests that could not be resolved to a
+ * country as first element (index 0) to geoip_countries. */
+ geoip_unresolved = tor_malloc_zero(sizeof(geoip_country_t));
+ strlcpy(geoip_unresolved->countrycode, "??",
+ sizeof(geoip_unresolved->countrycode));
+ smartlist_add(geoip_countries, geoip_unresolved);
+ country_idxplus1_by_lc_code = strmap_new();
+ strmap_set_lc(country_idxplus1_by_lc_code, "??", (void*)(1));
+}
+
+/** Clear appropriate GeoIP database, based on <b>family</b>, and
+ * reload it from the file <b>filename</b>. Return 0 on success, -1 on
+ * failure.
+ *
+ * Recognized line formats for IPv4 are:
+ * INTIPLOW,INTIPHIGH,CC
+ * and
+ * "INTIPLOW","INTIPHIGH","CC","CC3","COUNTRY NAME"
+ * where INTIPLOW and INTIPHIGH are IPv4 addresses encoded as 4-byte unsigned
+ * integers, and CC is a country code.
+ *
+ * Recognized line format for IPv6 is:
+ * IPV6LOW,IPV6HIGH,CC
+ * where IPV6LOW and IPV6HIGH are IPv6 addresses and CC is a country code.
+ *
+ * It also recognizes, and skips over, blank lines and lines that start
+ * with '#' (comments).
+ */
+int
+geoip_load_file(sa_family_t family, const char *filename, int severity)
+{
+ FILE *f;
+ crypto_digest_t *geoip_digest_env = NULL;
+
+ tor_assert(family == AF_INET || family == AF_INET6);
+
+ if (!(f = tor_fopen_cloexec(filename, "r"))) {
+ log_fn(severity, LD_GENERAL, "Failed to open GEOIP file %s.",
+ filename);
+ return -1;
+ }
+ if (!geoip_countries)
+ init_geoip_countries();
+
+ if (family == AF_INET) {
+ if (geoip_ipv4_entries) {
+ SMARTLIST_FOREACH(geoip_ipv4_entries, geoip_ipv4_entry_t *, e,
+ tor_free(e));
+ smartlist_free(geoip_ipv4_entries);
+ }
+ geoip_ipv4_entries = smartlist_new();
+ } else { /* AF_INET6 */
+ if (geoip_ipv6_entries) {
+ SMARTLIST_FOREACH(geoip_ipv6_entries, geoip_ipv6_entry_t *, e,
+ tor_free(e));
+ smartlist_free(geoip_ipv6_entries);
+ }
+ geoip_ipv6_entries = smartlist_new();
+ }
+ geoip_digest_env = crypto_digest_new();
+
+ log_notice(LD_GENERAL, "Parsing GEOIP %s file %s.",
+ (family == AF_INET) ? "IPv4" : "IPv6", filename);
+ while (!feof(f)) {
+ char buf[512];
+ if (fgets(buf, (int)sizeof(buf), f) == NULL)
+ break;
+ crypto_digest_add_bytes(geoip_digest_env, buf, strlen(buf));
+ /* FFFF track full country name. */
+ geoip_parse_entry(buf, family);
+ }
+ /*XXXX abort and return -1 if no entries/illformed?*/
+ fclose(f);
+
+ /* Sort list and remember file digests so that we can include it in
+ * our extra-info descriptors. */
+ if (family == AF_INET) {
+ smartlist_sort(geoip_ipv4_entries, geoip_ipv4_compare_entries_);
+ crypto_digest_get_digest(geoip_digest_env, geoip_digest, DIGEST_LEN);
+ } else {
+ /* AF_INET6 */
+ smartlist_sort(geoip_ipv6_entries, geoip_ipv6_compare_entries_);
+ crypto_digest_get_digest(geoip_digest_env, geoip6_digest, DIGEST_LEN);
+ }
+ crypto_digest_free(geoip_digest_env);
+
+ return 0;
+}
+
+/** Given an IP address in host order, return a number representing the
+ * country to which that address belongs, -1 for "No geoip information
+ * available", or 0 for the 'unknown country'. The return value will always
+ * be less than geoip_get_n_countries(). To decode it, call
+ * geoip_get_country_name().
+ */
+int
+geoip_get_country_by_ipv4(uint32_t ipaddr)
+{
+ geoip_ipv4_entry_t *ent;
+ if (!geoip_ipv4_entries)
+ return -1;
+ ent = smartlist_bsearch(geoip_ipv4_entries, &ipaddr,
+ geoip_ipv4_compare_key_to_entry_);
+ return ent ? (int)ent->country : 0;
+}
+
+/** Given an IPv6 address, return a number representing the country to
+ * which that address belongs, -1 for "No geoip information available", or
+ * 0 for the 'unknown country'. The return value will always be less than
+ * geoip_get_n_countries(). To decode it, call geoip_get_country_name().
+ */
+int
+geoip_get_country_by_ipv6(const struct in6_addr *addr)
+{
+ geoip_ipv6_entry_t *ent;
+
+ if (!geoip_ipv6_entries)
+ return -1;
+ ent = smartlist_bsearch(geoip_ipv6_entries, addr,
+ geoip_ipv6_compare_key_to_entry_);
+ return ent ? (int)ent->country : 0;
+}
+
+/** Given an IP address, return a number representing the country to which
+ * that address belongs, -1 for "No geoip information available", or 0 for
+ * the 'unknown country'. The return value will always be less than
+ * geoip_get_n_countries(). To decode it, call geoip_get_country_name().
+ */
+MOCK_IMPL(int,
+geoip_get_country_by_addr,(const tor_addr_t *addr))
+{
+ if (tor_addr_family(addr) == AF_INET) {
+ return geoip_get_country_by_ipv4(tor_addr_to_ipv4h(addr));
+ } else if (tor_addr_family(addr) == AF_INET6) {
+ return geoip_get_country_by_ipv6(tor_addr_to_in6(addr));
+ } else {
+ return -1;
+ }
+}
+
+/** Return the number of countries recognized by the GeoIP country list. */
+MOCK_IMPL(int,
+geoip_get_n_countries,(void))
+{
+ if (!geoip_countries)
+ init_geoip_countries();
+ return (int) smartlist_len(geoip_countries);
+}
+
+/** Return the two-letter country code associated with the number <b>num</b>,
+ * or "??" for an unknown value. */
+const char *
+geoip_get_country_name(country_t num)
+{
+ if (geoip_countries && num >= 0 && num < smartlist_len(geoip_countries)) {
+ geoip_country_t *c = smartlist_get(geoip_countries, num);
+ return c->countrycode;
+ } else
+ return "??";
+}
+
+/** Return true iff we have loaded a GeoIP database.*/
+MOCK_IMPL(int,
+geoip_is_loaded,(sa_family_t family))
+{
+ tor_assert(family == AF_INET || family == AF_INET6);
+ if (geoip_countries == NULL)
+ return 0;
+ if (family == AF_INET)
+ return geoip_ipv4_entries != NULL;
+ else /* AF_INET6 */
+ return geoip_ipv6_entries != NULL;
+}
+
+/** Return the hex-encoded SHA1 digest of the loaded GeoIP file. The
+ * result does not need to be deallocated, but will be overwritten by the
+ * next call of hex_str(). */
+const char *
+geoip_db_digest(sa_family_t family)
+{
+ tor_assert(family == AF_INET || family == AF_INET6);
+ if (family == AF_INET)
+ return hex_str(geoip_digest, DIGEST_LEN);
+ else /* AF_INET6 */
+ return hex_str(geoip6_digest, DIGEST_LEN);
+}
+
+/** Release all storage held by the GeoIP databases and country list. */
+STATIC void
+clear_geoip_db(void)
+{
+ if (geoip_countries) {
+ SMARTLIST_FOREACH(geoip_countries, geoip_country_t *, c, tor_free(c));
+ smartlist_free(geoip_countries);
+ }
+
+ strmap_free(country_idxplus1_by_lc_code, NULL);
+ if (geoip_ipv4_entries) {
+ SMARTLIST_FOREACH(geoip_ipv4_entries, geoip_ipv4_entry_t *, ent,
+ tor_free(ent));
+ smartlist_free(geoip_ipv4_entries);
+ }
+ if (geoip_ipv6_entries) {
+ SMARTLIST_FOREACH(geoip_ipv6_entries, geoip_ipv6_entry_t *, ent,
+ tor_free(ent));
+ smartlist_free(geoip_ipv6_entries);
+ }
+ geoip_countries = NULL;
+ country_idxplus1_by_lc_code = NULL;
+ geoip_ipv4_entries = NULL;
+ geoip_ipv6_entries = NULL;
+}
+
+/** Release all storage held in this file. */
+void
+geoip_free_all(void)
+{
+ clear_geoip_db();
+
+ memset(geoip_digest, 0, sizeof(geoip_digest));
+ memset(geoip6_digest, 0, sizeof(geoip6_digest));
+}
diff --git a/src/lib/geoip/geoip.h b/src/lib/geoip/geoip.h
new file mode 100644
index 0000000000..6ef27d66d0
--- /dev/null
+++ b/src/lib/geoip/geoip.h
@@ -0,0 +1,50 @@
+/* Copyright (c) 2001 Matej Pfajfar.
+ * Copyright (c) 2001-2004, Roger Dingledine.
+ * Copyright (c) 2004-2006, Roger Dingledine, Nick Mathewson.
+ * Copyright (c) 2007-2018, The Tor Project, Inc. */
+/* See LICENSE for licensing information */
+
+/**
+ * \file geoip.h
+ * \brief Header file for geoip.c.
+ **/
+
+#ifndef TOR_GEOIP_H
+#define TOR_GEOIP_H
+
+#include "orconfig.h"
+#include "lib/net/nettypes.h"
+#include "lib/testsupport/testsupport.h"
+#include "lib/net/inaddr_st.h"
+#include "lib/geoip/country.h"
+
+#ifdef GEOIP_PRIVATE
+STATIC int geoip_parse_entry(const char *line, sa_family_t family);
+STATIC void clear_geoip_db(void);
+#endif /* defined(GEOIP_PRIVATE) */
+
+struct in6_addr;
+struct tor_addr_t;
+
+int geoip_get_country_by_ipv4(uint32_t ipaddr);
+int geoip_get_country_by_ipv6(const struct in6_addr *addr);
+
+/** A per-country GeoIP record. */
+typedef struct geoip_country_t {
+ char countrycode[3];
+} geoip_country_t;
+
+struct smartlist_t;
+const struct smartlist_t *geoip_get_countries(void);
+
+int geoip_load_file(sa_family_t family, const char *filename, int severity);
+MOCK_DECL(int, geoip_get_country_by_addr, (const struct tor_addr_t *addr));
+MOCK_DECL(int, geoip_get_n_countries, (void));
+const char *geoip_get_country_name(country_t num);
+MOCK_DECL(int, geoip_is_loaded, (sa_family_t family));
+const char *geoip_db_digest(sa_family_t family);
+MOCK_DECL(country_t, geoip_get_country, (const char *countrycode));
+
+void geoip_free_all(void);
+
+#endif /* !defined(TOR_GEOIP_H) */
diff --git a/src/lib/geoip/include.am b/src/lib/geoip/include.am
new file mode 100644
index 0000000000..9710d75ac7
--- /dev/null
+++ b/src/lib/geoip/include.am
@@ -0,0 +1,17 @@
+noinst_LIBRARIES += src/lib/libtor-geoip.a
+
+if UNITTESTS_ENABLED
+noinst_LIBRARIES += src/lib/libtor-geoip-testing.a
+endif
+
+src_lib_libtor_geoip_a_SOURCES = \
+ src/lib/geoip/geoip.c
+
+src_lib_libtor_geoip_testing_a_SOURCES = \
+ $(src_lib_libtor_geoip_a_SOURCES)
+src_lib_libtor_geoip_testing_a_CPPFLAGS = $(AM_CPPFLAGS) $(TEST_CPPFLAGS)
+src_lib_libtor_geoip_testing_a_CFLAGS = $(AM_CFLAGS) $(TEST_CFLAGS)
+
+noinst_HEADERS += \
+ src/lib/geoip/geoip.h \
+ src/lib/geoip/country.h
diff --git a/src/lib/log/util_bug.h b/src/lib/log/util_bug.h
index 44a4f8381c..557d932ac3 100644
--- a/src/lib/log/util_bug.h
+++ b/src/lib/log/util_bug.h
@@ -56,6 +56,35 @@
#error "Sorry; we don't support building with NDEBUG."
#endif /* defined(NDEBUG) */
+#if defined(TOR_UNIT_TESTS) && defined(__GNUC__)
+/* We define this GCC macro as a replacement for PREDICT_UNLIKELY() in this
+ * header, so that in our unit test builds, we'll get compiler warnings about
+ * stuff like tor_assert(n = 5).
+ *
+ * The key here is that (e) is wrapped in exactly one layer of parentheses,
+ * and then passed right to a conditional. If you do anything else to the
+ * expression here, or introduce any more parentheses, the compiler won't
+ * help you.
+ *
+ * We only do this for the unit-test build case because it interferes with
+ * the likely-branch labeling. Note below that in the other case, we define
+ * these macros to just be synonyms for PREDICT_(UN)LIKELY.
+ */
+#define ASSERT_PREDICT_UNLIKELY_(e) \
+ ( { \
+ int tor__assert_tmp_value__; \
+ if (e) \
+ tor__assert_tmp_value__ = 1; \
+ else \
+ tor__assert_tmp_value__ = 0; \
+ tor__assert_tmp_value__; \
+ } )
+#define ASSERT_PREDICT_LIKELY_(e) ASSERT_PREDICT_UNLIKELY_(e)
+#else
+#define ASSERT_PREDICT_UNLIKELY_(e) PREDICT_UNLIKELY(e)
+#define ASSERT_PREDICT_LIKELY_(e) PREDICT_LIKELY(e)
+#endif
+
/* Sometimes we don't want to use assertions during branch coverage tests; it
* leads to tons of unreached branches which in reality are only assertions we
* didn't hit. */
@@ -67,13 +96,19 @@
/** Like assert(3), but send assertion failures to the log as well as to
* stderr. */
#define tor_assert(expr) STMT_BEGIN \
- if (PREDICT_UNLIKELY(!(expr))) { \
+ if (ASSERT_PREDICT_LIKELY_(expr)) { \
+ } else { \
tor_assertion_failed_(SHORT_FILE__, __LINE__, __func__, #expr); \
abort(); \
} STMT_END
#endif /* defined(TOR_UNIT_TESTS) && defined(DISABLE_ASSERTS_IN_UNIT_TESTS) */
-#define tor_assert_unreached() tor_assert(0)
+#define tor_assert_unreached() \
+ STMT_BEGIN { \
+ tor_assertion_failed_(SHORT_FILE__, __LINE__, __func__, \
+ "line should be unreached"); \
+ abort(); \
+ } STMT_END
/* Non-fatal bug assertions. The "unreached" variants mean "this line should
* never be reached." The "once" variants mean "Don't log a warning more than
@@ -104,7 +139,7 @@
#define tor_assert_nonfatal_unreached_once() tor_assert(0)
#define tor_assert_nonfatal_once(cond) tor_assert((cond))
#define BUG(cond) \
- (PREDICT_UNLIKELY(cond) ? \
+ (ASSERT_PREDICT_UNLIKELY_(cond) ? \
(tor_assertion_failed_(SHORT_FILE__,__LINE__,__func__,"!("#cond")"), \
abort(), 1) \
: 0)
@@ -113,14 +148,15 @@
#define tor_assert_nonfatal(cond) ((void)(cond))
#define tor_assert_nonfatal_unreached_once() STMT_NIL
#define tor_assert_nonfatal_once(cond) ((void)(cond))
-#define BUG(cond) (PREDICT_UNLIKELY(cond) ? 1 : 0)
+#define BUG(cond) (ASSERT_PREDICT_UNLIKELY_(cond) ? 1 : 0)
#else /* Normal case, !ALL_BUGS_ARE_FATAL, !DISABLE_ASSERTS_IN_UNIT_TESTS */
#define tor_assert_nonfatal_unreached() STMT_BEGIN \
tor_bug_occurred_(SHORT_FILE__, __LINE__, __func__, NULL, 0); \
STMT_END
#define tor_assert_nonfatal(cond) STMT_BEGIN \
- if (PREDICT_UNLIKELY(!(cond))) { \
- tor_bug_occurred_(SHORT_FILE__, __LINE__, __func__, #cond, 0); \
+ if (ASSERT_PREDICT_LIKELY_(cond)) { \
+ } else { \
+ tor_bug_occurred_(SHORT_FILE__, __LINE__, __func__, #cond, 0); \
} \
STMT_END
#define tor_assert_nonfatal_unreached_once() STMT_BEGIN \
@@ -132,13 +168,14 @@
STMT_END
#define tor_assert_nonfatal_once(cond) STMT_BEGIN \
static int warning_logged__ = 0; \
- if (!warning_logged__ && PREDICT_UNLIKELY(!(cond))) { \
+ if (ASSERT_PREDICT_LIKELY_(cond)) { \
+ } else if (!warning_logged__) { \
warning_logged__ = 1; \
tor_bug_occurred_(SHORT_FILE__, __LINE__, __func__, #cond, 1); \
} \
STMT_END
#define BUG(cond) \
- (PREDICT_UNLIKELY(cond) ? \
+ (ASSERT_PREDICT_UNLIKELY_(cond) ? \
(tor_bug_occurred_(SHORT_FILE__,__LINE__,__func__,"!("#cond")",0), 1) \
: 0)
#endif /* defined(ALL_BUGS_ARE_FATAL) || ... */
@@ -147,17 +184,17 @@
#define IF_BUG_ONCE__(cond,var) \
if (( { \
static int var = 0; \
- int bool_result = (cond); \
- if (PREDICT_UNLIKELY(bool_result) && !var) { \
+ int bool_result = !!(cond); \
+ if (bool_result && !var) { \
var = 1; \
tor_bug_occurred_(SHORT_FILE__, __LINE__, __func__, \
"!("#cond")", 1); \
} \
- PREDICT_UNLIKELY(bool_result); } ))
+ bool_result; } ))
#else /* !(defined(__GNUC__)) */
#define IF_BUG_ONCE__(cond,var) \
static int var = 0; \
- if (PREDICT_UNLIKELY(cond) ? \
+ if ((cond) ? \
(var ? 1 : \
(var=1, \
tor_bug_occurred_(SHORT_FILE__, __LINE__, __func__, \
@@ -175,7 +212,7 @@
*/
#define IF_BUG_ONCE(cond) \
- IF_BUG_ONCE__((cond), \
+ IF_BUG_ONCE__(ASSERT_PREDICT_UNLIKELY_(cond), \
IF_BUG_ONCE_VARNAME__(__LINE__))
/** Define this if you want Tor to crash when any problem comes up,
diff --git a/src/lib/net/address.c b/src/lib/net/address.c
index 03767e2950..c97a17037a 100644
--- a/src/lib/net/address.c
+++ b/src/lib/net/address.c
@@ -1187,14 +1187,22 @@ tor_addr_parse(tor_addr_t *addr, const char *src)
int result;
struct in_addr in_tmp;
struct in6_addr in6_tmp;
+ int brackets_detected = 0;
+
tor_assert(addr && src);
- if (src[0] == '[' && src[1])
+
+ size_t len = strlen(src);
+
+ if (len && src[0] == '[' && src[len - 1] == ']') {
+ brackets_detected = 1;
src = tmp = tor_strndup(src+1, strlen(src)-2);
+ }
if (tor_inet_pton(AF_INET6, src, &in6_tmp) > 0) {
result = AF_INET6;
tor_addr_from_in6(addr, &in6_tmp);
- } else if (tor_inet_pton(AF_INET, src, &in_tmp) > 0) {
+ } else if (!brackets_detected &&
+ tor_inet_pton(AF_INET, src, &in_tmp) > 0) {
result = AF_INET;
tor_addr_from_in(addr, &in_tmp);
} else {
diff --git a/src/lib/net/inaddr.c b/src/lib/net/inaddr.c
index dcd8fcdd65..0960d323c5 100644
--- a/src/lib/net/inaddr.c
+++ b/src/lib/net/inaddr.c
@@ -168,6 +168,13 @@ tor_inet_pton(int af, const char *src, void *dst)
if (af == AF_INET) {
return tor_inet_aton(src, dst);
} else if (af == AF_INET6) {
+ ssize_t len = strlen(src);
+
+ /* Reject if src has needless trailing ':'. */
+ if (len > 2 && src[len - 1] == ':' && src[len - 2] != ':') {
+ return 0;
+ }
+
struct in6_addr *out = dst;
uint16_t words[8];
int gapPos = -1, i, setWords=0;
@@ -207,7 +214,6 @@ tor_inet_pton(int af, const char *src, void *dst)
return 0;
if (TOR_ISXDIGIT(*src)) {
char *next;
- ssize_t len;
long r = strtol(src, &next, 16);
if (next == NULL || next == src) {
/* The 'next == src' error case can happen on versions of openbsd
diff --git a/src/lib/net/socket.c b/src/lib/net/socket.c
index 06421b080d..cd7c9685cd 100644
--- a/src/lib/net/socket.c
+++ b/src/lib/net/socket.c
@@ -142,41 +142,6 @@ tor_close_socket_simple(tor_socket_t s)
return r;
}
-/** As tor_close_socket_simple(), but keeps track of the number
- * of open sockets. Returns 0 on success, -1 on failure. */
-MOCK_IMPL(int,
-tor_close_socket,(tor_socket_t s))
-{
- int r = tor_close_socket_simple(s);
-
- socket_accounting_lock();
-#ifdef DEBUG_SOCKET_COUNTING
- if (s > max_socket || ! bitarray_is_set(open_sockets, s)) {
- log_warn(LD_BUG, "Closing a socket (%d) that wasn't returned by tor_open_"
- "socket(), or that was already closed or something.", s);
- } else {
- tor_assert(open_sockets && s <= max_socket);
- bitarray_clear(open_sockets, s);
- }
-#endif /* defined(DEBUG_SOCKET_COUNTING) */
- if (r == 0) {
- --n_sockets_open;
- } else {
-#ifdef _WIN32
- if (r != WSAENOTSOCK)
- --n_sockets_open;
-#else
- if (r != EBADF)
- --n_sockets_open; // LCOV_EXCL_LINE -- EIO and EINTR too hard to force.
-#endif /* defined(_WIN32) */
- r = -1;
- }
-
- tor_assert_nonfatal(n_sockets_open >= 0);
- socket_accounting_unlock();
- return r;
-}
-
/** @{ */
#ifdef DEBUG_SOCKET_COUNTING
/** Helper: if DEBUG_SOCKET_COUNTING is enabled, remember that <b>s</b> is
@@ -201,11 +166,50 @@ mark_socket_open(tor_socket_t s)
}
bitarray_set(open_sockets, s);
}
+static inline void
+mark_socket_closed(tor_socket_t s)
+{
+ if (s > max_socket || ! bitarray_is_set(open_sockets, s)) {
+ log_warn(LD_BUG, "Closing a socket (%d) that wasn't returned by tor_open_"
+ "socket(), or that was already closed or something.", s);
+ } else {
+ tor_assert(open_sockets && s <= max_socket);
+ bitarray_clear(open_sockets, s);
+ }
+}
#else /* !(defined(DEBUG_SOCKET_COUNTING)) */
#define mark_socket_open(s) ((void) (s))
+#define mark_socket_closed(s) ((void) (s))
#endif /* defined(DEBUG_SOCKET_COUNTING) */
/** @} */
+/** As tor_close_socket_simple(), but keeps track of the number
+ * of open sockets. Returns 0 on success, -1 on failure. */
+MOCK_IMPL(int,
+tor_close_socket,(tor_socket_t s))
+{
+ int r = tor_close_socket_simple(s);
+
+ socket_accounting_lock();
+ mark_socket_closed(s);
+ if (r == 0) {
+ --n_sockets_open;
+ } else {
+#ifdef _WIN32
+ if (r != WSAENOTSOCK)
+ --n_sockets_open;
+#else
+ if (r != EBADF)
+ --n_sockets_open; // LCOV_EXCL_LINE -- EIO and EINTR too hard to force.
+#endif /* defined(_WIN32) */
+ r = -1;
+ }
+
+ tor_assert_nonfatal(n_sockets_open >= 0);
+ socket_accounting_unlock();
+ return r;
+}
+
/** As socket(), but counts the number of open sockets. */
MOCK_IMPL(tor_socket_t,
tor_open_socket,(int domain, int type, int protocol))
@@ -307,6 +311,20 @@ tor_take_socket_ownership(tor_socket_t s)
socket_accounting_unlock();
}
+/**
+ * For socket accounting: declare that we are no longer the owner of the
+ * socket <b>s</b>. This will prevent us from overallocating sockets, and
+ * prevent us from asserting later when we close the socket <b>s</b>.
+ */
+void
+tor_release_socket_ownership(tor_socket_t s)
+{
+ socket_accounting_lock();
+ --n_sockets_open;
+ mark_socket_closed(s);
+ socket_accounting_unlock();
+}
+
/** As accept(), but counts the number of open sockets. */
tor_socket_t
tor_accept_socket(tor_socket_t sockfd, struct sockaddr *addr, socklen_t *len)
diff --git a/src/lib/net/socket.h b/src/lib/net/socket.h
index 5b7d6dbbc6..2b87441fc6 100644
--- a/src/lib/net/socket.h
+++ b/src/lib/net/socket.h
@@ -23,6 +23,7 @@ struct sockaddr;
int tor_close_socket_simple(tor_socket_t s);
MOCK_DECL(int, tor_close_socket, (tor_socket_t s));
void tor_take_socket_ownership(tor_socket_t s);
+void tor_release_socket_ownership(tor_socket_t s);
tor_socket_t tor_open_socket_with_extensions(
int domain, int type, int protocol,
int cloexec, int nonblock);
diff --git a/src/lib/net/socketpair.c b/src/lib/net/socketpair.c
index 8dd04d0486..380338f15c 100644
--- a/src/lib/net/socketpair.c
+++ b/src/lib/net/socketpair.c
@@ -2,6 +2,7 @@
* Copyright (c) 2004-2006, Roger Dingledine, Nick Mathewson.
* Copyright (c) 2007-2018, The Tor Project, Inc. */
+#include "lib/cc/torint.h"
#include "lib/net/socketpair.h"
#include "lib/net/inaddr_st.h"
#include "lib/arch/bytes.h"
diff --git a/src/lib/process/daemon.c b/src/lib/process/daemon.c
index c64affd8b9..ab3ac73ad5 100644
--- a/src/lib/process/daemon.c
+++ b/src/lib/process/daemon.c
@@ -38,6 +38,16 @@ static int finish_daemon_called = 0;
/** Socketpair used to communicate between parent and child process while
* daemonizing. */
static int daemon_filedes[2];
+
+/**
+ * Return true iff we've called start_daemon() at least once.
+ */
+bool
+start_daemon_has_been_called(void)
+{
+ return start_daemon_called != 0;
+}
+
/** Start putting the process into daemon mode: fork and drop all resources
* except standard fds. The parent process never returns, but stays around
* until finish_daemon is called. (Note: it's safe to call this more
@@ -168,4 +178,10 @@ finish_daemon(const char *cp)
(void)cp;
return 0;
}
+bool
+start_daemon_has_been_called(void)
+{
+ return false;
+}
+
#endif /* !defined(_WIN32) */
diff --git a/src/lib/process/daemon.h b/src/lib/process/daemon.h
index c3b78029af..e33bd56701 100644
--- a/src/lib/process/daemon.h
+++ b/src/lib/process/daemon.h
@@ -11,7 +11,11 @@
#ifndef TOR_DAEMON_H
#define TOR_DAEMON_H
+#include <stdbool.h>
+
int start_daemon(void);
int finish_daemon(const char *desired_cwd);
+bool start_daemon_has_been_called(void);
+
#endif
diff --git a/src/lib/string/util_string.c b/src/lib/string/util_string.c
index a6b0a3d68a..b2b85d151d 100644
--- a/src/lib/string/util_string.c
+++ b/src/lib/string/util_string.c
@@ -451,3 +451,93 @@ string_is_C_identifier(const char *string)
return 1;
}
+
+/** A byte with the top <b>x</b> bits set. */
+#define TOP_BITS(x) ((uint8_t)(0xFF << (8 - (x))))
+/** A byte with the lowest <b>x</b> bits set. */
+#define LOW_BITS(x) ((uint8_t)(0xFF >> (8 - (x))))
+
+/** Given the leading byte <b>b</b>, return the total number of bytes in the
+ * UTF-8 character. Returns 0 if it's an invalid leading byte.
+ */
+static uint8_t
+bytes_in_char(uint8_t b)
+{
+ if ((TOP_BITS(1) & b) == 0x00)
+ return 1; // a 1-byte UTF-8 char, aka ASCII
+ if ((TOP_BITS(3) & b) == TOP_BITS(2))
+ return 2; // a 2-byte UTF-8 char
+ if ((TOP_BITS(4) & b) == TOP_BITS(3))
+ return 3; // a 3-byte UTF-8 char
+ if ((TOP_BITS(5) & b) == TOP_BITS(4))
+ return 4; // a 4-byte UTF-8 char
+
+ // Invalid: either the top 2 bits are 10, or the top 5 bits are 11111.
+ return 0;
+}
+
+/** Returns true iff <b>b</b> is a UTF-8 continuation byte. */
+static bool
+is_continuation_byte(uint8_t b)
+{
+ uint8_t top2bits = b & TOP_BITS(2);
+ return top2bits == TOP_BITS(1);
+}
+
+/** Returns true iff the <b>len</b> bytes in <b>c</b> are a valid UTF-8
+ * character.
+ */
+static bool
+validate_char(const uint8_t *c, uint8_t len)
+{
+ if (len == 1)
+ return true; // already validated this is an ASCII char earlier.
+
+ uint8_t mask = LOW_BITS(7 - len); // bitmask for the leading byte.
+ uint32_t codepoint = c[0] & mask;
+
+ mask = LOW_BITS(6); // bitmask for continuation bytes.
+ for (uint8_t i = 1; i < len; i++) {
+ if (!is_continuation_byte(c[i]))
+ return false;
+ codepoint <<= 6;
+ codepoint |= (c[i] & mask);
+ }
+
+ if (len == 2 && codepoint <= 0x7f)
+ return false; // Invalid, overly long encoding, should have fit in 1 byte.
+
+ if (len == 3 && codepoint <= 0x7ff)
+ return false; // Invalid, overly long encoding, should have fit in 2 bytes.
+
+ if (len == 4 && codepoint <= 0xffff)
+ return false; // Invalid, overly long encoding, should have fit in 3 bytes.
+
+ if (codepoint >= 0xd800 && codepoint <= 0xdfff)
+ return false; // Invalid, reserved for UTF-16 surrogate pairs.
+
+ return codepoint <= 0x10ffff; // Check if within maximum.
+}
+
+/** Returns true iff the first <b>len</b> bytes in <b>str</b> are a
+ valid UTF-8 string. */
+int
+string_is_utf8(const char *str, size_t len)
+{
+ for (size_t i = 0; i < len;) {
+ uint8_t num_bytes = bytes_in_char(str[i]);
+ if (num_bytes == 0) // Invalid leading byte found.
+ return false;
+
+ size_t next_char = i + num_bytes;
+ if (next_char > len)
+ return false;
+
+ // Validate the continuation bytes in this multi-byte character,
+ // and advance to the next character in the string.
+ if (!validate_char((const uint8_t*)&str[i], num_bytes))
+ return false;
+ i = next_char;
+ }
+ return true;
+}
diff --git a/src/lib/string/util_string.h b/src/lib/string/util_string.h
index 471613462a..746ece0d33 100644
--- a/src/lib/string/util_string.h
+++ b/src/lib/string/util_string.h
@@ -52,4 +52,6 @@ const char *find_str_at_start_of_line(const char *haystack,
int string_is_C_identifier(const char *string);
+int string_is_utf8(const char *str, size_t len);
+
#endif /* !defined(TOR_UTIL_STRING_H) */
diff --git a/src/lib/time/compat_time.c b/src/lib/time/compat_time.c
index d26cb6880d..f1ddb4fdc4 100644
--- a/src/lib/time/compat_time.c
+++ b/src/lib/time/compat_time.c
@@ -237,6 +237,7 @@ monotime_reset_ratchets_for_testing(void)
*/
static struct mach_timebase_info mach_time_info;
static struct mach_timebase_info mach_time_info_msec_cvt;
+static int32_t mach_time_msec_cvt_threshold;
static int monotime_shift = 0;
static void
@@ -256,11 +257,15 @@ monotime_init_internal(void)
}
{
// For converting ticks to milliseconds in a 32-bit-friendly way, we
- // will first right-shift by 20, and then multiply by 20/19, since
- // (1<<20) * 19/20 is about 1e6. We precompute a new numerate and
+ // will first right-shift by 20, and then multiply by 2048/1953, since
+ // (1<<20) * 1953/2048 is about 1e6. We precompute a new numerator and
// denominator here to avoid multiple multiplies.
- mach_time_info_msec_cvt.numer = mach_time_info.numer * 20;
- mach_time_info_msec_cvt.denom = mach_time_info.denom * 19;
+ mach_time_info_msec_cvt.numer = mach_time_info.numer * 2048;
+ mach_time_info_msec_cvt.denom = mach_time_info.denom * 1953;
+ // For any value above this amount, we should divide before multiplying,
+ // to avoid overflow. For a value below this, we should multiply
+ // before dividing, to improve accuracy.
+ mach_time_msec_cvt_threshold = INT32_MAX / mach_time_info_msec_cvt.numer;
}
}
@@ -323,8 +328,13 @@ monotime_coarse_diff_msec32_(const monotime_coarse_t *start,
/* We already require in di_ops.c that right-shift performs a sign-extend. */
const int32_t diff_microticks = (int32_t)(diff_ticks >> 20);
- return (diff_microticks * mach_time_info_msec_cvt.numer) /
- mach_time_info_msec_cvt.denom;
+ if (diff_microticks >= mach_time_msec_cvt_threshold) {
+ return (diff_microticks / mach_time_info_msec_cvt.denom) *
+ mach_time_info_msec_cvt.numer;
+ } else {
+ return (diff_microticks * mach_time_info_msec_cvt.numer) /
+ mach_time_info_msec_cvt.denom;
+ }
}
uint32_t
diff --git a/src/lib/time/compat_time.h b/src/lib/time/compat_time.h
index 4427ce8f92..44fab62de5 100644
--- a/src/lib/time/compat_time.h
+++ b/src/lib/time/compat_time.h
@@ -200,6 +200,7 @@ monotime_coarse_diff_msec32(const monotime_coarse_t *start,
// on a 64-bit platform, let's assume 64/64 division is cheap.
return (int32_t) monotime_coarse_diff_msec(start, end);
#else
+#define USING_32BIT_MSEC_HACK
return monotime_coarse_diff_msec32_(start, end);
#endif
}
diff --git a/src/lib/time/tvdiff.c b/src/lib/time/tvdiff.c
index 8617110e52..bc8a1166e7 100644
--- a/src/lib/time/tvdiff.c
+++ b/src/lib/time/tvdiff.c
@@ -165,3 +165,25 @@ tv_to_msec(const struct timeval *tv)
conv += ((int64_t)tv->tv_usec+500)/1000L;
return conv;
}
+
+/**
+ * Return duration in seconds between time_t values
+ * <b>t1</b> and <b>t2</b> iff <b>t1</b> is numerically
+ * less or equal than <b>t2</b>. Otherwise, return TIME_MAX.
+ *
+ * This provides a safe way to compute difference between
+ * two UNIX timestamps (<b>t2</b> can be assumed by calling
+ * code to be later than <b>t1</b>) or two durations measured
+ * in seconds (<b>t2</b> can be assumed to be longer than
+ * <b>t1</b>). Calling code is expected to check for TIME_MAX
+ * return value and interpret that as error condition.
+ */
+time_t
+time_diff(const time_t t1, const time_t t2)
+{
+ if (t1 <= t2)
+ return t2 - t1;
+
+ return TIME_MAX;
+}
+
diff --git a/src/lib/time/tvdiff.h b/src/lib/time/tvdiff.h
index d78330d7d8..a15ce52ad6 100644
--- a/src/lib/time/tvdiff.h
+++ b/src/lib/time/tvdiff.h
@@ -18,4 +18,6 @@ long tv_udiff(const struct timeval *start, const struct timeval *end);
long tv_mdiff(const struct timeval *start, const struct timeval *end);
int64_t tv_to_msec(const struct timeval *tv);
+time_t time_diff(const time_t from, const time_t to);
+
#endif
diff --git a/src/lib/tls/.may_include b/src/lib/tls/.may_include
index ca7cb455e4..2840e590b8 100644
--- a/src/lib/tls/.may_include
+++ b/src/lib/tls/.may_include
@@ -8,6 +8,7 @@ lib/ctime/*.h
lib/encoding/*.h
lib/intmath/*.h
lib/log/*.h
+lib/malloc/*.h
lib/net/*.h
lib/string/*.h
lib/testsupport/testsupport.h
diff --git a/src/lib/tls/include.am b/src/lib/tls/include.am
index b25e2e16bf..a664b29fb2 100644
--- a/src/lib/tls/include.am
+++ b/src/lib/tls/include.am
@@ -12,6 +12,7 @@ src_lib_libtor_tls_a_SOURCES = \
if USE_NSS
src_lib_libtor_tls_a_SOURCES += \
+ src/lib/tls/nss_countbytes.c \
src/lib/tls/tortls_nss.c \
src/lib/tls/x509_nss.c
else
@@ -31,6 +32,7 @@ src_lib_libtor_tls_testing_a_CFLAGS = \
noinst_HEADERS += \
src/lib/tls/ciphers.inc \
src/lib/tls/buffers_tls.h \
+ src/lib/tls/nss_countbytes.h \
src/lib/tls/tortls.h \
src/lib/tls/tortls_internal.h \
src/lib/tls/tortls_st.h \
diff --git a/src/lib/tls/nss_countbytes.c b/src/lib/tls/nss_countbytes.c
new file mode 100644
index 0000000000..c727684529
--- /dev/null
+++ b/src/lib/tls/nss_countbytes.c
@@ -0,0 +1,244 @@
+/* Copyright 2018, The Tor Project Inc. */
+/* See LICENSE for licensing information */
+
+/**
+ * \file nss_countbytes.c
+ * \brief A PRFileDesc layer to let us count the number of bytes
+ * bytes actually written on a PRFileDesc.
+ **/
+
+#include "orconfig.h"
+
+#include "lib/log/util_bug.h"
+#include "lib/malloc/malloc.h"
+#include "lib/tls/nss_countbytes.h"
+
+#include <stdlib.h>
+#include <string.h>
+
+#include <prio.h>
+
+/** Boolean: have we initialized this module */
+static bool countbytes_initialized = false;
+
+/** Integer to identity this layer. */
+static PRDescIdentity countbytes_layer_id = PR_INVALID_IO_LAYER;
+
+/** Table of methods for this layer.*/
+static PRIOMethods countbytes_methods;
+
+/** Default close function provided by NSPR. We use this to help
+ * implement our own close function.*/
+static PRStatus(*default_close_fn)(PRFileDesc *fd);
+
+static PRStatus countbytes_close_fn(PRFileDesc *fd);
+static PRInt32 countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount);
+static PRInt32 countbytes_write_fn(PRFileDesc *fd, const void *buf,
+ PRInt32 amount);
+static PRInt32 countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov,
+ PRInt32 size, PRIntervalTime timeout);
+static PRInt32 countbytes_send_fn(PRFileDesc *fd, const void *buf,
+ PRInt32 amount, PRIntn flags,
+ PRIntervalTime timeout);
+static PRInt32 countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount,
+ PRIntn flags, PRIntervalTime timeout);
+
+/** Private fields for the byte-counter layer. We cast this to and from
+ * PRFilePrivate*, which is supposed to be allowed. */
+typedef struct tor_nss_bytecounts_t {
+ uint64_t n_read;
+ uint64_t n_written;
+} tor_nss_bytecounts_t;
+
+/**
+ * Initialize this module, if it is not already initialized.
+ **/
+void
+tor_nss_countbytes_init(void)
+{
+ if (countbytes_initialized)
+ return;
+
+ countbytes_layer_id = PR_GetUniqueIdentity("Tor byte-counting layer");
+ tor_assert(countbytes_layer_id != PR_INVALID_IO_LAYER);
+
+ memcpy(&countbytes_methods, PR_GetDefaultIOMethods(), sizeof(PRIOMethods));
+
+ default_close_fn = countbytes_methods.close;
+ countbytes_methods.close = countbytes_close_fn;
+ countbytes_methods.read = countbytes_read_fn;
+ countbytes_methods.write = countbytes_write_fn;
+ countbytes_methods.writev = countbytes_writev_fn;
+ countbytes_methods.send = countbytes_send_fn;
+ countbytes_methods.recv = countbytes_recv_fn;
+ /* NOTE: We aren't wrapping recvfrom, sendto, or sendfile, since I think
+ * NSS won't be using them for TLS connections. */
+
+ countbytes_initialized = true;
+}
+
+/**
+ * Return the tor_nss_bytecounts_t object for a given IO layer. Asserts that
+ * the IO layer is in fact a layer created by this module.
+ */
+static tor_nss_bytecounts_t *
+get_counts(PRFileDesc *fd)
+{
+ tor_assert(fd->identity == countbytes_layer_id);
+ return (tor_nss_bytecounts_t*) fd->secret;
+}
+
+/** Helper: increment the read-count of an fd by n. */
+#define INC_READ(fd, n) STMT_BEGIN \
+ get_counts(fd)->n_read += (n); \
+ STMT_END
+
+/** Helper: increment the write-count of an fd by n. */
+#define INC_WRITTEN(fd, n) STMT_BEGIN \
+ get_counts(fd)->n_written += (n); \
+ STMT_END
+
+/** Implementation for PR_Close: frees the 'secret' field, then passes control
+ * to the default close function */
+static PRStatus
+countbytes_close_fn(PRFileDesc *fd)
+{
+ tor_assert(fd);
+
+ tor_nss_bytecounts_t *counts = (tor_nss_bytecounts_t *)fd->secret;
+ tor_free(counts);
+ fd->secret = NULL;
+
+ return default_close_fn(fd);
+}
+
+/** Implementation for PR_Read: Calls the lower-level read function,
+ * and records what it said. */
+static PRInt32
+countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount)
+{
+ tor_assert(fd);
+ tor_assert(fd->lower);
+
+ PRInt32 result = (fd->lower->methods->read)(fd->lower, buf, amount);
+ if (result > 0)
+ INC_READ(fd, result);
+ return result;
+}
+/** Implementation for PR_Write: Calls the lower-level write function,
+ * and records what it said. */
+static PRInt32
+countbytes_write_fn(PRFileDesc *fd, const void *buf, PRInt32 amount)
+{
+ tor_assert(fd);
+ tor_assert(fd->lower);
+
+ PRInt32 result = (fd->lower->methods->write)(fd->lower, buf, amount);
+ if (result > 0)
+ INC_WRITTEN(fd, result);
+ return result;
+}
+/** Implementation for PR_Writev: Calls the lower-level writev function,
+ * and records what it said. */
+static PRInt32
+countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov,
+ PRInt32 size, PRIntervalTime timeout)
+{
+ tor_assert(fd);
+ tor_assert(fd->lower);
+
+ PRInt32 result = (fd->lower->methods->writev)(fd->lower, iov, size, timeout);
+ if (result > 0)
+ INC_WRITTEN(fd, result);
+ return result;
+}
+/** Implementation for PR_Send: Calls the lower-level send function,
+ * and records what it said. */
+static PRInt32
+countbytes_send_fn(PRFileDesc *fd, const void *buf,
+ PRInt32 amount, PRIntn flags, PRIntervalTime timeout)
+{
+ tor_assert(fd);
+ tor_assert(fd->lower);
+
+ PRInt32 result = (fd->lower->methods->send)(fd->lower, buf, amount, flags,
+ timeout);
+ if (result > 0)
+ INC_WRITTEN(fd, result);
+ return result;
+}
+/** Implementation for PR_Recv: Calls the lower-level recv function,
+ * and records what it said. */
+static PRInt32
+countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount,
+ PRIntn flags, PRIntervalTime timeout)
+{
+ tor_assert(fd);
+ tor_assert(fd->lower);
+
+ PRInt32 result = (fd->lower->methods->recv)(fd->lower, buf, amount, flags,
+ timeout);
+ if (result > 0)
+ INC_READ(fd, result);
+ return result;
+}
+
+/**
+ * Wrap a PRFileDesc from NSPR with a new PRFileDesc that will count the
+ * total number of bytes read and written. Return the new PRFileDesc.
+ *
+ * This function takes ownership of its input.
+ */
+PRFileDesc *
+tor_wrap_prfiledesc_with_byte_counter(PRFileDesc *stack)
+{
+ if (BUG(! countbytes_initialized)) {
+ tor_nss_countbytes_init();
+ }
+
+ tor_nss_bytecounts_t *bytecounts = tor_malloc_zero(sizeof(*bytecounts));
+
+ PRFileDesc *newfd = PR_CreateIOLayerStub(countbytes_layer_id,
+ &countbytes_methods);
+ tor_assert(newfd);
+ newfd->secret = (PRFilePrivate *)bytecounts;
+
+ /* This does some complicated messing around with the headers of these
+ objects; see the NSPR documentation for more. The upshot is that
+ after PushIOLayer, "stack" will be the head of the stack.
+ */
+ PRStatus status = PR_PushIOLayer(stack, PR_TOP_IO_LAYER, newfd);
+ tor_assert(status == PR_SUCCESS);
+
+ return stack;
+}
+
+/**
+ * Given a PRFileDesc returned by tor_wrap_prfiledesc_with_byte_counter(),
+ * or another PRFileDesc wrapping that PRFileDesc, set the provided
+ * pointers to the number of bytes read and written on the descriptor since
+ * it was created.
+ *
+ * Return 0 on success, -1 on failure.
+ */
+int
+tor_get_prfiledesc_byte_counts(PRFileDesc *fd,
+ uint64_t *n_read_out,
+ uint64_t *n_written_out)
+{
+ if (BUG(! countbytes_initialized)) {
+ tor_nss_countbytes_init();
+ }
+
+ tor_assert(fd);
+ PRFileDesc *bclayer = PR_GetIdentitiesLayer(fd, countbytes_layer_id);
+ if (BUG(bclayer == NULL))
+ return -1;
+
+ tor_nss_bytecounts_t *counts = get_counts(bclayer);
+
+ *n_read_out = counts->n_read;
+ *n_written_out = counts->n_written;
+
+ return 0;
+}
diff --git a/src/lib/tls/nss_countbytes.h b/src/lib/tls/nss_countbytes.h
new file mode 100644
index 0000000000..f26280edf2
--- /dev/null
+++ b/src/lib/tls/nss_countbytes.h
@@ -0,0 +1,25 @@
+/* Copyright 2018, The Tor Project, Inc. */
+/* See LICENSE for licensing information */
+
+/**
+ * \file nss_countbytes.h
+ * \brief Header for nss_countbytes.c, which lets us count the number of
+ * bytes actually written on a PRFileDesc.
+ **/
+
+#ifndef TOR_NSS_COUNTBYTES_H
+#define TOR_NSS_COUNTBYTES_H
+
+#include "lib/cc/torint.h"
+
+void tor_nss_countbytes_init(void);
+
+struct PRFileDesc;
+struct PRFileDesc *tor_wrap_prfiledesc_with_byte_counter(
+ struct PRFileDesc *stack);
+
+int tor_get_prfiledesc_byte_counts(struct PRFileDesc *fd,
+ uint64_t *n_read_out,
+ uint64_t *n_written_out);
+
+#endif
diff --git a/src/lib/tls/tortls.c b/src/lib/tls/tortls.c
index 3ae3a1a096..56f70bc371 100644
--- a/src/lib/tls/tortls.c
+++ b/src/lib/tls/tortls.c
@@ -71,13 +71,19 @@ tor_tls_get_my_certs(int server,
const tor_x509_cert_t **id_cert_out)
{
tor_tls_context_t *ctx = tor_tls_context_get(server);
- if (! ctx)
- return -1;
+ int rv = -1;
+ const tor_x509_cert_t *link_cert = NULL;
+ const tor_x509_cert_t *id_cert = NULL;
+ if (ctx) {
+ rv = 0;
+ link_cert = server ? ctx->my_link_cert : ctx->my_auth_cert;
+ id_cert = ctx->my_id_cert;
+ }
if (link_cert_out)
- *link_cert_out = server ? ctx->my_link_cert : ctx->my_auth_cert;
+ *link_cert_out = link_cert;
if (id_cert_out)
- *id_cert_out = ctx->my_id_cert;
- return 0;
+ *id_cert_out = id_cert;
+ return rv;
}
/**
diff --git a/src/lib/tls/tortls.h b/src/lib/tls/tortls.h
index 4591927081..81db5ce5a9 100644
--- a/src/lib/tls/tortls.h
+++ b/src/lib/tls/tortls.h
@@ -94,6 +94,7 @@ void tor_tls_set_renegotiate_callback(tor_tls_t *tls,
void (*cb)(tor_tls_t *, void *arg),
void *arg);
int tor_tls_is_server(tor_tls_t *tls);
+void tor_tls_release_socket(tor_tls_t *tls);
void tor_tls_free_(tor_tls_t *tls);
#define tor_tls_free(tls) FREE_AND_NULL(tor_tls_t, tor_tls_free_, (tls))
int tor_tls_peer_has_cert(tor_tls_t *tls);
@@ -125,6 +126,10 @@ int tor_tls_server_got_renegotiate(tor_tls_t *tls);
MOCK_DECL(int,tor_tls_cert_matches_key,(const tor_tls_t *tls,
const struct tor_x509_cert_t *cert));
MOCK_DECL(int,tor_tls_get_tlssecrets,(tor_tls_t *tls, uint8_t *secrets_out));
+#ifdef ENABLE_OPENSSL
+/* OpenSSL lets us see these master secrets; NSS sensibly does not. */
+#define HAVE_WORKING_TOR_TLS_GET_TLSSECRETS
+#endif
MOCK_DECL(int,tor_tls_export_key_material,(
tor_tls_t *tls, uint8_t *secrets_out,
const uint8_t *context,
diff --git a/src/lib/tls/tortls_nss.c b/src/lib/tls/tortls_nss.c
index 53adfedf32..462cd5b0ff 100644
--- a/src/lib/tls/tortls_nss.c
+++ b/src/lib/tls/tortls_nss.c
@@ -31,11 +31,12 @@
#include "lib/tls/tortls.h"
#include "lib/tls/tortls_st.h"
#include "lib/tls/tortls_internal.h"
+#include "lib/tls/nss_countbytes.h"
#include "lib/log/util_bug.h"
DISABLE_GCC_WARNING(strict-prototypes)
#include <prio.h>
-// For access to raw sockets.
+// For access to rar sockets.
#include <private/pprio.h>
#include <ssl.h>
#include <sslt.h>
@@ -158,6 +159,8 @@ tor_tls_context_new(crypto_pk_t *identity,
SECStatus s;
tor_assert(identity);
+ tor_tls_init();
+
tor_tls_context_t *ctx = tor_malloc_zero(sizeof(tor_tls_context_t));
ctx->refcnt = 1;
@@ -320,7 +323,7 @@ tor_tls_get_state_description(tor_tls_t *tls, char *buf, size_t sz)
void
tor_tls_init(void)
{
- /* We don't have any global setup to do yet, but that will change */
+ tor_nss_countbytes_init();
}
void
@@ -373,7 +376,11 @@ tor_tls_new(tor_socket_t sock, int is_server)
if (!tcp)
return NULL;
- PRFileDesc *ssl = SSL_ImportFD(ctx->ctx, tcp);
+ PRFileDesc *count = tor_wrap_prfiledesc_with_byte_counter(tcp);
+ if (! count)
+ return NULL;
+
+ PRFileDesc *ssl = SSL_ImportFD(ctx->ctx, count);
if (!ssl) {
PR_Close(tcp);
return NULL;
@@ -414,6 +421,47 @@ tor_tls_set_renegotiate_callback(tor_tls_t *tls,
/* We don't support renegotiation-based TLS with NSS. */
}
+/**
+ * Tell the TLS library that the underlying socket for <b>tls</b> has been
+ * closed, and the library should not attempt to free that socket itself.
+ */
+void
+tor_tls_release_socket(tor_tls_t *tls)
+{
+ if (! tls)
+ return;
+
+ /* NSS doesn't have the equivalent of BIO_NO_CLOSE. If you replace the
+ * fd with something that's invalid, it causes a memory leak in PR_Close.
+ *
+ * If there were a way to put the PRFileDesc into the CLOSED state, that
+ * would prevent it from closing its fd -- but there doesn't seem to be a
+ * supported way to do that either.
+ *
+ * So instead: we make a new sacrificial socket, and replace the original
+ * socket with that one. This seems to be the best we can do, until we
+ * redesign the mainloop code enough to make this function unnecessary.
+ */
+ tor_socket_t sock =
+ tor_open_socket_nonblocking(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (! SOCKET_OK(sock)) {
+ log_warn(LD_NET, "Out of sockets when trying to shut down an NSS "
+ "connection");
+ return;
+ }
+
+ PRFileDesc *tcp = PR_GetIdentitiesLayer(tls->ssl, PR_NSPR_IO_LAYER);
+ if (BUG(! tcp)) {
+ tor_close_socket(sock);
+ return;
+ }
+
+ PR_ChangeFileDescNativeHandle(tcp, sock);
+ /* Tell our socket accounting layer that we don't own this socket any more:
+ * NSS is about to free it for us. */
+ tor_release_socket_ownership(sock);
+}
+
void
tor_tls_impl_free_(tor_tls_impl_t *tls)
{
@@ -465,7 +513,6 @@ tor_tls_read, (tor_tls_t *tls, char *cp, size_t len))
PRInt32 rv = PR_Read(tls->ssl, cp, (int)len);
// log_debug(LD_NET, "PR_Read(%zu) returned %d", n, (int)rv);
if (rv > 0) {
- tls->n_read_since_last_check += rv;
return rv;
}
if (rv == 0)
@@ -489,7 +536,6 @@ tor_tls_write(tor_tls_t *tls, const char *cp, size_t n)
PRInt32 rv = PR_Write(tls->ssl, cp, (int)n);
// log_debug(LD_NET, "PR_Write(%zu) returned %d", n, (int)rv);
if (rv > 0) {
- tls->n_written_since_last_check += rv;
return rv;
}
if (rv == 0)
@@ -579,13 +625,17 @@ tor_tls_get_n_raw_bytes(tor_tls_t *tls,
tor_assert(tls);
tor_assert(n_read);
tor_assert(n_written);
- /* XXXX We don't curently have a way to measure this information correctly
- * in NSS; we could do that with a PRIO layer, but it'll take a little
- * coding. For now, we just track the number of bytes sent _in_ the TLS
- * stream. Doing this will make our rate-limiting slightly inaccurate. */
- *n_read = tls->n_read_since_last_check;
- *n_written = tls->n_written_since_last_check;
- tls->n_read_since_last_check = tls->n_written_since_last_check = 0;
+ uint64_t r, w;
+ if (tor_get_prfiledesc_byte_counts(tls->ssl, &r, &w) < 0) {
+ *n_read = *n_written = 0;
+ return;
+ }
+
+ *n_read = (size_t)(r - tls->last_read_count);
+ *n_written = (size_t)(w - tls->last_write_count);
+
+ tls->last_read_count = r;
+ tls->last_write_count = w;
}
int
diff --git a/src/lib/tls/tortls_openssl.c b/src/lib/tls/tortls_openssl.c
index dc6c0bee9c..227225b96e 100644
--- a/src/lib/tls/tortls_openssl.c
+++ b/src/lib/tls/tortls_openssl.c
@@ -1048,7 +1048,7 @@ tor_tls_new(tor_socket_t sock, int isServer)
goto err;
}
result->socket = sock;
- bio = BIO_new_socket(sock, BIO_NOCLOSE);
+ bio = BIO_new_socket(sock, BIO_CLOSE);
if (! bio) {
tls_log_errors(NULL, LOG_WARN, LD_NET, "opening BIO");
#ifdef SSL_set_tlsext_host_name
@@ -1154,6 +1154,28 @@ tor_tls_assert_renegotiation_unblocked(tor_tls_t *tls)
#endif /* defined(SSL_OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION) && ... */
}
+/**
+ * Tell the TLS library that the underlying socket for <b>tls</b> has been
+ * closed, and the library should not attempt to free that socket itself.
+ */
+void
+tor_tls_release_socket(tor_tls_t *tls)
+{
+ if (! tls)
+ return;
+
+ BIO *rbio, *wbio;
+ rbio = SSL_get_rbio(tls->ssl);
+ wbio = SSL_get_wbio(tls->ssl);
+
+ if (rbio) {
+ (void) BIO_set_close(rbio, BIO_NOCLOSE);
+ }
+ if (wbio && wbio != rbio) {
+ (void) BIO_set_close(wbio, BIO_NOCLOSE);
+ }
+}
+
void
tor_tls_impl_free_(tor_tls_impl_t *ssl)
{
diff --git a/src/lib/tls/tortls_st.h b/src/lib/tls/tortls_st.h
index a1b59a37af..549443a4e7 100644
--- a/src/lib/tls/tortls_st.h
+++ b/src/lib/tls/tortls_st.h
@@ -66,8 +66,9 @@ struct tor_tls_t {
void *callback_arg;
#endif
#ifdef ENABLE_NSS
- size_t n_read_since_last_check;
- size_t n_written_since_last_check;
+ /** Last values retried from tor_get_prfiledesc_byte_counts(). */
+ uint64_t last_write_count;
+ uint64_t last_read_count;
#endif
};