11# -*- coding: utf-8 -*-
22# This file is part of Ecotaxa, see license.md in the application root directory for license informations.
3- # Copyright (C) 2015-2021 Picheral, Colin, Irisson (UPMC-CNRS)
3+ # Copyright (C) 2015-2024 Picheral, Colin, Irisson (UPMC-CNRS), Amblard (LOVNOWER )
44#
55
66#
1313from numpy import ndarray
1414
1515from DB .Acquisition import Acquisition
16- from DB .CNNFeature import DEEP_FEATURES , ObjectCNNFeaturesBean , ObjectCNNFeature
16+ from DB .CNNFeatureVector import N_DEEP_FEATURES , ObjectCNNFeaturesVectorBean , ObjectCNNFeatureVector
1717from DB .Image import Image
1818from DB .Object import ObjectHeader , ObjectIDT
1919from DB .Project import ProjectIDT
@@ -33,7 +33,7 @@ class DeepFeatures(object):
3333
3434 OTOH, it can also _generate_ features, using another class of machine learning algorithm: CNN
3535 @see https://en.wikipedia.org/wiki/Convolutional_neural_network
36- These other features are stored in a dedicated DB table @see ObjectCNNFeature .
36+ These other features are stored in a dedicated DB table @see ObjectCNNFeatureVector .
3737 """
3838
3939 SAVE_EVERY : ClassVar = 500
@@ -53,8 +53,8 @@ def delete_all(session: Session, proj_id: ProjectIDT) -> int:
5353 Sample .sampleid == Acquisition .acq_sample_id , Sample .projid == proj_id
5454 ),
5555 )
56- qry = session .query (ObjectCNNFeature )
57- qry = qry .filter (ObjectCNNFeature .objcnnid .in_ (sub_qry ))
56+ qry = session .query (ObjectCNNFeatureVector )
57+ qry = qry .filter (ObjectCNNFeatureVector .objcnnid .in_ (sub_qry ))
5858 nb_deleted = qry .delete (synchronize_session = False )
5959 return nb_deleted
6060
@@ -72,9 +72,9 @@ def find_missing(session: Session, proj_id: ProjectIDT) -> Dict[ObjectIDT, str]:
7272 ),
7373 )
7474 qry = qry .outerjoin (Image ) # For detecting missing images
75- qry = qry .outerjoin (ObjectCNNFeature ) # For detecting missing features
75+ qry = qry .outerjoin (( ObjectCNNFeatureVector ) ) # For detecting missing features
7676 # noinspection PyComparisonWithNone
77- qry = qry .filter (ObjectCNNFeature .objcnnid == None ) # SQLAlchemy
77+ qry = qry .filter (ObjectCNNFeatureVector .objcnnid == None ) # SQLAlchemy
7878 qry = qry .order_by (ObjectHeader .objid , Image .imgrank )
7979 ret = {}
8080 for a_res in session .execute (qry ):
@@ -97,7 +97,7 @@ def save(cls, session: Session, features: Any) -> int:
9797 # for a_rec in features.to_records(index=True): # This is nice and can produce tuple()
9898 # but I found no way to feed them into DBWriter without going low-level.
9999 for obj_id , row in features .iterrows ():
100- bean = ObjectCNNFeaturesBean (obj_id , row )
100+ bean = ObjectCNNFeaturesVectorBean (obj_id , row )
101101 writer .add_cnn_features_with_pk (bean )
102102 nb_rows += 1
103103 if nb_rows % cls .SAVE_EVERY == 0 :
@@ -112,10 +112,10 @@ def read_for_objects(
112112 """
113113 Read CNN lines AKA features, in order, for given object_ids
114114 """
115- fk_to_objid = ObjectCNNFeature .objcnnid .name
115+ fk_to_objid = ObjectCNNFeatureVector .objcnnid .name
116116 sql = "WITH ordr (seq, objid) AS (select * from UNNEST(:seq, :oids)) "
117- sql += "SELECT " + "," . join ( DEEP_FEATURES )
118- sql += " FROM " + ObjectCNNFeature .__tablename__
117+ sql += "SELECT " + " features "
118+ sql += " FROM " + ObjectCNNFeatureVector .__tablename__
119119 sql += " JOIN ordr ON " + fk_to_objid + " = ordr.objid "
120120 sql += " ORDER BY ordr.seq "
121121 params = {"seq" : list (range (len (oid_lst ))), "oids" : oid_lst }
@@ -128,12 +128,13 @@ def np_read_for_objects(cls, session: Session, oid_lst: List[int]) -> ndarray:
128128 Read CNN lines AKA features, in order, for given object_ids, into a NumPy array
129129 """
130130 res = cls .read_for_objects (session , oid_lst )
131- ret = np .ndarray (shape = (len (oid_lst ), len ( res . keys ()) ), dtype = np .float32 )
131+ ret = np .ndarray (shape = (len (oid_lst ), N_DEEP_FEATURES ), dtype = np .float32 )
132132 ndx = 0
133133 for a_row in res :
134- ret [ndx ] = a_row
134+ all_feats = a_row ["features" ].strip ("[]" ).split ("," ) if type (a_row ["features" ]) == str else a_row ["features" ]
135+ ret [ndx ] = [float (x ) for x in all_feats ]
135136 ndx += 1
136137 assert ndx == len (
137138 oid_lst
138- ), "No enough CNN features in DB: expected %d read %d" % (len (oid_lst ), ndx )
139+ ), "Not enough CNN features in DB: expected %d read %d" % (len (oid_lst ), ndx )
139140 return ret
0 commit comments