Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions moneyflow/backends/ynab.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,40 @@ async def get_all_merchants(self) -> List[str]:
"""
return self.client.get_all_merchants()

def batch_update_merchant(
self, old_merchant_name: str, new_merchant_name: str
) -> Dict[str, Any]:
"""
Batch update all transactions with a given merchant name (YNAB optimization).

This is a YNAB-specific optimization that updates the payee once instead
of updating each transaction individually. This cascades to all transactions
with that payee, making bulk renames 100x faster.

**Performance**:
- Traditional: 100 transactions = 100 API calls
- Optimized: 100 transactions = 1 API call

Args:
old_merchant_name: Current merchant/payee name to rename
new_merchant_name: New merchant/payee name

Returns:
Dictionary with results (see YNABClient.batch_update_merchant for format)

Example:
>>> backend = YNABBackend()
>>> await backend.login(password=token)
>>> result = backend.batch_update_merchant("Amazon.com/abc", "Amazon")
>>> if result['success']:
... print(f"Optimized: Updated payee {result['payee_id']}")

Note:
This method is synchronous (not async) because the YNAB SDK is synchronous.
Other backends may not support this optimization.
"""
return self.client.batch_update_merchant(old_merchant_name, new_merchant_name)

def get_currency_symbol(self) -> str:
"""
Get the currency symbol from YNAB budget settings.
Expand Down
187 changes: 148 additions & 39 deletions moneyflow/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,9 +902,12 @@ async def commit_pending_edits(self, edits: List[Any]) -> Tuple[int, int]:
"""
Commit pending edits to backend API in parallel.

This method groups edits by transaction ID (in case multiple edits
affect the same transaction) and sends update requests in parallel
for maximum speed.
This method intelligently optimizes commits based on backend capabilities:
- For backends with batch_update_merchant (e.g., YNAB), bulk merchant
renames are handled with a single API call per (old, new) pair instead
of one call per transaction (100x performance improvement)
- For other backends, or non-merchant edits, uses individual transaction
updates in parallel for maximum speed

The method is resilient to partial failures - if some updates fail,
others will still succeed. The caller receives counts for both.
Expand Down Expand Up @@ -934,48 +937,154 @@ async def commit_pending_edits(self, edits: List[Any]) -> Tuple[int, int]:
logger.info("No edits to commit")
return 0, 0

# Group edits by transaction ID
edits_by_txn: Dict[str, Dict[str, Any]] = {}
for edit in edits:
txn_id = edit.transaction_id
if txn_id not in edits_by_txn:
edits_by_txn[txn_id] = {}

if edit.field == "merchant":
edits_by_txn[txn_id]["merchant_name"] = edit.new_value
elif edit.field == "category":
edits_by_txn[txn_id]["category_id"] = edit.new_value
elif edit.field == "hide_from_reports":
edits_by_txn[txn_id]["hide_from_reports"] = edit.new_value

# Create update tasks
tasks = []
for txn_id, updates in edits_by_txn.items():
tasks.append(self.mm.update_transaction(transaction_id=txn_id, **updates))

# Execute in parallel
results = await asyncio.gather(*tasks, return_exceptions=True)

# Count successes and failures, and log errors
success_count = 0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Unvalidated external input in batch merchant update (medium severity)

The old_merchant_name and new_merchant_name from user edits are passed directly to batch_update_merchant without validation. While YNAB client does some validation (empty/length checks), SQL injection-style attacks could occur if the backend uses these names in database queries. Add input validation here to sanitize merchant names (max length, allowed characters, no SQL/NoSQL injection patterns) before passing to backend.


Automated security review by Claude 4.5 Sonnet - Human review still required

failure_count = 0

# Check for auth errors that should trigger retry
auth_errors = []

for i, result in enumerate(results):
if isinstance(result, Exception):
failure_count += 1
logger.error(
f"Transaction update {i + 1}/{len(results)} FAILED: {result}", exc_info=result
# Check if backend supports batch merchant updates
has_batch_update = hasattr(self.mm, "batch_update_merchant")

# Separate merchant edits from other edits
merchant_edits = [e for e in edits if e.field == "merchant"]
other_edits = [e for e in edits if e.field != "merchant"]

# OPTIMIZATION: Group merchant edits by (old_value, new_value) for batch updates
if has_batch_update and merchant_edits:
logger.info(
f"Backend supports batch updates - optimizing {len(merchant_edits)} merchant edits"
)

# Group merchant edits by (old_name, new_name)
merchant_groups: Dict[Tuple[str, str], List[Any]] = {}
for edit in merchant_edits:
key = (edit.old_value, edit.new_value)
if key not in merchant_groups:
merchant_groups[key] = []
merchant_groups[key].append(edit)

# Try batch update for each (old, new) pair
successfully_batched_edits = []
failed_batch_edits = []

# Track processed transaction IDs to prevent double-counting
# Note: If the same transaction has multiple merchant edits (e.g., A→B then B→C),
# they'll be in different batch groups. We can only batch one of them.
processed_txn_ids = set()

for (old_name, new_name), group_edits in merchant_groups.items():
# Filter out edits for transactions already processed in a different batch
unprocessed_edits = [
e for e in group_edits if e.transaction_id not in processed_txn_ids
]

if not unprocessed_edits:
# All edits in this group were already processed in a previous batch
# Add them to failed list so they get individual processing with latest values
failed_batch_edits.extend(group_edits)
continue

group_edits = unprocessed_edits # Only batch the unprocessed ones
group_txn_ids = {e.transaction_id for e in group_edits}
logger.info(
f"Attempting batch update: '{old_name}' -> '{new_name}' "
f"({len(group_edits)} transactions)"
)

# Check if it's a 401/auth error
error_str = str(result).lower()
if "401" in error_str or "unauthorized" in error_str:
auth_errors.append(result)
else:
success_count += 1
try:
# Call batch_update_merchant in thread to avoid blocking event loop
result = await asyncio.to_thread(
self.mm.batch_update_merchant, # type: ignore[attr-defined]
old_name,
new_name,
)

if result.get("success"):
# Batch update succeeded - mark edits as processed and count as successful
processed_txn_ids.update(group_txn_ids)
success_count += len(group_edits)
successfully_batched_edits.extend(group_edits)
logger.info(
f"✓ Batch update succeeded for '{old_name}' -> '{new_name}' "
f"({len(group_edits)} transactions updated via 1 API call)"
)
else:
# Batch update failed - mark as processed but add to fallback list
processed_txn_ids.update(group_txn_ids)
logger.warning(
f"Batch update failed for '{old_name}' -> '{new_name}': "
f"{result.get('message', 'Unknown error')}. "
f"Falling back to individual transaction updates."
)
failed_batch_edits.extend(group_edits)

except Exception as e:
# Exception during batch - mark as processed and add to fallback list
processed_txn_ids.update(group_txn_ids)
logger.warning(
f"Batch update exception for '{old_name}' -> '{new_name}': {e}. "
f"Falling back to individual transaction updates.",
exc_info=True,
)
failed_batch_edits.extend(group_edits)

# Add failed batch edits back to the list for individual processing
merchant_edits = failed_batch_edits

# Safety check: ensure no overlap between successful and failed batches
successful_ids = {e.transaction_id for e in successfully_batched_edits}
failed_ids = {e.transaction_id for e in failed_batch_edits}
overlap = successful_ids & failed_ids
assert not overlap, (
f"Found {len(overlap)} edits in both successful and failed batches - "
"this indicates a race condition or logic error"
)

# Process remaining edits (non-merchant + failed batch updates) individually
edits_to_process = merchant_edits + other_edits

if edits_to_process:
logger.info(
f"Processing {len(edits_to_process)} edits individually "
f"({len(merchant_edits)} merchant, {len(other_edits)} other)"
)

# Group edits by transaction ID
edits_by_txn: Dict[str, Dict[str, Any]] = {}
for edit in edits_to_process:
txn_id = edit.transaction_id
if txn_id not in edits_by_txn:
edits_by_txn[txn_id] = {}

if edit.field == "merchant":
edits_by_txn[txn_id]["merchant_name"] = edit.new_value
elif edit.field == "category":
edits_by_txn[txn_id]["category_id"] = edit.new_value
elif edit.field == "hide_from_reports":
edits_by_txn[txn_id]["hide_from_reports"] = edit.new_value

# Create update tasks
tasks = []
for txn_id, updates in edits_by_txn.items():
tasks.append(self.mm.update_transaction(transaction_id=txn_id, **updates))

# Execute in parallel
results = await asyncio.gather(*tasks, return_exceptions=True)

# Count successes and failures
for i, result in enumerate(results):
if isinstance(result, Exception):
failure_count += 1
logger.error(
f"Transaction update {i + 1}/{len(results)} FAILED: {result}",
exc_info=result,
)

# Check if it's a 401/auth error
error_str = str(result).lower()
if "401" in error_str or "unauthorized" in error_str:
auth_errors.append(result)
else:
success_count += 1

logger.info(f"Commit completed: {success_count} succeeded, {failure_count} failed")

Expand Down
Loading