Skip to content

Commit 421bcd3

Browse files
committed
New tests, prior automatic installation, some other corrections
1 parent dc03e77 commit 421bcd3

File tree

19 files changed

+485
-312
lines changed

19 files changed

+485
-312
lines changed

.DS_Store

-6 KB
Binary file not shown.

files/.DS_Store

-6 KB
Binary file not shown.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ classifiers = [
3030
]
3131
dependencies = [
3232
'celery',
33-
'polars',
3433
'chembl_webresource_client',
3534
'django',
3635
'django-celery-results',

src/conftest.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

src/genui/generators/admin.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from django.contrib import admin
22

3-
# import genui.generators.extensions.genuidrugex.models
4-
# from . import models
3+
import genui.generators.extensions.genuireinvent.models
4+
from . import models
55

6-
# @admin.register(models.Generator)
7-
# class GeneratorAdmin(admin.ModelAdmin):
8-
# pass
6+
@admin.register(models.Generator)
7+
class GeneratorAdmin(admin.ModelAdmin):
8+
pass
99

10-
# @admin.register(genui.generators.extensions.genuidrugex.models.DrugExNet)
11-
# class DrugExNetAdmin(admin.ModelAdmin):
12-
# pass
10+
@admin.register(genui.generators.extensions.genuireinvent.models.ReinventNet)
11+
class ReinventNetAdmin(admin.ModelAdmin):
12+
pass
1313

14-
# @admin.register(genui.generators.extensions.genuidrugex.models.DrugExAgent)
15-
# class DrugExAgentAdmin(admin.ModelAdmin):
16-
# pass
14+
@admin.register(genui.generators.extensions.genuireinvent.models.ReinventAgent)
15+
class ReinventAgentAdmin(admin.ModelAdmin):
16+
pass
-6 KB
Binary file not shown.
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
# genui/src/genui/generators/extensions/genuireinvent/apps.py
2+
13
from django.apps import AppConfig
24

3-
class GenuiReinventConfig(AppConfig):
4-
default_auto_field = "django.db.models.BigAutoField"
5-
name = "genui.generators.extensions.genuireinvent"
6-
label = "genuireinvent"
7-
verbose_name = "GenUI REINVENT Integration"
5+
6+
class GenuireinventConfig(AppConfig):
7+
name = "genui.generators.extensions.genuireinvent"

src/genui/generators/extensions/genuireinvent/genuimodels/builders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class ReinventNetBuilder(bases.ProgressMixIn, bases.ModelBuilder):
2828

2929
def __init__(self, instance: ReinventNet, initial: ReinventNet = None, progress=None):
3030
super().__init__(instance, progress, None)
31+
# super().__init__(instance, progress, getattr(instance, "validationStrategy", None))
3132
self.initial = initial
3233
self.progressStages.append("Creating Corpus...")
3334
self.progressStages.append("Corpus Done.")
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
genuisetup
3+
4+
Registers models from the genuireinvent extension into the default GenUI group.
5+
"""
6+
7+
PARENT = "genui.generators"
8+
9+
10+
def setup(*args, **kwargs):
11+
from genui.utils.init import createGroup
12+
from . import models
13+
14+
createGroup(
15+
"GenUI_Users",
16+
[
17+
# Transfer learning / prior network
18+
models.ReinventNet,
19+
models.ReinventNetTraining,
20+
models.ReinventNetValidation,
21+
models.ModelPerformanceReinvent,
22+
# RL environment + scoring
23+
models.ReinventEnvironment,
24+
models.ReinventEnvironmentScores,
25+
models.ReinventDiversityFilter,
26+
models.ScoreModifier,
27+
models.ClippedScore,
28+
models.SmoothHump,
29+
models.ScoringMethod,
30+
models.PropertyScorer,
31+
models.GenUIModelScorer,
32+
models.UnwantedSmartsScorer,
33+
# RL agent + staged learning generator
34+
models.ReinventAgent,
35+
models.ReinventAgentTraining,
36+
models.ReinventAgentValidation,
37+
models.Reinvent,
38+
models.ReinventStage,
39+
],
40+
force=kwargs.get("force", False),
41+
)

src/genui/generators/extensions/genuireinvent/models.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from django.conf import settings
2020
from django.core.files.base import ContentFile
21-
from django.db import models
21+
from django.db import models, close_old_connections
2222
from django.utils import timezone
2323

2424
from genui.compounds.models import MolSet, ActivitySet
@@ -156,6 +156,7 @@ class ReinventNet(Model):
156156
CHECKPOINT_FILE_NOTE = "reinvent_tl_checkpoint" # where REINVENT writes
157157
CORPUS_TRAIN_NOTE = "reinvent_corpus_train"
158158
CORPUS_VALID_NOTE = "reinvent_corpus_valid"
159+
PRIOR_FILE_NOTE = "reinvent_prior_copy"
159160

160161
molset = models.ForeignKey(MolSet, on_delete=models.CASCADE, null=True)
161162
parent = models.ForeignKey("self", on_delete=models.CASCADE, null=True)
@@ -207,7 +208,33 @@ def get_clean_corpus_path(self) -> str:
207208
return self.corpusFullFile.path
208209

209210
# ── Prior path ────────────────────────────────────────────────
211+
@property
212+
def priorFile(self) -> ModelFile:
213+
# stored as an AUX file tied to this model
214+
return self._get_or_create_aux(self.PRIOR_FILE_NOTE, f"reinvent_prior_{self.pk}.prior")
215+
216+
def ensure_prior_copy(self) -> ModelFile:
217+
"""
218+
Make sure this model has its own prior copy (for reproducibility).
219+
"""
220+
src = _resolve_reinvent_prior_path()
221+
with open(src, "rb") as f:
222+
data = f.read()
223+
_overwrite_filefield(self.priorFile, data, filename=os.path.basename(src))
224+
return self.priorFile
225+
210226
def get_prior_path(self) -> str:
227+
"""
228+
Prefer the project/model-owned copy if present, else fall back to global.
229+
"""
230+
mf = self.files.filter(kind=ModelFile.AUXILIARY, note=self.PRIOR_FILE_NOTE).first()
231+
if mf and mf.file:
232+
try:
233+
# local storage
234+
return mf.file.path
235+
except Exception:
236+
pass
237+
# fallback (global location ensured by genuisetup)
211238
return _resolve_reinvent_prior_path()
212239

213240
# ── Clean corpus preparation (hashed AUX only) ─────────────────────────────
@@ -547,7 +574,6 @@ def to_reinvent(self):
547574
d["penalty_multiplier"] = self.penalty_multiplier
548575
return d
549576

550-
# TODO: predelat na Dataset
551577
class ReinventEnvironment(DataSet):
552578
name = models.CharField(max_length=255)
553579

@@ -749,7 +775,6 @@ class ReinventAgent(Model):
749775
environment = models.ForeignKey(ReinventEnvironment, on_delete=models.PROTECT)
750776
training = models.ForeignKey(ReinventAgentTraining, on_delete=models.PROTECT)
751777
validation = models.ForeignKey(ReinventAgentValidation, null=True, blank=True, on_delete=models.SET_NULL)
752-
#TODO pouzit Models aby dedilo logiku
753778
output_model = models.ForeignKey(ModelFile, null=True, blank=True, on_delete=models.SET_NULL)
754779

755780
# moved from Generator (keep defaults to avoid breaking existing configs)

0 commit comments

Comments
 (0)