diff options
Diffstat (limited to 'src/rust/protover/protover.rs')
-rw-r--r-- | src/rust/protover/protover.rs | 295 |
1 files changed, 152 insertions, 143 deletions
diff --git a/src/rust/protover/protover.rs b/src/rust/protover/protover.rs index 069b1088cd..826f1b73f1 100644 --- a/src/rust/protover/protover.rs +++ b/src/rust/protover/protover.rs @@ -113,94 +113,108 @@ pub fn get_supported_protocols() -> &'static str { .unwrap() } -/// Translates a vector representation of a protocol list into a HashMap -fn parse_protocols<P, S>( - protocols: P, -) -> Result<HashMap<Proto, HashSet<u32>>, &'static str> -where - P: Iterator<Item = S>, - S: AsRef<str>, -{ - let mut parsed = HashMap::new(); - - for subproto in protocols { - let (name, version) = get_proto_and_vers(subproto.as_ref())?; - parsed.insert(name, version); +pub struct SupportedProtocols(HashMap<Proto, Versions>); + +impl SupportedProtocols { + pub fn from_proto_entries<I, S>(protocol_strs: I) -> Result<Self, &'static str> + where + I: Iterator<Item = S>, + S: AsRef<str>, + { + let mut parsed = HashMap::new(); + for subproto in protocol_strs { + let (name, version) = get_proto_and_vers(subproto.as_ref())?; + parsed.insert(name, version); + } + Ok(SupportedProtocols(parsed)) } - Ok(parsed) -} -/// Translates a string representation of a protocol list to a HashMap -fn parse_protocols_from_string<'a>( - protocol_string: &'a str, -) -> Result<HashMap<Proto, HashSet<u32>>, &'static str> { - parse_protocols(protocol_string.split(" ")) -} - -/// Translates supported tor versions from a string into a HashMap, which is -/// useful when looking up a specific subprotocol. -/// -/// # Returns -/// -/// A `Result` whose `Ok` value is a `HashMap<Proto, <u32>>` holding all -/// subprotocols and versions currently supported by tor. -/// -/// The returned `Result`'s `Err` value is an `&'static str` with a description -/// of the error. -/// -fn tor_supported() -> Result<HashMap<Proto, HashSet<u32>>, &'static str> { - parse_protocols(get_supported_protocols().split(" ")) -} + /// Translates a string representation of a protocol list to a + /// SupportedProtocols instance. + /// + /// # Examples + /// + /// ``` + /// use protover::SupportedProtocols; + /// + /// let supported_protocols = SupportedProtocols::from_proto_entries_string( + /// "HSDir=1-2 HSIntro=3-4" + /// ); + /// ``` + pub fn from_proto_entries_string( + proto_entries: &str, + ) -> Result<Self, &'static str> { + Self::from_proto_entries(proto_entries.split(" ")) + } -/// Get the unique version numbers supported by a subprotocol. -/// -/// # Inputs -/// -/// * `version_string`, a string comprised of "[0-9,-]" -/// -/// # Returns -/// -/// A `Result` whose `Ok` value is a `HashSet<u32>` holding all of the unique -/// version numbers. If there were ranges in the `version_string`, then these -/// are expanded, i.e. `"1-3"` would expand to `HashSet<u32>::new([1, 2, 3])`. -/// The returned HashSet is *unordered*. -/// -/// The returned `Result`'s `Err` value is an `&'static str` with a description -/// of the error. -/// -/// # Errors -/// -/// This function will error if: -/// -/// * the `version_string` is empty or contains an equals (`"="`) sign, -/// * the expansion of a version range produces an error (see -/// `expand_version_range`), -/// * any single version number is not parseable as an `u32` in radix 10, or -/// * there are greater than 2^16 version numbers to expand. -/// -fn get_versions(version_string: &str) -> Result<HashSet<u32>, &'static str> { - if version_string.is_empty() { - return Err("version string is empty"); + /// Translate the supported tor versions from a string into a + /// HashMap, which is useful when looking up a specific + /// subprotocol. + /// + fn tor_supported() -> Result<Self, &'static str> { + Self::from_proto_entries_string(get_supported_protocols()) } +} - let mut versions = HashSet::<u32>::new(); +type Version = u32; + +/// Set of versions for a protocol. +#[derive(Debug, PartialEq, Eq)] +pub struct Versions(HashSet<Version>); + +impl Versions { + /// Get the unique version numbers supported by a subprotocol. + /// + /// # Inputs + /// + /// * `version_string`, a string comprised of "[0-9,-]" + /// + /// # Returns + /// + /// A `Result` whose `Ok` value is a `HashSet<u32>` holding all of the unique + /// version numbers. If there were ranges in the `version_string`, then these + /// are expanded, i.e. `"1-3"` would expand to `HashSet<u32>::new([1, 2, 3])`. + /// The returned HashSet is *unordered*. + /// + /// The returned `Result`'s `Err` value is an `&'static str` with a description + /// of the error. + /// + /// # Errors + /// + /// This function will error if: + /// + /// * the `version_string` is empty or contains an equals (`"="`) sign, + /// * the expansion of a version range produces an error (see + /// `expand_version_range`), + /// * any single version number is not parseable as an `u32` in radix 10, or + /// * there are greater than 2^16 version numbers to expand. + /// + fn from_version_string( + version_string: &str, + ) -> Result<Self, &'static str> { + if version_string.is_empty() { + return Err("version string is empty"); + } - for piece in version_string.split(",") { - if piece.contains("-") { - for p in expand_version_range(piece)? { - versions.insert(p); + let mut versions = HashSet::<Version>::new(); + + for piece in version_string.split(",") { + if piece.contains("-") { + for p in expand_version_range(piece)? { + versions.insert(p); + } + } else { + versions.insert(u32::from_str(piece).or( + Err("invalid protocol entry"), + )?); } - } else { - versions.insert(u32::from_str(piece).or( - Err("invalid protocol entry"), - )?); - } - if versions.len() > MAX_PROTOCOLS_TO_EXPAND as usize { - return Err("Too many versions to expand"); + if versions.len() > MAX_PROTOCOLS_TO_EXPAND as usize { + return Err("Too many versions to expand"); + } } + Ok(Versions(versions)) } - Ok(versions) } @@ -220,7 +234,7 @@ fn get_versions(version_string: &str) -> Result<HashSet<u32>, &'static str> { /// fn get_proto_and_vers<'a>( protocol_entry: &'a str, -) -> Result<(Proto, HashSet<u32>), &'static str> { +) -> Result<(Proto, Versions), &'static str> { let mut parts = protocol_entry.splitn(2, "="); let proto = match parts.next() { @@ -233,7 +247,7 @@ fn get_proto_and_vers<'a>( None => return Err("invalid protover entry"), }; - let versions = get_versions(vers)?; + let versions = Versions::from_version_string(vers)?; let proto_name = proto.parse()?; Ok((proto_name, versions)) @@ -260,19 +274,18 @@ fn contains_only_supported_protocols(proto_entry: &str) -> bool { Err(_) => return false, }; - let currently_supported: HashMap<Proto, HashSet<u32>> = - match tor_supported() { - Ok(n) => n, - Err(_) => return false, - }; + let currently_supported = match SupportedProtocols::tor_supported() { + Ok(n) => n.0, + Err(_) => return false, + }; let supported_versions = match currently_supported.get(&name) { Some(n) => n, None => return false, }; - vers.retain(|x| !supported_versions.contains(x)); - vers.is_empty() + vers.0.retain(|x| !supported_versions.0.contains(x)); + vers.0.is_empty() } /// Determine if we support every protocol a client supports, and if not, @@ -318,7 +331,7 @@ pub fn all_supported(protocols: &str) -> (bool, String) { /// /// * `list`, a string representation of a list of protocol entries. /// * `proto`, a `Proto` to test support for -/// * `vers`, a `u32` version which we will go on to determine whether the +/// * `vers`, a `Version` version which we will go on to determine whether the /// specified protocol supports. /// /// # Examples @@ -336,21 +349,19 @@ pub fn all_supported(protocols: &str) -> (bool, String) { pub fn protover_string_supports_protocol( list: &str, proto: Proto, - vers: u32, + vers: Version, ) -> bool { - let supported: HashMap<Proto, HashSet<u32>>; - - match parse_protocols_from_string(list) { - Ok(result) => supported = result, + let supported = match SupportedProtocols::from_proto_entries_string(list) { + Ok(result) => result.0, Err(_) => return false, - } + }; let supported_versions = match supported.get(&proto) { Some(n) => n, None => return false, }; - supported_versions.contains(&vers) + supported_versions.0.contains(&vers) } /// As protover_string_supports_protocol(), but also returns True if @@ -380,23 +391,21 @@ pub fn protover_string_supports_protocol_or_later( proto: Proto, vers: u32, ) -> bool { - let supported: HashMap<Proto, HashSet<u32>>; - - match parse_protocols_from_string(list) { - Ok(result) => supported = result, + let supported = match SupportedProtocols::from_proto_entries_string(list) { + Ok(result) => result.0, Err(_) => return false, - } + }; let supported_versions = match supported.get(&proto) { Some(n) => n, None => return false, }; - supported_versions.iter().any(|v| v >= &vers) + supported_versions.0.iter().any(|v| v >= &vers) } /// Fully expand a version range. For example, 1-3 expands to 1,2,3 -/// Helper for get_versions +/// Helper for Versions::from_version_string /// /// # Inputs /// @@ -498,10 +507,9 @@ fn find_range(list: &Vec<u32>) -> (bool, u32) { /// /// A `String` representation of this set in ascending order. /// -fn contract_protocol_list<'a>(supported_set: &'a HashSet<u32>) -> String { - let mut supported: Vec<u32> = supported_set.iter() - .map(|x| *x) - .collect(); +fn contract_protocol_list<'a>(supported_set: &'a HashSet<Version>) -> String { + let mut supported: Vec<Version> = + supported_set.iter().map(|x| *x).collect(); supported.sort(); let mut final_output: Vec<String> = Vec::new(); @@ -537,8 +545,8 @@ fn contract_protocol_list<'a>(supported_set: &'a HashSet<u32>) -> String { /// /// # Returns /// -/// A `Result` whose `Ok` value is a `HashSet<u32>` holding all of the unique -/// version numbers. +/// A `Result` whose `Ok` value is a `HashSet<Version>` holding all of the +/// unique version numbers. /// /// The returned `Result`'s `Err` value is an `&'static str` with a description /// of the error. @@ -549,12 +557,12 @@ fn contract_protocol_list<'a>(supported_set: &'a HashSet<u32>) -> String { /// /// * The protocol string does not follow the "protocol_name=version_list" /// expected format -/// * If the version string is malformed. See `get_versions`. +/// * If the version string is malformed. See `Versions::from_version_string`. /// fn parse_protocols_from_string_with_no_validation<'a>( protocol_string: &'a str, -) -> Result<HashMap<String, HashSet<u32>>, &'static str> { - let mut parsed: HashMap<String, HashSet<u32>> = HashMap::new(); +) -> Result<HashMap<String, Versions>, &'static str> { + let mut parsed: HashMap<String, Versions> = HashMap::new(); for subproto in protocol_string.split(" ") { let mut parts = subproto.splitn(2, "="); @@ -569,7 +577,7 @@ fn parse_protocols_from_string_with_no_validation<'a>( None => return Err("invalid protover entry"), }; - let versions = get_versions(vers)?; + let versions = Versions::from_version_string(vers)?; parsed.insert(String::from(name), versions); } @@ -617,21 +625,22 @@ pub fn compute_vote( // } // means that FirstSupportedProtocol has three votes which support version // 1, and one vote that supports version 2 - let mut all_count: HashMap<String, HashMap<u32, usize>> = HashMap::new(); + let mut all_count: HashMap<String, HashMap<Version, usize>> = + HashMap::new(); // parse and collect all of the protos and their versions and collect them for vote in list_of_proto_strings { - let this_vote: HashMap<String, HashSet<u32>> = + let this_vote: HashMap<String, Versions> = match parse_protocols_from_string_with_no_validation(&vote) { Ok(result) => result, Err(_) => continue, }; for (protocol, versions) in this_vote { - let supported_vers: &mut HashMap<u32, usize> = + let supported_vers: &mut HashMap<Version, usize> = all_count.entry(protocol).or_insert(HashMap::new()); - for version in versions { + for version in versions.0 { let counter: &mut usize = supported_vers.entry(version).or_insert(0); *counter += 1; @@ -705,20 +714,18 @@ fn write_vote_to_string(vote: &HashMap<String, String>) -> String { /// let is_supported = is_supported_here(Proto::Link, 1); /// assert_eq!(true, is_supported); /// ``` -pub fn is_supported_here(proto: Proto, vers: u32) -> bool { - let currently_supported: HashMap<Proto, HashSet<u32>>; - - match tor_supported() { - Ok(result) => currently_supported = result, +pub fn is_supported_here(proto: Proto, vers: Version) -> bool { + let currently_supported = match SupportedProtocols::tor_supported() { + Ok(result) => result.0, Err(_) => return false, - } + }; let supported_versions = match currently_supported.get(&proto) { Some(n) => n, None => return false, }; - supported_versions.contains(&vers) + supported_versions.0.contains(&vers) } /// Older versions of Tor cannot infer their own subprotocols @@ -764,48 +771,50 @@ pub fn compute_for_old_tor(version: &str) -> &'static [u8] { #[cfg(test)] mod test { + use super::Version; + #[test] - fn test_get_versions() { + fn test_versions_from_version_string() { use std::collections::HashSet; - use super::get_versions; + use super::Versions; - assert_eq!(Err("version string is empty"), get_versions("")); - assert_eq!(Err("invalid protocol entry"), get_versions("a,b")); - assert_eq!(Err("invalid protocol entry"), get_versions("1,!")); + assert_eq!(Err("version string is empty"), Versions::from_version_string("")); + assert_eq!(Err("invalid protocol entry"), Versions::from_version_string("a,b")); + assert_eq!(Err("invalid protocol entry"), Versions::from_version_string("1,!")); { - let mut versions: HashSet<u32> = HashSet::new(); + let mut versions: HashSet<Version> = HashSet::new(); versions.insert(1); - assert_eq!(Ok(versions), get_versions("1")); + assert_eq!(versions, Versions::from_version_string("1").unwrap().0); } { - let mut versions: HashSet<u32> = HashSet::new(); + let mut versions: HashSet<Version> = HashSet::new(); versions.insert(1); versions.insert(2); - assert_eq!(Ok(versions), get_versions("1,2")); + assert_eq!(versions, Versions::from_version_string("1,2").unwrap().0); } { - let mut versions: HashSet<u32> = HashSet::new(); + let mut versions: HashSet<Version> = HashSet::new(); versions.insert(1); versions.insert(2); versions.insert(3); - assert_eq!(Ok(versions), get_versions("1-3")); + assert_eq!(versions, Versions::from_version_string("1-3").unwrap().0); } { - let mut versions: HashSet<u32> = HashSet::new(); + let mut versions: HashSet<Version> = HashSet::new(); versions.insert(1); versions.insert(2); versions.insert(5); - assert_eq!(Ok(versions), get_versions("1-2,5")); + assert_eq!(versions, Versions::from_version_string("1-2,5").unwrap().0); } { - let mut versions: HashSet<u32> = HashSet::new(); + let mut versions: HashSet<Version> = HashSet::new(); versions.insert(1); versions.insert(3); versions.insert(4); versions.insert(5); - assert_eq!(Ok(versions), get_versions("1,3-5")); + assert_eq!(versions, Versions::from_version_string("1,3-5").unwrap().0); } } @@ -861,7 +870,7 @@ mod test { use super::contract_protocol_list; { - let mut versions = HashSet::<u32>::new(); + let mut versions = HashSet::<Version>::new(); assert_eq!(String::from(""), contract_protocol_list(&versions)); versions.insert(1); @@ -872,14 +881,14 @@ mod test { } { - let mut versions = HashSet::<u32>::new(); + let mut versions = HashSet::<Version>::new(); versions.insert(1); versions.insert(3); assert_eq!(String::from("1,3"), contract_protocol_list(&versions)); } { - let mut versions = HashSet::<u32>::new(); + let mut versions = HashSet::<Version>::new(); versions.insert(1); versions.insert(2); versions.insert(3); @@ -888,7 +897,7 @@ mod test { } { - let mut versions = HashSet::<u32>::new(); + let mut versions = HashSet::<Version>::new(); versions.insert(1); versions.insert(3); versions.insert(5); @@ -901,7 +910,7 @@ mod test { } { - let mut versions = HashSet::<u32>::new(); + let mut versions = HashSet::<Version>::new(); versions.insert(1); versions.insert(2); versions.insert(3); |