Skip to content

Commit b1536ed

Browse files
moileratVictor Reutenauergrololo06
authored
first and maybe sufficiently minimal usage of pgvector to store deepf… (#85)
* first and sufficiently minimal usage of pgvector to store deepfeature --------- Co-authored-by: Victor Reutenauer <victor@fotonower.com> Co-authored-by: grololo06 <laurent.salinas@laposte.net>
1 parent a8969ea commit b1536ed

File tree

13 files changed

+130
-41
lines changed

13 files changed

+130
-41
lines changed

.github/workflows/auto_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
# Label used to access the service container
2929
postgres:
3030
# Docker Hub image
31-
image: pgvector/pgvector:pg14
31+
image: pgvector/pgvector:0.7.4-pg14
3232
# Provide the password for postgres
3333
env:
3434
POSTGRES_PASSWORD: postgres12

QA/py/pg_files/schem_prod.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ CREATE EXTENSION IF NOT EXISTS tsm_system_time WITH SCHEMA public;
4444

4545
COMMENT ON EXTENSION tsm_system_time IS 'TABLESAMPLE method which accepts time in milliseconds as a limit';
4646

47+
--
48+
-- Name: vector; Type: EXTENSION; Schema: -; Owner: -
49+
--
50+
51+
-- doesn't work and should be there in ankane
52+
-- CREATE EXTENSION IF NOT EXISTS vector;
53+
4754

4855
SET default_tablespace = '';
4956

QA/py/pg_files/upgrade_prod.sql

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2219,6 +2219,31 @@ UPDATE alembic_version SET version_num='4e25988b1e56' WHERE alembic_version.vers
22192219
ALTER TABLE users ADD COLUMN orcid VARCHAR(20) DEFAULT NULL;
22202220
UPDATE alembic_version SET version_num='0a3132f436fb' WHERE alembic_version.version_num = '4e25988b1e56';
22212221

2222+
-- Running upgrade 0a3132f436fb -> a9dd3c62b7b0
2223+
2224+
CREATE TABLE obj_cnn_features_vector (
2225+
objcnnid BIGINT NOT NULL,
2226+
features VECTOR(50),
2227+
PRIMARY KEY (objcnnid),
2228+
FOREIGN KEY(objcnnid) REFERENCES obj_head (objid) ON DELETE CASCADE
2229+
);
2230+
2231+
INSERT INTO obj_cnn_features_vector (objcnnid, features)
2232+
SELECT objcnnid, ARRAY[cnn01, cnn02, cnn03, cnn04, cnn05, cnn06, cnn07, cnn08, cnn09, cnn10,
2233+
cnn11, cnn12, cnn13, cnn14, cnn15, cnn16, cnn17, cnn18, cnn19, cnn20,
2234+
cnn21, cnn22, cnn23, cnn24, cnn25, cnn26, cnn27, cnn28, cnn29, cnn30,
2235+
cnn31, cnn32, cnn33, cnn34, cnn35, cnn36, cnn37, cnn38, cnn39, cnn40,
2236+
cnn41, cnn42, cnn43, cnn44, cnn45, cnn46, cnn47, cnn48, cnn49, cnn50]::vector
2237+
FROM obj_cnn_features;
2238+
2239+
GRANT SELECT ON obj_cnn_features_vector TO readerole;
2240+
2241+
DROP TABLE obj_cnn_features;
2242+
2243+
UPDATE alembic_version SET version_num='a9dd3c62b7b0' WHERE alembic_version.version_num = '0a3132f436fb';
2244+
2245+
COMMIT;
2246+
22222247
------- Leave on tail
22232248

22242249
ALTER TABLE alembic_version REPLICA IDENTITY FULL;

QA/py/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,5 @@ stringcase==1.2.0
8585
# For python API client generated classes
8686
# Fails to build on GH due to obscure compilation issue. Nevermind for the moment.
8787
#backports-datetime-fromisoformat==1.0.0
88+
# pgvector for similarity search
89+
pgvector==0.2.4

py/API_operations/Subset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from BO.TSVFile import TSVFile
1616
from BO.helpers.ImportHelpers import ImportHow
1717
from DB.Acquisition import Acquisition
18-
from DB.CNNFeature import ObjectCNNFeature
18+
from DB.CNNFeatureVector import ObjectCNNFeatureVector
1919
from DB.Image import Image
2020
from DB.Object import ObjectHeader, ObjectFields, ObjectsClassifHisto
2121
from DB.Process import Process
@@ -35,7 +35,7 @@
3535
# Useful typings
3636
# TODO: Put somewhere else if reused in other classes
3737
DBObjectTupleT = Tuple[
38-
ObjectHeader, ObjectFields, ObjectCNNFeature, Image, Sample, Acquisition, Process
38+
ObjectHeader, ObjectFields, ObjectCNNFeatureVector, Image, Sample, Acquisition, Process
3939
]
4040
DBObjectTupleListT = List[DBObjectTupleT]
4141

@@ -164,15 +164,15 @@ def _db_fetch(self, object_ids: ObjectIDListT) -> Iterable[DBObjectTupleT]:
164164
)
165165
ret = (
166166
ret.outerjoin(Image, ObjectHeader.all_images)
167-
.outerjoin(ObjectCNNFeature)
167+
.outerjoin(ObjectCNNFeatureVector)
168168
.join(ObjectFields)
169169
)
170170
ret = ret.filter(ObjectHeader.objid == any_(object_ids))
171171
ret = ret.order_by(ObjectHeader.objid, Image.imgid)
172172
ret = ret.with_entities(
173173
ObjectHeader,
174174
ObjectFields,
175-
ObjectCNNFeature,
175+
ObjectCNNFeatureVector,
176176
Image,
177177
Sample,
178178
Acquisition,

py/BO/Prediction.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# This file is part of Ecotaxa, see license.md in the application root directory for license informations.
3-
# Copyright (C) 2015-2021 Picheral, Colin, Irisson (UPMC-CNRS)
3+
# Copyright (C) 2015-2024 Picheral, Colin, Irisson (UPMC-CNRS), Amblard (LOVNOWER)
44
#
55

66
#
@@ -13,7 +13,7 @@
1313
from numpy import ndarray
1414

1515
from DB.Acquisition import Acquisition
16-
from DB.CNNFeature import DEEP_FEATURES, ObjectCNNFeaturesBean, ObjectCNNFeature
16+
from DB.CNNFeatureVector import N_DEEP_FEATURES, ObjectCNNFeaturesVectorBean, ObjectCNNFeatureVector
1717
from DB.Image import Image
1818
from DB.Object import ObjectHeader, ObjectIDT
1919
from DB.Project import ProjectIDT
@@ -33,7 +33,7 @@ class DeepFeatures(object):
3333
3434
OTOH, it can also _generate_ features, using another class of machine learning algorithm: CNN
3535
@see https://en.wikipedia.org/wiki/Convolutional_neural_network
36-
These other features are stored in a dedicated DB table @see ObjectCNNFeature.
36+
These other features are stored in a dedicated DB table @see ObjectCNNFeatureVector.
3737
"""
3838

3939
SAVE_EVERY: ClassVar = 500
@@ -53,8 +53,8 @@ def delete_all(session: Session, proj_id: ProjectIDT) -> int:
5353
Sample.sampleid == Acquisition.acq_sample_id, Sample.projid == proj_id
5454
),
5555
)
56-
qry = session.query(ObjectCNNFeature)
57-
qry = qry.filter(ObjectCNNFeature.objcnnid.in_(sub_qry))
56+
qry = session.query(ObjectCNNFeatureVector)
57+
qry = qry.filter(ObjectCNNFeatureVector.objcnnid.in_(sub_qry))
5858
nb_deleted = qry.delete(synchronize_session=False)
5959
return nb_deleted
6060

@@ -72,9 +72,9 @@ def find_missing(session: Session, proj_id: ProjectIDT) -> Dict[ObjectIDT, str]:
7272
),
7373
)
7474
qry = qry.outerjoin(Image) # For detecting missing images
75-
qry = qry.outerjoin(ObjectCNNFeature) # For detecting missing features
75+
qry = qry.outerjoin((ObjectCNNFeatureVector)) # For detecting missing features
7676
# noinspection PyComparisonWithNone
77-
qry = qry.filter(ObjectCNNFeature.objcnnid == None) # SQLAlchemy
77+
qry = qry.filter(ObjectCNNFeatureVector.objcnnid == None) # SQLAlchemy
7878
qry = qry.order_by(ObjectHeader.objid, Image.imgrank)
7979
ret = {}
8080
for a_res in session.execute(qry):
@@ -97,7 +97,7 @@ def save(cls, session: Session, features: Any) -> int:
9797
# for a_rec in features.to_records(index=True): # This is nice and can produce tuple()
9898
# but I found no way to feed them into DBWriter without going low-level.
9999
for obj_id, row in features.iterrows():
100-
bean = ObjectCNNFeaturesBean(obj_id, row)
100+
bean = ObjectCNNFeaturesVectorBean(obj_id, row)
101101
writer.add_cnn_features_with_pk(bean)
102102
nb_rows += 1
103103
if nb_rows % cls.SAVE_EVERY == 0:
@@ -112,10 +112,10 @@ def read_for_objects(
112112
"""
113113
Read CNN lines AKA features, in order, for given object_ids
114114
"""
115-
fk_to_objid = ObjectCNNFeature.objcnnid.name
115+
fk_to_objid = ObjectCNNFeatureVector.objcnnid.name
116116
sql = "WITH ordr (seq, objid) AS (select * from UNNEST(:seq, :oids)) "
117-
sql += "SELECT " + ",".join(DEEP_FEATURES)
118-
sql += " FROM " + ObjectCNNFeature.__tablename__
117+
sql += "SELECT " + " features "
118+
sql += " FROM " + ObjectCNNFeatureVector.__tablename__
119119
sql += " JOIN ordr ON " + fk_to_objid + " = ordr.objid "
120120
sql += " ORDER BY ordr.seq "
121121
params = {"seq": list(range(len(oid_lst))), "oids": oid_lst}
@@ -128,12 +128,13 @@ def np_read_for_objects(cls, session: Session, oid_lst: List[int]) -> ndarray:
128128
Read CNN lines AKA features, in order, for given object_ids, into a NumPy array
129129
"""
130130
res = cls.read_for_objects(session, oid_lst)
131-
ret = np.ndarray(shape=(len(oid_lst), len(res.keys())), dtype=np.float32)
131+
ret = np.ndarray(shape=(len(oid_lst), N_DEEP_FEATURES), dtype=np.float32)
132132
ndx = 0
133133
for a_row in res:
134-
ret[ndx] = a_row
134+
all_feats = a_row["features"].strip("[]").split(",") if type(a_row["features"]) == str else a_row["features"]
135+
ret[ndx] = [float(x) for x in all_feats]
135136
ndx += 1
136137
assert ndx == len(
137138
oid_lst
138-
), "No enough CNN features in DB: expected %d read %d" % (len(oid_lst), ndx)
139+
), "Not enough CNN features in DB: expected %d read %d" % (len(oid_lst), ndx)
139140
return ret
Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,36 @@
11
# -*- coding: utf-8 -*-
22
# This file is part of Ecotaxa, see license.md in the application root directory for license informations.
3-
# Copyright (C) 2015-2021 Picheral, Colin, Irisson (UPMC-CNRS)
3+
# Copyright (C) 2022-2024 LOVNOWER : Amblard, Colin, Irisson, Reutenauer (UPMC-CNRS-FOTONOWER)
44
#
55
from typing import List
6+
from pgvector.sqlalchemy import Vector
67

78
from .Object import ObjectIDT
89
from .helpers.Bean import Bean
910
from .helpers.DDL import ForeignKey
1011
from .helpers.ORM import Column, relationship, Model
1112
from .helpers.Postgres import BIGINT, REAL
1213

14+
N_DEEP_FEATURES = 50
1315

14-
class ObjectCNNFeature(Model):
15-
__tablename__ = "obj_cnn_features"
16+
17+
class ObjectCNNFeatureVector(Model):
18+
__tablename__ = "obj_cnn_features_vector"
1619
objcnnid: int = Column(
1720
BIGINT, ForeignKey("obj_head.objid", ondelete="CASCADE"), primary_key=True
1821
)
22+
features: Vector = Column(Vector(N_DEEP_FEATURES))
1923
# The relationships are created in Relations.py but the typing here helps the IDE
2024
object: relationship
2125

2226

23-
# The features in _each_ row
24-
DEEP_FEATURES = ["cnn%02d" % i for i in range(1, 51)]
25-
26-
for a_feat in DEEP_FEATURES:
27-
setattr(ObjectCNNFeature, a_feat, Column(REAL))
28-
29-
30-
class ObjectCNNFeaturesBean(Bean):
27+
class ObjectCNNFeaturesVectorBean(Bean):
3128
"""
3229
A bean for feeding DBWriter.
3330
"""
3431

