-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtst_atts.py
147 lines (125 loc) · 4.84 KB
/
tst_atts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#
# Copyright (C) 2024, Northwestern University and Argonne National Laboratory
# See COPYRIGHT notice in top-level directory.
#
import math
import subprocess
import sys
import unittest
import os
import tempfile
import warnings
import math
import numpy as np
from collections import OrderedDict
from numpy.random.mtrand import uniform
from utils import validate_nc_file
import io
from mpi4py import MPI
import pnetcdf
# test attribute creation
FILE_NAME = 'tst_atts.nc'
VAR_NAME="dummy_var"
DIM1_NAME="x"
DIM1_LEN=2
DIM2_NAME="y"
DIM2_LEN=3
DIM3_NAME="z"
DIM3_LEN=25
STRATT = 'string attribute'
EMPTYSTRATT = ''
INTATT = np.int32(1)
FLOATATT = math.pi
SEQATT = np.arange(10).astype('i4')
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
ATTDICT = {'stratt':STRATT,'floatatt':FLOATATT,'seqatt':SEQATT,
'emptystratt':EMPTYSTRATT,'intatt':INTATT}
NUM_TESTS = 3
file_formats = [fmt for fmt in ['NC_64BIT_DATA', 'NC_64BIT_OFFSET', None] for i in range(NUM_TESTS)]
class AttrTestCase(unittest.TestCase):
def setUp(self):
if (len(sys.argv) == 2) and os.path.isdir(sys.argv[1]):
self.file_path = os.path.join(sys.argv[1], FILE_NAME)
else:
self.file_path = FILE_NAME
self._file_format = file_formats.pop(0)
with pnetcdf.File(self.file_path,'w', format = self._file_format) as f:
# try to set a dataset attribute with one of the reserved names.
f.put_att('file_format','netcdf5_format')
# test attribute renaming
f.stratt_tmp = STRATT
f.rename_att('stratt_tmp','stratt')
f.emptystratt = EMPTYSTRATT
f.floatatt = FLOATATT
f.intatt = INTATT
f.seqatt = SEQATT
# sequences of strings converted to a single string.
f.def_dim(DIM1_NAME, DIM1_LEN)
f.def_dim(DIM2_NAME, DIM2_LEN)
f.def_dim(DIM3_NAME, DIM3_LEN)
v = f.def_var(VAR_NAME, pnetcdf.NC_DOUBLE, (DIM1_NAME,DIM2_NAME,DIM3_NAME))
# try to set a variable attribute with one of the reserved names.
v.put_att('ndim','three')
v.stratt_tmp = STRATT
v.rename_att('stratt_tmp','stratt')
v.emptystratt = EMPTYSTRATT
v.intatt = INTATT
v.floatatt = FLOATATT
v.seqatt = SEQATT
# try set the attribute "_FillValue" to set the fill value of netCDF fill value
v._FillValue = -999.
f.foo = np.array('bar','S')
f.foo = np.array('bar','U')
assert validate_nc_file(os.environ.get('PNETCDF_DIR'), self.file_path) == 0 if os.environ.get('PNETCDF_DIR') is not None else True
def tearDown(self):
# Remove the temporary files
if (rank == 0) and not((len(sys.argv) == 2) and os.path.isdir(sys.argv[1])):
os.remove(self.file_path)
def test_file_attr_dict_(self):
with pnetcdf.File(self.file_path, 'r') as f:
# check __dict__ method for accessing all netCDF attributes.
for key,val in ATTDICT.items():
if type(val) == np.ndarray:
assert f.__dict__[key].tolist() == val.tolist()
else:
assert f.__dict__[key] == val
def test_attr_access(self):
with pnetcdf.File(self.file_path, 'r') as f:
v = f.variables[VAR_NAME]
# check accessing individual attributes.
assert f.intatt == INTATT
assert f.floatatt == FLOATATT
assert f.stratt == STRATT
assert f.emptystratt == EMPTYSTRATT
# check accessing variable individual attributes.
assert v.intatt == INTATT
assert v.floatatt == FLOATATT
assert v.stratt == STRATT
assert v.seqatt.tolist() == SEQATT.tolist()
assert v.get_att('ndim') == 'three'
assert v._FillValue == -999.
def test_var_attr_dict_(self):
with pnetcdf.File(self.file_path, 'r') as f:
v = f.variables[VAR_NAME]
# variable attributes.
# check __dict__ method for accessing all netCDF attributes.
for key,val in ATTDICT.items():
if type(val) == np.ndarray:
assert v.__dict__[key].tolist() == val.tolist()
else:
assert v.__dict__[key] == val
if __name__ == '__main__':
suite = unittest.TestSuite()
for i in range(len(file_formats) // NUM_TESTS):
suite.addTest(AttrTestCase("test_file_attr_dict_"))
suite.addTest(AttrTestCase("test_attr_access"))
suite.addTest(AttrTestCase("test_var_attr_dict_"))
runner = unittest.TextTestRunner()
output = io.StringIO()
runner = unittest.TextTestRunner(stream=output)
result = runner.run(suite)
if not result.wasSuccessful():
print(output.getvalue())
sys.exit(1)