Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions intentkit/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class UserUpdate(BaseModel):
str | None, Field(None, description="User's Telegram username")
]
extra: Annotated[
dict | None, Field(None, description="Additional user information")
dict[str, object] | None, Field(None, description="Additional user information")
]
evm_wallet_address: Annotated[
str | None, Field(None, description="User's EVM wallet address")
Expand All @@ -151,7 +151,8 @@ class UserUpdate(BaseModel):
str | None, Field(None, description="User's Solana wallet address")
]
linked_accounts: Annotated[
dict | None, Field(None, description="User's linked accounts information")
dict[str, object] | None,
Field(None, description="User's linked accounts information"),
]

async def _update_quota_for_nft_count(
Expand Down Expand Up @@ -348,3 +349,26 @@ async def get_by_tg(cls, telegram_username: str) -> UserModelType | None:
if user is None:
return None
return user_model_class.model_validate(user)

@classmethod
async def get_by_evm_wallet(cls, evm_wallet_address: str) -> UserModelType | None:
"""Get a user by EVM wallet address or matching ID."""
user_model_class = user_model_registry.get_user_model_class()
assert issubclass(user_model_class, User)
user_table_class = user_model_registry.get_user_table_class()
assert issubclass(user_table_class, UserTable)

async with get_session() as session:
result = await session.execute(
select(user_table_class).where(
user_table_class.evm_wallet_address == evm_wallet_address
)
)
user = result.scalars().first()
if user is not None:
return user_model_class.model_validate(user)

fallback_user = await session.get(user_table_class, evm_wallet_address)
if fallback_user is None:
return None
return user_model_class.model_validate(fallback_user)
58 changes: 58 additions & 0 deletions tests/models/test_user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
import pytest_asyncio
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool

from intentkit.models.base import Base
from intentkit.models.user import User, UserTable


@pytest_asyncio.fixture()
async def sqlite_engine():
from intentkit.models import db as db_module

test_engine = create_async_engine(
"sqlite+aiosqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
db_module.engine = test_engine

async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

try:
yield test_engine
finally:
await test_engine.dispose()
db_module.engine = None


@pytest.mark.asyncio
async def test_get_by_evm_wallet(sqlite_engine):
session_factory = async_sessionmaker(sqlite_engine, expire_on_commit=False)

async with session_factory() as session:
session.add(
UserTable(
id="user_with_wallet",
evm_wallet_address="0x123",
)
)
session.add(
UserTable(
id="user_id_only",
)
)
await session.commit()

direct_match = await User.get_by_evm_wallet("0x123")
assert direct_match is not None
assert direct_match.id == "user_with_wallet"

id_fallback = await User.get_by_evm_wallet("user_id_only")
assert id_fallback is not None
assert id_fallback.id == "user_id_only"

missing = await User.get_by_evm_wallet("0xnotfound")
assert missing is None
Loading