1+ import atexit
12import os
23import shutil
3- import atexit
4- import bencoder
54from hashlib import sha256
5+ from typing import Any
6+
67from argon2 import low_level
7- from jmbase import aes_cbc_encrypt , aes_cbc_decrypt
8+ from fastbencode import bdecode , bencode_utf8
9+
10+ from jmbase import aes_cbc_decrypt , aes_cbc_encrypt
11+
812from .support import get_random_bytes
913
1014
1115class Argon2Hash (object ):
12- def __init__ (self , password , salt = None , hash_len = 32 , salt_len = 16 ,
13- time_cost = 500 , memory_cost = 1000 , parallelism = 4 ,
14- argon2_type = low_level .Type .I , version = 19 ):
16+ def __init__ (
17+ self ,
18+ password ,
19+ salt = None ,
20+ hash_len = 32 ,
21+ salt_len = 16 ,
22+ time_cost = 500 ,
23+ memory_cost = 1000 ,
24+ parallelism = 4 ,
25+ argon2_type = low_level .Type .I ,
26+ version = 19 ,
27+ ):
1528 """
1629 args:
1730 password: password as bytes
@@ -30,11 +43,12 @@ def __init__(self, password, salt=None, hash_len=32, salt_len=16,
3043 'parallelism' : parallelism ,
3144 'hash_len' : hash_len ,
3245 'type' : argon2_type ,
33- 'version' : version
46+ 'version' : version ,
3447 }
3548 self .salt = salt if salt is not None else get_random_bytes (salt_len )
36- self .hash = low_level .hash_secret_raw (password , self .salt ,
37- ** self .settings )
49+ self .hash = low_level .hash_secret_raw (
50+ password , self .salt , ** self .settings
51+ )
3852
3953
4054class StorageError (Exception ):
@@ -62,8 +76,9 @@ class Storage(object):
6276
6377 KDF: argon2, ENC: AES-256-CBC
6478 """
79+
6580 MAGIC_UNENC = b'JMWALLET'
66- MAGIC_ENC = b'JMENCWLT'
81+ MAGIC_ENC = b'JMENCWLT'
6782 MAGIC_DETECT_ENC = b'JMWALLET'
6883
6984 ENC_KEY_BYTES = 32 # AES-256
@@ -116,7 +131,10 @@ def was_changed(self):
116131 return self ._data_checksum != self ._get_data_checksum ()
117132
118133 def check_password (self , password ):
119- return self ._hash .hash == self ._hash_password (password , self ._hash .salt ).hash
134+ return (
135+ self ._hash .hash
136+ == self ._hash_password (password , self ._hash .salt ).hash
137+ )
120138
121139 def change_password (self , password ):
122140 if self .read_only :
@@ -128,7 +146,7 @@ def save(self):
128146 """
129147 Write file to disk if data was modified
130148 """
131- #if not self.was_changed():
149+ # if not self.was_changed():
132150 # return
133151 if self .read_only :
134152 raise StorageError ("Read-only storage cannot be saved." )
@@ -149,7 +167,7 @@ def _get_file_magic(cls, path):
149167 return fh .read (len (cls .MAGIC_ENC ))
150168
151169 def _get_data_checksum (self ):
152- if self .data is None : #pragma: no cover
170+ if self .data is None : # pragma: no cover
153171 return None
154172 return sha256 (self ._serialize (self .data )).digest ()
155173
@@ -167,7 +185,8 @@ def _set_hash(self, password):
167185 self ._hash = self ._hash_password (password )
168186
169187 def _save_file (self ):
170- assert self .read_only == False
188+ if self .read_only :
189+ raise StorageError ("Read-only storage cannot be saved." )
171190 data = self ._serialize (self .data )
172191 enc_data = self ._encrypt_file (data )
173192
@@ -181,13 +200,17 @@ def _load_file(self, password):
181200 magic = data [:8 ]
182201
183202 if magic not in (self .MAGIC_ENC , self .MAGIC_UNENC ):
184- raise StorageError ("File does not appear to be a joinmarket wallet." )
203+ raise StorageError (
204+ "File does not appear to be a joinmarket wallet."
205+ )
185206
186207 data = data [8 :]
187208
188209 if magic == self .MAGIC_ENC :
189210 if password is None :
190- raise RetryableStorageError ("Password required to open wallet." )
211+ raise RetryableStorageError (
212+ "Password required to open wallet."
213+ )
191214 data = self ._decrypt_file (password , data )
192215 else :
193216 assert magic == self .MAGIC_UNENC
@@ -211,7 +234,7 @@ def _write_file(self, data):
211234 shutil .copystat (self .path , tmpfile )
212235 fh .write (data )
213236
214- #FIXME: behaviour with symlinks might be weird
237+ # FIXME: behaviour with symlinks might be weird
215238 shutil .move (tmpfile , self .path )
216239
217240 def _read_file (self ):
@@ -223,12 +246,12 @@ def get_location(self):
223246 return self .path
224247
225248 @staticmethod
226- def _serialize (data ) :
227- return bencoder . bencode (data )
249+ def _serialize (data : Any ) -> bytes :
250+ return bencode_utf8 (data )
228251
229252 @staticmethod
230- def _deserialize (data ) :
231- return bencoder . bdecode (data )
253+ def _deserialize (data : bytes ) -> Any :
254+ return bdecode (data )
232255
233256 def _encrypt_file (self , data ):
234257 if not self .is_encrypted ():
@@ -237,7 +260,7 @@ def _encrypt_file(self, data):
237260 iv = get_random_bytes (16 )
238261 container = {
239262 b'enc' : {b'salt' : self ._hash .salt , b'iv' : iv },
240- b'data' : self ._encrypt (data , iv )
263+ b'data' : self ._encrypt (data , iv ),
241264 }
242265 return self ._serialize (container )
243266
@@ -253,8 +276,9 @@ def _decrypt_file(self, password, data):
253276 return self ._decrypt (container [b'data' ], container [b'enc' ][b'iv' ])
254277
255278 def _encrypt (self , data : bytes , iv : bytes ) -> bytes :
256- return aes_cbc_encrypt (self ._hash .hash ,
257- self .MAGIC_DETECT_ENC + data , iv )
279+ return aes_cbc_encrypt (
280+ self ._hash .hash , self .MAGIC_DETECT_ENC + data , iv
281+ )
258282
259283 def _decrypt (self , data : bytes , iv : bytes ) -> bytes :
260284 try :
@@ -265,12 +289,16 @@ def _decrypt(self, data: bytes, iv: bytes) -> bytes:
265289
266290 if not dec_data .startswith (self .MAGIC_DETECT_ENC ):
267291 raise StoragePasswordError ("Wrong password." )
268- return dec_data [len (self .MAGIC_DETECT_ENC ):]
292+ return dec_data [len (self .MAGIC_DETECT_ENC ) :]
269293
270294 @classmethod
271295 def _hash_password (cls , password , salt = None ):
272- return Argon2Hash (password , salt ,
273- hash_len = cls .ENC_KEY_BYTES , salt_len = cls .SALT_LENGTH )
296+ return Argon2Hash (
297+ password ,
298+ salt ,
299+ hash_len = cls .ENC_KEY_BYTES ,
300+ salt_len = cls .SALT_LENGTH ,
301+ )
274302
275303 @staticmethod
276304 def _get_lock_filename (path : str ) -> str :
@@ -296,8 +324,9 @@ def verify_lock(cls, path: str):
296324 raise RetryableStorageError (
297325 "File is currently in use (locked by pid {}). "
298326 "If this is a leftover from a crashed instance "
299- "you need to remove the lock file `{}` manually." .
300- format (locked_by_pid , cls ._get_lock_filename (path ))
327+ "you need to remove the lock file `{}` manually." .format (
328+ locked_by_pid , cls ._get_lock_filename (path )
329+ )
301330 )
302331
303332 def _create_lock (self ):
0 commit comments