1313import grpc
1414from google .protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
1515
16- from immudb import header_manipulator_client_interceptor
16+ from immudb import grpcutils
17+ from immudb .grpc .schema_pb2 import OpenSessionResponse
1718from immudb .handler import (batchGet , batchSet , changePassword , changePermission , createUser ,
1819 currentRoot , createDatabase , databaseList , deleteKeys , useDatabase ,
1920 get , listUsers , sqldescribe , verifiedGet , verifiedSet , setValue , history ,
2021 scan , reference , verifiedreference , zadd , verifiedzadd ,
2122 zscan , healthcheck , health , txbyid , verifiedtxbyid , sqlexec , sqlquery ,
22- listtables , execAll )
23+ listtables , execAll , transaction )
2324from immudb .rootService import *
2425from immudb .grpc import schema_pb2_grpc
2526import warnings
2627import ecdsa
2728from immudb .datatypes import DeleteKeysRequest
2829from immudb .embedded .store import KVMetadata
30+ import threading
31+ import queue
2932
3033import datetime
3134
@@ -36,13 +39,15 @@ def __init__(self, immudUrl=None, rs: RootService = None, publicKeyFile: str = N
3639 if immudUrl is None :
3740 immudUrl = "localhost:3322"
3841 self .channel = grpc .insecure_channel (immudUrl )
39- self .__stub = schema_pb2_grpc . ImmuServiceStub ( self . channel )
42+ self ._resetStub ( )
4043 if rs is None :
4144 self .__rs = RootService ()
4245 else :
4346 self .__rs = rs
4447 self .__url = immudUrl
4548 self .loadKey (publicKeyFile )
49+ self .__login_response = None
50+ self ._session_response = None
4651
4752 def loadKey (self , kfile : str ):
4853 if kfile is None :
@@ -58,22 +63,29 @@ def shutdown(self):
5863 self .intercept_channel = None
5964 self .__rs = None
6065
66+ def set_session_id_interceptor (self , openSessionResponse ):
67+ sessionId = openSessionResponse .sessionID
68+ self .headersInterceptors = [
69+ grpcutils .header_adder_interceptor ('sessionid' , sessionId )]
70+ return self .get_intercepted_stub ()
71+
6172 def set_token_header_interceptor (self , response ):
6273 try :
6374 token = response .token
6475 except AttributeError :
6576 token = response .reply .token
66- self .header_interceptor = \
67- header_manipulator_client_interceptor .header_adder_interceptor (
77+ self .headersInterceptors = [
78+ grpcutils .header_adder_interceptor (
6879 'authorization' , "Bearer " + token
6980 )
70- try :
71- self .intercept_channel = grpc .intercept_channel (
72- self .channel , self .header_interceptor )
73- except ValueError as e :
74- raise Exception (
75- "Attempted to login on termninated client, channel has been shutdown" ) from e
76- return schema_pb2_grpc .ImmuServiceStub (self .intercept_channel )
81+ ]
82+ return self .get_intercepted_stub ()
83+
84+ def get_intercepted_stub (self ):
85+ intercepted , newStub = grpcutils .get_intercepted_stub (
86+ self .channel , self .headersInterceptors )
87+ self .intercept_channel = intercepted
88+ return newStub
7789
7890 @property
7991 def stub (self ):
@@ -88,12 +100,17 @@ def healthCheck(self):
88100 return healthcheck .call (self .__stub , self .__rs )
89101
90102 # Not implemented: connect
103+ def _convertToBytes (self , what ):
104+ if (type (what ) != bytes ):
105+ return bytes (what , encoding = 'utf-8' )
106+ return what
91107
92108 def login (self , username , password , database = b"defaultdb" ):
93- req = schema_pb2_grpc .schema__pb2 .LoginRequest (user = bytes (
94- username , encoding = 'utf-8' ), password = bytes (
95- password , encoding = 'utf-8'
96- ))
109+ convertedUsername = self ._convertToBytes (username )
110+ convertedPassword = self ._convertToBytes (password )
111+ convertedDatabase = self ._convertToBytes (database )
112+ req = schema_pb2_grpc .schema__pb2 .LoginRequest (
113+ user = convertedUsername , password = convertedPassword )
97114 try :
98115 self .__login_response = schema_pb2_grpc .schema__pb2 .LoginResponse = \
99116 self .__stub .Login (
@@ -105,7 +122,8 @@ def login(self, username, password, database=b"defaultdb"):
105122
106123 self .__stub = self .set_token_header_interceptor (self .__login_response )
107124 # Select database, modifying stub function accordingly
108- request = schema_pb2_grpc .schema__pb2 .Database (databaseName = database )
125+ request = schema_pb2_grpc .schema__pb2 .Database (
126+ databaseName = convertedDatabase )
109127 resp = self .__stub .UseDatabase (request )
110128 self .__stub = self .set_token_header_interceptor (resp )
111129
@@ -115,9 +133,62 @@ def login(self, username, password, database=b"defaultdb"):
115133 def logout (self ):
116134 self .__stub .Logout (google_dot_protobuf_dot_empty__pb2 .Empty ())
117135 self .__login_response = None
136+ self ._resetStub ()
137+
138+ def _resetStub (self ):
139+ self .headersInterceptors = []
140+ self .__stub = schema_pb2_grpc .ImmuServiceStub (self .channel )
118141
119- # Not implemented: openSession
120- # Not implemented: closeSession
142+ def keepAlive (self ):
143+ self .__stub .KeepAlive (google_dot_protobuf_dot_empty__pb2 .Empty ())
144+
145+ def openManagedSession (self , username , password , database = b"defaultdb" , keepAliveInterval = 60 ):
146+ class ManagedSession :
147+ def __init__ (this , keepAliveInterval ):
148+ this .keepAliveInterval = keepAliveInterval
149+ this .keepAliveStarted = False
150+ this .keepAliveProcess = None
151+ this .queue = queue .Queue ()
152+
153+ def manage (this ):
154+ while this .keepAliveStarted :
155+ try :
156+ what = this .queue .get (True , this .keepAliveInterval )
157+ except queue .Empty :
158+ self .keepAlive ()
159+
160+ def __enter__ (this ):
161+ interface = self .openSession (username , password , database )
162+ this .keepAliveStarted = True
163+ this .keepAliveProcess = threading .Thread (target = this .manage )
164+ this .keepAliveProcess .start ()
165+ return interface
166+
167+ def __exit__ (this , type , value , traceback ):
168+ this .keepAliveStarted = False
169+ this .queue .put (b'0' )
170+ self .closeSession ()
171+
172+ return ManagedSession (keepAliveInterval )
173+
174+ def openSession (self , username , password , database = b"defaultdb" ):
175+ convertedUsername = self ._convertToBytes (username )
176+ convertedPassword = self ._convertToBytes (password )
177+ convertedDatabase = self ._convertToBytes (database )
178+ req = schema_pb2_grpc .schema__pb2 .OpenSessionRequest (
179+ username = convertedUsername ,
180+ password = convertedPassword ,
181+ databaseName = convertedDatabase
182+ )
183+ self ._session_response = schema_pb2_grpc .schema__pb2 .OpenSessionResponse = self .__stub .OpenSession (
184+ req )
185+ self .__stub = self .set_session_id_interceptor (self ._session_response )
186+ return transaction .Tx (self .__stub , self ._session_response , self .channel )
187+
188+ def closeSession (self ):
189+ self .__stub .CloseSession (google_dot_protobuf_dot_empty__pb2 .Empty ())
190+ self ._session_response = None
191+ self ._resetStub ()
121192
122193 def createUser (self , user , password , permission , database ):
123194 request = schema_pb2_grpc .schema__pb2 .CreateUserRequest (
@@ -213,7 +284,7 @@ def verifiedGet(self, key: bytes):
213284 return verifiedGet .call (self .__stub , self .__rs , key , verifying_key = self .__vk )
214285
215286 def verifiedGetSince (self , key : bytes , sinceTx : int ):
216- return verifiedGet .call (self .__stub , self .__rs , key , sinceTx , self .__vk )
287+ return verifiedGet .call (self .__stub , self .__rs , key , sinceTx = sinceTx , verifying_key = self .__vk )
217288
218289 def verifiedGetAt (self , key : bytes , atTx : int ):
219290 return verifiedGet .call (self .__stub , self .__rs , key , atTx , self .__vk )
@@ -299,7 +370,7 @@ def sqlExec(self, stmt, params={}, noWait=False):
299370
300371 return sqlexec .call (self .__stub , self .__rs , stmt , params , noWait )
301372
302- def sqlQuery (self , query , params = {}):
373+ def sqlQuery (self , query , params = {}, columnNameMode = constants . COLUMN_NAME_MODE_NONE ):
303374 """Queries the database using SQL
304375 Args:
305376 query: a query in immudb SQL dialect.
@@ -310,7 +381,7 @@ def sqlQuery(self, query, params={}):
310381
311382 ['table1', 'table2']
312383 """
313- return sqlquery .call (self .__stub , self .__rs , query , params )
384+ return sqlquery .call (self .__stub , self .__rs , query , params , columnNameMode )
314385
315386 def listTables (self ):
316387 """List all tables in the current database
@@ -326,7 +397,6 @@ def describeTable(self, table):
326397 return sqldescribe .call (self .__stub , self .__rs , table )
327398
328399 # Not implemented: verifyRow
329- # Not implemented: newTx
330400
331401# deprecated
332402 def databaseCreate (self , dbName : bytes ):
@@ -360,7 +430,6 @@ def safeSet(self, key: bytes, value: bytes): # deprecated
360430
361431# immudb-py only
362432
363-
364433 def getAllValues (self , keys : list ): # immudb-py only
365434 resp = batchGet .call (self .__stub , self .__rs , keys )
366435 return resp
0 commit comments