Skip to content

Commit 5f54197

Browse files
feat: include learned abbreviations in venue intelligence dataset [AI-assisted] (#1011)
Updated the acronym export/import functionality to include learned abbreviations, creating a more comprehensive 'venue intelligence' dataset. - Enhanced AcronymCache with abbreviation export, import, and clear methods - Updated CLI 'export' and 'import' to use a unified JSON dataset format - Added --include-abbreviations flag to 'acronym clear' command - Added 'acronym abbreviation-stats' command - Maintained backward compatibility for legacy list-only import format Co-authored-by: florath-ai-assistant[bot] <Andreas.Florath@telekom.de>
1 parent 1557936 commit 5f54197

File tree

2 files changed

+170
-20
lines changed

2 files changed

+170
-20
lines changed

src/aletheia_probe/cache/acronym_cache.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,26 @@ def export_all_variants(self) -> list[dict[str, Any]]:
358358
)
359359
return [dict(row) for row in cursor.fetchall()]
360360

361+
def export_all_abbreviations(self) -> list[dict[str, Any]]:
362+
"""Export all learned abbreviations from the database.
363+
364+
Returns:
365+
List of dictionaries containing all abbreviation fields.
366+
"""
367+
detail_logger.debug("Exporting all learned abbreviations")
368+
369+
with self.get_connection_with_row_factory() as conn:
370+
cursor = conn.cursor()
371+
cursor.execute(
372+
"""
373+
SELECT abbreviated_form, expanded_form, confidence_score,
374+
occurrence_count, context
375+
FROM learned_abbreviations
376+
ORDER BY abbreviated_form, expanded_form
377+
"""
378+
)
379+
return [dict(row) for row in cursor.fetchall()]
380+
361381
def import_variants(
362382
self, variants: list[dict[str, Any]], merge: bool = True
363383
) -> int:
@@ -413,6 +433,46 @@ def import_variants(
413433

414434
return count
415435

436+
def import_abbreviations(
437+
self, abbreviations: list[dict[str, Any]], merge: bool = True
438+
) -> int:
439+
"""Import learned abbreviations into the database.
440+
441+
Args:
442+
abbreviations: List of abbreviation dictionaries.
443+
merge: If True, merges with existing data.
444+
445+
Returns:
446+
Number of abbreviations imported/updated.
447+
"""
448+
if not abbreviations:
449+
return 0
450+
451+
detail_logger.debug(f"Importing {len(abbreviations)} learned abbreviations")
452+
count = 0
453+
454+
for abbrev in abbreviations:
455+
# Validate required fields
456+
if not all(k in abbrev for k in ["abbreviated_form", "expanded_form"]):
457+
detail_logger.warning(f"Skipping invalid abbreviation: {abbrev}")
458+
continue
459+
460+
# Default values for missing fields
461+
confidence = abbrev.get("confidence_score", 0.1)
462+
context = abbrev.get("context")
463+
464+
# Use store_learned_abbreviation to handle merge logic
465+
self.store_learned_abbreviation(
466+
abbrev=abbrev["abbreviated_form"],
467+
expanded=abbrev["expanded_form"],
468+
confidence=confidence,
469+
context=context,
470+
log_prefix="[import] ",
471+
)
472+
count += 1
473+
474+
return count
475+
416476
def list_all_acronyms(
417477
self, entity_type: str | None = None, limit: int | None = None, offset: int = 0
418478
) -> list[dict[str, str]]:
@@ -540,6 +600,35 @@ def clear_acronym_database(self, entity_type: str | None = None) -> int:
540600
)
541601
return count
542602

