Skip to content

Commit e20c4c4

Browse files
anandoleecopybara-github
authored andcommitted
Add reference leak check to Python well_known_types_test. Fix two refleak bugs.
-Fix ref leak when assign Struct field with creation -Fix ref leak on "in" operator for ListValue PiperOrigin-RevId: 758351419
1 parent fc8a522 commit e20c4c4

File tree

3 files changed

+33
-24
lines changed

3 files changed

+33
-24
lines changed

python/google/protobuf/internal/well_known_types_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from google.protobuf import json_format
1717
from google.protobuf import text_format
1818
from google.protobuf.internal import more_messages_pb2
19+
from google.protobuf.internal import testing_refleaks
1920
from google.protobuf.internal import well_known_types
2021
from google.protobuf.internal import well_known_types_test_pb2
2122

@@ -54,6 +55,7 @@ def CheckDurationConversion(self, message, text):
5455
self.assertEqual(message, parsed_message)
5556

5657

58+
@testing_refleaks.TestCase
5759
class TimeUtilTest(TimeUtilTestBase):
5860

5961
def testTimestampSerializeAndParse(self):
@@ -704,6 +706,7 @@ def testDurationSub(self, old_time, time_delta, expected_sec, expected_nano):
704706
self.assertEqual(expected_nano, msg.optional_timestamp.nanos)
705707

706708

709+
@testing_refleaks.TestCase
707710
class StructTest(unittest.TestCase):
708711

709712
def testEmptyDict(self):
@@ -987,6 +990,7 @@ def testMergeFrom(self):
987990
self.assertEqual(5, struct['key5'][0][1])
988991

989992

993+
@testing_refleaks.TestCase
990994
class AnyTest(unittest.TestCase):
991995

992996
def testAnyMessage(self):

python/google/protobuf/pyext/message.cc

+16-16
Original file line numberDiff line numberDiff line change
@@ -1124,14 +1124,16 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
11241124
Descriptor::WELLKNOWNTYPE_STRUCT) {
11251125
ScopedPyObjectPtr ok(PyObject_CallMethod(
11261126
reinterpret_cast<PyObject*>(cmessage), "update", "O", value));
1127-
if (ok.get() == nullptr && PyDict_Size(value) == 1 &&
1128-
PyDict_Contains(value, PyUnicode_FromString("fields"))) {
1129-
// Fallback to init as normal message field.
1130-
PyErr_Clear();
1131-
PyObject* tmp = Clear(cmessage);
1132-
Py_DECREF(tmp);
1133-
if (InitAttributes(cmessage, nullptr, value) < 0) {
1134-
return -1;
1127+
if (ok.get() == nullptr && PyDict_Size(value) == 1) {
1128+
ScopedPyObjectPtr fields_str(PyUnicode_FromString("fields"));
1129+
if (PyDict_Contains(value, fields_str.get())) {
1130+
// Fallback to init as normal message field.
1131+
PyErr_Clear();
1132+
PyObject* tmp = Clear(cmessage);
1133+
Py_DECREF(tmp);
1134+
if (InitAttributes(cmessage, nullptr, value) < 0) {
1135+
return -1;
1136+
}
11351137
}
11361138
}
11371139
} else {
@@ -2391,21 +2393,19 @@ PyObject* Contains(CMessage* self, PyObject* arg) {
23912393
const Reflection* reflection = message->GetReflection();
23922394
const FieldDescriptor* map_field = descriptor->FindFieldByName("fields");
23932395
const FieldDescriptor* key_field = map_field->message_type()->map_key();
2394-
PyObject* py_string = CheckString(arg, key_field);
2395-
if (!py_string) {
2396+
ScopedPyObjectPtr py_string(CheckString(arg, key_field));
2397+
if (py_string.get() == nullptr) {
23962398
PyErr_SetString(PyExc_TypeError,
23972399
"The key passed to Struct message must be a str.");
23982400
return nullptr;
23992401
}
24002402
char* value;
24012403
Py_ssize_t value_len;
2402-
if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
2403-
Py_DECREF(py_string);
2404+
if (PyBytes_AsStringAndSize(py_string.get(), &value, &value_len) < 0) {
24042405
Py_RETURN_FALSE;
24052406
}
24062407
std::string key_str;
24072408
key_str.assign(value, value_len);
2408-
Py_DECREF(py_string);
24092409

24102410
MapKey map_key;
24112411
map_key.SetStringValue(key_str);
@@ -2414,9 +2414,9 @@ PyObject* Contains(CMessage* self, PyObject* arg) {
24142414
}
24152415
case Descriptor::WELLKNOWNTYPE_LISTVALUE: {
24162416
// For WKT ListValue, check if the key is in the items.
2417-
PyObject* items = PyObject_CallMethod(reinterpret_cast<PyObject*>(self),
2418-
"items", nullptr);
2419-
return PyBool_FromLong(PySequence_Contains(items, arg));
2417+
ScopedPyObjectPtr items(PyObject_CallMethod(
2418+
reinterpret_cast<PyObject*>(self), "items", nullptr));
2419+
return PyBool_FromLong(PySequence_Contains(items.get(), arg));
24202420
}
24212421
default:
24222422
// For other messages, check with HasField.

python/message.c

+13-8
Original file line numberDiff line numberDiff line change
@@ -453,13 +453,16 @@ static bool PyUpb_Message_InitMessageAttribute(PyObject* _self, PyObject* name,
453453
const upb_MessageDef* msgdef = upb_FieldDef_MessageSubDef(field);
454454
if (upb_MessageDef_WellKnownType(msgdef) == kUpb_WellKnown_Struct) {
455455
ok = PyObject_CallMethod(submsg, "_internal_assign", "O", value);
456-
if (!ok && PyDict_Size(value) == 1 &&
457-
PyDict_Contains(value, PyUnicode_FromString("fields"))) {
458-
// Fall back to init as normal message field.
459-
PyErr_Clear();
460-
PyObject* tmp = PyUpb_Message_Clear((PyUpb_Message*)submsg);
461-
Py_DECREF(tmp);
462-
ok = PyUpb_Message_InitAttributes(submsg, NULL, value) >= 0;
456+
if (!ok && PyDict_Size(value) == 1) {
457+
PyObject* fields_str = PyUnicode_FromString("fields");
458+
if (PyDict_Contains(value, fields_str)) {
459+
// Fall back to init as normal message field.
460+
PyErr_Clear();
461+
PyObject* tmp = PyUpb_Message_Clear((PyUpb_Message*)submsg);
462+
Py_DECREF(tmp);
463+
ok = PyUpb_Message_InitAttributes(submsg, NULL, value) >= 0;
464+
}
465+
Py_DECREF(fields_str);
463466
}
464467
} else {
465468
ok = PyUpb_Message_InitAttributes(submsg, NULL, value) >= 0;
@@ -1110,7 +1113,9 @@ static PyObject* PyUpb_Message_Contains(PyObject* _self, PyObject* arg) {
11101113
PyUpb_Message* self = (void*)_self;
11111114
if (PyUpb_Message_IsStub(self)) Py_RETURN_FALSE;
11121115
PyObject* items = PyObject_CallMethod(_self, "items", NULL);
1113-
return PyBool_FromLong(PySequence_Contains(items, arg));
1116+
int ret = PySequence_Contains(items, arg);
1117+
Py_DECREF(items);
1118+
return PyBool_FromLong(ret);
11141119
}
11151120
default:
11161121
// For other messages, check with HasField.

0 commit comments

Comments
 (0)