Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase performance with tf.function decorator #234

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

julesdesir
Copy link
Collaborator

@julesdesir julesdesir commented Feb 27, 2025

Description of the goal of the PR

Description: increase performance with tf.function decorator

Changes this PR introduces (fill it before implementation)

  • : ENH: make @tf.function work with train_step and compute_batch_loss
  • : MAINT: small fix in notebook
  • : TST: small fix in test_get_negative_samples

Checklist before requesting a review

  • I have commented my code, particularly in hard-to-understand areas
  • I have typed my code
  • I have created / updated the docstrings
  • I have updated the README, if relevant
  • I have updated the requirements files if a new package is used
  • I have tested my code
  • The CI pipeline passes
  • I have performed a self-review of my code

Copy link

Important

The terms of service for this installation has not been accepted. Please ask the Organization owners to visit the Gemini Code Assist Admin Console to sign it.

@julesdesir julesdesir self-assigned this Feb 27, 2025
@julesdesir julesdesir added the enhancement New feature or request label Feb 27, 2025
Copy link
Contributor

Coverage

Coverage Report for Python 3.9
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py480100% 
choice_learn/basket_models
   __init__.py30100% 
   dataset.py117497%71–74
   preprocessing.py947817%43–45, 128–364
   shopper.py3582792%165, 194, 343, 363, 378, 381, 395, 684–688, 781–785, 883–887, 1218, 1297, 1335–1336, 1440–1441, 1517–1518
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6473395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2392390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 577
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py393599%43–44, 154–155, 715
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py2541295%144, 186, 283, 302, 342, 349, 378, 397, 428–429, 438–439
   baseline_models.py490100% 
   conditional_logit.py2362191%46, 49, 51, 82, 85, 88–92, 95–99, 133, 298, 335, 392, 467–473, 598, 632, 739, 743
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2362360%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
TOTAL479170085% 

Tests Skipped Failures Errors Time
165 0 💤 0 ❌ 0 🔥 4m 48s ⏱️

Copy link
Contributor

Coverage

Coverage Report for Python 3.11
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py480100% 
choice_learn/basket_models
   __init__.py30100% 
   dataset.py117497%71–74
   preprocessing.py947817%43–45, 128–364
   shopper.py3582792%165, 194, 343, 363, 378, 381, 395, 684–688, 781–785, 883–887, 1218, 1297, 1335–1336, 1440–1441, 1517–1518
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6473395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2392390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 577
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py393499%39, 154–155, 715
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py2541295%144, 186, 283, 302, 342, 349, 378, 397, 428–429, 438–439
   baseline_models.py490100% 
   conditional_logit.py2362191%46, 49, 51, 82, 85, 88–92, 95–99, 133, 298, 335, 392, 467–473, 598, 632, 739, 743
   halo_mnl.py1241885%186, 341, 360, 364–380
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
TOTAL479371785% 

Tests Skipped Failures Errors Time
165 0 💤 1 ❌ 0 🔥 5m 12s ⏱️

Copy link
Contributor

Coverage

Coverage Report for Python 3.10
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py480100% 
choice_learn/basket_models
   __init__.py30100% 
   dataset.py117497%71–74
   preprocessing.py947817%43–45, 128–364
   shopper.py3582792%165, 194, 343, 363, 378, 381, 395, 684–688, 781–785, 883–887, 1218, 1297, 1335–1336, 1440–1441, 1517–1518
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6473395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2392390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 577
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py393499%39, 154–155, 715
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py2541295%144, 186, 283, 302, 342, 349, 378, 397, 428–429, 438–439
   baseline_models.py490100% 
   conditional_logit.py2362191%46, 49, 51, 82, 85, 88–92, 95–99, 133, 298, 335, 392, 467–473, 598, 632, 739, 743
   halo_mnl.py1241885%186, 341, 360, 364–380
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
TOTAL479371785% 

Tests Skipped Failures Errors Time
165 0 💤 1 ❌ 0 🔥 5m 12s ⏱️

Copy link
Contributor

Coverage

Coverage Report for Python 3.12
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py480100% 
choice_learn/basket_models
   __init__.py30100% 
   dataset.py117497%71–74
   preprocessing.py947817%43–45, 128–364
   shopper.py3582792%165, 194, 343, 363, 378, 381, 395, 684–688, 781–785, 883–887, 1218, 1297, 1335–1336, 1440–1441, 1517–1518
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6473395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2392390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 577
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py393499%39, 154–155, 715
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py2541295%144, 186, 283, 302, 342, 349, 378, 397, 428–429, 438–439
   baseline_models.py490100% 
   conditional_logit.py2362191%46, 49, 51, 82, 85, 88–92, 95–99, 133, 298, 335, 392, 467–473, 598, 632, 739, 743
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
TOTAL479370185% 

Tests Skipped Failures Errors Time
165 0 💤 0 ❌ 0 🔥 6m 7s ⏱️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant