11
2- from . import run_on_executor
2+ from . import run_on_executor , PrivateSSHKeyContext
3+
34import giturlparse
4- import json
55from concurrent .futures import ThreadPoolExecutor
66import git
77import os
88import os .path
99import subprocess
1010import stat
11+ import yaml
12+
1113from shutil import copyfile , rmtree
1214from tempfile import TemporaryDirectory
1315
1416
17+ def git_ssh_command (private_key ):
18+ return "ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i {0}" .format (private_key )
19+
20+
21+ def git_ssh_environment (g , ssh_private_key_filename = None ):
22+ if ssh_private_key_filename :
23+ return g .custom_environment (GIT_SSH_COMMAND = git_ssh_command (ssh_private_key_filename ))
24+ return g .custom_environment ()
25+
26+
1527class CacheRedirectException (Exception ):
1628 def __init__ (self , url ):
1729 self .url = url
1830
1931
32+ class GitRepositoryError (Exception ):
33+ def __init__ (self , code , message ):
34+ self .code = code
35+ self .message = message
36+
37+
2038class GitRepository (object ):
2139 executor = ThreadPoolExecutor ()
22- g = git .cmd .Git ()
2340
24- def __init__ (self , url , cache_directory , public_url ):
41+ def __init__ (self , settings , cache_directory , public_url , default_private_key = None ):
42+
43+ if isinstance (settings , str ):
44+ url = settings
45+ self .private_key = None
46+ elif isinstance (settings , dict ):
47+ url = settings .get ("url" , None )
48+ if url is None :
49+ raise RuntimeError ("Repo is an object but has no `url` property" )
50+ self .private_key = settings .get ("ssh_key" , default_private_key )
51+ else :
52+ raise RuntimeError ("Repo should be a string or a object" )
53+
2554 self .repo_url = url
2655 self .parsed_url = giturlparse .parse (url )
2756 if not self .parsed_url .valid :
@@ -33,14 +62,20 @@ def __init__(self, url, cache_directory, public_url):
3362 @run_on_executor
3463 def list_versions (self ):
3564 repo_name = self .name ()
36- tags = GitRepository .g .ls_remote (self .repo_url , tags = True )
37- result = {}
38- for line in tags .split ('\n ' ):
39- ref_hash , ref = line .split ('\t ' )
40- tag_name = ref .split ("/" )[- 1 ]
41- tar_name = self .package_tar (tag_name )
42- result [tar_name ] = self .public_url + "/download/" + repo_name + "/" + tar_name
43- return result
65+ with PrivateSSHKeyContext (ssh_private_key = self .private_key ) as ssh_private_key_filename :
66+ g = git .cmd .Git ()
67+ with git_ssh_environment (g , ssh_private_key_filename = ssh_private_key_filename ):
68+ try :
69+ tags = g .ls_remote (self .repo_url , tags = True )
70+ except git .GitCommandError as e :
71+ raise GitRepositoryError (500 , "Failed to fetch remote repository: {0}" .format (e .status ))
72+ result = {}
73+ for line in tags .split ('\n ' ):
74+ ref_hash , ref = line .split ('\t ' )
75+ tag_name = ref .split ("/" )[- 1 ]
76+ tar_name = self .package_tar (tag_name )
77+ result [tar_name ] = self .public_url + "/download/" + repo_name + "/" + tar_name
78+ return result
4479
4580 def get_cache_url (self , package_version ):
4681 return self .cache_public_url + "/" + self .package_tar (package_version )
@@ -59,20 +94,26 @@ def download(self, package_version):
5994 if os .path .isfile (os .path .join (self .cache_directory , tar_name )):
6095 return cache_url
6196
97+ # noinspection PyUnusedLocal
6298 def set_rw (operation , name , exc ):
6399 os .chmod (name , stat .S_IWRITE )
64100 os .remove (name )
65101 return True
66102
67103 with TemporaryDirectory (prefix = "pypigit" ) as temp_dir :
68- r = git .Git (temp_dir )
69- r .clone (self .repo_url , branch = package_version , depth = 1 )
104+ g = git .Git (temp_dir )
105+
106+ with PrivateSSHKeyContext (ssh_private_key = self .private_key ) as ssh_private_key_filename :
107+ with git_ssh_environment (g , ssh_private_key_filename = ssh_private_key_filename ):
108+ g .clone (self .repo_url , branch = package_version , depth = 1 )
109+
70110 build = os .path .join (temp_dir , repo_name )
71111 p = subprocess .Popen ("python setup.py sdist" , stdout = subprocess .PIPE , shell = True , cwd = build )
72112 p .communicate ()
73113
74114 copyfile (os .path .join (build , "dist" , tar_name ), os .path .join (self .cache_directory , tar_name ))
75115
116+ # noinspection PyTypeChecker
76117 rmtree (build , onerror = set_rw )
77118
78119 return cache_url
@@ -85,18 +126,23 @@ class GitRepositories(object):
85126 def __init__ (self , repos_filename , cache_directory , public_url ):
86127
87128 with open (repos_filename , "r" ) as f :
88- repos = json .load (f )
129+ repos = yaml .load (f )
89130
90131 if not isinstance (repos , dict ):
91132 raise RuntimeError ("--repos file should be a json object" )
92133
134+ default_private_key = repos .get ("default_ssh_key" )
135+
93136 if "repositories" not in repos :
94137 raise RuntimeError ("No 'repositories' section in --repos file" )
95138
96139 self .repositories = {
97140 repo .name (): repo
98- for repo in map (lambda path : GitRepository (path , cache_directory , public_url ),
99- filter (lambda s : isinstance (s , str ), repos ["repositories" ]))
141+ for repo in map (
142+ lambda settings : GitRepository (
143+ settings , cache_directory , public_url , default_private_key = default_private_key
144+ ), repos ["repositories" ]
145+ )
100146 }
101147
102148 def find (self , name ):
0 commit comments