Skip to content

Commit 8421646

Browse files
Make ModelValue and SymmetricValue mutable reference (#311)
Co-authored-by: jayaprabhakar <jayaprabhakar@gmail.com>
1 parent 44767b9 commit 8421646

7 files changed

Lines changed: 146 additions & 110 deletions

File tree

lib/starlark_types.go

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ type SymmetricValues struct {
9595
starlark.Tuple
9696
}
9797

98-
func (s SymmetricValues) Index(i int) SymmetricValue { return s.Tuple.Index(i).(SymmetricValue) }
99-
func (s SymmetricValues) Type() string { return "symmetric_values" }
98+
func (s SymmetricValues) Index(i int) *SymmetricValue { return s.Tuple.Index(i).(*SymmetricValue) }
99+
func (s SymmetricValues) Type() string { return "symmetric_values" }
100100

101101
func MakeSymmetricValues(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
102102
prefix := ""
@@ -123,7 +123,7 @@ func MakeSymmetricValues(thread *starlark.Thread, fn *starlark.Builtin, args sta
123123
return values, nil
124124
}
125125

126-
func NewSymmetricValues(values []SymmetricValue) *SymmetricValues {
126+
func NewSymmetricValues(values []*SymmetricValue) *SymmetricValues {
127127
sv := &SymmetricValues{}
128128
for _, v := range values {
129129
sv.Tuple = append(sv.Tuple, v)
@@ -1192,8 +1192,8 @@ var _ starlark.Value = (*Bag)(nil)
11921192
// There is a limit in the API, this is required to get 'in' keyword to work
11931193
var _ starlark.Mapping = (*Bag)(nil)
11941194

1195-
func NewModelValue(prefix string, i int64) ModelValue {
1196-
return ModelValue{
1195+
func NewModelValue(prefix string, i int64) *ModelValue {
1196+
return &ModelValue{
11971197
prefix: prefix,
11981198
id: i,
11991199
}
@@ -1206,16 +1206,20 @@ type ModelValue struct {
12061206
id int64
12071207
}
12081208

1209-
func (m ModelValue) GetPrefix() string {
1209+
func (m *ModelValue) GetPrefix() string {
12101210
return m.prefix
12111211
}
12121212

1213-
func (m ModelValue) GetId() int64 {
1213+
func (m *ModelValue) GetId() int64 {
12141214
return m.id
12151215
}
12161216

1217-
func (m ModelValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) {
1218-
y := y_.(ModelValue)
1217+
func (m *ModelValue) SetId(id int64) {
1218+
m.id = id
1219+
}
1220+
1221+
func (m *ModelValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) {
1222+
y := y_.(*ModelValue)
12191223
switch op {
12201224
case syntax.EQL:
12211225
return modelValueEqual(m, y), nil
@@ -1226,57 +1230,57 @@ func (m ModelValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth in
12261230
}
12271231
}
12281232

1229-
func modelValueEqual(s ModelValue, y ModelValue) bool {
1233+
func modelValueEqual(s *ModelValue, y *ModelValue) bool {
12301234
return s.prefix == y.prefix && s.id == y.id
12311235
}
12321236

1233-
func (m ModelValue) String() string {
1237+
func (m *ModelValue) String() string {
12341238
return fmt.Sprintf("%s%d", m.prefix, m.id)
12351239
}
12361240

1237-
func (m ModelValue) FullString() string {
1241+
func (m *ModelValue) FullString() string {
12381242
return fmt.Sprintf("%s%d", m.prefix, m.id)
12391243
}
12401244

1241-
func (m ModelValue) ShortString() string {
1245+
func (m *ModelValue) ShortString() string {
12421246
return fmt.Sprintf("%s%d", m.prefix, m.id)
12431247
}
12441248

1245-
func (m ModelValue) MarshalJSON() ([]byte, error) {
1249+
func (m *ModelValue) MarshalJSON() ([]byte, error) {
12461250
return json.Marshal(m.FullString())
12471251
}
12481252

1249-
func (m ModelValue) Type() string {
1253+
func (m *ModelValue) Type() string {
12501254
return ModelValueType
12511255
}
12521256

1253-
func (m ModelValue) Freeze() {}
1257+
func (m *ModelValue) Freeze() {}
12541258

1255-
func (m ModelValue) Truth() starlark.Bool {
1259+
func (m *ModelValue) Truth() starlark.Bool {
12561260
return true
12571261
}
12581262

1259-
func (m ModelValue) Hash() (uint32, error) {
1263+
func (m *ModelValue) Hash() (uint32, error) {
12601264
return starlark.String(m.FullString()).Hash()
12611265
}
12621266

1263-
var _ starlark.Value = ModelValue{}
1264-
var _ starlark.Comparable = ModelValue{}
1267+
var _ starlark.Value = (*ModelValue)(nil)
1268+
var _ starlark.Comparable = (*ModelValue)(nil)
12651269

12661270
// NewSymmetricValue creates a new symmetric value with the default Nominal kind.
12671271
// This maintains backward compatibility with existing code.
1268-
func NewSymmetricValue(prefix string, i int64) SymmetricValue {
1269-
return SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: SymmetryKindNominal}
1272+
func NewSymmetricValue(prefix string, i int64) *SymmetricValue {
1273+
return &SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: SymmetryKindNominal}
12701274
}
12711275

12721276
// NewSymmetricValueWithKind creates a new symmetric value with the specified kind.
1273-
func NewSymmetricValueWithKind(prefix string, i int64, kind SymmetryKind) SymmetricValue {
1274-
return SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: kind}
1277+
func NewSymmetricValueWithKind(prefix string, i int64, kind SymmetryKind) *SymmetricValue {
1278+
return &SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: kind}
12751279
}
12761280

12771281
// NewRotationalSymmetricValue creates a new symmetric value for rotational domains with the limit set.
1278-
func NewRotationalSymmetricValue(prefix string, i int64, limit int) SymmetricValue {
1279-
return SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: SymmetryKindRotational, Limit: limit}
1282+
func NewRotationalSymmetricValue(prefix string, i int64, limit int) *SymmetricValue {
1283+
return &SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: SymmetryKindRotational, Limit: limit}
12801284
}
12811285

12821286
const SymmetricValueType = "symmetric_value"
@@ -1287,12 +1291,12 @@ type SymmetricValue struct {
12871291
Limit int // Only used for rotational kind (domain size for mod arithmetic). 0 for other kinds.
12881292
}
12891293

1290-
func (s SymmetricValue) GetKind() SymmetryKind {
1294+
func (s *SymmetricValue) GetKind() SymmetryKind {
12911295
return s.Kind
12921296
}
12931297

1294-
func (s SymmetricValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) {
1295-
y := y_.(SymmetricValue)
1298+
func (s *SymmetricValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) {
1299+
y := y_.(*SymmetricValue)
12961300

12971301
// Values from different domains cannot be compared
12981302
if s.prefix != y.prefix {
@@ -1301,9 +1305,9 @@ func (s SymmetricValue) CompareSameType(op syntax.Token, y_ starlark.Value, dept
13011305

13021306
switch op {
13031307
case syntax.EQL:
1304-
return modelValueEqual(s.ModelValue, y.ModelValue), nil
1308+
return symmetricValueEqual(s, y), nil
13051309
case syntax.NEQ:
1306-
return !modelValueEqual(s.ModelValue, y.ModelValue), nil
1310+
return !symmetricValueEqual(s, y), nil
13071311
case syntax.LT, syntax.LE, syntax.GT, syntax.GE:
13081312
// Ordering only allowed for Ordinal and Interval kinds
13091313
if s.Kind == SymmetryKindNominal {
@@ -1326,18 +1330,22 @@ func (s SymmetricValue) CompareSameType(op syntax.Token, y_ starlark.Value, dept
13261330
return false, fmt.Errorf("%s %s %s not implemented", s.Type(), op, y.Type())
13271331
}
13281332

1329-
func (s SymmetricValue) Type() string {
1333+
func symmetricValueEqual(s *SymmetricValue, y *SymmetricValue) bool {
1334+
return s.prefix == y.prefix && s.id == y.id
1335+
}
1336+
1337+
func (s *SymmetricValue) Type() string {
13301338
return SymmetricValueType
13311339
}
13321340

1333-
var _ starlark.Value = SymmetricValue{}
1334-
var _ starlark.Comparable = SymmetricValue{}
1341+
var _ starlark.Value = (*SymmetricValue)(nil)
1342+
var _ starlark.Comparable = (*SymmetricValue)(nil)
13351343

13361344
func CompareStringer[E fmt.Stringer](a, b E) int {
13371345
return strings.Compare(a.String(), b.String())
13381346
}
13391347

1340-
func (s SymmetricValue) Binary(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) {
1348+
func (s *SymmetricValue) Binary(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) {
13411349
if s.Kind == SymmetryKindRotational {
13421350
return s.binaryRotational(op, y, side)
13431351
}
@@ -1357,7 +1365,7 @@ func (s SymmetricValue) Binary(op syntax.Token, y starlark.Value, side starlark.
13571365
return NewSymmetricValueWithKind(s.prefix, s.id-int64(i), s.Kind), nil
13581366
}
13591367
// SymmetricValue - SymmetricValue -> int (signed difference)
1360-
if other, ok := y.(SymmetricValue); ok {
1368+
if other, ok := y.(*SymmetricValue); ok {
13611369
if s.prefix != other.prefix {
13621370
return nil, fmt.Errorf("cannot subtract values from different domains: %s vs %s", s.prefix, other.prefix)
13631371
}
@@ -1373,7 +1381,7 @@ func modPositive(a, m int64) int64 {
13731381
return ((a % m) + m) % m
13741382
}
13751383

1376-
func (s SymmetricValue) binaryRotational(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) {
1384+
func (s *SymmetricValue) binaryRotational(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) {
13771385
limit := int64(s.Limit)
13781386
if limit <= 0 {
13791387
return nil, fmt.Errorf("rotational value has invalid limit %d", s.Limit)
@@ -1390,7 +1398,7 @@ func (s SymmetricValue) binaryRotational(op syntax.Token, y starlark.Value, side
13901398
return NewRotationalSymmetricValue(s.prefix, modPositive(s.id-int64(i), limit), s.Limit), nil
13911399
}
13921400
// SymmetricValue - SymmetricValue -> plain int (mod limit)
1393-
if other, ok := y.(SymmetricValue); ok {
1401+
if other, ok := y.(*SymmetricValue); ok {
13941402
if s.prefix != other.prefix {
13951403
return nil, fmt.Errorf("cannot subtract values from different domains: %s vs %s", s.prefix, other.prefix)
13961404
}

lib/symmetry.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ func (d *SymmetryDomain) segments(thread *starlark.Thread, _ *starlark.Builtin,
272272

273273
// Helper to extract int64 ID from value
274274
getID := func(v starlark.Value) (int64, bool) {
275-
if sv, ok := v.(SymmetricValue); ok {
275+
if sv, ok := v.(*SymmetricValue); ok {
276276
if sv.prefix == d.Name {
277277
return sv.id, true
278278
}

modelchecker/channel_message.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func (cm *ChannelMessage) HashCode() string {
4949
return fmt.Sprintf("%x", h.Sum(nil))
5050
}
5151

52-
func (cm *ChannelMessage) Clone(refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) *ChannelMessage {
52+
func (cm *ChannelMessage) Clone(refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) *ChannelMessage {
5353
frame, err := cm.frame.Clone(refs, permutations, alt)
5454
PanicOnError(err)
5555
params := CloneDict(cm.params, refs, permutations, alt)

modelchecker/clone.go

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func deepCloneStarlarkValue(value starlark.Value, refs map[starlark.Value]starla
1212
return deepCloneStarlarkValueWithPermutations(value, refs, nil, 0)
1313
}
1414

15-
func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) (starlark.Value, error) {
15+
func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) (starlark.Value, error) {
1616
if refs == nil {
1717
refs = make(map[starlark.Value]starlark.Value)
1818
} else if reflect.TypeOf(value).Kind() != reflect.Ptr {
@@ -25,19 +25,37 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl
2525
// "builtin_function_or_method".
2626
switch value.Type() {
2727

28-
case "NoneType", "int", "float", "bool", "string", "bytes", "function", "range", "struct", "symmetric_values", "model_value", "Channel", "symmetry_domain":
28+
case "NoneType", "int", "float", "bool", "string", "bytes", "function", "range", "struct", "symmetric_values", "Channel", "symmetry_domain":
2929
//case starlark.Bool, starlark.String, starlark.Int:
3030
// For simple values, just return a copy
3131
// Also starlark struct is immutable
3232
// symmetry_domain is stateless configuration (Name, Limit, DomainType)
3333
return value, nil
3434

35+
case "model_value":
36+
// ModelValue is now a pointer type - clone it as a reference
37+
v := value.(*lib.ModelValue)
38+
newVal := lib.NewModelValue(v.GetPrefix(), v.GetId())
39+
refs[value] = newVal
40+
return newVal, nil
41+
3542
case "symmetric_value":
43+
// SymmetricValue is now a pointer type - clone it as a reference
44+
v := value.(*lib.SymmetricValue)
3645
if permutations != nil {
37-
v := value.(lib.SymmetricValue)
46+
// First try direct pointer lookup
3847
if other, ok := permutations[v]; ok {
48+
// Return the permuted value - also add to refs for consistency
49+
refs[value] = other[alt]
3950
return other[alt], nil
4051
}
52+
// Fall back to value-based lookup (for cases where SymmetricValues are created on-the-fly)
53+
for k, other := range permutations {
54+
if k.GetPrefix() == v.GetPrefix() && k.GetId() == v.GetId() {
55+
refs[value] = other[alt]
56+
return other[alt], nil
57+
}
58+
}
4159
panic(fmt.Sprintf("symmetric_value %#v (Kind: %d) should be in %v and alt %v. Keys: %+v", v, v.Kind, permutations, alt, func() []string {
4260
keys := make([]string, 0, len(permutations))
4361
for k := range permutations {
@@ -46,7 +64,13 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl
4664
return keys
4765
}()))
4866
}
49-
return value, nil
67+
// Clone the symmetric value - now it's a pointer type
68+
newVal := lib.NewSymmetricValueWithKind(v.GetPrefix(), v.GetId(), v.GetKind())
69+
if v.Kind == lib.SymmetryKindRotational {
70+
newVal = lib.NewRotationalSymmetricValue(v.GetPrefix(), v.GetId(), v.Limit)
71+
}
72+
refs[value] = newVal
73+
return newVal, nil
5074
case "list":
5175
newVal := starlark.NewList(make([]starlark.Value, 0))
5276
refs[value] = newVal
@@ -121,16 +145,20 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl
121145
newRight := s.Right
122146

123147
if permutations != nil {
124-
// Map Left ID
125-
leftSV := lib.NewSymmetricValueWithKind(s.Domain.Name, s.Left, s.Domain.Kind)
126-
if other, ok := permutations[leftSV]; ok {
127-
newLeft = other[alt].GetId()
148+
// Map Left ID - look up by pointer in permutations
149+
for oldSV, newSVs := range permutations {
150+
if oldSV.GetPrefix() == s.Domain.Name && oldSV.GetId() == s.Left {
151+
newLeft = newSVs[alt].GetId()
152+
break
153+
}
128154
}
129155

130-
// Map Right ID
131-
rightSV := lib.NewSymmetricValueWithKind(s.Domain.Name, s.Right, s.Domain.Kind)
132-
if other, ok := permutations[rightSV]; ok {
133-
newRight = other[alt].GetId()
156+
// Map Right ID - look up by pointer in permutations
157+
for oldSV, newSVs := range permutations {
158+
if oldSV.GetPrefix() == s.Domain.Name && oldSV.GetId() == s.Right {
159+
newRight = newSVs[alt].GetId()
160+
break
161+
}
134162
}
135163
}
136164

@@ -201,7 +229,7 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl
201229
if err != nil {
202230
return nil, err
203231
}
204-
newRoleId := newVal.(lib.SymmetricValue)
232+
newRoleId := newVal.(*lib.SymmetricValue)
205233
prefix = newRoleId.GetPrefix()
206234
id = newRoleId.GetId()
207235
}
@@ -236,7 +264,7 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl
236264
}
237265

238266
func deepCloneStringDict(v *starlark.Dict, refs map[starlark.Value]starlark.Value,
239-
src map[lib.SymmetricValue][]lib.SymmetricValue, alt int) (*starlark.Dict, error) {
267+
src map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) (*starlark.Dict, error) {
240268

241269
newDict := &starlark.Dict{}
242270
refs[v] = newDict
@@ -255,7 +283,7 @@ func deepCloneStringDict(v *starlark.Dict, refs map[starlark.Value]starlark.Valu
255283
return newDict, nil
256284
}
257285

258-
func deepCloneIterableToList(iterable starlark.Iterable, refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) ([]starlark.Value, error) {
286+
func deepCloneIterableToList(iterable starlark.Iterable, refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) ([]starlark.Value, error) {
259287
var newList []starlark.Value
260288
iter := iterable.Iterate()
261289
defer iter.Done()

0 commit comments

Comments
 (0)