3532
def __init__(self, obj_id: ObjectIDT, features: List[float]):
36-
super().__init__(zip(DEEP_FEATURES, features))
37-
self["objcnnid"] = obj_id
33+
super().__init__({
34+
"objcnnid": obj_id,
35+
"features": features,
36+
})

py/DB/Relations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Trick to prevent accidental re-export of the DB Models involved
1010
# Note: The trick doesn't work :(
1111
from .Acquisition import Acquisition
12-
from .CNNFeature import ObjectCNNFeature
12+
from .CNNFeatureVector import ObjectCNNFeatureVector
1313
from .Collection import (
1414
Collection,
1515
CollectionProject,
@@ -126,13 +126,13 @@
126126
uselist=False,
127127
)
128128

129-
ObjectCNNFeature.object = relationship(
129+
ObjectCNNFeatureVector.object = relationship(
130130
ObjectHeader,
131131
foreign_keys="ObjectHeader.objid",
132-
primaryjoin="ObjectCNNFeature.objcnnid==ObjectHeader.objid",
132+
primaryjoin="ObjectCNNFeatureVector.objcnnid==ObjectHeader.objid",
133133
uselist=False,
134134
)
135-
ObjectHeader.cnn_features = relationship(ObjectCNNFeature, uselist=False)
135+
ObjectHeader.cnn_features = relationship(ObjectCNNFeatureVector, uselist=False)
136136

137137
ObjectHeader.all_images = relationship(Image)
138138

py/DB/helpers/DBWriter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
# This file is part of Ecotaxa, see license.md in the application root directory for license informations.
33
# Copyright (C) 2015-2020 Picheral, Colin, Irisson (UPMC-CNRS)
44
#
5-
from typing import List, Optional, ClassVar
5+
from typing import Dict, Tuple, List, Type, Optional, ClassVar
66

77
from helpers.DynamicLogs import get_logger
88
from .Bean import Bean
99
from .Direct import text
10-
from .ORM import Session, MetaData, minimal_table_of
10+
from .ORM import Session, Table, MetaData, minimal_table_of
1111
from .Postgres import SequenceCache
12-
from ..CNNFeature import ObjectCNNFeature
12+
from ..CNNFeatureVector import ObjectCNNFeatureVector
1313
from ..Image import Image
1414
from ..Object import ObjectHeader, ObjectFields, ObjectsClassifHisto
1515

@@ -35,7 +35,7 @@ def __init__(self, session: Session):
3535
self.obj_tbl = ObjectHeader.__table__
3636
self.obj_fields_tbl = ObjectFields.__table__ # Slow by default @see narrow_to
3737
self.img_tbl = Image.__table__
38-
self.obj_cnn_tbl = ObjectCNNFeature.__table__
38+
self.obj_cnn_vector_tbl = ObjectCNNFeatureVector.__table__
3939
self.obj_history_tbl = ObjectsClassifHisto.__table__
4040
# Data
4141
self.obj_bulks: List[Bean] = []
@@ -70,7 +70,7 @@ def do_bulk_save(self) -> None:
7070
inserts = [
7171
self.obj_tbl.insert(),
7272
self.obj_fields_tbl.insert(),
73-
self.obj_cnn_tbl.insert(),
73+
self.obj_cnn_vector_tbl.insert(),
7474
self.img_tbl.insert(),
7575
self.obj_history_tbl.insert(),
7676
]
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""similarity search
2+
3+
Revision ID: a9dd3c62b7b0
4+
Revises: 0a3132f436fb
5+
Create Date: 2024-02-19 16:53:45.397975
6+
7+
"""
8+
9+
# revision identifiers, used by Alembic.
10+
revision = "a9dd3c62b7b0"
11+
down_revision = "0a3132f436fb"
12+
13+
import sqlalchemy as sa
14+
from alembic import op
15+
from pgvector.sqlalchemy import Vector # type:ignore
16+
17+
18+
def upgrade():
19+
# ### commands auto generated by Alembic - please adjust! ###
20+
op.create_table(
21+
"obj_cnn_features_vector",
22+
sa.Column("objcnnid", sa.BIGINT(), nullable=False),
23+
sa.Column("features", Vector(dim=50), nullable=True),
24+
sa.ForeignKeyConstraint(["objcnnid"], ["obj_head.objid"], ondelete="CASCADE"),
25+
sa.PrimaryKeyConstraint("objcnnid"),
26+
)
27+
op.execute(
28+
"""
29+
INSERT INTO obj_cnn_features_vector (objcnnid, features)
30+
SELECT objcnnid, ARRAY[cnn01, cnn02, cnn03, cnn04, cnn05, cnn06, cnn07, cnn08, cnn09, cnn10,
31+
cnn11, cnn12, cnn13, cnn14, cnn15, cnn16, cnn17, cnn18, cnn19, cnn20,
32+
cnn21, cnn22, cnn23, cnn24, cnn25, cnn26, cnn27, cnn28, cnn29, cnn30,
33+
cnn31, cnn32, cnn33, cnn34, cnn35, cnn36, cnn37, cnn38, cnn39, cnn40,
34+
cnn41, cnn42, cnn43, cnn44, cnn45, cnn46, cnn47, cnn48, cnn49, cnn50]::vector
35+
FROM obj_cnn_features
36+
"""
37+
)
38+
op.execute(
39+
"""
40+
GRANT SELECT ON obj_cnn_features_vector TO readerole
41+
"""
42+
)
43+
op.drop_table("obj_cnn_features")
44+
# ### end Alembic commands ###
45+
46+
47+
def downgrade():
48+
# ### commands auto generated by Alembic - please adjust! ###
49+
op.drop_table("obj_cnn_features_vector")
50+
# ### end Alembic commands ###

0 commit comments

Comments
 (0)