Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions quotientai/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __init__(self, api_key: str):
self.api_key = api_key
self.token = None
self.token_expiry = 0
self._token_path = token_dir / ".quotient" / "auth_token.json"
self.token_api_key = None
self._token_path = token_dir / ".quotient" / f"{api_key[-6:]+'_' if api_key else ''}auth_token.json"

# Try to load existing token
self._load_token()
Expand Down Expand Up @@ -59,7 +60,7 @@ def _save_token(self, token: str, expiry: int):
return None
# Save to disk
with open(self._token_path, "w") as f:
json.dump({"token": token, "expires_at": expiry}, f)
json.dump({"token": token, "expires_at": expiry, "api_key": self.api_key}, f)

def _load_token(self):
"""Load token from disk if available"""
Expand All @@ -71,14 +72,20 @@ def _load_token(self):
data = json.load(f)
self.token = data.get("token")
self.token_expiry = data.get("expires_at", 0)
self.token_api_key = data.get("api_key")
except Exception:
# If loading fails, token remains None
pass

def _is_token_valid(self):
"""Check if token exists and is not expired"""
self._load_token()

if not self.token:
return False

if self.token_api_key != self.api_key:
return False

# With 5-minute buffer
return time.time() < (self.token_expiry - 300)
Expand Down
11 changes: 9 additions & 2 deletions quotientai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(self, api_key: str):
self.api_key = api_key
self.token = None
self.token_expiry = 0
self._token_path = token_dir / ".quotient" / "auth_token.json"
self.token_api_key = None
self._token_path = token_dir / ".quotient" / f"{api_key[-6:]+'_' if api_key else ''}auth_token.json"

# Try to load existing token
self._load_token()
Expand Down Expand Up @@ -62,7 +63,7 @@ def _save_token(self, token: str, expiry: int):

# Save to disk
with open(self._token_path, "w") as f:
json.dump({"token": token, "expires_at": expiry}, f)
json.dump({"token": token, "expires_at": expiry, "api_key": self.api_key}, f)

def _load_token(self):
"""Load token from disk if available"""
Expand All @@ -74,14 +75,20 @@ def _load_token(self):
data = json.load(f)
self.token = data.get("token")
self.token_expiry = data.get("expires_at", 0)
self.token_api_key = data.get("api_key")
except Exception:
# If loading fails, token remains None
pass

def _is_token_valid(self):
"""Check if token exists and is not expired"""
self._load_token()

if not self.token:
return False

if self.token_api_key != self.api_key:
return False

# With 5-minute buffer
return time.time() < (self.token_expiry - 300)
Expand Down
43 changes: 26 additions & 17 deletions tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,15 @@ def test_initialization(self, tmp_path):
# Use a clean temporary directory for token storage
token_dir = tmp_path / ".quotient"

# Test successful home directory case
with patch('pathlib.Path.home', return_value=tmp_path):
client = _AsyncQuotientClient(api_key)

assert client.api_key == api_key
assert client.token is None
assert client.token_expiry == 0
assert client.token_api_key is None
assert client.headers["Authorization"] == f"Bearer {api_key}"
assert client._token_path == tmp_path / ".quotient" / "auth_token.json"

# Test fallback to /root when home fails
with patch('pathlib.Path.home', side_effect=Exception("Test error")), \
patch('pathlib.Path.exists', return_value=True):
client = _AsyncQuotientClient(api_key)
assert client._token_path == Path("/root/.quotient/auth_token.json")

# Test fallback to cwd when home fails and /root doesn't exist
with patch('pathlib.Path.home', side_effect=Exception("Test error")), \
patch('pathlib.Path.exists', return_value=False):
client = _AsyncQuotientClient(api_key)
assert client._token_path == Path.cwd() / ".quotient" / "auth_token.json"
assert client._token_path == tmp_path / ".quotient" / f"{api_key[-6:]}_auth_token.json"

