1010
1111
1212import os
13+ import re
14+ import subprocess
1315from typing import Optional
1416
1517from sapling import encoding , error , json , util
@@ -71,6 +73,14 @@ def __init__(self, repodir=None, repo=None, ui=None):
7173 # TODO: remove this force_arcrc safety check when sl release has soaked
7274 # for a while.
7375 self ._force_arcrc = ui .configbool ("phabricator" , "force_arcrc" , False )
76+ self ._phabricator_auth_command = ui .configlist (
77+ "phabricator" , "auth-command" , ["jf" , "auth" ]
78+ )
79+ self ._auth_token_regex = ui .config (
80+ "phabricator" ,
81+ "auth-expired-regex" ,
82+ r"invalid auth token|the provided crypto auth token\(s\) are expired" ,
83+ )
7484 if not self ._mock :
7585 self ._app_id = ui .config ("phabricator" , "graphql_app_id" )
7686 self ._host = ui .config ("phabricator" , "graphql_host" )
@@ -99,6 +109,51 @@ def _get_phab_client(self):
99109 self ._host ,
100110 )
101111
112+ def _is_token_expired (self , response ):
113+ if not self ._auth_token_regex :
114+ return False
115+ pattern = re .compile (self ._auth_token_regex , re .IGNORECASE )
116+ possible_errors = [response .get (key ) for key in ("error" , "errors" )]
117+ possible_errors = [e if isinstance (e , list ) else [e ] for e in possible_errors ]
118+ possible_messages = [str (m ) for e in possible_errors for m in e ]
119+ return any (pattern .search (m ) for m in possible_messages )
120+
121+ # _query wraps self._client.query and provides transparent re-auth
122+ # if the response looks like it failed due to expired tokens.
123+ def _query (self , timeout , request , params = None ):
124+ ret = self ._client .query (timeout , request , params )
125+ if not self ._phabricator_auth_command or not self ._is_token_expired (ret ):
126+ return ret
127+
128+ try :
129+ subprocess .run (
130+ self ._phabricator_auth_command ,
131+ stdin = subprocess .DEVNULL ,
132+ stdout = subprocess .PIPE ,
133+ stderr = subprocess .PIPE ,
134+ check = True ,
135+ timeout = 60 ,
136+ )
137+ except Exception as ex :
138+ cmd_str = " " .join (self ._phabricator_auth_command )
139+ self ._ui .warn (_ ("warning: `%s` failed: %s\n " ) % (cmd_str , ex ))
140+ stderr = getattr (ex , "stderr" , None ) or b""
141+ stdout = getattr (ex , "stdout" , None ) or b""
142+ if stderr :
143+ self ._ui .warn (
144+ _ (" stderr: %s\n " )
145+ % stderr .decode ("utf-8" , errors = "replace" ).strip ()
146+ )
147+ if stdout :
148+ self ._ui .warn (
149+ _ (" stdout: %s\n " )
150+ % stdout .decode ("utf-8" , errors = "replace" ).strip ()
151+ )
152+ return ret
153+
154+ self ._client = self ._get_phab_client ()
155+ return self ._client .query (timeout , request , params )
156+
102157 def _applyarcconfig (self , config , defaultarcrchost ):
103158 arcrchost = config .get ("graphql_uri" , None )
104159 if "OVERRIDE_GRAPHQL_URI" in encoding .environ :
@@ -221,7 +276,7 @@ def getdiffversion(self, timeout, diffid, version=None):
221276 query = query % extra_query
222277
223278 params = {"diffid" : diffid }
224- ret = self ._client . query (timeout , query , params )
279+ ret = self ._query (timeout , query , params )
225280
226281 try :
227282 latest : Optional [dict ] = ret ["data" ]["phabricator_diff_query" ][0 ][
@@ -305,7 +360,7 @@ def getnodes(self, repo, diffids, diff_status, timeout=10):
305360 }
306361 """
307362 params = {"diffids" : diffids }
308- ret = self ._client . query (timeout , query , params )
363+ ret = self ._query (timeout , query , params )
309364 # Example result:
310365 # { "data": {
311366 # "phabricator_diff_query": [
@@ -425,13 +480,13 @@ def getrevisioninfo(self, timeout, signalstatus, *revision_numbers):
425480 ret = self ._mocked_responses .pop ()
426481 else :
427482 params = {"params" : {"numbers" : rev_numbers }}
428- ret = self ._client . query (timeout , self ._getquery (signalstatus ), params )
483+ ret = self ._query (timeout , self ._getquery (signalstatus ), params )
429484 return self ._processrevisioninfo (ret )
430485
431486 def graphqlquery (self , query , variables , timeout = 60_000 ):
432487 if self ._mock :
433488 return self ._mocked_responses .pop ()
434- return self ._client . query (timeout , query , variables )
489+ return self ._query (timeout , query , variables )
435490
436491 def _getquery (self , signalstatus ):
437492 signalquery = ""
@@ -592,7 +647,7 @@ def getmirroredrev(self, fromrepo, fromtype, torepo, totype, rev, timeout=15):
592647 "revs" : [rev ],
593648 }
594649 }
595- ret = self ._client . query (timeout , query , json .dumps (params ))
650+ ret = self ._query (timeout , query , json .dumps (params ))
596651 self ._raise_errors (ret )
597652 for pair in ret ["data" ]["query" ]["rev_map" ]:
598653 if pair ["from_rev" ] == rev :
@@ -625,7 +680,7 @@ def getmirroredrevmap(self, repo, nodes, fromtype, totype, timeout=15):
625680 "revs" : list (map (fromenc , nodes )),
626681 }
627682 }
628- ret = self ._client . query (timeout , query , json .dumps (params ))
683+ ret = self ._query (timeout , query , json .dumps (params ))
629684 self ._raise_errors (ret )
630685 result = {}
631686 for pair in ret ["data" ]["query" ]["rev_map" ]:
@@ -688,7 +743,7 @@ def scmquery_log(
688743 "follow_mutable_file_history" : use_mutable_history ,
689744 }
690745 }
691- ret = self ._client . query (timeout , query , json .dumps (params ))
746+ ret = self ._query (timeout , query , json .dumps (params ))
692747 self ._raise_errors (ret )
693748 return ret ["data" ]["query" ]
694749
@@ -701,7 +756,7 @@ def get_username(self, unixname=None, timeout=10) -> str:
701756 query = "query($u: String!) { intern_user_for_unixname(unixname: $u) { access_name email } }"
702757 params = {"u" : unixname }
703758 # {'data': {'intern_user_for_unixname': {'access_name': 'Name', 'email': 'foo@example.com'}}}
704- ret = self ._client . query (timeout , query , json .dumps (params ))
759+ ret = self ._query (timeout , query , json .dumps (params ))
705760 self ._raise_errors (ret )
706761 data = ret ["data" ]["intern_user_for_unixname" ]
707762 if not data :
0 commit comments