|
48 | 48 | SearchTextToolInput, |
49 | 49 | ) |
50 | 50 | from tools.filesystem import GetCWDTool, GetCWDToolInput, RemoveTool, RemoveToolInput |
| 51 | +from tools.upstream_tools import ( |
| 52 | + ExtractUpstreamRepositoryTool, |
| 53 | + ExtractUpstreamRepositoryInput, |
| 54 | +) |
51 | 55 |
|
52 | 56 |
|
53 | 57 | @pytest.mark.parametrize( |
@@ -706,3 +710,276 @@ async def test_find_rej_files(git_repo): |
706 | 710 | (git_repo / "foo-bar.rej").write_text("rej content 2") |
707 | 711 | result = await find_rej_files(git_repo) |
708 | 712 | assert sorted(result) == sorted(["file.txt.rej", "foo-bar.rej"]) |
| 713 | + |
| 714 | + |
| 715 | +# Tests for ExtractUpstreamRepositoryTool |
| 716 | +@pytest.mark.asyncio |
| 717 | +@pytest.mark.parametrize( |
| 718 | + "url, expected_owner, expected_repo, expected_commit", |
| 719 | + [ |
| 720 | + # Regular commit URLs |
| 721 | + ( |
| 722 | + "https://github.com/owner/repo/commit/abc123def456", |
| 723 | + "owner", |
| 724 | + "repo", |
| 725 | + "abc123def456", |
| 726 | + ), |
| 727 | + ( |
| 728 | + "https://github.com/owner/repo/commit/abc123def456.patch", |
| 729 | + "owner", |
| 730 | + "repo", |
| 731 | + "abc123def456", |
| 732 | + ), |
| 733 | + ( |
| 734 | + "https://gitlab.com/owner/repo/-/commit/abc123def456", |
| 735 | + "owner", |
| 736 | + "repo", |
| 737 | + "abc123def456", |
| 738 | + ), |
| 739 | + ], |
| 740 | +) |
| 741 | +async def test_extract_upstream_repository_commit_url(url, expected_owner, expected_repo, expected_commit): |
| 742 | + """Test that regular commit URLs are properly parsed without API calls.""" |
| 743 | + tool = ExtractUpstreamRepositoryTool() |
| 744 | + |
| 745 | + output = await tool.run( |
| 746 | + input=ExtractUpstreamRepositoryInput(upstream_fix_url=url) |
| 747 | + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) |
| 748 | + |
| 749 | + result = output.to_json_safe() |
| 750 | + |
| 751 | + # Verify the results |
| 752 | + expected_netloc = "github.com" if "github" in url else "gitlab.com" |
| 753 | + assert result.repo_url == f"https://{expected_netloc}/{expected_owner}/{expected_repo}.git" |
| 754 | + assert result.commit_hash == expected_commit |
| 755 | + assert result.is_pr is False |
| 756 | + assert result.is_compare is False |
| 757 | + |
| 758 | + |
| 759 | +@pytest.mark.asyncio |
| 760 | +async def test_extract_upstream_repository_pr_url(): |
| 761 | + """Test that PR URLs fetch commit hash from GitHub API.""" |
| 762 | + import aiohttp |
| 763 | + from unittest.mock import AsyncMock, MagicMock, patch |
| 764 | + |
| 765 | + tool = ExtractUpstreamRepositoryTool() |
| 766 | + url = "https://github.com/owner/repo/pull/123" |
| 767 | + |
| 768 | + # Mock GitHub API response |
| 769 | + mock_data = {"head": {"sha": "pr_commit_hash_abc123"}} |
| 770 | + |
| 771 | + mock_response = AsyncMock() |
| 772 | + mock_response.status = 200 |
| 773 | + mock_response.raise_for_status = MagicMock() |
| 774 | + mock_response.json = AsyncMock(return_value=mock_data) |
| 775 | + mock_response.__aenter__ = AsyncMock(return_value=mock_response) |
| 776 | + mock_response.__aexit__ = AsyncMock(return_value=None) |
| 777 | + |
| 778 | + mock_session = MagicMock() |
| 779 | + mock_session.get = MagicMock(return_value=mock_response) |
| 780 | + mock_session.__aenter__ = AsyncMock(return_value=mock_session) |
| 781 | + mock_session.__aexit__ = AsyncMock(return_value=None) |
| 782 | + |
| 783 | + with patch('tools.upstream_tools.aiohttp.ClientSession', return_value=mock_session): |
| 784 | + output = await tool.run( |
| 785 | + input=ExtractUpstreamRepositoryInput(upstream_fix_url=url) |
| 786 | + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) |
| 787 | + |
| 788 | + result = output.to_json_safe() |
| 789 | + |
| 790 | + # Verify the results |
| 791 | + assert result.repo_url == "https://github.com/owner/repo.git" |
| 792 | + assert result.commit_hash == "pr_commit_hash_abc123" |
| 793 | + assert result.is_pr is True |
| 794 | + assert result.pr_number == "123" |
| 795 | + assert result.is_compare is False |
| 796 | + |
| 797 | + |
| 798 | +@pytest.mark.asyncio |
| 799 | +@pytest.mark.parametrize( |
| 800 | + "url, expected_owner, expected_repo, expected_base_ref, expected_target_ref, is_github", |
| 801 | + [ |
| 802 | + ( |
| 803 | + "https://github.com/git-lfs/git-lfs/compare/v3.7.0...v3.7.1", |
| 804 | + "git-lfs", |
| 805 | + "git-lfs", |
| 806 | + "v3.7.0", |
| 807 | + "v3.7.1", |
| 808 | + True, |
| 809 | + ), |
| 810 | + ( |
| 811 | + "https://github.com/owner/repo/compare/v1.0.0...v2.0.0.patch", |
| 812 | + "owner", |
| 813 | + "repo", |
| 814 | + "v1.0.0", |
| 815 | + "v2.0.0", |
| 816 | + True, |
| 817 | + ), |
| 818 | + ( |
| 819 | + "https://gitlab.com/owner/project/-/compare/release-1.0...release-2.0", |
| 820 | + "owner", |
| 821 | + "project", |
| 822 | + "release-1.0", |
| 823 | + "release-2.0", |
| 824 | + False, |
| 825 | + ), |
| 826 | + ( |
| 827 | + "https://github.com/libsndfile/libsndfile/compare/1.2.2..1.2.3", |
| 828 | + "libsndfile", |
| 829 | + "libsndfile", |
| 830 | + "1.2.2", |
| 831 | + "1.2.3", |
| 832 | + True, |
| 833 | + ), |
| 834 | + ], |
| 835 | +) |
| 836 | +async def test_extract_upstream_repository_compare_url( |
| 837 | + url, expected_owner, expected_repo, expected_base_ref, expected_target_ref, is_github |
| 838 | +): |
| 839 | + """Test that compare URLs are properly parsed and repository info is extracted.""" |
| 840 | + import aiohttp |
| 841 | + from unittest.mock import AsyncMock, MagicMock, patch |
| 842 | + |
| 843 | + tool = ExtractUpstreamRepositoryTool() |
| 844 | + |
| 845 | + # Create mock response data |
| 846 | + if is_github: |
| 847 | + mock_data = { |
| 848 | + "commits": [ |
| 849 | + {"sha": "abc123"}, |
| 850 | + {"sha": "def456"}, |
| 851 | + {"sha": "ghi789"} |
| 852 | + ] |
| 853 | + } |
| 854 | + else: |
| 855 | + # GitLab returns commits in reverse order (newest first) |
| 856 | + mock_data = { |
| 857 | + "commits": [ |
| 858 | + {"id": "ghi789"}, |
| 859 | + {"id": "def456"}, |
| 860 | + {"id": "abc123"} |
| 861 | + ] |
| 862 | + } |
| 863 | + |
| 864 | + # Mock the aiohttp session and response |
| 865 | + mock_response = AsyncMock() |
| 866 | + mock_response.status = 200 |
| 867 | + mock_response.raise_for_status = MagicMock() |
| 868 | + mock_response.json = AsyncMock(return_value=mock_data) |
| 869 | + mock_response.__aenter__ = AsyncMock(return_value=mock_response) |
| 870 | + mock_response.__aexit__ = AsyncMock(return_value=None) |
| 871 | + |
| 872 | + mock_session = MagicMock() |
| 873 | + mock_session.get = MagicMock(return_value=mock_response) |
| 874 | + mock_session.__aenter__ = AsyncMock(return_value=mock_session) |
| 875 | + mock_session.__aexit__ = AsyncMock(return_value=None) |
| 876 | + |
| 877 | + with patch('tools.upstream_tools.aiohttp.ClientSession', return_value=mock_session): |
| 878 | + output = await tool.run( |
| 879 | + input=ExtractUpstreamRepositoryInput(upstream_fix_url=url) |
| 880 | + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) |
| 881 | + |
| 882 | + result = output.to_json_safe() |
| 883 | + |
| 884 | + # Verify the results |
| 885 | + expected_netloc = "github.com" if is_github else "gitlab.com" |
| 886 | + assert result.repo_url == f"https://{expected_netloc}/{expected_owner}/{expected_repo}.git" |
| 887 | + assert result.is_compare is True |
| 888 | + assert result.is_pr is False |
| 889 | + assert result.base_ref == expected_base_ref |
| 890 | + assert result.target_ref == expected_target_ref |
| 891 | + assert result.compare_commits == ["abc123", "def456", "ghi789"] |
| 892 | + assert result.commit_hash == "ghi789" # Should be the last (newest) commit |
| 893 | + |
| 894 | + |
| 895 | +@pytest.mark.asyncio |
| 896 | +async def test_extract_upstream_repository_compare_url_with_special_chars(): |
| 897 | + """Test that compare URLs with special characters in refs are properly URL-encoded.""" |
| 898 | + import aiohttp |
| 899 | + from unittest.mock import AsyncMock, MagicMock, patch |
| 900 | + |
| 901 | + tool = ExtractUpstreamRepositoryTool() |
| 902 | + url = "https://github.com/owner/repo/compare/feature/branch-1...bugfix/issue-123" |
| 903 | + |
| 904 | + mock_data = { |
| 905 | + "commits": [ |
| 906 | + {"sha": "commit1"}, |
| 907 | + {"sha": "commit2"} |
| 908 | + ] |
| 909 | + } |
| 910 | + |
| 911 | + mock_response = AsyncMock() |
| 912 | + mock_response.status = 200 |
| 913 | + mock_response.raise_for_status = MagicMock() |
| 914 | + mock_response.json = AsyncMock(return_value=mock_data) |
| 915 | + mock_response.__aenter__ = AsyncMock(return_value=mock_response) |
| 916 | + mock_response.__aexit__ = AsyncMock(return_value=None) |
| 917 | + |
| 918 | + mock_session = MagicMock() |
| 919 | + mock_session.get = MagicMock(return_value=mock_response) |
| 920 | + mock_session.__aenter__ = AsyncMock(return_value=mock_session) |
| 921 | + mock_session.__aexit__ = AsyncMock(return_value=None) |
| 922 | + |
| 923 | + with patch('tools.upstream_tools.aiohttp.ClientSession', return_value=mock_session): |
| 924 | + output = await tool.run( |
| 925 | + input=ExtractUpstreamRepositoryInput(upstream_fix_url=url) |
| 926 | + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) |
| 927 | + |
| 928 | + result = output.to_json_safe() |
| 929 | + |
| 930 | + # Verify URL encoding happened |
| 931 | + call_args = mock_session.get.call_args |
| 932 | + called_url = call_args[0][0] |
| 933 | + |
| 934 | + # Check that slashes in branch names were URL-encoded |
| 935 | + assert "feature%2Fbranch-1" in called_url |
| 936 | + assert "bugfix%2Fissue-123" in called_url |
| 937 | + assert result.base_ref == "feature/branch-1" |
| 938 | + assert result.target_ref == "bugfix/issue-123" |
| 939 | + |
| 940 | + |
| 941 | +@pytest.mark.asyncio |
| 942 | +async def test_extract_upstream_repository_compare_url_api_failure(): |
| 943 | + """Test that compare URLs gracefully fall back when API fails.""" |
| 944 | + import aiohttp |
| 945 | + from unittest.mock import AsyncMock, MagicMock, patch |
| 946 | + |
| 947 | + tool = ExtractUpstreamRepositoryTool() |
| 948 | + url = "https://github.com/owner/repo/compare/v1.0.0...v2.0.0" |
| 949 | + |
| 950 | + # Mock aiohttp session to raise aiohttp.ClientError |
| 951 | + mock_response = AsyncMock() |
| 952 | + mock_response.raise_for_status = MagicMock(side_effect=aiohttp.ClientError("API Error")) |
| 953 | + mock_response.__aenter__ = AsyncMock(return_value=mock_response) |
| 954 | + mock_response.__aexit__ = AsyncMock(return_value=None) |
| 955 | + |
| 956 | + mock_session = MagicMock() |
| 957 | + mock_session.get = MagicMock(return_value=mock_response) |
| 958 | + mock_session.__aenter__ = AsyncMock(return_value=mock_session) |
| 959 | + mock_session.__aexit__ = AsyncMock(return_value=None) |
| 960 | + |
| 961 | + with patch('tools.upstream_tools.aiohttp.ClientSession', return_value=mock_session): |
| 962 | + output = await tool.run( |
| 963 | + input=ExtractUpstreamRepositoryInput(upstream_fix_url=url) |
| 964 | + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) |
| 965 | + |
| 966 | + result = output.to_json_safe() |
| 967 | + |
| 968 | + # Should fall back gracefully |
| 969 | + assert result.is_compare is True |
| 970 | + assert result.commit_hash == "v2.0.0" # Falls back to target_ref |
| 971 | + assert result.compare_commits is None # No commits fetched |
| 972 | + |
| 973 | + |
| 974 | +@pytest.mark.asyncio |
| 975 | +async def test_extract_upstream_repository_invalid_url(): |
| 976 | + """Test that invalid URLs raise appropriate errors.""" |
| 977 | + tool = ExtractUpstreamRepositoryTool() |
| 978 | + url = "https://github.com/invalid/url/structure" |
| 979 | + |
| 980 | + with pytest.raises(ToolError) as exc_info: |
| 981 | + await tool.run( |
| 982 | + input=ExtractUpstreamRepositoryInput(upstream_fix_url=url) |
| 983 | + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) |
| 984 | + |
| 985 | + assert "Could not extract commit hash from URL" in str(exc_info.value) |
0 commit comments