diff --git a/datadog_checks_base/datadog_checks/base/utils/tracing.py b/datadog_checks_base/datadog_checks/base/utils/tracing.py
index 5899b0443cb51..71bf6b097fe11 100644
--- a/datadog_checks_base/datadog_checks/base/utils/tracing.py
+++ b/datadog_checks_base/datadog_checks/base/utils/tracing.py
@@ -45,12 +45,15 @@ def _get_integration_name(function_name, self, *args, **kwargs):
return integration_name if integration_name else "UNKNOWN_INTEGRATION"
-def tracing_method(f, tracer):
+def tracing_method(f, tracer, is_entry_point):
if inspect.signature(f).parameters.get('self'):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
integration_name = _get_integration_name(f.__name__, self, *args, **kwargs)
+ if is_entry_point:
+ configure_tracer(tracer, self)
+
with tracer.trace(f.__name__, service=INTEGRATION_TRACING_SERVICE_NAME, resource=integration_name) as span:
span.set_tag('_dd.origin', INTEGRATION_TRACING_SERVICE_NAME)
return f(self, *args, **kwargs)
@@ -69,7 +72,7 @@ def wrapper(*args, **kwargs):
def traced_warning(f, tracer):
"""
- Traces the AgentCheck.warning method
+ Traces the AgentCheck.warning method.
The span is always an error span, including the current stack trace.
The error message is set to the warning message.
"""
@@ -102,6 +105,57 @@ def wrapper(self, warning_message, *args, **kwargs):
return f
+def configure_tracer(tracer, self_check):
+ """
+ Generate a tracer context for the given function with configurable sampling rate.
+ If not set or invalid, defaults to 0 (no sampling).
+ The tracer context is only set at entry point functions so we can attach a trace root to the span.
+ """
+ apm_tracing_enabled = False
+ context_provider = None
+ try:
+ integration_tracing, integration_tracing_exhaustive = tracing_enabled()
+ if integration_tracing or integration_tracing_exhaustive:
+ apm_tracing_enabled = True
+
+ # If the check has a dd_trace_id and dd_parent_id, we can use it to create a trace root
+ dd_parent_id = None
+ dd_trace_id = None
+ if hasattr(self_check, "instance") and self_check.instance:
+ dd_trace_id = self_check.instance.get("dd_trace_id", None)
+ dd_parent_id = self_check.instance.get("dd_parent_span_id", None)
+ elif hasattr(self_check, "instances") and self_check.instances and len(self_check.instances) > 0:
+ dd_trace_id = self_check.instances[0].get("dd_trace_id", None)
+ dd_parent_id = self_check.instances[0].get("dd_parent_span_id", None)
+
+ if dd_trace_id and dd_parent_id:
+ from ddtrace.context import Context
+
+ apm_tracing_enabled = True
+ context_provider = Context(
+ trace_id=dd_trace_id,
+ span_id=dd_parent_id,
+ )
+ except (ValueError, TypeError, AttributeError, ImportError):
+ pass
+
+ try:
+ # Update the tracer configuration to make sure we trace only if we really need to
+ tracer.configure(
+ appsec_enabled=False,
+ enabled=apm_tracing_enabled,
+ )
+
+ # If the current trace context is not set or is set to an empty trace_id, activate the context provider
+ current_context = tracer.current_trace_context()
+ if (
+ current_context is None or (current_context is not None and len(current_context.trace_id) == 0)
+ ) and context_provider:
+ tracer.context_provider.activate(context_provider)
+ except Exception:
+ pass
+
+
def tracing_enabled():
"""
:return: (integration_tracing, integration_tracing_exhaustive)
@@ -118,42 +172,46 @@ def tracing_enabled():
def traced_class(cls):
- integration_tracing, integration_tracing_exhaustive = tracing_enabled()
- if integration_tracing:
- try:
- integration_tracing_exhaustive = is_affirmative(datadog_agent.get_config('integration_tracing_exhaustive'))
+ """
+ Decorator that adds tracing to all methods of a class.
+ Only traces specific methods by default, unless exhaustive tracing is enabled.
+ """
+ _, integration_tracing_exhaustive = tracing_enabled()
- from ddtrace import patch_all, tracer
+ try:
+ from ddtrace import patch_all, tracer
+
+ patch_all()
- patch_all()
+ def decorate(cls):
+ for attr in cls.__dict__:
+ attribute = getattr(cls, attr)
- def decorate(cls):
- for attr in cls.__dict__:
- attribute = getattr(cls, attr)
+ if not callable(attribute) or inspect.isclass(attribute):
+ continue
- if not callable(attribute) or inspect.isclass(attribute):
- continue
+ # Ignoring staticmethod and classmethod because they don't need cls in args
+ # also ignore nested classes
+ if isinstance(cls.__dict__[attr], staticmethod) or isinstance(cls.__dict__[attr], classmethod):
+ continue
- # Ignoring staticmethod and classmethod because they don't need cls in args
- # also ignore nested classes
- if isinstance(cls.__dict__[attr], staticmethod) or isinstance(cls.__dict__[attr], classmethod):
- continue
+ # Get rid of SnmpCheck._thread_factory and related
+ if getattr(attribute, '__module__', 'threading') in EXCLUDED_MODULES:
+ continue
- # Get rid of SnmpCheck._thread_factory and related
- if getattr(attribute, '__module__', 'threading') in EXCLUDED_MODULES:
- continue
+ if not integration_tracing_exhaustive and attr not in AGENT_CHECK_DEFAULT_TRACED_METHODS:
+ continue
- if not integration_tracing_exhaustive and attr not in AGENT_CHECK_DEFAULT_TRACED_METHODS:
- continue
+ is_entry_point = attr == 'run' or attr == 'check'
- if attr == 'warning':
- setattr(cls, attr, traced_warning(attribute, tracer))
- else:
- setattr(cls, attr, tracing_method(attribute, tracer))
- return cls
+ if attr == 'warning':
+ setattr(cls, attr, traced_warning(attribute, tracer))
+ else:
+ setattr(cls, attr, tracing_method(attribute, tracer, is_entry_point))
+ return cls
- return decorate(cls)
- except Exception:
- pass
+ return decorate(cls)
+ except Exception:
+ pass
return cls
diff --git a/datadog_checks_base/tests/base/utils/test_tracing.py b/datadog_checks_base/tests/base/utils/test_tracing.py
index 2ee9b88490972..c4925506e3a71 100644
--- a/datadog_checks_base/tests/base/utils/test_tracing.py
+++ b/datadog_checks_base/tests/base/utils/test_tracing.py
@@ -79,37 +79,71 @@ def traced_mock_classes():
'integration_tracing_exhaustive',
[pytest.param(False, id="exhaustive_false"), pytest.param(True, id="exhaustive_true")],
)
-def test_traced_class(integration_tracing, integration_tracing_exhaustive, datadog_agent):
+@pytest.mark.parametrize(
+ 'dd_trace_id', [pytest.param(None, id="no_trace_id"), pytest.param("123456789", id="with_trace_id")]
+)
+@pytest.mark.parametrize(
+ 'dd_parent_id', [pytest.param(None, id="no_parent_id"), pytest.param("987654321", id="with_parent_id")]
+)
+def test_traced_class(integration_tracing, integration_tracing_exhaustive, dd_trace_id, dd_parent_id, datadog_agent):
def _get_config(key):
return {
'integration_tracing': str(integration_tracing).lower(),
'integration_tracing_exhaustive': str(integration_tracing_exhaustive).lower(),
}.get(key, None)
+ instance = {}
+ if dd_trace_id is not None:
+ instance['dd_trace_id'] = dd_trace_id
+ if dd_parent_id is not None:
+ instance['dd_parent_span_id'] = dd_parent_id
+
with mock.patch.object(datadog_agent, 'get_config', _get_config), mock.patch('ddtrace.tracer') as tracer:
+ # Track the last activated context
+ def mock_activate(context):
+ def mock_current_trace_context():
+ return context
+
+ tracer.current_trace_context.side_effect = mock_current_trace_context
+
+ tracer.context_provider.activate.side_effect = mock_activate
+
with traced_mock_classes():
- check = DummyCheck('dummy', {}, [{}])
+ check = DummyCheck('dummy', {}, [instance])
check.run()
- if integration_tracing:
- called_services = {c.kwargs['service'] for c in tracer.trace.mock_calls if 'service' in c.kwargs}
- called_methods = {c.args[0] for c in tracer.trace.mock_calls if c.args}
-
- assert called_services == {INTEGRATION_TRACING_SERVICE_NAME}
- for m in AGENT_CHECK_DEFAULT_TRACED_METHODS:
+ called_services = {c.kwargs['service'] for c in tracer.trace.mock_calls if 'service' in c.kwargs}
+ called_methods = {c.args[0] for c in tracer.trace.mock_calls if c.args}
+
+ assert called_services == {INTEGRATION_TRACING_SERVICE_NAME}
+ for m in AGENT_CHECK_DEFAULT_TRACED_METHODS:
+ assert m in called_methods
+
+ warning_span_tag_calls = tracer.trace().__enter__().set_tag.call_args_list
+ assert mock.call('_dd.origin', INTEGRATION_TRACING_SERVICE_NAME) in warning_span_tag_calls
+ assert mock.call(ERROR_MSG, 'whoops oh no') in warning_span_tag_calls
+ assert mock.call(ERROR_TYPE, 'AgentCheck.warning') in warning_span_tag_calls
+
+ # If dd_trace_id and dd_parent_id are set, verify context provider is activated
+ if dd_trace_id is not None and dd_parent_id is not None:
+ # Assert called once
+ tracer.context_provider.activate.assert_called_once()
+ context = tracer.context_provider.activate.call_args[0][0]
+ assert context.trace_id == dd_trace_id
+ assert context.span_id == dd_parent_id
+
+ # Check that the tracer is configured with the correct enabled value
+ tracing = (
+ integration_tracing
+ or integration_tracing_exhaustive
+ or (dd_trace_id is not None and dd_parent_id is not None)
+ )
+ assert tracer.configure.call_args[1]['enabled'] is tracing
+
+ exhaustive_only_methods = {'__init__', 'dummy_method'}
+ if integration_tracing_exhaustive:
+ for m in exhaustive_only_methods:
assert m in called_methods
-
- warning_span_tag_calls = tracer.trace().__enter__().set_tag.call_args_list
- assert mock.call('_dd.origin', INTEGRATION_TRACING_SERVICE_NAME) in warning_span_tag_calls
- assert mock.call(ERROR_MSG, 'whoops oh no') in warning_span_tag_calls
- assert mock.call(ERROR_TYPE, 'AgentCheck.warning') in warning_span_tag_calls
-
- exhaustive_only_methods = {'__init__', 'dummy_method'}
- if integration_tracing_exhaustive:
- for m in exhaustive_only_methods:
- assert m in called_methods
- else:
- for m in exhaustive_only_methods:
- assert m not in called_methods
else:
- tracer.trace.assert_not_called()
+ for m in exhaustive_only_methods:
+ assert m not in called_methods