Skip to content

Revert "add timing information for sampler outputs" #784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,8 @@ def __init__(
self._max_treedepths: np.ndarray = np.zeros(
self.runset.chains, dtype=int
)
self._chain_time: List[Dict[str, float]] = []

# info from CSV header and initial and final comment blocks
# info from CSV initial comments and header
config = self._validate_csv_files()
self._metadata: InferenceMetadata = InferenceMetadata(config)
if not self._is_fixed_param:
Expand Down Expand Up @@ -241,14 +240,6 @@ def max_treedepths(self) -> Optional[np.ndarray]:
"""
return self._max_treedepths if not self._is_fixed_param else None

@property
def time(self) -> List[Dict[str, float]]:
"""
List of per-chain time info scraped from CSV file.
Each chain has dict with keys "warmup", "sampling", "total".
"""
return self._chain_time

def draws(
self, *, inc_warmup: bool = False, concat_chains: bool = False
) -> np.ndarray:
Expand Down Expand Up @@ -310,7 +301,6 @@ def _validate_csv_files(self) -> Dict[str, Any]:
save_warmup=self._save_warmup,
thin=self._thin,
)
self._chain_time.append(dzero['time']) # type: ignore
if not self._is_fixed_param:
self._divergences[i] = dzero['ct_divergences']
self._max_treedepths[i] = dzero['ct_max_treedepth']
Expand All @@ -323,7 +313,6 @@ def _validate_csv_files(self) -> Dict[str, Any]:
save_warmup=self._save_warmup,
thin=self._thin,
)
self._chain_time.append(drest['time']) # type: ignore
for key in dzero:
# check args that matter for parsing, plus name, version
if (
Expand Down
61 changes: 0 additions & 61 deletions cmdstanpy/utils/stancsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]:
lineno = scan_warmup_iters(fd, dict, lineno)
lineno = scan_hmc_params(fd, dict, lineno)
lineno = scan_sampling_iters(fd, dict, lineno, is_fixed_param)
lineno = scan_time(fd, dict, lineno)
except ValueError as e:
raise ValueError("Error in reading csv file: " + path) from e
return dict
Expand Down Expand Up @@ -382,66 +381,6 @@ def scan_sampling_iters(
return lineno


def scan_time(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int:
"""
Scan time information from the trailing comment lines in a Stan CSV file.

# Elapsed Time: 0.001332 seconds (Warm-up)
# 0.000249 seconds (Sampling)
# 0.001581 seconds (Total)


It extracts the time values and saves them in the config_dict: key 'time',
value a dictionary with keys 'warmup', 'sampling', and 'total'.
Returns the updated line number after reading the time info.

:param fd: Open file descriptor at comment row following all sample data.
:param config_dict: Dictionary to which the time info is added.
:param lineno: Current line number
"""
time = {}
keys = ['warmup', 'sampling', 'total']
while True:
pos = fd.tell()
line = fd.readline()
if not line:
break
lineno += 1
stripped = line.strip()
if not stripped.startswith('#'):
fd.seek(pos)
lineno -= 1
break
content = stripped.lstrip('#').strip()
if not content:
continue
tokens = content.split()
if len(tokens) < 3:
raise ValueError(f"Invalid time at line {lineno}: {content}")
if 'Warm-up' in content:
key = 'warmup'
time_str = tokens[2]
elif 'Sampling' in content:
key = 'sampling'
time_str = tokens[0]
elif 'Total' in content:
key = 'total'
time_str = tokens[0]
else:
raise ValueError(f"Invalid time at line {lineno}: {content}")
try:
t = float(time_str)
except ValueError as e:
raise ValueError(f"Invalid time at line {lineno}: {content}") from e
time[key] = t

if not all(key in time for key in keys):
raise ValueError(f"Invalid time, stopped at {lineno}")

config_dict['time'] = time
return lineno


def read_metric(path: str) -> List[int]:
"""
Read metric file in JSON or Rdump format.
Expand Down
6 changes: 0 additions & 6 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,12 +1714,6 @@ def test_metadata() -> None:
assert fit.column_names == col_names
assert fit.metric_type == 'diag_e'

assert len(fit.time) == 4
for i in range(4):
assert 'warmup' in fit.time[i].keys()
assert 'sampling' in fit.time[i].keys()
assert 'total' in fit.time[i].keys()

assert fit.metadata.cmdstan_config['num_samples'] == 100
assert fit.metadata.cmdstan_config['thin'] == 1
assert fit.metadata.cmdstan_config['algorithm'] == 'hmc'
Expand Down
54 changes: 0 additions & 54 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,57 +702,3 @@ def test_munge_varnames() -> None:

var = 'y.2.3:1.2:5:6'
assert stancsv.munge_varname(var) == 'y[2,3].1[2].5.6'


def test_scan_time_normal() -> None:
csv_content = (
"# Elapsed Time: 0.005 seconds (Warm-up)\n"
"# 0 seconds (Sampling)\n"
"# 0.005 seconds (Total)\n"
)
fd = io.StringIO(csv_content)
config_dict = {}
start_line = 0
final_line = stancsv.scan_time(fd, config_dict, start_line)
assert final_line == 3
expected = {'warmup': 0.005, 'sampling': 0.0, 'total': 0.005}
assert config_dict.get('time') == expected


def test_scan_time_no_timing() -> None:
csv_content = (
"# merrily we roll along\n"
"# roll along\n"
"# very merrily we roll along\n"
)
fd = io.StringIO(csv_content)
config_dict = {}
start_line = 0
with pytest.raises(ValueError, match="Invalid time"):
stancsv.scan_time(fd, config_dict, start_line)


def test_scan_time_invalid_value() -> None:
csv_content = (
"# Elapsed Time: abc seconds (Warm-up)\n"
"# 0.200 seconds (Sampling)\n"
"# 0.300 seconds (Total)\n"
)
fd = io.StringIO(csv_content)
config_dict = {}
start_line = 0
with pytest.raises(ValueError, match="Invalid time"):
stancsv.scan_time(fd, config_dict, start_line)


def test_scan_time_invalid_string() -> None:
csv_content = (
"# Elapsed Time: 0.22 seconds (foo)\n"
"# 0.200 seconds (Sampling)\n"
"# 0.300 seconds (Total)\n"
)
fd = io.StringIO(csv_content)
config_dict = {}
start_line = 0
with pytest.raises(ValueError, match="Invalid time"):
stancsv.scan_time(fd, config_dict, start_line)
Loading