def test_handle_jwt_response(self):
"""Test that _handle_response properly processes JWT tokens"""
Expand Down Expand Up @@ -119,11 +108,12 @@ def test_save_token(self, tmp_path):
assert client.token_expiry == test_expiry

# Verify token was saved to disk
token_file = tmp_path / ".quotient" / "auth_token.json"
token_file = tmp_path / ".quotient" / f"{client.api_key[-6:]}_auth_token.json"
assert token_file.exists()
stored_data = json.loads(token_file.read_text())
assert stored_data["token"] == test_token
assert stored_data["expires_at"] == test_expiry
assert stored_data["api_key"] == client.api_key

def test_load_token(self, tmp_path):
"""Test that _load_token reads token data correctly"""
Expand All @@ -135,17 +125,19 @@ def test_load_token(self, tmp_path):
# Write a token file
token_dir = tmp_path / ".quotient"
token_dir.mkdir(parents=True)
token_file = token_dir / "auth_token.json"
token_file = token_dir / f"{client.api_key[-6:]}_auth_token.json"
token_file.write_text(json.dumps({
"token": test_token,
"expires_at": test_expiry
"expires_at": test_expiry,
"api_key": client.api_key
}))

# Load the token
client._load_token()

assert client.token == test_token
assert client.token_expiry == test_expiry
assert client.token_api_key == client.api_key

def test_is_token_valid(self, tmp_path):
"""Test token validity checking"""
Expand All @@ -159,16 +151,25 @@ def test_is_token_valid(self, tmp_path):
# Test with expired token
client.token = "expired.token"
client.token_expiry = int(time.time()) - 3600 # 1 hour ago
client.token_api_key = client.api_key
assert not client._is_token_valid()

# Test with valid token
client.token = "valid.token"
client.token_expiry = int(time.time()) + 3600 # 1 hour from now
client.token_api_key = client.api_key
assert client._is_token_valid()

# Test with token about to expire (within 5 minute buffer)
client.token = "about.to.expire"
client.token_expiry = int(time.time()) + 200 # Less than 5 minutes
client.token_api_key = client.api_key
assert not client._is_token_valid()

# Test with mismatched API key
client.token = "valid.token"
client.token_expiry = int(time.time()) + 3600
client.token_api_key = "different-api-key"
assert not client._is_token_valid()

def test_update_auth_header(self, tmp_path):
Expand All @@ -185,13 +186,21 @@ def test_update_auth_header(self, tmp_path):
test_token = "test.jwt.token"
client.token = test_token
client.token_expiry = int(time.time()) + 3600
client.token_api_key = client.api_key
client._update_auth_header()
assert client.headers["Authorization"] == f"Bearer {test_token}"

# Should revert to API key when token expires
client.token_expiry = int(time.time()) - 3600
client._update_auth_header()
assert client.headers["Authorization"] == f"Bearer {client.api_key}"

# Should revert to API key when API key doesn't match
client.token = test_token
client.token_expiry = int(time.time()) + 3600
client.token_api_key = "different-api-key"
client._update_auth_header()
assert client.headers["Authorization"] == f"Bearer {client.api_key}"

def test_token_directory_creation_failure(self, tmp_path, caplog):
"""Test that appropriate error is raised when token directory creation fails"""
Expand Down
34 changes: 28 additions & 6 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def test_initialization(self, tmp_path):
assert client.api_key == api_key
assert client.token is None
assert client.token_expiry == 0
assert client.token_api_key is None
assert client.headers["Authorization"] == f"Bearer {api_key}"
assert client._token_path == tmp_path / ".quotient" / f"{api_key[-6:]}_auth_token.json"

def test_handle_jwt_response(self):
"""Test that _handle_response properly processes JWT tokens"""
Expand Down Expand Up @@ -111,11 +113,12 @@ def test_save_token(self, tmp_path):
assert client.token_expiry == test_expiry

# Verify token was saved to disk
token_file = tmp_path / ".quotient" / "auth_token.json"
token_file = tmp_path / ".quotient" / f"{client.api_key[-6:]}_auth_token.json"
assert token_file.exists()
stored_data = json.loads(token_file.read_text())
assert stored_data["token"] == test_token
assert stored_data["expires_at"] == test_expiry
assert stored_data["api_key"] == client.api_key

