Skip to content

Commit d371a89

Browse files
committed
Support compare URLs in extract_upstream_repository
Fetch commit lists from compare URLs for cherry-pick workflow.
1 parent b296d5b commit d371a89

File tree

2 files changed

+377
-40
lines changed

2 files changed

+377
-40
lines changed

agents/tests/unit/test_tools.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848
SearchTextToolInput,
4949
)
5050
from tools.filesystem import GetCWDTool, GetCWDToolInput, RemoveTool, RemoveToolInput
51+
from tools.upstream_tools import (
52+
ExtractUpstreamRepositoryTool,
53+
ExtractUpstreamRepositoryInput,
54+
)
5155

5256

5357
@pytest.mark.parametrize(
@@ -706,3 +710,276 @@ async def test_find_rej_files(git_repo):
706710
(git_repo / "foo-bar.rej").write_text("rej content 2")
707711
result = await find_rej_files(git_repo)
708712
assert sorted(result) == sorted(["file.txt.rej", "foo-bar.rej"])
713+
714+
715+
# Tests for ExtractUpstreamRepositoryTool
716+
@pytest.mark.asyncio
717+
@pytest.mark.parametrize(
718+
"url, expected_owner, expected_repo, expected_commit",
719+
[
720+
# Regular commit URLs
721+
(
722+
"https://github.com/owner/repo/commit/abc123def456",
723+
"owner",
724+
"repo",
725+
"abc123def456",
726+
),
727+
(
728+
"https://github.com/owner/repo/commit/abc123def456.patch",
729+
"owner",
730+
"repo",
731+
"abc123def456",
732+
),
733+
(
734+
"https://gitlab.com/owner/repo/-/commit/abc123def456",
735+
"owner",
736+
"repo",
737+
"abc123def456",
738+
),
739+
],
740+
)
741+
async def test_extract_upstream_repository_commit_url(url, expected_owner, expected_repo, expected_commit):
742+
"""Test that regular commit URLs are properly parsed without API calls."""
743+
tool = ExtractUpstreamRepositoryTool()
744+
745+
output = await tool.run(
746+
input=ExtractUpstreamRepositoryInput(upstream_fix_url=url)
747+
).middleware(GlobalTrajectoryMiddleware(pretty=True))
748+
749+
result = output.to_json_safe()
750+
751+
# Verify the results
752+
expected_netloc = "github.com" if "github" in url else "gitlab.com"
753+
assert result.repo_url == f"https://{expected_netloc}/{expected_owner}/{expected_repo}.git"
754+
assert result.commit_hash == expected_commit
755+
assert result.is_pr is False
756+
assert result.is_compare is False
757+
758+
759+
@pytest.mark.asyncio
760+
async def test_extract_upstream_repository_pr_url():
761+
"""Test that PR URLs fetch commit hash from GitHub API."""
762+
import aiohttp
763+
from unittest.mock import AsyncMock, MagicMock, patch
764+
765+
tool = ExtractUpstreamRepositoryTool()
766+
url = "https://github.com/owner/repo/pull/123"
767+
768+
# Mock GitHub API response
769+
mock_data = {"head": {"sha": "pr_commit_hash_abc123"}}
770+
771+
mock_response = AsyncMock()
772+
mock_response.status = 200
773+
mock_response.raise_for_status = MagicMock()
774+
mock_response.json = AsyncMock(return_value=mock_data)
775+
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
776+
mock_response.__aexit__ = AsyncMock(return_value=None)
777+
778+
mock_session = MagicMock()
779+
mock_session.get = MagicMock(return_value=mock_response)
780+
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
781+
mock_session.__aexit__ = AsyncMock(return_value=None)
782+
783+
with patch('tools.upstream_tools.aiohttp.ClientSession', return_value=mock_session):
784+
output = await tool.run(
785+
input=ExtractUpstreamRepositoryInput(upstream_fix_url=url)
786+
).middleware(GlobalTrajectoryMiddleware(pretty=True))
787+
788+
result = output.to_json_safe()
789+
790+
# Verify the results
791+
assert result.repo_url == "https://github.com/owner/repo.git"
792+
assert result.commit_hash == "pr_commit_hash_abc123"
793+
assert result.is_pr is True
794+
assert result.pr_number == "123"
795+
assert result.is_compare is False
796+
797+
798+
@pytest.mark.asyncio
799+
@pytest.mark.parametrize(
800+
"url, expected_owner, expected_repo, expected_base_ref, expected_target_ref, is_github",
801+
[
802+
(
803+
"https://github.com/git-lfs/git-lfs/compare/v3.7.0...v3.7.1",
804+
"git-lfs",
805+
"git-lfs",
806+
"v3.7.0",
807+
"v3.7.1",
808+
True,
809+
),
810+
(
811+
"https://github.com/owner/repo/compare/v1.0.0...v2.0.0.patch",
812+
"owner",
813+
"repo",
814+
"v1.0.0",
815+
"v2.0.0",
816+
True,
817+
),
818+
(
819+
"https://gitlab.com/owner/project/-/compare/release-1.0...release-2.0",
820+
"owner",
821+
"project",
822+
"release-1.0",
823+
"release-2.0",
824+
False,
825+
),
826+
(
827+
"https://github.com/libsndfile/libsndfile/compare/1.2.2..1.2.3",
828+
"libsndfile",
829+
"libsndfile",
830+
"1.2.2",
831+
"1.2.3",
832+
True,
833+
),
834+
],
835+
)
836+
async def test_extract_upstream_repository_compare_url(
837+
url, expected_owner, expected_repo, expected_base_ref, expected_target_ref, is_github
838+
):
839+
"""Test that compare URLs are properly parsed and repository info is extracted."""
840+
import aiohttp
841+
from unittest.mock import AsyncMock, MagicMock, patch
842+
843+
tool = ExtractUpstreamRepositoryTool()
844+
845+
# Create mock response data
846+
if is_github:
847+
mock_data = {
848+
"commits": [
849+
{"sha": "abc123"},
850+
{"sha": "def456"},
851+
{"sha": "ghi789"}
852+
]
853+
}
854+
else:
855+
# GitLab returns commits in reverse order (newest first)
856+
mock_data = {
857+
"commits": [
858+
{"id": "ghi789"},
859+
{"id": "def456"},
860+
{"id": "abc123"}
861+
]
862+
}
863+
864+
# Mock the aiohttp session and response
865+
mock_response = AsyncMock()
866+
mock_response.status = 200
867+
mock_response.raise_for_status = MagicMock()
868+
mock_response.json = AsyncMock(return_value=mock_data)
869+
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
870+
mock_response.__aexit__ = AsyncMock(return_value=None)
871+
872+
mock_session = MagicMock()
873+
mock_session.get = MagicMock(return_value=mock_response)
874+
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
875+
mock_session.__aexit__ = AsyncMock(return_value=None)
876+
877+
with patch('tools.upstream_tools.aiohttp.ClientSession', return_value=mock_session):
878+
output = await tool.run(
879+
input=ExtractUpstreamRepositoryInput(upstream_fix_url=url)
880+
).middleware(GlobalTrajectoryMiddleware(pretty=True))
881+
882+
result = output.to_json_safe()
883+
884+
# Verify the results
885+
expected_netloc = "github.com" if is_github else "gitlab.com"
886+
assert result.repo_url == f"https://{expected_netloc}/{expected_owner}/{expected_repo}.git"
887+
assert result.is_compare is True
888+
assert result.is_pr is False
889+
assert result.base_ref == expected_base_ref
890+
assert result.target_ref == expected_target_ref
891+
assert result.compare_commits == ["abc123", "def456", "ghi789"]
892+
assert result.commit_hash == "ghi789" # Should be the last (newest) commit
893+
894+
895+
@pytest.mark.asyncio
896+
async def test_extract_upstream_repository_compare_url_with_special_chars():
897+
"""Test that compare URLs with special characters in refs are properly URL-encoded."""
898+
import aiohttp
899+
from unittest.mock import AsyncMock, MagicMock, patch
900+
901+
tool = ExtractUpstreamRepositoryTool()
902+
url = "https://github.com/owner/repo/compare/feature/branch-1...bugfix/issue-123"
903+
904+
mock_data = {
905+
"commits": [
906+
{"sha": "commit1"},
907+
{"sha": "commit2"}
908+
]
909+
}
910+
911+
mock_response = AsyncMock()
912+
mock_response.status = 200
913+
mock_response.raise_for_status = MagicMock()
914+
mock_response.json = AsyncMock(return_value=mock_data)
915+
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
916+
mock_response.__aexit__ = AsyncMock(return_value=None)
917+
918+
mock_session = MagicMock()
919+
mock_session.get = MagicMock(return_value=mock_response)
920+
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
921+
mock_session.__aexit__ = AsyncMock(return_value=None)
922+
923+
with patch('tools.upstream_tools.aiohttp.ClientSession', return_value=mock_session):
924+
output = await tool.run(
925+
input=ExtractUpstreamRepositoryInput(upstream_fix_url=url)
926+
).middleware(GlobalTrajectoryMiddleware(pretty=True))
927+
928+
result = output.to_json_safe()
929+
930+
# Verify URL encoding happened
931+
call_args = mock_session.get.call_args
932+
called_url = call_args[0][0]
933+
934+
# Check that slashes in branch names were URL-encoded
935+
assert "feature%2Fbranch-1" in called_url
936+
assert "bugfix%2Fissue-123" in called_url
937+
assert result.base_ref == "feature/branch-1"
938+
assert result.target_ref == "bugfix/issue-123"
939+
940+
941+
@pytest.mark.asyncio
942+
async def test_extract_upstream_repository_compare_url_api_failure():
943+
"""Test that compare URLs gracefully fall back when API fails."""
944+
import aiohttp
945+
from unittest.mock import AsyncMock, MagicMock, patch
946+
947+
tool = ExtractUpstreamRepositoryTool()
948+
url = "https://github.com/owner/repo/compare/v1.0.0...v2.0.0"
949+
950+
# Mock aiohttp session to raise aiohttp.ClientError
951+
mock_response = AsyncMock()
952+
mock_response.raise_for_status = MagicMock(side_effect=aiohttp.ClientError("API Error"))
953+
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
954+
mock_response.__aexit__ = AsyncMock(return_value=None)
955+
956+
mock_session = MagicMock()
957+
mock_session.get = MagicMock(return_value=mock_response)
958+
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
959+
mock_session.__aexit__ = AsyncMock(return_value=None)
960+
961+
with patch('tools.upstream_tools.aiohttp.ClientSession', return_value=mock_session):
962+
output = await tool.run(
963+
input=ExtractUpstreamRepositoryInput(upstream_fix_url=url)
964+
).middleware(GlobalTrajectoryMiddleware(pretty=True))
965+
966+
result = output.to_json_safe()
967+
968+
# Should fall back gracefully
969+
assert result.is_compare is True
970+
assert result.commit_hash == "v2.0.0" # Falls back to target_ref
971+
assert result.compare_commits is None # No commits fetched
972+
973+
974+
@pytest.mark.asyncio
975+
async def test_extract_upstream_repository_invalid_url():
976+
"""Test that invalid URLs raise appropriate errors."""
977+
tool = ExtractUpstreamRepositoryTool()
978+
url = "https://github.com/invalid/url/structure"
979+
980+
with pytest.raises(ToolError) as exc_info:
981+
await tool.run(
982+
input=ExtractUpstreamRepositoryInput(upstream_fix_url=url)
983+
).middleware(GlobalTrajectoryMiddleware(pretty=True))
984+
985+
assert "Could not extract commit hash from URL" in str(exc_info.value)

0 commit comments

Comments
 (0)