603+
def clear_learned_abbreviations(self) -> int:
604+
"""Clear all entries from the learned abbreviations database.
605+
606+
Returns:
607+
Number of entries deleted
608+
"""
609+
detail_logger.debug("Clearing entire learned abbreviations database")
610+
611+
with self.get_connection() as conn:
612+
cursor = conn.cursor()
613+
614+
# Get count before deletion
615+
cursor.execute("SELECT COUNT(*) FROM learned_abbreviations")
616+
result = cursor.fetchone()
617+
count = result[0] if result else 0
618+
detail_logger.debug(f"Found {count} abbreviations to delete")
619+
620+
# Delete all entries
621+
cursor.execute("DELETE FROM learned_abbreviations")
622+
detail_logger.debug(
623+
"Deleted all entries from learned abbreviations database"
624+
)
625+
626+
conn.commit()
627+
detail_logger.debug(
628+
f"Abbreviation clear operation completed, {count} entries deleted"
629+
)
630+
return count
631+
543632
def mark_acronym_as_ambiguous(
544633
self, acronym: str, entity_type: str, venues: list[str] | None = None
545634
) -> None:

src/aletheia_probe/cli.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,22 @@ def stats() -> None:
412412
status_logger.info(f"Total acronyms: {total:,}")
413413

414414

415+
@acronym.command(name="abbreviation-stats")
416+
@handle_cli_errors
417+
def abbreviation_stats() -> None:
418+
"""Show statistics about the learned abbreviations database."""
419+
status_logger = get_status_logger()
420+
421+
acronym_cache = AcronymCache()
422+
abbreviations = acronym_cache.export_all_abbreviations()
423+
total = len(abbreviations)
424+
425+
if total == 0:
426+
status_logger.info("No abbreviations learned yet")
427+
else:
428+
status_logger.info(f"Total learned abbreviations: {total:,}")
429+
430+
415431
@acronym.command(name="list")
416432
@click.option("--limit", type=int, help="Maximum number of entries to display")
417433
@click.option("--offset", type=int, default=0, help="Number of entries to skip")
@@ -450,7 +466,7 @@ def list_acronyms(limit: int | None, offset: int) -> None:
450466
@click.argument("output_file", type=click.Path())
451467
@handle_cli_errors
452468
def export(output_file: str) -> None:
453-
"""Export the entire acronym database to a JSON file.
469+
"""Export the entire acronym and abbreviation database to a JSON file.
454470
455471
Args:
456472
output_file: Path to the output JSON file.
@@ -459,16 +475,27 @@ def export(output_file: str) -> None:
459475
acronym_cache = AcronymCache()
460476

461477
variants = acronym_cache.export_all_variants()
478+
abbreviations = acronym_cache.export_all_abbreviations()
479+
480+
dataset = {
481+
"acronyms": variants,
482+
"abbreviations": abbreviations,
483+
"metadata": {
484+
"version": "1.0",
485+
"type": "aletheia-probe-venue-intelligence",
486+
},
487+
}
462488

463489
try:
464490
with open(output_file, "w", encoding="utf-8") as f:
465-
json.dump(variants, f, indent=2, ensure_ascii=False)
491+
json.dump(dataset, f, indent=2, ensure_ascii=False)
466492

467493
status_logger.info(
468-
f"Successfully exported {len(variants)} acronym variants to {output_file}"
494+
f"Successfully exported {len(variants)} acronym variants and "
495+
f"{len(abbreviations)} learned abbreviations to {output_file}"
469496
)
470497
except Exception as e:
471-
status_logger.error(f"Failed to export acronyms: {e}")
498+
status_logger.error(f"Failed to export dataset: {e}")
472499
raise click.ClickException(str(e)) from e
473500

474501

@@ -481,7 +508,9 @@ def export(output_file: str) -> None:
481508
)
482509
@handle_cli_errors
483510
def import_acronyms(input_file: str, merge: bool) -> None:
484-
"""Import acronyms from a JSON file.
511+
"""Import acronyms and abbreviations from a JSON file.
512+
513+
Supports both unified dataset format and legacy list-only format.
485514
486515
Args:
487516
input_file: Path to the input JSON file.
@@ -492,44 +521,69 @@ def import_acronyms(input_file: str, merge: bool) -> None:
492521