def test_load_token(self, tmp_path):
"""Test that _load_token reads token data correctly"""
Expand All @@ -127,17 +130,19 @@ def test_load_token(self, tmp_path):
# Write a token file
token_dir = tmp_path / ".quotient"
token_dir.mkdir(parents=True)
token_file = token_dir / "auth_token.json"
token_file = token_dir / f"{client.api_key[-6:]}_auth_token.json"
token_file.write_text(json.dumps({
"token": test_token,
"expires_at": test_expiry
"expires_at": test_expiry,
"api_key": client.api_key
}))

# Load the token
client._load_token()

assert client.token == test_token
assert client.token_expiry == test_expiry
assert client.token_api_key == client.api_key

def test_is_token_valid(self, tmp_path):
"""Test token validity checking"""
Expand All @@ -151,16 +156,25 @@ def test_is_token_valid(self, tmp_path):
# Test with expired token
client.token = "expired.token"
client.token_expiry = int(time.time()) - 3600 # 1 hour ago
client.token_api_key = client.api_key
assert not client._is_token_valid()

# Test with valid token
client.token = "valid.token"
client.token_expiry = int(time.time()) + 3600 # 1 hour from now
client.token_api_key = client.api_key
assert client._is_token_valid()

# Test with token about to expire (within 5 minute buffer)
client.token = "about.to.expire"
client.token_expiry = int(time.time()) + 200 # Less than 5 minutes
client.token_api_key = client.api_key
assert not client._is_token_valid()

# Test with mismatched API key
client.token = "valid.token"
client.token_expiry = int(time.time()) + 3600
client.token_api_key = "different-api-key"
assert not client._is_token_valid()

def test_update_auth_header(self, tmp_path):
Expand All @@ -177,19 +191,27 @@ def test_update_auth_header(self, tmp_path):
test_token = "test.jwt.token"
client.token = test_token
client.token_expiry = int(time.time()) + 3600
client.token_api_key = client.api_key
client._update_auth_header()
assert client.headers["Authorization"] == f"Bearer {test_token}"

# Should revert to API key when token expires
client.token_expiry = int(time.time()) - 3600
client._update_auth_header()
assert client.headers["Authorization"] == f"Bearer {client.api_key}"

# Should revert to API key when API key doesn't match
client.token = test_token
client.token_expiry = int(time.time()) + 3600
client.token_api_key = "different-api-key"
client._update_auth_header()
assert client.headers["Authorization"] == f"Bearer {client.api_key}"

def test_token_path_uses_home(self):
with patch('pathlib.Path.home') as mock_home:
mock_home.return_value = Path('/home/user')
client = _BaseQuotientClient('test-key')
assert client._token_path == Path('/home/user/.quotient/auth_token.json')
assert client._token_path == Path('/home/user/.quotient/st-key_auth_token.json')

def test_token_path_fallback_to_root(self):
with patch('pathlib.Path.home') as mock_home, \
Expand All @@ -200,7 +222,7 @@ def test_token_path_fallback_to_root(self):
mock_exists.return_value = True

client = _BaseQuotientClient('test-key')
assert client._token_path == Path('/root/.quotient/auth_token.json')
assert client._token_path == Path('/root/.quotient/st-key_auth_token.json')

def test_token_path_fallback_to_cwd(self):
with patch('pathlib.Path.home') as mock_home, \
Expand All @@ -214,7 +236,7 @@ def test_token_path_fallback_to_cwd(self):
mock_cwd.return_value = Path('/current/dir')

client = _BaseQuotientClient('test-key')
assert client._token_path == Path('/current/dir/.quotient/auth_token.json')
assert client._token_path == Path('/current/dir/.quotient/st-key_auth_token.json')

def test_handle_jwt_token_success(self):
client = _BaseQuotientClient('test-key')
Expand Down