11# -*- coding: utf-8 -*-
2+ import uuid
23from itertools import chain
3- from typing import Dict
44
5+ import fsspec
56import pytest
67
78from pyathena .filesystem .s3 import S3File , S3FileSystem
89from tests import ENV
910from tests .pyathena .conftest import connect
1011
1112
13+ @pytest .fixture (scope = "class" )
14+ def register_filesystem ():
15+ fsspec .register_implementation ("s3" , "pyathena.filesystem.s3.S3FileSystem" , clobber = True )
16+ fsspec .register_implementation ("s3a" , "pyathena.filesystem.s3.S3FileSystem" , clobber = True )
17+
18+
19+ @pytest .mark .usefixtures ("register_filesystem" )
1220class TestS3FileSystem :
1321 def test_parse_path (self ):
1422 actual = S3FileSystem .parse_path ("s3://bucket" )
@@ -109,35 +117,34 @@ def test_parse_path_invalid(self):
109117 S3FileSystem .parse_path ("s3a://bucket/path/to/obj?foo=bar" )
110118
111119 @pytest .fixture (scope = "class" )
112- def fs (self ) -> Dict [str , S3FileSystem ]:
113- fs = {
114- "default" : S3FileSystem (connect ()),
115- "small_batches" : S3FileSystem (connect (), default_block_size = 3 ),
116- }
117- return fs
120+ def fs (self , request ):
121+ if not hasattr (request , "param" ):
122+ setattr (request , "param" , {})
123+ return S3FileSystem (connect (), ** request .param )
118124
119125 @pytest .mark .parametrize (
120- ["start " , "end " , "batch_mode " , "target_data" ],
126+ ["fs " , "start " , "end " , "target_data" ],
121127 list (
122128 chain (
123129 * [
124130 [
125- (0 , 5 , x , b"01234" ),
126- (2 , 7 , x , b"23456" ),
127- (0 , 10 , x , b"0123456789" ),
131+ ({ "default_block_size" : x }, 0 , 5 , b"01234" ),
132+ ({ "default_block_size" : x }, 2 , 7 , b"23456" ),
133+ ({ "default_block_size" : x }, 0 , 10 , b"0123456789" ),
128134 ]
129- for x in ("default" , "small_batches" )
135+ for x in (S3FileSystem . DEFAULT_BLOCK_SIZE , 3 )
130136 ]
131137 )
132138 ),
139+ indirect = ["fs" ],
133140 )
134- def test_read (self , fs , start , end , batch_mode , target_data ):
141+ def test_read (self , fs , start , end , target_data ):
135142 # lowest level access: use _get_object
136- data = fs [ batch_mode ] ._get_object (
143+ data = fs ._get_object (
137144 ENV .s3_staging_bucket , ENV .s3_filesystem_test_file_key , ranges = (start , end )
138145 )
139146 assert data == (start , target_data ), data
140- with fs [ batch_mode ] .open (
147+ with fs .open (
141148 f"s3://{ ENV .s3_staging_bucket } /{ ENV .s3_filesystem_test_file_key } " , "rb"
142149 ) as file :
143150 # mid-level access: use _fetch_range
@@ -148,7 +155,84 @@ def test_read(self, fs, start, end, batch_mode, target_data):
148155 data = file .read (end - start )
149156 assert data == target_data , data
150157
151- def test_compatibility_with_s3fs (self ):
158+ @pytest .mark .parametrize (
159+ ["base" , "exp" ],
160+ [
161+ (1 , 2 ** 10 ),
162+ (10 , 2 ** 10 ),
163+ (100 , 2 ** 10 ),
164+ (1 , 2 ** 20 ),
165+ (10 , 2 ** 20 ),
166+ (100 , 2 ** 20 ),
167+ (1000 , 2 ** 20 ),
168+ ],
169+ )
170+ def test_write (self , fs , base , exp ):
171+ data = b"a" * (base * exp )
172+ path = f"s3://{ ENV .s3_staging_bucket } /{ ENV .s3_staging_key } { ENV .schema } /{ uuid .uuid4 ()} .dat"
173+ with fs .open (path , "wb" ) as f :
174+ f .write (data )
175+ with fs .open (path , "rb" ) as f :
176+ actual = f .read ()
177+ assert len (actual ) == len (data )
178+ assert actual == data
179+
180+ @pytest .mark .parametrize (
181+ ["base" , "exp" ],
182+ [
183+ (1 , 2 ** 10 ),
184+ (10 , 2 ** 10 ),
185+ (100 , 2 ** 10 ),
186+ (1 , 2 ** 20 ),
187+ (10 , 2 ** 20 ),
188+ (100 , 2 ** 20 ),
189+ (1000 , 2 ** 20 ),
190+ ],
191+ )
192+ def test_append (self , fs , base , exp ):
193+ # TODO: Check the metadata is kept.
194+ data = b"a" * (base * exp )
195+ path = f"s3://{ ENV .s3_staging_bucket } /{ ENV .s3_staging_key } { ENV .schema } /{ uuid .uuid4 ()} .dat"
196+ with fs .open (path , "ab" ) as f :
197+ f .write (data )
198+ extra = b"extra"
199+ with fs .open (path , "ab" ) as f :
200+ f .write (extra )
201+ with fs .open (path , "rb" ) as f :
202+ actual = f .read ()
203+ assert len (actual ) == len (data + extra )
204+ assert actual == data + extra
205+
206+ def test_exists (self , fs ):
207+ path = f"s3://{ ENV .s3_staging_bucket } /{ ENV .s3_filesystem_test_file_key } "
208+ assert fs .exists (path )
209+
210+ not_exists_path = (
211+ f"s3://{ ENV .s3_staging_bucket } /{ ENV .s3_staging_key } { ENV .schema } /{ uuid .uuid4 ()} "
212+ )
213+ assert not fs .exists (not_exists_path )
214+
215+ def test_touch (self , fs ):
216+ path = f"s3://{ ENV .s3_staging_bucket } /{ ENV .s3_staging_key } { ENV .schema } /{ uuid .uuid4 ()} "
217+ assert not fs .exists (path )
218+ fs .touch (path )
219+ assert fs .exists (path )
220+ assert fs .size (path ) == 0
221+
222+ with fs .open (path , "wb" ) as f :
223+ f .write (b"data" )
224+ assert fs .size (path ) == 4
225+ fs .touch (path , truncate = True )
226+ assert fs .size (path ) == 0
227+
228+ with fs .open (path , "wb" ) as f :
229+ f .write (b"data" )
230+ assert fs .size (path ) == 4
231+ with pytest .raises (ValueError ):
232+ fs .touch (path , truncate = False )
233+ assert fs .size (path ) == 4
234+
235+ def test_pandas_read_csv (self ):
152236 import pandas
153237
154238 df = pandas .read_csv (
@@ -158,6 +242,16 @@ def test_compatibility_with_s3fs(self):
158242 )
159243 assert [(row ["col" ],) for _ , row in df .iterrows ()] == [(123456789 ,)]
160244
245+ def test_pandas_write_csv (self ):
246+ import pandas
247+
248+ df = pandas .DataFrame ({"a" : [1 ], "b" : [2 ]})
249+ path = f"s3://{ ENV .s3_staging_bucket } /{ ENV .s3_staging_key } { ENV .schema } /{ uuid .uuid4 ()} .csv"
250+ df .to_csv (path , index = False )
251+
252+ actual = pandas .read_csv (path )
253+ pandas .testing .assert_frame_equal (df , actual )
254+
161255
162256class TestS3File :
163257 @pytest .mark .parametrize (
0 commit comments