diff options
author | Nick Mathewson <nickm@torproject.org> | 2024-05-09 21:03:08 +0000 |
---|---|---|
committer | Nick Mathewson <nickm@torproject.org> | 2024-05-09 21:03:08 +0000 |
commit | 1fb9afb478e6edb37071ab6804a92905795fd81e (patch) | |
tree | 5b9e36055c78a8112151431b9a9e5905d9a8530e | |
parent | 229804dccab97e51f82c997aae0dbe69774f31cf (diff) | |
parent | f7cef4bb0e7fe0ad46d87a3deea58addccd6ed7b (diff) | |
download | arti-1fb9afb478e6edb37071ab6804a92905795fd81e.tar.gz arti-1fb9afb478e6edb37071ab6804a92905795fd81e.zip |
Merge branch 'rpc_stream_preliminaries' into 'main'
RPC: Preliminaries for RPC-stream integration
See merge request tpo/core/arti!2140
-rw-r--r-- | Cargo.lock | 1 | ||||
-rw-r--r-- | crates/arti-client/src/client.rs | 7 | ||||
-rw-r--r-- | crates/arti-client/src/lib.rs | 2 | ||||
-rw-r--r-- | crates/arti-client/src/rpc.rs | 28 | ||||
-rw-r--r-- | crates/arti-rpcserver/src/codecs.rs (renamed from crates/arti-rpcserver/src/streams.rs) | 0 | ||||
-rw-r--r-- | crates/arti-rpcserver/src/connection.rs | 2 | ||||
-rw-r--r-- | crates/arti-rpcserver/src/lib.rs | 2 | ||||
-rw-r--r-- | crates/arti-rpcserver/src/mgr.rs | 9 | ||||
-rw-r--r-- | crates/arti-rpcserver/src/session.rs | 69 | ||||
-rw-r--r-- | crates/tor-proto/src/stream/data.rs | 2 | ||||
-rw-r--r-- | crates/tor-rpcbase/Cargo.toml | 1 | ||||
-rw-r--r-- | crates/tor-rpcbase/src/dispatch.rs | 2 | ||||
-rw-r--r-- | crates/tor-rpcbase/src/lib.rs | 10 | ||||
-rw-r--r-- | crates/tor-rpcbase/src/obj.rs | 47 | ||||
-rw-r--r-- | crates/tor-rpcbase/src/obj/cast.rs | 110 |
15 files changed, 241 insertions, 51 deletions
diff --git a/Cargo.lock b/Cargo.lock index 9810cf256..9540aa3aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5834,6 +5834,7 @@ version = "0.18.0" dependencies = [ "assert-impl", "derive-deftly", + "derive_more", "downcast-rs", "erased-serde", "futures", diff --git a/crates/arti-client/src/client.rs b/crates/arti-client/src/client.rs index 8e5e69fa0..77646f19e 100644 --- a/crates/arti-client/src/client.rs +++ b/crates/arti-client/src/client.rs @@ -70,7 +70,12 @@ use tracing::{debug, info}; // implicit Arcs inside them! maybe it's time to replace much of the insides of // this with an Arc<TorClientInner>? #[derive(Clone)] -#[cfg_attr(feature = "rpc", derive(Deftly), derive_deftly(Object))] +#[cfg_attr( + feature = "rpc", + derive(Deftly), + derive_deftly(Object), + deftly(rpc(expose_outside_of_session)) +)] pub struct TorClient<R: Runtime> { /// Asynchronous runtime object. runtime: R, diff --git a/crates/arti-client/src/lib.rs b/crates/arti-client/src/lib.rs index 0d2948015..bc2229366 100644 --- a/crates/arti-client/src/lib.rs +++ b/crates/arti-client/src/lib.rs @@ -45,7 +45,7 @@ mod address; mod builder; mod client; #[cfg(feature = "rpc")] -mod rpc; +pub mod rpc; mod util; pub mod config; diff --git a/crates/arti-client/src/rpc.rs b/crates/arti-client/src/rpc.rs index 6af391290..9e866dead 100644 --- a/crates/arti-client/src/rpc.rs +++ b/crates/arti-client/src/rpc.rs @@ -17,8 +17,9 @@ impl<R: Runtime> crate::TorClient<R> { /// parameterized. pub fn rpc_methods() -> Vec<rpc::dispatch::InvokerEnt> { rpc::invoker_ent_list![ - get_client_status::<R>, // - watch_client_status::<R> + get_client_status::<R>, + watch_client_status::<R>, + isolated_client::<R>, ] } } @@ -113,3 +114,26 @@ async fn watch_client_status<R: Runtime>( // This can only happen if the client exits. Ok(rpc::NIL) } + +/// RPC method: Return an owned ID for a new isolated client instance. +#[derive(Deftly, Debug, Serialize, Deserialize)] +#[derive_deftly(rpc::DynMethod)] +#[deftly(rpc(method_name = "arti::isolated-client"))] +#[non_exhaustive] +pub struct IsolatedClient {} + +impl rpc::Method for IsolatedClient { + type Output = rpc::SingletonId; + type Update = rpc::NoUpdates; +} + +/// RPC method implementation: return a new isolated client based on a given client. +async fn isolated_client<R: Runtime>( + client: Arc<TorClient<R>>, + _method: Box<IsolatedClient>, + ctx: Box<dyn rpc::Context>, +) -> Result<rpc::SingletonId, rpc::RpcError> { + let new_client = Arc::new(client.isolated_client()); + let client_id = ctx.register_owned(new_client); + Ok(rpc::SingletonId::from(client_id)) +} diff --git a/crates/arti-rpcserver/src/streams.rs b/crates/arti-rpcserver/src/codecs.rs index 43ebc6524..43ebc6524 100644 --- a/crates/arti-rpcserver/src/streams.rs +++ b/crates/arti-rpcserver/src/codecs.rs diff --git a/crates/arti-rpcserver/src/connection.rs b/crates/arti-rpcserver/src/connection.rs index 44c41ab8d..807ba5de6 100644 --- a/crates/arti-rpcserver/src/connection.rs +++ b/crates/arti-rpcserver/src/connection.rs @@ -223,7 +223,7 @@ impl Connection { { let write = Box::pin(asynchronous_codec::FramedWrite::new( output, - crate::streams::JsonLinesEncoder::<BoxedResponse>::default(), + crate::codecs::JsonLinesEncoder::<BoxedResponse>::default(), )); let read = Box::pin( diff --git a/crates/arti-rpcserver/src/lib.rs b/crates/arti-rpcserver/src/lib.rs index 0a244435c..16758e217 100644 --- a/crates/arti-rpcserver/src/lib.rs +++ b/crates/arti-rpcserver/src/lib.rs @@ -41,6 +41,7 @@ //! <!-- @@ end lint list maintained by maint/add_warning @@ --> mod cancel; +mod codecs; mod connection; mod err; mod globalid; @@ -48,7 +49,6 @@ mod mgr; mod msgs; mod objmap; mod session; -mod streams; pub use connection::{auth::RpcAuthentication, Connection, ConnectionError}; pub use mgr::RpcMgr; diff --git a/crates/arti-rpcserver/src/mgr.rs b/crates/arti-rpcserver/src/mgr.rs index b66f6bdef..f81261230 100644 --- a/crates/arti-rpcserver/src/mgr.rs +++ b/crates/arti-rpcserver/src/mgr.rs @@ -9,13 +9,13 @@ use weak_table::WeakValueHashMap; use crate::{ connection::{Connection, ConnectionId}, globalid::{GlobalId, MacKey}, - RpcAuthentication, RpcSession, + RpcAuthentication, }; /// A function we use to construct Session objects in response to authentication. // // TODO RPC: Perhaps this should return a Result? -type SessionFactory = Box<dyn Fn(&RpcAuthentication) -> Arc<RpcSession> + Send + Sync>; +type SessionFactory = Box<dyn Fn(&RpcAuthentication) -> Arc<dyn rpc::Object> + Send + Sync>; /// Shared state, configuration, and data for all RPC sessions. /// @@ -93,10 +93,9 @@ pub(crate) struct Inner { impl RpcMgr { /// Create a new RpcMgr. - /// pub fn new<F>(make_session: F) -> Arc<Self> where - F: Fn(&RpcAuthentication) -> Arc<RpcSession> + Send + Sync + 'static, + F: Fn(&RpcAuthentication) -> Arc<dyn rpc::Object> + Send + Sync + 'static, { Arc::new(RpcMgr { global_id_mac_key: MacKey::new(&mut rand::thread_rng()), @@ -189,7 +188,7 @@ impl RpcMgr { } /// Construct a new object to serve as the `session` for a connection. - pub(crate) fn create_session(&self, auth: &RpcAuthentication) -> Arc<RpcSession> { + pub(crate) fn create_session(&self, auth: &RpcAuthentication) -> Arc<dyn rpc::Object> { (self.session_factory)(auth) } } diff --git a/crates/arti-rpcserver/src/session.rs b/crates/arti-rpcserver/src/session.rs index b9485d96f..0db7f1402 100644 --- a/crates/arti-rpcserver/src/session.rs +++ b/crates/arti-rpcserver/src/session.rs @@ -3,8 +3,11 @@ //! A "session" is created when a user authenticates on an RPC connection. It //! is the root for all other RPC capabilities. +use arti_client::TorClient; use derive_deftly::Deftly; +use rpc::static_rpc_invoke_fn; use std::sync::Arc; +use tor_rtcompat::Runtime; use tor_rpcbase as rpc; use tor_rpcbase::templates::*; @@ -27,14 +30,31 @@ pub struct RpcSession { /// An inner TorClient object that we use to implement remaining /// functionality. #[allow(unused)] - client: Arc<dyn rpc::Object>, + client: Arc<dyn Client>, +} + +/// Type-erased `TorClient``, as used within an RpcSession. +trait Client: rpc::Object { + /// Return a new isolated TorClient. + fn isolated_client(&self) -> Arc<dyn rpc::Object>; + + /// Upcast `self` to an rpc::Object. + fn upcast_arc(self: Arc<Self>) -> Arc<dyn rpc::Object>; +} + +impl<R: Runtime> Client for TorClient<R> { + fn isolated_client(&self) -> Arc<dyn rpc::Object> { + Arc::new(TorClient::isolated_client(self)) + } + + fn upcast_arc(self: Arc<Self>) -> Arc<dyn rpc::Object> { + self + } } impl RpcSession { /// Create a new session object containing a single client object. - pub fn new_with_client<R: tor_rtcompat::Runtime>( - client: Arc<arti_client::TorClient<R>>, - ) -> Arc<Self> { + pub fn new_with_client<R: Runtime>(client: Arc<arti_client::TorClient<R>>) -> Arc<Self> { Arc::new(Self { client }) } } @@ -76,9 +96,6 @@ async fn rpc_release( ctx.release_owned(&method.obj)?; Ok(rpc::Nil::default()) } -rpc::static_rpc_invoke_fn! { - rpc_release; -} /// A simple temporary method to echo a reply. #[derive(Debug, serde::Deserialize, serde::Serialize, Deftly)] @@ -105,6 +122,42 @@ async fn echo_on_session( Ok(*method) } -rpc::static_rpc_invoke_fn! { +/// An RPC method to return the default client for a session. +#[derive(Debug, serde::Deserialize, serde::Serialize, Deftly)] +#[derive_deftly(DynMethod)] +#[deftly(rpc(method_name = "arti:get-client"))] +struct GetClient {} + +impl rpc::Method for GetClient { + type Output = rpc::SingletonId; + type Update = rpc::NoUpdates; +} + +/// Implement GetClient on an RpcSession. +async fn get_client_on_session( + session: Arc<RpcSession>, + _method: Box<GetClient>, + ctx: Box<dyn rpc::Context>, +) -> Result<rpc::SingletonId, rpc::RpcError> { + Ok(rpc::SingletonId::from( + // TODO RPC: This relies (somewhat) on deduplication properties for register_owned. + ctx.register_owned(session.client.clone().upcast_arc()), + )) +} + +/// Implement IsolatedClient on an RpcSession. +async fn isolated_client_on_session( + session: Arc<RpcSession>, + _method: Box<arti_client::rpc::IsolatedClient>, + ctx: Box<dyn rpc::Context>, +) -> Result<rpc::SingletonId, rpc::RpcError> { + let new_client = session.client.isolated_client(); + Ok(rpc::SingletonId::from(ctx.register_owned(new_client))) +} + +static_rpc_invoke_fn! { + rpc_release; echo_on_session; + get_client_on_session; + isolated_client_on_session; } diff --git a/crates/tor-proto/src/stream/data.rs b/crates/tor-proto/src/stream/data.rs index 84f8040ff..d74ef5b67 100644 --- a/crates/tor-proto/src/stream/data.rs +++ b/crates/tor-proto/src/stream/data.rs @@ -382,7 +382,7 @@ impl DataStream { /// is received to indicate an error. /// /// Does nothing if this stream is already connected. - pub(crate) async fn wait_for_connection(&mut self) -> Result<()> { + pub async fn wait_for_connection(&mut self) -> Result<()> { // We must put state back before returning let state = self.r.state.take().expect("Missing state in DataReader"); diff --git a/crates/tor-rpcbase/Cargo.toml b/crates/tor-rpcbase/Cargo.toml index 78718eeb2..813dae39a 100644 --- a/crates/tor-rpcbase/Cargo.toml +++ b/crates/tor-rpcbase/Cargo.toml @@ -13,6 +13,7 @@ repository = "https://gitlab.torproject.org/tpo/core/arti.git/" [dependencies] derive-deftly = "0.10.3" +derive_more = "0.99.3" downcast-rs = "1.2.1" erased-serde = "0.4.2" futures = "0.3.14" diff --git a/crates/tor-rpcbase/src/dispatch.rs b/crates/tor-rpcbase/src/dispatch.rs index f15c3b575..ba485bb36 100644 --- a/crates/tor-rpcbase/src/dispatch.rs +++ b/crates/tor-rpcbase/src/dispatch.rs @@ -282,7 +282,7 @@ macro_rules! invoker_ent { /// ``` #[macro_export] macro_rules! invoker_ent_list { - { $($func:expr),* } => { + { $($func:expr),* $(,)? } => { vec![ $( $crate::invoker_ent!($func) diff --git a/crates/tor-rpcbase/src/lib.rs b/crates/tor-rpcbase/src/lib.rs index d851351c3..3b9b65a93 100644 --- a/crates/tor-rpcbase/src/lib.rs +++ b/crates/tor-rpcbase/src/lib.rs @@ -50,7 +50,7 @@ use std::{convert::Infallible, sync::Arc}; pub use dispatch::{DispatchTable, InvokeError, UpdateSink}; pub use err::RpcError; pub use method::{is_method_name, iter_method_names, DynMethod, Method, NoUpdates}; -pub use obj::{Object, ObjectId, ObjectRefExt}; +pub use obj::{Object, ObjectArcExt, ObjectId}; #[doc(hidden)] pub use obj::cast::CastTable; @@ -174,3 +174,11 @@ impl<T: Context> ContextExt for T {} pub struct Nil {} /// An instance of rpc::Nil. pub const NIL: Nil = Nil {}; + +/// Common return type for RPC methods that return a single object ID +/// and nothing else. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, derive_more::From)] +pub struct SingletonId { + /// The ID of the object that we're returning. + id: ObjectId, +} diff --git a/crates/tor-rpcbase/src/obj.rs b/crates/tor-rpcbase/src/obj.rs index 6438457d9..cf2cb5461 100644 --- a/crates/tor-rpcbase/src/obj.rs +++ b/crates/tor-rpcbase/src/obj.rs @@ -2,6 +2,8 @@ pub(crate) mod cast; +use std::sync::Arc; + use derive_deftly::define_derive_deftly; use downcast_rs::DowncastSync; use serde::{Deserialize, Serialize}; @@ -75,7 +77,7 @@ where } } -/// Extension trait for `dyn Object` and similar to support convenient +/// Extension trait for `Arc<dyn Object>` to support convenient /// downcasting to `dyn Trait`. /// /// You don't need to use this for downcasting to an object's concrete @@ -84,7 +86,7 @@ where /// # Examples /// /// ``` -/// use tor_rpcbase::{Object, ObjectRefExt, templates::*}; +/// use tor_rpcbase::{Object, ObjectArcExt, templates::*}; /// use derive_deftly::Deftly; /// use std::sync::Arc; /// @@ -111,22 +113,35 @@ where /// /// assert_eq!(check_feet(Arc::new(Frog{})), 4); /// ``` -pub trait ObjectRefExt { - /// Try to cast this `Object` to a `T`. On success, return a reference to +pub trait ObjectArcExt { + /// Try to cast this `Arc<dyn Object>` to a `T`. On success, return a reference to /// T; on failure, return None. fn cast_to_trait<T: ?Sized + 'static>(&self) -> Option<&T>; + + /// Try to cast this `Arc<dyn Object>` to an `Arc<T>`. + fn cast_to_arc_trait<T: ?Sized + 'static>(self) -> Result<Arc<T>, Arc<dyn Object>>; } -impl ObjectRefExt for dyn Object { - fn cast_to_trait<T: ?Sized + 'static>(&self) -> Option<&T> { +impl dyn Object { + /// Try to cast this `Object` to a `T`. On success, return a reference to + /// T; on failure, return None. + /// + /// This method is only for casting to `&dyn Trait`; + /// see [`ObjectArcExt`] for limitations. + pub fn cast_to_trait<T: ?Sized + 'static>(&self) -> Option<&T> { let table = self.get_cast_table(); table.cast_object_to(self) } } -impl ObjectRefExt for std::sync::Arc<dyn Object> { +impl ObjectArcExt for Arc<dyn Object> { fn cast_to_trait<T: ?Sized + 'static>(&self) -> Option<&T> { - self.as_ref().cast_to_trait() + let obj: &dyn Object = self.as_ref(); + obj.cast_to_trait() + } + fn cast_to_arc_trait<T: ?Sized + 'static>(self) -> Result<Arc<T>, Arc<dyn Object>> { + let table = self.get_cast_table(); + table.cast_object_to_arc(self.clone()) } } @@ -181,7 +196,7 @@ define_derive_deftly! { /// impl Doodad for Frobnitz {} /// /// use std::sync::Arc; -/// use rpc::ObjectRefExt; // for the cast_to method. +/// use rpc::ObjectArcExt; // for the cast_to method. /// let frob_obj: Arc<dyn rpc::Object> = Arc::new(Frobnitz {}); /// let gizmo: &dyn Gizmo = frob_obj.cast_to_trait().unwrap(); /// let doodad: &dyn Doodad = frob_obj.cast_to_trait().unwrap(); @@ -209,7 +224,7 @@ define_derive_deftly! { /// impl<T:Clone,U:PartialEq> ExampleTrait for Generic<T,U> {} /// /// use std::sync::Arc; -/// use rpc::ObjectRefExt; // for the cast_to method. +/// use rpc::ObjectArcExt; // for the cast_to method. /// let obj: Arc<dyn rpc::Object> = Arc::new(Generic { t: 42_u8, u: 42_u8 }); /// let tr: &dyn ExampleTrait = obj.cast_to_trait().unwrap(); /// ``` @@ -381,5 +396,17 @@ mod test { let erased_bikes: &dyn Object = &bikes; let has_wheels: &dyn HasWheels = erased_bikes.cast_to_trait().unwrap(); assert_eq!(has_wheels.num_wheels(), 4); + + let arc_bikes = Arc::new(bikes); + let erased_arc_bytes: Arc<dyn Object> = arc_bikes.clone(); + let arc_has_wheels: Arc<dyn HasWheels> = + erased_arc_bytes.clone().cast_to_arc_trait().ok().unwrap(); + assert_eq!(arc_has_wheels.num_wheels(), 4); + + trait SomethingElse {} + let arc_something_else: Result<Arc<dyn SomethingElse>, _> = + erased_arc_bytes.clone().cast_to_arc_trait(); + let err_arc = arc_something_else.err().unwrap(); + assert!(Arc::ptr_eq(&err_arc, &erased_arc_bytes)); } } diff --git a/crates/tor-rpcbase/src/obj/cast.rs b/crates/tor-rpcbase/src/obj/cast.rs index dafaa54b3..8b9c46a05 100644 --- a/crates/tor-rpcbase/src/obj/cast.rs +++ b/crates/tor-rpcbase/src/obj/cast.rs @@ -7,6 +7,7 @@ use std::{ any::{Any, TypeId}, collections::HashMap, + sync::Arc, }; use once_cell::sync::Lazy; @@ -20,7 +21,7 @@ use crate::Object; /// `derive_deftly(Object)`. /// /// You shouldn't use this directly; instead use -/// [`ObjectRefExt`](super::ObjectRefExt). +/// [`ObjectArcExt`](super::ObjectArcExt). /// /// Note that the concrete object type `O` /// is *not* represented in the type of `CastTable`; @@ -34,14 +35,27 @@ pub struct CastTable { /// Every entry in this table must contain: /// /// * A key that is `typeid::of::<&'static dyn Tr>()` for some trait `Tr`. - /// * A function of type `fn(&dyn Object) -> &dyn Tr` for the same trait - /// `Tr`. This function must accept a `&dyn Object` whose concrete type - /// is actually `O`, and it SHOULD panic for other input types. - /// - /// Note that we use `Box` here in order to support generic types: you can't - /// get a `&'static` reference to a function that takes a generic type in - /// current rust. - table: HashMap<TypeId, Box<dyn Any + Send + Sync>>, + /// * A [`Caster`] whose functions are suitable for casting objects from this table's + /// type to `dyn Tr`. + table: HashMap<TypeId, Caster>, +} + +/// A single entry in a `CastTable`. +/// +/// Each `Caster` exists for one concrete object type "`O`", and one trait type "`Tr`". +/// +/// Note that we use `Box` here in order to support generic types: you can't +/// get a `&'static` reference to a function that takes a generic type in +/// current rust. +struct Caster { + /// Actual type: `fn(Arc<dyn Object>) -> Arc<dyn Tr>` + /// + /// Panics if Object does not have the expected type (`O`). + cast_to_ref: Box<dyn Any + Send + Sync>, + /// Actual type: `fn(Arc<dyn Object>) -> Arc<dyn Tr>` + /// + /// Panics if Object does not have the expected type (`O`). + cast_to_arc: Box<dyn Any + Send + Sync>, } impl CastTable { @@ -55,8 +69,11 @@ impl CastTable { /// `T` must be `dyn Tr` for some trait `Tr`. /// (Not checked by the compiler.) /// - /// `func` is a downcaster from `&dyn Object` to `&dyn Tr`. - /// `func` SHOULD + /// `cast_to_ref` is a downcaster from `&dyn Object` to `&dyn Tr`. + /// + /// `cast_to_arc` is a downcaster from `Arc<dyn Object>` to `Arc<dyn Tr>`` + /// + /// These functions SHOULD /// panic if the concrete type of its argument is not the concrete type `O` /// associated with this `CastTable`. /// @@ -74,8 +91,17 @@ impl CastTable { // We insert and look up by `TypeId::of::<&'static dyn SomeTrait>`, // which must mean `&'static (dyn SomeTrait + 'static)` // since a 'static reference to anything non-'static is an ill-formed type. - pub fn insert<T: 'static + ?Sized>(&mut self, func: fn(&dyn Object) -> &T) { - self.insert_erased(TypeId::of::<&'static T>(), Box::new(func) as _); + pub fn insert<T: 'static + ?Sized>( + &mut self, + cast_to_ref: fn(&dyn Object) -> &T, + cast_to_arc: fn(Arc<dyn Object>) -> Arc<T>, + ) { + let type_id = TypeId::of::<&'static T>(); + let caster = Caster { + cast_to_ref: Box::new(cast_to_ref), + cast_to_arc: Box::new(cast_to_arc), + }; + self.insert_erased(type_id, caster); } /// Implementation for adding an entry to the `CastTable` @@ -87,8 +113,8 @@ impl CastTable { /// Like `insert`, but less compile-time checking. /// `type_id` is the identity of `&'static dyn Tr`, /// and `func` has been boxed and type-erased. - fn insert_erased(&mut self, type_id: TypeId, func: Box<dyn Any + Send + Sync>) { - let old_val = self.table.insert(type_id, func); + fn insert_erased(&mut self, type_id: TypeId, caster: Caster) { + let old_val = self.table.insert(type_id, caster); assert!( old_val.is_none(), "Tried to insert a duplicate entry in a cast table.", @@ -102,6 +128,7 @@ impl CastTable { /// `T` should be `dyn Tr`. /// If `T` is not one of the `dyn Tr` for which `insert` was called, /// returns `None`. + /// /// # Panics /// /// Panics if the concrete type of `obj` does not match `O`. @@ -110,12 +137,40 @@ impl CastTable { /// violated. pub fn cast_object_to<'a, T: 'static + ?Sized>(&self, obj: &'a dyn Object) -> Option<&'a T> { let target_type = TypeId::of::<&'static T>(); - let caster = self.table.get(&target_type)?.as_ref(); + let caster = self.table.get(&target_type)?; let caster: &fn(&dyn Object) -> &T = caster + .cast_to_ref .downcast_ref() .expect("Incorrect cast-function type found in cast table!"); Some(caster(obj)) } + + /// As [`cast_object_to`](CastTable::cast_object_to), but returns an `Arc<dyn Tr>`. + /// + /// If `T` is not one of the `dyn Tr` types for which `insert_arc` was called, + /// return `Err(obj)`. + /// + /// # Panics + /// + /// Panics if the concrete type of `obj` does not match `O`. + /// + /// May panic if any of the Requirements for [`CastTable::insert_arc`] were + /// violated. + pub fn cast_object_to_arc<T: 'static + ?Sized>( + &self, + obj: Arc<dyn Object>, + ) -> Result<Arc<T>, Arc<dyn Object>> { + let target_type = TypeId::of::<&'static T>(); + let caster = match self.table.get(&target_type) { + Some(c) => c, + None => return Err(obj), + }; + let caster: &fn(Arc<dyn Object>) -> Arc<T> = caster + .cast_to_arc + .downcast_ref() + .expect("Incorrect cast-function type found in cast table!"); + Ok(caster(obj)) + } } /// Static cast table that doesn't support casting anything to anything. @@ -140,15 +195,24 @@ macro_rules! cast_table_deftness_helper{ #[allow(unused_mut)] let mut table = $crate::CastTable::default(); $({ - // `f` is the actual function that does the downcasting. + use std::sync::Arc; + // These are the actual functions that does the downcasting. // It works by downcasting with Any to the concrete type, and then // upcasting from the concrete type to &dyn Trait. - let f: fn(&dyn $crate::Object) -> &(dyn $traitname + 'static) = |self_| { + let cast_to_ref: fn(&dyn $crate::Object) -> &(dyn $traitname + 'static) = |self_| { let self_: &Self = self_.downcast_ref().unwrap(); let self_: &dyn $traitname = self_ as _; self_ }; - table.insert::<dyn $traitname>(f); + let cast_to_arc: fn(Arc<dyn $crate::Object>) -> Arc<dyn $traitname> = |self_| { + let self_: Arc<Self> = self_ + .downcast_arc() + .ok() + .expect("used with incorrect type"); + let self_: Arc<dyn $traitname> = self_ as _; + self_ + }; + table.insert::<dyn $traitname>(cast_to_ref, cast_to_arc); })* table } @@ -189,6 +253,10 @@ mod test { let tab = Simple::make_cast_table(); let obj: &dyn Object = &concrete; let _cast: &(dyn Tr1 + '_) = tab.cast_object_to(obj).expect("cast failed"); + + let arc = Arc::new(Simple); + let arc_obj: Arc<dyn Object> = arc.clone(); + let _cast: Arc<dyn Tr1> = tab.cast_object_to_arc(arc_obj).ok().expect("cast failed"); } #[derive(Deftly)] @@ -205,5 +273,9 @@ mod test { let tab = Generic::<&'static str>::make_cast_table(); let obj: &dyn Object = &gen; let _cast: &(dyn Tr1 + '_) = tab.cast_object_to(obj).expect("cast failed"); + + let arc = Arc::new(Generic("bar")); + let arc_obj: Arc<dyn Object> = arc.clone(); + let _cast: Arc<dyn Tr2> = tab.cast_object_to_arc(arc_obj).ok().expect("cast failed"); } } |