forked from facebookresearch/PrivacyGuard
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase_test_analysis_node.py
More file actions
99 lines (84 loc) · 3.4 KB
/
base_test_analysis_node.py
File metadata and controls
99 lines (84 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pyre-strict
import unittest
from typing import Any, Dict, List, Tuple
import numpy as np
import pandas as pd
class BaseTestAnalysisNode(unittest.TestCase):
"""
Util test class which sets up common dataframes for use in testing.
"""
def assertIsListOfFloats(self, value: Any, msg: str = "") -> None:
"""Assert that value is a list containing only float or np.floating elements."""
self.assertIsInstance(value, list, msg or "Expected a list")
self.assertTrue(
all(isinstance(x, (float, np.floating)) for x in value),
msg
or f"Expected all elements to be float, got types: {[type(x).__name__ for x in value]}",
)
def assertIsListOfFloatsWithLength(
self, value: Any, expected_length: int, msg: str = ""
) -> None:
"""Assert that value is a list of floats with a specific length."""
self.assertIsListOfFloats(value, msg)
self.assertEqual(
len(value),
expected_length,
msg or f"Expected list of length {expected_length}, got {len(value)}",
)
def assertAllKeysPresent(
self, d: Dict[str, Any], keys: List[str], msg: str = ""
) -> None:
"""Assert that all specified keys are present in dictionary."""
self.assertTrue(
set(keys).issubset(d.keys()),
msg or f"Missing keys: {set(keys) - set(d.keys())}",
)
def sample_normal_distribution(
self, mean: float = 0.0, std_dev: float = 1.0, num_samples: int = 20000
) -> pd.DataFrame:
scores = np.random.normal(loc=mean, scale=std_dev, size=num_samples)
user_ids = list(range(0, num_samples))
return pd.DataFrame({"user_id": user_ids, "score": scores})
def setUp(self) -> None:
self.df_train_merge = pd.DataFrame(
{
"user_id": [
123456,
123456,
789012,
345678,
901234,
],
"score": [0.8, 0.7, 0.6, 0.9, 0.5],
}
)
# Create sample data for testing (same structure but different values)
self.df_test_merge = pd.DataFrame(
{
"user_id": [
567890,
567890,
112233,
445566,
778899,
],
"score": [0.2, 0.3, 0.4, 0.1, 0.5],
}
)
self.user_id_key = "user_id"
super().setUp()
def get_long_dataframes(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
np.random.seed(0)
df_train_user_long = self.sample_normal_distribution(0.5, 0.1, 10000)
df_test_user_long = self.sample_normal_distribution(0.5, 0.1, 10000)
return (df_train_user_long, df_test_user_long)