|
| 1 | +from copy import deepcopy |
| 2 | +from math import ceil, floor |
| 3 | +from struct import pack, unpack |
| 4 | +import calendar |
| 5 | +import datetime |
| 6 | +import numpy as np |
| 7 | +import os |
| 8 | +import re |
| 9 | +import warnings |
| 10 | + |
| 11 | +def padtrim(buf, num): |
| 12 | + num -= len(buf) |
| 13 | + if num>=0: |
| 14 | + # pad the input to the specified length |
| 15 | + return str(buf) + ' ' * num |
| 16 | + else: |
| 17 | + # trim the input to the specified length |
| 18 | + return buf[0:num] |
| 19 | + |
| 20 | +#################################################################################################### |
| 21 | +# the EDF header is represented as a tuple of (meas_info, chan_info) |
| 22 | +# meas_info should have ['record_length', 'magic', 'hour', 'subject_id', 'recording_id', 'n_records', 'month', 'subtype', 'second', 'nchan', 'data_size', 'data_offset', 'lowpass', 'year', 'highpass', 'day', 'minute'] |
| 23 | +# chan_info should have ['physical_min', 'transducers', 'physical_max', 'digital_max', 'ch_names', 'n_samps', 'units', 'digital_min'] |
| 24 | +#################################################################################################### |
| 25 | + |
| 26 | +class EDFWriter(): |
| 27 | + def __init__(self, fname=None): |
| 28 | + self.fname = None |
| 29 | + self.meas_info = None |
| 30 | + self.chan_info = None |
| 31 | + self.calibrate = None |
| 32 | + self.offset = None |
| 33 | + self.n_records = 0 |
| 34 | + if fname: |
| 35 | + self.open(fname) |
| 36 | + |
| 37 | + def open(self, fname): |
| 38 | + with open(fname, 'wb') as fid: |
| 39 | + assert(fid.tell() == 0) |
| 40 | + self.fname = fname |
| 41 | + |
| 42 | + def close(self): |
| 43 | + # it is still needed to update the number of records in the header |
| 44 | + # this requires copying the whole file content |
| 45 | + meas_info = self.meas_info |
| 46 | + chan_info = self.chan_info |
| 47 | + # update the n_records value in the file |
| 48 | + tempname = self.fname + '.bak' |
| 49 | + os.rename(self.fname, tempname) |
| 50 | + with open(tempname, 'rb') as fid1: |
| 51 | + assert(fid1.tell() == 0) |
| 52 | + with open(self.fname, 'wb') as fid2: |
| 53 | + assert(fid2.tell() == 0) |
| 54 | + fid2.write(fid1.read(236)) |
| 55 | + fid1.read(8) # skip this part |
| 56 | + fid2.write(padtrim(str(self.n_records), 8)) # but write this instead |
| 57 | + fid2.write(fid1.read(meas_info['data_offset'] - 236 - 8)) |
| 58 | + blocksize = np.sum(chan_info['n_samps']) * meas_info['data_size'] |
| 59 | + for block in range(self.n_records): |
| 60 | + fid2.write(fid1.read(blocksize)) |
| 61 | + os.remove(tempname) |
| 62 | + self.fname = None |
| 63 | + self.meas_info = None |
| 64 | + self.chan_info = None |
| 65 | + self.calibrate = None |
| 66 | + self.offset = None |
| 67 | + self.n_records = 0 |
| 68 | + return |
| 69 | + |
| 70 | + def writeHeader(self, header): |
| 71 | + meas_info = header[0] |
| 72 | + chan_info = header[1] |
| 73 | + meas_size = 256 |
| 74 | + chan_size = 256 * meas_info['nchan'] |
| 75 | + with open(self.fname, 'wb') as fid: |
| 76 | + assert(fid.tell() == 0) |
| 77 | + |
| 78 | + # fill in the missing or incomplete information |
| 79 | + if not 'subject_id' in meas_info: |
| 80 | + meas_info['subject_id'] = '' |
| 81 | + if not 'recording_id' in meas_info: |
| 82 | + meas_info['recording_id'] = '' |
| 83 | + if not 'subtype' in meas_info: |
| 84 | + meas_info['subtype'] = 'edf' |
| 85 | + nchan = meas_info['nchan'] |
| 86 | + if not 'ch_names' in chan_info or len(chan_info['ch_names'])<nchan: |
| 87 | + chan_info['ch_names'] = [str(i) for i in range(nchan)] |
| 88 | + if not 'transducers' in chan_info or len(chan_info['transducers'])<nchan: |
| 89 | + chan_info['transducers'] = ['' for i in range(nchan)] |
| 90 | + if not 'units' in chan_info or len(chan_info['units'])<nchan: |
| 91 | + chan_info['units'] = ['' for i in range(nchan)] |
| 92 | + |
| 93 | + if meas_info['subtype'] in ('24BIT', 'bdf'): |
| 94 | + meas_info['data_size'] = 3 # 24-bit (3 byte) integers |
| 95 | + else: |
| 96 | + meas_info['data_size'] = 2 # 16-bit (2 byte) integers |
| 97 | + |
| 98 | + fid.write(padtrim('0', 8)) |
| 99 | + fid.write(padtrim(meas_info['subject_id'], 80)) |
| 100 | + fid.write(padtrim(meas_info['recording_id'], 80)) |
| 101 | + fid.write(padtrim('{:0>2d}.{:0>2d}.{:0>2d}'.format(meas_info['day'], meas_info['month'], meas_info['year']), 8)) |
| 102 | + fid.write(padtrim('{:0>2d}.{:0>2d}.{:0>2d}'.format(meas_info['hour'], meas_info['minute'], meas_info['second']), 8)) |
| 103 | + fid.write(padtrim(str(meas_size + chan_size), 8)) |
| 104 | + fid.write(' ' * 44) |
| 105 | + fid.write(padtrim(str(-1), 8)) # the final n_records should be inserted on byte 236 |
| 106 | + fid.write(padtrim(str(meas_info['record_length']), 8)) |
| 107 | + fid.write(padtrim(str(meas_info['nchan']), 4)) |
| 108 | + |
| 109 | + # ensure that these are all np arrays rather than lists |
| 110 | + for key in ['physical_min', 'transducers', 'physical_max', 'digital_max', 'ch_names', 'n_samps', 'units', 'digital_min']: |
| 111 | + chan_info[key] = np.asarray(chan_info[key]) |
| 112 | + |
| 113 | + for i in range(meas_info['nchan']): |
| 114 | + fid.write(padtrim( chan_info['ch_names'][i], 16)) |
| 115 | + for i in range(meas_info['nchan']): |
| 116 | + fid.write(padtrim( chan_info['transducers'][i], 80)) |
| 117 | + for i in range(meas_info['nchan']): |
| 118 | + fid.write(padtrim( chan_info['units'][i], 8)) |
| 119 | + for i in range(meas_info['nchan']): |
| 120 | + fid.write(padtrim(str(chan_info['physical_min'][i]), 8)) |
| 121 | + for i in range(meas_info['nchan']): |
| 122 | + fid.write(padtrim(str(chan_info['physical_max'][i]), 8)) |
| 123 | + for i in range(meas_info['nchan']): |
| 124 | + fid.write(padtrim(str(chan_info['digital_min'][i]), 8)) |
| 125 | + for i in range(meas_info['nchan']): |
| 126 | + fid.write(padtrim(str(chan_info['digital_max'][i]), 8)) |
| 127 | + for i in range(meas_info['nchan']): |
| 128 | + fid.write(' ' * 80) # prefiltering |
| 129 | + for i in range(meas_info['nchan']): |
| 130 | + fid.write(padtrim(str(chan_info['n_samps'][i]), 8)) |
| 131 | + for i in range(meas_info['nchan']): |
| 132 | + fid.write(' ' * 32) # reserved |
| 133 | + meas_info['data_offset'] = fid.tell() |
| 134 | + |
| 135 | + self.meas_info = meas_info |
| 136 | + self.chan_info = chan_info |
| 137 | + self.calibrate = (chan_info['physical_max'] - chan_info['physical_min'])/(chan_info['digital_max'] - chan_info['digital_min']); |
| 138 | + self.offset = chan_info['physical_min'] - self.calibrate * chan_info['digital_min']; |
| 139 | + channels = list(range(meas_info['nchan'])) |
| 140 | + for ch in channels: |
| 141 | + if self.calibrate[ch]<0: |
| 142 | + self.calibrate[ch] = 1; |
| 143 | + self.offset[ch] = 0; |
| 144 | + |
| 145 | + def writeBlock(self, data): |
| 146 | + meas_info = self.meas_info |
| 147 | + chan_info = self.chan_info |
| 148 | + with open(self.fname, 'ab') as fid: |
| 149 | + assert(fid.tell() > 0) |
| 150 | + for i in range(meas_info['nchan']): |
| 151 | + raw = deepcopy(data[i]) |
| 152 | + |
| 153 | + assert(len(raw)==chan_info['n_samps'][i]) |
| 154 | + if min(raw)<chan_info['physical_min'][i]: |
| 155 | + warnings.warn('Value exceeds physical_min: ' + str(min(raw)) ); |
| 156 | + if max(raw)>chan_info['physical_max'][i]: |
| 157 | + warnings.warn('Value exceeds physical_max: '+ str(max(raw))); |
| 158 | + |
| 159 | + raw -= self.offset[i] # FIXME I am not sure about the order of calibrate and offset |
| 160 | + raw /= self.calibrate[i] |
| 161 | + |
| 162 | + raw = np.asarray(raw, dtype=np.int16) |
| 163 | + buf = [pack('h', x) for x in raw] |
| 164 | + for val in buf: |
| 165 | + fid.write(val) |
| 166 | + self.n_records += 1 |
| 167 | + |
| 168 | +#################################################################################################### |
| 169 | + |
| 170 | +class EDFReader(): |
| 171 | + def __init__(self, fname=None): |
| 172 | + self.fname = None |
| 173 | + self.meas_info = None |
| 174 | + self.chan_info = None |
| 175 | + self.calibrate = None |
| 176 | + self.offset = None |
| 177 | + if fname: |
| 178 | + self.open(fname) |
| 179 | + |
| 180 | + def open(self, fname): |
| 181 | + with open(fname, 'rb') as fid: |
| 182 | + assert(fid.tell() == 0) |
| 183 | + self.fname = fname |
| 184 | + self.readHeader() |
| 185 | + return self.meas_info, self.chan_info |
| 186 | + |
| 187 | + def close(self): |
| 188 | + self.fname = None |
| 189 | + self.meas_info = None |
| 190 | + self.chan_info = None |
| 191 | + self.calibrate = None |
| 192 | + self.offset = None |
| 193 | + |
| 194 | + def readHeader(self): |
| 195 | + # the following is copied over from MNE-Python and subsequently modified |
| 196 | + # to more closely reflect the native EDF standard |
| 197 | + meas_info = {} |
| 198 | + chan_info = {} |
| 199 | + with open(self.fname, 'rb') as fid: |
| 200 | + assert(fid.tell() == 0) |
| 201 | + |
| 202 | + meas_info['magic'] = fid.read(8).strip().decode() |
| 203 | + meas_info['subject_id'] = fid.read(80).strip().decode() # subject id |
| 204 | + meas_info['recording_id'] = fid.read(80).strip().decode() # recording id |
| 205 | + |
| 206 | + day, month, year = [int(x) for x in re.findall('(\d+)', fid.read(8).decode())] |
| 207 | + hour, minute, second = [int(x) for x in re.findall('(\d+)', fid.read(8).decode())] |
| 208 | + meas_info['day'] = day |
| 209 | + meas_info['month'] = month |
| 210 | + meas_info['year'] = year |
| 211 | + meas_info['hour'] = hour |
| 212 | + meas_info['minute'] = minute |
| 213 | + meas_info['second'] = second |
| 214 | + # date = datetime.datetime(year + 2000, month, day, hour, minute, sec) |
| 215 | + # meas_info['meas_date'] = calendar.timegm(date.utctimetuple()) |
| 216 | + |
| 217 | + meas_info['data_offset'] = header_nbytes = int(fid.read(8).decode()) |
| 218 | + |
| 219 | + subtype = fid.read(44).strip().decode()[:5] |
| 220 | + if len(subtype) > 0: |
| 221 | + meas_info['subtype'] = subtype |
| 222 | + else: |
| 223 | + meas_info['subtype'] = os.path.splitext(self.fname)[1][1:].lower() |
| 224 | + |
| 225 | + if meas_info['subtype'] in ('24BIT', 'bdf'): |
| 226 | + meas_info['data_size'] = 3 # 24-bit (3 byte) integers |
| 227 | + else: |
| 228 | + meas_info['data_size'] = 2 # 16-bit (2 byte) integers |
| 229 | + |
| 230 | + meas_info['n_records'] = n_records = int(fid.read(8).decode()) |
| 231 | + |
| 232 | + # record length in seconds |
| 233 | + record_length = float(fid.read(8).decode()) |
| 234 | + if record_length == 0: |
| 235 | + meas_info['record_length'] = record_length = 1. |
| 236 | + warnings.warn('Headermeas_information is incorrect for record length. ' |
| 237 | + 'Default record length set to 1.') |
| 238 | + else: |
| 239 | + meas_info['record_length'] = record_length |
| 240 | + meas_info['nchan'] = nchan = int(fid.read(4).decode()) |
| 241 | + |
| 242 | + channels = list(range(nchan)) |
| 243 | + chan_info['ch_names'] = [fid.read(16).strip().decode() for ch in channels] |
| 244 | + chan_info['transducers'] = [fid.read(80).strip().decode() for ch in channels] |
| 245 | + chan_info['units'] = [fid.read(8).strip().decode() for ch in channels] |
| 246 | + chan_info['physical_min'] = physical_min = np.array([float(fid.read(8).decode()) for ch in channels]) |
| 247 | + chan_info['physical_max'] = physical_max = np.array([float(fid.read(8).decode()) for ch in channels]) |
| 248 | + chan_info['digital_min'] = digital_min = np.array([float(fid.read(8).decode()) for ch in channels]) |
| 249 | + chan_info['digital_max'] = digital_max = np.array([float(fid.read(8).decode()) for ch in channels]) |
| 250 | + |
| 251 | + prefiltering = [fid.read(80).strip().decode() for ch in channels][:-1] |
| 252 | + highpass = np.ravel([re.findall('HP:\s+(\w+)', filt) for filt in prefiltering]) |
| 253 | + lowpass = np.ravel([re.findall('LP:\s+(\w+)', filt) for filt in prefiltering]) |
| 254 | + high_pass_default = 0. |
| 255 | + if highpass.size == 0: |
| 256 | + meas_info['highpass'] = high_pass_default |
| 257 | + elif all(highpass): |
| 258 | + if highpass[0] == 'NaN': |
| 259 | + meas_info['highpass'] = high_pass_default |
| 260 | + elif highpass[0] == 'DC': |
| 261 | + meas_info['highpass'] = 0. |
| 262 | + else: |
| 263 | + meas_info['highpass'] = float(highpass[0]) |
| 264 | + else: |
| 265 | + meas_info['highpass'] = float(np.max(highpass)) |
| 266 | + warnings.warn('Channels contain different highpass filters. ' |
| 267 | + 'Highest filter setting will be stored.') |
| 268 | + |
| 269 | + if lowpass.size == 0: |
| 270 | + meas_info['lowpass'] = None |
| 271 | + elif all(lowpass): |
| 272 | + if lowpass[0] == 'NaN': |
| 273 | + meas_info['lowpass'] = None |
| 274 | + else: |
| 275 | + meas_info['lowpass'] = float(lowpass[0]) |
| 276 | + else: |
| 277 | + meas_info['lowpass'] = float(np.min(lowpass)) |
| 278 | + warnings.warn('%s' % ('Channels contain different lowpass filters.' |
| 279 | + ' Lowest filter setting will be stored.')) |
| 280 | + # number of samples per record |
| 281 | + chan_info['n_samps'] = n_samps = np.array([int(fid.read(8).decode()) for ch in channels]) |
| 282 | + |
| 283 | + fid.read(32 *meas_info['nchan']).decode() # reserved |
| 284 | + assert fid.tell() == header_nbytes |
| 285 | + |
| 286 | + if meas_info['n_records']==-1: |
| 287 | + # this happens if the n_records is not updated at the end of recording |
| 288 | + tot_samps = (os.path.getsize(self.fname)-meas_info['data_offset'])/meas_info['data_size'] |
| 289 | + meas_info['n_records'] = tot_samps/sum(n_samps) |
| 290 | + |
| 291 | + self.calibrate = (chan_info['physical_max'] - chan_info['physical_min'])/(chan_info['digital_max'] - chan_info['digital_min']); |
| 292 | + self.offset = chan_info['physical_min'] - self.calibrate * chan_info['digital_min']; |
| 293 | + for ch in channels: |
| 294 | + if self.calibrate[ch]<0: |
| 295 | + self.calibrate[ch] = 1; |
| 296 | + self.offset[ch] = 0; |
| 297 | + |
| 298 | + self.meas_info = meas_info |
| 299 | + self.chan_info = chan_info |
| 300 | + return (meas_info, chan_info) |
| 301 | + |
| 302 | + def readBlock(self, block): |
| 303 | + assert(block>=0) |
| 304 | + meas_info = self.meas_info |
| 305 | + chan_info = self.chan_info |
| 306 | + data = [] |
| 307 | + with open(self.fname, 'rb') as fid: |
| 308 | + assert(fid.tell() == 0) |
| 309 | + blocksize = np.sum(chan_info['n_samps']) * meas_info['data_size'] |
| 310 | + fid.seek(meas_info['data_offset'] + block * blocksize) |
| 311 | + for i in range(meas_info['nchan']): |
| 312 | + buf = fid.read(chan_info['n_samps'][i]*meas_info['data_size']) |
| 313 | + raw = np.asarray(unpack('<{}h'.format(chan_info['n_samps'][i]), buf), dtype=np.float32) |
| 314 | + raw *= self.calibrate[i] |
| 315 | + raw += self.offset[i] # FIXME I am not sure about the order of calibrate and offset |
| 316 | + data.append(raw) |
| 317 | + return data |
| 318 | + |
| 319 | + def readSamples(self, channel, begsample, endsample): |
| 320 | + meas_info = self.meas_info |
| 321 | + chan_info = self.chan_info |
| 322 | + n_samps = chan_info['n_samps'][channel] |
| 323 | + begblock = int(floor((begsample) / n_samps)) |
| 324 | + endblock = int(floor((endsample) / n_samps)) |
| 325 | + data = self.readBlock(begblock)[channel] |
| 326 | + for block in range(begblock+1, endblock+1): |
| 327 | + data = np.append(data, self.readBlock(block)[channel]) |
| 328 | + begsample -= begblock*n_samps |
| 329 | + endsample -= begblock*n_samps |
| 330 | + return data[begsample:(endsample+1)] |
| 331 | + |
| 332 | +#################################################################################################### |
| 333 | +# the following are a number of helper functions to make the behaviour of this EDFReader |
| 334 | +# class more similar to https://bitbucket.org/cleemesser/python-edf/ |
| 335 | +#################################################################################################### |
| 336 | + |
| 337 | + def getSignalTextLabels(self): |
| 338 | + # convert from unicode to string |
| 339 | + return [str(x) for x in self.chan_info['ch_names']] |
| 340 | + |
| 341 | + def getNSignals(self): |
| 342 | + return self.meas_info['nchan'] |
| 343 | + |
| 344 | + def getSignalFreqs(self): |
| 345 | + return self.chan_info['n_samps'] / self.meas_info['record_length'] |
| 346 | + |
| 347 | + def getNSamples(self): |
| 348 | + return self.chan_info['n_samps'] * self.meas_info['n_records'] |
| 349 | + |
| 350 | + def readSignal(self, chanindx): |
| 351 | + begsample = 0; |
| 352 | + endsample = self.chan_info['n_samps'][chanindx] * self.meas_info['n_records'] - 1; |
| 353 | + return self.readSamples(chanindx, begsample, endsample) |
| 354 | + |
| 355 | +#################################################################################################### |
| 356 | + |
| 357 | +if False: |
| 358 | + file_in = EDFReader() |
| 359 | + file_in.open('/Users/roboos/day 01[10.03].edf') |
| 360 | + print file_in.readSamples(0, 0, 0) |
| 361 | + print file_in.readSamples(0, 0, 128) |
| 362 | + |
| 363 | + |
| 364 | +if False: |
| 365 | + file_in = EDFReader() |
| 366 | + file_in.open('/Users/roboos/test_generator.edf') |
| 367 | + |
| 368 | + file_out = EDFWriter() |
| 369 | + file_out.open('/Users/roboos/test_generator copy.edf') |
| 370 | + |
| 371 | + header = file_in.readHeader() |
| 372 | + |
| 373 | + file_out.writeHeader(header) |
| 374 | + |
| 375 | + meas_info = header[0] |
| 376 | + for i in range(meas_info['n_records']): |
| 377 | + data = file_in.readBlock(i) |
| 378 | + file_out.writeBlock(data) |
| 379 | + |
| 380 | + file_in.close() |
| 381 | + file_out.close() |
0 commit comments