BehaviorTree
Core Library to create and execute Behavior Trees
Loading...
Searching...
No Matches
polymorphic_cast_registry.hpp
1/* Copyright (C) 2022-2025 Davide Faconti - All Rights Reserved
2*
3* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
4* to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
5* and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
7*
8* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
9* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
10* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
11*/
12
13#pragma once
14
15#include "behaviortree_cpp/contrib/any.hpp"
16#include "behaviortree_cpp/contrib/expected.hpp"
17
18#include <algorithm>
19#include <functional>
20#include <map>
21#include <memory>
22#include <mutex>
23#include <set>
24#include <shared_mutex>
25#include <typeindex>
26
27namespace BT
28{
29
30/**
31 * @brief Registry for polymorphic shared_ptr cast relationships.
32 *
33 * This enables passing shared_ptr<Derived> to ports expecting shared_ptr<Base>
34 * without breaking ABI compatibility. Users register inheritance relationships
35 * at runtime, and the registry handles upcasting/downcasting transparently.
36 *
37 * This class is typically owned by BehaviorTreeFactory and passed to Blackboard
38 * during tree creation. This avoids global state and makes testing easier.
39 *
40 * Usage with BehaviorTreeFactory:
41 * BehaviorTreeFactory factory;
42 * factory.registerPolymorphicCast<Cat, Animal>();
43 * factory.registerPolymorphicCast<Sphynx, Cat>();
44 * auto tree = factory.createTreeFromText(xml);
45 */
47{
48public:
49 using CastFunction = std::function<linb::any(const linb::any&)>;
50
51 PolymorphicCastRegistry() = default;
52 ~PolymorphicCastRegistry() = default;
53
54 // Non-copyable, non-movable (contains mutex)
55 PolymorphicCastRegistry(const PolymorphicCastRegistry&) = delete;
56 PolymorphicCastRegistry& operator=(const PolymorphicCastRegistry&) = delete;
57 PolymorphicCastRegistry(PolymorphicCastRegistry&&) = delete;
59
60 /**
61 * @brief Register a Derived -> Base inheritance relationship.
62 *
63 * This enables:
64 * - Upcasting: shared_ptr<Derived> can be retrieved as shared_ptr<Base>
65 * - Downcasting: shared_ptr<Base> can be retrieved as shared_ptr<Derived>
66 * (via dynamic_pointer_cast, may return nullptr if types don't match)
67 *
68 * @tparam Derived The derived class (must inherit from Base)
69 * @tparam Base The base class (must be polymorphic - have virtual functions)
70 */
71 template <typename Derived, typename Base>
72 void registerCast()
73 {
74 static_assert(std::is_base_of_v<Base, Derived>, "Derived must inherit from Base");
75 static_assert(std::is_polymorphic_v<Base>, "Base must be polymorphic (have virtual "
76 "functions)");
77
78 std::unique_lock<std::shared_mutex> lock(mutex_);
79
80 // Register upcast: Derived -> Base
81 auto upcast_key = std::make_pair(std::type_index(typeid(std::shared_ptr<Derived>)),
82 std::type_index(typeid(std::shared_ptr<Base>)));
83
84 upcasts_[upcast_key] = [](const linb::any& from) -> linb::any {
85 auto ptr = linb::any_cast<std::shared_ptr<Derived>>(from);
86 return std::static_pointer_cast<Base>(ptr);
87 };
88
89 // Register downcast: Base -> Derived (uses dynamic_pointer_cast)
90 auto downcast_key = std::make_pair(std::type_index(typeid(std::shared_ptr<Base>)),
91 std::type_index(typeid(std::shared_ptr<Derived>)));
92
93 downcasts_[downcast_key] = [](const linb::any& from) -> linb::any {
94 auto ptr = linb::any_cast<std::shared_ptr<Base>>(from);
95 auto derived = std::dynamic_pointer_cast<Derived>(ptr);
96 if(!derived)
97 {
98 throw std::bad_cast();
99 }
100 return derived;
101 };
102
103 // Track inheritance relationship for port compatibility checks
104 base_types_[std::type_index(typeid(std::shared_ptr<Derived>))].insert(
105 std::type_index(typeid(std::shared_ptr<Base>)));
106 }
107
108 /**
109 * @brief Check if from_type can be converted to to_type.
110 *
111 * Returns true if:
112 * - from_type == to_type
113 * - from_type is a registered derived type of to_type (upcast)
114 * - to_type is a registered derived type of from_type (downcast)
115 */
116 [[nodiscard]] bool isConvertible(std::type_index from_type,
117 std::type_index to_type) const
118 {
119 if(from_type == to_type)
120 {
121 return true;
122 }
123
124 std::shared_lock<std::shared_mutex> lock(mutex_);
125
126 // Check direct upcast
127 auto upcast_key = std::make_pair(from_type, to_type);
128 if(upcasts_.find(upcast_key) != upcasts_.end())
129 {
130 return true;
131 }
132
133 // Check transitive upcast (e.g., Sphynx -> Cat -> Animal)
134 if(canUpcastTransitive(from_type, to_type))
135 {
136 return true;
137 }
138
139 // Check downcast
140 auto downcast_key = std::make_pair(from_type, to_type);
141 if(downcasts_.find(downcast_key) != downcasts_.end())
142 {
143 return true;
144 }
145
146 return false;
147 }
148
149 /**
150 * @brief Check if from_type can be UPCAST to to_type (not downcast).
151 *
152 * This is stricter than isConvertible - only allows going from
153 * derived to base, not the reverse.
154 */
155 [[nodiscard]] bool canUpcast(std::type_index from_type, std::type_index to_type) const
156 {
157 if(from_type == to_type)
158 {
159 return true;
160 }
161
162 std::shared_lock<std::shared_mutex> lock(mutex_);
163 return canUpcastTransitive(from_type, to_type);
164 }
165
166 /**
167 * @brief Attempt to cast the value to the target type.
168 *
169 * @param from The source any containing a shared_ptr
170 * @param from_type The type_index of the stored type
171 * @param to_type The target type_index
172 * @return The casted any on success, or an error string on failure
173 */
174 [[nodiscard]] nonstd::expected<linb::any, std::string>
175 tryCast(const linb::any& from, std::type_index from_type, std::type_index to_type) const
176 {
177 if(from_type == to_type)
178 {
179 return from;
180 }
181
182 std::shared_lock<std::shared_mutex> lock(mutex_);
183
184 // Try direct upcast
185 auto upcast_key = std::make_pair(from_type, to_type);
186 auto upcast_it = upcasts_.find(upcast_key);
187 if(upcast_it != upcasts_.end())
188 {
189 try
190 {
191 return upcast_it->second(from);
192 }
193 catch(const std::exception& e)
194 {
195 return nonstd::make_unexpected(std::string("Direct upcast failed: ") + e.what());
196 }
197 }
198
199 // Try transitive upcast
200 auto transitive_up = applyTransitiveCasts(from, from_type, to_type, upcasts_, true);
201 if(transitive_up)
202 {
203 return transitive_up;
204 }
205
206 // Try direct downcast
207 auto downcast_key = std::make_pair(from_type, to_type);
208 auto downcast_it = downcasts_.find(downcast_key);
209 if(downcast_it != downcasts_.end())
210 {
211 try
212 {
213 return downcast_it->second(from);
214 }
215 catch(const std::exception& e)
216 {
217 return nonstd::make_unexpected(std::string("Downcast failed "
218 "(dynamic_pointer_cast returned "
219 "null): ") +
220 e.what());
221 }
222 }
223
224 // Try transitive downcast
225 auto transitive_down =
226 applyTransitiveCasts(from, to_type, from_type, downcasts_, false);
227 if(transitive_down)
228 {
229 return transitive_down;
230 }
231
232 return nonstd::make_unexpected(std::string("No registered polymorphic conversion "
233 "available"));
234 }
235
236 /**
237 * @brief Get all registered base types for a given type.
238 */
240 {
241 std::shared_lock<std::shared_mutex> lock(mutex_);
242 auto it = base_types_.find(type);
243 if(it != base_types_.end())
244 {
245 return it->second;
246 }
247 return {};
248 }
249
250 /**
251 * @brief Clear all registrations (mainly for testing).
252 */
253 void clear()
254 {
255 std::unique_lock<std::shared_mutex> lock(mutex_);
256 upcasts_.clear();
257 downcasts_.clear();
258 base_types_.clear();
259 }
260
261private:
262 // Check if we can upcast from_type to to_type through a chain of registered casts
263 [[nodiscard]] bool canUpcastTransitive(std::type_index from_type,
264 std::type_index to_type) const
265 {
266 // Depth-first search to find a path from from_type to to_type
267 std::set<std::type_index> visited;
268 std::vector<std::type_index> queue;
269 queue.push_back(from_type);
270
271 while(!queue.empty())
272 {
273 auto current = queue.back();
274 queue.pop_back();
275
276 if(visited.count(current) != 0)
277 {
278 continue;
279 }
280 visited.insert(current);
281
282 auto it = base_types_.find(current);
283 if(it == base_types_.end())
284 {
285 continue;
286 }
287
288 for(const auto& base : it->second)
289 {
290 if(base == to_type)
291 {
292 return true;
293 }
294 queue.push_back(base);
295 }
296 }
297 return false;
298 }
299
300 // Common helper for transitive upcast and downcast.
301 //
302 // Performs depth-first search from dfs_start through base_types_ edges,
303 // looking for dfs_target. When found, reconstructs the path from dfs_target
304 // back to dfs_start. If reverse_path is true, reverses it so casts are applied
305 // in [dfs_start -> dfs_target] order; otherwise applies in traced order.
306 //
307 // For upcast: dfs_start=from_type, dfs_target=to_type, reverse=true, map=upcasts_
308 // For downcast: dfs_start=to_type, dfs_target=from_type, reverse=false, map=downcasts_
310
311 [[nodiscard]] nonstd::expected<linb::any, std::string> applyTransitiveCasts(
312 const linb::any& from, std::type_index dfs_start, std::type_index dfs_target,
313 const CastMap& cast_map, bool reverse_path) const
314 {
315 // Note: std::type_index has no default constructor, so we can't use operator[]
316 std::map<std::type_index, std::type_index> parent;
317 std::vector<std::type_index> stack;
318 stack.push_back(dfs_start);
319 parent.insert({ dfs_start, dfs_start });
320
321 while(!stack.empty())
322 {
323 auto current = stack.back();
324 stack.pop_back();
325
326 auto it = base_types_.find(current);
327 if(it == base_types_.end())
328 {
329 continue;
330 }
331
332 for(const auto& base : it->second)
333 {
334 if(parent.find(base) != parent.end())
335 {
336 continue;
337 }
338 parent.insert({ base, current });
339 if(base == dfs_target)
340 {
341 // Reconstruct path: trace from dfs_target back to dfs_start
342 std::vector<std::type_index> path;
343 auto node = dfs_target;
344 while(node != dfs_start)
345 {
346 path.push_back(node);
347 node = parent.at(node);
348 }
349 path.push_back(dfs_start);
350
351 if(reverse_path)
352 {
353 std::reverse(path.begin(), path.end());
354 }
355
356 // Apply casts along the path
357 linb::any current_value = from;
358 for(size_t i = 0; i + 1 < path.size(); ++i)
359 {
360 auto cast_key = std::make_pair(path[i], path[i + 1]);
361 auto cast_it = cast_map.find(cast_key);
362 if(cast_it == cast_map.end())
363 {
364 return nonstd::make_unexpected(std::string("Transitive cast: missing step "
365 "in chain"));
366 }
367 try
368 {
369 current_value = cast_it->second(current_value);
370 }
371 catch(const std::exception& e)
372 {
373 return nonstd::make_unexpected(std::string("Transitive cast step "
374 "failed: ") +
375 e.what());
376 }
377 }
378 return current_value;
379 }
380 stack.push_back(base);
381 }
382 }
383 return nonstd::make_unexpected(std::string("No transitive path found"));
384 }
385
386 mutable std::shared_mutex mutex_;
387 std::map<std::pair<std::type_index, std::type_index>, CastFunction> upcasts_;
388 std::map<std::pair<std::type_index, std::type_index>, CastFunction> downcasts_;
389 std::map<std::type_index, std::set<std::type_index>> base_types_;
390};
391
392} // namespace BT
Registry for polymorphic shared_ptr cast relationships.
Definition: polymorphic_cast_registry.hpp:47
void clear()
Clear all registrations (mainly for testing).
Definition: polymorphic_cast_registry.hpp:253
std::set< std::type_index > getBaseTypes(std::type_index type) const
Get all registered base types for a given type.
Definition: polymorphic_cast_registry.hpp:239
void registerCast()
Register a Derived -> Base inheritance relationship.
Definition: polymorphic_cast_registry.hpp:72
bool isConvertible(std::type_index from_type, std::type_index to_type) const
Check if from_type can be converted to to_type.
Definition: polymorphic_cast_registry.hpp:116
bool canUpcast(std::type_index from_type, std::type_index to_type) const
Check if from_type can be UPCAST to to_type (not downcast).
Definition: polymorphic_cast_registry.hpp:155
nonstd::expected< linb::any, std::string > tryCast(const linb::any &from, std::type_index from_type, std::type_index to_type) const
Attempt to cast the value to the target type.
Definition: polymorphic_cast_registry.hpp:175
Definition: action_node.h:24