Skip to content

Commit 7a4a691

Browse files
amritghimiredreadatourampcode-comCopilot
authored
job run: add --no-follow and fix behavior when websocket closes early (#1577)
* job run: add --no-follow and fix behavior when websocket closes early Add --no-follow so CI can wait for job completion without streaming logs to the console. When --no-follow is set we still consume the log stream and only skip printing log lines and log blobs. When the log stream websocket closes before a final status we now fetch job status via REST and only show dataset versions if the job actually finished. Otherwise we print "Lost connection" and exit 1. Also fix the status check to use JobStatus.finished() and break on unknown status to avoid an infinite loop. * Update docs/commands/job/run.md Co-authored-by: Vladimir Rudnykh <dreadatour@gmail.com> * Skip ping messages * Add no follow params to studio client * Pass verbose flag through to job log streaming Switch create_job call in process_jobs_args to use keyword arguments for clarity and add the missing verbose parameter. show_logs_from_client now accepts a verbose flag and prints diagnostic messages when the job finishes, retries are exhausted, or an unknown status is encountered. This makes it easier to debug log streaming issues without needing to attach a debugger. Amp-Thread-ID: https://ampcode.com/threads/T-019c3357-1a5b-76bb-8ea5-9f3baf69cc99 Co-authored-by: Amp <amp@ampcode.com> * Fix tests * Increase coverage * studio: fix job run tests and switch verbose to logging - Fix test_studio_run_non_zero_exit_code and websocket disconnect tests: patch StudioClient.tail_job_logs where it is used (datachain.studio) so the mock is applied. Add no_follow to mock signature to match real API. Mock GET jobs with a regex so requests with query params match. - In studio.py, drop the verbose flag from create_job and show_logs_from_client; use logger.debug() for debug messages instead. - Adjust test_studio_run_invalid_job_status to assert on caplog when checking debug messages. Add tests for verbose (caplog), log blobs, _get_job_status edge cases, rest_status None, dataset versions error, and TASK status. - Add return type to _get_job_status and log on exception. * Add clarify * Fix test * Update tests/test_cli_studio.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Vladimir Rudnykh <dreadatour@gmail.com> Co-authored-by: Amp <amp@ampcode.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent f3a7bf0 commit 7a4a691

File tree

5 files changed

+588
-45
lines changed

5 files changed

+588
-45
lines changed

docs/commands/job/run.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ usage: datachain job run [-h] [-v] [-q] [--team TEAM] [--env-file ENV_FILE]
1414
[--req-file REQ_FILE] [--req REQ [REQ ...]]
1515
[--priority PRIORITY]
1616
[--start-time START_TIME] [--cron CRON]
17-
[--no-wait] [--ignore-checkpoints]
17+
[--no-wait] [--no-follow] [--ignore-checkpoints]
1818
file
1919
```
2020

@@ -43,6 +43,7 @@ This command runs a job in Studio using the specified query file. You can config
4343
* `--start-time START_TIME` - Time to schedule the task in YYYY-MM-DDTHH:mm format or natural language.
4444
* `--cron CRON` - Cron expression for the cron task.
4545
* `--no-wait` - Do not wait for the job to finish.
46+
* `--no-follow` - Do not print the job logs to the console
4647
* `--ignore-checkpoints` - Ignore existing checkpoints and run from scratch.
4748
* `-h`, `--help` - Show the help message and exit.
4849
* `-v`, `--verbose` - Be verbose.
@@ -155,6 +156,12 @@ datachain job run --start-time "tomorrow 3pm" --cron "0 0 * * *" query.py
155156
datachain job run query.py --no-wait
156157
```
157158

159+
14. Start the job and wait for completion but don't print logs
160+
```bash
161+
# Useful for CI where you just want to wait for the completion of the jobs.
162+
datachain job run query.py --no-follow
163+
```
164+
158165
## Notes
159166

160167
* **Checkpoints**: Running the same script multiple times via `datachain job run` automatically links jobs together, enabling checkpoint reuse. If a previous run of the same script (by absolute path) exists, DataChain will resume from where it left off.

src/datachain/cli/parser/job.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ def add_jobs_parser(subparsers, parent_parser) -> None:
122122
action="store_true",
123123
help="Do not wait for the job to finish",
124124
)
125+
studio_run_parser.add_argument(
126+
"--no-follow",
127+
action="store_true",
128+
help="Do not print the job logs to the console",
129+
)
125130
studio_run_parser.add_argument(
126131
"--ignore-checkpoints",
127132
action="store_true",

src/datachain/remote/studio.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,9 @@ def _unpacker_hook(code, data):
297297

298298
return msgpack.ExtType(code, data)
299299

300-
async def tail_job_logs(self, job_id: str) -> AsyncIterator[dict]:
300+
async def tail_job_logs(
301+
self, job_id: str, no_follow: bool = False
302+
) -> AsyncIterator[dict]:
301303
"""
302304
Follow job logs via websocket connection.
303305
@@ -312,6 +314,8 @@ async def tail_job_logs(self, job_id: str) -> AsyncIterator[dict]:
312314
parsed_url._replace(scheme="wss" if parsed_url.scheme == "https" else "ws")
313315
)
314316
ws_url = f"{ws_url}/logs/follow/?job_id={job_id}&team_name={self.team}"
317+
if no_follow:
318+
ws_url += "&no_follow=true"
315319

316320
async with websockets.connect(
317321
ws_url,
@@ -321,7 +325,8 @@ async def tail_job_logs(self, job_id: str) -> AsyncIterator[dict]:
321325
try:
322326
message = await websocket.recv()
323327
data = json.loads(message)
324-
328+
if data.get("type") == "ping":
329+
continue
325330
# Yield the parsed message data
326331
yield data
327332

src/datachain/studio.py

Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import logging
23
import os
34
import sys
45
import warnings
@@ -20,6 +21,8 @@
2021
from datachain.remote.studio import StudioClient
2122
from datachain.utils import STUDIO_URL, flatten
2223

24+
logger = logging.getLogger("datachain")
25+
2326
if TYPE_CHECKING:
2427
from argparse import Namespace
2528

@@ -43,23 +46,24 @@ def process_jobs_args(args: "Namespace"):
4346

4447
if args.cmd == "run":
4548
return create_job(
46-
args.file,
47-
args.team,
48-
args.env_file,
49-
args.env,
50-
args.workers,
51-
args.files,
52-
args.python_version,
53-
args.repository,
54-
args.req,
55-
args.req_file,
56-
args.priority,
57-
args.cluster,
58-
args.start_time,
59-
args.cron,
60-
args.no_wait,
61-
args.credentials_name,
62-
args.ignore_checkpoints,
49+
query_file=args.file,
50+
team_name=args.team,
51+
env_file=args.env_file,
52+
env=args.env,
53+
workers=args.workers,
54+
files=args.files,
55+
python_version=args.python_version,
56+
repository=args.repository,
57+
req=args.req,
58+
req_file=args.req_file,
59+
priority=args.priority,
60+
cluster=args.cluster,
61+
start_time=args.start_time,
62+
cron=args.cron,
63+
no_wait=args.no_wait,
64+
credentials_name=args.credentials_name,
65+
ignore_checkpoints=args.ignore_checkpoints,
66+
no_follow=args.no_follow,
6367
)
6468

6569
if args.cmd == "cancel":
@@ -366,21 +370,33 @@ async def _show_log_blobs(log_blobs: list[str], client):
366370
print("\n>>>> Warning: Failed to fetch logs from studio")
367371

368372

369-
def show_logs_from_client(client, job_id):
373+
def _get_job_status(client, job_id: str) -> str | None:
374+
try:
375+
response = client.get_jobs(job_id=job_id)
376+
if response.ok and response.data and len(response.data) > 0:
377+
return response.data[0].get("status")
378+
except (requests.RequestException, OSError, KeyError):
379+
logger.debug("Failed to get job status: %s", job_id)
380+
return None
381+
382+
383+
def show_logs_from_client( # noqa: C901
384+
client, job_id: str, no_follow: bool = False
385+
):
370386
async def _run():
371387
retry_count = 0
372388
latest_status = None
373389
processed_statuses = set()
374390
log_blobs_processed = False
375391
while True:
376-
async for message in client.tail_job_logs(job_id):
377-
if "log_blobs" in message:
392+
async for message in client.tail_job_logs(job_id, no_follow=no_follow):
393+
if "log_blobs" in message and not no_follow:
378394
log_blobs = message.get("log_blobs", [])
379395
if log_blobs and not log_blobs_processed:
380396
log_blobs_processed = True
381397
await _show_log_blobs(log_blobs, client)
382398

383-
elif "logs" in message:
399+
elif "logs" in message and not no_follow:
384400
for log in message["logs"]:
385401
print(log["message"], end="")
386402
elif "job" in message:
@@ -390,20 +406,41 @@ async def _run():
390406
processed_statuses.add(latest_status)
391407
print(f"\n>>>> Job is now in {latest_status} status.")
392408

409+
# After websocket closes, check actual job status via REST
410+
rest_status = _get_job_status(client, job_id)
411+
if rest_status and rest_status != latest_status:
412+
print(f"\n>>>> Job is now in {rest_status} status.")
413+
if rest_status:
414+
latest_status = rest_status
415+
393416
try:
394-
if retry_count > RETRY_MAX_TIMES or (
395-
latest_status and JobStatus[latest_status].finished()
396-
):
417+
if latest_status and JobStatus[latest_status] in JobStatus.finished():
418+
logger.debug("Job is in finished status: %s", latest_status)
419+
break
420+
if retry_count > RETRY_MAX_TIMES:
421+
logger.debug("Max retry count reached: %s", retry_count)
397422
break
398423
await asyncio.sleep(RETRY_SLEEP_SEC)
399424
retry_count += 1
400425
except KeyError:
401-
pass
426+
break
402427

403428
return latest_status
404429

405430
final_status = asyncio.run(_run())
406431

432+
try:
433+
job_finished = final_status and JobStatus[final_status] in JobStatus.finished()
434+
except KeyError:
435+
logger.debug("Job status is not a valid status: %s", final_status)
436+
job_finished = False
437+
438+
if not job_finished:
439+
logger.debug("Job is not finished: %s.", final_status or "unknown")
440+
print(f"\n>>>> Lost connection. Job status: {final_status or 'unknown'}.")
441+
return 1
442+
443+
# Show dataset versions only for finished jobs
407444
response = client.dataset_job_versions(job_id)
408445
if not response.ok:
409446
raise DataChainError(response.message)
@@ -417,11 +454,13 @@ async def _run():
417454
else:
418455
print("\n\nNo dataset versions created during the job.")
419456

420-
exit_code_by_status = {
421-
"FAILED": 1,
422-
"CANCELED": 2,
423-
}
424-
return exit_code_by_status.get(final_status.upper(), 0) if final_status else 0
457+
if final_status.upper() == "COMPLETE":
458+
return 0
459+
if final_status.upper() == "FAILED":
460+
return 1
461+
if final_status.upper() == "CANCELED":
462+
return 2
463+
return 0
425464

426465

427466
def create_job( # noqa: PLR0913
@@ -442,6 +481,7 @@ def create_job( # noqa: PLR0913
442481
no_wait: bool | None = False,
443482
credentials_name: str | None = None,
444483
ignore_checkpoints: bool = False,
484+
no_follow: bool = False,
445485
):
446486
catalog = get_catalog()
447487

@@ -532,7 +572,13 @@ def create_job( # noqa: PLR0913
532572
print("Open the job in Studio at", job_data.get("url"))
533573
print("=" * 40)
534574

535-
return 0 if no_wait else show_logs_from_client(client, job_id)
575+
return (
576+
0
577+
if no_wait
578+
else show_logs_from_client(
579+
client=client, job_id=str(job_id), no_follow=no_follow
580+
)
581+
)
536582

537583

538584
def upload_files(client: StudioClient, files: list[str]) -> list[str]:

0 commit comments

Comments
 (0)