Skip to content

Commit 6919976

Browse files
authored
Merge branch 'main' into mps
2 parents 8c470f1 + f066e63 commit 6919976

File tree

4 files changed

+22
-9
lines changed

4 files changed

+22
-9
lines changed

cirq-google/cirq_google/engine/engine_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ async def list_programs_async(
211211
created_before: datetime.datetime | datetime.date | None = None,
212212
created_after: datetime.datetime | datetime.date | None = None,
213213
has_labels: dict[str, str] | None = None,
214-
):
214+
) -> list[quantum.QuantumProgram]:
215215
"""Returns a list of previously executed quantum programs.
216216
217217
Args:
@@ -242,7 +242,7 @@ async def list_programs_async(
242242
request = quantum.ListQuantumProgramsRequest(
243243
parent=_project_name(project_id), filter=" AND ".join(filters)
244244
)
245-
return await self._send_request_async(self.grpc_client.list_quantum_programs, request)
245+
return await self._send_list_request_async(self.grpc_client.list_quantum_programs, request)
246246

247247
list_programs = duet.sync(list_programs_async)
248248

@@ -485,7 +485,7 @@ async def list_jobs_async(
485485
execution_states: set[quantum.ExecutionStatus.State] | None = None,
486486
executed_processor_ids: list[str] | None = None,
487487
scheduled_processor_ids: list[str] | None = None,
488-
):
488+
) -> list[quantum.QuantumJob]:
489489
"""Returns the list of jobs for a given program.
490490
491491
Args:
@@ -545,7 +545,7 @@ async def list_jobs_async(
545545
program_id = "-"
546546
parent = _program_name_from_ids(project_id, program_id)
547547
request = quantum.ListQuantumJobsRequest(parent=parent, filter=" AND ".join(filters))
548-
return await self._send_request_async(self.grpc_client.list_quantum_jobs, request)
548+
return await self._send_list_request_async(self.grpc_client.list_quantum_jobs, request)
549549

550550
list_jobs = duet.sync(list_jobs_async)
551551

cirq-google/cirq_google/engine/engine_client_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_list_program(client_constructor, default_engine_client):
157157
quantum.QuantumProgram(name='projects/proj/programs/prog1'),
158158
quantum.QuantumProgram(name='projects/proj/programs/prog2'),
159159
]
160-
grpc_client.list_quantum_programs.return_value = results
160+
grpc_client.list_quantum_programs.return_value = _AsyncIterable(results)
161161

162162
assert default_engine_client.list_programs(project_id='proj') == results
163163
grpc_client.list_quantum_programs.assert_called_with(
@@ -1252,7 +1252,7 @@ def test_list_jobs(client_constructor, default_engine_client):
12521252
quantum.QuantumJob(name='projects/proj/programs/prog1/jobs/job1'),
12531253
quantum.QuantumJob(name='projects/proj/programs/prog1/jobs/job2'),
12541254
]
1255-
grpc_client.list_quantum_jobs.return_value = results
1255+
grpc_client.list_quantum_jobs.return_value = _AsyncIterable(results)
12561256

12571257
assert default_engine_client.list_jobs(project_id='proj', program_id='prog1') == results
12581258
grpc_client.list_quantum_jobs.assert_called_with(
@@ -1265,6 +1265,15 @@ def test_list_jobs(client_constructor, default_engine_client):
12651265
)
12661266

12671267

1268+
class _AsyncIterable:
1269+
def __init__(self, items):
1270+
self.items = items
1271+
1272+
async def __aiter__(self):
1273+
for item in self.items:
1274+
yield item
1275+
1276+
12681277
@pytest.mark.parametrize(
12691278
'expected_filter, '
12701279
'created_after, '

cirq-google/cirq_google/engine/engine_program.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,9 @@ async def get_circuit_async(self) -> cirq.Circuit:
361361
Returns:
362362
The program's cirq Circuit.
363363
"""
364-
if self._program is None or self._program.code is None:
364+
# The code field is an any_pb2.Any and is always set. But if the program has not
365+
# been fetched this field may be empty, which we can see by checking the type_url.
366+
if self._program is None or not self._program.code or not self._program.code.type_url:
365367
self._program = await self.context.client.get_program_async(
366368
self.project_id, self.program_id, True
367369
)

cirq-google/cirq_google/engine/engine_program_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,14 @@ def test_get_circuit_v1(get_program_async):
298298

299299

300300
@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async')
301-
def test_get_circuit_v2(get_program_async):
301+
@pytest.mark.parametrize("include_empty_program", [False, True])
302+
def test_get_circuit_v2(get_program_async, include_empty_program: bool) -> None:
302303
circuit = cirq.Circuit(
303304
cirq.X(cirq.GridQubit(5, 2)) ** 0.5, cirq.measure(cirq.GridQubit(5, 2), key='result')
304305
)
305306

306-
program = cg.EngineProgram('a', 'b', EngineContext())
307+
program_msg = quantum.QuantumProgram() if include_empty_program else None
308+
program = cg.EngineProgram('a', 'b', EngineContext(), _program=program_msg)
307309
get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2)
308310
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
309311
program.get_circuit(), circuit

0 commit comments

Comments
 (0)