Skip to content
Merged
209 changes: 204 additions & 5 deletions src/jdb_to_nwb/convert_photometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@
import struct
import pandas as pd
import numpy as np
import os
import json
import warnings
import scipy.io
from scipy.signal import butter, lfilter, hilbert
from scipy.signal import butter, lfilter, hilbert, filtfilt
from scipy.sparse import diags, eye, csc_matrix
from scipy.sparse.linalg import spsolve
from sklearn.linear_model import Lasso

# Some of these imports are unused for now but will be used for photometry metadata
from ndx_fiber_photometry import (
Indicator,

Check failure on line 16 in src/jdb_to_nwb/convert_photometry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/jdb_to_nwb/convert_photometry.py:16:5: F401 `ndx_fiber_photometry.Indicator` imported but unused
OpticalFiber,

Check failure on line 17 in src/jdb_to_nwb/convert_photometry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/jdb_to_nwb/convert_photometry.py:17:5: F401 `ndx_fiber_photometry.OpticalFiber` imported but unused
ExcitationSource,

Check failure on line 18 in src/jdb_to_nwb/convert_photometry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/jdb_to_nwb/convert_photometry.py:18:5: F401 `ndx_fiber_photometry.ExcitationSource` imported but unused
Photodetector,

Check failure on line 19 in src/jdb_to_nwb/convert_photometry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/jdb_to_nwb/convert_photometry.py:19:5: F401 `ndx_fiber_photometry.Photodetector` imported but unused
DichroicMirror,

Check failure on line 20 in src/jdb_to_nwb/convert_photometry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/jdb_to_nwb/convert_photometry.py:20:5: F401 `ndx_fiber_photometry.DichroicMirror` imported but unused
BandOpticalFilter,

Check failure on line 21 in src/jdb_to_nwb/convert_photometry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/jdb_to_nwb/convert_photometry.py:21:5: F401 `ndx_fiber_photometry.BandOpticalFilter` imported but unused
EdgeOpticalFilter,

Check failure on line 22 in src/jdb_to_nwb/convert_photometry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/jdb_to_nwb/convert_photometry.py:22:5: F401 `ndx_fiber_photometry.EdgeOpticalFilter` imported but unused
FiberPhotometry,
FiberPhotometryTable,
FiberPhotometryResponseSeries,
Expand Down Expand Up @@ -336,6 +338,205 @@
weights[-1] = weights[0]
return baseline

def import_ppd(ppd_file_path):
'''
Credit to the homie: https://github.com/ThomasAkam/photometry_preprocessing.git
I edited it so that his function only returns the data dictionary without the filtered data.
Raw data is filtered later/separately using the process_ppd_photometry function.

Function to import pyPhotometry binary data files into Python. Returns a dictionary with the
following items:
'filename' - Data filename
'subject_ID' - Subject ID
'date_time' - Recording start date and time (ISO 8601 format string)
'end_time' - Recording end date and time (ISO 8601 format string)
'mode' - Acquisition mode
'sampling_rate' - Sampling rate (Hz)
'LED_current' - Current for LEDs 1 and 2 (mA)
'version' - Version number of pyPhotometry
'analog_1' - Raw analog signal 1 (volts)
'analog_2' - Raw analog signal 2 (volts)
'analog_3' - Raw analog signal 3 (if present, volts)
'digital_1' - Digital signal 1
'digital_2' - Digital signal 2 (if present)
'pulse_inds_1' - Locations of rising edges on digital input 1 (samples).
'pulse_inds_2' - Locations of rising edges on digital input 2 (samples).
'pulse_times_1' - Times of rising edges on digital input 1 (ms).
'pulse_times_2' - Times of rising edges on digital input 2 (ms).
'time' - Time of each sample relative to start of recording (ms)
'''
with open(ppd_file_path, "rb") as f:
header_size = int.from_bytes(f.read(2), "little")
data_header = f.read(header_size)
data = np.frombuffer(f.read(), dtype=np.dtype("<u2"))
# Extract header information
header_dict = json.loads(data_header)
volts_per_division = header_dict["volts_per_division"]
sampling_rate = header_dict["sampling_rate"]
# Extract signals.
analog = data >> 1 # Analog signal is most significant 15 bits.
digital = ((data & 1) == 1).astype(int) # Digital signal is least significant bit.
# Alternating samples are different signals.
if "n_analog_signals" in header_dict.keys():
n_analog_signals = header_dict["n_analog_signals"]
n_digital_signals = header_dict["n_digital_signals"]
else: # Pre version 1.0 data file.
n_analog_signals = 2
n_digital_signals = 2
analog_1 = analog[::n_analog_signals] * volts_per_division[0]
analog_2 = analog[1::n_analog_signals] * volts_per_division[1]
analog_3 = analog[2::n_analog_signals] * volts_per_division[0] if n_analog_signals == 3 else None
digital_1 = digital[::n_analog_signals]
digital_2 = digital[1::n_analog_signals] if n_digital_signals == 2 else None
time = np.arange(analog_1.shape[0]) * 1000 / sampling_rate # Time relative to start of recording (ms).

# Extract rising edges for digital inputs.
pulse_inds_1 = 1 + np.where(np.diff(digital_1) == 1)[0]
pulse_inds_2 = 1 + np.where(np.diff(digital_2) == 1)[0] if n_digital_signals == 2 else None
pulse_times_1 = pulse_inds_1 * 1000 / sampling_rate
pulse_times_2 = pulse_inds_2 * 1000 / sampling_rate if n_digital_signals == 2 else None
# Return signals + header information as a dictionary.
data_dict = {
"filename": os.path.basename(ppd_file_path),
"analog_1": analog_1,
"analog_2": analog_2,
"digital_1": digital_1,
"digital_2": digital_2,
"pulse_inds_1": pulse_inds_1,
"pulse_inds_2": pulse_inds_2,
"pulse_times_1": pulse_times_1,
"pulse_times_2": pulse_times_2,
"time": time,
}
if n_analog_signals == 3:
data_dict.update(
{
"analog_3": analog_3,
}
)
data_dict.update(header_dict)
return data_dict

def process_ppd_photometry(nwbfile: NWBFile, ppd_file_path):
"""
Process pyPhotometry data from a .ppd file and add the processed signals to the NWB file.
"""
ppd_data = import_ppd(ppd_file_path)

raw_green = pd.Series(ppd_data['analog_1'])
raw_red = pd.Series(ppd_data ['analog_2'])
raw_405 = pd.Series(ppd_data['analog_3'])

relative_raw_signal = raw_green / raw_405

sampling_rate = ppd_data['sampling_rate']
visits = ppd_data['pulse_inds_1'][1:]

# low pass at 10Hz to remove high frequency noise
print('Filtering data...')
b,a = butter(2, 10, btype='low', fs=sampling_rate)
green_denoised = filtfilt(b,a, raw_green)
red_denoised = filtfilt(b,a, raw_red)
ratio_denoised = filtfilt(b,a, relative_raw_signal)
denoised_405 = filtfilt(b,a, raw_405)
# high pass at 0.001Hz which removes the drift due to bleaching, but will also remove any physiological variation in the signal on very slow timescales.
b,a = butter(2, 0.001, btype='high', fs=sampling_rate)
green_highpass = filtfilt(b,a, green_denoised, padtype='even')
red_highpass = filtfilt(b,a, red_denoised, padtype='even')
ratio_highpass = filtfilt(b,a, ratio_denoised, padtype='even')
highpass_405 = filtfilt(b,a, denoised_405, padtype='even')

# Z-score of each signal to normalize the data
print('Z-scoring data...')
green_zscored = np.divide(np.subtract(green_highpass,green_highpass.mean()),green_highpass.std())

red_zscored = np.divide(np.subtract(red_highpass,red_highpass.mean()),red_highpass.std())

zscored_405 = np.divide(np.subtract(highpass_405,highpass_405.mean()),highpass_405.std())

ratio_zscored = np.divide(np.subtract(ratio_highpass,ratio_highpass.mean()),ratio_highpass.std())
print('Done processing photometry data!')

# Add actual photometry data to the NWB
print("Adding photometry signals to NWB ...")

raw_470_response_series = FiberPhotometryResponseSeries(
name="raw_470",
description="Raw 470 nm",
data=raw_green.T[0],
unit="V",
rate=float(sampling_rate),
)

z_scored_470_response_series = FiberPhotometryResponseSeries(
name="z_scored_470",
description="Z-scored 470 nm",
data=green_zscored.T[0],
unit="z-score",
rate=float(sampling_rate),
)

raw_405_response_series = FiberPhotometryResponseSeries(
name="raw_405",
description="Raw 405 nm",
data=raw_405.T[0],
unit="V",
rate=float(sampling_rate),
)

z_scored_405_response_series = FiberPhotometryResponseSeries(
name="zscored_405",
description="Z-scored 405nm. This is used to calculate the ratiometric index when using GRAB-ACh3.8",
data=zscored_405.T[0],
unit="z-score",
rate=float(sampling_rate),
)

raw_565_response_series = FiberPhotometryResponseSeries(
name="raw_565",
description="Raw 565 nm",
data=raw_red.T[0],
unit="V",
rate=float(sampling_rate),
)

z_scored_565_response_series = FiberPhotometryResponseSeries(
name="zscored_565",
description="Z-scored 565nm",
data=red_zscored.T[0],
unit="z-score",
rate=float(sampling_rate),
)

raw_ratio_response_series = FiberPhotometryResponseSeries(
name="raw_470/405",
description="Raw ratiometric index of 470nm and 405nm",
data=relative_raw_signal.T[0],
unit="V",
rate=float(sampling_rate),
)

z_scored_ratio_response_series = FiberPhotometryResponseSeries(
name="zscored_470/405",
description="Z-scored ratiometric index of 470nm and 405nm",
data=ratio_zscored.T[0],
unit="z-score",
rate=float(sampling_rate),
)

# Add the FiberPhotometryResponseSeries objects to the NWB
nwbfile.add_acquisition(raw_405_response_series)
nwbfile.add_acquisition(raw_470_response_series)
nwbfile.add_acquisition(raw_565_response_series)
nwbfile.add_acquisition(raw_ratio_response_series)
nwbfile.add_acquisition(z_scored_405_response_series)
nwbfile.add_acquisition(z_scored_470_response_series)
nwbfile.add_acquisition(z_scored_565_response_series)
nwbfile.add_acquisition(z_scored_ratio_response_series)

# Return port visits in downsampled photometry time (86 Hz) to use for alignment
return sampling_rate, visits


def add_photometry_metadata(nwbfile: NWBFile, metadata: dict):
# TODO for Ryan - add photometry metadata to NWB :)
Expand Down Expand Up @@ -395,10 +596,8 @@
elif "ppd_file_path" in metadata["photometry"]:
# Process ppd file from pyPhotometry
print("Processing ppd file from pyPhotometry...")

# TODO for Jose - add pyPhotometry processing here!!
# Probably add the processing functions above and just call them here
raise NotImplementedError("pyPhotometry processing is not yet implemented.")
ppd_file_path = metadata["photometry"]["ppd_file_path"]
sampling_rate, visits = process_ppd_photometry(nwbfile, ppd_file_path)

else:
raise ValueError(
Expand Down
Loading