Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions spyglass_dlc/dgramling_dlc_orient.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def get_default(cls):
cls().insert_default(skip_duplicates=True)
default = (cls & {"dlc_orientation_params_name": "default"}).fetch1()
return default


@classmethod
def get_orient_methods(cls):
return _key_to_func_dict


@schema
class DLCOrientationSelection(dj.Manual):
""" """
Expand Down Expand Up @@ -159,7 +164,7 @@ def fetch1_dataframe(self):
)


def two_pt_head_orientation(pos_df: pd.DataFrame, **params):
def two_pt_orientation(pos_df: pd.DataFrame, **params):
"""Determines orientation based on vector between two points"""
BP1 = params.pop("bodypart1", None)
BP2 = params.pop("bodypart2", None)
Expand Down Expand Up @@ -209,6 +214,6 @@ def red_led_bisector_orientation(pos_df: pd.DataFrame, **params):

_key_to_func_dict = {
"none": no_orientation,
"red_green_orientation": two_pt_head_orientation,
"two_pt_orientation": two_pt_orientation,
"red_led_bisector": red_led_bisector_orientation,
}
4 changes: 3 additions & 1 deletion spyglass_dlc/dgramling_dlc_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def insert_estimation_task(
video_path, video_filename, _, _ = get_video_path(key)
output_dir = cls.infer_output_dir(key, video_filename=video_filename)
video_dir = os.path.dirname(video_path) + "/"
video_path = check_videofile(video_path=video_dir, video_filename=video_filename)[0]
video_path = check_videofile(
video_path=video_dir, video_filename=video_filename
)[0]
cls.insert1(
{
**key,
Expand Down
5 changes: 5 additions & 0 deletions spyglass_dlc/dgramling_dlc_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ def interp_pos(dlc_df, **kwargs):
)
subthresh_spans = get_span_start_stop(subthresh_inds)
for ind, (span_start, span_stop) in enumerate(subthresh_spans):
if (span_stop + 1) >= len(dlc_df):
dlc_df.loc[idx[span_start:span_stop], idx["x"]] = np.nan
dlc_df.loc[idx[span_start:span_stop], idx["y"]] = np.nan
print(f"ind: {ind} has no endpoint with which to interpolate")
continue
x = [dlc_df["x"].iloc[span_start - 1], dlc_df["x"].iloc[span_stop + 1]]
y = [dlc_df["y"].iloc[span_start - 1], dlc_df["y"].iloc[span_stop + 1]]
span_len = int(span_stop - span_start + 1)
Expand Down