Skip to content

[hma][api] bank_get_content can optionally return signals #1763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
58 changes: 56 additions & 2 deletions hasher-matcher-actioner/src/OpenMediaMatch/blueprints/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ class BankedContentMetadata(t.TypedDict):
json: t.NotRequired[dict[t.Any, t.Any]]


class BankContentResponse(t.TypedDict):
id: int
bank: str
enabled: bool
original_media_uri: t.Optional[str]
signals: t.NotRequired[dict[str, str]]


bp = Blueprint("curation", __name__)
bp.register_error_handler(HTTPException, flask_utils.api_error_handler)

Expand Down Expand Up @@ -136,14 +144,60 @@ def _validate_bank_add_metadata() -> t.Optional[BankedContentMetadata]:

@bp.route("/bank/<bank_name>/content/<content_id>", methods=["GET"])
def bank_get_content(bank_name: str, content_id: int):
"""
Get content from a bank by ID.

Query Parameters:
signal_type (optional): If specified, includes the signal value for this signal type

Returns: JSON representation of the bank content

Without signal_type parameter:
{
'id': 1234,
'bank': 'TEST_BANK',
'enabled': true,
'original_media_uri': 'file:///data/media/uploaded_content_123.jpg'
}

With signal_type parameter:
{
'id': 1234,
'bank': 'TEST_BANK',
'enabled': true,
'original_media_uri': 'file:///data/media/uploaded_content_123.jpg',
'signals': {
'pdq': 'f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22'
}
}
"""
storage = persistence.get_storage()
bank = storage.get_bank(bank_name)
if not bank:
abort(404, f"bank '{bank_name}' not found")
content = storage.bank_content_get([content_id])

signal_type = request.args.get("signal_type")
if signal_type:
signal_type_cfgs = storage.get_signal_type_configs()
if signal_type not in signal_type_cfgs:
abort(400, f"No such signal type '{signal_type}'")

content = storage.bank_content_get([content_id], signal_type)
if not content:
abort(404, f"content '{content_id}' not found")
return jsonify(content[0])

content_obj = content[0]
response: BankContentResponse = {
"id": content_obj.id,
"bank": content_obj.bank.name,
"enabled": content_obj.enabled,
"original_media_uri": content_obj.original_media_uri,
}

if signal_type and content_obj.signals:
response["signals"] = content_obj.signals

return response


@bp.route("/bank/<bank_name>/content", methods=["POST"])
Expand Down
33 changes: 29 additions & 4 deletions hasher-matcher-actioner/src/OpenMediaMatch/storage/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

import abc
from dataclasses import dataclass
from dataclasses import dataclass, field
import typing as t
import time

Expand Down Expand Up @@ -308,7 +308,11 @@ class BankContentConfig:
Represents all the signals (hashes) for one piece of content.

When signals come from external sources, or the original content
has been lost
has been lost.

Signals are only included in this object when explicitly requested via
signal_type parameter in bank_content_get().
The signals field will be None unless specifically requested.
"""

ENABLED: t.ClassVar[int] = 1
Expand All @@ -330,6 +334,10 @@ class BankContentConfig:

bank: BankConfig

# Dictionary mapping signal_type names to signal values
# which is only populated when explicitly requested
signals: t.Optional[dict[str, str]] = None

@property
def enabled(self) -> bool:
if self.disable_until_ts == 0:
Expand Down Expand Up @@ -409,8 +417,25 @@ def bank_delete(self, name: str) -> None:

# Bank content
@abc.abstractmethod
def bank_content_get(self, id: t.Iterable[int]) -> t.Sequence[BankContentConfig]:
"""Get the content config for a bank"""
def bank_content_get(
self, id: t.Iterable[int], *, include_signals: bool = false
) -> t.Sequence[BankContentConfig]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsure: As an alternative to including it in the bank content config, another solution would be to return this as Mapping[int, Ret] where ret could include the config and the requested signals separately.

"""
Get the content config for a bank.

Args:
id: The IDs of the bank content to retrieve
signal_type: Optional signal type to include in the response.
If provided, signals of this type will be included in the result.
If not provided, no signals will be fetched, significantly improving
performance by avoiding unnecessary database joins.

Returns:
List of bank content configs. The 'signals' field will only be populated
when signal_type is specified, and will only contain signals of the requested type.
When signal_type is not specified, the 'signals' field will not be present in the
returned objects, helping to reduce response size and improve clarity.
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsure: As an alternative, we could provide a method that allows fetching the signals one at a time directly. The upsides would be that it's simpler to implement multiple interfaces, and would allow fetching strictly only the signals.

@abc.abstractmethod
def bank_content_update(self, val: BankContentConfig) -> None:
Expand Down
25 changes: 22 additions & 3 deletions hasher-matcher-actioner/src/OpenMediaMatch/storage/mocked.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,29 @@ def bank_delete(self, name: str) -> None:
self.banks.pop(name, None)

