Skip to content

Commit 156ca7a

Browse files
cicichen01facebook-github-bot
authored andcommitted
Simplify the get_random_model_and_data() method (#1251)
Summary: As titled. The sample data type is fixed after construction and should be handled separately based on its type. Differential Revision: D55153967
1 parent fabac35 commit 156ca7a

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

tests/influence/_utils/common.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -285,21 +285,17 @@ def get_random_model_and_data(
285285
torch.normal(0, 1, (num_samples, in_features)).double()
286286
for _ in range(num_inputs)
287287
]
288-
all_samples = (
289-
_move_sample_to_cuda(all_samples)
290-
if isinstance(all_samples, list) and use_gpu
291-
else (all_samples.cuda() if use_gpu else all_samples)
292-
)
288+
if use_gpu:
289+
all_samples = _move_sample_to_cuda(all_samples)
290+
293291
train_samples = [ts[:num_train] for ts in all_samples]
294292
test_samples = [ts[num_train:] for ts in all_samples]
295293
hessian_samples = [ts[:num_hessian] for ts in all_samples]
296294
else:
297295
all_samples = torch.normal(0, 1, (num_samples, in_features)).double()
298-
all_samples = (
299-
_move_sample_to_cuda(all_samples)
300-
if isinstance(all_samples, list) and use_gpu
301-
else (all_samples.cuda() if use_gpu else all_samples)
302-
)
296+
297+
if use_gpu:
298+
all_samples = all_samples.cuda()
303299
train_samples = all_samples[:num_train]
304300
test_samples = all_samples[num_train:]
305301
hessian_samples = all_samples[:num_hessian]

0 commit comments

Comments
 (0)