Skip to content

refactor code to remove dependency on THC/THCDeviceUtils.cuh #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@ demo/
compare/

# images/
images_det/
images_*/
imgs_*/
imgs/

# archived files
archive/

# packages from other repos
frankmocap/
hand_object_detector/

models/*
predata_dandan/multitask/*.json
Expand Down
Empty file added __init__.py
Empty file.
30 changes: 25 additions & 5 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@
from model.faster_rcnn.vgg16 import vgg16
from model.faster_rcnn.resnet import resnet
import pdb
import re

try:
xrange # Python 2
except NameError:
xrange = range # Python 3

def num_sort(input_string):
return list(map(int, re.findall(r'\d+', input_string)))[0]


def parse_args():
"""
Expand Down Expand Up @@ -149,8 +153,7 @@ def _get_image_blob(im):

return blob, np.array(im_scale_factors)

if __name__ == '__main__':

def main():
args = parse_args()

# print('Called with args:')
Expand Down Expand Up @@ -241,10 +244,12 @@ def _get_image_blob(im):
print(f'image dir = {args.image_dir}')
print(f'save dir = {args.save_dir}')
imglist = os.listdir(args.image_dir)
imglist = sorted(imglist, key=num_sort, reverse=True)
num_images = len(imglist)

print('Loaded Photo: {} images.'.format(num_images))

contact = []

while (num_images >= 0):
total_tic = time.time()
Expand Down Expand Up @@ -343,6 +348,7 @@ def _get_image_blob(im):
if vis:
im2show = np.copy(im)
obj_dets, hand_dets = None, None
obj_detect = 0
for j in xrange(1, len(pascal_classes)):
# inds = torch.nonzero(scores[:,j] > thresh).view(-1)
if pascal_classes[j] == 'hand':
Expand All @@ -365,13 +371,21 @@ def _get_image_blob(im):
cls_dets = cls_dets[keep.view(-1).long()]
if pascal_classes[j] == 'targetobject':
obj_dets = cls_dets.cpu().numpy()
# print(imglist[num_images])
obj_detect = 1
if pascal_classes[j] == 'hand':
hand_dets = cls_dets.cpu().numpy()


contact.append(obj_detect)
if vis:
# visualization
im2show = vis_detections_filtered_objects_PIL(im2show, obj_dets, hand_dets, thresh_hand, thresh_obj)

# if contact_curr:
# contact.append(1)
# else:
# contact.append(0)

misc_toc = time.time()
nms_time = misc_toc - misc_tic

Expand All @@ -395,7 +409,13 @@ def _get_image_blob(im):
print('Frame rate:', frame_rate)
if cv2.waitKey(1) & 0xFF == ord('q'):
break

print(contact)
print(len(contact))
if webcam_num >= 0:
cap.release()
cv2.destroyAllWindows()
cv2.destroyAllWindows()



if __name__ == '__main__':
main()
242 changes: 242 additions & 0 deletions get_obj_contact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Jiasen Lu, Jianwei Yang, based on code from Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# import _init_paths
import os
import numpy as np
import cv2
import torch
# from scipy.misc import imread
from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from model.rpn.bbox_transform import clip_boxes
# from model.nms.nms_wrapper import nms
from model.roi_layers import nms
from model.rpn.bbox_transform import bbox_transform_inv
from model.utils.net_utils import save_net, load_net, vis_detections, vis_detections_PIL, vis_detections_filtered_objects_PIL, vis_detections_filtered_objects # (1) here add a function to viz
from model.utils.blob import im_list_to_blob
from model.faster_rcnn.resnet import resnet
import pdb
import re
from tqdm import tqdm
try:
xrange # Python 2
except NameError:
xrange = range # Python 3

def num_sort(input_string):
return list(map(int, re.findall(r'\d+', input_string)))[0]

lr = cfg.TRAIN.LEARNING_RATE
momentum = cfg.TRAIN.MOMENTUM
weight_decay = cfg.TRAIN.WEIGHT_DECAY

def _get_image_blob(im):
"""Converts an image into a network input.
Arguments:
im (ndarray): a color image in BGR order
Returns:
blob (ndarray): a data blob holding an image pyramid
im_scale_factors (list): list of image scales (relative to im) used
in the image pyramid
"""
im_orig = im.astype(np.float32, copy=True)
im_orig -= cfg.PIXEL_MEANS

im_shape = im_orig.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])

processed_ims = []
im_scale_factors = []

for target_size in cfg.TEST.SCALES:
im_scale = float(target_size) / float(im_size_min)
# Prevent the biggest axis from being more than MAX_SIZE
if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE:
im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)
im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
interpolation=cv2.INTER_LINEAR)
im_scale_factors.append(im_scale)
processed_ims.append(im)

# Create a blob to hold the input images
blob = im_list_to_blob(processed_ims)

return blob, np.array(im_scale_factors)

def get_obj_contact(hand_traj, args):

torch.cuda.empty_cache()

if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)

cfg.USE_GPU_NMS = True
np.random.seed(cfg.RNG_SEED)
# load model
model_dir = args.load_dir + "/" + args.net + "_handobj_100K" + "/" + args.dataset
if not os.path.exists(model_dir):
raise Exception('There is no input directory for loading network from ' + model_dir)
load_name = os.path.join(model_dir, 'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))

pascal_classes = np.asarray(['__background__', 'targetobject', 'hand'])
args.set_cfgs = ['ANCHOR_SCALES', '[8, 16, 32, 64]', 'ANCHOR_RATIOS', '[0.5, 1, 2]']

fasterRCNN = resnet(pascal_classes, 101, pretrained=False, class_agnostic=args.class_agnostic)
fasterRCNN.create_architecture()

print("load checkpoint %s" % (load_name))
checkpoint = torch.load(load_name)
fasterRCNN.load_state_dict(checkpoint['model'])
if 'pooling_mode' in checkpoint.keys():
cfg.POOLING_MODE = checkpoint['pooling_mode']

print('Hand-Contact model loaded successfully!')

# initilize the tensor holder here.
im_data = torch.FloatTensor(1)
im_info = torch.FloatTensor(1)
num_boxes = torch.LongTensor(1)
gt_boxes = torch.FloatTensor(1)
box_info = torch.FloatTensor(1)

# ship to cuda
im_data = im_data.cuda()
im_info = im_info.cuda()
num_boxes = num_boxes.cuda()
gt_boxes = gt_boxes.cuda()

with torch.no_grad():
cfg.CUDA = True

fasterRCNN.cuda()
fasterRCNN.eval()

max_per_image = 100
thresh_hand = args.thresh_hand
thresh_obj = args.thresh_obj
vis = args.vis

hand_traj_with_contact = []
for point_idx in tqdm(range(len(hand_traj)), dynamic_ncols=True):
traj_point = hand_traj[point_idx]
im_path = traj_point["image_path"]
im = cv2.imread(im_path)

im_blob, im_scales = _get_image_blob(im)
im_info_np = np.array([[im_blob.shape[1], im_blob.shape[2], im_scales[0]]], dtype=np.float32)

im_data_pt = torch.from_numpy(im_blob)
im_data_pt = im_data_pt.permute(0, 3, 1, 2)
im_info_pt = torch.from_numpy(im_info_np)

with torch.no_grad():
im_data.resize_(im_data_pt.size()).copy_(im_data_pt)
im_info.resize_(im_info_pt.size()).copy_(im_info_pt)
gt_boxes.resize_(1, 1, 5).zero_()
num_boxes.resize_(1).zero_()
box_info.resize_(1, 1, 5).zero_()

rois, cls_prob, bbox_pred, \
rpn_loss_cls, rpn_loss_box, \
RCNN_loss_cls, RCNN_loss_bbox, \
rois_label, loss_list = fasterRCNN(im_data, im_info, gt_boxes, num_boxes, box_info)

scores = cls_prob.data
boxes = rois.data[:, :, 1:5]

# extact predicted params
contact_vector = loss_list[0][0] # hand contact state info
offset_vector = loss_list[1][0].detach() # offset vector (factored into a unit vector and a magnitude)
lr_vector = loss_list[2][0].detach() # hand side info (left/right)

# get hand contact
_, contact_indices = torch.max(contact_vector, 2)
contact_indices = contact_indices.squeeze(0).unsqueeze(-1).float()

# get hand side
lr = torch.sigmoid(lr_vector) > 0.5
lr = lr.squeeze(0).float()

if cfg.TEST.BBOX_REG:
# Apply bounding-box regression deltas
box_deltas = bbox_pred.data
if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
# Optionally normalize targets by a precomputed mean and stdev
if args.class_agnostic:
box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
+ torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()

box_deltas = box_deltas.view(1, -1, 4)
else:
box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
+ torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
box_deltas = box_deltas.view(1, -1, 4 * len(pascal_classes))

pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)
pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)
else:
# Simply repeat the boxes, once for each class
pred_boxes = np.tile(boxes, (1, scores.shape[1]))

pred_boxes /= im_scales[0]

scores = scores.squeeze()
pred_boxes = pred_boxes.squeeze()

im2show = np.copy(im)

contact = 0
obj_dets, hand_dets = None, None
for j in xrange(1, len(pascal_classes)):
# inds = torch.nonzero(scores[:,j] > thresh).view(-1)
if pascal_classes[j] == 'hand':
inds = torch.nonzero(scores[:,j]>thresh_hand).view(-1)
elif pascal_classes[j] == 'targetobject':
inds = torch.nonzero(scores[:,j]>thresh_obj).view(-1)

# if there is det
if inds.numel() > 0:
cls_scores = scores[:,j][inds]
_, order = torch.sort(cls_scores, 0, True)
if args.class_agnostic:
cls_boxes = pred_boxes[inds, :]
else:
cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]

cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1), contact_indices[inds], offset_vector.squeeze(0)[inds], lr[inds]), 1)
cls_dets = cls_dets[order]
keep = nms(cls_boxes[order, :], cls_scores[order], cfg.TEST.NMS)
cls_dets = cls_dets[keep.view(-1).long()]
if pascal_classes[j] == 'targetobject':
obj_dets = cls_dets.cpu().numpy()
if 0 in lr[inds].cpu().numpy() and args.hand == 'left_hand':
contact = 1
elif 1 in lr[inds].cpu().numpy() and args.hand == 'right_hand':
contact = 1
else:
contact = 0
if pascal_classes[j] == 'hand':
hand_dets = cls_dets.cpu().numpy()

im2show = vis_detections_filtered_objects_PIL(im2show, obj_dets, hand_dets, thresh_hand, thresh_obj)
result_path = os.path.join(args.data_dir, "hand_object_detect")
os.makedirs(result_path, exist_ok=True)
img_name = os.path.basename(im_path)
img__name = img_name[:-4] + '.jpg'
res_img_path = os.path.join(result_path, img_name)
im2show.save(res_img_path)

traj_point["contact"] = contact
hand_traj_with_contact.append(traj_point)
return hand_traj_with_contact


Loading