Skip to content

Commit 2fe82b2

Browse files
committed
[skip ci] Simplify test case.
1 parent b160f6c commit 2fe82b2

File tree

2 files changed

+39
-58
lines changed

2 files changed

+39
-58
lines changed
+26-36
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,37 @@
1-
// Copyright (c) 2021 The Pybind Development Team.
1+
// Copyright (c) 2025 The Pybind Development Team.
22
// All rights reserved. Use of this source code is governed by a
33
// BSD-style license that can be found in the LICENSE file.
44

55
#include "pybind11_tests.h"
66

7-
#include <utility>
7+
#include <memory>
88

99
namespace pybind11_tests {
1010
namespace class_sh_trampoline_weak_ptr {
1111

12-
// // For testing whether a python subclass of a C++ object can be accessed from a C++ weak_ptr
13-
struct WpBase {
14-
// returns true if the base virtual function is called
15-
virtual bool is_base_used() { return true; }
16-
17-
// returns true if there's an associated python instance
18-
bool has_python_instance() {
19-
auto *tinfo = py::detail::get_type_info(typeid(WpBase));
20-
return (bool) py::detail::get_object_handle(this, tinfo);
21-
}
22-
23-
WpBase() = default;
24-
WpBase(const WpBase &) = delete;
25-
virtual ~WpBase() = default;
12+
struct VirtBase {
13+
virtual ~VirtBase() = default;
14+
virtual int get_code() { return 100; }
2615
};
2716

28-
struct PyWpBase : WpBase {
29-
using WpBase::WpBase;
30-
bool is_base_used() override { PYBIND11_OVERRIDE(bool, WpBase, is_base_used); }
17+
struct PyVirtBase : VirtBase, py::trampoline_self_life_support {
18+
using VirtBase::VirtBase;
19+
int get_code() override { PYBIND11_OVERRIDE(int, VirtBase, get_code); }
3120
};
3221

33-
struct WpBaseTester {
34-
std::shared_ptr<WpBase> get_object() const { return m_obj.lock(); }
35-
void set_object(std::shared_ptr<WpBase> obj) { m_obj = obj; }
36-
bool is_expired() { return m_obj.expired(); }
37-
bool is_base_used() { return m_obj.lock()->is_base_used(); }
38-
std::weak_ptr<WpBase> m_obj;
22+
struct WpOwner {
23+
void set_wp(std::shared_ptr<VirtBase> sp) { wp = sp; }
24+
25+
int get_code() {
26+
auto sp = wp.lock();
27+
if (!sp) {
28+
return -999;
29+
}
30+
return sp->get_code();
31+
}
32+
33+
private:
34+
std::weak_ptr<VirtBase> wp;
3935
};
4036

4137
} // namespace class_sh_trampoline_weak_ptr
@@ -44,18 +40,12 @@ struct WpBaseTester {
4440
using namespace pybind11_tests::class_sh_trampoline_weak_ptr;
4541

4642
TEST_SUBMODULE(class_sh_trampoline_weak_ptr, m) {
47-
// For testing whether a python subclass of a C++ object can be accessed from a C++ weak_ptr
48-
49-
py::classh<WpBase, PyWpBase>(m, "WpBase")
43+
py::classh<VirtBase, PyVirtBase>(m, "VirtBase")
5044
.def(py::init<>())
51-
.def(py::init([](int) { return std::make_shared<PyWpBase>(); }))
52-
.def("is_base_used", &WpBase::is_base_used)
53-
.def("has_python_instance", &WpBase::has_python_instance);
45+
.def("get_code", &VirtBase::get_code);
5446

55-
py::classh<WpBaseTester>(m, "WpBaseTester")
47+
py::classh<WpOwner>(m, "WpOwner")
5648
.def(py::init<>())
57-
.def("get_object", &WpBaseTester::get_object)
58-
.def("set_object", &WpBaseTester::set_object)
59-
.def("is_expired", &WpBaseTester::is_expired)
60-
.def("is_base_used", &WpBaseTester::is_base_used);
49+
.def("set_wp", &WpOwner::set_wp)
50+
.def("get_code", &WpOwner::get_code);
6151
}

tests/test_class_sh_trampoline_weak_ptr.py

+13-22
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,22 @@
66
import pybind11_tests.class_sh_trampoline_weak_ptr as m
77

88

9-
@pytest.mark.skipif("env.GRAALPY", reason="Cannot reliably trigger GC")
10-
def test_weak_ptr_base():
11-
tester = m.WpBaseTester()
12-
13-
obj = m.WpBase()
14-
15-
tester.set_object(obj)
16-
17-
assert tester.is_expired() is False
18-
assert tester.is_base_used() is True
19-
assert tester.get_object().is_base_used() is True
9+
class PyDrvd(m.VirtBase):
10+
def get_code(self):
11+
return 200
2012

2113

2214
@pytest.mark.skipif("env.GRAALPY", reason="Cannot reliably trigger GC")
23-
def test_weak_ptr_child():
24-
class PyChild(m.WpBase):
25-
def is_base_used(self):
26-
return False
27-
28-
tester = m.WpBaseTester()
15+
@pytest.mark.parametrize(("vtype", "expected_code"), [(m.VirtBase, 100), (PyDrvd, 200)])
16+
def test_weak_ptr_base(vtype, expected_code):
17+
wpo = m.WpOwner()
18+
assert wpo.get_code() == -999
2919

30-
obj = PyChild()
20+
obj = vtype()
21+
assert obj.get_code() == expected_code
3122

32-
tester.set_object(obj)
23+
wpo.set_wp(obj)
24+
assert wpo.get_code() == expected_code
3325

34-
assert tester.is_expired() is False
35-
assert tester.is_base_used() is False
36-
assert tester.get_object().is_base_used() is False
26+
del obj
27+
assert wpo.get_code() == -999

0 commit comments

Comments
 (0)