Skip to content

Commit 2e73eb8

Browse files
anandoleecopybara-github
authored andcommitted
Add reference leak check to Python well_known_types_tests. Fix two refleak bugs.
-Fix ref leak on Struct field creation -Fix ref leak on in operator for ListValue PiperOrigin-RevId: 758351419
1 parent 6193a8c commit 2e73eb8

File tree

3 files changed

+33
-24
lines changed

3 files changed

+33
-24
lines changed

python/google/protobuf/internal/well_known_types_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from google.protobuf import json_format
1717
from google.protobuf import text_format
1818
from google.protobuf.internal import more_messages_pb2
19+
from google.protobuf.internal import testing_refleaks
1920
from google.protobuf.internal import well_known_types
2021
from google.protobuf.internal import well_known_types_test_pb2
2122

@@ -54,6 +55,7 @@ def CheckDurationConversion(self, message, text):
5455
self.assertEqual(message, parsed_message)
5556

5657

58+
@testing_refleaks.TestCase
5759
class TimeUtilTest(TimeUtilTestBase):
5860

5961
def testTimestampSerializeAndParse(self):
@@ -704,6 +706,7 @@ def testDurationSub(self, old_time, time_delta, expected_sec, expected_nano):
704706
self.assertEqual(expected_nano, msg.optional_timestamp.nanos)
705707

706708

709+
@testing_refleaks.TestCase
707710
class StructTest(unittest.TestCase):
708711

709712
def testEmptyDict(self):
@@ -987,6 +990,7 @@ def testMergeFrom(self):
987990
self.assertEqual(5, struct['key5'][0][1])
988991

989992

993+
@testing_refleaks.TestCase
990994
class AnyTest(unittest.TestCase):
991995

992996
def testAnyMessage(self):

python/google/protobuf/pyext/message.cc

+16-16
Original file line numberDiff line numberDiff line change
@@ -1124,14 +1124,16 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
11241124
Descriptor::WELLKNOWNTYPE_STRUCT) {
11251125
ScopedPyObjectPtr ok(PyObject_CallMethod(
11261126
reinterpret_cast<PyObject*>(cmessage), "update", "O", value));
1127-
if (ok.get() == nullptr && PyDict_Size(value) == 1 &&
1128-
PyDict_Contains(value, PyUnicode_FromString("fields"))) {
1129-
// Fallback to init as normal message field.
1130-
PyErr_Clear();
1131-
PyObject* tmp = Clear(cmessage);
1132-
Py_DECREF(tmp);
1133-
if (InitAttributes(cmessage, nullptr, value) < 0) {
1134-
return -1;
1127+
if (ok.get() == nullptr && PyDict_Size(value) == 1) {
1128+
ScopedPyObjectPtr fields_str(PyUnicode_FromString("fields"));
1129+
if (PyDict_Contains(value, fields_str.get())) {
1130+
// Fallback to init as normal message field.
1131+
PyErr_Clear();
1132+
PyObject* tmp = Clear(cmessage);
1133+
Py_DECREF(tmp);
1134+
if (InitAttributes(cmessage, nullptr, value) < 0) {
1135+
return -1;
1136+
}
11351137
}
11361138
}
11371139
} else {
@@ -2391,21 +2393,19 @@ PyObject* Contains(CMessage* self, PyObject* arg) {
23912393
const Reflection* reflection = message->GetReflection();
23922394
const FieldDescriptor* map_field = descriptor->FindFieldByName("fields");
23932395
const FieldDescriptor* key_field = map_field->message_type()->map_key();
2394-
PyObject* py_string = CheckString(arg, key_field);
2395-
if (!py_string) {
2396+
ScopedPyObjectPtr py_string(CheckString(arg, key_field));
2397+
if (py_string.get() == nullptr) {
23962398
PyErr_SetString(PyExc_TypeError,
23972399
"The key passed to Struct message must be a str.");
23982400
return nullptr;
23992401
}
24002402
char* value;
24012403
Py_ssize_t value_len;
2402-
if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
2403-
Py_DECREF(py_string);
2404+
if (PyBytes_AsStringAndSize(py_string.get(), &value, &value_len) < 0) {
24042405
Py_RETURN_FALSE;
24052406
}
24062407
std::string key_str;
24072408
key_str.assign(value, value_len);
2408-
Py_DECREF(py_string);
24092409

24102410
MapKey map_key;
24112411
map_key.SetStringValue(key_str);
@@ -2414,9 +2414,9 @@ PyObject* Contains(CMessage* self, PyObject* arg) {
24142414
}
24152415
case Descriptor::WELLKNOWNTYPE_LISTVALUE: {
24162416
// For WKT ListValue, check if the key is in the items.
2417-
PyObject* items = PyObject_CallMethod(reinterpret_cast<PyObject*>(self),
2418-
"items", nullptr);
2419-
return PyBool_FromLong(PySequence_Contains(items, arg));
2417+
ScopedPyObjectPtr items(PyObject_CallMethod(
2418+
reinterpret_cast<PyObject*>(self), "items", nullptr));
2419+
return PyBool_FromLong(PySequence_Contains(items.get(), arg));
24202420
}
24212421
default:
24222422
// For other messages, check with HasField.

python/message.c

+13-8
Original file line numberDiff line numberDiff line change
@@ -459,13 +459,16 @@ static bool PyUpb_Message_InitMessageAttribute(PyObject* _self, PyObject* name,
459459
const upb_MessageDef* msgdef = upb_FieldDef_MessageSubDef(field);
460460
if (upb_MessageDef_WellKnownType(msgdef) == kUpb_WellKnown_Struct) {
461461
ok = PyObject_CallMethod(submsg, "_internal_assign", "O", value);
462-
if (!ok && PyDict_Size(value) == 1 &&
463-
PyDict_Contains(value, PyUnicode_FromString("fields"))) {
464-
// Fall back to init as normal message field.
465-
PyErr_Clear();
466-
PyObject* tmp = PyUpb_Message_Clear((PyUpb_Message*)submsg);
467-
Py_DECREF(tmp);
468-
ok = PyUpb_Message_InitAttributes(submsg, NULL, value) >= 0;
462+
if (!ok && PyDict_Size(value) == 1) {
463+
PyObject* fields_str = PyUnicode_FromString("fields");
464+
if (PyDict_Contains(value, fields_str)) {
465+
// Fall back to init as normal message field.
466+
PyErr_Clear();
467+
PyObject* tmp = PyUpb_Message_Clear((PyUpb_Message*)submsg);
468+
Py_DECREF(tmp);
469+
ok = PyUpb_Message_InitAttributes(submsg, NULL, value) >= 0;
470+
}
471+
Py_DECREF(fields_str);
469472
}
470473
} else {
471474
ok = PyUpb_Message_InitAttributes(submsg, NULL, value) >= 0;
@@ -1127,7 +1130,9 @@ static PyObject* PyUpb_Message_Contains(PyObject* _self, PyObject* arg) {
11271130
PyUpb_Message* self = (void*)_self;
11281131
if (PyUpb_Message_IsStub(self)) Py_RETURN_FALSE;
11291132
PyObject* items = PyObject_CallMethod(_self, "items", NULL);
1130-
return PyBool_FromLong(PySequence_Contains(items, arg));
1133+
int ret = PySequence_Contains(items, arg);
1134+
Py_DECREF(items);
1135+
return PyBool_FromLong(ret);
11311136
}
11321137
default:
11331138
// For other messages, check with HasField.

0 commit comments

Comments
 (0)