def bank_content_get(
self, id: t.Iterable[int]
self, id: t.Iterable[int], signal_type: t.Optional[str] = None
) -> t.Sequence[interface.BankContentConfig]:
# TODO
raise Exception("Not implemented")
mock_bank = self.get_bank("MOCK_BANK")
if mock_bank is None:
mock_bank = interface.BankConfig("MOCK_BANK", matching_enabled_ratio=1.0)

results = []
for content_id in id:
config = interface.BankContentConfig(
id=content_id,
disable_until_ts=interface.BankContentConfig.ENABLED,
collab_metadata={},
original_media_uri=None,
bank=mock_bank,
)

if signal_type is not None:
# For mocked data, we'll just use the ID as the signal value.
config.signals = {signal_type: f"mock_signal_{content_id}"}

results.append(config)

return results

def bank_content_update(self, val: interface.BankContentConfig) -> None:
# TODO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,33 @@ def set_typed_config(self, cfg: BankContentConfig) -> t.Self:
self.disable_until_ts = cfg.disable_until_ts
return self

def as_storage_iface_cls(self) -> BankContentConfig:
return BankContentConfig(
self.id,
def as_storage_iface_cls(self, *, include_signals: bool = False) -> BankContentConfig:
"""
Convert a database BankContent record to a BankContentConfig interface object.

Args:
include_signals: If True, include the signals in the returned object.
If False (default), signals will not be included, which
improves performance by avoiding potentially expensive
lazy loading of the signals relationship.

Returns:
A BankContentConfig with signals only included if explicitly requested.
"""
content_config = BankContentConfig(
id=self.id,
disable_until_ts=self.disable_until_ts,
collab_metadata={},
original_media_uri=None,
original_media_uri=self.original_content_uri,
bank=self.bank.as_storage_iface_cls(),
)

# Only include signals when explicitly requested
if include_signals:
content_config.signals = {s.signal_type: s.signal_val for s in self.signals}

return content_config


class ContentSignal(db.Model): # type: ignore[name-defined]
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,14 +523,44 @@ def bank_delete(self, name: str) -> None:
database.db.session.commit()

def bank_content_get(
self, ids: t.Iterable[int]
self, ids: t.Iterable[int], signal_type: t.Optional[str] = None
) -> t.Sequence[interface.BankContentConfig]:
return [
b.as_storage_iface_cls()
for b in database.db.session.query(database.BankContent)
.filter(database.BankContent.id.in_(ids))
.all()
]
query = database.db.session.query(database.BankContent)

if signal_type is not None:
query = query.outerjoin(
database.ContentSignal,
(database.ContentSignal.content_id == database.BankContent.id)
& (database.ContentSignal.signal_type == signal_type),
)
query = query.options(
joinedload(database.BankContent.signals).load_only(
database.ContentSignal.signal_type,
database.ContentSignal.signal_val,
)
)

query = query.filter(database.BankContent.id.in_(ids))
bank_contents = query.all()

result = []
for bc in bank_contents:
content_config = bc.as_storage_iface_cls(include_signals=False)

if signal_type is not None:
# If there's matching signals, add them to the content config
content_config.signals = {}
matching_signals = [
s for s in bc.signals if s.signal_type == signal_type
]
if matching_signals:
content_config.signals = {
signal_type: matching_signals[0].signal_val
}

result.append(content_config)

return result

def bank_content_update(self, val: interface.BankContentConfig) -> None:
sesh = database.db.session
Expand Down
34 changes: 34 additions & 0 deletions hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api_banks.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,37 @@ def test_banks_add_hash_index(app: Flask, client: FlaskClient):
)
assert post_response.status_code == 200
assert post_response.json == {"matches": [2]}


def test_bank_get_content_signal_validation(client: FlaskClient):
"""Test signal type validation for bank content retrieval"""
bank_name = "TEST_BANK"
create_bank(client, bank_name)

# Add some content with a PDQ signal
response = client.post(
f"/c/bank/{bank_name}/signal",
json={"pdq": "0" * 64},
)
assert response.status_code == 200
response_json = response.get_json()
assert response_json is not None
content_id = response_json["id"]

# Test that requesting a valid signal type returns the correct signal value
response = client.get(f"/c/bank/{bank_name}/content/{content_id}?signal_type=pdq")
assert response.status_code == 200
response_json = response.get_json()
assert response_json is not None
assert "signals" in response_json
assert response_json["signals"] == {"pdq": "0" * 64}

# Test that requesting a non-existent signal type returns an error
response = client.get(
f"/c/bank/{bank_name}/content/{content_id}?signal_type=invalid"
)
assert response.status_code == 400
response_json = response.get_json()
assert response_json is not None
assert "message" in response_json
assert "No such signal type" in response_json["message"]
Loading