Skip to content

Commit 29726f2

Browse files
authored
Merge pull request #194 from alwilson/soft_const_priority_over_dist
Remove high-priority, soft weight constraint from distributions
2 parents 4a3a9dc + 2cf28cc commit 29726f2

File tree

7 files changed

+74
-66
lines changed

7 files changed

+74
-66
lines changed

src/vsc/model/constraint_dist_scope_model.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
44
@author: mballance
55
'''
6+
from typing import List, Tuple
67
from vsc.model.constraint_inline_scope_model import ConstraintInlineScopeModel
78
from vsc.model.constraint_dist_model import ConstraintDistModel
89
from vsc.model.constraint_soft_model import ConstraintSoftModel
10+
from vsc.model.rand_state import RandState
911

1012
class ConstraintDistScopeModel(ConstraintInlineScopeModel):
1113
"""Holds implementation data about dist constraint"""
@@ -16,10 +18,36 @@ def __init__(self, dist_c, constraints=None):
1618
self.dist_c : ConstraintDistModel = dist_c
1719

1820
self.dist_soft_c : ConstraintSoftModel = None
19-
20-
# Indicates the current-target range. This is
21-
# updated during the weight-selection process
21+
22+
# List of (weight, index) tuples
23+
self.weight_list : List[Tuple[int, int]] = []
24+
self.total_weight = 0
25+
26+
# Indicates the current-target range. This is used to
27+
# by solvegroup_swizzler_range.
2228
self.target_range = 0
29+
30+
def next_target_range(self, randstate : RandState) -> int:
31+
"""Select the next target range from the weight list"""
32+
33+
seed_v = randstate.rng.randint(1, self.total_weight)
34+
35+
# Find the first range
36+
i = 0
37+
while i < len(self.weight_list):
38+
seed_v -= self.weight_list[i][0]
39+
40+
if seed_v <= 0:
41+
break
42+
43+
i += 1
44+
45+
if i >= len(self.weight_list):
46+
i = len(self.weight_list)-1
47+
48+
self.target_range = self.weight_list[i][1]
49+
50+
return self.target_range
2351

2452
def set_dist_soft_c(self, c : ConstraintSoftModel):
2553
self.addConstraint(c)

src/vsc/model/rand_set.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
# @author: ballance
2222

2323
from builtins import set
24-
from typing import Set, List
24+
from typing import Set, List, Dict
25+
from vsc.model.constraint_dist_scope_model import ConstraintDistScopeModel
2526

2627
from vsc.model.constraint_model import ConstraintModel
2728
from vsc.model.field_model import FieldModel
2829
from vsc.model.constraint_soft_model import ConstraintSoftModel
30+
from vsc.model.field_scalar_model import FieldScalarModel
2931
from vsc.visitors.model_pretty_printer import ModelPrettyPrinter
3032

3133

@@ -43,7 +45,7 @@ def __init__(self, order=-1):
4345
self.soft_constraint_s : Set[ConstraintModel] = set()
4446
self.soft_constraint_l : List[ConstraintModel] = []
4547
self.soft_priority = 0
46-
self.dist_field_m = {}
48+
self.dist_field_m : Dict[FieldScalarModel, List[ConstraintDistScopeModel]] = {}
4749

4850
# List of fields in each ordered set
4951
# Only non-none if order constraints impact this randset

src/vsc/model/rand_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def rand_s(self):
2929

3030
return val
3131

32-
def randint(self, low, high):
32+
def randint(self, low, high) -> int:
3333
low = int(low)
3434
high = int(high)
3535

src/vsc/model/solvegroup_swizzler_partsel.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,24 @@
1010
from vsc.model.expr_literal_model import ExprLiteralModel
1111
from vsc.model.expr_model import ExprModel
1212
from vsc.model.expr_partselect_model import ExprPartselectModel
13+
from vsc.model.field_model import FieldModel
1314
from vsc.model.field_scalar_model import FieldScalarModel
15+
from vsc.model.rand_set import RandSet
16+
from vsc.model.rand_state import RandState
1417
from vsc.model.variable_bound_model import VariableBoundModel
1518

1619

1720
class SolveGroupSwizzlerPartsel(object):
1821

1922
def __init__(self, randstate, solve_info, debug=0):
2023
self.debug = debug
21-
self.randstate = randstate
24+
self.randstate : RandState = randstate
2225
self.solve_info = solve_info
2326

2427
def swizzle(self,
25-
btor,
26-
rs,
27-
bound_m):
28+
btor,
29+
rs : RandSet,
30+
bound_m : VariableBoundModel):
2831
if self.debug > 0:
2932
print("--> swizzle_randvars")
3033

@@ -51,7 +54,7 @@ def swizzle(self,
5154
if self.debug > 0:
5255
print("<-- swizzle_randvars")
5356

54-
def swizzle_field_l(self, field_l, rs, bound_m, btor):
57+
def swizzle_field_l(self, field_l, rs : RandSet, bound_m, btor):
5558
e = None
5659
if len(field_l) > 0:
5760
# Make a copy of the field list so we don't
@@ -96,7 +99,10 @@ def swizzle_field_l(self, field_l, rs, bound_m, btor):
9699
else:
97100
return False
98101

99-
def swizzle_field(self, f, rs, bound_m) -> ExprModel:
102+
def swizzle_field(self,
103+
f : FieldScalarModel,
104+
rs : RandSet,
105+
bound_m : VariableBoundModel)->ExprModel:
100106
ret = None
101107

102108
if self.debug > 0:
@@ -106,14 +112,15 @@ def swizzle_field(self, f, rs, bound_m) -> ExprModel:
106112
if self.debug > 0:
107113
print("Note: field %s is in dist map" % f.name)
108114
for d in rs.dist_field_m[f]:
109-
print(" Target interval %d" % d.target_range)
115+
print(" Weight list %s" % d.weight_list)
110116
if len(rs.dist_field_m[f]) > 1:
111117
target_d = self.randstate.randint(0, len(rs.dist_field_m[f])-1)
112118
dist_scope_c = rs.dist_field_m[f][target_d]
113119
else:
114120
dist_scope_c = rs.dist_field_m[f][0]
115121

116-
target_w = dist_scope_c.dist_c.weights[dist_scope_c.target_range]
122+
target_range = dist_scope_c.next_target_range(self.randstate)
123+
target_w = dist_scope_c.dist_c.weights[target_range]
117124
if target_w.rng_rhs is not None:
118125
# Dual-bound range
119126
val_l = target_w.rng_lhs.val()
@@ -129,6 +136,8 @@ def swizzle_field(self, f, rs, bound_m) -> ExprModel:
129136
else:
130137
# Single value
131138
val = target_w.rng_lhs.val()
139+
if self.debug > 0:
140+
print("Select dist-weight value %d" % (int(val)))
132141
ret = [ExprBinModel(
133142
ExprFieldRefModel(f),
134143
BinExprType.Eq,

src/vsc/model/solvegroup_swizzler_range.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vsc.model.expr_literal_model import ExprLiteralModel
1212
from vsc.model.expr_model import ExprModel
1313
from vsc.model.field_scalar_model import FieldScalarModel
14+
from vsc.model.rand_set import RandSet
1415
from vsc.model.variable_bound_model import VariableBoundModel
1516

1617

@@ -103,7 +104,10 @@ def swizzle_field_l(self, field_l, rs, bound_m, btor):
103104
else:
104105
return False
105106

106-
def swizzle_field(self, f, rs, bound_m) -> ExprModel:
107+
def swizzle_field(self,
108+
f : FieldScalarModel,
109+
rs : RandSet,
110+
bound_m : VariableBoundModel)->ExprModel:
107111
ret = None
108112

109113
if self.debug > 0:

src/vsc/visitors/dist_constraint_builder.py

Lines changed: 10 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -96,56 +96,21 @@ def visit_constraint_dist(self, c):
9696

9797
# Form a list of non-zero weighted tuples of weight/range
9898
# Sort in ascending order
99-
weight_l = []
100-
total_w = 0
99+
weight_list = []
100+
total_weight = 0
101101
for i,w in enumerate(c.weights):
102102
weight = int(w.weight.val())
103-
total_w += weight
103+
total_weight += weight
104104
if weight > 0:
105-
weight_l.append((weight, i))
106-
weight_l.sort(key=lambda w:w[0])
107-
108-
seed_v = self.rng.randint(1, total_w)
109-
110-
# Find the first range
111-
i = 0
112-
while i < len(weight_l):
113-
seed_v -= weight_l[i][0]
114-
115-
if seed_v <= 0:
116-
break
117-
118-
i += 1
105+
weight_list.append((weight, i))
106+
weight_list.sort(key=lambda w:w[0])
119107

120-
if i >= len(weight_l):
121-
i = len(weight_l)-1
108+
scope.weight_list = weight_list
109+
scope.total_weight = total_weight
110+
111+
# Call next_target_range for solvegroup_swizzler_range to use
112+
_ = scope.next_target_range(self.rng)
122113

123-
scope.target_range = weight_l[i][1]
124-
target_w = c.weights[weight_l[i][1]]
125-
dist_soft_c = None
126-
if target_w.rng_rhs is not None:
127-
dist_soft_c = ConstraintSoftModel(
128-
ExprBinModel(
129-
ExprBinModel(
130-
c.lhs,
131-
BinExprType.Ge,
132-
target_w.rng_lhs),
133-
BinExprType.And,
134-
ExprBinModel(
135-
c.lhs,
136-
BinExprType.Le,
137-
target_w.rng_rhs)))
138-
else:
139-
dist_soft_c = ConstraintSoftModel(
140-
ExprBinModel(
141-
c.lhs,
142-
BinExprType.Eq,
143-
target_w.rng_lhs))
144-
# Give dist constraints a high priority to allow
145-
# them to override all user-defined soft constraints
146-
dist_soft_c.priority = 1000000
147-
scope.set_dist_soft_c(dist_soft_c)
148-
149114
self.override_constraint(scope)
150115

151116

ve/unit/test_constraint_soft.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def dist_a(self):
7070
it.a == 1 #B
7171

7272
def test_soft_dist_priority(self):
73-
"""Ensures that dist constraints take priority over soft constraints"""
73+
"""Ensures that soft constraints take priority over dist constraints"""
7474

7575
@vsc.randobj
7676
class my_item(object):
@@ -90,18 +90,18 @@ def dist_a(self):
9090
vsc.weight(1, 10),
9191
vsc.weight(2, 10),
9292
vsc.weight(4, 10),
93-
vsc.weight(8, 10)])
93+
vsc.weight(8, 10)])
9494

9595
hist = [0]*9
9696
item = my_item()
9797
for i in range(100):
9898
item.randomize()
9999
hist[item.a] += 1
100100

101-
self.assertGreater(hist[0], 0)
102-
self.assertGreater(hist[1], 0)
103-
self.assertGreater(hist[2], 0)
104-
self.assertGreater(hist[4], 0)
101+
self.assertEqual(hist[0], 0)
102+
self.assertEqual(hist[1], 0)
103+
self.assertEqual(hist[2], 0)
104+
self.assertEqual(hist[4], 0)
105105
self.assertGreater(hist[8], 0)
106106

107107
def test_compound_array(self):

0 commit comments

Comments
 (0)