Skip to content
This repository was archived by the owner on Jun 9, 2023. It is now read-only.

Commit 7d46c3a

Browse files
author
DEKHTIARJonathan
committed
Bug Fix with TF Package Name
1 parent b2b5d99 commit 7d46c3a

2 files changed

Lines changed: 15 additions & 10 deletions

File tree

nvtx_plugins/python/nvtx/plugins/tf/package_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
MAJOR = 0
1919
MINOR = 1
20-
PATCH = 6
20+
PATCH = 7
2121
PRE_RELEASE = ''
2222

2323
# Use the following formatting: (major, minor, patch, pre-release)

setup.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import os
2222
import sys
2323

24-
import pkgutil
24+
import subprocess
2525

2626
from setuptools import setup
2727
from setuptools import Extension
@@ -46,16 +46,21 @@
4646
from setup_utils import custom_build_ext
4747

4848

49-
REQUIRED_PACKAGES = ['wrapt']
49+
def run_piped_subprocess(command):
50+
ps = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
51+
return ps.communicate()[0].decode('utf-8').strip()
5052

51-
for tf_package_name in ['tensorflow', 'tensorflow-gpu']:
52-
tf_loader = pkgutil.find_loader(tf_package_name)
53-
if tf_loader is not None:
54-
REQUIRED_PACKAGES.append(tf_package_name)
55-
break
56-
else: # in case no available package is found, default to 'tensorflow'
57-
REQUIRED_PACKAGES.append('tensorflow')
5853

54+
def get_tf_pkgname():
55+
cmd_rslt = run_piped_subprocess("pip freeze | grep tensorflow-gpu")
56+
57+
if "tensorflow-gpu" in cmd_rslt:
58+
return "tensorflow-gpu"
59+
else:
60+
return "tensorflow" # Default if not found
61+
62+
63+
REQUIRED_PACKAGES = ['wrapt', get_tf_pkgname()]
5964

6065
tensorflow_nvtx_lib = Extension(
6166
'nvtx.plugins.tf.lib.nvtx_ops',

0 commit comments

Comments
 (0)