+/*
+ * Variants produced by LTTng-UST contain TSDL-unsafe names. A variant/selector
+ * sanitization pass is performed before serializing a trace class hierarchy to
+ * TSDL.
+ *
+ * The variant_tsdl_keyword_sanitizer visitor is used to visit field before it
+ * is handed-over to the actual TSDL-producing visitor.
+ *
+ * As it visits fields, the variant_tsdl_keyword_sanitizer populates a
+ * "type_overrider" with TSDL-safe replacements for any variant or enumeration
+ * that uses TSDL-unsafe identifiers (reserved keywords).
+ *
+ * The type_overrider, in turn, is used by the rest of the TSDL serialization
+ * visitor (tsdl_field_visitor) to swap any TSDL-unsafe types with their
+ * sanitized version.
+ *
+ * The tsdl_field_visitor owns the type_overrider and only briefly shares it
+ * with the variant_tsdl_keyword_sanitizer which takes a reference to it.
+ */
+class variant_tsdl_keyword_sanitizer : public lttng::sessiond::trace::field_visitor,
+ public lttng::sessiond::trace::type_visitor {
+public:
+ using type_lookup_function = std::function<const lst::type&(const lst::field_location&)>;
+
+ variant_tsdl_keyword_sanitizer(tsdl::details::type_overrider& type_overrides,
+ type_lookup_function lookup_type) :
+ _type_overrides{type_overrides}, _lookup_type(lookup_type)
+ {
+ }
+
+private:
+ class _c_string_comparator {
+ public:
+ int operator()(const char *lhs, const char *rhs) const
+ {
+ return std::strcmp(lhs, rhs) < 0;
+ }
+ };
+ using unsafe_names = std::set<const char *, _c_string_comparator>;
+
+ virtual void visit(const lst::field& field) override final
+ {
+ _type_overrides.type(field.get_type()).accept(*this);
+ }
+
+ virtual void visit(const lst::integer_type& type __attribute__((unused))) override final
+ {
+ }
+
+ virtual void visit(const lst::floating_point_type& type __attribute__((unused))) override final
+ {
+ }
+
+ virtual void visit(const lst::signed_enumeration_type& type __attribute__((unused))) override final
+ {
+ }
+
+ virtual void visit(const lst::unsigned_enumeration_type& type __attribute__((unused))) override final
+ {
+ }
+
+ virtual void visit(const lst::static_length_array_type& type __attribute__((unused))) override final
+ {
+ }
+
+ virtual void visit(const lst::dynamic_length_array_type& type __attribute__((unused))) override final
+ {
+ }
+
+ virtual void visit(const lst::static_length_blob_type& type __attribute__((unused))) override final
+ {
+ }
+
+ virtual void visit(const lst::dynamic_length_blob_type& type __attribute__((unused))) override final
+ {
+ }
+
+ virtual void visit(const lst::null_terminated_string_type& type __attribute__((unused))) override final
+ {
+ }
+
+ virtual void visit(const lst::structure_type& type) override final
+ {
+ /* Recurse into structure attributes. */
+ for (const auto& field : type.fields_) {
+ field->accept(*this);
+ }
+ }
+
+ /*
+ * Create a new enumeration type replacing any mapping that match, by name, the elements in `unsafe_names_found`
+ * with a TSDL-safe version. Currently, unsafe identifiers are made safe by adding
+ * a leading underscore.
+ */
+ template <typename MappingIntegerType>
+ lst::type::cuptr _create_sanitized_selector(
+ const lst::typed_enumeration_type<MappingIntegerType>& original_selector,
+ const unsafe_names& unsafe_names_found)
+ {
+ auto new_mappings = std::make_shared<typename lst::typed_enumeration_type<
+ MappingIntegerType>::mappings>();
+
+ for (const auto& mapping : *original_selector.mappings_) {
+ if (unsafe_names_found.find(mapping.name.c_str()) ==
+ unsafe_names_found.end()) {
+ /* Mapping is safe, simply copy it. */
+ new_mappings->emplace_back(mapping);
+ } else {
+ /* Unsafe mapping, rename it and keep the rest of its attributes. */
+ new_mappings->emplace_back(
+ fmt::format("_{}", mapping.name), mapping.range);
+ }
+ }
+
+ return lttng::make_unique<lst::typed_enumeration_type<MappingIntegerType>>(
+ original_selector.alignment, original_selector.byte_order,
+ original_selector.size, original_selector.base_, new_mappings);
+ }
+
+ template <typename MappingIntegerType>
+ const typename lst::typed_enumeration_type<MappingIntegerType>::mapping&
+ _find_enumeration_mapping_by_range(
+ const typename lst::typed_enumeration_type<MappingIntegerType>&
+ enumeration_type,
+ const typename lst::typed_enumeration_type<
+ MappingIntegerType>::mapping::range_t& target_mapping_range)
+ {
+ for (const auto& mapping : *enumeration_type.mappings_) {
+ if (mapping.range == target_mapping_range) {
+ return mapping;
+ }
+ }
+
+ LTTNG_THROW_ERROR(fmt::format(
+ "Failed to find mapping by range in enumeration while sanitizing a variant: target_mapping_range={}",
+ target_mapping_range));
+ }
+
+ /*
+ * Copy `original_variant`, but use the mappings of a previously-published sanitized tag
+ * to produce a TSDL-safe version of the variant.
+ */
+ template <typename MappingIntegerType>
+ lst::type::cuptr _create_sanitized_variant(
+ const lst::variant_type<MappingIntegerType>& original_variant)
+ {
+ typename lst::variant_type<MappingIntegerType>::choices new_choices;
+ const auto& sanitized_selector = static_cast<
+ const lst::typed_enumeration_type<MappingIntegerType>&>(
+ _type_overrides.type(_lookup_type(
+ original_variant.selector_field_location)));
+
+ /* Visit variant choices to sanitize them as needed. */
+ for (const auto& choice : original_variant.choices_) {
+ choice.second->accept(*this);
+ }
+
+ for (const auto& choice : original_variant.choices_) {
+ const auto& sanitized_choice_type = _type_overrides.type(*choice.second);
+
+ new_choices.emplace_back(
+ _find_enumeration_mapping_by_range(
+ sanitized_selector, choice.first.range),
+ sanitized_choice_type.copy());
+ }
+
+ return lttng::make_unique<lst::variant_type<MappingIntegerType>>(
+ original_variant.alignment,
+ original_variant.selector_field_location,
+ std::move(new_choices));
+ }
+
+ template <typename MappingIntegerType>
+ void visit_variant(const lst::variant_type<MappingIntegerType>& type)
+ {
+ unsafe_names unsafe_names_found;
+ static const std::unordered_set<std::string> tsdl_protected_keywords = {
+ "align",
+ "callsite",
+ "const",
+ "char",
+ "clock",
+ "double",
+ "enum",
+ "env",
+ "event",
+ "floating_point",
+ "float",
+ "integer",
+ "int",
+ "long",
+ "short",
+ "signed",
+ "stream",
+ "string",
+ "struct",
+ "trace",
+ "typealias",
+ "typedef",
+ "unsigned",
+ "variant",
+ "void",
+ "_Bool",
+ "_Complex",
+ "_Imaginary",
+ };
+
+ for (const auto& choice : type.choices_) {
+ if (tsdl_protected_keywords.find(choice.first.name) != tsdl_protected_keywords.cend()) {
+ /* Choice name is illegal, we have to rename it and its matching mapping. */
+ unsafe_names_found.insert(choice.first.name.c_str());
+ }
+ }
+
+ if (unsafe_names_found.empty()) {
+ return;
+ }
+
+ /*
+ * Look-up selector field type.
+ *
+ * Since it may have been overriden previously, keep the original and overriden
+ * selector field types (which may be the same, if the original was not overriden).
+ *
+ * We work from the "overriden" selector field type to preserve any existing
+ * modifications. However, the original field type will be used to publish the new
+ * version of the type leaving only the most recent overriden type in the type
+ * overrides.
+ */
+ const auto& original_selector_type = _lookup_type(type.selector_field_location);
+ const auto& overriden_selector_type = _type_overrides.type(original_selector_type);
+
+ auto sanitized_selector_type = _create_sanitized_selector(
+ static_cast<const lst::typed_enumeration_type<MappingIntegerType>&>(
+ overriden_selector_type), unsafe_names_found);
+ _type_overrides.publish(original_selector_type, std::move(sanitized_selector_type));
+
+ auto sanitized_variant_type = _create_sanitized_variant(
+ static_cast<const lst::variant_type<MappingIntegerType>&>(type));
+ _type_overrides.publish(type, std::move(sanitized_variant_type));
+ }
+
+ virtual void visit(const lst::variant_type<lst::signed_enumeration_type::mapping::range_t::range_integer_t>& type) override final
+ {
+ visit_variant(type);
+ }
+
+ virtual void visit(const lst::variant_type<lst::unsigned_enumeration_type::mapping::range_t::range_integer_t>& type) override final
+ {
+ visit_variant(type);
+ }
+
+ virtual void visit(const lst::static_length_string_type& type __attribute__((unused))) override final
+ {
+ }
+
+ virtual void visit(const lst::dynamic_length_string_type& type __attribute__((unused))) override final
+ {
+ }
+
+ tsdl::details::type_overrider& _type_overrides;
+ const type_lookup_function _lookup_type;
+};
+