aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Mathewson <nickm@torproject.org>2024-05-09 21:03:08 +0000
committerNick Mathewson <nickm@torproject.org>2024-05-09 21:03:08 +0000
commit1fb9afb478e6edb37071ab6804a92905795fd81e (patch)
tree5b9e36055c78a8112151431b9a9e5905d9a8530e
parent229804dccab97e51f82c997aae0dbe69774f31cf (diff)
parentf7cef4bb0e7fe0ad46d87a3deea58addccd6ed7b (diff)
downloadarti-1fb9afb478e6edb37071ab6804a92905795fd81e.tar.gz
arti-1fb9afb478e6edb37071ab6804a92905795fd81e.zip
Merge branch 'rpc_stream_preliminaries' into 'main'
RPC: Preliminaries for RPC-stream integration See merge request tpo/core/arti!2140
-rw-r--r--Cargo.lock1
-rw-r--r--crates/arti-client/src/client.rs7
-rw-r--r--crates/arti-client/src/lib.rs2
-rw-r--r--crates/arti-client/src/rpc.rs28
-rw-r--r--crates/arti-rpcserver/src/codecs.rs (renamed from crates/arti-rpcserver/src/streams.rs)0
-rw-r--r--crates/arti-rpcserver/src/connection.rs2
-rw-r--r--crates/arti-rpcserver/src/lib.rs2
-rw-r--r--crates/arti-rpcserver/src/mgr.rs9
-rw-r--r--crates/arti-rpcserver/src/session.rs69
-rw-r--r--crates/tor-proto/src/stream/data.rs2
-rw-r--r--crates/tor-rpcbase/Cargo.toml1
-rw-r--r--crates/tor-rpcbase/src/dispatch.rs2
-rw-r--r--crates/tor-rpcbase/src/lib.rs10
-rw-r--r--crates/tor-rpcbase/src/obj.rs47
-rw-r--r--crates/tor-rpcbase/src/obj/cast.rs110
15 files changed, 241 insertions, 51 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 9810cf256..9540aa3aa 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -5834,6 +5834,7 @@ version = "0.18.0"
dependencies = [
"assert-impl",
"derive-deftly",
+ "derive_more",
"downcast-rs",
"erased-serde",
"futures",
diff --git a/crates/arti-client/src/client.rs b/crates/arti-client/src/client.rs
index 8e5e69fa0..77646f19e 100644
--- a/crates/arti-client/src/client.rs
+++ b/crates/arti-client/src/client.rs
@@ -70,7 +70,12 @@ use tracing::{debug, info};
// implicit Arcs inside them! maybe it's time to replace much of the insides of
// this with an Arc<TorClientInner>?
#[derive(Clone)]
-#[cfg_attr(feature = "rpc", derive(Deftly), derive_deftly(Object))]
+#[cfg_attr(
+ feature = "rpc",
+ derive(Deftly),
+ derive_deftly(Object),
+ deftly(rpc(expose_outside_of_session))
+)]
pub struct TorClient<R: Runtime> {
/// Asynchronous runtime object.
runtime: R,
diff --git a/crates/arti-client/src/lib.rs b/crates/arti-client/src/lib.rs
index 0d2948015..bc2229366 100644
--- a/crates/arti-client/src/lib.rs
+++ b/crates/arti-client/src/lib.rs
@@ -45,7 +45,7 @@ mod address;
mod builder;
mod client;
#[cfg(feature = "rpc")]
-mod rpc;
+pub mod rpc;
mod util;
pub mod config;
diff --git a/crates/arti-client/src/rpc.rs b/crates/arti-client/src/rpc.rs
index 6af391290..9e866dead 100644
--- a/crates/arti-client/src/rpc.rs
+++ b/crates/arti-client/src/rpc.rs
@@ -17,8 +17,9 @@ impl<R: Runtime> crate::TorClient<R> {
/// parameterized.
pub fn rpc_methods() -> Vec<rpc::dispatch::InvokerEnt> {
rpc::invoker_ent_list![
- get_client_status::<R>, //
- watch_client_status::<R>
+ get_client_status::<R>,
+ watch_client_status::<R>,
+ isolated_client::<R>,
]
}
}
@@ -113,3 +114,26 @@ async fn watch_client_status<R: Runtime>(
// This can only happen if the client exits.
Ok(rpc::NIL)
}
+
+/// RPC method: Return an owned ID for a new isolated client instance.
+#[derive(Deftly, Debug, Serialize, Deserialize)]
+#[derive_deftly(rpc::DynMethod)]
+#[deftly(rpc(method_name = "arti::isolated-client"))]
+#[non_exhaustive]
+pub struct IsolatedClient {}
+
+impl rpc::Method for IsolatedClient {
+ type Output = rpc::SingletonId;
+ type Update = rpc::NoUpdates;
+}
+
+/// RPC method implementation: return a new isolated client based on a given client.
+async fn isolated_client<R: Runtime>(
+ client: Arc<TorClient<R>>,
+ _method: Box<IsolatedClient>,
+ ctx: Box<dyn rpc::Context>,
+) -> Result<rpc::SingletonId, rpc::RpcError> {
+ let new_client = Arc::new(client.isolated_client());
+ let client_id = ctx.register_owned(new_client);
+ Ok(rpc::SingletonId::from(client_id))
+}
diff --git a/crates/arti-rpcserver/src/streams.rs b/crates/arti-rpcserver/src/codecs.rs
index 43ebc6524..43ebc6524 100644
--- a/crates/arti-rpcserver/src/streams.rs
+++ b/crates/arti-rpcserver/src/codecs.rs
diff --git a/crates/arti-rpcserver/src/connection.rs b/crates/arti-rpcserver/src/connection.rs
index 44c41ab8d..807ba5de6 100644
--- a/crates/arti-rpcserver/src/connection.rs
+++ b/crates/arti-rpcserver/src/connection.rs
@@ -223,7 +223,7 @@ impl Connection {
{
let write = Box::pin(asynchronous_codec::FramedWrite::new(
output,
- crate::streams::JsonLinesEncoder::<BoxedResponse>::default(),
+ crate::codecs::JsonLinesEncoder::<BoxedResponse>::default(),
));
let read = Box::pin(
diff --git a/crates/arti-rpcserver/src/lib.rs b/crates/arti-rpcserver/src/lib.rs
index 0a244435c..16758e217 100644
--- a/crates/arti-rpcserver/src/lib.rs
+++ b/crates/arti-rpcserver/src/lib.rs
@@ -41,6 +41,7 @@
//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
mod cancel;
+mod codecs;
mod connection;
mod err;
mod globalid;
@@ -48,7 +49,6 @@ mod mgr;
mod msgs;
mod objmap;
mod session;
-mod streams;
pub use connection::{auth::RpcAuthentication, Connection, ConnectionError};
pub use mgr::RpcMgr;
diff --git a/crates/arti-rpcserver/src/mgr.rs b/crates/arti-rpcserver/src/mgr.rs
index b66f6bdef..f81261230 100644
--- a/crates/arti-rpcserver/src/mgr.rs
+++ b/crates/arti-rpcserver/src/mgr.rs
@@ -9,13 +9,13 @@ use weak_table::WeakValueHashMap;
use crate::{
connection::{Connection, ConnectionId},
globalid::{GlobalId, MacKey},
- RpcAuthentication, RpcSession,
+ RpcAuthentication,
};
/// A function we use to construct Session objects in response to authentication.
//
// TODO RPC: Perhaps this should return a Result?
-type SessionFactory = Box<dyn Fn(&RpcAuthentication) -> Arc<RpcSession> + Send + Sync>;
+type SessionFactory = Box<dyn Fn(&RpcAuthentication) -> Arc<dyn rpc::Object> + Send + Sync>;
/// Shared state, configuration, and data for all RPC sessions.
///
@@ -93,10 +93,9 @@ pub(crate) struct Inner {
impl RpcMgr {
/// Create a new RpcMgr.
- ///
pub fn new<F>(make_session: F) -> Arc<Self>
where
- F: Fn(&RpcAuthentication) -> Arc<RpcSession> + Send + Sync + 'static,
+ F: Fn(&RpcAuthentication) -> Arc<dyn rpc::Object> + Send + Sync + 'static,
{
Arc::new(RpcMgr {
global_id_mac_key: MacKey::new(&mut rand::thread_rng()),
@@ -189,7 +188,7 @@ impl RpcMgr {
}
/// Construct a new object to serve as the `session` for a connection.
- pub(crate) fn create_session(&self, auth: &RpcAuthentication) -> Arc<RpcSession> {
+ pub(crate) fn create_session(&self, auth: &RpcAuthentication) -> Arc<dyn rpc::Object> {
(self.session_factory)(auth)
}
}
diff --git a/crates/arti-rpcserver/src/session.rs b/crates/arti-rpcserver/src/session.rs
index b9485d96f..0db7f1402 100644
--- a/crates/arti-rpcserver/src/session.rs
+++ b/crates/arti-rpcserver/src/session.rs
@@ -3,8 +3,11 @@
//! A "session" is created when a user authenticates on an RPC connection. It
//! is the root for all other RPC capabilities.
+use arti_client::TorClient;
use derive_deftly::Deftly;
+use rpc::static_rpc_invoke_fn;
use std::sync::Arc;
+use tor_rtcompat::Runtime;
use tor_rpcbase as rpc;
use tor_rpcbase::templates::*;
@@ -27,14 +30,31 @@ pub struct RpcSession {
/// An inner TorClient object that we use to implement remaining
/// functionality.
#[allow(unused)]
- client: Arc<dyn rpc::Object>,
+ client: Arc<dyn Client>,
+}
+
+/// Type-erased `TorClient``, as used within an RpcSession.
+trait Client: rpc::Object {
+ /// Return a new isolated TorClient.
+ fn isolated_client(&self) -> Arc<dyn rpc::Object>;
+
+ /// Upcast `self` to an rpc::Object.
+ fn upcast_arc(self: Arc<Self>) -> Arc<dyn rpc::Object>;
+}
+
+impl<R: Runtime> Client for TorClient<R> {
+ fn isolated_client(&self) -> Arc<dyn rpc::Object> {
+ Arc::new(TorClient::isolated_client(self))
+ }
+
+ fn upcast_arc(self: Arc<Self>) -> Arc<dyn rpc::Object> {
+ self
+ }
}
impl RpcSession {
/// Create a new session object containing a single client object.
- pub fn new_with_client<R: tor_rtcompat::Runtime>(
- client: Arc<arti_client::TorClient<R>>,
- ) -> Arc<Self> {
+ pub fn new_with_client<R: Runtime>(client: Arc<arti_client::TorClient<R>>) -> Arc<Self> {
Arc::new(Self { client })
}
}
@@ -76,9 +96,6 @@ async fn rpc_release(
ctx.release_owned(&method.obj)?;
Ok(rpc::Nil::default())
}
-rpc::static_rpc_invoke_fn! {
- rpc_release;
-}
/// A simple temporary method to echo a reply.
#[derive(Debug, serde::Deserialize, serde::Serialize, Deftly)]
@@ -105,6 +122,42 @@ async fn echo_on_session(
Ok(*method)
}
-rpc::static_rpc_invoke_fn! {
+/// An RPC method to return the default client for a session.
+#[derive(Debug, serde::Deserialize, serde::Serialize, Deftly)]
+#[derive_deftly(DynMethod)]
+#[deftly(rpc(method_name = "arti:get-client"))]
+struct GetClient {}
+
+impl rpc::Method for GetClient {
+ type Output = rpc::SingletonId;
+ type Update = rpc::NoUpdates;
+}
+
+/// Implement GetClient on an RpcSession.
+async fn get_client_on_session(
+ session: Arc<RpcSession>,
+ _method: Box<GetClient>,
+ ctx: Box<dyn rpc::Context>,
+) -> Result<rpc::SingletonId, rpc::RpcError> {
+ Ok(rpc::SingletonId::from(
+ // TODO RPC: This relies (somewhat) on deduplication properties for register_owned.
+ ctx.register_owned(session.client.clone().upcast_arc()),
+ ))
+}
+
+/// Implement IsolatedClient on an RpcSession.
+async fn isolated_client_on_session(
+ session: Arc<RpcSession>,
+ _method: Box<arti_client::rpc::IsolatedClient>,
+ ctx: Box<dyn rpc::Context>,
+) -> Result<rpc::SingletonId, rpc::RpcError> {
+ let new_client = session.client.isolated_client();
+ Ok(rpc::SingletonId::from(ctx.register_owned(new_client)))
+}
+
+static_rpc_invoke_fn! {
+ rpc_release;
echo_on_session;
+ get_client_on_session;
+ isolated_client_on_session;
}
diff --git a/crates/tor-proto/src/stream/data.rs b/crates/tor-proto/src/stream/data.rs
index 84f8040ff..d74ef5b67 100644
--- a/crates/tor-proto/src/stream/data.rs
+++ b/crates/tor-proto/src/stream/data.rs
@@ -382,7 +382,7 @@ impl DataStream {
/// is received to indicate an error.
///
/// Does nothing if this stream is already connected.
- pub(crate) async fn wait_for_connection(&mut self) -> Result<()> {
+ pub async fn wait_for_connection(&mut self) -> Result<()> {
// We must put state back before returning
let state = self.r.state.take().expect("Missing state in DataReader");
diff --git a/crates/tor-rpcbase/Cargo.toml b/crates/tor-rpcbase/Cargo.toml
index 78718eeb2..813dae39a 100644
--- a/crates/tor-rpcbase/Cargo.toml
+++ b/crates/tor-rpcbase/Cargo.toml
@@ -13,6 +13,7 @@ repository = "https://gitlab.torproject.org/tpo/core/arti.git/"
[dependencies]
derive-deftly = "0.10.3"
+derive_more = "0.99.3"
downcast-rs = "1.2.1"
erased-serde = "0.4.2"
futures = "0.3.14"
diff --git a/crates/tor-rpcbase/src/dispatch.rs b/crates/tor-rpcbase/src/dispatch.rs
index f15c3b575..ba485bb36 100644
--- a/crates/tor-rpcbase/src/dispatch.rs
+++ b/crates/tor-rpcbase/src/dispatch.rs
@@ -282,7 +282,7 @@ macro_rules! invoker_ent {
/// ```
#[macro_export]
macro_rules! invoker_ent_list {
- { $($func:expr),* } => {
+ { $($func:expr),* $(,)? } => {
vec![
$(
$crate::invoker_ent!($func)
diff --git a/crates/tor-rpcbase/src/lib.rs b/crates/tor-rpcbase/src/lib.rs
index d851351c3..3b9b65a93 100644
--- a/crates/tor-rpcbase/src/lib.rs
+++ b/crates/tor-rpcbase/src/lib.rs
@@ -50,7 +50,7 @@ use std::{convert::Infallible, sync::Arc};
pub use dispatch::{DispatchTable, InvokeError, UpdateSink};
pub use err::RpcError;
pub use method::{is_method_name, iter_method_names, DynMethod, Method, NoUpdates};
-pub use obj::{Object, ObjectId, ObjectRefExt};
+pub use obj::{Object, ObjectArcExt, ObjectId};
#[doc(hidden)]
pub use obj::cast::CastTable;
@@ -174,3 +174,11 @@ impl<T: Context> ContextExt for T {}
pub struct Nil {}
/// An instance of rpc::Nil.
pub const NIL: Nil = Nil {};
+
+/// Common return type for RPC methods that return a single object ID
+/// and nothing else.
+#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, derive_more::From)]
+pub struct SingletonId {
+ /// The ID of the object that we're returning.
+ id: ObjectId,
+}
diff --git a/crates/tor-rpcbase/src/obj.rs b/crates/tor-rpcbase/src/obj.rs
index 6438457d9..cf2cb5461 100644
--- a/crates/tor-rpcbase/src/obj.rs
+++ b/crates/tor-rpcbase/src/obj.rs
@@ -2,6 +2,8 @@
pub(crate) mod cast;
+use std::sync::Arc;
+
use derive_deftly::define_derive_deftly;
use downcast_rs::DowncastSync;
use serde::{Deserialize, Serialize};
@@ -75,7 +77,7 @@ where
}
}
-/// Extension trait for `dyn Object` and similar to support convenient
+/// Extension trait for `Arc<dyn Object>` to support convenient
/// downcasting to `dyn Trait`.
///
/// You don't need to use this for downcasting to an object's concrete
@@ -84,7 +86,7 @@ where
/// # Examples
///
/// ```
-/// use tor_rpcbase::{Object, ObjectRefExt, templates::*};
+/// use tor_rpcbase::{Object, ObjectArcExt, templates::*};
/// use derive_deftly::Deftly;
/// use std::sync::Arc;
///
@@ -111,22 +113,35 @@ where
///
/// assert_eq!(check_feet(Arc::new(Frog{})), 4);
/// ```
-pub trait ObjectRefExt {
- /// Try to cast this `Object` to a `T`. On success, return a reference to
+pub trait ObjectArcExt {
+ /// Try to cast this `Arc<dyn Object>` to a `T`. On success, return a reference to
/// T; on failure, return None.
fn cast_to_trait<T: ?Sized + 'static>(&self) -> Option<&T>;
+
+ /// Try to cast this `Arc<dyn Object>` to an `Arc<T>`.
+ fn cast_to_arc_trait<T: ?Sized + 'static>(self) -> Result<Arc<T>, Arc<dyn Object>>;
}
-impl ObjectRefExt for dyn Object {
- fn cast_to_trait<T: ?Sized + 'static>(&self) -> Option<&T> {
+impl dyn Object {
+ /// Try to cast this `Object` to a `T`. On success, return a reference to
+ /// T; on failure, return None.
+ ///
+ /// This method is only for casting to `&dyn Trait`;
+ /// see [`ObjectArcExt`] for limitations.
+ pub fn cast_to_trait<T: ?Sized + 'static>(&self) -> Option<&T> {
let table = self.get_cast_table();
table.cast_object_to(self)
}
}
-impl ObjectRefExt for std::sync::Arc<dyn Object> {
+impl ObjectArcExt for Arc<dyn Object> {
fn cast_to_trait<T: ?Sized + 'static>(&self) -> Option<&T> {
- self.as_ref().cast_to_trait()
+ let obj: &dyn Object = self.as_ref();
+ obj.cast_to_trait()
+ }
+ fn cast_to_arc_trait<T: ?Sized + 'static>(self) -> Result<Arc<T>, Arc<dyn Object>> {
+ let table = self.get_cast_table();
+ table.cast_object_to_arc(self.clone())
}
}
@@ -181,7 +196,7 @@ define_derive_deftly! {
/// impl Doodad for Frobnitz {}
///
/// use std::sync::Arc;
-/// use rpc::ObjectRefExt; // for the cast_to method.
+/// use rpc::ObjectArcExt; // for the cast_to method.
/// let frob_obj: Arc<dyn rpc::Object> = Arc::new(Frobnitz {});
/// let gizmo: &dyn Gizmo = frob_obj.cast_to_trait().unwrap();
/// let doodad: &dyn Doodad = frob_obj.cast_to_trait().unwrap();
@@ -209,7 +224,7 @@ define_derive_deftly! {
/// impl<T:Clone,U:PartialEq> ExampleTrait for Generic<T,U> {}
///
/// use std::sync::Arc;
-/// use rpc::ObjectRefExt; // for the cast_to method.
+/// use rpc::ObjectArcExt; // for the cast_to method.
/// let obj: Arc<dyn rpc::Object> = Arc::new(Generic { t: 42_u8, u: 42_u8 });
/// let tr: &dyn ExampleTrait = obj.cast_to_trait().unwrap();
/// ```
@@ -381,5 +396,17 @@ mod test {
let erased_bikes: &dyn Object = &bikes;
let has_wheels: &dyn HasWheels = erased_bikes.cast_to_trait().unwrap();
assert_eq!(has_wheels.num_wheels(), 4);
+
+ let arc_bikes = Arc::new(bikes);
+ let erased_arc_bytes: Arc<dyn Object> = arc_bikes.clone();
+ let arc_has_wheels: Arc<dyn HasWheels> =
+ erased_arc_bytes.clone().cast_to_arc_trait().ok().unwrap();
+ assert_eq!(arc_has_wheels.num_wheels(), 4);
+
+ trait SomethingElse {}
+ let arc_something_else: Result<Arc<dyn SomethingElse>, _> =
+ erased_arc_bytes.clone().cast_to_arc_trait();
+ let err_arc = arc_something_else.err().unwrap();
+ assert!(Arc::ptr_eq(&err_arc, &erased_arc_bytes));
}
}
diff --git a/crates/tor-rpcbase/src/obj/cast.rs b/crates/tor-rpcbase/src/obj/cast.rs
index dafaa54b3..8b9c46a05 100644
--- a/crates/tor-rpcbase/src/obj/cast.rs
+++ b/crates/tor-rpcbase/src/obj/cast.rs
@@ -7,6 +7,7 @@
use std::{
any::{Any, TypeId},
collections::HashMap,
+ sync::Arc,
};
use once_cell::sync::Lazy;
@@ -20,7 +21,7 @@ use crate::Object;
/// `derive_deftly(Object)`.
///
/// You shouldn't use this directly; instead use
-/// [`ObjectRefExt`](super::ObjectRefExt).
+/// [`ObjectArcExt`](super::ObjectArcExt).
///
/// Note that the concrete object type `O`
/// is *not* represented in the type of `CastTable`;
@@ -34,14 +35,27 @@ pub struct CastTable {
/// Every entry in this table must contain:
///
/// * A key that is `typeid::of::<&'static dyn Tr>()` for some trait `Tr`.
- /// * A function of type `fn(&dyn Object) -> &dyn Tr` for the same trait
- /// `Tr`. This function must accept a `&dyn Object` whose concrete type
- /// is actually `O`, and it SHOULD panic for other input types.
- ///
- /// Note that we use `Box` here in order to support generic types: you can't
- /// get a `&'static` reference to a function that takes a generic type in
- /// current rust.
- table: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
+ /// * A [`Caster`] whose functions are suitable for casting objects from this table's
+ /// type to `dyn Tr`.
+ table: HashMap<TypeId, Caster>,
+}
+
+/// A single entry in a `CastTable`.
+///
+/// Each `Caster` exists for one concrete object type "`O`", and one trait type "`Tr`".
+///
+/// Note that we use `Box` here in order to support generic types: you can't
+/// get a `&'static` reference to a function that takes a generic type in
+/// current rust.
+struct Caster {
+ /// Actual type: `fn(Arc<dyn Object>) -> Arc<dyn Tr>`
+ ///
+ /// Panics if Object does not have the expected type (`O`).
+ cast_to_ref: Box<dyn Any + Send + Sync>,
+ /// Actual type: `fn(Arc<dyn Object>) -> Arc<dyn Tr>`
+ ///
+ /// Panics if Object does not have the expected type (`O`).
+ cast_to_arc: Box<dyn Any + Send + Sync>,
}
impl CastTable {
@@ -55,8 +69,11 @@ impl CastTable {
/// `T` must be `dyn Tr` for some trait `Tr`.
/// (Not checked by the compiler.)
///
- /// `func` is a downcaster from `&dyn Object` to `&dyn Tr`.
- /// `func` SHOULD
+ /// `cast_to_ref` is a downcaster from `&dyn Object` to `&dyn Tr`.
+ ///
+ /// `cast_to_arc` is a downcaster from `Arc<dyn Object>` to `Arc<dyn Tr>``
+ ///
+ /// These functions SHOULD
/// panic if the concrete type of its argument is not the concrete type `O`
/// associated with this `CastTable`.
///
@@ -74,8 +91,17 @@ impl CastTable {
// We insert and look up by `TypeId::of::<&'static dyn SomeTrait>`,
// which must mean `&'static (dyn SomeTrait + 'static)`
// since a 'static reference to anything non-'static is an ill-formed type.
- pub fn insert<T: 'static + ?Sized>(&mut self, func: fn(&dyn Object) -> &T) {
- self.insert_erased(TypeId::of::<&'static T>(), Box::new(func) as _);
+ pub fn insert<T: 'static + ?Sized>(
+ &mut self,
+ cast_to_ref: fn(&dyn Object) -> &T,
+ cast_to_arc: fn(Arc<dyn Object>) -> Arc<T>,
+ ) {
+ let type_id = TypeId::of::<&'static T>();
+ let caster = Caster {
+ cast_to_ref: Box::new(cast_to_ref),
+ cast_to_arc: Box::new(cast_to_arc),
+ };
+ self.insert_erased(type_id, caster);
}
/// Implementation for adding an entry to the `CastTable`
@@ -87,8 +113,8 @@ impl CastTable {
/// Like `insert`, but less compile-time checking.
/// `type_id` is the identity of `&'static dyn Tr`,
/// and `func` has been boxed and type-erased.
- fn insert_erased(&mut self, type_id: TypeId, func: Box<dyn Any + Send + Sync>) {
- let old_val = self.table.insert(type_id, func);
+ fn insert_erased(&mut self, type_id: TypeId, caster: Caster) {
+ let old_val = self.table.insert(type_id, caster);
assert!(
old_val.is_none(),
"Tried to insert a duplicate entry in a cast table.",
@@ -102,6 +128,7 @@ impl CastTable {
/// `T` should be `dyn Tr`.
/// If `T` is not one of the `dyn Tr` for which `insert` was called,
/// returns `None`.
+ ///
/// # Panics
///
/// Panics if the concrete type of `obj` does not match `O`.
@@ -110,12 +137,40 @@ impl CastTable {
/// violated.
pub fn cast_object_to<'a, T: 'static + ?Sized>(&self, obj: &'a dyn Object) -> Option<&'a T> {
let target_type = TypeId::of::<&'static T>();
- let caster = self.table.get(&target_type)?.as_ref();
+ let caster = self.table.get(&target_type)?;
let caster: &fn(&dyn Object) -> &T = caster
+ .cast_to_ref
.downcast_ref()
.expect("Incorrect cast-function type found in cast table!");
Some(caster(obj))
}
+
+ /// As [`cast_object_to`](CastTable::cast_object_to), but returns an `Arc<dyn Tr>`.
+ ///
+ /// If `T` is not one of the `dyn Tr` types for which `insert_arc` was called,
+ /// return `Err(obj)`.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the concrete type of `obj` does not match `O`.
+ ///
+ /// May panic if any of the Requirements for [`CastTable::insert_arc`] were
+ /// violated.
+ pub fn cast_object_to_arc<T: 'static + ?Sized>(
+ &self,
+ obj: Arc<dyn Object>,
+ ) -> Result<Arc<T>, Arc<dyn Object>> {
+ let target_type = TypeId::of::<&'static T>();
+ let caster = match self.table.get(&target_type) {
+ Some(c) => c,
+ None => return Err(obj),
+ };
+ let caster: &fn(Arc<dyn Object>) -> Arc<T> = caster
+ .cast_to_arc
+ .downcast_ref()
+ .expect("Incorrect cast-function type found in cast table!");
+ Ok(caster(obj))
+ }
}
/// Static cast table that doesn't support casting anything to anything.
@@ -140,15 +195,24 @@ macro_rules! cast_table_deftness_helper{
#[allow(unused_mut)]
let mut table = $crate::CastTable::default();
$({
- // `f` is the actual function that does the downcasting.
+ use std::sync::Arc;
+ // These are the actual functions that does the downcasting.
// It works by downcasting with Any to the concrete type, and then
// upcasting from the concrete type to &dyn Trait.
- let f: fn(&dyn $crate::Object) -> &(dyn $traitname + 'static) = |self_| {
+ let cast_to_ref: fn(&dyn $crate::Object) -> &(dyn $traitname + 'static) = |self_| {
let self_: &Self = self_.downcast_ref().unwrap();
let self_: &dyn $traitname = self_ as _;
self_
};
- table.insert::<dyn $traitname>(f);
+ let cast_to_arc: fn(Arc<dyn $crate::Object>) -> Arc<dyn $traitname> = |self_| {
+ let self_: Arc<Self> = self_
+ .downcast_arc()
+ .ok()
+ .expect("used with incorrect type");
+ let self_: Arc<dyn $traitname> = self_ as _;
+ self_
+ };
+ table.insert::<dyn $traitname>(cast_to_ref, cast_to_arc);
})*
table
}
@@ -189,6 +253,10 @@ mod test {
let tab = Simple::make_cast_table();
let obj: &dyn Object = &concrete;
let _cast: &(dyn Tr1 + '_) = tab.cast_object_to(obj).expect("cast failed");
+
+ let arc = Arc::new(Simple);
+ let arc_obj: Arc<dyn Object> = arc.clone();
+ let _cast: Arc<dyn Tr1> = tab.cast_object_to_arc(arc_obj).ok().expect("cast failed");
}
#[derive(Deftly)]
@@ -205,5 +273,9 @@ mod test {
let tab = Generic::<&'static str>::make_cast_table();
let obj: &dyn Object = &gen;
let _cast: &(dyn Tr1 + '_) = tab.cast_object_to(obj).expect("cast failed");
+
+ let arc = Arc::new(Generic("bar"));
+ let arc_obj: Arc<dyn Object> = arc.clone();
+ let _cast: Arc<dyn Tr2> = tab.cast_object_to_arc(arc_obj).ok().expect("cast failed");
}
}