From ef945e4deb2e24b7d7ce05c3289a29a22c6b427d Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=A9mie=20Galarneau?= Date: Thu, 17 Nov 2022 11:44:09 -0500 Subject: [PATCH] Tests: add basic ust context tests for $app, vpid, vuid, vgid MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Context tracing has very little testing coverage beyond the namespace tests. This test is initially introduced to help troubleshoot an issue with application contexts. However, with the scaffolding now in place, it's trivial to exercise some other context types. This also adds a basic framework to write tests in Python. Signed-off-by: Jérémie Galarneau Change-Id: I85be842fab252d8b853392d3f742b1461a69f0fe --- configure.ac | 2 + tests/regression/Makefile.am | 3 +- tests/regression/tools/Makefile.am | 1 + tests/regression/tools/context/Makefile.am | 18 + tests/regression/tools/context/test_ust.py | 175 +++++++++ tests/utils/Makefile.am | 2 +- tests/utils/lttngtest/Makefile.am | 29 ++ tests/utils/lttngtest/__init__.py | 12 + tests/utils/lttngtest/environment.py | 373 +++++++++++++++++++ tests/utils/lttngtest/logger.py | 20 + tests/utils/lttngtest/lttng.py | 375 +++++++++++++++++++ tests/utils/lttngtest/lttngctl.py | 406 +++++++++++++++++++++ tests/utils/lttngtest/tap_generator.py | 157 ++++++++ 13 files changed, 1571 insertions(+), 2 deletions(-) create mode 100644 tests/regression/tools/context/Makefile.am create mode 100755 tests/regression/tools/context/test_ust.py create mode 100644 tests/utils/lttngtest/Makefile.am create mode 100644 tests/utils/lttngtest/__init__.py create mode 100644 tests/utils/lttngtest/environment.py create mode 100644 tests/utils/lttngtest/logger.py create mode 100644 tests/utils/lttngtest/lttng.py create mode 100644 tests/utils/lttngtest/lttngctl.py create mode 100644 tests/utils/lttngtest/tap_generator.py diff --git a/configure.ac b/configure.ac index 0f12145db..6697fcd46 100644 --- a/configure.ac +++ b/configure.ac @@ -1235,6 +1235,7 @@ AC_CONFIG_FILES([ tests/regression/tools/trigger/utils/Makefile tests/regression/tools/trigger/name/Makefile tests/regression/tools/trigger/hidden/Makefile + tests/regression/tools/context/Makefile tests/regression/ust/Makefile tests/regression/ust/nprocesses/Makefile tests/regression/ust/high-throughput/Makefile @@ -1267,6 +1268,7 @@ AC_CONFIG_FILES([ tests/unit/ini_config/Makefile tests/perf/Makefile tests/utils/Makefile + tests/utils/lttngtest/Makefile tests/utils/tap/Makefile tests/utils/testapp/Makefile tests/utils/testapp/gen-ns-events/Makefile diff --git a/tests/regression/Makefile.am b/tests/regression/Makefile.am index 1ff5ad7f3..3aa3b25fe 100644 --- a/tests/regression/Makefile.am +++ b/tests/regression/Makefile.am @@ -63,7 +63,8 @@ TESTS = tools/base-path/test_ust \ tools/trigger/test_list_triggers_cli \ tools/trigger/test_remove_trigger_cli \ tools/trigger/name/test_trigger_name_backwards_compat \ - tools/trigger/hidden/test_hidden_trigger + tools/trigger/hidden/test_hidden_trigger \ + tools/context/test_ust.py # Only build kernel tests on Linux. if IS_LINUX diff --git a/tests/regression/tools/Makefile.am b/tests/regression/tools/Makefile.am index fc39333c0..e77c7b6b9 100644 --- a/tests/regression/tools/Makefile.am +++ b/tests/regression/tools/Makefile.am @@ -3,6 +3,7 @@ SUBDIRS = base-path \ channel \ clear \ + context \ crash \ exclusion \ filtering \ diff --git a/tests/regression/tools/context/Makefile.am b/tests/regression/tools/context/Makefile.am new file mode 100644 index 000000000..53d92a279 --- /dev/null +++ b/tests/regression/tools/context/Makefile.am @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: GPL-2.0-only + +noinst_SCRIPTS = test_ust.py +EXTRA_DIST = test_ust.py + +all-local: + @if [ x"$(srcdir)" != x"$(builddir)" ]; then \ + for script in $(EXTRA_DIST); do \ + cp -f $(srcdir)/$$script $(builddir); \ + done; \ + fi + +clean-local: + @if [ x"$(srcdir)" != x"$(builddir)" ]; then \ + for script in $(EXTRA_DIST); do \ + rm -f $(builddir)/$$script; \ + done; \ + fi diff --git a/tests/regression/tools/context/test_ust.py b/tests/regression/tools/context/test_ust.py new file mode 100755 index 000000000..d3215f0f4 --- /dev/null +++ b/tests/regression/tools/context/test_ust.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2022 Jérémie Galarneau +# +# SPDX-License-Identifier: GPL-2.0-only + +from cgi import test +import pathlib +import sys +import os +from typing import Any, Callable, Type + +""" +Test the addition of various user space contexts. + +This test successively sets up a session with a certain context enabled, traces +a test application, and then reads the resulting trace to determine if: + - the context field is present in the trace + - the context field has the expected value. + +The vpid, vuid, vgid and java application contexts are validated by this test. +""" + +# Import in-tree test utils +test_utils_import_path = pathlib.Path(__file__).absolute().parents[3] / "utils" +sys.path.append(str(test_utils_import_path)) + +import lttngtest +import bt2 + + +def context_trace_field_name(context_type: Type[lttngtest.ContextType]) -> str: + if isinstance(context_type, lttngtest.VpidContextType): + return "vpid" + elif isinstance(context_type, lttngtest.VuidContextType): + return "vuid" + elif isinstance(context_type, lttngtest.VgidContextType): + return "vgid" + elif isinstance(context_type, lttngtest.JavaApplicationContextType): + # Depends on the trace format and will need to be adapted for CTF 2. + return "_app_{retriever}_{name}".format( + retriever=context_type.retriever_name, name=context_type.field_name + ) + else: + raise NotImplementedError + + +def trace_stream_class_has_context_field_in_event_context( + trace_location: pathlib.Path, context_field_name: str +) -> bool: + iterator = bt2.TraceCollectionMessageIterator(str(trace_location)) + + # A bt2 message sequence is guaranteed to begin with a StreamBeginningMessage. + # Since we only have one channel (one stream class) and one trace, it is + # safe to use it to determine if the stream class contains the expected + # context field. + stream_begin_msg = next(iterator) + + trace_class = stream_begin_msg.stream.trace.cls + # Ensure the trace class has only one stream class. + assert len(trace_class) + + stream_class_id = next(iter(trace_class)) + stream_class = trace_class[stream_class_id] + event_common_context_field_class = stream_class.event_common_context_field_class + + return context_field_name in event_common_context_field_class + + +def trace_events_have_context_value( + trace_location: pathlib.Path, context_field_name: str, value: Any +) -> bool: + for msg in bt2.TraceCollectionMessageIterator(str(trace_location)): + if type(msg) is not bt2._EventMessageConst: + continue + + if msg.event.common_context_field[context_field_name] != value: + print(msg.event.common_context_field[context_field_name]) + return False + return True + + +def test_static_context( + tap: lttngtest.TapGenerator, + test_env: lttngtest._Environment, + context_type: lttngtest.ContextType, + context_value_retriever: Callable[[lttngtest.WaitTraceTestApplication], Any], +) -> None: + tap.diagnostic( + "Test presence and expected value of context `{context_name}`".format( + context_name=type(context_type).__name__ + ) + ) + + session_output_location = lttngtest.LocalSessionOutputLocation( + test_env.create_temporary_directory("trace") + ) + + client: lttngtest.Controller = lttngtest.LTTngClient(test_env, log=tap.diagnostic) + + with tap.case("Create a session") as test_case: + session = client.create_session(output=session_output_location) + tap.diagnostic("Created session `{session_name}`".format(session_name=session.name)) + + with tap.case( + "Add a channel to session `{session_name}`".format(session_name=session.name) + ) as test_case: + channel = session.add_channel(lttngtest.TracingDomain.User) + tap.diagnostic("Created channel `{channel_name}`".format(channel_name=channel.name)) + + with tap.case( + "Add {context_type} context to channel `{channel_name}`".format( + context_type=type(context_type).__name__, channel_name=channel.name + ) + ) as test_case: + channel.add_context(context_type) + + test_app = test_env.launch_wait_trace_test_application(50) + + # Only track the test application + session.user_vpid_process_attribute_tracker.track(test_app.vpid) + expected_context_value = context_value_retriever(test_app) + + # Enable all user space events, the default for a user tracepoint event rule. + channel.add_recording_rule(lttngtest.UserTracepointEventRule()) + + session.start() + test_app.trace() + test_app.wait_for_exit() + session.stop() + session.destroy() + + tap.test( + trace_stream_class_has_context_field_in_event_context( + session_output_location.path, context_trace_field_name(context_type) + ), + "Stream class contains field `{context_field_name}`".format( + context_field_name=context_trace_field_name(context_type) + ), + ) + + tap.test( + trace_events_have_context_value( + session_output_location.path, + context_trace_field_name(context_type), + expected_context_value, + ), + "Trace's events contain the expected `{context_field_name}` value `{expected_context_value}`".format( + context_field_name=context_trace_field_name(context_type), + expected_context_value=expected_context_value, + ), + ) + + +tap = lttngtest.TapGenerator(20) +tap.diagnostic("Test user space context tracing") + +with lttngtest.test_environment(with_sessiond=True, log=tap.diagnostic) as test_env: + test_static_context( + tap, test_env, lttngtest.VpidContextType(), lambda test_app: test_app.vpid + ) + test_static_context( + tap, test_env, lttngtest.VuidContextType(), lambda test_app: os.getuid() + ) + test_static_context( + tap, test_env, lttngtest.VgidContextType(), lambda test_app: os.getgid() + ) + test_static_context( + tap, + test_env, + lttngtest.JavaApplicationContextType("mayo", "ketchup"), + lambda test_app: {}, + ) + +sys.exit(0 if tap.is_successful else 1) diff --git a/tests/utils/Makefile.am b/tests/utils/Makefile.am index 96dbe2852..941476e31 100644 --- a/tests/utils/Makefile.am +++ b/tests/utils/Makefile.am @@ -1,6 +1,6 @@ # SPDX-License-Identifier: GPL-2.0-only -SUBDIRS = . tap testapp xml-utils +SUBDIRS = . tap testapp xml-utils lttngtest EXTRA_DIST = utils.sh test_utils.py babelstats.pl warn_processes.sh \ parse-callstack.py diff --git a/tests/utils/lttngtest/Makefile.am b/tests/utils/lttngtest/Makefile.am new file mode 100644 index 000000000..a78c62638 --- /dev/null +++ b/tests/utils/lttngtest/Makefile.am @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: GPL-2.0-only + +EXTRA_DIST = __init__.py \ + environment.py \ + logger.py \ + lttngctl.py \ + lttng.py \ + tap_generator.py + +dist_noinst_SCRIPTS = __init__.py \ + environment.py \ + logger.py \ + lttngctl.py \ + lttng.py \ + tap_generator.py + +all-local: + @if [ x"$(srcdir)" != x"$(builddir)" ]; then \ + for script in $(EXTRA_DIST); do \ + cp -f $(srcdir)/$$script $(builddir); \ + done; \ + fi + +clean-local: + @if [ x"$(srcdir)" != x"$(builddir)" ]; then \ + for script in $(EXTRA_DIST); do \ + rm -f $(builddir)/$$script; \ + done; \ + fi diff --git a/tests/utils/lttngtest/__init__.py b/tests/utils/lttngtest/__init__.py new file mode 100644 index 000000000..3676667fc --- /dev/null +++ b/tests/utils/lttngtest/__init__.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2022 Jérémie Galarneau +# +# SPDX-License-Identifier: GPL-2.0-only +# + +from .tap_generator import * +from .environment import * +from .environment import _Environment +from .lttngctl import * +from .lttng import * diff --git a/tests/utils/lttngtest/environment.py b/tests/utils/lttngtest/environment.py new file mode 100644 index 000000000..2e4b48569 --- /dev/null +++ b/tests/utils/lttngtest/environment.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2022 Jérémie Galarneau +# +# SPDX-License-Identifier: GPL-2.0-only +# + +from types import FrameType +from typing import Callable, Optional, Tuple, List +import sys +import pathlib +import signal +import subprocess +import shlex +import shutil +import os +import queue +import tempfile +from . import logger +import time +import threading +import contextlib + + +class TemporaryDirectory: + def __init__(self, prefix: str): + self._directory_path = tempfile.mkdtemp(prefix=prefix) + + def __del__(self): + shutil.rmtree(self._directory_path, ignore_errors=True) + + @property + def path(self) -> pathlib.Path: + return pathlib.Path(self._directory_path) + + +class _SignalWaitQueue: + """ + Utility class useful to wait for a signal before proceeding. + + Simply register the `signal` method as the handler for the signal you are + interested in and call `wait_for_signal` to wait for its reception. + + Registering a signal: + signal.signal(signal.SIGWHATEVER, queue.signal) + + Waiting for the signal: + queue.wait_for_signal() + """ + + def __init__(self): + self._queue: queue.Queue = queue.Queue() + + def signal(self, signal_number, frame: Optional[FrameType]): + self._queue.put_nowait(signal_number) + + def wait_for_signal(self): + self._queue.get(block=True) + + +class WaitTraceTestApplication: + """ + Create an application that waits before tracing. This allows a test to + launch an application, get its PID, and get it to start tracing when it + has completed its setup. + """ + + def __init__( + self, + binary_path: pathlib.Path, + event_count: int, + environment: "Environment", + wait_time_between_events_us: int = 0, + ): + self._environment: Environment = environment + if event_count % 5: + # The test application currently produces 5 different events per iteration. + raise ValueError("event count must be a multiple of 5") + self._iteration_count: int = int(event_count / 5) + # File that the application will wait to see before tracing its events. + self._app_start_tracing_file_path: pathlib.Path = pathlib.Path( + tempfile.mktemp( + prefix="app_", + suffix="_start_tracing", + dir=environment.lttng_home_location, + ) + ) + self._has_returned = False + + test_app_env = os.environ.copy() + test_app_env["LTTNG_HOME"] = str(environment.lttng_home_location) + # Make sure the app is blocked until it is properly registered to + # the session daemon. + test_app_env["LTTNG_UST_REGISTER_TIMEOUT"] = "-1" + + # File that the application will create to indicate it has completed its initialization. + app_ready_file_path: str = tempfile.mktemp( + prefix="app_", suffix="_ready", dir=environment.lttng_home_location + ) + + test_app_args = [str(binary_path)] + test_app_args.extend( + shlex.split( + "--iter {iteration_count} --create-in-main {app_ready_file_path} --wait-before-first-event {app_start_tracing_file_path} --wait {wait_time_between_events_us}".format( + iteration_count=self._iteration_count, + app_ready_file_path=app_ready_file_path, + app_start_tracing_file_path=self._app_start_tracing_file_path, + wait_time_between_events_us=wait_time_between_events_us, + ) + ) + ) + + self._process: subprocess.Popen = subprocess.Popen( + test_app_args, + env=test_app_env, + ) + + # Wait for the application to create the file indicating it has fully + # initialized. Make sure the app hasn't crashed in order to not wait + # forever. + while True: + if os.path.exists(app_ready_file_path): + break + + if self._process.poll() is not None: + # Application has unexepectedly returned. + raise RuntimeError( + "Test application has unexepectedly returned during its initialization with return code `{return_code}`".format( + return_code=self._process.returncode + ) + ) + + time.sleep(0.1) + + def trace(self) -> None: + if self._process.poll() is not None: + # Application has unexepectedly returned. + raise RuntimeError( + "Test application has unexepectedly before tracing with return code `{return_code}`".format( + return_code=self._process.returncode + ) + ) + open(self._app_start_tracing_file_path, mode="x") + + def wait_for_exit(self) -> None: + if self._process.wait() != 0: + raise RuntimeError( + "Test application has exit with return code `{return_code}`".format( + return_code=self._process.returncode + ) + ) + self._has_returned = True + + @property + def vpid(self) -> int: + return self._process.pid + + def __del__(self): + if not self._has_returned: + # This is potentially racy if the pid has been recycled. However, + # we can't use pidfd_open since it is only available in python >= 3.9. + self._process.kill() + self._process.wait() + + +class ProcessOutputConsumer(threading.Thread, logger._Logger): + def __init__( + self, process: subprocess.Popen, name: str, log: Callable[[str], None] + ): + threading.Thread.__init__(self) + self._prefix = name + logger._Logger.__init__(self, log) + self._process = process + + def run(self) -> None: + while self._process.poll() is None: + assert self._process.stdout + line = self._process.stdout.readline().decode("utf-8").replace("\n", "") + if len(line) != 0: + self._log("{prefix}: {line}".format(prefix=self._prefix, line=line)) + + +# Generate a temporary environment in which to execute a test. +class _Environment(logger._Logger): + def __init__( + self, with_sessiond: bool, log: Optional[Callable[[str], None]] = None + ): + super().__init__(log) + signal.signal(signal.SIGTERM, self._handle_termination_signal) + signal.signal(signal.SIGINT, self._handle_termination_signal) + + # Assumes the project's hierarchy to this file is: + # tests/utils/python/this_file + self._project_root: pathlib.Path = pathlib.Path(__file__).absolute().parents[3] + self._lttng_home: Optional[TemporaryDirectory] = TemporaryDirectory( + "lttng_test_env_home" + ) + + self._sessiond: Optional[subprocess.Popen[bytes]] = ( + self._launch_lttng_sessiond() if with_sessiond else None + ) + + @property + def lttng_home_location(self) -> pathlib.Path: + if self._lttng_home is None: + raise RuntimeError("Attempt to access LTTng home after clean-up") + return self._lttng_home.path + + @property + def lttng_client_path(self) -> pathlib.Path: + return self._project_root / "src" / "bin" / "lttng" / "lttng" + + def create_temporary_directory(self, prefix: Optional[str] = None) -> pathlib.Path: + # Simply return a path that is contained within LTTNG_HOME; it will + # be destroyed when the temporary home goes out of scope. + assert self._lttng_home + return pathlib.Path( + tempfile.mkdtemp( + prefix="tmp" if prefix is None else prefix, + dir=str(self._lttng_home.path), + ) + ) + + # Unpack a list of environment variables from a string + # such as "HELLO=is_it ME='/you/are/looking/for'" + @staticmethod + def _unpack_env_vars(env_vars_string: str) -> List[Tuple[str, str]]: + unpacked_vars = [] + for var in shlex.split(env_vars_string): + equal_position = var.find("=") + # Must have an equal sign and not end with an equal sign + if equal_position == -1 or equal_position == len(var) - 1: + raise ValueError( + "Invalid sessiond environment variable: `{}`".format(var) + ) + + var_name = var[0:equal_position] + var_value = var[equal_position + 1 :] + # Unquote any paths + var_value = var_value.replace("'", "") + var_value = var_value.replace('"', "") + unpacked_vars.append((var_name, var_value)) + + return unpacked_vars + + def _launch_lttng_sessiond(self) -> Optional[subprocess.Popen]: + is_64bits_host = sys.maxsize > 2**32 + + sessiond_path = ( + self._project_root / "src" / "bin" / "lttng-sessiond" / "lttng-sessiond" + ) + consumerd_path_option_name = "--consumerd{bitness}-path".format( + bitness="64" if is_64bits_host else "32" + ) + consumerd_path = ( + self._project_root / "src" / "bin" / "lttng-consumerd" / "lttng-consumerd" + ) + + no_sessiond_var = os.environ.get("TEST_NO_SESSIOND") + if no_sessiond_var and no_sessiond_var == "1": + # Run test without a session daemon; the user probably + # intends to run one under gdb for example. + return None + + # Setup the session daemon's environment + sessiond_env_vars = os.environ.get("LTTNG_SESSIOND_ENV_VARS") + sessiond_env = os.environ.copy() + if sessiond_env_vars: + self._log("Additional lttng-sessiond environment variables:") + additional_vars = self._unpack_env_vars(sessiond_env_vars) + for var_name, var_value in additional_vars: + self._log(" {name}={value}".format(name=var_name, value=var_value)) + sessiond_env[var_name] = var_value + + sessiond_env["LTTNG_SESSION_CONFIG_XSD_PATH"] = str( + self._project_root / "src" / "common" + ) + + assert self._lttng_home is not None + sessiond_env["LTTNG_HOME"] = str(self._lttng_home.path) + + wait_queue = _SignalWaitQueue() + signal.signal(signal.SIGUSR1, wait_queue.signal) + + self._log( + "Launching session daemon with LTTNG_HOME=`{home_dir}`".format( + home_dir=str(self._lttng_home.path) + ) + ) + process = subprocess.Popen( + [ + str(sessiond_path), + consumerd_path_option_name, + str(consumerd_path), + "--sig-parent", + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=sessiond_env, + ) + + if self._logging_function: + self._sessiond_output_consumer: Optional[ + ProcessOutputConsumer + ] = ProcessOutputConsumer(process, "lttng-sessiond", self._logging_function) + self._sessiond_output_consumer.daemon = True + self._sessiond_output_consumer.start() + + # Wait for SIGUSR1, indicating the sessiond is ready to proceed + wait_queue.wait_for_signal() + signal.signal(signal.SIGUSR1, wait_queue.signal) + + return process + + def _handle_termination_signal( + self, signal_number: int, frame: Optional[FrameType] + ) -> None: + self._log( + "Killed by {signal_name} signal, cleaning-up".format( + signal_name=signal.strsignal(signal_number) + ) + ) + self._cleanup() + + def launch_wait_trace_test_application( + self, event_count: int + ) -> WaitTraceTestApplication: + """ + Launch an application that will wait before tracing `event_count` events. + """ + return WaitTraceTestApplication( + self._project_root + / "tests" + / "utils" + / "testapp" + / "gen-ust-nevents" + / "gen-ust-nevents", + event_count, + self, + ) + + # Clean-up managed processes + def _cleanup(self) -> None: + if self._sessiond and self._sessiond.poll() is None: + # The session daemon is alive; kill it. + self._log( + "Killing session daemon (pid = {sessiond_pid})".format( + sessiond_pid=self._sessiond.pid + ) + ) + + self._sessiond.terminate() + self._sessiond.wait() + if self._sessiond_output_consumer: + self._sessiond_output_consumer.join() + self._sessiond_output_consumer = None + + self._log("Session daemon killed") + self._sessiond = None + + self._lttng_home = None + + def __del__(self): + self._cleanup() + + +@contextlib.contextmanager +def test_environment(with_sessiond: bool, log: Optional[Callable[[str], None]] = None): + env = _Environment(with_sessiond, log) + try: + yield env + finally: + env._cleanup() diff --git a/tests/utils/lttngtest/logger.py b/tests/utils/lttngtest/logger.py new file mode 100644 index 000000000..9f90ec06a --- /dev/null +++ b/tests/utils/lttngtest/logger.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2022 Jérémie Galarneau +# +# SPDX-License-Identifier: GPL-2.0-only + +from typing import Callable, Optional + + +class _Logger: + def __init__(self, log: Optional[Callable[[str], None]]): + self._logging_function: Optional[Callable[[str], None]] = log + + def _log(self, msg: str) -> None: + if self._logging_function: + self._logging_function(msg) + + @property + def logger(self) -> Optional[Callable[[str], None]]: + return self._logging_function diff --git a/tests/utils/lttngtest/lttng.py b/tests/utils/lttngtest/lttng.py new file mode 100644 index 000000000..6829fa707 --- /dev/null +++ b/tests/utils/lttngtest/lttng.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2022 Jérémie Galarneau +# +# SPDX-License-Identifier: GPL-2.0-only + +from concurrent.futures import process +from . import lttngctl, logger, environment +import pathlib +import os +from typing import Callable, Optional, Type, Union +import shlex +import subprocess +import enum + +""" +Implementation of the lttngctl interface based on the `lttng` command line client. +""" + + +class Unsupported(lttngctl.ControlException): + def __init__(self, msg: str): + super().__init__(msg) + + +def _get_domain_option_name(domain: lttngctl.TracingDomain) -> str: + if domain == lttngctl.TracingDomain.User: + return "userspace" + elif domain == lttngctl.TracingDomain.Kernel: + return "kernel" + elif domain == lttngctl.TracingDomain.Log4j: + return "log4j" + elif domain == lttngctl.TracingDomain.JUL: + return "jul" + elif domain == lttngctl.TracingDomain.Python: + return "python" + else: + raise Unsupported("Domain `{domain_name}` is not supported by the LTTng client") + + +def _get_context_type_name(context: lttngctl.ContextType) -> str: + if isinstance(context, lttngctl.VgidContextType): + return "vgid" + elif isinstance(context, lttngctl.VuidContextType): + return "vuid" + elif isinstance(context, lttngctl.VpidContextType): + return "vpid" + elif isinstance(context, lttngctl.JavaApplicationContextType): + return "$app.{retriever}:{field}".format( + retriever=context.retriever_name, field=context.field_name + ) + else: + raise Unsupported( + "Context `{context_name}` is not supported by the LTTng client".format( + type(context).__name__ + ) + ) + + +class _Channel(lttngctl.Channel): + def __init__( + self, + client: "LTTngClient", + name: str, + domain: lttngctl.TracingDomain, + session: "_Session", + ): + self._client: LTTngClient = client + self._name: str = name + self._domain: lttngctl.TracingDomain = domain + self._session: _Session = session + + def add_context(self, context_type: lttngctl.ContextType) -> None: + domain_option_name = _get_domain_option_name(self.domain) + context_type_name = _get_context_type_name(context_type) + self._client._run_cmd( + "add-context --{domain_option_name} --type {context_type_name}".format( + domain_option_name=domain_option_name, + context_type_name=context_type_name, + ) + ) + + def add_recording_rule(self, rule: Type[lttngctl.EventRule]) -> None: + client_args = ( + "enable-event --session {session_name} --channel {channel_name}".format( + session_name=self._session.name, channel_name=self.name + ) + ) + if isinstance(rule, lttngctl.TracepointEventRule): + domain_option_name = ( + "userspace" + if isinstance(rule, lttngctl.UserTracepointEventRule) + else "kernel" + ) + client_args = client_args + " --{domain_option_name}".format( + domain_option_name=domain_option_name + ) + + if rule.name_pattern: + client_args = client_args + " " + rule.name_pattern + else: + client_args = client_args + " --all" + + if rule.filter_expression: + client_args = client_args + " " + rule.filter_expression + + if rule.log_level_rule: + if isinstance(rule.log_level_rule, lttngctl.LogLevelRuleAsSevereAs): + client_args = client_args + " --loglevel {log_level}".format( + log_level=rule.log_level_rule.level + ) + elif isinstance(rule.log_level_rule, lttngctl.LogLevelRuleExactly): + client_args = client_args + " --loglevel-only {log_level}".format( + log_level=rule.log_level_rule.level + ) + else: + raise Unsupported( + "Unsupported log level rule type `{log_level_rule_type}`".format( + log_level_rule_type=type(rule.log_level_rule).__name__ + ) + ) + + if rule.name_pattern_exclusions: + client_args = client_args + " --exclude " + for idx, pattern in enumerate(rule.name_pattern_exclusions): + if idx != 0: + client_args = client_args + "," + client_args = client_args + pattern + else: + raise Unsupported( + "event rule type `{event_rule_type}` is unsupported by LTTng client".format( + event_rule_type=type(rule).__name__ + ) + ) + + self._client._run_cmd(client_args) + + @property + def name(self) -> str: + return self._name + + @property + def domain(self) -> lttngctl.TracingDomain: + return self._domain + + +class _ProcessAttribute(enum.Enum): + PID = (enum.auto(),) + VPID = (enum.auto(),) + UID = (enum.auto(),) + VUID = (enum.auto(),) + GID = (enum.auto(),) + VGID = (enum.auto(),) + + +def _get_process_attribute_option_name(attribute: _ProcessAttribute) -> str: + return { + _ProcessAttribute.PID: "pid", + _ProcessAttribute.VPID: "vpid", + _ProcessAttribute.UID: "uid", + _ProcessAttribute.VUID: "vuid", + _ProcessAttribute.GID: "gid", + _ProcessAttribute.VGID: "vgid", + }[attribute] + + +class _ProcessAttributeTracker(lttngctl.ProcessAttributeTracker): + def __init__( + self, + client: "LTTngClient", + attribute: _ProcessAttribute, + domain: lttngctl.TracingDomain, + session: "_Session", + ): + self._client: LTTngClient = client + self._tracked_attribute: _ProcessAttribute = attribute + self._domain: lttngctl.TracingDomain = domain + self._session: "_Session" = session + if attribute == _ProcessAttribute.PID or attribute == _ProcessAttribute.VPID: + self._allowed_value_types: list[type] = [int, str] + else: + self._allowed_value_types: list[type] = [int] + + def _call_client(self, cmd_name: str, value: Union[int, str]) -> None: + if type(value) not in self._allowed_value_types: + raise TypeError( + "Value of type `{value_type}` is not allowed for process attribute {attribute_name}".format( + value_type=type(value).__name__, + attribute_name=self._tracked_attribute.name, + ) + ) + + process_attribute_option_name = _get_process_attribute_option_name( + self._tracked_attribute + ) + domain_name = _get_domain_option_name(self._domain) + self._client._run_cmd( + "{cmd_name} --session {session_name} --{domain_name} --{tracked_attribute_name} {value}".format( + cmd_name=cmd_name, + session_name=self._session.name, + domain_name=domain_name, + tracked_attribute_name=process_attribute_option_name, + value=value, + ) + ) + + def track(self, value: Union[int, str]) -> None: + self._call_client("track", value) + + def untrack(self, value: Union[int, str]) -> None: + self._call_client("untrack", value) + + +class _Session(lttngctl.Session): + def __init__( + self, + client: "LTTngClient", + name: str, + output: Optional[Type[lttngctl.SessionOutputLocation]], + ): + self._client: LTTngClient = client + self._name: str = name + self._output: Optional[Type[lttngctl.SessionOutputLocation]] = output + + @property + def name(self) -> str: + return self._name + + def add_channel( + self, domain: lttngctl.TracingDomain, channel_name: Optional[str] = None + ) -> lttngctl.Channel: + channel_name = lttngctl.Channel._generate_name() + domain_option_name = _get_domain_option_name(domain) + self._client._run_cmd( + "enable-channel --{domain_name} {channel_name}".format( + domain_name=domain_option_name, channel_name=channel_name + ) + ) + return _Channel(self._client, channel_name, domain, self) + + def add_context(self, context_type: lttngctl.ContextType) -> None: + pass + + @property + def output(self) -> Optional[Type[lttngctl.SessionOutputLocation]]: + return self._output + + def start(self) -> None: + self._client._run_cmd("start {session_name}".format(session_name=self.name)) + + def stop(self) -> None: + self._client._run_cmd("stop {session_name}".format(session_name=self.name)) + + def destroy(self) -> None: + self._client._run_cmd("destroy {session_name}".format(session_name=self.name)) + + @property + def kernel_pid_process_attribute_tracker( + self, + ) -> Type[lttngctl.ProcessIDProcessAttributeTracker]: + return _ProcessAttributeTracker(self._client, _ProcessAttribute.PID, lttngctl.TracingDomain.Kernel, self) # type: ignore + + @property + def kernel_vpid_process_attribute_tracker( + self, + ) -> Type[lttngctl.VirtualProcessIDProcessAttributeTracker]: + return _ProcessAttributeTracker(self._client, _ProcessAttribute.VPID, lttngctl.TracingDomain.Kernel, self) # type: ignore + + @property + def user_vpid_process_attribute_tracker( + self, + ) -> Type[lttngctl.VirtualProcessIDProcessAttributeTracker]: + return _ProcessAttributeTracker(self._client, _ProcessAttribute.VPID, lttngctl.TracingDomain.User, self) # type: ignore + + @property + def kernel_gid_process_attribute_tracker( + self, + ) -> Type[lttngctl.GroupIDProcessAttributeTracker]: + return _ProcessAttributeTracker(self._client, _ProcessAttribute.GID, lttngctl.TracingDomain.Kernel, self) # type: ignore + + @property + def kernel_vgid_process_attribute_tracker( + self, + ) -> Type[lttngctl.VirtualGroupIDProcessAttributeTracker]: + return _ProcessAttributeTracker(self._client, _ProcessAttribute.VGID, lttngctl.TracingDomain.Kernel, self) # type: ignore + + @property + def user_vgid_process_attribute_tracker( + self, + ) -> Type[lttngctl.VirtualGroupIDProcessAttributeTracker]: + return _ProcessAttributeTracker(self._client, _ProcessAttribute.VGID, lttngctl.TracingDomain.User, self) # type: ignore + + @property + def kernel_uid_process_attribute_tracker( + self, + ) -> Type[lttngctl.UserIDProcessAttributeTracker]: + return _ProcessAttributeTracker(self._client, _ProcessAttribute.UID, lttngctl.TracingDomain.Kernel, self) # type: ignore + + @property + def kernel_vuid_process_attribute_tracker( + self, + ) -> Type[lttngctl.VirtualUserIDProcessAttributeTracker]: + return _ProcessAttributeTracker(self._client, _ProcessAttribute.VUID, lttngctl.TracingDomain.Kernel, self) # type: ignore + + @property + def user_vuid_process_attribute_tracker( + self, + ) -> Type[lttngctl.VirtualUserIDProcessAttributeTracker]: + return _ProcessAttributeTracker(self._client, _ProcessAttribute.VUID, lttngctl.TracingDomain.User, self) # type: ignore + + +class LTTngClientError(lttngctl.ControlException): + def __init__(self, command_args: str, error_output: str): + self._command_args: str = command_args + self._output: str = error_output + + +class LTTngClient(logger._Logger, lttngctl.Controller): + """ + Implementation of a LTTngCtl Controller that uses the `lttng` client as a back-end. + """ + + def __init__( + self, + test_environment: environment._Environment, + log: Optional[Callable[[str], None]], + ): + logger._Logger.__init__(self, log) + self._environment: environment._Environment = test_environment + + def _run_cmd(self, command_args: str) -> None: + """ + Invoke the `lttng` client with a set of arguments. The command is + executed in the context of the client's test environment. + """ + args: list[str] = [str(self._environment.lttng_client_path)] + args.extend(shlex.split(command_args)) + + self._log("lttng {command_args}".format(command_args=command_args)) + + client_env: dict[str, str] = os.environ.copy() + client_env["LTTNG_HOME"] = str(self._environment.lttng_home_location) + + process = subprocess.Popen( + args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=client_env + ) + + out = process.communicate()[0] + + if process.returncode != 0: + decoded_output = out.decode("utf-8") + for error_line in decoded_output.splitlines(): + self._log(error_line) + raise LTTngClientError(command_args, decoded_output) + + def create_session( + self, + name: Optional[str] = None, + output: Optional[lttngctl.SessionOutputLocation] = None, + ) -> lttngctl.Session: + name = name if name else lttngctl.Session._generate_name() + + if isinstance(output, lttngctl.LocalSessionOutputLocation): + output_option = "--output {output_path}".format(output_path=output.path) + elif output is None: + output_option = "--no-output" + else: + raise TypeError("LTTngClient only supports local or no output") + + self._run_cmd( + "create {session_name} {output_option}".format( + session_name=name, output_option=output_option + ) + ) + return _Session(self, name, output) diff --git a/tests/utils/lttngtest/lttngctl.py b/tests/utils/lttngtest/lttngctl.py new file mode 100644 index 000000000..632c13083 --- /dev/null +++ b/tests/utils/lttngtest/lttngctl.py @@ -0,0 +1,406 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2022 Jérémie Galarneau +# +# SPDX-License-Identifier: GPL-2.0-only + +import abc +import random +import string +import pathlib +import enum +from typing import Optional, Type, Union, List + +""" +Defines an abstract interface to control LTTng tracing. + +The various control concepts are defined by this module. You can use them with a +Controller to interact with a session daemon. + +This interface is not comprehensive; it currently provides a subset of the +control functionality that is used by tests. +""" + + +def _generate_random_string(length: int) -> str: + return "".join( + random.choice(string.ascii_lowercase + string.digits) for _ in range(length) + ) + + +class ContextType(abc.ABC): + """Base class representing a tracing context field.""" + + pass + + +class VpidContextType(ContextType): + """Application's virtual process id.""" + + pass + + +class VuidContextType(ContextType): + """Application's virtual user id.""" + + pass + + +class VgidContextType(ContextType): + """Application's virtual group id.""" + + pass + + +class JavaApplicationContextType(ContextType): + """A java application-specific context field is a piece of state which the application provides.""" + + def __init__(self, retriever_name: str, field_name: str): + self._retriever_name: str = retriever_name + self._field_name: str = field_name + + @property + def retriever_name(self) -> str: + return self._retriever_name + + @property + def field_name(self) -> str: + return self._field_name + + +class TracingDomain(enum.Enum): + """Tracing domain.""" + + User = enum.auto(), "User space tracing domain" + Kernel = enum.auto(), "Linux kernel tracing domain." + Log4j = enum.auto(), "Log4j tracing back-end." + JUL = enum.auto(), "Java Util Logging tracing back-end." + Python = enum.auto(), "Python logging module tracing back-end." + + +class EventRule(abc.ABC): + """Event rule base class, see LTTNG-EVENT-RULE(7).""" + + pass + + +class LogLevelRule: + pass + + +class LogLevelRuleAsSevereAs(LogLevelRule): + def __init__(self, level: int): + self._level = level + + @property + def level(self) -> int: + return self._level + + +class LogLevelRuleExactly(LogLevelRule): + def __init__(self, level: int): + self._level = level + + @property + def level(self) -> int: + return self._level + + +class TracepointEventRule(EventRule): + def __init__( + self, + name_pattern: Optional[str] = None, + filter_expression: Optional[str] = None, + log_level_rule: Optional[LogLevelRule] = None, + name_pattern_exclusions: Optional[List[str]] = None, + ): + self._name_pattern: Optional[str] = name_pattern + self._filter_expression: Optional[str] = filter_expression + self._log_level_rule: Optional[LogLevelRule] = log_level_rule + self._name_pattern_exclusions: Optional[List[str]] = name_pattern_exclusions + + @property + def name_pattern(self) -> Optional[str]: + return self._name_pattern + + @property + def filter_expression(self) -> Optional[str]: + return self._filter_expression + + @property + def log_level_rule(self) -> Optional[LogLevelRule]: + return self._log_level_rule + + @property + def name_pattern_exclusions(self) -> Optional[List[str]]: + return self._name_pattern_exclusions + + +class UserTracepointEventRule(TracepointEventRule): + def __init__( + self, + name_pattern: Optional[str] = None, + filter_expression: Optional[str] = None, + log_level_rule: Optional[LogLevelRule] = None, + name_pattern_exclusions: Optional[List[str]] = None, + ): + TracepointEventRule.__init__(**locals()) + + +class KernelTracepointEventRule(TracepointEventRule): + def __init__( + self, + name_pattern: Optional[str] = None, + filter_expression: Optional[str] = None, + log_level_rule: Optional[LogLevelRule] = None, + name_pattern_exclusions: Optional[List[str]] = None, + ): + TracepointEventRule.__init__(**locals()) + + +class Channel(abc.ABC): + """ + A channel is an object which is responsible for a set of ring buffers. It is + associated to a domain and + """ + + @staticmethod + def _generate_name() -> str: + return "channel_{random_id}".format(random_id=_generate_random_string(8)) + + @abc.abstractmethod + def add_context(self, context_type: ContextType) -> None: + pass + + @property + @abc.abstractmethod + def domain(self) -> TracingDomain: + pass + + @property + @abc.abstractmethod + def name(self) -> str: + pass + + @abc.abstractmethod + def add_recording_rule(self, rule: Type[EventRule]) -> None: + pass + + +class SessionOutputLocation(abc.ABC): + pass + + +class LocalSessionOutputLocation(SessionOutputLocation): + def __init__(self, trace_path: pathlib.Path): + self._path = trace_path + + @property + def path(self) -> pathlib.Path: + return self._path + + +class ProcessAttributeTracker(abc.ABC): + """ + Process attribute tracker used to filter before the evaluation of event + rules. + + Note that this interface is currently limited as it doesn't allow changing + the tracking policy. For instance, it is not possible to set the tracking + policy back to "all" once it has transitioned to "include set". + """ + + class TrackingPolicy(enum.Enum): + INCLUDE_ALL = ( + enum.auto(), + """ + Track all possible process attribute value of a given type (i.e. no filtering). + This is the default state of a process attribute tracker. + """, + ) + EXCLUDE_ALL = ( + enum.auto(), + "Exclude all possible process attribute values of a given type.", + ) + INCLUDE_SET = enum.auto(), "Track a set of specific process attribute values." + + def __init__(self, policy: TrackingPolicy): + self._policy = policy + + @property + def tracking_policy(self) -> TrackingPolicy: + return self._policy + + +class ProcessIDProcessAttributeTracker(ProcessAttributeTracker): + @abc.abstractmethod + def track(self, pid: int) -> None: + pass + + @abc.abstractmethod + def untrack(self, pid: int) -> None: + pass + + +class VirtualProcessIDProcessAttributeTracker(ProcessAttributeTracker): + @abc.abstractmethod + def track(self, vpid: int) -> None: + pass + + @abc.abstractmethod + def untrack(self, vpid: int) -> None: + pass + + +class UserIDProcessAttributeTracker(ProcessAttributeTracker): + @abc.abstractmethod + def track(self, uid: Union[int, str]) -> None: + pass + + @abc.abstractmethod + def untrack(self, uid: Union[int, str]) -> None: + pass + + +class VirtualUserIDProcessAttributeTracker(ProcessAttributeTracker): + @abc.abstractmethod + def track(self, vuid: Union[int, str]) -> None: + pass + + @abc.abstractmethod + def untrack(self, vuid: Union[int, str]) -> None: + pass + + +class GroupIDProcessAttributeTracker(ProcessAttributeTracker): + @abc.abstractmethod + def track(self, gid: Union[int, str]) -> None: + pass + + @abc.abstractmethod + def untrack(self, gid: Union[int, str]) -> None: + pass + + +class VirtualGroupIDProcessAttributeTracker(ProcessAttributeTracker): + @abc.abstractmethod + def track(self, vgid: Union[int, str]) -> None: + pass + + @abc.abstractmethod + def untrack(self, vgid: Union[int, str]) -> None: + pass + + +class Session(abc.ABC): + @staticmethod + def _generate_name() -> str: + return "session_{random_id}".format(random_id=_generate_random_string(8)) + + @property + @abc.abstractmethod + def name(self) -> str: + pass + + @property + @abc.abstractmethod + def output(self) -> Optional[Type[SessionOutputLocation]]: + pass + + @abc.abstractmethod + def add_channel( + self, domain: TracingDomain, channel_name: Optional[str] = None + ) -> Channel: + """Add a channel with default attributes to the session.""" + pass + + @abc.abstractmethod + def start(self) -> None: + pass + + @abc.abstractmethod + def stop(self) -> None: + pass + + @abc.abstractmethod + def destroy(self) -> None: + pass + + @abc.abstractproperty + def kernel_pid_process_attribute_tracker( + self, + ) -> Type[ProcessIDProcessAttributeTracker]: + raise NotImplementedError + + @abc.abstractproperty + def kernel_vpid_process_attribute_tracker( + self, + ) -> Type[VirtualProcessIDProcessAttributeTracker]: + raise NotImplementedError + + @abc.abstractproperty + def user_vpid_process_attribute_tracker( + self, + ) -> Type[VirtualProcessIDProcessAttributeTracker]: + raise NotImplementedError + + @abc.abstractproperty + def kernel_gid_process_attribute_tracker( + self, + ) -> Type[GroupIDProcessAttributeTracker]: + raise NotImplementedError + + @abc.abstractproperty + def kernel_vgid_process_attribute_tracker( + self, + ) -> Type[VirtualGroupIDProcessAttributeTracker]: + raise NotImplementedError + + @abc.abstractproperty + def user_vgid_process_attribute_tracker( + self, + ) -> Type[VirtualGroupIDProcessAttributeTracker]: + raise NotImplementedError + + @abc.abstractproperty + def kernel_uid_process_attribute_tracker( + self, + ) -> Type[UserIDProcessAttributeTracker]: + raise NotImplementedError + + @abc.abstractproperty + def kernel_vuid_process_attribute_tracker( + self, + ) -> Type[VirtualUserIDProcessAttributeTracker]: + raise NotImplementedError + + @abc.abstractproperty + def user_vuid_process_attribute_tracker( + self, + ) -> Type[VirtualUserIDProcessAttributeTracker]: + raise NotImplementedError + + +class ControlException(RuntimeError): + """Base type for exceptions thrown by a controller.""" + + def __init__(self, msg: str): + super().__init__(msg) + + +class Controller(abc.ABC): + """ + Interface of a top-level control interface. A control interface can be, for + example, the LTTng client or a wrapper around liblttng-ctl. It is used to + create and manage top-level objects of a session daemon instance. + """ + + @abc.abstractmethod + def create_session( + self, name: Optional[str] = None, output: Optional[SessionOutputLocation] = None + ) -> Session: + """ + Create a session with an output. Don't specify an output + to create a session without an output. + """ + pass diff --git a/tests/utils/lttngtest/tap_generator.py b/tests/utils/lttngtest/tap_generator.py new file mode 100644 index 000000000..c28e87d36 --- /dev/null +++ b/tests/utils/lttngtest/tap_generator.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2022 Jérémie Galarneau +# +# SPDX-License-Identifier: GPL-2.0-only +# + +import contextlib +import sys +from typing import Optional + + +class InvalidTestPlan(RuntimeError): + def __init__(self, msg: str): + super().__init__(msg) + + +class BailOut(RuntimeError): + def __init__(self, msg: str): + super().__init__(msg) + + +class TestCase: + def __init__(self, tap_generator: "TapGenerator", description: str): + self._tap_generator = tap_generator + self._result: Optional[bool] = None + self._description = description + + @property + def result(self) -> Optional[bool]: + return self._result + + @property + def description(self) -> str: + return self._description + + def _set_result(self, result: bool) -> None: + if self._result is not None: + raise RuntimeError("Can't set test case result twice") + + self._result = result + self._tap_generator.test(result, self._description) + + def success(self) -> None: + self._set_result(True) + + def fail(self) -> None: + self._set_result(False) + + +# Produces a test execution report in the TAP format. +class TapGenerator: + def __init__(self, total_test_count: int): + if total_test_count <= 0: + raise ValueError("Test count must be greater than zero") + + self._total_test_count: int = total_test_count + self._last_test_case_id: int = 0 + self._printed_plan: bool = False + self._has_failure: bool = False + + def __del__(self): + if self.remaining_test_cases > 0: + self.bail_out( + "Missing {remaining_test_cases} test cases".format( + remaining_test_cases=self.remaining_test_cases + ) + ) + + @property + def remaining_test_cases(self) -> int: + return self._total_test_count - self._last_test_case_id + + def _print(self, msg: str) -> None: + if not self._printed_plan: + print( + "1..{total_test_count}".format(total_test_count=self._total_test_count), + flush=True, + ) + self._printed_plan = True + + print(msg, flush=True) + + def skip_all(self, reason) -> None: + if self._last_test_case_id != 0: + raise RuntimeError("Can't skip all tests after running test cases") + + if reason: + self._print("1..0 # Skip all: {reason}".format(reason=reason)) + + self._last_test_case_id = self._total_test_count + + def skip(self, reason, skip_count: int = 1) -> None: + for i in range(skip_count): + self._last_test_case_id = self._last_test_case_id + 1 + self._print( + "ok {test_number} # Skip: {reason}".format( + reason=reason, test_number=(i + self._last_test_case_id) + ) + ) + + def bail_out(self, reason: str) -> None: + self._print("Bail out! {reason}".format(reason=reason)) + self._last_test_case_id = self._total_test_count + raise BailOut(reason) + + def test(self, result: bool, description: str) -> None: + if self._last_test_case_id == self._total_test_count: + raise InvalidTestPlan("Executing too many tests") + + if result is False: + self._has_failure = True + + result_string = "ok" if result else "not ok" + self._last_test_case_id = self._last_test_case_id + 1 + self._print( + "{result_string} {case_id} - {description}".format( + result_string=result_string, + case_id=self._last_test_case_id, + description=description, + ) + ) + + def ok(self, description: str) -> None: + self.test(True, description) + + def fail(self, description: str) -> None: + self.test(False, description) + + @property + def is_successful(self) -> bool: + return ( + self._last_test_case_id == self._total_test_count and not self._has_failure + ) + + @contextlib.contextmanager + def case(self, description: str): + test_case = TestCase(self, description) + try: + yield test_case + except Exception as e: + self.diagnostic( + "Exception `{exception_type}` thrown during test case `{description}`, marking as failure.".format( + description=test_case.description, exception_type=type(e).__name__ + ) + ) + + if str(e) != "": + self.diagnostic(str(e)) + + test_case.fail() + finally: + if test_case.result is None: + test_case.success() + + def diagnostic(self, msg) -> None: + print("# {msg}".format(msg=msg), file=sys.stderr, flush=True) -- 2.34.1