493522
try:
494523
with open(input_file, encoding="utf-8") as f:
495-
variants = json.load(f)
496-
497-
if not isinstance(variants, list):
498-
raise ValueError("Input file must contain a JSON list of variants")
499-
500-
status_logger.info(f"Read {len(variants)} variants from {input_file}")
524+
data = json.load(f)
525+
526+
variants = []
527+
abbreviations = []
528+
529+
# Detect format
530+
if isinstance(data, list):
531+
# Legacy format: list of variants
532+
variants = data
533+
status_logger.info(f"Detected legacy format with {len(variants)} variants")
534+
elif isinstance(data, dict):
535+
# Unified format
536+
variants = data.get("acronyms", [])
537+
abbreviations = data.get("abbreviations", [])
538+
status_logger.info(
539+
f"Detected unified format with {len(variants)} variants and "
540+
f"{len(abbreviations)} abbreviations"
541+
)
542+
else:
543+
raise ValueError("Input file must contain a JSON list or object")
501544

502545
if not merge:
503546
if click.confirm(
504-
"This will clear existing acronyms before importing. Continue?",
547+
"This will clear existing data before importing. Continue?",
505548
abort=True,
506549
):
507550
acronym_cache.clear_acronym_database()
551+
acronym_cache.clear_learned_abbreviations()
508552

509-
count = acronym_cache.import_variants(variants, merge=True)
553+
variant_count = acronym_cache.import_variants(variants, merge=True)
554+
abbrev_count = acronym_cache.import_abbreviations(abbreviations, merge=True)
510555

511-
status_logger.info(f"Successfully imported {count} acronym variants")
556+
status_logger.info(
557+
f"Successfully imported {variant_count} acronym variants and "
558+
f"{abbrev_count} learned abbreviations"
559+
)
512560

513561
except Exception as e:
514-
status_logger.error(f"Failed to import acronyms: {e}")
562+
status_logger.error(f"Failed to import dataset: {e}")
515563
raise click.ClickException(str(e)) from e
516564

517565

518566
@acronym.command()
519567
@click.option("--confirm", is_flag=True, help="Skip confirmation prompt")
568+
@click.option(
569+
"--include-abbreviations", is_flag=True, help="Also clear learned abbreviations"
570+
)
520571
@handle_cli_errors
521-
def clear(confirm: bool) -> None:
522-
"""Clear all entries from the acronym database.
572+
def clear(confirm: bool, include_abbreviations: bool) -> None:
573+
"""Clear entries from the acronym database.
523574
524575
Args:
525576
confirm: Whether to skip the confirmation prompt.
577+
include_abbreviations: Whether to also clear learned abbreviations.
526578
"""
527579
status_logger = get_status_logger()
528580

581+
msg = "This will delete all conference acronym mappings."
582+
if include_abbreviations:
583+
msg = "This will delete all acronym mappings AND learned abbreviations."
584+
529585
if not confirm:
530-
click.confirm(
531-
"This will delete all conference acronym mappings. Continue?", abort=True
532-
)
586+
click.confirm(f"{msg} Continue?", abort=True)
533587

534588
acronym_cache = AcronymCache()
535589
count = acronym_cache.clear_acronym_database()
@@ -539,6 +593,13 @@ def clear(confirm: bool) -> None:
539593
else:
540594
status_logger.info(f"Cleared {count:,} acronym mapping(s).")
541595

596+
if include_abbreviations:
597+
abbrev_count = acronym_cache.clear_learned_abbreviations()
598+
if abbrev_count == 0:
599+
status_logger.info("Learned abbreviations database is already empty.")
600+
else:
601+
status_logger.info(f"Cleared {abbrev_count:,} learned abbreviation(s).")
602+
542603

543604
@acronym.command()
544605
@click.argument("acronym")

0 commit comments

Comments
 (0)