2
3
4
5
6
7
8
9
10
11
15#include "behaviortree_cpp/contrib/any.hpp"
16#include "behaviortree_cpp/contrib/expected.hpp"
24#include <shared_mutex>
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
49 using CastFunction = std::function<linb::any(
const linb::any&)>;
51 PolymorphicCastRegistry() =
default;
52 ~PolymorphicCastRegistry() =
default;
61
62
63
64
65
66
67
68
69
70
71 template <
typename Derived,
typename Base>
74 static_assert(std::is_base_of_v<Base, Derived>,
"Derived must inherit from Base");
75 static_assert(std::is_polymorphic_v<Base>,
"Base must be polymorphic (have virtual "
78 std::unique_lock<std::shared_mutex> lock(mutex_);
81 auto upcast_key = std::make_pair(std::type_index(
typeid(std::shared_ptr<Derived>)),
82 std::type_index(
typeid(std::shared_ptr<Base>)));
84 upcasts_[upcast_key] = [](
const linb::any& from) -> linb::any {
85 auto ptr = linb::any_cast<std::shared_ptr<Derived>>(from);
86 return std::static_pointer_cast<Base>(ptr);
90 auto downcast_key = std::make_pair(std::type_index(
typeid(std::shared_ptr<Base>)),
91 std::type_index(
typeid(std::shared_ptr<Derived>)));
93 downcasts_[downcast_key] = [](
const linb::any& from) -> linb::any {
94 auto ptr = linb::any_cast<std::shared_ptr<Base>>(from);
95 auto derived = std::dynamic_pointer_cast<Derived>(ptr);
98 throw std::bad_cast();
104 base_types_[std::type_index(
typeid(std::shared_ptr<Derived>))].insert(
105 std::type_index(
typeid(std::shared_ptr<Base>)));
109
110
111
112
113
114
115
117 std::type_index to_type)
const
119 if(from_type == to_type)
124 std::shared_lock<std::shared_mutex> lock(mutex_);
127 auto upcast_key = std::make_pair(from_type, to_type);
128 if(upcasts_.find(upcast_key) != upcasts_.end())
134 if(canUpcastTransitive(from_type, to_type))
140 auto downcast_key = std::make_pair(from_type, to_type);
141 if(downcasts_.find(downcast_key) != downcasts_.end())
150
151
152
153
154
155 [[
nodiscard]]
bool canUpcast(std::type_index from_type, std::type_index to_type)
const
157 if(from_type == to_type)
162 std::shared_lock<std::shared_mutex> lock(mutex_);
163 return canUpcastTransitive(from_type, to_type);
167
168
169
170
171
172
173
174 [[
nodiscard]] nonstd::expected<linb::any, std::string>
175 tryCast(
const linb::any& from, std::type_index from_type, std::type_index to_type)
const
177 if(from_type == to_type)
182 std::shared_lock<std::shared_mutex> lock(mutex_);
185 auto upcast_key = std::make_pair(from_type, to_type);
186 auto upcast_it = upcasts_.find(upcast_key);
187 if(upcast_it != upcasts_.end())
191 return upcast_it->second(from);
193 catch(
const std::exception& e)
195 return nonstd::make_unexpected(std::string(
"Direct upcast failed: ") + e.what());
200 auto transitive_up = applyTransitiveCasts(from, from_type, to_type, upcasts_,
true);
203 return transitive_up;
207 auto downcast_key = std::make_pair(from_type, to_type);
208 auto downcast_it = downcasts_.find(downcast_key);
209 if(downcast_it != downcasts_.end())
213 return downcast_it->second(from);
215 catch(
const std::exception& e)
217 return nonstd::make_unexpected(std::string(
"Downcast failed "
218 "(dynamic_pointer_cast returned "
225 auto transitive_down =
226 applyTransitiveCasts(from, to_type, from_type, downcasts_,
false);
229 return transitive_down;
232 return nonstd::make_unexpected(std::string(
"No registered polymorphic conversion "
237
238
241 std::shared_lock<std::shared_mutex> lock(mutex_);
242 auto it = base_types_.find(type);
243 if(it != base_types_.end())
251
252
255 std::unique_lock<std::shared_mutex> lock(mutex_);
263 [[
nodiscard]]
bool canUpcastTransitive(std::type_index from_type,
264 std::type_index to_type)
const
267 std::set<std::type_index> visited;
268 std::vector<std::type_index> queue;
269 queue.push_back(from_type);
271 while(!queue.empty())
273 auto current = queue.back();
276 if(visited.count(current) != 0)
280 visited.insert(current);
282 auto it = base_types_.find(current);
283 if(it == base_types_.end())
288 for(
const auto& base : it->second)
294 queue.push_back(base);
311 [[
nodiscard]] nonstd::expected<linb::any, std::string> applyTransitiveCasts(
312 const linb::any& from, std::type_index dfs_start, std::type_index dfs_target,
313 const CastMap& cast_map,
bool reverse_path)
const
316 std::map<std::type_index, std::type_index> parent;
317 std::vector<std::type_index> stack;
318 stack.push_back(dfs_start);
319 parent.insert({ dfs_start, dfs_start });
321 while(!stack.empty())
323 auto current = stack.back();
326 auto it = base_types_.find(current);
327 if(it == base_types_.end())
332 for(
const auto& base : it->second)
334 if(parent.find(base) != parent.end())
338 parent.insert({ base, current });
339 if(base == dfs_target)
342 std::vector<std::type_index> path;
343 auto node = dfs_target;
344 while(node != dfs_start)
346 path.push_back(node);
347 node = parent.at(node);
349 path.push_back(dfs_start);
353 std::reverse(path.begin(), path.end());
357 linb::any current_value = from;
358 for(size_t i = 0; i + 1 < path.size(); ++i)
360 auto cast_key = std::make_pair(path[i], path[i + 1]);
361 auto cast_it = cast_map.find(cast_key);
362 if(cast_it == cast_map.end())
364 return nonstd::make_unexpected(std::string(
"Transitive cast: missing step "
369 current_value = cast_it->second(current_value);
371 catch(
const std::exception& e)
373 return nonstd::make_unexpected(std::string(
"Transitive cast step "
378 return current_value;
380 stack.push_back(base);
383 return nonstd::make_unexpected(std::string(
"No transitive path found"));
386 mutable std::shared_mutex mutex_;
387 std::map<std::pair<std::type_index, std::type_index>, CastFunction> upcasts_;
388 std::map<std::pair<std::type_index, std::type_index>, CastFunction> downcasts_;
389 std::map<std::type_index, std::set<std::type_index>> base_types_;
Registry for polymorphic shared_ptr cast relationships.
Definition: polymorphic_cast_registry.hpp:47
void clear()
Clear all registrations (mainly for testing).
Definition: polymorphic_cast_registry.hpp:253
std::set< std::type_index > getBaseTypes(std::type_index type) const
Get all registered base types for a given type.
Definition: polymorphic_cast_registry.hpp:239
void registerCast()
Register a Derived -> Base inheritance relationship.
Definition: polymorphic_cast_registry.hpp:72
bool isConvertible(std::type_index from_type, std::type_index to_type) const
Check if from_type can be converted to to_type.
Definition: polymorphic_cast_registry.hpp:116
bool canUpcast(std::type_index from_type, std::type_index to_type) const
Check if from_type can be UPCAST to to_type (not downcast).
Definition: polymorphic_cast_registry.hpp:155
nonstd::expected< linb::any, std::string > tryCast(const linb::any &from, std::type_index from_type, std::type_index to_type) const
Attempt to cast the value to the target type.
Definition: polymorphic_cast_registry.hpp:175
Definition: action_node.h:24