diff --git a/lib/starlark_types.go b/lib/starlark_types.go index 7850b4f..7572bc9 100644 --- a/lib/starlark_types.go +++ b/lib/starlark_types.go @@ -95,8 +95,8 @@ type SymmetricValues struct { starlark.Tuple } -func (s SymmetricValues) Index(i int) SymmetricValue { return s.Tuple.Index(i).(SymmetricValue) } -func (s SymmetricValues) Type() string { return "symmetric_values" } +func (s SymmetricValues) Index(i int) *SymmetricValue { return s.Tuple.Index(i).(*SymmetricValue) } +func (s SymmetricValues) Type() string { return "symmetric_values" } func MakeSymmetricValues(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { prefix := "" @@ -123,7 +123,7 @@ func MakeSymmetricValues(thread *starlark.Thread, fn *starlark.Builtin, args sta return values, nil } -func NewSymmetricValues(values []SymmetricValue) *SymmetricValues { +func NewSymmetricValues(values []*SymmetricValue) *SymmetricValues { sv := &SymmetricValues{} for _, v := range values { sv.Tuple = append(sv.Tuple, v) @@ -1192,8 +1192,8 @@ var _ starlark.Value = (*Bag)(nil) // There is a limit in the API, this is required to get 'in' keyword to work var _ starlark.Mapping = (*Bag)(nil) -func NewModelValue(prefix string, i int64) ModelValue { - return ModelValue{ +func NewModelValue(prefix string, i int64) *ModelValue { + return &ModelValue{ prefix: prefix, id: i, } @@ -1206,16 +1206,20 @@ type ModelValue struct { id int64 } -func (m ModelValue) GetPrefix() string { +func (m *ModelValue) GetPrefix() string { return m.prefix } -func (m ModelValue) GetId() int64 { +func (m *ModelValue) GetId() int64 { return m.id } -func (m ModelValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) { - y := y_.(ModelValue) +func (m *ModelValue) SetId(id int64) { + m.id = id +} + +func (m *ModelValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) { + y := y_.(*ModelValue) switch op { case syntax.EQL: return modelValueEqual(m, y), nil @@ -1226,57 +1230,57 @@ func (m ModelValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth in } } -func modelValueEqual(s ModelValue, y ModelValue) bool { +func modelValueEqual(s *ModelValue, y *ModelValue) bool { return s.prefix == y.prefix && s.id == y.id } -func (m ModelValue) String() string { +func (m *ModelValue) String() string { return fmt.Sprintf("%s%d", m.prefix, m.id) } -func (m ModelValue) FullString() string { +func (m *ModelValue) FullString() string { return fmt.Sprintf("%s%d", m.prefix, m.id) } -func (m ModelValue) ShortString() string { +func (m *ModelValue) ShortString() string { return fmt.Sprintf("%s%d", m.prefix, m.id) } -func (m ModelValue) MarshalJSON() ([]byte, error) { +func (m *ModelValue) MarshalJSON() ([]byte, error) { return json.Marshal(m.FullString()) } -func (m ModelValue) Type() string { +func (m *ModelValue) Type() string { return ModelValueType } -func (m ModelValue) Freeze() {} +func (m *ModelValue) Freeze() {} -func (m ModelValue) Truth() starlark.Bool { +func (m *ModelValue) Truth() starlark.Bool { return true } -func (m ModelValue) Hash() (uint32, error) { +func (m *ModelValue) Hash() (uint32, error) { return starlark.String(m.FullString()).Hash() } -var _ starlark.Value = ModelValue{} -var _ starlark.Comparable = ModelValue{} +var _ starlark.Value = (*ModelValue)(nil) +var _ starlark.Comparable = (*ModelValue)(nil) // NewSymmetricValue creates a new symmetric value with the default Nominal kind. // This maintains backward compatibility with existing code. -func NewSymmetricValue(prefix string, i int64) SymmetricValue { - return SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: SymmetryKindNominal} +func NewSymmetricValue(prefix string, i int64) *SymmetricValue { + return &SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: SymmetryKindNominal} } // NewSymmetricValueWithKind creates a new symmetric value with the specified kind. -func NewSymmetricValueWithKind(prefix string, i int64, kind SymmetryKind) SymmetricValue { - return SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: kind} +func NewSymmetricValueWithKind(prefix string, i int64, kind SymmetryKind) *SymmetricValue { + return &SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: kind} } // NewRotationalSymmetricValue creates a new symmetric value for rotational domains with the limit set. -func NewRotationalSymmetricValue(prefix string, i int64, limit int) SymmetricValue { - return SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: SymmetryKindRotational, Limit: limit} +func NewRotationalSymmetricValue(prefix string, i int64, limit int) *SymmetricValue { + return &SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: SymmetryKindRotational, Limit: limit} } const SymmetricValueType = "symmetric_value" @@ -1287,12 +1291,12 @@ type SymmetricValue struct { Limit int // Only used for rotational kind (domain size for mod arithmetic). 0 for other kinds. } -func (s SymmetricValue) GetKind() SymmetryKind { +func (s *SymmetricValue) GetKind() SymmetryKind { return s.Kind } -func (s SymmetricValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) { - y := y_.(SymmetricValue) +func (s *SymmetricValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) { + y := y_.(*SymmetricValue) // Values from different domains cannot be compared if s.prefix != y.prefix { @@ -1301,9 +1305,9 @@ func (s SymmetricValue) CompareSameType(op syntax.Token, y_ starlark.Value, dept switch op { case syntax.EQL: - return modelValueEqual(s.ModelValue, y.ModelValue), nil + return symmetricValueEqual(s, y), nil case syntax.NEQ: - return !modelValueEqual(s.ModelValue, y.ModelValue), nil + return !symmetricValueEqual(s, y), nil case syntax.LT, syntax.LE, syntax.GT, syntax.GE: // Ordering only allowed for Ordinal and Interval kinds if s.Kind == SymmetryKindNominal { @@ -1326,18 +1330,22 @@ func (s SymmetricValue) CompareSameType(op syntax.Token, y_ starlark.Value, dept return false, fmt.Errorf("%s %s %s not implemented", s.Type(), op, y.Type()) } -func (s SymmetricValue) Type() string { +func symmetricValueEqual(s *SymmetricValue, y *SymmetricValue) bool { + return s.prefix == y.prefix && s.id == y.id +} + +func (s *SymmetricValue) Type() string { return SymmetricValueType } -var _ starlark.Value = SymmetricValue{} -var _ starlark.Comparable = SymmetricValue{} +var _ starlark.Value = (*SymmetricValue)(nil) +var _ starlark.Comparable = (*SymmetricValue)(nil) func CompareStringer[E fmt.Stringer](a, b E) int { return strings.Compare(a.String(), b.String()) } -func (s SymmetricValue) Binary(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) { +func (s *SymmetricValue) Binary(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) { if s.Kind == SymmetryKindRotational { return s.binaryRotational(op, y, side) } @@ -1357,7 +1365,7 @@ func (s SymmetricValue) Binary(op syntax.Token, y starlark.Value, side starlark. return NewSymmetricValueWithKind(s.prefix, s.id-int64(i), s.Kind), nil } // SymmetricValue - SymmetricValue -> int (signed difference) - if other, ok := y.(SymmetricValue); ok { + if other, ok := y.(*SymmetricValue); ok { if s.prefix != other.prefix { return nil, fmt.Errorf("cannot subtract values from different domains: %s vs %s", s.prefix, other.prefix) } @@ -1373,7 +1381,7 @@ func modPositive(a, m int64) int64 { return ((a % m) + m) % m } -func (s SymmetricValue) binaryRotational(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) { +func (s *SymmetricValue) binaryRotational(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) { limit := int64(s.Limit) if limit <= 0 { return nil, fmt.Errorf("rotational value has invalid limit %d", s.Limit) @@ -1390,7 +1398,7 @@ func (s SymmetricValue) binaryRotational(op syntax.Token, y starlark.Value, side return NewRotationalSymmetricValue(s.prefix, modPositive(s.id-int64(i), limit), s.Limit), nil } // SymmetricValue - SymmetricValue -> plain int (mod limit) - if other, ok := y.(SymmetricValue); ok { + if other, ok := y.(*SymmetricValue); ok { if s.prefix != other.prefix { return nil, fmt.Errorf("cannot subtract values from different domains: %s vs %s", s.prefix, other.prefix) } diff --git a/lib/symmetry.go b/lib/symmetry.go index 8b264db..4fcbcc7 100644 --- a/lib/symmetry.go +++ b/lib/symmetry.go @@ -272,7 +272,7 @@ func (d *SymmetryDomain) segments(thread *starlark.Thread, _ *starlark.Builtin, // Helper to extract int64 ID from value getID := func(v starlark.Value) (int64, bool) { - if sv, ok := v.(SymmetricValue); ok { + if sv, ok := v.(*SymmetricValue); ok { if sv.prefix == d.Name { return sv.id, true } diff --git a/modelchecker/channel_message.go b/modelchecker/channel_message.go index bbfa58a..e02780b 100644 --- a/modelchecker/channel_message.go +++ b/modelchecker/channel_message.go @@ -49,7 +49,7 @@ func (cm *ChannelMessage) HashCode() string { return fmt.Sprintf("%x", h.Sum(nil)) } -func (cm *ChannelMessage) Clone(refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) *ChannelMessage { +func (cm *ChannelMessage) Clone(refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) *ChannelMessage { frame, err := cm.frame.Clone(refs, permutations, alt) PanicOnError(err) params := CloneDict(cm.params, refs, permutations, alt) diff --git a/modelchecker/clone.go b/modelchecker/clone.go index ddb176a..2ddface 100644 --- a/modelchecker/clone.go +++ b/modelchecker/clone.go @@ -12,7 +12,7 @@ func deepCloneStarlarkValue(value starlark.Value, refs map[starlark.Value]starla return deepCloneStarlarkValueWithPermutations(value, refs, nil, 0) } -func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) (starlark.Value, error) { +func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) (starlark.Value, error) { if refs == nil { refs = make(map[starlark.Value]starlark.Value) } else if reflect.TypeOf(value).Kind() != reflect.Ptr { @@ -25,19 +25,37 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl // "builtin_function_or_method". switch value.Type() { - case "NoneType", "int", "float", "bool", "string", "bytes", "function", "range", "struct", "symmetric_values", "model_value", "Channel", "symmetry_domain": + case "NoneType", "int", "float", "bool", "string", "bytes", "function", "range", "struct", "symmetric_values", "Channel", "symmetry_domain": //case starlark.Bool, starlark.String, starlark.Int: // For simple values, just return a copy // Also starlark struct is immutable // symmetry_domain is stateless configuration (Name, Limit, DomainType) return value, nil + case "model_value": + // ModelValue is now a pointer type - clone it as a reference + v := value.(*lib.ModelValue) + newVal := lib.NewModelValue(v.GetPrefix(), v.GetId()) + refs[value] = newVal + return newVal, nil + case "symmetric_value": + // SymmetricValue is now a pointer type - clone it as a reference + v := value.(*lib.SymmetricValue) if permutations != nil { - v := value.(lib.SymmetricValue) + // First try direct pointer lookup if other, ok := permutations[v]; ok { + // Return the permuted value - also add to refs for consistency + refs[value] = other[alt] return other[alt], nil } + // Fall back to value-based lookup (for cases where SymmetricValues are created on-the-fly) + for k, other := range permutations { + if k.GetPrefix() == v.GetPrefix() && k.GetId() == v.GetId() { + refs[value] = other[alt] + return other[alt], nil + } + } panic(fmt.Sprintf("symmetric_value %#v (Kind: %d) should be in %v and alt %v. Keys: %+v", v, v.Kind, permutations, alt, func() []string { keys := make([]string, 0, len(permutations)) for k := range permutations { @@ -46,7 +64,13 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl return keys }())) } - return value, nil + // Clone the symmetric value - now it's a pointer type + newVal := lib.NewSymmetricValueWithKind(v.GetPrefix(), v.GetId(), v.GetKind()) + if v.Kind == lib.SymmetryKindRotational { + newVal = lib.NewRotationalSymmetricValue(v.GetPrefix(), v.GetId(), v.Limit) + } + refs[value] = newVal + return newVal, nil case "list": newVal := starlark.NewList(make([]starlark.Value, 0)) refs[value] = newVal @@ -121,16 +145,20 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl newRight := s.Right if permutations != nil { - // Map Left ID - leftSV := lib.NewSymmetricValueWithKind(s.Domain.Name, s.Left, s.Domain.Kind) - if other, ok := permutations[leftSV]; ok { - newLeft = other[alt].GetId() + // Map Left ID - look up by pointer in permutations + for oldSV, newSVs := range permutations { + if oldSV.GetPrefix() == s.Domain.Name && oldSV.GetId() == s.Left { + newLeft = newSVs[alt].GetId() + break + } } - // Map Right ID - rightSV := lib.NewSymmetricValueWithKind(s.Domain.Name, s.Right, s.Domain.Kind) - if other, ok := permutations[rightSV]; ok { - newRight = other[alt].GetId() + // Map Right ID - look up by pointer in permutations + for oldSV, newSVs := range permutations { + if oldSV.GetPrefix() == s.Domain.Name && oldSV.GetId() == s.Right { + newRight = newSVs[alt].GetId() + break + } } } @@ -201,7 +229,7 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl if err != nil { return nil, err } - newRoleId := newVal.(lib.SymmetricValue) + newRoleId := newVal.(*lib.SymmetricValue) prefix = newRoleId.GetPrefix() id = newRoleId.GetId() } @@ -236,7 +264,7 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl } func deepCloneStringDict(v *starlark.Dict, refs map[starlark.Value]starlark.Value, - src map[lib.SymmetricValue][]lib.SymmetricValue, alt int) (*starlark.Dict, error) { + src map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) (*starlark.Dict, error) { newDict := &starlark.Dict{} refs[v] = newDict @@ -255,7 +283,7 @@ func deepCloneStringDict(v *starlark.Dict, refs map[starlark.Value]starlark.Valu return newDict, nil } -func deepCloneIterableToList(iterable starlark.Iterable, refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) ([]starlark.Value, error) { +func deepCloneIterableToList(iterable starlark.Iterable, refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) ([]starlark.Value, error) { var newList []starlark.Value iter := iterable.Iterate() defer iter.Done() diff --git a/modelchecker/processor.go b/modelchecker/processor.go index 8274e31..747264a 100644 --- a/modelchecker/processor.go +++ b/modelchecker/processor.go @@ -356,7 +356,7 @@ func (p *Process) Fork() *Process { return p2 } -func (p *Process) CloneForAssert(permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) *Process { +func (p *Process) CloneForAssert(permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) *Process { refs := make(map[starlark.Value]starlark.Value) p2 := &Process{ Name: p.Name, @@ -1731,13 +1731,13 @@ func (p *Processor) findVisitedSymmetric(node *Node) (*Node, bool, string) { return other, ok, minHash } -func (p *Process) symmetricHash(permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) string { +func (p *Process) symmetricHash(permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) string { p2 := p.CloneForAssert(permutations, alt) return p2.HashCode() } func (p *Process) GetSymmetryRoles() []*lib.SymmetricValues { - m := make(map[string][]lib.SymmetricValue) + m := make(map[string][]*lib.SymmetricValue) for _, role := range p.Roles { if role != nil && role.IsSymmetric() { m[role.Name] = append(m[role.Name], lib.NewSymmetricValue(role.Name, role.Ref)) @@ -1765,15 +1765,15 @@ func (p *Process) addChannelMessage(channel *lib.Channel, roleShortRef string, f } } -func getSymmetryPermutations(process *Process) (map[lib.SymmetricValue][]lib.SymmetricValue, int) { - var values [][]lib.SymmetricValue - var usedValues [][]lib.SymmetricValue +func getSymmetryPermutations(process *Process) (map[*lib.SymmetricValue][]*lib.SymmetricValue, int) { + var values [][]*lib.SymmetricValue + var usedValues [][]*lib.SymmetricValue // Static mappings for ordinal (rank squash) and interval (zero-shift). // These have a single canonical form, so they're appended to each permutation. type staticMapping struct { - used []lib.SymmetricValue - canonical []lib.SymmetricValue + used []*lib.SymmetricValue + canonical []*lib.SymmetricValue } var postProcessingMappings []staticMapping // rotationalExtraMappings holds additional shift alternatives for tied pivots/reflections. @@ -1804,7 +1804,7 @@ func getSymmetryPermutations(process *Process) (map[lib.SymmetricValue][]lib.Sym } kind := used[0].GetKind() prefix := used[0].GetPrefix() - canonical := make([]lib.SymmetricValue, len(used)) + canonical := make([]*lib.SymmetricValue, len(used)) for j := 0; j < len(used); j++ { canonical[j] = lib.NewSymmetricValueWithKind(prefix, int64(j), kind) } @@ -1822,7 +1822,7 @@ func getSymmetryPermutations(process *Process) (map[lib.SymmetricValue][]lib.Sym postProcessingMappings = append(postProcessingMappings, staticMapping{used: used, canonical: canonical}) if tied { // Add reverse mapping as extra alternative - revCanonical := make([]lib.SymmetricValue, len(used)) + revCanonical := make([]*lib.SymmetricValue, len(used)) for j := range used { revCanonical[j] = lib.NewSymmetricValueWithKind(prefix, rev[j], kind) } @@ -1846,26 +1846,26 @@ func getSymmetryPermutations(process *Process) (map[lib.SymmetricValue][]lib.Sym fwd, rev, tied := lib.GetIntervalReflectionCandidates(usedIDs) if rev != nil && fwd == nil { // Reverse wins - revCanonical := make([]lib.SymmetricValue, len(used)) + revCanonical := make([]*lib.SymmetricValue, len(used)) for j := range used { revCanonical[j] = lib.NewSymmetricValueWithKind(prefix, rev[j], kind) } postProcessingMappings = append(postProcessingMappings, staticMapping{used: used, canonical: revCanonical}) } else if fwd != nil && !tied { // Forward wins - shifted := make([]lib.SymmetricValue, len(used)) + shifted := make([]*lib.SymmetricValue, len(used)) for j := range used { shifted[j] = lib.NewSymmetricValueWithKind(prefix, fwd[j], kind) } postProcessingMappings = append(postProcessingMappings, staticMapping{used: used, canonical: shifted}) } else { // Tied: forward as base, reverse as extra - shifted := make([]lib.SymmetricValue, len(used)) + shifted := make([]*lib.SymmetricValue, len(used)) for j := range used { shifted[j] = lib.NewSymmetricValueWithKind(prefix, fwd[j], kind) } postProcessingMappings = append(postProcessingMappings, staticMapping{used: used, canonical: shifted}) - revCanonical := make([]lib.SymmetricValue, len(used)) + revCanonical := make([]*lib.SymmetricValue, len(used)) for j := range used { revCanonical[j] = lib.NewSymmetricValueWithKind(prefix, rev[j], kind) } @@ -1873,7 +1873,7 @@ func getSymmetryPermutations(process *Process) (map[lib.SymmetricValue][]lib.Sym rotationalExtraMappings = append(rotationalExtraMappings, []staticMapping{{used: used, canonical: revCanonical}}) } } else { - shifted := make([]lib.SymmetricValue, len(used)) + shifted := make([]*lib.SymmetricValue, len(used)) for j := 0; j < len(used); j++ { shifted[j] = lib.NewSymmetricValueWithKind(prefix, used[j].GetId()-minID, kind) } @@ -1894,7 +1894,7 @@ func getSymmetryPermutations(process *Process) (map[lib.SymmetricValue][]lib.Sym candidates := lib.GetCanonicalRotationsWithReflection(usedIDs, limit) // First candidate is the base mapping first := candidates[0] - shifted := make([]lib.SymmetricValue, len(used)) + shifted := make([]*lib.SymmetricValue, len(used)) for j, v := range used { val := v.GetId() if first.Reflected { @@ -1907,7 +1907,7 @@ func getSymmetryPermutations(process *Process) (map[lib.SymmetricValue][]lib.Sym if len(candidates) > 1 { extraMappings := make([]staticMapping, 0, len(candidates)-1) for _, c := range candidates[1:] { - sh := make([]lib.SymmetricValue, len(used)) + sh := make([]*lib.SymmetricValue, len(used)) for j, v := range used { val := v.GetId() if c.Reflected { @@ -1924,7 +1924,7 @@ func getSymmetryPermutations(process *Process) (map[lib.SymmetricValue][]lib.Sym shifts := lib.GetCanonicalRotations(usedIDs, limit) // Single canonical rotation — static mapping shift := shifts[0] - shifted := make([]lib.SymmetricValue, len(used)) + shifted := make([]*lib.SymmetricValue, len(used)) for j, v := range used { shifted[j] = lib.NewRotationalSymmetricValue(prefix, ((v.GetId()+shift)%lim+lim)%lim, limit) } @@ -1935,7 +1935,7 @@ func getSymmetryPermutations(process *Process) (map[lib.SymmetricValue][]lib.Sym extraMappings := make([]staticMapping, 0, len(shifts)-1) for _, s := range shifts[1:] { - sh := make([]lib.SymmetricValue, len(used)) + sh := make([]*lib.SymmetricValue, len(used)) for j, v := range used { sh[j] = lib.NewRotationalSymmetricValue(prefix, ((v.GetId()+s)%lim+lim)%lim, limit) } @@ -1955,11 +1955,11 @@ func getSymmetryPermutations(process *Process) (map[lib.SymmetricValue][]lib.Sym defs := process.Heap.GetSymmetryDefs() for _, def := range defs { - v := make([]lib.SymmetricValue, def.Len()) + v := make([]*lib.SymmetricValue, def.Len()) for j := 0; j < def.Len(); j++ { v[j] = def.Index(j) } - slices.SortFunc(v, lib.CompareStringer[lib.SymmetricValue]) + slices.SortFunc(v, lib.CompareStringer[*lib.SymmetricValue]) // If canonicalization is disabled, we treat everything as nominal (full permutation) // to be safe, or we could potentially return identity for ordinal. @@ -1973,25 +1973,25 @@ func getSymmetryPermutations(process *Process) (map[lib.SymmetricValue][]lib.Sym // Roles are assumed Nominal roles := process.GetSymmetryRoles() for _, role := range roles { - v := make([]lib.SymmetricValue, role.Len()) + v := make([]*lib.SymmetricValue, role.Len()) for j := 0; j < role.Len(); j++ { v[j] = role.Index(j) } - slices.SortFunc(v, lib.CompareStringer[lib.SymmetricValue]) + slices.SortFunc(v, lib.CompareStringer[*lib.SymmetricValue]) values = append(values, v) usedValues = append(usedValues, v) } // Generate all permutations permutations := lib.GenerateAllPermutations(values) - v := make([][]lib.SymmetricValue, len(permutations)) + v := make([][]*lib.SymmetricValue, len(permutations)) for i, permutation := range permutations { v[i] = slices.Concat(permutation...) } // Build permutation map with actual used values as keys - permMap := make(map[lib.SymmetricValue][]lib.SymmetricValue) - actualKeys := make([]lib.SymmetricValue, 0) + permMap := make(map[*lib.SymmetricValue][]*lib.SymmetricValue) + actualKeys := make([]*lib.SymmetricValue, 0) for _, used := range usedValues { actualKeys = append(actualKeys, used...) } @@ -2029,8 +2029,8 @@ func getSymmetryPermutations(process *Process) (map[lib.SymmetricValue][]lib.Sym // Rotational ties: replicate permutations for each combination of rotational shifts. // Build all rotational shift combinations. type rotChoice struct { - used []lib.SymmetricValue - canonical []lib.SymmetricValue + used []*lib.SymmetricValue + canonical []*lib.SymmetricValue } // For each domain with ties, collect all shift options (base + extras) var rotOptions [][]rotChoice diff --git a/modelchecker/state_visitor.go b/modelchecker/state_visitor.go index 9451266..623eac3 100644 --- a/modelchecker/state_visitor.go +++ b/modelchecker/state_visitor.go @@ -9,13 +9,14 @@ import ( // StateVisitor defines the interface for visiting state elements during traversal type StateVisitor interface { - VisitSymmetricValue(sv lib.SymmetricValue) + VisitSymmetricValue(sv *lib.SymmetricValue) } // UsedSymmetricValuesCollector collects all SymmetricValue instances in the current state type UsedSymmetricValuesCollector struct { - usedValues map[string]map[int64]bool // prefix -> set of IDs - limits map[string]int // prefix -> Limit (for rotational domains) + usedValues map[string]map[int64]bool // prefix -> set of IDs + limits map[string]int // prefix -> Limit (for rotational domains) + pointers map[string]map[int64]*lib.SymmetricValue // prefix -> ID -> actual pointer from state } // NewUsedSymmetricValuesCollector creates a new collector @@ -23,18 +24,21 @@ func NewUsedSymmetricValuesCollector() *UsedSymmetricValuesCollector { return &UsedSymmetricValuesCollector{ usedValues: make(map[string]map[int64]bool), limits: make(map[string]int), + pointers: make(map[string]map[int64]*lib.SymmetricValue), } } // VisitSymmetricValue records that a symmetric value is used in the state -func (c *UsedSymmetricValuesCollector) VisitSymmetricValue(sv lib.SymmetricValue) { +func (c *UsedSymmetricValuesCollector) VisitSymmetricValue(sv *lib.SymmetricValue) { prefix := sv.GetPrefix() id := sv.GetId() if c.usedValues[prefix] == nil { c.usedValues[prefix] = make(map[int64]bool) + c.pointers[prefix] = make(map[int64]*lib.SymmetricValue) } c.usedValues[prefix][id] = true + c.pointers[prefix][id] = sv // Store the actual pointer if sv.Limit > 0 { c.limits[prefix] = sv.Limit } @@ -54,17 +58,13 @@ func (c *UsedSymmetricValuesCollector) GetUsedIds(prefix string) []int64 { return ids } -// GetUsedSymmetricValues returns SymmetricValues for a given prefix -func (c *UsedSymmetricValuesCollector) GetUsedSymmetricValues(prefix string, kind lib.SymmetryKind) []lib.SymmetricValue { +// GetUsedSymmetricValues returns the actual SymmetricValue pointers for a given prefix +func (c *UsedSymmetricValuesCollector) GetUsedSymmetricValues(prefix string, kind lib.SymmetryKind) []*lib.SymmetricValue { ids := c.GetUsedIds(prefix) - values := make([]lib.SymmetricValue, len(ids)) - limit := c.limits[prefix] + values := make([]*lib.SymmetricValue, len(ids)) for i, id := range ids { - if kind == lib.SymmetryKindRotational && limit > 0 { - values[i] = lib.NewRotationalSymmetricValue(prefix, id, limit) - } else { - values[i] = lib.NewSymmetricValueWithKind(prefix, id, kind) - } + // Return the actual pointers from the state + values[i] = c.pointers[prefix][id] } return values } @@ -110,7 +110,7 @@ func visitStarlarkValue(value starlark.Value, visitor StateVisitor, visited map[ return case "symmetric_value": - sv := value.(lib.SymmetricValue) + sv := value.(*lib.SymmetricValue) visitor.VisitSymmetricValue(sv) case "list": @@ -183,7 +183,7 @@ func visitStarlarkValue(value starlark.Value, visitor StateVisitor, visited map[ role := value.(*lib.Role) if role.IsSymmetric() { roleId := lib.NewSymmetricValue(role.Name, role.Ref) - visitor.VisitSymmetricValue(roleId) + visitor.VisitSymmetricValue(roleId) // roleId is already a pointer now } visitStarlarkValue(role.Fields, visitor, visited) visitStarlarkValue(role.Params, visitor, visited) @@ -262,7 +262,7 @@ func visitScope(scope *Scope, visitor StateVisitor, visited map[starlark.Value]b } // getUsedSymmetricValues returns the symmetric values that are actually used in the process state -func (p *Process) getUsedSymmetricValues() [][]lib.SymmetricValue { +func (p *Process) getUsedSymmetricValues() [][]*lib.SymmetricValue { collector := NewUsedSymmetricValuesCollector() p.AcceptVisitor(collector) @@ -279,7 +279,7 @@ func (p *Process) getUsedSymmetricValues() [][]lib.SymmetricValue { // Get all symmetric value definitions defs := p.Heap.GetSymmetryDefs() - result := make([][]lib.SymmetricValue, 0) + result := make([][]*lib.SymmetricValue, 0) for _, def := range defs { if def.Len() == 0 { continue @@ -291,7 +291,7 @@ func (p *Process) getUsedSymmetricValues() [][]lib.SymmetricValue { // For materialized rotational domains, return the full set [0..limit-1] if domain, ok := materializedDomains[prefix]; ok { - fullSet := make([]lib.SymmetricValue, domain.Limit) + fullSet := make([]*lib.SymmetricValue, domain.Limit) for i := 0; i < domain.Limit; i++ { fullSet[i] = lib.NewRotationalSymmetricValue(prefix, int64(i), domain.Limit) } diff --git a/modelchecker/thread.go b/modelchecker/thread.go index fd1ec46..0ba39ec 100644 --- a/modelchecker/thread.go +++ b/modelchecker/thread.go @@ -54,7 +54,7 @@ func (h *Heap) GetSymmetryDefs() []*lib.SymmetricValues { // Also handle SymmetryDomain from the new symmetry.nominal() API if domain, ok := value.(*lib.SymmetryDomain); ok { // Create a SymmetricValues containing all possible values for this domain - values := make([]lib.SymmetricValue, domain.Limit) + values := make([]*lib.SymmetricValue, domain.Limit) for i := 0; i < domain.Limit; i++ { if domain.Kind == lib.SymmetryKindRotational { values[i] = lib.NewRotationalSymmetricValue(domain.Name, int64(i), domain.Limit) @@ -239,7 +239,7 @@ func (h *Heap) insert(k string, v starlark.Value) bool { return true } -func (h *Heap) Clone(refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) *Heap { +func (h *Heap) Clone(refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) *Heap { return &Heap{state: CloneDict(h.state, refs, permutations, alt), globals: h.globals} } @@ -267,7 +267,7 @@ func (s *Scope) MarshalJSON() ([]byte, error) { }) } -func (s *Scope) Clone(refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) *Scope { +func (s *Scope) Clone(refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) *Scope { if s == nil { return nil } @@ -362,12 +362,12 @@ func (s *Scope) getAllVisibleVariablesToDict(dict starlark.StringDict) { s.getAllVisibleVariablesResolveRoles(dict, make(map[starlark.Value]starlark.Value)) } -func CloneDict(oldDict starlark.StringDict, refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) starlark.StringDict { +func CloneDict(oldDict starlark.StringDict, refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) starlark.StringDict { return CopyDict(oldDict, nil, refs, permutations, alt) } // CopyDict copies values `from` to `to` overriding existing values. If the `to` is nil, creates a new dict. -func CopyDict(from starlark.StringDict, to starlark.StringDict, refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) starlark.StringDict { +func CopyDict(from starlark.StringDict, to starlark.StringDict, refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) starlark.StringDict { if to == nil { to = make(starlark.StringDict) } @@ -423,7 +423,7 @@ func (c *CallFrame) HashCode() string { } return fmt.Sprintf("%x", h.Sum(nil)) } -func (c *CallFrame) Clone(refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) (*CallFrame, error) { +func (c *CallFrame) Clone(refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) (*CallFrame, error) { obj := c.obj if c.obj != nil { cloned, err := deepCloneStarlarkValueWithPermutations(c.obj, refs, permutations, alt) @@ -454,7 +454,7 @@ func NewCallStack() *CallStack { return &CallStack{lib.NewStack[*CallFrame]()} } -func (s *CallStack) CloneFrames(refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) []*CallFrame { +func (s *CallStack) CloneFrames(refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) []*CallFrame { frames := s.RawArray() cloned := make([]*CallFrame, len(frames)) for i, frame := range frames { @@ -465,7 +465,7 @@ func (s *CallStack) CloneFrames(refs map[starlark.Value]starlark.Value, permutat return cloned } -func (s *CallStack) Clone(refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) *CallStack { +func (s *CallStack) Clone(refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) *CallStack { newStack := NewCallStack() for _, frame := range s.CloneFrames(refs, permutations, alt) { newStack.Push(frame) @@ -568,7 +568,7 @@ func (t *Thread) popFrame() *CallFrame { return frame } -func (t *Thread) Clone(refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) *Thread { +func (t *Thread) Clone(refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) *Thread { // TODO: handle symmetry in stack.Clone() return &Thread{Id: t.Id, Process: t.Process, Files: t.Files, Stack: t.Stack.Clone(refs, permutations, alt), Fairness: t.Fairness} }